token_cache.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. package bearer
  2. import (
  3. "context"
  4. "fmt"
  5. "sync/atomic"
  6. "time"
  7. smithycontext "github.com/aws/smithy-go/context"
  8. "github.com/aws/smithy-go/internal/sync/singleflight"
  9. )
  10. // package variable that can be override in unit tests.
  11. var timeNow = time.Now
  12. // TokenCacheOptions provides a set of optional configuration options for the
  13. // TokenCache TokenProvider.
  14. type TokenCacheOptions struct {
  15. // The duration before the token will expire when the credentials will be
  16. // refreshed. If DisableAsyncRefresh is true, the RetrieveBearerToken calls
  17. // will be blocking.
  18. //
  19. // Asynchronous refreshes are deduplicated, and only one will be in-flight
  20. // at a time. If the token expires while an asynchronous refresh is in
  21. // flight, the next call to RetrieveBearerToken will block on that refresh
  22. // to return.
  23. RefreshBeforeExpires time.Duration
  24. // The timeout the underlying TokenProvider's RetrieveBearerToken call must
  25. // return within, or will be canceled. Defaults to 0, no timeout.
  26. //
  27. // If 0 timeout, its possible for the underlying tokenProvider's
  28. // RetrieveBearerToken call to block forever. Preventing subsequent
  29. // TokenCache attempts to refresh the token.
  30. //
  31. // If this timeout is reached all pending deduplicated calls to
  32. // TokenCache RetrieveBearerToken will fail with an error.
  33. RetrieveBearerTokenTimeout time.Duration
  34. // The minimum duration between asynchronous refresh attempts. If the next
  35. // asynchronous recent refresh attempt was within the minimum delay
  36. // duration, the call to retrieve will return the current cached token, if
  37. // not expired.
  38. //
  39. // The asynchronous retrieve is deduplicated across multiple calls when
  40. // RetrieveBearerToken is called. The asynchronous retrieve is not a
  41. // periodic task. It is only performed when the token has not yet expired,
  42. // and the current item is within the RefreshBeforeExpires window, and the
  43. // TokenCache's RetrieveBearerToken method is called.
  44. //
  45. // If 0, (default) there will be no minimum delay between asynchronous
  46. // refresh attempts.
  47. //
  48. // If DisableAsyncRefresh is true, this option is ignored.
  49. AsyncRefreshMinimumDelay time.Duration
  50. // Sets if the TokenCache will attempt to refresh the token in the
  51. // background asynchronously instead of blocking for credentials to be
  52. // refreshed. If disabled token refresh will be blocking.
  53. //
  54. // The first call to RetrieveBearerToken will always be blocking, because
  55. // there is no cached token.
  56. DisableAsyncRefresh bool
  57. }
  58. // TokenCache provides an utility to cache Bearer Authentication tokens from a
  59. // wrapped TokenProvider. The TokenCache can be has options to configure the
  60. // cache's early and asynchronous refresh of the token.
  61. type TokenCache struct {
  62. options TokenCacheOptions
  63. provider TokenProvider
  64. cachedToken atomic.Value
  65. lastRefreshAttemptTime atomic.Value
  66. sfGroup singleflight.Group
  67. }
  68. // NewTokenCache returns a initialized TokenCache that implements the
  69. // TokenProvider interface. Wrapping the provider passed in. Also taking a set
  70. // of optional functional option parameters to configure the token cache.
  71. func NewTokenCache(provider TokenProvider, optFns ...func(*TokenCacheOptions)) *TokenCache {
  72. var options TokenCacheOptions
  73. for _, fn := range optFns {
  74. fn(&options)
  75. }
  76. return &TokenCache{
  77. options: options,
  78. provider: provider,
  79. }
  80. }
  81. // RetrieveBearerToken returns the token if it could be obtained, or error if a
  82. // valid token could not be retrieved.
  83. //
  84. // The passed in Context's cancel/deadline/timeout will impacting only this
  85. // individual retrieve call and not any other already queued up calls. This
  86. // means underlying provider's RetrieveBearerToken calls could block for ever,
  87. // and not be canceled with the Context. Set RetrieveBearerTokenTimeout to
  88. // provide a timeout, preventing the underlying TokenProvider blocking forever.
  89. //
  90. // By default, if the passed in Context is canceled, all of its values will be
  91. // considered expired. The wrapped TokenProvider will not be able to lookup the
  92. // values from the Context once it is expired. This is done to protect against
  93. // expired values no longer being valid. To disable this behavior, use
  94. // smithy-go's context.WithPreserveExpiredValues to add a value to the Context
  95. // before calling RetrieveBearerToken to enable support for expired values.
  96. //
  97. // Without RetrieveBearerTokenTimeout there is the potential for a underlying
  98. // Provider's RetrieveBearerToken call to sit forever. Blocking in subsequent
  99. // attempts at refreshing the token.
  100. func (p *TokenCache) RetrieveBearerToken(ctx context.Context) (Token, error) {
  101. cachedToken, ok := p.getCachedToken()
  102. if !ok || cachedToken.Expired(timeNow()) {
  103. return p.refreshBearerToken(ctx)
  104. }
  105. // Check if the token should be refreshed before it expires.
  106. refreshToken := cachedToken.Expired(timeNow().Add(p.options.RefreshBeforeExpires))
  107. if !refreshToken {
  108. return cachedToken, nil
  109. }
  110. if p.options.DisableAsyncRefresh {
  111. return p.refreshBearerToken(ctx)
  112. }
  113. p.tryAsyncRefresh(ctx)
  114. return cachedToken, nil
  115. }
  116. // tryAsyncRefresh attempts to asynchronously refresh the token returning the
  117. // already cached token. If it AsyncRefreshMinimumDelay option is not zero, and
  118. // the duration since the last refresh is less than that value, nothing will be
  119. // done.
  120. func (p *TokenCache) tryAsyncRefresh(ctx context.Context) {
  121. if p.options.AsyncRefreshMinimumDelay != 0 {
  122. var lastRefreshAttempt time.Time
  123. if v := p.lastRefreshAttemptTime.Load(); v != nil {
  124. lastRefreshAttempt = v.(time.Time)
  125. }
  126. if timeNow().Before(lastRefreshAttempt.Add(p.options.AsyncRefreshMinimumDelay)) {
  127. return
  128. }
  129. }
  130. // Ignore the returned channel so this won't be blocking, and limit the
  131. // number of additional goroutines created.
  132. p.sfGroup.DoChan("async-refresh", func() (interface{}, error) {
  133. res, err := p.refreshBearerToken(ctx)
  134. if p.options.AsyncRefreshMinimumDelay != 0 {
  135. var refreshAttempt time.Time
  136. if err != nil {
  137. refreshAttempt = timeNow()
  138. }
  139. p.lastRefreshAttemptTime.Store(refreshAttempt)
  140. }
  141. return res, err
  142. })
  143. }
  144. func (p *TokenCache) refreshBearerToken(ctx context.Context) (Token, error) {
  145. resCh := p.sfGroup.DoChan("refresh-token", func() (interface{}, error) {
  146. ctx := smithycontext.WithSuppressCancel(ctx)
  147. if v := p.options.RetrieveBearerTokenTimeout; v != 0 {
  148. var cancel func()
  149. ctx, cancel = context.WithTimeout(ctx, v)
  150. defer cancel()
  151. }
  152. return p.singleRetrieve(ctx)
  153. })
  154. select {
  155. case res := <-resCh:
  156. return res.Val.(Token), res.Err
  157. case <-ctx.Done():
  158. return Token{}, fmt.Errorf("retrieve bearer token canceled, %w", ctx.Err())
  159. }
  160. }
  161. func (p *TokenCache) singleRetrieve(ctx context.Context) (interface{}, error) {
  162. token, err := p.provider.RetrieveBearerToken(ctx)
  163. if err != nil {
  164. return Token{}, fmt.Errorf("failed to retrieve bearer token, %w", err)
  165. }
  166. p.cachedToken.Store(&token)
  167. return token, nil
  168. }
  169. // getCachedToken returns the currently cached token and true if found. Returns
  170. // false if no token is cached.
  171. func (p *TokenCache) getCachedToken() (Token, bool) {
  172. v := p.cachedToken.Load()
  173. if v == nil {
  174. return Token{}, false
  175. }
  176. t := v.(*Token)
  177. if t == nil || t.Value == "" {
  178. return Token{}, false
  179. }
  180. return *t, true
  181. }