jwt.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. //go:generate ./gen.sh
  2. // Package jwt implements JSON Web Tokens as described in https://tools.ietf.org/html/rfc7519
  3. package jwt
  4. import (
  5. "bytes"
  6. "io"
  7. "io/ioutil"
  8. "net/http"
  9. "strings"
  10. "sync/atomic"
  11. "github.com/lestrrat-go/backoff/v2"
  12. "github.com/lestrrat-go/jwx"
  13. "github.com/lestrrat-go/jwx/internal/json"
  14. "github.com/lestrrat-go/jwx/jwe"
  15. "github.com/lestrrat-go/jwx/jwa"
  16. "github.com/lestrrat-go/jwx/jwk"
  17. "github.com/lestrrat-go/jwx/jws"
  18. "github.com/pkg/errors"
  19. )
  20. const _jwt = `jwt`
  21. // Settings controls global settings that are specific to JWTs.
  22. func Settings(options ...GlobalOption) {
  23. var flattenAudienceBool bool
  24. //nolint:forcetypeassert
  25. for _, option := range options {
  26. switch option.Ident() {
  27. case identFlattenAudience{}:
  28. flattenAudienceBool = option.Value().(bool)
  29. }
  30. }
  31. v := atomic.LoadUint32(&json.FlattenAudience)
  32. if (v == 1) != flattenAudienceBool {
  33. var newVal uint32
  34. if flattenAudienceBool {
  35. newVal = 1
  36. }
  37. atomic.CompareAndSwapUint32(&json.FlattenAudience, v, newVal)
  38. }
  39. }
  40. var registry = json.NewRegistry()
  41. // ParseString calls Parse against a string
  42. func ParseString(s string, options ...ParseOption) (Token, error) {
  43. return parseBytes([]byte(s), options...)
  44. }
  45. // Parse parses the JWT token payload and creates a new `jwt.Token` object.
  46. // The token must be encoded in either JSON format or compact format.
  47. //
  48. // This function can work with encrypted and/or signed tokens. Any combination
  49. // of JWS and JWE may be applied to the token, but this function will only
  50. // attempt to verify/decrypt up to 2 levels (i.e. JWS only, JWE only, JWS then
  51. // JWE, or JWE then JWS)
  52. //
  53. // If the token is signed and you want to verify the payload matches the signature,
  54. // you must pass the jwt.WithVerify(alg, key) or jwt.WithKeySet(jwk.Set) option.
  55. // If you do not specify these parameters, no verification will be performed.
  56. //
  57. // During verification, if the JWS headers specify a key ID (`kid`), the
  58. // key used for verification must match the specified ID. If you are somehow
  59. // using a key without a `kid` (which is highly unlikely if you are working
  60. // with a JWT from a well know provider), you can workaround this by modifying
  61. // the `jwk.Key` and setting the `kid` header.
  62. //
  63. // If you also want to assert the validity of the JWT itself (i.e. expiration
  64. // and such), use the `Validate()` function on the returned token, or pass the
  65. // `WithValidate(true)` option. Validate options can also be passed to
  66. // `Parse`
  67. //
  68. // This function takes both ParseOption and ValidateOption types:
  69. // ParseOptions control the parsing behavior, and ValidateOptions are
  70. // passed to `Validate()` when `jwt.WithValidate` is specified.
  71. func Parse(s []byte, options ...ParseOption) (Token, error) {
  72. return parseBytes(s, options...)
  73. }
  74. // ParseReader calls Parse against an io.Reader
  75. func ParseReader(src io.Reader, options ...ParseOption) (Token, error) {
  76. // We're going to need the raw bytes regardless. Read it.
  77. data, err := ioutil.ReadAll(src)
  78. if err != nil {
  79. return nil, errors.Wrap(err, `failed to read from token data source`)
  80. }
  81. return parseBytes(data, options...)
  82. }
  83. type parseCtx struct {
  84. decryptParams DecryptParameters
  85. verifyParams VerifyParameters
  86. keySet jwk.Set
  87. keySetProvider KeySetProvider
  88. token Token
  89. validateOpts []ValidateOption
  90. verifyAutoOpts []jws.VerifyOption
  91. localReg *json.Registry
  92. inferAlgorithm bool
  93. pedantic bool
  94. skipVerification bool
  95. useDefault bool
  96. validate bool
  97. verifyAuto bool
  98. }
  99. func parseBytes(data []byte, options ...ParseOption) (Token, error) {
  100. var ctx parseCtx
  101. for _, o := range options {
  102. if v, ok := o.(ValidateOption); ok {
  103. ctx.validateOpts = append(ctx.validateOpts, v)
  104. continue
  105. }
  106. //nolint:forcetypeassert
  107. switch o.Ident() {
  108. case identVerifyAuto{}:
  109. ctx.verifyAuto = o.Value().(bool)
  110. case identFetchWhitelist{}:
  111. ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithFetchWhitelist(o.Value().(jwk.Whitelist)))
  112. case identHTTPClient{}:
  113. ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithHTTPClient(o.Value().(*http.Client)))
  114. case identFetchBackoff{}:
  115. ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithFetchBackoff(o.Value().(backoff.Policy)))
  116. case identJWKSetFetcher{}:
  117. ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithJWKSetFetcher(o.Value().(jws.JWKSetFetcher)))
  118. case identVerify{}:
  119. ctx.verifyParams = o.Value().(VerifyParameters)
  120. case identDecrypt{}:
  121. ctx.decryptParams = o.Value().(DecryptParameters)
  122. case identKeySet{}:
  123. ks, ok := o.Value().(jwk.Set)
  124. if !ok {
  125. return nil, errors.Errorf(`invalid JWK set passed via WithKeySet() option (%T)`, o.Value())
  126. }
  127. ctx.keySet = ks
  128. case identToken{}:
  129. token, ok := o.Value().(Token)
  130. if !ok {
  131. return nil, errors.Errorf(`invalid token passed via WithToken() option (%T)`, o.Value())
  132. }
  133. ctx.token = token
  134. case identPedantic{}:
  135. ctx.pedantic = o.Value().(bool)
  136. case identDefault{}:
  137. ctx.useDefault = o.Value().(bool)
  138. case identValidate{}:
  139. ctx.validate = o.Value().(bool)
  140. case identTypedClaim{}:
  141. pair := o.Value().(claimPair)
  142. if ctx.localReg == nil {
  143. ctx.localReg = json.NewRegistry()
  144. }
  145. ctx.localReg.Register(pair.Name, pair.Value)
  146. case identInferAlgorithmFromKey{}:
  147. ctx.inferAlgorithm = o.Value().(bool)
  148. case identKeySetProvider{}:
  149. ctx.keySetProvider = o.Value().(KeySetProvider)
  150. }
  151. }
  152. data = bytes.TrimSpace(data)
  153. return parse(&ctx, data)
  154. }
  155. const (
  156. _JwsVerifyInvalid = iota
  157. _JwsVerifyDone
  158. _JwsVerifyExpectNested
  159. _JwsVerifySkipped
  160. )
  161. func verifyJWS(ctx *parseCtx, payload []byte) ([]byte, int, error) {
  162. if ctx.verifyAuto {
  163. options := ctx.verifyAutoOpts
  164. verified, err := jws.VerifyAuto(payload, options...)
  165. return verified, _JwsVerifyDone, err
  166. }
  167. // if we have a key set or a provider, use that
  168. ks := ctx.keySet
  169. p := ctx.keySetProvider
  170. if ks != nil || p != nil {
  171. return verifyJWSWithKeySet(ctx, payload)
  172. }
  173. // We can't proceed without verification parameters
  174. vp := ctx.verifyParams
  175. if vp == nil {
  176. return nil, _JwsVerifySkipped, nil
  177. }
  178. return verifyJWSWithParams(ctx, payload, vp.Algorithm(), vp.Key())
  179. }
  180. func verifyJWSWithKeySet(ctx *parseCtx, payload []byte) ([]byte, int, error) {
  181. // First, get the JWS message
  182. msg, err := jws.Parse(payload)
  183. if err != nil {
  184. return nil, _JwsVerifyInvalid, errors.Wrap(err, `failed to parse token data as JWS message`)
  185. }
  186. ks := ctx.keySet
  187. if ks == nil { // the caller should have checked ctx.keySet || ctx.keySetProvider
  188. if p := ctx.keySetProvider; p != nil {
  189. // "trust" the payload, and parse it so that the provider can do its thing
  190. ctx.skipVerification = true
  191. tok, err := parse(ctx, msg.Payload())
  192. if err != nil {
  193. return nil, _JwsVerifyInvalid, err
  194. }
  195. ctx.skipVerification = false
  196. v, err := p.KeySetFrom(tok)
  197. if err != nil {
  198. return nil, _JwsVerifyInvalid, errors.Wrap(err, `failed to obtain jwk.Set from KeySetProvider`)
  199. }
  200. ks = v
  201. }
  202. }
  203. // Bail out early if we don't even have a key in the set
  204. if ks.Len() == 0 {
  205. return nil, _JwsVerifyInvalid, errors.New(`empty keyset provided`)
  206. }
  207. var key jwk.Key
  208. // Find the kid. we need the kid, unless the user explicitly
  209. // specified to use the "default" (the first and only) key in the set
  210. headers := msg.Signatures()[0].ProtectedHeaders()
  211. kid := headers.KeyID()
  212. if kid == "" {
  213. // If the kid is NOT specified... ctx.useDefault needs to be true, and the
  214. // JWKs must have exactly one key in it
  215. if !ctx.useDefault {
  216. return nil, _JwsVerifyInvalid, errors.New(`failed to find matching key: no key ID ("kid") specified in token`)
  217. } else if ctx.useDefault && ks.Len() > 1 {
  218. return nil, _JwsVerifyInvalid, errors.New(`failed to find matching key: no key ID ("kid") specified in token but multiple keys available in key set`)
  219. }
  220. // if we got here, then useDefault == true AND there is exactly
  221. // one key in the set.
  222. key, _ = ks.Get(0)
  223. } else {
  224. // Otherwise we better be able to look up the key, baby.
  225. v, ok := ks.LookupKeyID(kid)
  226. if !ok {
  227. return nil, _JwsVerifyInvalid, errors.Errorf(`failed to find key with key ID %q in key set`, kid)
  228. }
  229. key = v
  230. }
  231. // We found a key with matching kid. Check fo the algorithm specified in the key.
  232. // If we find an algorithm in the key, use that.
  233. if v := key.Algorithm(); v != "" {
  234. var alg jwa.SignatureAlgorithm
  235. if err := alg.Accept(v); err != nil {
  236. return nil, _JwsVerifyInvalid, errors.Wrapf(err, `invalid signature algorithm %s`, key.Algorithm())
  237. }
  238. // Okay, we have a valid algorithm, go go
  239. return verifyJWSWithParams(ctx, payload, alg, key)
  240. }
  241. if ctx.inferAlgorithm {
  242. // Check whether the JWT headers specify a valid
  243. // algorithm, use it if it's compatible.
  244. algs, err := jws.AlgorithmsForKey(key)
  245. if err != nil {
  246. return nil, _JwsVerifyInvalid, errors.Wrapf(err, `failed to get a list of signature methods for key type %s`, key.KeyType())
  247. }
  248. for _, alg := range algs {
  249. // bail out if the JWT has a `alg` field, and it doesn't match
  250. if tokAlg := headers.Algorithm(); tokAlg != "" {
  251. if tokAlg != alg {
  252. continue
  253. }
  254. }
  255. return verifyJWSWithParams(ctx, payload, alg, key)
  256. }
  257. }
  258. return nil, _JwsVerifyInvalid, errors.New(`failed to match any of the keys`)
  259. }
  260. func verifyJWSWithParams(ctx *parseCtx, payload []byte, alg jwa.SignatureAlgorithm, key interface{}) ([]byte, int, error) {
  261. var m *jws.Message
  262. var verifyOpts []jws.VerifyOption
  263. if ctx.pedantic {
  264. m = jws.NewMessage()
  265. verifyOpts = []jws.VerifyOption{jws.WithMessage(m)}
  266. }
  267. v, err := jws.Verify(payload, alg, key, verifyOpts...)
  268. if err != nil {
  269. return nil, _JwsVerifyInvalid, errors.Wrap(err, `failed to verify jws signature`)
  270. }
  271. if !ctx.pedantic {
  272. return v, _JwsVerifyDone, nil
  273. }
  274. // This payload could be a JWT+JWS, in which case typ: JWT should be there
  275. // If its JWT+(JWE or JWS or...)+JWS, then cty should be JWT
  276. for _, sig := range m.Signatures() {
  277. hdrs := sig.ProtectedHeaders()
  278. if strings.ToLower(hdrs.Type()) == _jwt {
  279. return v, _JwsVerifyDone, nil
  280. }
  281. if strings.ToLower(hdrs.ContentType()) == _jwt {
  282. return v, _JwsVerifyExpectNested, nil
  283. }
  284. }
  285. // Hmmm, it was a JWS and we got... nothing?
  286. return nil, _JwsVerifyInvalid, errors.Errorf(`expected "typ" or "cty" fields, neither could be found`)
  287. }
  288. // verify parameter exists to make sure that we don't accidentally skip
  289. // over verification just because alg == "" or key == nil or something.
  290. func parse(ctx *parseCtx, data []byte) (Token, error) {
  291. payload := data
  292. const maxDecodeLevels = 2
  293. // If cty = `JWT`, we expect this to be a nested structure
  294. var expectNested bool
  295. OUTER:
  296. for i := 0; i < maxDecodeLevels; i++ {
  297. switch kind := jwx.GuessFormat(payload); kind {
  298. case jwx.JWT:
  299. if ctx.pedantic {
  300. if expectNested {
  301. return nil, errors.Errorf(`expected nested encrypted/signed payload, got raw JWT`)
  302. }
  303. }
  304. if i == 0 {
  305. // We were NOT enveloped in other formats
  306. if !ctx.skipVerification {
  307. if _, _, err := verifyJWS(ctx, payload); err != nil {
  308. return nil, err
  309. }
  310. }
  311. }
  312. break OUTER
  313. case jwx.UnknownFormat:
  314. // "Unknown" may include invalid JWTs, for example, those who lack "aud"
  315. // claim. We could be pedantic and reject these
  316. if ctx.pedantic {
  317. return nil, errors.Errorf(`invalid JWT`)
  318. }
  319. if i == 0 {
  320. // We were NOT enveloped in other formats
  321. if !ctx.skipVerification {
  322. if _, _, err := verifyJWS(ctx, payload); err != nil {
  323. return nil, err
  324. }
  325. }
  326. }
  327. break OUTER
  328. case jwx.JWS:
  329. // Food for thought: This is going to break if you have multiple layers of
  330. // JWS enveloping using different keys. It is highly unlikely use case,
  331. // but it might happen.
  332. // skipVerification should only be set to true by us. It's used
  333. // when we just want to parse the JWT out of a payload
  334. if !ctx.skipVerification {
  335. // nested return value means:
  336. // false (next envelope _may_ need to be processed)
  337. // true (next envelope MUST be processed)
  338. v, state, err := verifyJWS(ctx, payload)
  339. if err != nil {
  340. return nil, err
  341. }
  342. if state != _JwsVerifySkipped {
  343. payload = v
  344. // We only check for cty and typ if the pedantic flag is enabled
  345. if !ctx.pedantic {
  346. continue
  347. }
  348. if state == _JwsVerifyExpectNested {
  349. expectNested = true
  350. continue OUTER
  351. }
  352. // if we're not nested, we found our target. bail out of this loop
  353. break OUTER
  354. }
  355. }
  356. // No verification.
  357. m, err := jws.Parse(data)
  358. if err != nil {
  359. return nil, errors.Wrap(err, `invalid jws message`)
  360. }
  361. payload = m.Payload()
  362. case jwx.JWE:
  363. dp := ctx.decryptParams
  364. if dp == nil {
  365. return nil, errors.Errorf(`jwt.Parse: cannot proceed with JWE encrypted payload without decryption parameters`)
  366. }
  367. var m *jwe.Message
  368. var decryptOpts []jwe.DecryptOption
  369. if ctx.pedantic {
  370. m = jwe.NewMessage()
  371. decryptOpts = []jwe.DecryptOption{jwe.WithMessage(m)}
  372. }
  373. v, err := jwe.Decrypt(data, dp.Algorithm(), dp.Key(), decryptOpts...)
  374. if err != nil {
  375. return nil, errors.Wrap(err, `failed to decrypt payload`)
  376. }
  377. if !ctx.pedantic {
  378. payload = v
  379. continue
  380. }
  381. if strings.ToLower(m.ProtectedHeaders().Type()) == _jwt {
  382. payload = v
  383. break OUTER
  384. }
  385. if strings.ToLower(m.ProtectedHeaders().ContentType()) == _jwt {
  386. expectNested = true
  387. payload = v
  388. continue OUTER
  389. }
  390. default:
  391. return nil, errors.Errorf(`unsupported format (layer: #%d)`, i+1)
  392. }
  393. expectNested = false
  394. }
  395. if ctx.token == nil {
  396. ctx.token = New()
  397. }
  398. if ctx.localReg != nil {
  399. dcToken, ok := ctx.token.(TokenWithDecodeCtx)
  400. if !ok {
  401. return nil, errors.Errorf(`typed claim was requested, but the token (%T) does not support DecodeCtx`, ctx.token)
  402. }
  403. dc := json.NewDecodeCtx(ctx.localReg)
  404. dcToken.SetDecodeCtx(dc)
  405. defer func() { dcToken.SetDecodeCtx(nil) }()
  406. }
  407. if err := json.Unmarshal(payload, ctx.token); err != nil {
  408. return nil, errors.Wrap(err, `failed to parse token`)
  409. }
  410. if ctx.validate {
  411. if err := Validate(ctx.token, ctx.validateOpts...); err != nil {
  412. return nil, err
  413. }
  414. }
  415. return ctx.token, nil
  416. }
  417. // Sign is a convenience function to create a signed JWT token serialized in
  418. // compact form.
  419. //
  420. // It accepts either a raw key (e.g. rsa.PrivateKey, ecdsa.PrivateKey, etc)
  421. // or a jwk.Key, and the name of the algorithm that should be used to sign
  422. // the token.
  423. //
  424. // If the key is a jwk.Key and the key contains a key ID (`kid` field),
  425. // then it is added to the protected header generated by the signature
  426. //
  427. // The algorithm specified in the `alg` parameter must be able to support
  428. // the type of key you provided, otherwise an error is returned.
  429. //
  430. // The protected header will also automatically have the `typ` field set
  431. // to the literal value `JWT`, unless you provide a custom value for it
  432. // by jwt.WithHeaders option.
  433. func Sign(t Token, alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) ([]byte, error) {
  434. return NewSerializer().Sign(alg, key, options...).Serialize(t)
  435. }
  436. // Equal compares two JWT tokens. Do not use `reflect.Equal` or the like
  437. // to compare tokens as they will also compare extra detail such as
  438. // sync.Mutex objects used to control concurrent access.
  439. //
  440. // The comparison for values is currently done using a simple equality ("=="),
  441. // except for time.Time, which uses time.Equal after dropping the monotonic
  442. // clock and truncating the values to 1 second accuracy.
  443. //
  444. // if both t1 and t2 are nil, returns true
  445. func Equal(t1, t2 Token) bool {
  446. if t1 == nil && t2 == nil {
  447. return true
  448. }
  449. // we already checked for t1 == t2 == nil, so safe to do this
  450. if t1 == nil || t2 == nil {
  451. return false
  452. }
  453. j1, err := json.Marshal(t1)
  454. if err != nil {
  455. return false
  456. }
  457. j2, err := json.Marshal(t2)
  458. if err != nil {
  459. return false
  460. }
  461. return bytes.Equal(j1, j2)
  462. }
  463. func (t *stdToken) Clone() (Token, error) {
  464. dst := New()
  465. for _, pair := range t.makePairs() {
  466. //nolint:forcetypeassert
  467. key := pair.Key.(string)
  468. if err := dst.Set(key, pair.Value); err != nil {
  469. return nil, errors.Wrapf(err, `failed to set %s`, key)
  470. }
  471. }
  472. return dst, nil
  473. }
  474. // RegisterCustomField allows users to specify that a private field
  475. // be decoded as an instance of the specified type. This option has
  476. // a global effect.
  477. //
  478. // For example, suppose you have a custom field `x-birthday`, which
  479. // you want to represent as a string formatted in RFC3339 in JSON,
  480. // but want it back as `time.Time`.
  481. //
  482. // In that case you would register a custom field as follows
  483. //
  484. // jwt.RegisterCustomField(`x-birthday`, timeT)
  485. //
  486. // Then `token.Get("x-birthday")` will still return an `interface{}`,
  487. // but you can convert its type to `time.Time`
  488. //
  489. // bdayif, _ := token.Get(`x-birthday`)
  490. // bday := bdayif.(time.Time)
  491. //
  492. func RegisterCustomField(name string, object interface{}) {
  493. registry.Register(name, object)
  494. }