| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385 |
- package jwt
- import (
- "context"
- "strconv"
- "time"
- "github.com/pkg/errors"
- )
- type Clock interface {
- Now() time.Time
- }
- type ClockFunc func() time.Time
- func (f ClockFunc) Now() time.Time {
- return f()
- }
- func isSupportedTimeClaim(c string) error {
- switch c {
- case ExpirationKey, IssuedAtKey, NotBeforeKey:
- return nil
- }
- return NewValidationError(errors.Errorf(`unsupported time claim %s`, strconv.Quote(c)))
- }
- func timeClaim(t Token, clock Clock, c string) time.Time {
- switch c {
- case ExpirationKey:
- return t.Expiration()
- case IssuedAtKey:
- return t.IssuedAt()
- case NotBeforeKey:
- return t.NotBefore()
- case "":
- return clock.Now()
- }
- return time.Time{} // should *NEVER* reach here, but...
- }
- // Validate makes sure that the essential claims stand.
- //
- // See the various `WithXXX` functions for optional parameters
- // that can control the behavior of this method.
- func Validate(t Token, options ...ValidateOption) error {
- ctx := context.Background()
- var clock Clock = ClockFunc(time.Now)
- var skew time.Duration
- var validators = []Validator{
- IsIssuedAtValid(),
- IsExpirationValid(),
- IsNbfValid(),
- }
- for _, o := range options {
- //nolint:forcetypeassert
- switch o.Ident() {
- case identClock{}:
- clock = o.Value().(Clock)
- case identAcceptableSkew{}:
- skew = o.Value().(time.Duration)
- case identContext{}:
- ctx = o.Value().(context.Context)
- case identValidator{}:
- v := o.Value().(Validator)
- switch v := v.(type) {
- case *isInTimeRange:
- if v.c1 != "" {
- if err := isSupportedTimeClaim(v.c1); err != nil {
- return err
- }
- validators = append(validators, IsRequired(v.c1))
- }
- if v.c2 != "" {
- if err := isSupportedTimeClaim(v.c2); err != nil {
- return err
- }
- validators = append(validators, IsRequired(v.c2))
- }
- }
- validators = append(validators, v)
- }
- }
- ctx = SetValidationCtxSkew(ctx, skew)
- ctx = SetValidationCtxClock(ctx, clock)
- for _, v := range validators {
- if err := v.Validate(ctx, t); err != nil {
- return err
- }
- }
- return nil
- }
- type isInTimeRange struct {
- c1 string
- c2 string
- dur time.Duration
- less bool // if true, d =< c1 - c2. otherwise d >= c1 - c2
- }
- // MaxDeltaIs implements the logic behind `WithMaxDelta()` option
- func MaxDeltaIs(c1, c2 string, dur time.Duration) Validator {
- return &isInTimeRange{
- c1: c1,
- c2: c2,
- dur: dur,
- less: true,
- }
- }
- // MinDeltaIs implements the logic behind `WithMinDelta()` option
- func MinDeltaIs(c1, c2 string, dur time.Duration) Validator {
- return &isInTimeRange{
- c1: c1,
- c2: c2,
- dur: dur,
- less: false,
- }
- }
- func (iitr *isInTimeRange) Validate(ctx context.Context, t Token) error {
- clock := ValidationCtxClock(ctx) // MUST be populated
- skew := ValidationCtxSkew(ctx) // MUST be populated
- // We don't check if the claims already exist, because we already did that
- // by piggybacking on `required` check.
- t1 := timeClaim(t, clock, iitr.c1).Truncate(time.Second)
- t2 := timeClaim(t, clock, iitr.c2).Truncate(time.Second)
- if iitr.less { // t1 - t2 <= iitr.dur
- // t1 - t2 < iitr.dur + skew
- if t1.Sub(t2) > iitr.dur+skew {
- return NewValidationError(errors.Errorf(`iitr between %s and %s exceeds %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew))
- }
- } else {
- if t1.Sub(t2) < iitr.dur-skew {
- return NewValidationError(errors.Errorf(`iitr between %s and %s is less than %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew))
- }
- }
- return nil
- }
- type ValidationError interface {
- error
- isValidationError()
- }
- func NewValidationError(err error) ValidationError {
- return &validationError{error: err}
- }
- // This is a generic validation error.
- type validationError struct {
- error
- }
- func (validationError) isValidationError() {}
- var errTokenExpired = NewValidationError(errors.New(`exp not satisfied`))
- var errInvalidIssuedAt = NewValidationError(errors.New(`iat not satisfied`))
- var errTokenNotYetValid = NewValidationError(errors.New(`nbf not satisfied`))
- // ErrTokenExpired returns the immutable error used when `exp` claim
- // is not satisfied
- func ErrTokenExpired() error {
- return errTokenExpired
- }
- // ErrInvalidIssuedAt returns the immutable error used when `iat` claim
- // is not satisfied
- func ErrInvalidIssuedAt() error {
- return errInvalidIssuedAt
- }
- func ErrTokenNotYetValid() error {
- return errTokenNotYetValid
- }
- // Validator describes interface to validate a Token.
- type Validator interface {
- // Validate should return an error if a required conditions is not met.
- // This method will be changed in the next major release to return
- // jwt.ValidationError instead of error to force users to return
- // a validation error even for user-specified validators
- Validate(context.Context, Token) error
- }
- // ValidatorFunc is a type of Validator that does not have any
- // state, that is implemented as a function
- type ValidatorFunc func(context.Context, Token) error
- func (vf ValidatorFunc) Validate(ctx context.Context, tok Token) error {
- return vf(ctx, tok)
- }
- type identValidationCtxClock struct{}
- type identValidationCtxSkew struct{}
- func SetValidationCtxClock(ctx context.Context, cl Clock) context.Context {
- return context.WithValue(ctx, identValidationCtxClock{}, cl)
- }
- // ValidationCtxClock returns the Clock object associated with
- // the current validation context. This value will always be available
- // during validation of tokens.
- func ValidationCtxClock(ctx context.Context) Clock {
- //nolint:forcetypeassert
- return ctx.Value(identValidationCtxClock{}).(Clock)
- }
- func SetValidationCtxSkew(ctx context.Context, dur time.Duration) context.Context {
- return context.WithValue(ctx, identValidationCtxSkew{}, dur)
- }
- func ValidationCtxSkew(ctx context.Context) time.Duration {
- //nolint:forcetypeassert
- return ctx.Value(identValidationCtxSkew{}).(time.Duration)
- }
- // IsExpirationValid is one of the default validators that will be executed.
- // It does not need to be specified by users, but it exists as an
- // exported field so that you can check what it does.
- //
- // The supplied context.Context object must have the "clock" and "skew"
- // populated with appropriate values using SetValidationCtxClock() and
- // SetValidationCtxSkew()
- func IsExpirationValid() Validator {
- return ValidatorFunc(isExpirationValid)
- }
- func isExpirationValid(ctx context.Context, t Token) error {
- if tv := t.Expiration(); !tv.IsZero() && tv.Unix() != 0 {
- clock := ValidationCtxClock(ctx) // MUST be populated
- now := clock.Now().Truncate(time.Second)
- ttv := tv.Truncate(time.Second)
- skew := ValidationCtxSkew(ctx) // MUST be populated
- if !now.Before(ttv.Add(skew)) {
- return ErrTokenExpired()
- }
- }
- return nil
- }
- // IsIssuedAtValid is one of the default validators that will be executed.
- // It does not need to be specified by users, but it exists as an
- // exported field so that you can check what it does.
- //
- // The supplied context.Context object must have the "clock" and "skew"
- // populated with appropriate values using SetValidationCtxClock() and
- // SetValidationCtxSkew()
- func IsIssuedAtValid() Validator {
- return ValidatorFunc(isIssuedAtValid)
- }
- func isIssuedAtValid(ctx context.Context, t Token) error {
- if tv := t.IssuedAt(); !tv.IsZero() && tv.Unix() != 0 {
- clock := ValidationCtxClock(ctx) // MUST be populated
- now := clock.Now().Truncate(time.Second)
- ttv := tv.Truncate(time.Second)
- skew := ValidationCtxSkew(ctx) // MUST be populated
- if now.Before(ttv.Add(-1 * skew)) {
- return ErrInvalidIssuedAt()
- }
- }
- return nil
- }
- // IsNbfValid is one of the default validators that will be executed.
- // It does not need to be specified by users, but it exists as an
- // exported field so that you can check what it does.
- //
- // The supplied context.Context object must have the "clock" and "skew"
- // populated with appropriate values using SetValidationCtxClock() and
- // SetValidationCtxSkew()
- func IsNbfValid() Validator {
- return ValidatorFunc(isNbfValid)
- }
- func isNbfValid(ctx context.Context, t Token) error {
- if tv := t.NotBefore(); !tv.IsZero() && tv.Unix() != 0 {
- clock := ValidationCtxClock(ctx) // MUST be populated
- now := clock.Now().Truncate(time.Second)
- ttv := tv.Truncate(time.Second)
- skew := ValidationCtxSkew(ctx) // MUST be populated
- // now cannot be before t, so we check for now > t - skew
- if !now.Equal(ttv) && !now.After(ttv.Add(-1*skew)) {
- return ErrTokenNotYetValid()
- }
- }
- return nil
- }
- type claimContainsString struct {
- name string
- value string
- }
- // ClaimContainsString can be used to check if the claim called `name`, which is
- // expected to be a list of strings, contains `value`. Currently because of the
- // implementation this will probably only work for `aud` fields.
- func ClaimContainsString(name, value string) Validator {
- return claimContainsString{
- name: name,
- value: value,
- }
- }
- // IsValidationError returns true if the error is a validation error
- func IsValidationError(err error) bool {
- switch err {
- case errTokenExpired, errTokenNotYetValid, errInvalidIssuedAt:
- return true
- default:
- switch err.(type) {
- case *validationError:
- return true
- default:
- return false
- }
- }
- }
- func (ccs claimContainsString) Validate(_ context.Context, t Token) error {
- v, ok := t.Get(ccs.name)
- if !ok {
- return NewValidationError(errors.Errorf(`claim %q not found`, ccs.name))
- }
- list, ok := v.([]string)
- if !ok {
- return NewValidationError(errors.Errorf(`claim %q must be a []string (got %T)`, ccs.name, v))
- }
- var found bool
- for _, v := range list {
- if v == ccs.value {
- found = true
- break
- }
- }
- if !found {
- return NewValidationError(errors.Errorf(`%s not satisfied`, ccs.name))
- }
- return nil
- }
- type claimValueIs struct {
- name string
- value interface{}
- }
- // ClaimValueIs creates a Validator that checks if the value of claim `name`
- // matches `value`. The comparison is done using a simple `==` comparison,
- // and therefore complex comparisons may fail using this code. If you
- // need to do more, use a custom Validator.
- func ClaimValueIs(name string, value interface{}) Validator {
- return &claimValueIs{name: name, value: value}
- }
- func (cv *claimValueIs) Validate(_ context.Context, t Token) error {
- v, ok := t.Get(cv.name)
- if !ok {
- return NewValidationError(errors.Errorf(`%q not satisfied: claim %q does not exist`, cv.name, cv.name))
- }
- if v != cv.value {
- return NewValidationError(errors.Errorf(`%q not satisfied: values do not match`, cv.name))
- }
- return nil
- }
- // IsRequired creates a Validator that checks if the required claim `name`
- // exists in the token
- func IsRequired(name string) Validator {
- return isRequired(name)
- }
- type isRequired string
- func (ir isRequired) Validate(_ context.Context, t Token) error {
- _, ok := t.Get(string(ir))
- if !ok {
- return NewValidationError(errors.Errorf(`required claim %q was not found`, string(ir)))
- }
- return nil
- }
|