provider.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package ec2rolecreds
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "math"
  8. "path"
  9. "strings"
  10. "time"
  11. "github.com/aws/aws-sdk-go-v2/aws"
  12. "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
  13. sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
  14. "github.com/aws/aws-sdk-go-v2/internal/sdk"
  15. "github.com/aws/smithy-go"
  16. "github.com/aws/smithy-go/logging"
  17. "github.com/aws/smithy-go/middleware"
  18. )
  19. // ProviderName provides a name of EC2Role provider
  20. const ProviderName = "EC2RoleProvider"
  21. // GetMetadataAPIClient provides the interface for an EC2 IMDS API client for the
  22. // GetMetadata operation.
  23. type GetMetadataAPIClient interface {
  24. GetMetadata(context.Context, *imds.GetMetadataInput, ...func(*imds.Options)) (*imds.GetMetadataOutput, error)
  25. }
  26. // A Provider retrieves credentials from the EC2 service, and keeps track if
  27. // those credentials are expired.
  28. //
  29. // The New function must be used to create the with a custom EC2 IMDS client.
  30. //
  31. // p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{
  32. // o.Client = imds.New(imds.Options{/* custom options */})
  33. // })
  34. type Provider struct {
  35. options Options
  36. }
  37. // Options is a list of user settable options for setting the behavior of the Provider.
  38. type Options struct {
  39. // The API client that will be used by the provider to make GetMetadata API
  40. // calls to EC2 IMDS.
  41. //
  42. // If nil, the provider will default to the EC2 IMDS client.
  43. Client GetMetadataAPIClient
  44. // The chain of providers that was used to create this provider
  45. // These values are for reporting purposes and are not meant to be set up directly
  46. CredentialSources []aws.CredentialSource
  47. }
  48. // New returns an initialized Provider value configured to retrieve
  49. // credentials from EC2 Instance Metadata service.
  50. func New(optFns ...func(*Options)) *Provider {
  51. options := Options{}
  52. for _, fn := range optFns {
  53. fn(&options)
  54. }
  55. if options.Client == nil {
  56. options.Client = imds.New(imds.Options{})
  57. }
  58. return &Provider{
  59. options: options,
  60. }
  61. }
  62. // Retrieve retrieves credentials from the EC2 service. Error will be returned
  63. // if the request fails, or unable to extract the desired credentials.
  64. func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
  65. credsList, err := requestCredList(ctx, p.options.Client)
  66. if err != nil {
  67. return aws.Credentials{Source: ProviderName}, err
  68. }
  69. if len(credsList) == 0 {
  70. return aws.Credentials{Source: ProviderName},
  71. fmt.Errorf("unexpected empty EC2 IMDS role list")
  72. }
  73. credsName := credsList[0]
  74. roleCreds, err := requestCred(ctx, p.options.Client, credsName)
  75. if err != nil {
  76. return aws.Credentials{Source: ProviderName}, err
  77. }
  78. creds := aws.Credentials{
  79. AccessKeyID: roleCreds.AccessKeyID,
  80. SecretAccessKey: roleCreds.SecretAccessKey,
  81. SessionToken: roleCreds.Token,
  82. Source: ProviderName,
  83. CanExpire: true,
  84. Expires: roleCreds.Expiration,
  85. }
  86. // Cap role credentials Expires to 1 hour so they can be refreshed more
  87. // often. Jitter will be applied credentials cache if being used.
  88. if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) {
  89. creds.Expires = anHour
  90. }
  91. return creds, nil
  92. }
  93. // HandleFailToRefresh will extend the credentials Expires time if it it is
  94. // expired. If the credentials will not expire within the minimum time, they
  95. // will be returned.
  96. //
  97. // If the credentials cannot expire, the original error will be returned.
  98. func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) (
  99. aws.Credentials, error,
  100. ) {
  101. if !prevCreds.CanExpire {
  102. return aws.Credentials{}, err
  103. }
  104. if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) {
  105. return prevCreds, nil
  106. }
  107. newCreds := prevCreds
  108. randFloat64, err := sdkrand.CryptoRandFloat64()
  109. if err != nil {
  110. return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err)
  111. }
  112. // Random distribution of [5,15) minutes.
  113. expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute
  114. newCreds.Expires = sdk.NowTime().Add(expireOffset)
  115. logger := middleware.GetLogger(ctx)
  116. logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes()))
  117. return newCreds, nil
  118. }
  119. // AdjustExpiresBy will adds the passed in duration to the passed in
  120. // credential's Expires time, unless the time until Expires is less than 15
  121. // minutes. Returns the credentials, even if not updated.
  122. func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) (
  123. aws.Credentials, error,
  124. ) {
  125. if !creds.CanExpire {
  126. return creds, nil
  127. }
  128. if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) {
  129. return creds, nil
  130. }
  131. creds.Expires = creds.Expires.Add(dur)
  132. return creds, nil
  133. }
  134. // ec2RoleCredRespBody provides the shape for unmarshaling credential
  135. // request responses.
  136. type ec2RoleCredRespBody struct {
  137. // Success State
  138. Expiration time.Time
  139. AccessKeyID string
  140. SecretAccessKey string
  141. Token string
  142. // Error state
  143. Code string
  144. Message string
  145. }
  146. const iamSecurityCredsPath = "/iam/security-credentials/"
  147. // requestCredList requests a list of credentials from the EC2 service. If
  148. // there are no credentials, or there is an error making or receiving the
  149. // request
  150. func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) {
  151. resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
  152. Path: iamSecurityCredsPath,
  153. })
  154. if err != nil {
  155. return nil, fmt.Errorf("no EC2 IMDS role found, %w", err)
  156. }
  157. defer resp.Content.Close()
  158. credsList := []string{}
  159. s := bufio.NewScanner(resp.Content)
  160. for s.Scan() {
  161. credsList = append(credsList, s.Text())
  162. }
  163. if err := s.Err(); err != nil {
  164. return nil, fmt.Errorf("failed to read EC2 IMDS role, %w", err)
  165. }
  166. return credsList, nil
  167. }
  168. // requestCred requests the credentials for a specific credentials from the EC2 service.
  169. //
  170. // If the credentials cannot be found, or there is an error reading the response
  171. // and error will be returned.
  172. func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) {
  173. resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
  174. Path: path.Join(iamSecurityCredsPath, credsName),
  175. })
  176. if err != nil {
  177. return ec2RoleCredRespBody{},
  178. fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
  179. credsName, err)
  180. }
  181. defer resp.Content.Close()
  182. var respCreds ec2RoleCredRespBody
  183. if err := json.NewDecoder(resp.Content).Decode(&respCreds); err != nil {
  184. return ec2RoleCredRespBody{},
  185. fmt.Errorf("failed to decode %s EC2 IMDS role credentials, %w",
  186. credsName, err)
  187. }
  188. if !strings.EqualFold(respCreds.Code, "Success") {
  189. // If an error code was returned something failed requesting the role.
  190. return ec2RoleCredRespBody{},
  191. fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
  192. credsName,
  193. &smithy.GenericAPIError{Code: respCreds.Code, Message: respCreds.Message})
  194. }
  195. return respCreds, nil
  196. }
  197. // ProviderSources returns the credential chain that was used to construct this provider
  198. func (p *Provider) ProviderSources() []aws.CredentialSource {
  199. if p.options.CredentialSources == nil {
  200. return []aws.CredentialSource{aws.CredentialSourceIMDS}
  201. } // If no source has been set, assume this is used directly which means just call to assume role
  202. return p.options.CredentialSources
  203. }