request_compression.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. // Package requestcompression implements runtime support for smithy-modeled
  2. // request compression.
  3. //
  4. // This package is designated as private and is intended for use only by the
  5. // smithy client runtime. The exported API therein is not considered stable and
  6. // is subject to breaking changes without notice.
  7. package requestcompression
  8. import (
  9. "bytes"
  10. "context"
  11. "fmt"
  12. "github.com/aws/smithy-go/middleware"
  13. "github.com/aws/smithy-go/transport/http"
  14. "io"
  15. )
  16. const MaxRequestMinCompressSizeBytes = 10485760
  17. // Enumeration values for supported compress Algorithms.
  18. const (
  19. GZIP = "gzip"
  20. )
  21. type compressFunc func(io.Reader) ([]byte, error)
  22. var allowedAlgorithms = map[string]compressFunc{
  23. GZIP: gzipCompress,
  24. }
  25. // AddRequestCompression add requestCompression middleware to op stack
  26. func AddRequestCompression(stack *middleware.Stack, disabled bool, minBytes int64, algorithms []string) error {
  27. return stack.Serialize.Add(&requestCompression{
  28. disableRequestCompression: disabled,
  29. requestMinCompressSizeBytes: minBytes,
  30. compressAlgorithms: algorithms,
  31. }, middleware.After)
  32. }
  33. type requestCompression struct {
  34. disableRequestCompression bool
  35. requestMinCompressSizeBytes int64
  36. compressAlgorithms []string
  37. }
  38. // ID returns the ID of the middleware
  39. func (m requestCompression) ID() string {
  40. return "RequestCompression"
  41. }
  42. // HandleSerialize gzip compress the request's stream/body if enabled by config fields
  43. func (m requestCompression) HandleSerialize(
  44. ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
  45. ) (
  46. out middleware.SerializeOutput, metadata middleware.Metadata, err error,
  47. ) {
  48. if m.disableRequestCompression {
  49. return next.HandleSerialize(ctx, in)
  50. }
  51. // still need to check requestMinCompressSizeBytes in case it is out of range after service client config
  52. if m.requestMinCompressSizeBytes < 0 || m.requestMinCompressSizeBytes > MaxRequestMinCompressSizeBytes {
  53. return out, metadata, fmt.Errorf("invalid range for min request compression size bytes %d, must be within 0 and 10485760 inclusively", m.requestMinCompressSizeBytes)
  54. }
  55. req, ok := in.Request.(*http.Request)
  56. if !ok {
  57. return out, metadata, fmt.Errorf("unknown request type %T", req)
  58. }
  59. for _, algorithm := range m.compressAlgorithms {
  60. compressFunc := allowedAlgorithms[algorithm]
  61. if compressFunc != nil {
  62. if stream := req.GetStream(); stream != nil {
  63. size, found, err := req.StreamLength()
  64. if err != nil {
  65. return out, metadata, fmt.Errorf("error while finding request stream length, %v", err)
  66. } else if !found || size < m.requestMinCompressSizeBytes {
  67. return next.HandleSerialize(ctx, in)
  68. }
  69. compressedBytes, err := compressFunc(stream)
  70. if err != nil {
  71. return out, metadata, fmt.Errorf("failed to compress request stream, %v", err)
  72. }
  73. var newReq *http.Request
  74. if newReq, err = req.SetStream(bytes.NewReader(compressedBytes)); err != nil {
  75. return out, metadata, fmt.Errorf("failed to set request stream, %v", err)
  76. }
  77. *req = *newReq
  78. if val := req.Header.Get("Content-Encoding"); val != "" {
  79. req.Header.Set("Content-Encoding", fmt.Sprintf("%s, %s", val, algorithm))
  80. } else {
  81. req.Header.Set("Content-Encoding", algorithm)
  82. }
  83. }
  84. break
  85. }
  86. }
  87. return next.HandleSerialize(ctx, in)
  88. }