rsa.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. package jws
  2. import (
  3. "crypto"
  4. "crypto/rand"
  5. "crypto/rsa"
  6. "github.com/lestrrat-go/jwx/internal/keyconv"
  7. "github.com/lestrrat-go/jwx/jwa"
  8. "github.com/pkg/errors"
  9. )
  10. var rsaSigners map[jwa.SignatureAlgorithm]*rsaSigner
  11. var rsaVerifiers map[jwa.SignatureAlgorithm]*rsaVerifier
  12. func init() {
  13. algs := map[jwa.SignatureAlgorithm]struct {
  14. Hash crypto.Hash
  15. PSS bool
  16. }{
  17. jwa.RS256: {
  18. Hash: crypto.SHA256,
  19. },
  20. jwa.RS384: {
  21. Hash: crypto.SHA384,
  22. },
  23. jwa.RS512: {
  24. Hash: crypto.SHA512,
  25. },
  26. jwa.PS256: {
  27. Hash: crypto.SHA256,
  28. PSS: true,
  29. },
  30. jwa.PS384: {
  31. Hash: crypto.SHA384,
  32. PSS: true,
  33. },
  34. jwa.PS512: {
  35. Hash: crypto.SHA512,
  36. PSS: true,
  37. },
  38. }
  39. rsaSigners = make(map[jwa.SignatureAlgorithm]*rsaSigner)
  40. rsaVerifiers = make(map[jwa.SignatureAlgorithm]*rsaVerifier)
  41. for alg, item := range algs {
  42. rsaSigners[alg] = &rsaSigner{
  43. alg: alg,
  44. hash: item.Hash,
  45. pss: item.PSS,
  46. }
  47. rsaVerifiers[alg] = &rsaVerifier{
  48. alg: alg,
  49. hash: item.Hash,
  50. pss: item.PSS,
  51. }
  52. }
  53. }
  54. type rsaSigner struct {
  55. alg jwa.SignatureAlgorithm
  56. hash crypto.Hash
  57. pss bool
  58. }
  59. func newRSASigner(alg jwa.SignatureAlgorithm) Signer {
  60. return rsaSigners[alg]
  61. }
  62. func (rs *rsaSigner) Algorithm() jwa.SignatureAlgorithm {
  63. return rs.alg
  64. }
  65. func (rs *rsaSigner) Sign(payload []byte, key interface{}) ([]byte, error) {
  66. if key == nil {
  67. return nil, errors.New(`missing private key while signing payload`)
  68. }
  69. signer, ok := key.(crypto.Signer)
  70. if !ok {
  71. var privkey rsa.PrivateKey
  72. if err := keyconv.RSAPrivateKey(&privkey, key); err != nil {
  73. return nil, errors.Wrapf(err, `failed to retrieve rsa.PrivateKey out of %T`, key)
  74. }
  75. signer = &privkey
  76. }
  77. h := rs.hash.New()
  78. if _, err := h.Write(payload); err != nil {
  79. return nil, errors.Wrap(err, "failed to write payload to hash")
  80. }
  81. if rs.pss {
  82. return signer.Sign(rand.Reader, h.Sum(nil), &rsa.PSSOptions{
  83. Hash: rs.hash,
  84. SaltLength: rsa.PSSSaltLengthEqualsHash,
  85. })
  86. }
  87. return signer.Sign(rand.Reader, h.Sum(nil), rs.hash)
  88. }
  89. type rsaVerifier struct {
  90. alg jwa.SignatureAlgorithm
  91. hash crypto.Hash
  92. pss bool
  93. }
  94. func newRSAVerifier(alg jwa.SignatureAlgorithm) Verifier {
  95. return rsaVerifiers[alg]
  96. }
  97. func (rv *rsaVerifier) Verify(payload, signature []byte, key interface{}) error {
  98. if key == nil {
  99. return errors.New(`missing public key while verifying payload`)
  100. }
  101. var pubkey rsa.PublicKey
  102. if cs, ok := key.(crypto.Signer); ok {
  103. cpub := cs.Public()
  104. switch cpub := cpub.(type) {
  105. case rsa.PublicKey:
  106. pubkey = cpub
  107. case *rsa.PublicKey:
  108. pubkey = *cpub
  109. default:
  110. return errors.Errorf(`failed to retrieve rsa.PublicKey out of crypto.Signer %T`, key)
  111. }
  112. } else {
  113. if err := keyconv.RSAPublicKey(&pubkey, key); err != nil {
  114. return errors.Wrapf(err, `failed to retrieve rsa.PublicKey out of %T`, key)
  115. }
  116. }
  117. h := rv.hash.New()
  118. if _, err := h.Write(payload); err != nil {
  119. return errors.Wrap(err, "failed to write payload to hash")
  120. }
  121. if rv.pss {
  122. return rsa.VerifyPSS(&pubkey, rv.hash, h.Sum(nil), signature, nil)
  123. }
  124. return rsa.VerifyPKCS1v15(&pubkey, rv.hash, h.Sum(nil), signature)
  125. }