headers.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package jwe
  2. import (
  3. "context"
  4. "github.com/lestrrat-go/jwx/internal/base64"
  5. "github.com/lestrrat-go/jwx/internal/json"
  6. "github.com/lestrrat-go/iter/mapiter"
  7. "github.com/lestrrat-go/jwx/internal/iter"
  8. "github.com/pkg/errors"
  9. )
  10. type isZeroer interface {
  11. isZero() bool
  12. }
  13. func (h *stdHeaders) isZero() bool {
  14. return h.agreementPartyUInfo == nil &&
  15. h.agreementPartyVInfo == nil &&
  16. h.algorithm == nil &&
  17. h.compression == nil &&
  18. h.contentEncryption == nil &&
  19. h.contentType == nil &&
  20. h.critical == nil &&
  21. h.ephemeralPublicKey == nil &&
  22. h.jwk == nil &&
  23. h.jwkSetURL == nil &&
  24. h.keyID == nil &&
  25. h.typ == nil &&
  26. h.x509CertChain == nil &&
  27. h.x509CertThumbprint == nil &&
  28. h.x509CertThumbprintS256 == nil &&
  29. h.x509URL == nil &&
  30. len(h.privateParams) == 0
  31. }
  32. // Iterate returns a channel that successively returns all the
  33. // header name and values.
  34. func (h *stdHeaders) Iterate(ctx context.Context) Iterator {
  35. pairs := h.makePairs()
  36. ch := make(chan *HeaderPair, len(pairs))
  37. go func(ctx context.Context, ch chan *HeaderPair, pairs []*HeaderPair) {
  38. defer close(ch)
  39. for _, pair := range pairs {
  40. select {
  41. case <-ctx.Done():
  42. return
  43. case ch <- pair:
  44. }
  45. }
  46. }(ctx, ch, pairs)
  47. return mapiter.New(ch)
  48. }
  49. func (h *stdHeaders) Walk(ctx context.Context, visitor Visitor) error {
  50. return iter.WalkMap(ctx, h, visitor)
  51. }
  52. func (h *stdHeaders) AsMap(ctx context.Context) (map[string]interface{}, error) {
  53. return iter.AsMap(ctx, h)
  54. }
  55. func (h *stdHeaders) Clone(ctx context.Context) (Headers, error) {
  56. dst := NewHeaders()
  57. if err := h.Copy(ctx, dst); err != nil {
  58. return nil, errors.Wrap(err, `failed to copy header contents to new object`)
  59. }
  60. return dst, nil
  61. }
  62. func (h *stdHeaders) Copy(ctx context.Context, dst Headers) error {
  63. for _, pair := range h.makePairs() {
  64. //nolint:forcetypeassert
  65. key := pair.Key.(string)
  66. if err := dst.Set(key, pair.Value); err != nil {
  67. return errors.Wrapf(err, `failed to set header %q`, key)
  68. }
  69. }
  70. return nil
  71. }
  72. func (h *stdHeaders) Merge(ctx context.Context, h2 Headers) (Headers, error) {
  73. h3 := NewHeaders()
  74. if h != nil {
  75. if err := h.Copy(ctx, h3); err != nil {
  76. return nil, errors.Wrap(err, `failed to copy headers from receiver`)
  77. }
  78. }
  79. if h2 != nil {
  80. if err := h2.Copy(ctx, h3); err != nil {
  81. return nil, errors.Wrap(err, `failed to copy headers from argument`)
  82. }
  83. }
  84. return h3, nil
  85. }
  86. func (h *stdHeaders) Encode() ([]byte, error) {
  87. buf, err := json.Marshal(h)
  88. if err != nil {
  89. return nil, errors.Wrap(err, `failed to marshal headers to JSON prior to encoding`)
  90. }
  91. return base64.Encode(buf), nil
  92. }
  93. func (h *stdHeaders) Decode(buf []byte) error {
  94. // base64 json string -> json object representation of header
  95. decoded, err := base64.Decode(buf)
  96. if err != nil {
  97. return errors.Wrap(err, "failed to unmarshal base64 encoded buffer")
  98. }
  99. if err := json.Unmarshal(decoded, h); err != nil {
  100. return errors.Wrap(err, "failed to unmarshal buffer")
  101. }
  102. return nil
  103. }