express_default.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package s3
  2. import (
  3. "context"
  4. "crypto/hmac"
  5. "crypto/sha256"
  6. "errors"
  7. "fmt"
  8. "sync"
  9. "time"
  10. "github.com/aws/aws-sdk-go-v2/aws"
  11. "github.com/aws/aws-sdk-go-v2/internal/sdk"
  12. "github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
  13. "github.com/aws/smithy-go/container/private/cache"
  14. "github.com/aws/smithy-go/container/private/cache/lru"
  15. )
  16. const s3ExpressCacheCap = 100
  17. const s3ExpressRefreshWindow = 1 * time.Minute
  18. type cacheKey struct {
  19. CredentialsHash string // hmac(sigv4 akid, sigv4 secret)
  20. Bucket string
  21. }
  22. func (c cacheKey) Slug() string {
  23. return fmt.Sprintf("%s%s", c.CredentialsHash, c.Bucket)
  24. }
  25. type sessionCredsCache struct {
  26. mu sync.Mutex
  27. cache cache.Cache
  28. }
  29. func (c *sessionCredsCache) Get(key cacheKey) (*aws.Credentials, bool) {
  30. c.mu.Lock()
  31. defer c.mu.Unlock()
  32. if v, ok := c.cache.Get(key); ok {
  33. return v.(*aws.Credentials), true
  34. }
  35. return nil, false
  36. }
  37. func (c *sessionCredsCache) Put(key cacheKey, creds *aws.Credentials) {
  38. c.mu.Lock()
  39. defer c.mu.Unlock()
  40. c.cache.Put(key, creds)
  41. }
  42. // The default S3Express provider uses an LRU cache with a capacity of 100.
  43. //
  44. // Credentials will be refreshed asynchronously when a Retrieve() call is made
  45. // for cached credentials within an expiry window (1 minute, currently
  46. // non-configurable).
  47. type defaultS3ExpressCredentialsProvider struct {
  48. sf singleflight.Group
  49. client createSessionAPIClient
  50. cache *sessionCredsCache
  51. refreshWindow time.Duration
  52. v4creds aws.CredentialsProvider // underlying credentials used for CreateSession
  53. }
  54. type createSessionAPIClient interface {
  55. CreateSession(context.Context, *CreateSessionInput, ...func(*Options)) (*CreateSessionOutput, error)
  56. }
  57. func newDefaultS3ExpressCredentialsProvider() *defaultS3ExpressCredentialsProvider {
  58. return &defaultS3ExpressCredentialsProvider{
  59. cache: &sessionCredsCache{
  60. cache: lru.New(s3ExpressCacheCap),
  61. },
  62. refreshWindow: s3ExpressRefreshWindow,
  63. }
  64. }
  65. // returns a cloned provider using new base credentials, used when per-op
  66. // config mutations change the credentials provider
  67. func (p *defaultS3ExpressCredentialsProvider) CloneWithBaseCredentials(v4creds aws.CredentialsProvider) *defaultS3ExpressCredentialsProvider {
  68. return &defaultS3ExpressCredentialsProvider{
  69. client: p.client,
  70. cache: p.cache,
  71. refreshWindow: p.refreshWindow,
  72. v4creds: v4creds,
  73. }
  74. }
  75. func (p *defaultS3ExpressCredentialsProvider) Retrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
  76. v4creds, err := p.v4creds.Retrieve(ctx)
  77. if err != nil {
  78. return aws.Credentials{}, fmt.Errorf("get sigv4 creds: %w", err)
  79. }
  80. key := cacheKey{
  81. CredentialsHash: gethmac(v4creds.AccessKeyID, v4creds.SecretAccessKey),
  82. Bucket: bucket,
  83. }
  84. creds, ok := p.cache.Get(key)
  85. if !ok || creds.Expired() {
  86. return p.awaitDoChanRetrieve(ctx, key)
  87. }
  88. if creds.Expires.Sub(sdk.NowTime()) <= p.refreshWindow {
  89. p.doChanRetrieve(ctx, key)
  90. }
  91. return *creds, nil
  92. }
  93. func (p *defaultS3ExpressCredentialsProvider) doChanRetrieve(ctx context.Context, key cacheKey) <-chan singleflight.Result {
  94. return p.sf.DoChan(key.Slug(), func() (interface{}, error) {
  95. return p.retrieve(ctx, key)
  96. })
  97. }
  98. func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
  99. ch := p.doChanRetrieve(ctx, key)
  100. select {
  101. case r := <-ch:
  102. return r.Val.(aws.Credentials), r.Err
  103. case <-ctx.Done():
  104. return aws.Credentials{}, errors.New("s3express retrieve credentials canceled")
  105. }
  106. }
  107. func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
  108. resp, err := p.client.CreateSession(ctx, &CreateSessionInput{
  109. Bucket: aws.String(key.Bucket),
  110. })
  111. if err != nil {
  112. return aws.Credentials{}, err
  113. }
  114. creds, err := credentialsFromResponse(resp)
  115. if err != nil {
  116. return aws.Credentials{}, err
  117. }
  118. p.cache.Put(key, creds)
  119. return *creds, nil
  120. }
  121. func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) {
  122. if o.Credentials == nil {
  123. return nil, errors.New("s3express session credentials unset")
  124. }
  125. if o.Credentials.AccessKeyId == nil || o.Credentials.SecretAccessKey == nil || o.Credentials.SessionToken == nil || o.Credentials.Expiration == nil {
  126. return nil, errors.New("s3express session credentials missing one or more required fields")
  127. }
  128. return &aws.Credentials{
  129. AccessKeyID: *o.Credentials.AccessKeyId,
  130. SecretAccessKey: *o.Credentials.SecretAccessKey,
  131. SessionToken: *o.Credentials.SessionToken,
  132. CanExpire: true,
  133. Expires: *o.Credentials.Expiration,
  134. }, nil
  135. }
  136. func gethmac(p, key string) string {
  137. hash := hmac.New(sha256.New, []byte(key))
  138. hash.Write([]byte(p))
  139. return string(hash.Sum(nil))
  140. }