jwk.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. //go:generate go run internal/cmd/genheader/main.go
  2. // Package jwk implements JWK as described in https://tools.ietf.org/html/rfc7517
  3. package jwk
  4. import (
  5. "crypto/ecdsa"
  6. "crypto/rsa"
  7. "encoding/json"
  8. "io/ioutil"
  9. "net/http"
  10. "net/url"
  11. "os"
  12. "github.com/lestrrat/go-jwx/internal/base64"
  13. "github.com/lestrrat/go-jwx/jwa"
  14. "github.com/pkg/errors"
  15. )
  16. // New creates a jwk.Key from the given key.
  17. func New(key interface{}) (Key, error) {
  18. if key == nil {
  19. return nil, errors.New(`jwk.New requires a non-nil key`)
  20. }
  21. switch v := key.(type) {
  22. case *rsa.PrivateKey:
  23. return newRSAPrivateKey(v)
  24. case *rsa.PublicKey:
  25. return newRSAPublicKey(v)
  26. case *ecdsa.PrivateKey:
  27. return newECDSAPrivateKey(v)
  28. case *ecdsa.PublicKey:
  29. return newECDSAPublicKey(v)
  30. case []byte:
  31. return newSymmetricKey(v)
  32. default:
  33. return nil, errors.Errorf(`invalid key type %T`, key)
  34. }
  35. }
  36. // Fetch fetches a JWK resource specified by a URL
  37. func Fetch(urlstring string) (*Set, error) {
  38. u, err := url.Parse(urlstring)
  39. if err != nil {
  40. return nil, errors.Wrap(err, `failed to parse url`)
  41. }
  42. var src []byte
  43. switch u.Scheme {
  44. case "http", "https":
  45. res, err := http.Get(u.String())
  46. if err != nil {
  47. return nil, errors.Wrap(err, "failed to fetch remote JWK")
  48. }
  49. if res.StatusCode != http.StatusOK {
  50. return nil, errors.New("failed to fetch remote JWK (status != 200)")
  51. }
  52. // XXX Check for maximum length to read?
  53. buf, err := ioutil.ReadAll(res.Body)
  54. if err != nil {
  55. return nil, errors.Wrap(err, "failed to read JWK HTTP response body")
  56. }
  57. defer res.Body.Close()
  58. src = buf
  59. case "file":
  60. f, err := os.Open(u.Path)
  61. if err != nil {
  62. return nil, errors.Wrap(err, `failed to open jwk file`)
  63. }
  64. defer f.Close()
  65. buf, err := ioutil.ReadAll(f)
  66. if err != nil {
  67. return nil, errors.Wrap(err, `failed read content from jwk file`)
  68. }
  69. src = buf
  70. default:
  71. return nil, errors.Errorf(`invalid url scheme %s`, u.Scheme)
  72. }
  73. return Parse(src)
  74. }
  75. // FetchHTTP fetches the remote JWK and parses its contents
  76. func FetchHTTP(jwkurl string) (*Set, error) {
  77. res, err := http.Get(jwkurl)
  78. if err != nil {
  79. return nil, errors.Wrap(err, "failed to fetch remote JWK")
  80. }
  81. if res.StatusCode != http.StatusOK {
  82. return nil, errors.New("failed to fetch remote JWK (status != 200)")
  83. }
  84. // XXX Check for maximum length to read?
  85. buf, err := ioutil.ReadAll(res.Body)
  86. if err != nil {
  87. return nil, errors.Wrap(err, "failed to read JWK HTTP response body")
  88. }
  89. defer res.Body.Close()
  90. return Parse(buf)
  91. }
  92. // Parse parses JWK from the incoming byte buffer.
  93. func Parse(buf []byte) (*Set, error) {
  94. m := make(map[string]interface{})
  95. if err := json.Unmarshal(buf, &m); err != nil {
  96. return nil, errors.Wrap(err, "failed to unmarshal JWK")
  97. }
  98. // We must change what the underlying structure that gets decoded
  99. // out of this JSON is based on parameters within the already parsed
  100. // JSON (m). In order to do this, we have to go through the tedious
  101. // task of parsing the contents of this map :/
  102. if _, ok := m["keys"]; ok {
  103. var set Set
  104. if err := set.ExtractMap(m); err != nil {
  105. return nil, errors.Wrap(err, `failed to extract from map`)
  106. }
  107. return &set, nil
  108. }
  109. k, err := constructKey(m)
  110. if err != nil {
  111. return nil, errors.Wrap(err, `failed to construct key from keys`)
  112. }
  113. return &Set{Keys: []Key{k}}, nil
  114. }
  115. // ParseString parses JWK from the incoming string.
  116. func ParseString(s string) (*Set, error) {
  117. return Parse([]byte(s))
  118. }
  119. // LookupKeyID looks for keys matching the given key id. Note that the
  120. // Set *may* contain multiple keys with the same key id
  121. func (s Set) LookupKeyID(kid string) []Key {
  122. var keys []Key
  123. for _, key := range s.Keys {
  124. if key.KeyID() == kid {
  125. keys = append(keys, key)
  126. }
  127. }
  128. return keys
  129. }
  130. func (s *Set) ExtractMap(m map[string]interface{}) error {
  131. raw, ok := m["keys"]
  132. if !ok {
  133. return errors.New("missing 'keys' parameter")
  134. }
  135. v, ok := raw.([]interface{})
  136. if !ok {
  137. return errors.New("invalid 'keys' parameter")
  138. }
  139. var ks Set
  140. for _, c := range v {
  141. conf, ok := c.(map[string]interface{})
  142. if !ok {
  143. return errors.New("invalid element in 'keys'")
  144. }
  145. k, err := constructKey(conf)
  146. if err != nil {
  147. return errors.Wrap(err, `failed to construct key from map`)
  148. }
  149. ks.Keys = append(ks.Keys, k)
  150. }
  151. *s = ks
  152. return nil
  153. }
  154. func constructKey(m map[string]interface{}) (Key, error) {
  155. kty, ok := m["kty"].(string)
  156. if !ok {
  157. return nil, errors.Errorf(`unsupported kty type %T`, m[KeyTypeKey])
  158. }
  159. var key Key
  160. switch jwa.KeyType(kty) {
  161. case jwa.RSA:
  162. if _, ok := m["d"]; ok {
  163. key = &RSAPrivateKey{}
  164. } else {
  165. key = &RSAPublicKey{}
  166. }
  167. case jwa.EC:
  168. if _, ok := m["d"]; ok {
  169. key = &ECDSAPrivateKey{}
  170. } else {
  171. key = &ECDSAPublicKey{}
  172. }
  173. case jwa.OctetSeq:
  174. key = &SymmetricKey{}
  175. default:
  176. return nil, errors.Errorf(`invalid kty %s`, kty)
  177. }
  178. if err := key.ExtractMap(m); err != nil {
  179. return nil, errors.Wrap(err, `failed to extract key from map`)
  180. }
  181. return key, nil
  182. }
  183. func getRequiredKey(m map[string]interface{}, key string) ([]byte, error) {
  184. return getKey(m, key, true)
  185. }
  186. func getOptionalKey(m map[string]interface{}, key string) ([]byte, error) {
  187. return getKey(m, key, false)
  188. }
  189. func getKey(m map[string]interface{}, key string, required bool) ([]byte, error) {
  190. v, ok := m[key]
  191. if !ok {
  192. if !required {
  193. return nil, errors.Errorf(`missing parameter '%s'`, key)
  194. }
  195. return nil, errors.Errorf(`missing required parameter '%s'`, key)
  196. }
  197. vs, ok := v.(string)
  198. if !ok {
  199. return nil, errors.Errorf(`invalid type for parameter '%s': %T`, key, v)
  200. }
  201. buf, err := base64.DecodeString(vs)
  202. if err != nil {
  203. return nil, errors.Wrapf(err, `failed to base64 decode key %s`, key)
  204. }
  205. return buf, nil
  206. }