encrypt.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package jwe
  2. import (
  3. "context"
  4. "sync"
  5. "github.com/lestrrat-go/jwx/internal/base64"
  6. "github.com/lestrrat-go/jwx/jwa"
  7. "github.com/pkg/errors"
  8. )
  9. var encryptCtxPool = sync.Pool{
  10. New: func() interface{} {
  11. return &encryptCtx{}
  12. },
  13. }
  14. func getEncryptCtx() *encryptCtx {
  15. //nolint:forcetypeassert
  16. return encryptCtxPool.Get().(*encryptCtx)
  17. }
  18. func releaseEncryptCtx(ctx *encryptCtx) {
  19. ctx.protected = nil
  20. ctx.contentEncrypter = nil
  21. ctx.generator = nil
  22. ctx.keyEncrypters = nil
  23. ctx.compress = jwa.NoCompress
  24. encryptCtxPool.Put(ctx)
  25. }
  26. // Encrypt takes the plaintext and encrypts into a JWE message.
  27. func (e encryptCtx) Encrypt(plaintext []byte) (*Message, error) {
  28. bk, err := e.generator.Generate()
  29. if err != nil {
  30. return nil, errors.Wrap(err, "failed to generate key")
  31. }
  32. cek := bk.Bytes()
  33. if e.protected == nil {
  34. // shouldn't happen, but...
  35. e.protected = NewHeaders()
  36. }
  37. if err := e.protected.Set(ContentEncryptionKey, e.contentEncrypter.Algorithm()); err != nil {
  38. return nil, errors.Wrap(err, `failed to set "enc" in protected header`)
  39. }
  40. compression := e.compress
  41. if compression != jwa.NoCompress {
  42. if err := e.protected.Set(CompressionKey, compression); err != nil {
  43. return nil, errors.Wrap(err, `failed to set "zip" in protected header`)
  44. }
  45. }
  46. // In JWE, multiple recipients may exist -- they receive an
  47. // encrypted version of the CEK, using their key encryption
  48. // algorithm of choice.
  49. recipients := make([]Recipient, len(e.keyEncrypters))
  50. for i, enc := range e.keyEncrypters {
  51. r := NewRecipient()
  52. if err := r.Headers().Set(AlgorithmKey, enc.Algorithm()); err != nil {
  53. return nil, errors.Wrap(err, "failed to set header")
  54. }
  55. if v := enc.KeyID(); v != "" {
  56. if err := r.Headers().Set(KeyIDKey, v); err != nil {
  57. return nil, errors.Wrap(err, "failed to set header")
  58. }
  59. }
  60. enckey, err := enc.Encrypt(cek)
  61. if err != nil {
  62. return nil, errors.Wrap(err, `failed to encrypt key`)
  63. }
  64. if enc.Algorithm() == jwa.ECDH_ES || enc.Algorithm() == jwa.DIRECT {
  65. if len(e.keyEncrypters) > 1 {
  66. return nil, errors.Errorf("unable to support multiple recipients for ECDH-ES")
  67. }
  68. cek = enckey.Bytes()
  69. } else {
  70. if err := r.SetEncryptedKey(enckey.Bytes()); err != nil {
  71. return nil, errors.Wrap(err, "failed to set encrypted key")
  72. }
  73. }
  74. if hp, ok := enckey.(populater); ok {
  75. if err := hp.Populate(r.Headers()); err != nil {
  76. return nil, errors.Wrap(err, "failed to populate")
  77. }
  78. }
  79. recipients[i] = r
  80. }
  81. // If there's only one recipient, you want to include that in the
  82. // protected header
  83. if len(recipients) == 1 {
  84. h, err := e.protected.Merge(context.TODO(), recipients[0].Headers())
  85. if err != nil {
  86. return nil, errors.Wrap(err, "failed to merge protected headers")
  87. }
  88. e.protected = h
  89. }
  90. aad, err := e.protected.Encode()
  91. if err != nil {
  92. return nil, errors.Wrap(err, "failed to base64 encode protected headers")
  93. }
  94. plaintext, err = compress(plaintext, compression)
  95. if err != nil {
  96. return nil, errors.Wrap(err, `failed to compress payload before encryption`)
  97. }
  98. // ...on the other hand, there's only one content cipher.
  99. iv, ciphertext, tag, err := e.contentEncrypter.Encrypt(cek, plaintext, aad)
  100. if err != nil {
  101. return nil, errors.Wrap(err, "failed to encrypt payload")
  102. }
  103. msg := NewMessage()
  104. decodedAad, err := base64.Decode(aad)
  105. if err != nil {
  106. return nil, errors.Wrap(err, "failed to decode base64")
  107. }
  108. if err := msg.Set(AuthenticatedDataKey, decodedAad); err != nil {
  109. return nil, errors.Wrapf(err, `failed to set %s`, AuthenticatedDataKey)
  110. }
  111. if err := msg.Set(CipherTextKey, ciphertext); err != nil {
  112. return nil, errors.Wrapf(err, `failed to set %s`, CipherTextKey)
  113. }
  114. if err := msg.Set(InitializationVectorKey, iv); err != nil {
  115. return nil, errors.Wrapf(err, `failed to set %s`, InitializationVectorKey)
  116. }
  117. if err := msg.Set(ProtectedHeadersKey, e.protected); err != nil {
  118. return nil, errors.Wrapf(err, `failed to set %s`, ProtectedHeadersKey)
  119. }
  120. if err := msg.Set(RecipientsKey, recipients); err != nil {
  121. return nil, errors.Wrapf(err, `failed to set %s`, RecipientsKey)
  122. }
  123. if err := msg.Set(TagKey, tag); err != nil {
  124. return nil, errors.Wrapf(err, `failed to set %s`, TagKey)
  125. }
  126. return msg, nil
  127. }