cipher.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package cipher
  2. import (
  3. "crypto/aes"
  4. "crypto/cipher"
  5. "fmt"
  6. "github.com/lestrrat-go/jwx/jwa"
  7. "github.com/lestrrat-go/jwx/jwe/internal/aescbc"
  8. "github.com/lestrrat-go/jwx/jwe/internal/keygen"
  9. "github.com/pkg/errors"
  10. )
  11. var gcm = &gcmFetcher{}
  12. var cbc = &cbcFetcher{}
  13. func (f gcmFetcher) Fetch(key []byte) (cipher.AEAD, error) {
  14. aescipher, err := aes.NewCipher(key)
  15. if err != nil {
  16. return nil, errors.Wrap(err, "cipher: failed to create AES cipher for GCM")
  17. }
  18. aead, err := cipher.NewGCM(aescipher)
  19. if err != nil {
  20. return nil, errors.Wrap(err, `failed to create GCM for cipher`)
  21. }
  22. return aead, nil
  23. }
  24. func (f cbcFetcher) Fetch(key []byte) (cipher.AEAD, error) {
  25. aead, err := aescbc.New(key, aes.NewCipher)
  26. if err != nil {
  27. return nil, errors.Wrap(err, "cipher: failed to create AES cipher for CBC")
  28. }
  29. return aead, nil
  30. }
  31. func (c AesContentCipher) KeySize() int {
  32. return c.keysize
  33. }
  34. func (c AesContentCipher) TagSize() int {
  35. return c.tagsize
  36. }
  37. func NewAES(alg jwa.ContentEncryptionAlgorithm) (*AesContentCipher, error) {
  38. var keysize int
  39. var tagsize int
  40. var fetcher Fetcher
  41. switch alg {
  42. case jwa.A128GCM:
  43. keysize = 16
  44. tagsize = 16
  45. fetcher = gcm
  46. case jwa.A192GCM:
  47. keysize = 24
  48. tagsize = 16
  49. fetcher = gcm
  50. case jwa.A256GCM:
  51. keysize = 32
  52. tagsize = 16
  53. fetcher = gcm
  54. case jwa.A128CBC_HS256:
  55. tagsize = 16
  56. keysize = tagsize * 2
  57. fetcher = cbc
  58. case jwa.A192CBC_HS384:
  59. tagsize = 24
  60. keysize = tagsize * 2
  61. fetcher = cbc
  62. case jwa.A256CBC_HS512:
  63. tagsize = 32
  64. keysize = tagsize * 2
  65. fetcher = cbc
  66. default:
  67. return nil, errors.Errorf("failed to create AES content cipher: invalid algorithm (%s)", alg)
  68. }
  69. return &AesContentCipher{
  70. keysize: keysize,
  71. tagsize: tagsize,
  72. fetch: fetcher,
  73. }, nil
  74. }
  75. func (c AesContentCipher) Encrypt(cek, plaintext, aad []byte) (iv, ciphertext, tag []byte, err error) {
  76. var aead cipher.AEAD
  77. aead, err = c.fetch.Fetch(cek)
  78. if err != nil {
  79. return nil, nil, nil, errors.Wrap(err, "failed to fetch AEAD")
  80. }
  81. // Seal may panic (argh!), so protect ourselves from that
  82. defer func() {
  83. if e := recover(); e != nil {
  84. switch e := e.(type) {
  85. case error:
  86. err = e
  87. default:
  88. err = errors.Errorf("%s", e)
  89. }
  90. err = errors.Wrap(err, "failed to encrypt")
  91. }
  92. }()
  93. var bs keygen.ByteSource
  94. if c.NonceGenerator == nil {
  95. bs, err = keygen.NewRandom(aead.NonceSize()).Generate()
  96. } else {
  97. bs, err = c.NonceGenerator.Generate()
  98. }
  99. if err != nil {
  100. return nil, nil, nil, errors.Wrap(err, "failed to generate nonce")
  101. }
  102. iv = bs.Bytes()
  103. combined := aead.Seal(nil, iv, plaintext, aad)
  104. tagoffset := len(combined) - c.TagSize()
  105. if tagoffset < 0 {
  106. panic(fmt.Sprintf("tag offset is less than 0 (combined len = %d, tagsize = %d)", len(combined), c.TagSize()))
  107. }
  108. tag = combined[tagoffset:]
  109. ciphertext = make([]byte, tagoffset)
  110. copy(ciphertext, combined[:tagoffset])
  111. return
  112. }
  113. func (c AesContentCipher) Decrypt(cek, iv, ciphertxt, tag, aad []byte) (plaintext []byte, err error) {
  114. aead, err := c.fetch.Fetch(cek)
  115. if err != nil {
  116. return nil, errors.Wrap(err, "failed to fetch AEAD data")
  117. }
  118. // Open may panic (argh!), so protect ourselves from that
  119. defer func() {
  120. if e := recover(); e != nil {
  121. switch e := e.(type) {
  122. case error:
  123. err = e
  124. default:
  125. err = errors.Errorf("%s", e)
  126. }
  127. err = errors.Wrap(err, "failed to decrypt")
  128. return
  129. }
  130. }()
  131. combined := make([]byte, len(ciphertxt)+len(tag))
  132. copy(combined, ciphertxt)
  133. copy(combined[len(ciphertxt):], tag)
  134. buf, aeaderr := aead.Open(nil, iv, combined, aad)
  135. if aeaderr != nil {
  136. err = errors.Wrap(aeaderr, `aead.Open failed`)
  137. return
  138. }
  139. plaintext = buf
  140. return
  141. }