decrypt.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. package jwe
  2. import (
  3. "crypto/aes"
  4. cryptocipher "crypto/cipher"
  5. "crypto/ecdsa"
  6. "crypto/rsa"
  7. "crypto/sha256"
  8. "crypto/sha512"
  9. "hash"
  10. "golang.org/x/crypto/pbkdf2"
  11. "github.com/lestrrat-go/jwx/internal/keyconv"
  12. "github.com/lestrrat-go/jwx/jwa"
  13. "github.com/lestrrat-go/jwx/jwe/internal/cipher"
  14. "github.com/lestrrat-go/jwx/jwe/internal/content_crypt"
  15. "github.com/lestrrat-go/jwx/jwe/internal/keyenc"
  16. "github.com/lestrrat-go/jwx/x25519"
  17. "github.com/pkg/errors"
  18. )
  19. // Decrypter is responsible for taking various components to decrypt a message.
  20. // its operation is not concurrency safe. You must provide locking yourself
  21. //nolint:govet
  22. type Decrypter struct {
  23. aad []byte
  24. apu []byte
  25. apv []byte
  26. computedAad []byte
  27. iv []byte
  28. keyiv []byte
  29. keysalt []byte
  30. keytag []byte
  31. tag []byte
  32. privkey interface{}
  33. pubkey interface{}
  34. ctalg jwa.ContentEncryptionAlgorithm
  35. keyalg jwa.KeyEncryptionAlgorithm
  36. cipher content_crypt.Cipher
  37. keycount int
  38. }
  39. // NewDecrypter Creates a new Decrypter instance. You must supply the
  40. // rest of parameters via their respective setter methods before
  41. // calling Decrypt().
  42. //
  43. // privkey must be a private key in its "raw" format (i.e. something like
  44. // *rsa.PrivateKey, instead of jwk.Key)
  45. //
  46. // You should consider this object immutable once you assign values to it.
  47. func NewDecrypter(keyalg jwa.KeyEncryptionAlgorithm, ctalg jwa.ContentEncryptionAlgorithm, privkey interface{}) *Decrypter {
  48. return &Decrypter{
  49. ctalg: ctalg,
  50. keyalg: keyalg,
  51. privkey: privkey,
  52. }
  53. }
  54. func (d *Decrypter) AgreementPartyUInfo(apu []byte) *Decrypter {
  55. d.apu = apu
  56. return d
  57. }
  58. func (d *Decrypter) AgreementPartyVInfo(apv []byte) *Decrypter {
  59. d.apv = apv
  60. return d
  61. }
  62. func (d *Decrypter) AuthenticatedData(aad []byte) *Decrypter {
  63. d.aad = aad
  64. return d
  65. }
  66. func (d *Decrypter) ComputedAuthenticatedData(aad []byte) *Decrypter {
  67. d.computedAad = aad
  68. return d
  69. }
  70. func (d *Decrypter) ContentEncryptionAlgorithm(ctalg jwa.ContentEncryptionAlgorithm) *Decrypter {
  71. d.ctalg = ctalg
  72. return d
  73. }
  74. func (d *Decrypter) InitializationVector(iv []byte) *Decrypter {
  75. d.iv = iv
  76. return d
  77. }
  78. func (d *Decrypter) KeyCount(keycount int) *Decrypter {
  79. d.keycount = keycount
  80. return d
  81. }
  82. func (d *Decrypter) KeyInitializationVector(keyiv []byte) *Decrypter {
  83. d.keyiv = keyiv
  84. return d
  85. }
  86. func (d *Decrypter) KeySalt(keysalt []byte) *Decrypter {
  87. d.keysalt = keysalt
  88. return d
  89. }
  90. func (d *Decrypter) KeyTag(keytag []byte) *Decrypter {
  91. d.keytag = keytag
  92. return d
  93. }
  94. // PublicKey sets the public key to be used in decoding EC based encryptions.
  95. // The key must be in its "raw" format (i.e. *ecdsa.PublicKey, instead of jwk.Key)
  96. func (d *Decrypter) PublicKey(pubkey interface{}) *Decrypter {
  97. d.pubkey = pubkey
  98. return d
  99. }
  100. func (d *Decrypter) Tag(tag []byte) *Decrypter {
  101. d.tag = tag
  102. return d
  103. }
  104. func (d *Decrypter) ContentCipher() (content_crypt.Cipher, error) {
  105. if d.cipher == nil {
  106. switch d.ctalg {
  107. case jwa.A128GCM, jwa.A192GCM, jwa.A256GCM, jwa.A128CBC_HS256, jwa.A192CBC_HS384, jwa.A256CBC_HS512:
  108. cipher, err := cipher.NewAES(d.ctalg)
  109. if err != nil {
  110. return nil, errors.Wrapf(err, `failed to build content cipher for %s`, d.ctalg)
  111. }
  112. d.cipher = cipher
  113. default:
  114. return nil, errors.Errorf(`invalid content cipher algorithm (%s)`, d.ctalg)
  115. }
  116. }
  117. return d.cipher, nil
  118. }
  119. func (d *Decrypter) Decrypt(recipientKey, ciphertext []byte) (plaintext []byte, err error) {
  120. cek, keyerr := d.DecryptKey(recipientKey)
  121. if keyerr != nil {
  122. err = errors.Wrap(keyerr, `failed to decrypt key`)
  123. return
  124. }
  125. cipher, ciphererr := d.ContentCipher()
  126. if ciphererr != nil {
  127. err = errors.Wrap(ciphererr, `failed to fetch content crypt cipher`)
  128. return
  129. }
  130. computedAad := d.computedAad
  131. if d.aad != nil {
  132. computedAad = append(append(computedAad, '.'), d.aad...)
  133. }
  134. plaintext, err = cipher.Decrypt(cek, d.iv, ciphertext, d.tag, computedAad)
  135. if err != nil {
  136. err = errors.Wrap(err, `failed to decrypt payload`)
  137. return
  138. }
  139. return plaintext, nil
  140. }
  141. func (d *Decrypter) decryptSymmetricKey(recipientKey, cek []byte) ([]byte, error) {
  142. switch d.keyalg {
  143. case jwa.DIRECT:
  144. return cek, nil
  145. case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
  146. var hashFunc func() hash.Hash
  147. var keylen int
  148. switch d.keyalg {
  149. case jwa.PBES2_HS256_A128KW:
  150. hashFunc = sha256.New
  151. keylen = 16
  152. case jwa.PBES2_HS384_A192KW:
  153. hashFunc = sha512.New384
  154. keylen = 24
  155. case jwa.PBES2_HS512_A256KW:
  156. hashFunc = sha512.New
  157. keylen = 32
  158. }
  159. salt := []byte(d.keyalg)
  160. salt = append(salt, byte(0))
  161. salt = append(salt, d.keysalt...)
  162. cek = pbkdf2.Key(cek, salt, d.keycount, keylen, hashFunc)
  163. fallthrough
  164. case jwa.A128KW, jwa.A192KW, jwa.A256KW:
  165. block, err := aes.NewCipher(cek)
  166. if err != nil {
  167. return nil, errors.Wrap(err, `failed to create new AES cipher`)
  168. }
  169. jek, err := keyenc.Unwrap(block, recipientKey)
  170. if err != nil {
  171. return nil, errors.Wrap(err, `failed to unwrap key`)
  172. }
  173. return jek, nil
  174. case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW:
  175. if len(d.keyiv) != 12 {
  176. return nil, errors.Errorf("GCM requires 96-bit iv, got %d", len(d.keyiv)*8)
  177. }
  178. if len(d.keytag) != 16 {
  179. return nil, errors.Errorf("GCM requires 128-bit tag, got %d", len(d.keytag)*8)
  180. }
  181. block, err := aes.NewCipher(cek)
  182. if err != nil {
  183. return nil, errors.Wrap(err, `failed to create new AES cipher`)
  184. }
  185. aesgcm, err := cryptocipher.NewGCM(block)
  186. if err != nil {
  187. return nil, errors.Wrap(err, `failed to create new GCM wrap`)
  188. }
  189. ciphertext := recipientKey[:]
  190. ciphertext = append(ciphertext, d.keytag...)
  191. jek, err := aesgcm.Open(nil, d.keyiv, ciphertext, nil)
  192. if err != nil {
  193. return nil, errors.Wrap(err, `failed to decode key`)
  194. }
  195. return jek, nil
  196. default:
  197. return nil, errors.Errorf("decrypt key: unsupported algorithm %s", d.keyalg)
  198. }
  199. }
  200. func (d *Decrypter) DecryptKey(recipientKey []byte) (cek []byte, err error) {
  201. if d.keyalg.IsSymmetric() {
  202. var ok bool
  203. cek, ok = d.privkey.([]byte)
  204. if !ok {
  205. return nil, errors.Errorf("decrypt key: []byte is required as the key to build %s key decrypter (got %T)", d.keyalg, d.privkey)
  206. }
  207. return d.decryptSymmetricKey(recipientKey, cek)
  208. }
  209. k, err := d.BuildKeyDecrypter()
  210. if err != nil {
  211. return nil, errors.Wrap(err, `failed to build key decrypter`)
  212. }
  213. cek, err = k.Decrypt(recipientKey)
  214. if err != nil {
  215. return nil, errors.Wrap(err, `failed to decrypt key`)
  216. }
  217. return cek, nil
  218. }
  219. func (d *Decrypter) BuildKeyDecrypter() (keyenc.Decrypter, error) {
  220. cipher, err := d.ContentCipher()
  221. if err != nil {
  222. return nil, errors.Wrap(err, `failed to fetch content crypt cipher`)
  223. }
  224. switch alg := d.keyalg; alg {
  225. case jwa.RSA1_5:
  226. var privkey rsa.PrivateKey
  227. if err := keyconv.RSAPrivateKey(&privkey, d.privkey); err != nil {
  228. return nil, errors.Wrapf(err, "*rsa.PrivateKey is required as the key to build %s key decrypter", alg)
  229. }
  230. return keyenc.NewRSAPKCS15Decrypt(alg, &privkey, cipher.KeySize()/2), nil
  231. case jwa.RSA_OAEP, jwa.RSA_OAEP_256:
  232. var privkey rsa.PrivateKey
  233. if err := keyconv.RSAPrivateKey(&privkey, d.privkey); err != nil {
  234. return nil, errors.Wrapf(err, "*rsa.PrivateKey is required as the key to build %s key decrypter", alg)
  235. }
  236. return keyenc.NewRSAOAEPDecrypt(alg, &privkey)
  237. case jwa.A128KW, jwa.A192KW, jwa.A256KW:
  238. sharedkey, ok := d.privkey.([]byte)
  239. if !ok {
  240. return nil, errors.Errorf("[]byte is required as the key to build %s key decrypter", alg)
  241. }
  242. return keyenc.NewAES(alg, sharedkey)
  243. case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
  244. switch d.pubkey.(type) {
  245. case x25519.PublicKey:
  246. return keyenc.NewECDHESDecrypt(alg, d.ctalg, d.pubkey, d.apu, d.apv, d.privkey), nil
  247. default:
  248. var pubkey ecdsa.PublicKey
  249. if err := keyconv.ECDSAPublicKey(&pubkey, d.pubkey); err != nil {
  250. return nil, errors.Wrapf(err, "*ecdsa.PublicKey is required as the key to build %s key decrypter", alg)
  251. }
  252. var privkey ecdsa.PrivateKey
  253. if err := keyconv.ECDSAPrivateKey(&privkey, d.privkey); err != nil {
  254. return nil, errors.Wrapf(err, "*ecdsa.PrivateKey is required as the key to build %s key decrypter", alg)
  255. }
  256. return keyenc.NewECDHESDecrypt(alg, d.ctalg, &pubkey, d.apu, d.apv, &privkey), nil
  257. }
  258. default:
  259. return nil, errors.Errorf(`unsupported algorithm for key decryption (%s)`, alg)
  260. }
  261. }