From f3a15c401e592eb029a24f121f7253a888262e95 Mon Sep 17 00:00:00 2001 From: andy-stark-redis <164213578+andy-stark-redis@users.noreply.github.com> Date: Fri, 17 Jan 2025 08:29:51 +0000 Subject: [PATCH 1/5] DOC-4560 pipelines/transactions example (#3202) * DOC-4560 basic transaction example * DOC-4560 added pipe/transaction examples --- doctests/pipe_trans_example_test.go | 180 ++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 doctests/pipe_trans_example_test.go diff --git a/doctests/pipe_trans_example_test.go b/doctests/pipe_trans_example_test.go new file mode 100644 index 000000000..ea1dd5b48 --- /dev/null +++ b/doctests/pipe_trans_example_test.go @@ -0,0 +1,180 @@ +// EXAMPLE: pipe_trans_tutorial +// HIDE_START +package example_commands_test + +import ( + "context" + "fmt" + + "github.com/redis/go-redis/v9" +) + +// HIDE_END + +func ExampleClient_transactions() { + ctx := context.Background() + + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", // no password docs + DB: 0, // use default DB + }) + // REMOVE_START + for i := 0; i < 5; i++ { + rdb.Del(ctx, fmt.Sprintf("seat:%d", i)) + } + + rdb.Del(ctx, "counter:1", "counter:2", "counter:3", "shellpath") + // REMOVE_END + + // STEP_START basic_pipe + pipe := rdb.Pipeline() + + for i := 0; i < 5; i++ { + pipe.Set(ctx, fmt.Sprintf("seat:%v", i), fmt.Sprintf("#%v", i), 0) + } + + cmds, err := pipe.Exec(ctx) + + if err != nil { + panic(err) + } + + for _, c := range cmds { + fmt.Printf("%v;", c.(*redis.StatusCmd).Val()) + } + + fmt.Println("") + // >>> OK;OK;OK;OK;OK; + + pipe = rdb.Pipeline() + + get0Result := pipe.Get(ctx, "seat:0") + get3Result := pipe.Get(ctx, "seat:3") + get4Result := pipe.Get(ctx, "seat:4") + + cmds, err = pipe.Exec(ctx) + + // The results are available only after the pipeline + // has finished executing. + fmt.Println(get0Result.Val()) // >>> #0 + fmt.Println(get3Result.Val()) // >>> #3 + fmt.Println(get4Result.Val()) // >>> #4 + // STEP_END + + // STEP_START basic_pipe_pipelined + var pd0Result *redis.StatusCmd + var pd3Result *redis.StatusCmd + var pd4Result *redis.StatusCmd + + cmds, err = rdb.Pipelined(ctx, func(pipe redis.Pipeliner) error { + pd0Result = (*redis.StatusCmd)(pipe.Get(ctx, "seat:0")) + pd3Result = (*redis.StatusCmd)(pipe.Get(ctx, "seat:3")) + pd4Result = (*redis.StatusCmd)(pipe.Get(ctx, "seat:4")) + return nil + }) + + if err != nil { + panic(err) + } + + // The results are available only after the pipeline + // has finished executing. + fmt.Println(pd0Result.Val()) // >>> #0 + fmt.Println(pd3Result.Val()) // >>> #3 + fmt.Println(pd4Result.Val()) // >>> #4 + // STEP_END + + // STEP_START basic_trans + trans := rdb.TxPipeline() + + trans.IncrBy(ctx, "counter:1", 1) + trans.IncrBy(ctx, "counter:2", 2) + trans.IncrBy(ctx, "counter:3", 3) + + cmds, err = trans.Exec(ctx) + + for _, c := range cmds { + fmt.Println(c.(*redis.IntCmd).Val()) + } + // >>> 1 + // >>> 2 + // >>> 3 + // STEP_END + + // STEP_START basic_trans_txpipelined + var tx1Result *redis.IntCmd + var tx2Result *redis.IntCmd + var tx3Result *redis.IntCmd + + cmds, err = rdb.TxPipelined(ctx, func(trans redis.Pipeliner) error { + tx1Result = trans.IncrBy(ctx, "counter:1", 1) + tx2Result = trans.IncrBy(ctx, "counter:2", 2) + tx3Result = trans.IncrBy(ctx, "counter:3", 3) + return nil + }) + + if err != nil { + panic(err) + } + + fmt.Println(tx1Result.Val()) // >>> 2 + fmt.Println(tx2Result.Val()) // >>> 4 + fmt.Println(tx3Result.Val()) // >>> 6 + // STEP_END + + // STEP_START trans_watch + // Set initial value of `shellpath`. + rdb.Set(ctx, "shellpath", "/usr/syscmds/", 0) + + const maxRetries = 1000 + + // Retry if the key has been changed. + for i := 0; i < maxRetries; i++ { + err := rdb.Watch(ctx, + func(tx *redis.Tx) error { + currentPath, err := rdb.Get(ctx, "shellpath").Result() + newPath := currentPath + ":/usr/mycmds/" + + _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Set(ctx, "shellpath", newPath, 0) + return nil + }) + + return err + }, + "shellpath", + ) + + if err == nil { + // Success. + break + } else if err == redis.TxFailedErr { + // Optimistic lock lost. Retry the transaction. + continue + } else { + // Panic for any other error. + panic(err) + } + } + + fmt.Println(rdb.Get(ctx, "shellpath").Val()) + // >>> /usr/syscmds/:/usr/mycmds/ + // STEP_END + + // Output: + // OK;OK;OK;OK;OK; + // #0 + // #3 + // #4 + // #0 + // #3 + // #4 + // 1 + // 2 + // 3 + // 2 + // 4 + // 6 + // /usr/syscmds/:/usr/mycmds/ +} From 0e3ea5fd6bbd5b4e485bfb5256de82966ccb3c1e Mon Sep 17 00:00:00 2001 From: andy-stark-redis <164213578+andy-stark-redis@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:02:55 +0000 Subject: [PATCH 2/5] DOC-4449 hash command examples (#3229) * DOC-4450 added hgetall and hvals doc examples * DOC-4449 added hgetall and hvals doc examples * DOC-4449 rewrote to avoid Collect and Keys functions (not available in test version of Go) * DOC-4449 replaced slices.Sort function with older alternative * DOC-4449 removed another instance of slices.Sort * DOC-4449 fixed bugs in tests * DOC-4449 try sort.Strings() for sorting key lists --------- Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> --- doctests/cmds_hash_test.go | 114 ++++++++++++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 3 deletions(-) diff --git a/doctests/cmds_hash_test.go b/doctests/cmds_hash_test.go index f9630a9de..52ade74e9 100644 --- a/doctests/cmds_hash_test.go +++ b/doctests/cmds_hash_test.go @@ -5,6 +5,7 @@ package example_commands_test import ( "context" "fmt" + "sort" "github.com/redis/go-redis/v9" ) @@ -74,8 +75,20 @@ func ExampleClient_hset() { panic(err) } - fmt.Println(res6) - // >>> map[field1:Hello field2:Hi field3:World] + keys := make([]string, 0, len(res6)) + + for key, _ := range res6 { + keys = append(keys, key) + } + + sort.Strings(keys) + + for _, key := range keys { + fmt.Printf("Key: %v, value: %v\n", key, res6[key]) + } + // >>> Key: field1, value: Hello + // >>> Key: field2, value: Hi + // >>> Key: field3, value: World // STEP_END // Output: @@ -84,7 +97,9 @@ func ExampleClient_hset() { // 2 // Hi // World - // map[field1:Hello field2:Hi field3:World] + // Key: field1, value: Hello + // Key: field2, value: Hi + // Key: field3, value: World } func ExampleClient_hget() { @@ -131,3 +146,96 @@ func ExampleClient_hget() { // foo // redis: nil } + +func ExampleClient_hgetall() { + ctx := context.Background() + + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", // no password + DB: 0, // use default DB + }) + + // REMOVE_START + rdb.Del(ctx, "myhash") + // REMOVE_END + + // STEP_START hgetall + hGetAllResult1, err := rdb.HSet(ctx, "myhash", + "field1", "Hello", + "field2", "World", + ).Result() + + if err != nil { + panic(err) + } + + fmt.Println(hGetAllResult1) // >>> 2 + + hGetAllResult2, err := rdb.HGetAll(ctx, "myhash").Result() + + if err != nil { + panic(err) + } + + keys := make([]string, 0, len(hGetAllResult2)) + + for key, _ := range hGetAllResult2 { + keys = append(keys, key) + } + + sort.Strings(keys) + + for _, key := range keys { + fmt.Printf("Key: %v, value: %v\n", key, hGetAllResult2[key]) + } + // >>> Key: field1, value: Hello + // >>> Key: field2, value: World + // STEP_END + + // Output: + // 2 + // Key: field1, value: Hello + // Key: field2, value: World +} + +func ExampleClient_hvals() { + ctx := context.Background() + + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", // no password docs + DB: 0, // use default DB + }) + + // REMOVE_START + rdb.Del(ctx, "myhash") + // REMOVE_END + + // STEP_START hvals + hValsResult1, err := rdb.HSet(ctx, "myhash", + "field1", "Hello", + "field2", "World", + ).Result() + + if err != nil { + panic(err) + } + + fmt.Println(hValsResult1) // >>> 2 + + hValsResult2, err := rdb.HVals(ctx, "myhash").Result() + + if err != nil { + panic(err) + } + + sort.Strings(hValsResult2) + + fmt.Println(hValsResult2) // >>> [Hello World] + // STEP_END + + // Output: + // 2 + // [Hello World] +} From efe0f65bf0dde8b04d6a94c84f21a7b3cba56303 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 20 Jan 2025 11:32:10 +0200 Subject: [PATCH 3/5] Order slices of strings to be sure what the output of Println in doctests will be. (#3241) * Sort the slices of strings in doctest to make the output deterministic * fix wording --- doctests/sets_example_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/doctests/sets_example_test.go b/doctests/sets_example_test.go index 7446a2789..2d6504e2b 100644 --- a/doctests/sets_example_test.go +++ b/doctests/sets_example_test.go @@ -5,6 +5,7 @@ package example_commands_test import ( "context" "fmt" + "sort" "github.com/redis/go-redis/v9" ) @@ -215,6 +216,9 @@ func ExampleClient_saddsmembers() { panic(err) } + // Sort the strings in the slice to make sure the output is lexicographical + sort.Strings(res10) + fmt.Println(res10) // >>> [bike:1 bike:2 bike:3] // STEP_END @@ -294,6 +298,10 @@ func ExampleClient_sdiff() { panic(err) } + + // Sort the strings in the slice to make sure the output is lexicographical + sort.Strings(res13) + fmt.Println(res13) // >>> [bike:2 bike:3] // STEP_END @@ -349,6 +357,9 @@ func ExampleClient_multisets() { panic(err) } + // Sort the strings in the slice to make sure the output is lexicographical + sort.Strings(res15) + fmt.Println(res15) // >>> [bike:1 bike:2 bike:3 bike:4] res16, err := rdb.SDiff(ctx, "bikes:racing:france", "bikes:racing:usa", "bikes:racing:italy").Result() @@ -373,6 +384,9 @@ func ExampleClient_multisets() { panic(err) } + // Sort the strings in the slice to make sure the output is lexicographical + sort.Strings(res18) + fmt.Println(res18) // >>> [bike:2 bike:3] // STEP_END From 36e96654db445007593fb1dad8d61e082019b030 Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Wed, 29 Jan 2025 00:40:58 +0200 Subject: [PATCH 4/5] Create Credential Provider for EntraID --- entra_id/credentials_provider.go | 57 ++++++++++++ entra_id/credentials_provider_test.go | 81 ++++++++++++++++ entra_id/entra_id_suite_test.go | 13 +++ entra_id/go.mod | 13 +++ entra_id/go.sum | 4 + entra_id/token_manager.go | 129 ++++++++++++++++++++++++++ entra_id/token_manager_test.go | 125 +++++++++++++++++++++++++ 7 files changed, 422 insertions(+) create mode 100644 entra_id/credentials_provider.go create mode 100644 entra_id/credentials_provider_test.go create mode 100644 entra_id/entra_id_suite_test.go create mode 100644 entra_id/go.mod create mode 100644 entra_id/go.sum create mode 100644 entra_id/token_manager.go create mode 100644 entra_id/token_manager_test.go diff --git a/entra_id/credentials_provider.go b/entra_id/credentials_provider.go new file mode 100644 index 000000000..ea16b0995 --- /dev/null +++ b/entra_id/credentials_provider.go @@ -0,0 +1,57 @@ +package entra_id + +import ( + "log" + "time" +) + +// EntraIdIdentityProvider defines the interface for an identity provider +type EntraIdIdentityProvider interface { + RequestToken(forceRefresh bool) (string, time.Duration, error) +} + +// EntraIdCredentialsProvider manages credentials and token lifecycle +type EntraIdCredentialsProvider struct { + tokenManager *TokenManager + isStreaming bool +} + +// NewEntraIdCredentialsProvider initializes a new credentials provider +func NewEntraIdCredentialsProvider(idp EntraIdIdentityProvider, refreshInterval time.Duration, telemetryEnabled bool) *EntraIdCredentialsProvider { + refreshFunc := func() (string, time.Duration, error) { + return idp.RequestToken(false) + } + + tokenManager := NewTokenManager(refreshFunc, refreshInterval, telemetryEnabled) + + return &EntraIdCredentialsProvider{ + tokenManager: tokenManager, + isStreaming: false, + } +} + +// GetCredentials retrieves the current token or refreshes it if needed +func (cp *EntraIdCredentialsProvider) GetCredentials() (string, error) { + token, valid := cp.tokenManager.GetToken() + if !valid { + if err := cp.tokenManager.RefreshToken(); err != nil { + log.Printf("[EntraIdCredentialsProvider] Failed to refresh token: %v", err) + return "", err + } + token, _ = cp.tokenManager.GetToken() + } + + // Start streaming if not already started + if !cp.isStreaming { + cp.tokenManager.StartAutoRefresh() + cp.isStreaming = true + } + + return token, nil +} + +// Stop stops the credentials provider and cleans up resources +func (cp *EntraIdCredentialsProvider) Stop() { + cp.tokenManager.StopAutoRefresh() + log.Println("[EntraIdCredentialsProvider] Stopped and cleaned up resources.") +} diff --git a/entra_id/credentials_provider_test.go b/entra_id/credentials_provider_test.go new file mode 100644 index 000000000..e6aaf4647 --- /dev/null +++ b/entra_id/credentials_provider_test.go @@ -0,0 +1,81 @@ +package entra_id_test + +import ( + "errors" + "time" + + . "github.com/bsm/ginkgo/v2" + . "github.com/bsm/gomega" + "github.com/go-redis/entra_id" +) + +type MockEntraIdIdentityProvider struct { + Token string + TTL time.Duration + Error error +} + +func (m *MockEntraIdIdentityProvider) RequestToken(forceRefresh bool) (string, time.Duration, error) { + if m.Error != nil { + return "", 0, m.Error + } + return m.Token, m.TTL, nil +} + +var _ = Describe("EntraIdCredentialsProvider", func() { + var ( + provider *entra_id.EntraIdCredentialsProvider + mockIDP *MockEntraIdIdentityProvider + refreshRate time.Duration + ) + + BeforeEach(func() { + refreshRate = 1 * time.Minute + mockIDP = &MockEntraIdIdentityProvider{ + Token: "mock-token", + TTL: 10 * time.Second, + Error: nil, + } + provider = entra_id.NewEntraIdCredentialsProvider(mockIDP, refreshRate, true) + }) + + AfterEach(func() { + provider.Stop() + }) + + Context("Initial Token Retrieval", func() { + It("should retrieve a valid token from the identity provider", func() { + token, err := provider.GetCredentials() + Expect(err).To(BeNil()) + Expect(token).To(Equal("mock-token")) + }) + + It("should return an error if the identity provider fails", func() { + mockIDP.Error = errors.New("identity provider failure") + _, err := provider.GetCredentials() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("identity provider failure")) + }) + }) + + Context("Automatic Token Renewal", func() { + It("should automatically refresh the token when it expires", func() { + token, err := provider.GetCredentials() + Expect(err).To(BeNil()) + Expect(token).To(Equal("mock-token")) + + time.Sleep(11 * time.Second) // Wait for token expiry and auto-refresh + + newToken, err := provider.GetCredentials() + Expect(err).To(BeNil()) + Expect(newToken).To(Equal("mock-token")) // Mock still returns the same token + }) + }) + + Context("Stop Streaming", func() { + It("should stop token renewal and clean up resources when Stop is called", func() { + provider.GetCredentials() // Start streaming + // Ensure no further actions or panics occur after stopping, the stopping ocuur in the AfterEach + }) + }) +}) diff --git a/entra_id/entra_id_suite_test.go b/entra_id/entra_id_suite_test.go new file mode 100644 index 000000000..885c79b62 --- /dev/null +++ b/entra_id/entra_id_suite_test.go @@ -0,0 +1,13 @@ +package entra_id_test + +import ( + "testing" + + . "github.com/bsm/ginkgo/v2" + . "github.com/bsm/gomega" +) + +func TestEntraId(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "EntraId Suite") +} diff --git a/entra_id/go.mod b/entra_id/go.mod new file mode 100644 index 000000000..0871d78fc --- /dev/null +++ b/entra_id/go.mod @@ -0,0 +1,13 @@ +module entra_id + +go 1.22.0 + +toolchain go1.23.1 + +replace github.com/go-redis/entra_id => ./ + +require ( + github.com/bsm/ginkgo/v2 v2.12.0 + github.com/bsm/gomega v1.27.10 + github.com/go-redis/entra_id v0.0.0-00010101000000-000000000000 +) diff --git a/entra_id/go.sum b/entra_id/go.sum new file mode 100644 index 000000000..cf6b62ca8 --- /dev/null +++ b/entra_id/go.sum @@ -0,0 +1,4 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= diff --git a/entra_id/token_manager.go b/entra_id/token_manager.go new file mode 100644 index 000000000..9e921cdcf --- /dev/null +++ b/entra_id/token_manager.go @@ -0,0 +1,129 @@ +package entra_id + +import ( + "log" + "sync" + "time" +) + +type TokenManager struct { + token string + expiresAt time.Time + mutex sync.Mutex + refreshFunc func() (string, time.Duration, error) + stopChan chan struct{} + refreshTicker *time.Ticker + refreshInterval time.Duration + telemetryEnabled bool +} + +// NewTokenManager initializes a new TokenManager. +func NewTokenManager(refreshFunc func() (string, time.Duration, error), refreshInterval time.Duration, telemetryEnabled bool) *TokenManager { + return &TokenManager{ + refreshFunc: refreshFunc, + stopChan: make(chan struct{}), + refreshInterval: refreshInterval, + telemetryEnabled: telemetryEnabled, + } +} + +// SetToken updates the token and its expiration. +func (tm *TokenManager) SetToken(token string, ttl time.Duration) { + tm.mutex.Lock() + defer tm.mutex.Unlock() + tm.token = token + tm.expiresAt = time.Now().Add(ttl) + log.Printf("[TokenManager] Token updated with TTL: %s", ttl) +} + +// GetToken returns the current token if it's still valid. +func (tm *TokenManager) GetToken() (string, bool) { + tm.mutex.Lock() + defer tm.mutex.Unlock() + if time.Now().After(tm.expiresAt) { + return "", false + } + return tm.token, true +} + +// RefreshToken fetches a new token using the provided refresh function. +func (tm *TokenManager) RefreshToken() error { + if tm.refreshFunc == nil { + return nil + } + token, ttl, err := tm.refreshFunc() + if err != nil { + log.Printf("[TokenManager] Failed to refresh token: %v", err) + return err + } + tm.SetToken(token, ttl) + log.Println("[TokenManager] Token refreshed successfully.") + return nil +} + +// StartAutoRefresh starts a goroutine to proactively refresh the token. +func (tm *TokenManager) StartAutoRefresh() { + tm.refreshTicker = time.NewTicker(tm.refreshInterval) + go func() { + for { + select { + case <-tm.refreshTicker.C: + if tm.shouldRefresh() { + log.Println("[TokenManager] Proactively refreshing token...") + if err := tm.RefreshToken(); err != nil { + log.Printf("[TokenManager] Error during token refresh: %v", err) + } + } + case <-tm.stopChan: + log.Println("[TokenManager] Stopping auto-refresh...") + return + } + } + }() +} + +// StopAutoRefresh stops the auto-refresh goroutine and cleans up resources. +func (tm *TokenManager) StopAutoRefresh() { + if tm.refreshTicker != nil { + tm.refreshTicker.Stop() + } + close(tm.stopChan) + log.Println("[TokenManager] Auto-refresh stopped and resources cleaned.") +} + +// shouldRefresh determines if the token should be refreshed. +func (tm *TokenManager) shouldRefresh() bool { + tm.mutex.Lock() + defer tm.mutex.Unlock() + remaining := time.Until(tm.expiresAt) + + // Trigger refresh when less than 20% of TTL remains + return remaining < (tm.refreshInterval / 5) +} + +// MonitorTelemetry adds monitoring for token usage and expiration. +func (tm *TokenManager) MonitorTelemetry() { + if !tm.telemetryEnabled { + return + } + + go func() { + ticker := time.NewTicker(30 * time.Second) // Adjust as needed + defer ticker.Stop() + + for { + select { + case <-ticker.C: + _, valid := tm.GetToken() + if !valid { + log.Println("[TokenManager] Token has expired.") + } else { + log.Printf("[TokenManager] Token is valid: expires in %s", time.Until(tm.expiresAt)) + } + case <-tm.stopChan: + log.Println("[TokenManager] Telemetry monitoring stopped.") + return + } + } + }() +} diff --git a/entra_id/token_manager_test.go b/entra_id/token_manager_test.go new file mode 100644 index 000000000..866288172 --- /dev/null +++ b/entra_id/token_manager_test.go @@ -0,0 +1,125 @@ +package entra_id_test + +import ( + "errors" + "sync" + "time" + + . "github.com/bsm/ginkgo/v2" + . "github.com/bsm/gomega" + "github.com/go-redis/entra_id" +) + +var _ = Describe("TokenManager", func() { + var ( + tokenManager *entra_id.TokenManager + mockRefresh func() (string, time.Duration, error) + ) + + BeforeEach(func() { + mockRefresh = func() (string, time.Duration, error) { + return "new-token", 10 * time.Second, nil + } + tokenManager = entra_id.NewTokenManager(mockRefresh, 1*time.Minute, true) + }) + + AfterEach(func() { + tokenManager.StopAutoRefresh() + }) + + Context("Token Refresh", func() { + It("should refresh the token successfully", func() { + err := tokenManager.RefreshToken() + Expect(err).To(BeNil()) + + token, valid := tokenManager.GetToken() + Expect(valid).To(BeTrue()) + Expect(token).To(Equal("new-token")) + }) + + It("should return an error if the refresh function fails", func() { + failingRefresh := func() (string, time.Duration, error) { + return "", 0, errors.New("refresh failed") + } + tokenManager = entra_id.NewTokenManager(failingRefresh, 1*time.Minute, true) + err := tokenManager.RefreshToken() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("refresh failed")) + }) + }) + + Context("Token Expiry", func() { + It("should return false if the token has expired", func() { + tokenManager.SetToken("expired-token", 1*time.Second) + time.Sleep(2 * time.Second) + + _, valid := tokenManager.GetToken() + Expect(valid).To(BeFalse()) + }) + }) + + Context("Auto-Refresh", func() { + It("should automatically refresh the token before it expires", func() { + refreshed := make(chan struct{}) + mockRefresh := func() (string, time.Duration, error) { + close(refreshed) // Signal that refresh occurred + return "new-token", 10 * time.Second, nil + } + + tokenManager := entra_id.NewTokenManager(mockRefresh, 1*time.Second, false) + tokenManager.SetToken("old-token", 5*time.Second) + tokenManager.StartAutoRefresh() + + select { + case <-refreshed: + // Token refreshed successfully + case <-time.After(6 * time.Second): + Fail("Token refresh did not occur in time") + } + + token, valid := tokenManager.GetToken() + Expect(valid).To(BeTrue()) + Expect(token).To(Equal("new-token")) + + tokenManager.StopAutoRefresh() + }) + + It("should stop auto-refresh when StopAutoRefresh is called", func() { + mockRefresh := func() (string, time.Duration, error) { + return "new-token", 10 * time.Second, nil + } + + tokenManager := entra_id.NewTokenManager(mockRefresh, 1*time.Second, false) + tokenManager.StartAutoRefresh() + tokenManager.StopAutoRefresh() + time.Sleep(2 * time.Second) + + Expect(tokenManager.GetToken()).ToNot(BeNil()) // Ensure no panic or issues after stopping + }) + }) + + Context("Concurrency", func() { + It("should handle concurrent access without race conditions", func() { + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = tokenManager.GetToken() + tokenManager.SetToken("concurrent-token", 10*time.Second) + }() + } + wg.Wait() + }) + }) + + Context("Telemetry", func() { + It("should log token expiration and validity during monitoring", func() { + tokenManager.SetToken("telemetry-token", 2*time.Second) + go tokenManager.MonitorTelemetry() + + time.Sleep(4 * time.Second) + // Log verification can be done by capturing logs (if necessary) + }) + }) +}) From 444fe552fb820520fe0c6470a9da2192099241d5 Mon Sep 17 00:00:00 2001 From: ofekshenawa Date: Wed, 29 Jan 2025 16:18:46 +0200 Subject: [PATCH 5/5] update min go version --- entra_id/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/entra_id/go.mod b/entra_id/go.mod index 0871d78fc..6cf9abd45 100644 --- a/entra_id/go.mod +++ b/entra_id/go.mod @@ -1,6 +1,6 @@ module entra_id -go 1.22.0 +go 1.18.0 toolchain go1.23.1