ecdsa.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package jws
  2. import (
  3. "crypto"
  4. "crypto/ecdsa"
  5. "crypto/rand"
  6. "encoding/asn1"
  7. "fmt"
  8. "math/big"
  9. "github.com/lestrrat-go/jwx/internal/keyconv"
  10. "github.com/lestrrat-go/jwx/internal/pool"
  11. "github.com/lestrrat-go/jwx/jwa"
  12. "github.com/pkg/errors"
  13. )
  14. var ecdsaSigners map[jwa.SignatureAlgorithm]*ecdsaSigner
  15. var ecdsaVerifiers map[jwa.SignatureAlgorithm]*ecdsaVerifier
  16. func init() {
  17. algs := map[jwa.SignatureAlgorithm]crypto.Hash{
  18. jwa.ES256: crypto.SHA256,
  19. jwa.ES384: crypto.SHA384,
  20. jwa.ES512: crypto.SHA512,
  21. jwa.ES256K: crypto.SHA256,
  22. }
  23. ecdsaSigners = make(map[jwa.SignatureAlgorithm]*ecdsaSigner)
  24. ecdsaVerifiers = make(map[jwa.SignatureAlgorithm]*ecdsaVerifier)
  25. for alg, hash := range algs {
  26. ecdsaSigners[alg] = &ecdsaSigner{
  27. alg: alg,
  28. hash: hash,
  29. }
  30. ecdsaVerifiers[alg] = &ecdsaVerifier{
  31. alg: alg,
  32. hash: hash,
  33. }
  34. }
  35. }
  36. func newECDSASigner(alg jwa.SignatureAlgorithm) Signer {
  37. return ecdsaSigners[alg]
  38. }
  39. // ecdsaSigners are immutable.
  40. type ecdsaSigner struct {
  41. alg jwa.SignatureAlgorithm
  42. hash crypto.Hash
  43. }
  44. func (es ecdsaSigner) Algorithm() jwa.SignatureAlgorithm {
  45. return es.alg
  46. }
  47. func (es *ecdsaSigner) Sign(payload []byte, key interface{}) ([]byte, error) {
  48. if key == nil {
  49. return nil, errors.New(`missing private key while signing payload`)
  50. }
  51. h := es.hash.New()
  52. if _, err := h.Write(payload); err != nil {
  53. return nil, errors.Wrap(err, "failed to write payload using ecdsa")
  54. }
  55. signer, ok := key.(crypto.Signer)
  56. if ok {
  57. switch key.(type) {
  58. case ecdsa.PrivateKey, *ecdsa.PrivateKey:
  59. // if it's a ecdsa.PrivateKey, it's more efficient to
  60. // go through the non-crypto.Signer route. Set ok to false
  61. ok = false
  62. }
  63. }
  64. var r, s *big.Int
  65. var curveBits int
  66. if ok {
  67. signed, err := signer.Sign(rand.Reader, h.Sum(nil), es.hash)
  68. if err != nil {
  69. return nil, err
  70. }
  71. var p struct {
  72. R *big.Int
  73. S *big.Int
  74. }
  75. if _, err := asn1.Unmarshal(signed, &p); err != nil {
  76. return nil, errors.Wrap(err, `failed to unmarshal ASN1 encoded signature`)
  77. }
  78. // Okay, this is silly, but hear me out. When we use the
  79. // crypto.Signer interface, the PrivateKey is hidden.
  80. // But we need some information about the key (it's bit size).
  81. //
  82. // So while silly, we're going to have to make another call
  83. // here and fetch the Public key.
  84. // This probably means that this should be cached some where.
  85. cpub := signer.Public()
  86. pubkey, ok := cpub.(*ecdsa.PublicKey)
  87. if !ok {
  88. return nil, fmt.Errorf(`expected *ecdsa.PublicKey, got %T`, pubkey)
  89. }
  90. curveBits = pubkey.Curve.Params().BitSize
  91. r = p.R
  92. s = p.S
  93. } else {
  94. var privkey ecdsa.PrivateKey
  95. if err := keyconv.ECDSAPrivateKey(&privkey, key); err != nil {
  96. return nil, errors.Wrapf(err, `failed to retrieve ecdsa.PrivateKey out of %T`, key)
  97. }
  98. curveBits = privkey.Curve.Params().BitSize
  99. rtmp, stmp, err := ecdsa.Sign(rand.Reader, &privkey, h.Sum(nil))
  100. if err != nil {
  101. return nil, errors.Wrap(err, "failed to sign payload using ecdsa")
  102. }
  103. r = rtmp
  104. s = stmp
  105. }
  106. keyBytes := curveBits / 8
  107. // Curve bits do not need to be a multiple of 8.
  108. if curveBits%8 > 0 {
  109. keyBytes++
  110. }
  111. rBytes := r.Bytes()
  112. rBytesPadded := make([]byte, keyBytes)
  113. copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
  114. sBytes := s.Bytes()
  115. sBytesPadded := make([]byte, keyBytes)
  116. copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
  117. out := append(rBytesPadded, sBytesPadded...)
  118. return out, nil
  119. }
  120. // ecdsaVerifiers are immutable.
  121. type ecdsaVerifier struct {
  122. alg jwa.SignatureAlgorithm
  123. hash crypto.Hash
  124. }
  125. func newECDSAVerifier(alg jwa.SignatureAlgorithm) Verifier {
  126. return ecdsaVerifiers[alg]
  127. }
  128. func (v ecdsaVerifier) Algorithm() jwa.SignatureAlgorithm {
  129. return v.alg
  130. }
  131. func (v *ecdsaVerifier) Verify(payload []byte, signature []byte, key interface{}) error {
  132. if key == nil {
  133. return errors.New(`missing public key while verifying payload`)
  134. }
  135. var pubkey ecdsa.PublicKey
  136. if cs, ok := key.(crypto.Signer); ok {
  137. cpub := cs.Public()
  138. switch cpub := cpub.(type) {
  139. case ecdsa.PublicKey:
  140. pubkey = cpub
  141. case *ecdsa.PublicKey:
  142. pubkey = *cpub
  143. default:
  144. return errors.Errorf(`failed to retrieve ecdsa.PublicKey out of crypto.Signer %T`, key)
  145. }
  146. } else {
  147. if err := keyconv.ECDSAPublicKey(&pubkey, key); err != nil {
  148. return errors.Wrapf(err, `failed to retrieve ecdsa.PublicKey out of %T`, key)
  149. }
  150. }
  151. r := pool.GetBigInt()
  152. s := pool.GetBigInt()
  153. defer pool.ReleaseBigInt(r)
  154. defer pool.ReleaseBigInt(s)
  155. n := len(signature) / 2
  156. r.SetBytes(signature[:n])
  157. s.SetBytes(signature[n:])
  158. h := v.hash.New()
  159. if _, err := h.Write(payload); err != nil {
  160. return errors.Wrap(err, "failed to write payload using ecdsa")
  161. }
  162. if !ecdsa.Verify(&pubkey, h.Sum(nil), r, s) {
  163. return errors.New(`failed to verify signature using ecdsa`)
  164. }
  165. return nil
  166. }