| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- package s3
- import (
- "context"
- "crypto/hmac"
- "crypto/sha256"
- "errors"
- "fmt"
- "sync"
- "time"
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/internal/sdk"
- "github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
- "github.com/aws/smithy-go/container/private/cache"
- "github.com/aws/smithy-go/container/private/cache/lru"
- )
- const s3ExpressCacheCap = 100
- const s3ExpressRefreshWindow = 1 * time.Minute
- type cacheKey struct {
- CredentialsHash string // hmac(sigv4 akid, sigv4 secret)
- Bucket string
- }
- func (c cacheKey) Slug() string {
- return fmt.Sprintf("%s%s", c.CredentialsHash, c.Bucket)
- }
- type sessionCredsCache struct {
- mu sync.Mutex
- cache cache.Cache
- }
- func (c *sessionCredsCache) Get(key cacheKey) (*aws.Credentials, bool) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if v, ok := c.cache.Get(key); ok {
- return v.(*aws.Credentials), true
- }
- return nil, false
- }
- func (c *sessionCredsCache) Put(key cacheKey, creds *aws.Credentials) {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.cache.Put(key, creds)
- }
- // The default S3Express provider uses an LRU cache with a capacity of 100.
- //
- // Credentials will be refreshed asynchronously when a Retrieve() call is made
- // for cached credentials within an expiry window (1 minute, currently
- // non-configurable).
- type defaultS3ExpressCredentialsProvider struct {
- sf singleflight.Group
- client createSessionAPIClient
- cache *sessionCredsCache
- refreshWindow time.Duration
- v4creds aws.CredentialsProvider // underlying credentials used for CreateSession
- }
- type createSessionAPIClient interface {
- CreateSession(context.Context, *CreateSessionInput, ...func(*Options)) (*CreateSessionOutput, error)
- }
- func newDefaultS3ExpressCredentialsProvider() *defaultS3ExpressCredentialsProvider {
- return &defaultS3ExpressCredentialsProvider{
- cache: &sessionCredsCache{
- cache: lru.New(s3ExpressCacheCap),
- },
- refreshWindow: s3ExpressRefreshWindow,
- }
- }
- // returns a cloned provider using new base credentials, used when per-op
- // config mutations change the credentials provider
- func (p *defaultS3ExpressCredentialsProvider) CloneWithBaseCredentials(v4creds aws.CredentialsProvider) *defaultS3ExpressCredentialsProvider {
- return &defaultS3ExpressCredentialsProvider{
- client: p.client,
- cache: p.cache,
- refreshWindow: p.refreshWindow,
- v4creds: v4creds,
- }
- }
- func (p *defaultS3ExpressCredentialsProvider) Retrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
- v4creds, err := p.v4creds.Retrieve(ctx)
- if err != nil {
- return aws.Credentials{}, fmt.Errorf("get sigv4 creds: %w", err)
- }
- key := cacheKey{
- CredentialsHash: gethmac(v4creds.AccessKeyID, v4creds.SecretAccessKey),
- Bucket: bucket,
- }
- creds, ok := p.cache.Get(key)
- if !ok || creds.Expired() {
- return p.awaitDoChanRetrieve(ctx, key)
- }
- if creds.Expires.Sub(sdk.NowTime()) <= p.refreshWindow {
- p.doChanRetrieve(ctx, key)
- }
- return *creds, nil
- }
- func (p *defaultS3ExpressCredentialsProvider) doChanRetrieve(ctx context.Context, key cacheKey) <-chan singleflight.Result {
- return p.sf.DoChan(key.Slug(), func() (interface{}, error) {
- return p.retrieve(ctx, key)
- })
- }
- func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
- ch := p.doChanRetrieve(ctx, key)
- select {
- case r := <-ch:
- return r.Val.(aws.Credentials), r.Err
- case <-ctx.Done():
- return aws.Credentials{}, errors.New("s3express retrieve credentials canceled")
- }
- }
- func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
- resp, err := p.client.CreateSession(ctx, &CreateSessionInput{
- Bucket: aws.String(key.Bucket),
- })
- if err != nil {
- return aws.Credentials{}, err
- }
- creds, err := credentialsFromResponse(resp)
- if err != nil {
- return aws.Credentials{}, err
- }
- p.cache.Put(key, creds)
- return *creds, nil
- }
- func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) {
- if o.Credentials == nil {
- return nil, errors.New("s3express session credentials unset")
- }
- if o.Credentials.AccessKeyId == nil || o.Credentials.SecretAccessKey == nil || o.Credentials.SessionToken == nil || o.Credentials.Expiration == nil {
- return nil, errors.New("s3express session credentials missing one or more required fields")
- }
- return &aws.Credentials{
- AccessKeyID: *o.Credentials.AccessKeyId,
- SecretAccessKey: *o.Credentials.SecretAccessKey,
- SessionToken: *o.Credentials.SessionToken,
- CanExpire: true,
- Expires: *o.Credentials.Expiration,
- }, nil
- }
- func gethmac(p, key string) string {
- hash := hmac.New(sha256.New, []byte(key))
- hash.Write([]byte(p))
- return string(hash.Sum(nil))
- }
|