| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561 |
- //go:generate ./gen.sh
- // Package jwt implements JSON Web Tokens as described in https://tools.ietf.org/html/rfc7519
- package jwt
- import (
- "bytes"
- "io"
- "io/ioutil"
- "net/http"
- "strings"
- "sync/atomic"
- "github.com/lestrrat-go/backoff/v2"
- "github.com/lestrrat-go/jwx"
- "github.com/lestrrat-go/jwx/internal/json"
- "github.com/lestrrat-go/jwx/jwe"
- "github.com/lestrrat-go/jwx/jwa"
- "github.com/lestrrat-go/jwx/jwk"
- "github.com/lestrrat-go/jwx/jws"
- "github.com/pkg/errors"
- )
- const _jwt = `jwt`
- // Settings controls global settings that are specific to JWTs.
- func Settings(options ...GlobalOption) {
- var flattenAudienceBool bool
- //nolint:forcetypeassert
- for _, option := range options {
- switch option.Ident() {
- case identFlattenAudience{}:
- flattenAudienceBool = option.Value().(bool)
- }
- }
- v := atomic.LoadUint32(&json.FlattenAudience)
- if (v == 1) != flattenAudienceBool {
- var newVal uint32
- if flattenAudienceBool {
- newVal = 1
- }
- atomic.CompareAndSwapUint32(&json.FlattenAudience, v, newVal)
- }
- }
- var registry = json.NewRegistry()
- // ParseString calls Parse against a string
- func ParseString(s string, options ...ParseOption) (Token, error) {
- return parseBytes([]byte(s), options...)
- }
- // Parse parses the JWT token payload and creates a new `jwt.Token` object.
- // The token must be encoded in either JSON format or compact format.
- //
- // This function can work with encrypted and/or signed tokens. Any combination
- // of JWS and JWE may be applied to the token, but this function will only
- // attempt to verify/decrypt up to 2 levels (i.e. JWS only, JWE only, JWS then
- // JWE, or JWE then JWS)
- //
- // If the token is signed and you want to verify the payload matches the signature,
- // you must pass the jwt.WithVerify(alg, key) or jwt.WithKeySet(jwk.Set) option.
- // If you do not specify these parameters, no verification will be performed.
- //
- // During verification, if the JWS headers specify a key ID (`kid`), the
- // key used for verification must match the specified ID. If you are somehow
- // using a key without a `kid` (which is highly unlikely if you are working
- // with a JWT from a well know provider), you can workaround this by modifying
- // the `jwk.Key` and setting the `kid` header.
- //
- // If you also want to assert the validity of the JWT itself (i.e. expiration
- // and such), use the `Validate()` function on the returned token, or pass the
- // `WithValidate(true)` option. Validate options can also be passed to
- // `Parse`
- //
- // This function takes both ParseOption and ValidateOption types:
- // ParseOptions control the parsing behavior, and ValidateOptions are
- // passed to `Validate()` when `jwt.WithValidate` is specified.
- func Parse(s []byte, options ...ParseOption) (Token, error) {
- return parseBytes(s, options...)
- }
- // ParseReader calls Parse against an io.Reader
- func ParseReader(src io.Reader, options ...ParseOption) (Token, error) {
- // We're going to need the raw bytes regardless. Read it.
- data, err := ioutil.ReadAll(src)
- if err != nil {
- return nil, errors.Wrap(err, `failed to read from token data source`)
- }
- return parseBytes(data, options...)
- }
- type parseCtx struct {
- decryptParams DecryptParameters
- verifyParams VerifyParameters
- keySet jwk.Set
- keySetProvider KeySetProvider
- token Token
- validateOpts []ValidateOption
- verifyAutoOpts []jws.VerifyOption
- localReg *json.Registry
- inferAlgorithm bool
- pedantic bool
- skipVerification bool
- useDefault bool
- validate bool
- verifyAuto bool
- }
- func parseBytes(data []byte, options ...ParseOption) (Token, error) {
- var ctx parseCtx
- for _, o := range options {
- if v, ok := o.(ValidateOption); ok {
- ctx.validateOpts = append(ctx.validateOpts, v)
- continue
- }
- //nolint:forcetypeassert
- switch o.Ident() {
- case identVerifyAuto{}:
- ctx.verifyAuto = o.Value().(bool)
- case identFetchWhitelist{}:
- ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithFetchWhitelist(o.Value().(jwk.Whitelist)))
- case identHTTPClient{}:
- ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithHTTPClient(o.Value().(*http.Client)))
- case identFetchBackoff{}:
- ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithFetchBackoff(o.Value().(backoff.Policy)))
- case identJWKSetFetcher{}:
- ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithJWKSetFetcher(o.Value().(jws.JWKSetFetcher)))
- case identVerify{}:
- ctx.verifyParams = o.Value().(VerifyParameters)
- case identDecrypt{}:
- ctx.decryptParams = o.Value().(DecryptParameters)
- case identKeySet{}:
- ks, ok := o.Value().(jwk.Set)
- if !ok {
- return nil, errors.Errorf(`invalid JWK set passed via WithKeySet() option (%T)`, o.Value())
- }
- ctx.keySet = ks
- case identToken{}:
- token, ok := o.Value().(Token)
- if !ok {
- return nil, errors.Errorf(`invalid token passed via WithToken() option (%T)`, o.Value())
- }
- ctx.token = token
- case identPedantic{}:
- ctx.pedantic = o.Value().(bool)
- case identDefault{}:
- ctx.useDefault = o.Value().(bool)
- case identValidate{}:
- ctx.validate = o.Value().(bool)
- case identTypedClaim{}:
- pair := o.Value().(claimPair)
- if ctx.localReg == nil {
- ctx.localReg = json.NewRegistry()
- }
- ctx.localReg.Register(pair.Name, pair.Value)
- case identInferAlgorithmFromKey{}:
- ctx.inferAlgorithm = o.Value().(bool)
- case identKeySetProvider{}:
- ctx.keySetProvider = o.Value().(KeySetProvider)
- }
- }
- data = bytes.TrimSpace(data)
- return parse(&ctx, data)
- }
- const (
- _JwsVerifyInvalid = iota
- _JwsVerifyDone
- _JwsVerifyExpectNested
- _JwsVerifySkipped
- )
- func verifyJWS(ctx *parseCtx, payload []byte) ([]byte, int, error) {
- if ctx.verifyAuto {
- options := ctx.verifyAutoOpts
- verified, err := jws.VerifyAuto(payload, options...)
- return verified, _JwsVerifyDone, err
- }
- // if we have a key set or a provider, use that
- ks := ctx.keySet
- p := ctx.keySetProvider
- if ks != nil || p != nil {
- return verifyJWSWithKeySet(ctx, payload)
- }
- // We can't proceed without verification parameters
- vp := ctx.verifyParams
- if vp == nil {
- return nil, _JwsVerifySkipped, nil
- }
- return verifyJWSWithParams(ctx, payload, vp.Algorithm(), vp.Key())
- }
- func verifyJWSWithKeySet(ctx *parseCtx, payload []byte) ([]byte, int, error) {
- // First, get the JWS message
- msg, err := jws.Parse(payload)
- if err != nil {
- return nil, _JwsVerifyInvalid, errors.Wrap(err, `failed to parse token data as JWS message`)
- }
- ks := ctx.keySet
- if ks == nil { // the caller should have checked ctx.keySet || ctx.keySetProvider
- if p := ctx.keySetProvider; p != nil {
- // "trust" the payload, and parse it so that the provider can do its thing
- ctx.skipVerification = true
- tok, err := parse(ctx, msg.Payload())
- if err != nil {
- return nil, _JwsVerifyInvalid, err
- }
- ctx.skipVerification = false
- v, err := p.KeySetFrom(tok)
- if err != nil {
- return nil, _JwsVerifyInvalid, errors.Wrap(err, `failed to obtain jwk.Set from KeySetProvider`)
- }
- ks = v
- }
- }
- // Bail out early if we don't even have a key in the set
- if ks.Len() == 0 {
- return nil, _JwsVerifyInvalid, errors.New(`empty keyset provided`)
- }
- var key jwk.Key
- // Find the kid. we need the kid, unless the user explicitly
- // specified to use the "default" (the first and only) key in the set
- headers := msg.Signatures()[0].ProtectedHeaders()
- kid := headers.KeyID()
- if kid == "" {
- // If the kid is NOT specified... ctx.useDefault needs to be true, and the
- // JWKs must have exactly one key in it
- if !ctx.useDefault {
- return nil, _JwsVerifyInvalid, errors.New(`failed to find matching key: no key ID ("kid") specified in token`)
- } else if ctx.useDefault && ks.Len() > 1 {
- return nil, _JwsVerifyInvalid, errors.New(`failed to find matching key: no key ID ("kid") specified in token but multiple keys available in key set`)
- }
- // if we got here, then useDefault == true AND there is exactly
- // one key in the set.
- key, _ = ks.Get(0)
- } else {
- // Otherwise we better be able to look up the key, baby.
- v, ok := ks.LookupKeyID(kid)
- if !ok {
- return nil, _JwsVerifyInvalid, errors.Errorf(`failed to find key with key ID %q in key set`, kid)
- }
- key = v
- }
- // We found a key with matching kid. Check fo the algorithm specified in the key.
- // If we find an algorithm in the key, use that.
- if v := key.Algorithm(); v != "" {
- var alg jwa.SignatureAlgorithm
- if err := alg.Accept(v); err != nil {
- return nil, _JwsVerifyInvalid, errors.Wrapf(err, `invalid signature algorithm %s`, key.Algorithm())
- }
- // Okay, we have a valid algorithm, go go
- return verifyJWSWithParams(ctx, payload, alg, key)
- }
- if ctx.inferAlgorithm {
- // Check whether the JWT headers specify a valid
- // algorithm, use it if it's compatible.
- algs, err := jws.AlgorithmsForKey(key)
- if err != nil {
- return nil, _JwsVerifyInvalid, errors.Wrapf(err, `failed to get a list of signature methods for key type %s`, key.KeyType())
- }
- for _, alg := range algs {
- // bail out if the JWT has a `alg` field, and it doesn't match
- if tokAlg := headers.Algorithm(); tokAlg != "" {
- if tokAlg != alg {
- continue
- }
- }
- return verifyJWSWithParams(ctx, payload, alg, key)
- }
- }
- return nil, _JwsVerifyInvalid, errors.New(`failed to match any of the keys`)
- }
- func verifyJWSWithParams(ctx *parseCtx, payload []byte, alg jwa.SignatureAlgorithm, key interface{}) ([]byte, int, error) {
- var m *jws.Message
- var verifyOpts []jws.VerifyOption
- if ctx.pedantic {
- m = jws.NewMessage()
- verifyOpts = []jws.VerifyOption{jws.WithMessage(m)}
- }
- v, err := jws.Verify(payload, alg, key, verifyOpts...)
- if err != nil {
- return nil, _JwsVerifyInvalid, errors.Wrap(err, `failed to verify jws signature`)
- }
- if !ctx.pedantic {
- return v, _JwsVerifyDone, nil
- }
- // This payload could be a JWT+JWS, in which case typ: JWT should be there
- // If its JWT+(JWE or JWS or...)+JWS, then cty should be JWT
- for _, sig := range m.Signatures() {
- hdrs := sig.ProtectedHeaders()
- if strings.ToLower(hdrs.Type()) == _jwt {
- return v, _JwsVerifyDone, nil
- }
- if strings.ToLower(hdrs.ContentType()) == _jwt {
- return v, _JwsVerifyExpectNested, nil
- }
- }
- // Hmmm, it was a JWS and we got... nothing?
- return nil, _JwsVerifyInvalid, errors.Errorf(`expected "typ" or "cty" fields, neither could be found`)
- }
- // verify parameter exists to make sure that we don't accidentally skip
- // over verification just because alg == "" or key == nil or something.
- func parse(ctx *parseCtx, data []byte) (Token, error) {
- payload := data
- const maxDecodeLevels = 2
- // If cty = `JWT`, we expect this to be a nested structure
- var expectNested bool
- OUTER:
- for i := 0; i < maxDecodeLevels; i++ {
- switch kind := jwx.GuessFormat(payload); kind {
- case jwx.JWT:
- if ctx.pedantic {
- if expectNested {
- return nil, errors.Errorf(`expected nested encrypted/signed payload, got raw JWT`)
- }
- }
- if i == 0 {
- // We were NOT enveloped in other formats
- if !ctx.skipVerification {
- if _, _, err := verifyJWS(ctx, payload); err != nil {
- return nil, err
- }
- }
- }
- break OUTER
- case jwx.UnknownFormat:
- // "Unknown" may include invalid JWTs, for example, those who lack "aud"
- // claim. We could be pedantic and reject these
- if ctx.pedantic {
- return nil, errors.Errorf(`invalid JWT`)
- }
- if i == 0 {
- // We were NOT enveloped in other formats
- if !ctx.skipVerification {
- if _, _, err := verifyJWS(ctx, payload); err != nil {
- return nil, err
- }
- }
- }
- break OUTER
- case jwx.JWS:
- // Food for thought: This is going to break if you have multiple layers of
- // JWS enveloping using different keys. It is highly unlikely use case,
- // but it might happen.
- // skipVerification should only be set to true by us. It's used
- // when we just want to parse the JWT out of a payload
- if !ctx.skipVerification {
- // nested return value means:
- // false (next envelope _may_ need to be processed)
- // true (next envelope MUST be processed)
- v, state, err := verifyJWS(ctx, payload)
- if err != nil {
- return nil, err
- }
- if state != _JwsVerifySkipped {
- payload = v
- // We only check for cty and typ if the pedantic flag is enabled
- if !ctx.pedantic {
- continue
- }
- if state == _JwsVerifyExpectNested {
- expectNested = true
- continue OUTER
- }
- // if we're not nested, we found our target. bail out of this loop
- break OUTER
- }
- }
- // No verification.
- m, err := jws.Parse(data)
- if err != nil {
- return nil, errors.Wrap(err, `invalid jws message`)
- }
- payload = m.Payload()
- case jwx.JWE:
- dp := ctx.decryptParams
- if dp == nil {
- return nil, errors.Errorf(`jwt.Parse: cannot proceed with JWE encrypted payload without decryption parameters`)
- }
- var m *jwe.Message
- var decryptOpts []jwe.DecryptOption
- if ctx.pedantic {
- m = jwe.NewMessage()
- decryptOpts = []jwe.DecryptOption{jwe.WithMessage(m)}
- }
- v, err := jwe.Decrypt(data, dp.Algorithm(), dp.Key(), decryptOpts...)
- if err != nil {
- return nil, errors.Wrap(err, `failed to decrypt payload`)
- }
- if !ctx.pedantic {
- payload = v
- continue
- }
- if strings.ToLower(m.ProtectedHeaders().Type()) == _jwt {
- payload = v
- break OUTER
- }
- if strings.ToLower(m.ProtectedHeaders().ContentType()) == _jwt {
- expectNested = true
- payload = v
- continue OUTER
- }
- default:
- return nil, errors.Errorf(`unsupported format (layer: #%d)`, i+1)
- }
- expectNested = false
- }
- if ctx.token == nil {
- ctx.token = New()
- }
- if ctx.localReg != nil {
- dcToken, ok := ctx.token.(TokenWithDecodeCtx)
- if !ok {
- return nil, errors.Errorf(`typed claim was requested, but the token (%T) does not support DecodeCtx`, ctx.token)
- }
- dc := json.NewDecodeCtx(ctx.localReg)
- dcToken.SetDecodeCtx(dc)
- defer func() { dcToken.SetDecodeCtx(nil) }()
- }
- if err := json.Unmarshal(payload, ctx.token); err != nil {
- return nil, errors.Wrap(err, `failed to parse token`)
- }
- if ctx.validate {
- if err := Validate(ctx.token, ctx.validateOpts...); err != nil {
- return nil, err
- }
- }
- return ctx.token, nil
- }
- // Sign is a convenience function to create a signed JWT token serialized in
- // compact form.
- //
- // It accepts either a raw key (e.g. rsa.PrivateKey, ecdsa.PrivateKey, etc)
- // or a jwk.Key, and the name of the algorithm that should be used to sign
- // the token.
- //
- // If the key is a jwk.Key and the key contains a key ID (`kid` field),
- // then it is added to the protected header generated by the signature
- //
- // The algorithm specified in the `alg` parameter must be able to support
- // the type of key you provided, otherwise an error is returned.
- //
- // The protected header will also automatically have the `typ` field set
- // to the literal value `JWT`, unless you provide a custom value for it
- // by jwt.WithHeaders option.
- func Sign(t Token, alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) ([]byte, error) {
- return NewSerializer().Sign(alg, key, options...).Serialize(t)
- }
- // Equal compares two JWT tokens. Do not use `reflect.Equal` or the like
- // to compare tokens as they will also compare extra detail such as
- // sync.Mutex objects used to control concurrent access.
- //
- // The comparison for values is currently done using a simple equality ("=="),
- // except for time.Time, which uses time.Equal after dropping the monotonic
- // clock and truncating the values to 1 second accuracy.
- //
- // if both t1 and t2 are nil, returns true
- func Equal(t1, t2 Token) bool {
- if t1 == nil && t2 == nil {
- return true
- }
- // we already checked for t1 == t2 == nil, so safe to do this
- if t1 == nil || t2 == nil {
- return false
- }
- j1, err := json.Marshal(t1)
- if err != nil {
- return false
- }
- j2, err := json.Marshal(t2)
- if err != nil {
- return false
- }
- return bytes.Equal(j1, j2)
- }
- func (t *stdToken) Clone() (Token, error) {
- dst := New()
- for _, pair := range t.makePairs() {
- //nolint:forcetypeassert
- key := pair.Key.(string)
- if err := dst.Set(key, pair.Value); err != nil {
- return nil, errors.Wrapf(err, `failed to set %s`, key)
- }
- }
- return dst, nil
- }
- // RegisterCustomField allows users to specify that a private field
- // be decoded as an instance of the specified type. This option has
- // a global effect.
- //
- // For example, suppose you have a custom field `x-birthday`, which
- // you want to represent as a string formatted in RFC3339 in JSON,
- // but want it back as `time.Time`.
- //
- // In that case you would register a custom field as follows
- //
- // jwt.RegisterCustomField(`x-birthday`, timeT)
- //
- // Then `token.Get("x-birthday")` will still return an `interface{}`,
- // but you can convert its type to `time.Time`
- //
- // bdayif, _ := token.Get(`x-birthday`)
- // bday := bdayif.(time.Time)
- //
- func RegisterCustomField(name string, object interface{}) {
- registry.Register(name, object)
- }
|