rsa.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. package jwk
  2. import (
  3. "crypto"
  4. "crypto/rsa"
  5. "encoding/binary"
  6. "fmt"
  7. "math/big"
  8. "github.com/lestrrat-go/blackmagic"
  9. "github.com/lestrrat-go/jwx/internal/base64"
  10. "github.com/lestrrat-go/jwx/internal/pool"
  11. "github.com/pkg/errors"
  12. )
  13. func (k *rsaPrivateKey) FromRaw(rawKey *rsa.PrivateKey) error {
  14. k.mu.Lock()
  15. defer k.mu.Unlock()
  16. d, err := bigIntToBytes(rawKey.D)
  17. if err != nil {
  18. return errors.Wrap(err, `invalid rsa.PrivateKey`)
  19. }
  20. k.d = d
  21. l := len(rawKey.Primes)
  22. if l < 0 /* I know, I'm being paranoid */ || l > 2 {
  23. return fmt.Errorf(`invalid number of primes in rsa.PrivateKey: need 0 to 2, but got %d`, len(rawKey.Primes))
  24. }
  25. if l > 0 {
  26. p, err := bigIntToBytes(rawKey.Primes[0])
  27. if err != nil {
  28. return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err)
  29. }
  30. k.p = p
  31. }
  32. if l > 1 {
  33. q, err := bigIntToBytes(rawKey.Primes[1])
  34. if err != nil {
  35. return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err)
  36. }
  37. k.q = q
  38. }
  39. // dp, dq, qi are optional values
  40. if v, err := bigIntToBytes(rawKey.Precomputed.Dp); err == nil {
  41. k.dp = v
  42. }
  43. if v, err := bigIntToBytes(rawKey.Precomputed.Dq); err == nil {
  44. k.dq = v
  45. }
  46. if v, err := bigIntToBytes(rawKey.Precomputed.Qinv); err == nil {
  47. k.qi = v
  48. }
  49. // public key part
  50. n, e, err := rsaPublicKeyByteValuesFromRaw(&rawKey.PublicKey)
  51. if err != nil {
  52. return errors.Wrap(err, `invalid rsa.PrivateKey`)
  53. }
  54. k.n = n
  55. k.e = e
  56. return nil
  57. }
  58. func rsaPublicKeyByteValuesFromRaw(rawKey *rsa.PublicKey) ([]byte, []byte, error) {
  59. n, err := bigIntToBytes(rawKey.N)
  60. if err != nil {
  61. return nil, nil, errors.Wrap(err, `invalid rsa.PublicKey`)
  62. }
  63. data := make([]byte, 8)
  64. binary.BigEndian.PutUint64(data, uint64(rawKey.E))
  65. i := 0
  66. for ; i < len(data); i++ {
  67. if data[i] != 0x0 {
  68. break
  69. }
  70. }
  71. return n, data[i:], nil
  72. }
  73. func (k *rsaPublicKey) FromRaw(rawKey *rsa.PublicKey) error {
  74. k.mu.Lock()
  75. defer k.mu.Unlock()
  76. n, e, err := rsaPublicKeyByteValuesFromRaw(rawKey)
  77. if err != nil {
  78. return errors.Wrap(err, `invalid rsa.PrivateKey`)
  79. }
  80. k.n = n
  81. k.e = e
  82. return nil
  83. }
  84. func (k *rsaPrivateKey) Raw(v interface{}) error {
  85. k.mu.RLock()
  86. defer k.mu.RUnlock()
  87. var d, q, p big.Int // note: do not use from sync.Pool
  88. d.SetBytes(k.d)
  89. q.SetBytes(k.q)
  90. p.SetBytes(k.p)
  91. // optional fields
  92. var dp, dq, qi *big.Int
  93. if len(k.dp) > 0 {
  94. dp = &big.Int{} // note: do not use from sync.Pool
  95. dp.SetBytes(k.dp)
  96. }
  97. if len(k.dq) > 0 {
  98. dq = &big.Int{} // note: do not use from sync.Pool
  99. dq.SetBytes(k.dq)
  100. }
  101. if len(k.qi) > 0 {
  102. qi = &big.Int{} // note: do not use from sync.Pool
  103. qi.SetBytes(k.qi)
  104. }
  105. var key rsa.PrivateKey
  106. pubk := newRSAPublicKey()
  107. pubk.n = k.n
  108. pubk.e = k.e
  109. if err := pubk.Raw(&key.PublicKey); err != nil {
  110. return errors.Wrap(err, `failed to materialize RSA public key`)
  111. }
  112. key.D = &d
  113. key.Primes = []*big.Int{&p, &q}
  114. if dp != nil {
  115. key.Precomputed.Dp = dp
  116. }
  117. if dq != nil {
  118. key.Precomputed.Dq = dq
  119. }
  120. if qi != nil {
  121. key.Precomputed.Qinv = qi
  122. }
  123. key.Precomputed.CRTValues = []rsa.CRTValue{}
  124. return blackmagic.AssignIfCompatible(v, &key)
  125. }
  126. // Raw takes the values stored in the Key object, and creates the
  127. // corresponding *rsa.PublicKey object.
  128. func (k *rsaPublicKey) Raw(v interface{}) error {
  129. k.mu.RLock()
  130. defer k.mu.RUnlock()
  131. var key rsa.PublicKey
  132. n := pool.GetBigInt()
  133. e := pool.GetBigInt()
  134. defer pool.ReleaseBigInt(e)
  135. n.SetBytes(k.n)
  136. e.SetBytes(k.e)
  137. key.N = n
  138. key.E = int(e.Int64())
  139. return blackmagic.AssignIfCompatible(v, &key)
  140. }
  141. func makeRSAPublicKey(v interface {
  142. makePairs() []*HeaderPair
  143. }) (Key, error) {
  144. newKey := NewRSAPublicKey()
  145. // Iterate and copy everything except for the bits that should not be in the public key
  146. for _, pair := range v.makePairs() {
  147. switch pair.Key {
  148. case RSADKey, RSADPKey, RSADQKey, RSAPKey, RSAQKey, RSAQIKey:
  149. continue
  150. default:
  151. //nolint:forcetypeassert
  152. key := pair.Key.(string)
  153. if err := newKey.Set(key, pair.Value); err != nil {
  154. return nil, errors.Wrapf(err, `failed to set field %q`, key)
  155. }
  156. }
  157. }
  158. return newKey, nil
  159. }
  160. func (k *rsaPrivateKey) PublicKey() (Key, error) {
  161. return makeRSAPublicKey(k)
  162. }
  163. func (k *rsaPublicKey) PublicKey() (Key, error) {
  164. return makeRSAPublicKey(k)
  165. }
  166. // Thumbprint returns the JWK thumbprint using the indicated
  167. // hashing algorithm, according to RFC 7638
  168. func (k rsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
  169. k.mu.RLock()
  170. defer k.mu.RUnlock()
  171. var key rsa.PrivateKey
  172. if err := k.Raw(&key); err != nil {
  173. return nil, errors.Wrap(err, `failed to materialize RSA private key`)
  174. }
  175. return rsaThumbprint(hash, &key.PublicKey)
  176. }
  177. func (k rsaPublicKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
  178. k.mu.RLock()
  179. defer k.mu.RUnlock()
  180. var key rsa.PublicKey
  181. if err := k.Raw(&key); err != nil {
  182. return nil, errors.Wrap(err, `failed to materialize RSA public key`)
  183. }
  184. return rsaThumbprint(hash, &key)
  185. }
  186. func rsaThumbprint(hash crypto.Hash, key *rsa.PublicKey) ([]byte, error) {
  187. buf := pool.GetBytesBuffer()
  188. defer pool.ReleaseBytesBuffer(buf)
  189. buf.WriteString(`{"e":"`)
  190. buf.WriteString(base64.EncodeUint64ToString(uint64(key.E)))
  191. buf.WriteString(`","kty":"RSA","n":"`)
  192. buf.WriteString(base64.EncodeToString(key.N.Bytes()))
  193. buf.WriteString(`"}`)
  194. h := hash.New()
  195. if _, err := buf.WriteTo(h); err != nil {
  196. return nil, errors.Wrap(err, "failed to write rsaThumbprint")
  197. }
  198. return h.Sum(nil), nil
  199. }