rsa.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. package jwk
  2. import (
  3. "bytes"
  4. "crypto"
  5. "crypto/rsa"
  6. "encoding/json"
  7. "math/big"
  8. "github.com/lestrrat/go-jwx/internal/base64"
  9. "github.com/lestrrat/go-jwx/jwa"
  10. pdebug "github.com/lestrrat/go-pdebug"
  11. "github.com/pkg/errors"
  12. )
  13. func newRSAPublicKey(key *rsa.PublicKey) (*RSAPublicKey, error) {
  14. if key == nil {
  15. return nil, errors.New(`non-nil rsa.PublicKey required`)
  16. }
  17. var hdr StandardHeaders
  18. hdr.Set(KeyTypeKey, jwa.RSA)
  19. return &RSAPublicKey{
  20. headers: &hdr,
  21. key: key,
  22. }, nil
  23. }
  24. func newRSAPrivateKey(key *rsa.PrivateKey) (*RSAPrivateKey, error) {
  25. if key == nil {
  26. return nil, errors.New(`non-nil rsa.PrivateKey required`)
  27. }
  28. if len(key.Primes) < 2 {
  29. return nil, errors.New("two primes required for RSA private key")
  30. }
  31. var hdr StandardHeaders
  32. hdr.Set(KeyTypeKey, jwa.RSA)
  33. return &RSAPrivateKey{
  34. headers: &hdr,
  35. key: key,
  36. }, nil
  37. }
  38. func (k RSAPrivateKey) PublicKey() (*RSAPublicKey, error) {
  39. return newRSAPublicKey(&k.key.PublicKey)
  40. }
  41. func (k *RSAPublicKey) Materialize() (interface{}, error) {
  42. if k.key == nil {
  43. return nil, errors.New(`key has no rsa.PublicKey associated with it`)
  44. }
  45. return k.key, nil
  46. }
  47. func (k *RSAPrivateKey) Materialize() (interface{}, error) {
  48. if k.key == nil {
  49. return nil, errors.New(`key has no rsa.PrivateKey associated with it`)
  50. }
  51. return k.key, nil
  52. }
  53. func (k RSAPublicKey) MarshalJSON() (buf []byte, err error) {
  54. if pdebug.Enabled {
  55. g := pdebug.Marker("jwk.RSAPublicKey.MarshalJSON").BindError(&err)
  56. defer g.End()
  57. }
  58. m := map[string]interface{}{}
  59. if err := k.PopulateMap(m); err != nil {
  60. return nil, errors.Wrap(err, `failed to populate pulibc key values`)
  61. }
  62. return json.Marshal(m)
  63. }
  64. func (k RSAPublicKey) PopulateMap(m map[string]interface{}) (err error) {
  65. if pdebug.Enabled {
  66. g := pdebug.Marker("jwk.RSAPublicKey.PopulateJSON").BindError(&err)
  67. defer g.End()
  68. }
  69. if err := k.headers.PopulateMap(m); err != nil {
  70. return errors.Wrap(err, `failed to populate header values`)
  71. }
  72. m[`n`] = base64.EncodeToString(k.key.N.Bytes())
  73. m[`e`] = base64.EncodeUint64ToString(uint64(k.key.E))
  74. return nil
  75. }
  76. func (k *RSAPublicKey) UnmarshalJSON(data []byte) (err error) {
  77. if pdebug.Enabled {
  78. g := pdebug.Marker("jwk.RSAPublicKey.UnmarshalJSON").BindError(&err)
  79. defer g.End()
  80. }
  81. m := map[string]interface{}{}
  82. if err := json.Unmarshal(data, &m); err != nil {
  83. return errors.Wrap(err, `failed to unmarshal public key`)
  84. }
  85. if err := k.ExtractMap(m); err != nil {
  86. return errors.Wrap(err, `failed to extract data from map`)
  87. }
  88. return nil
  89. }
  90. func (k *RSAPublicKey) ExtractMap(m map[string]interface{}) (err error) {
  91. if pdebug.Enabled {
  92. g := pdebug.Marker("jwk.RSAPublicKey.ExtractMap").BindError(&err)
  93. defer g.End()
  94. }
  95. const (
  96. eKey = `e`
  97. nKey = `n`
  98. )
  99. nbuf, err := getRequiredKey(m, nKey)
  100. if err != nil {
  101. return errors.Wrapf(err, `failed to get required key %s`, nKey)
  102. }
  103. delete(m, nKey)
  104. ebuf, err := getRequiredKey(m, eKey)
  105. if err != nil {
  106. return errors.Wrapf(err, `failed to get required key %s`, eKey)
  107. }
  108. delete(m, eKey)
  109. var n, e big.Int
  110. n.SetBytes(nbuf)
  111. e.SetBytes(ebuf)
  112. var hdrs StandardHeaders
  113. if err := hdrs.ExtractMap(m); err != nil {
  114. return errors.Wrap(err, `failed to extract header values`)
  115. }
  116. *k = RSAPublicKey{
  117. headers: &hdrs,
  118. key: &rsa.PublicKey{E: int(e.Int64()), N: &n},
  119. }
  120. return nil
  121. }
  122. func (k RSAPrivateKey) MarshalJSON() (buf []byte, err error) {
  123. if pdebug.Enabled {
  124. g := pdebug.Marker("jwk.RSAPrivateKey.MarshalJSON").BindError(&err)
  125. defer g.End()
  126. }
  127. m := make(map[string]interface{})
  128. if err := k.PopulateMap(m); err != nil {
  129. return nil, errors.Wrap(err, `failed to populate private key values`)
  130. }
  131. return json.Marshal(m)
  132. }
  133. func (k RSAPrivateKey) PopulateMap(m map[string]interface{}) (err error) {
  134. if pdebug.Enabled {
  135. g := pdebug.Marker("jwk.RSAPrivateKey.PopulateMap").BindError(&err)
  136. defer g.End()
  137. }
  138. const (
  139. dKey = `d`
  140. pKey = `p`
  141. qKey = `q`
  142. dpKey = `dp`
  143. dqKey = `dq`
  144. qiKey = `qi`
  145. )
  146. if err := k.headers.PopulateMap(m); err != nil {
  147. return errors.Wrap(err, `failed to populate header values`)
  148. }
  149. pubkey, _ := newRSAPublicKey(&k.key.PublicKey)
  150. if err := pubkey.PopulateMap(m); err != nil {
  151. return errors.Wrap(err, `failed to populate public key values`)
  152. }
  153. if err := k.headers.PopulateMap(m); err != nil {
  154. return errors.Wrap(err, `failed to populate header values`)
  155. }
  156. m[dKey] = base64.EncodeToString(k.key.D.Bytes())
  157. m[pKey] = base64.EncodeToString(k.key.Primes[0].Bytes())
  158. m[qKey] = base64.EncodeToString(k.key.Primes[1].Bytes())
  159. if v := k.key.Precomputed.Dp; v != nil {
  160. m[dpKey] = base64.EncodeToString(v.Bytes())
  161. }
  162. if v := k.key.Precomputed.Dq; v != nil {
  163. m[dqKey] = base64.EncodeToString(v.Bytes())
  164. }
  165. if v := k.key.Precomputed.Qinv; v != nil {
  166. m[qiKey] = base64.EncodeToString(v.Bytes())
  167. }
  168. return nil
  169. }
  170. func (k *RSAPrivateKey) UnmarshalJSON(data []byte) (err error) {
  171. if pdebug.Enabled {
  172. g := pdebug.Marker("jwk.RSAPrivateKey.UnmarshalJSON").BindError(&err)
  173. defer g.End()
  174. pdebug.Printf("data --> %s", data)
  175. }
  176. m := map[string]interface{}{}
  177. if err := json.Unmarshal(data, &m); err != nil {
  178. return errors.Wrap(err, `failed to unmarshal public key`)
  179. }
  180. var key RSAPrivateKey
  181. if err := key.ExtractMap(m); err != nil {
  182. return errors.Wrap(err, `failed to extract data from map`)
  183. }
  184. *k = key
  185. return nil
  186. }
  187. func (k *RSAPrivateKey) ExtractMap(m map[string]interface{}) (err error) {
  188. if pdebug.Enabled {
  189. g := pdebug.Marker("jwk.RSAPrivateKey.ExractMap").BindError(&err)
  190. defer g.End()
  191. }
  192. const (
  193. dKey = `d`
  194. pKey = `p`
  195. qKey = `q`
  196. dpKey = `dp`
  197. dqKey = `dq`
  198. qiKey = `qi`
  199. )
  200. dbuf, err := getRequiredKey(m, dKey)
  201. if err != nil {
  202. return errors.Wrap(err, `failed to get required key`)
  203. }
  204. delete(m, dKey)
  205. pbuf, err := getRequiredKey(m, pKey)
  206. if err != nil {
  207. return errors.Wrap(err, `failed to get required key`)
  208. }
  209. delete(m, pKey)
  210. qbuf, err := getRequiredKey(m, qKey)
  211. if err != nil {
  212. return errors.Wrap(err, `failed to get required key`)
  213. }
  214. delete(m, qKey)
  215. var d, q, p big.Int
  216. d.SetBytes(dbuf)
  217. q.SetBytes(qbuf)
  218. p.SetBytes(pbuf)
  219. var dp, dq, qi *big.Int
  220. dpbuf, err := getOptionalKey(m, dpKey)
  221. if err == nil {
  222. delete(m, dpKey)
  223. dp = &big.Int{}
  224. dp.SetBytes(dpbuf)
  225. }
  226. dqbuf, err := getOptionalKey(m, dqKey)
  227. if err == nil {
  228. delete(m, dqKey)
  229. dq = &big.Int{}
  230. dq.SetBytes(dqbuf)
  231. }
  232. qibuf, err := getOptionalKey(m, qiKey)
  233. if err == nil {
  234. delete(m, qiKey)
  235. qi = &big.Int{}
  236. qi.SetBytes(qibuf)
  237. }
  238. var pubkey RSAPublicKey
  239. if err := pubkey.ExtractMap(m); err != nil {
  240. return errors.Wrap(err, `failed to extract fields for public key`)
  241. }
  242. materialized, err := pubkey.Materialize()
  243. if err != nil {
  244. return errors.Wrap(err, `failed to materialize RSA public key`)
  245. }
  246. rsaPubkey := materialized.(*rsa.PublicKey)
  247. var key rsa.PrivateKey
  248. key.PublicKey = *rsaPubkey
  249. key.D = &d
  250. key.Primes = []*big.Int{&p, &q}
  251. if dp != nil {
  252. key.Precomputed.Dp = dp
  253. }
  254. if dq != nil {
  255. key.Precomputed.Dq = dq
  256. }
  257. if qi != nil {
  258. key.Precomputed.Qinv = qi
  259. }
  260. *k = RSAPrivateKey{
  261. headers: pubkey.headers,
  262. key: &key,
  263. }
  264. return nil
  265. }
  266. // Thumbprint returns the JWK thumbprint using the indicated
  267. // hashing algorithm, according to RFC 7638
  268. func (k RSAPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
  269. return rsaThumbprint(hash, &k.key.PublicKey)
  270. }
  271. func (k RSAPublicKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
  272. return rsaThumbprint(hash, k.key)
  273. }
  274. func rsaThumbprint(hash crypto.Hash, key *rsa.PublicKey) ([]byte, error) {
  275. var buf bytes.Buffer
  276. buf.WriteString(`{"e":"`)
  277. buf.WriteString(base64.EncodeUint64ToString(uint64(key.E)))
  278. buf.WriteString(`","kty":"RSA","n":"`)
  279. buf.WriteString(base64.EncodeToString(key.N.Bytes()))
  280. buf.WriteString(`"}`)
  281. h := hash.New()
  282. buf.WriteTo(h)
  283. return h.Sum(nil), nil
  284. }