validate.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. package jwt
  2. import (
  3. "context"
  4. "strconv"
  5. "time"
  6. "github.com/pkg/errors"
  7. )
  8. type Clock interface {
  9. Now() time.Time
  10. }
  11. type ClockFunc func() time.Time
  12. func (f ClockFunc) Now() time.Time {
  13. return f()
  14. }
  15. func isSupportedTimeClaim(c string) error {
  16. switch c {
  17. case ExpirationKey, IssuedAtKey, NotBeforeKey:
  18. return nil
  19. }
  20. return NewValidationError(errors.Errorf(`unsupported time claim %s`, strconv.Quote(c)))
  21. }
  22. func timeClaim(t Token, clock Clock, c string) time.Time {
  23. switch c {
  24. case ExpirationKey:
  25. return t.Expiration()
  26. case IssuedAtKey:
  27. return t.IssuedAt()
  28. case NotBeforeKey:
  29. return t.NotBefore()
  30. case "":
  31. return clock.Now()
  32. }
  33. return time.Time{} // should *NEVER* reach here, but...
  34. }
  35. // Validate makes sure that the essential claims stand.
  36. //
  37. // See the various `WithXXX` functions for optional parameters
  38. // that can control the behavior of this method.
  39. func Validate(t Token, options ...ValidateOption) error {
  40. ctx := context.Background()
  41. var clock Clock = ClockFunc(time.Now)
  42. var skew time.Duration
  43. var validators = []Validator{
  44. IsIssuedAtValid(),
  45. IsExpirationValid(),
  46. IsNbfValid(),
  47. }
  48. for _, o := range options {
  49. //nolint:forcetypeassert
  50. switch o.Ident() {
  51. case identClock{}:
  52. clock = o.Value().(Clock)
  53. case identAcceptableSkew{}:
  54. skew = o.Value().(time.Duration)
  55. case identContext{}:
  56. ctx = o.Value().(context.Context)
  57. case identValidator{}:
  58. v := o.Value().(Validator)
  59. switch v := v.(type) {
  60. case *isInTimeRange:
  61. if v.c1 != "" {
  62. if err := isSupportedTimeClaim(v.c1); err != nil {
  63. return err
  64. }
  65. validators = append(validators, IsRequired(v.c1))
  66. }
  67. if v.c2 != "" {
  68. if err := isSupportedTimeClaim(v.c2); err != nil {
  69. return err
  70. }
  71. validators = append(validators, IsRequired(v.c2))
  72. }
  73. }
  74. validators = append(validators, v)
  75. }
  76. }
  77. ctx = SetValidationCtxSkew(ctx, skew)
  78. ctx = SetValidationCtxClock(ctx, clock)
  79. for _, v := range validators {
  80. if err := v.Validate(ctx, t); err != nil {
  81. return err
  82. }
  83. }
  84. return nil
  85. }
  86. type isInTimeRange struct {
  87. c1 string
  88. c2 string
  89. dur time.Duration
  90. less bool // if true, d =< c1 - c2. otherwise d >= c1 - c2
  91. }
  92. // MaxDeltaIs implements the logic behind `WithMaxDelta()` option
  93. func MaxDeltaIs(c1, c2 string, dur time.Duration) Validator {
  94. return &isInTimeRange{
  95. c1: c1,
  96. c2: c2,
  97. dur: dur,
  98. less: true,
  99. }
  100. }
  101. // MinDeltaIs implements the logic behind `WithMinDelta()` option
  102. func MinDeltaIs(c1, c2 string, dur time.Duration) Validator {
  103. return &isInTimeRange{
  104. c1: c1,
  105. c2: c2,
  106. dur: dur,
  107. less: false,
  108. }
  109. }
  110. func (iitr *isInTimeRange) Validate(ctx context.Context, t Token) error {
  111. clock := ValidationCtxClock(ctx) // MUST be populated
  112. skew := ValidationCtxSkew(ctx) // MUST be populated
  113. // We don't check if the claims already exist, because we already did that
  114. // by piggybacking on `required` check.
  115. t1 := timeClaim(t, clock, iitr.c1).Truncate(time.Second)
  116. t2 := timeClaim(t, clock, iitr.c2).Truncate(time.Second)
  117. if iitr.less { // t1 - t2 <= iitr.dur
  118. // t1 - t2 < iitr.dur + skew
  119. if t1.Sub(t2) > iitr.dur+skew {
  120. return NewValidationError(errors.Errorf(`iitr between %s and %s exceeds %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew))
  121. }
  122. } else {
  123. if t1.Sub(t2) < iitr.dur-skew {
  124. return NewValidationError(errors.Errorf(`iitr between %s and %s is less than %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew))
  125. }
  126. }
  127. return nil
  128. }
  129. type ValidationError interface {
  130. error
  131. isValidationError()
  132. }
  133. func NewValidationError(err error) ValidationError {
  134. return &validationError{error: err}
  135. }
  136. // This is a generic validation error.
  137. type validationError struct {
  138. error
  139. }
  140. func (validationError) isValidationError() {}
  141. var errTokenExpired = NewValidationError(errors.New(`exp not satisfied`))
  142. var errInvalidIssuedAt = NewValidationError(errors.New(`iat not satisfied`))
  143. var errTokenNotYetValid = NewValidationError(errors.New(`nbf not satisfied`))
  144. // ErrTokenExpired returns the immutable error used when `exp` claim
  145. // is not satisfied
  146. func ErrTokenExpired() error {
  147. return errTokenExpired
  148. }
  149. // ErrInvalidIssuedAt returns the immutable error used when `iat` claim
  150. // is not satisfied
  151. func ErrInvalidIssuedAt() error {
  152. return errInvalidIssuedAt
  153. }
  154. func ErrTokenNotYetValid() error {
  155. return errTokenNotYetValid
  156. }
  157. // Validator describes interface to validate a Token.
  158. type Validator interface {
  159. // Validate should return an error if a required conditions is not met.
  160. // This method will be changed in the next major release to return
  161. // jwt.ValidationError instead of error to force users to return
  162. // a validation error even for user-specified validators
  163. Validate(context.Context, Token) error
  164. }
  165. // ValidatorFunc is a type of Validator that does not have any
  166. // state, that is implemented as a function
  167. type ValidatorFunc func(context.Context, Token) error
  168. func (vf ValidatorFunc) Validate(ctx context.Context, tok Token) error {
  169. return vf(ctx, tok)
  170. }
  171. type identValidationCtxClock struct{}
  172. type identValidationCtxSkew struct{}
  173. func SetValidationCtxClock(ctx context.Context, cl Clock) context.Context {
  174. return context.WithValue(ctx, identValidationCtxClock{}, cl)
  175. }
  176. // ValidationCtxClock returns the Clock object associated with
  177. // the current validation context. This value will always be available
  178. // during validation of tokens.
  179. func ValidationCtxClock(ctx context.Context) Clock {
  180. //nolint:forcetypeassert
  181. return ctx.Value(identValidationCtxClock{}).(Clock)
  182. }
  183. func SetValidationCtxSkew(ctx context.Context, dur time.Duration) context.Context {
  184. return context.WithValue(ctx, identValidationCtxSkew{}, dur)
  185. }
  186. func ValidationCtxSkew(ctx context.Context) time.Duration {
  187. //nolint:forcetypeassert
  188. return ctx.Value(identValidationCtxSkew{}).(time.Duration)
  189. }
  190. // IsExpirationValid is one of the default validators that will be executed.
  191. // It does not need to be specified by users, but it exists as an
  192. // exported field so that you can check what it does.
  193. //
  194. // The supplied context.Context object must have the "clock" and "skew"
  195. // populated with appropriate values using SetValidationCtxClock() and
  196. // SetValidationCtxSkew()
  197. func IsExpirationValid() Validator {
  198. return ValidatorFunc(isExpirationValid)
  199. }
  200. func isExpirationValid(ctx context.Context, t Token) error {
  201. if tv := t.Expiration(); !tv.IsZero() && tv.Unix() != 0 {
  202. clock := ValidationCtxClock(ctx) // MUST be populated
  203. now := clock.Now().Truncate(time.Second)
  204. ttv := tv.Truncate(time.Second)
  205. skew := ValidationCtxSkew(ctx) // MUST be populated
  206. if !now.Before(ttv.Add(skew)) {
  207. return ErrTokenExpired()
  208. }
  209. }
  210. return nil
  211. }
  212. // IsIssuedAtValid is one of the default validators that will be executed.
  213. // It does not need to be specified by users, but it exists as an
  214. // exported field so that you can check what it does.
  215. //
  216. // The supplied context.Context object must have the "clock" and "skew"
  217. // populated with appropriate values using SetValidationCtxClock() and
  218. // SetValidationCtxSkew()
  219. func IsIssuedAtValid() Validator {
  220. return ValidatorFunc(isIssuedAtValid)
  221. }
  222. func isIssuedAtValid(ctx context.Context, t Token) error {
  223. if tv := t.IssuedAt(); !tv.IsZero() && tv.Unix() != 0 {
  224. clock := ValidationCtxClock(ctx) // MUST be populated
  225. now := clock.Now().Truncate(time.Second)
  226. ttv := tv.Truncate(time.Second)
  227. skew := ValidationCtxSkew(ctx) // MUST be populated
  228. if now.Before(ttv.Add(-1 * skew)) {
  229. return ErrInvalidIssuedAt()
  230. }
  231. }
  232. return nil
  233. }
  234. // IsNbfValid is one of the default validators that will be executed.
  235. // It does not need to be specified by users, but it exists as an
  236. // exported field so that you can check what it does.
  237. //
  238. // The supplied context.Context object must have the "clock" and "skew"
  239. // populated with appropriate values using SetValidationCtxClock() and
  240. // SetValidationCtxSkew()
  241. func IsNbfValid() Validator {
  242. return ValidatorFunc(isNbfValid)
  243. }
  244. func isNbfValid(ctx context.Context, t Token) error {
  245. if tv := t.NotBefore(); !tv.IsZero() && tv.Unix() != 0 {
  246. clock := ValidationCtxClock(ctx) // MUST be populated
  247. now := clock.Now().Truncate(time.Second)
  248. ttv := tv.Truncate(time.Second)
  249. skew := ValidationCtxSkew(ctx) // MUST be populated
  250. // now cannot be before t, so we check for now > t - skew
  251. if !now.Equal(ttv) && !now.After(ttv.Add(-1*skew)) {
  252. return ErrTokenNotYetValid()
  253. }
  254. }
  255. return nil
  256. }
  257. type claimContainsString struct {
  258. name string
  259. value string
  260. }
  261. // ClaimContainsString can be used to check if the claim called `name`, which is
  262. // expected to be a list of strings, contains `value`. Currently because of the
  263. // implementation this will probably only work for `aud` fields.
  264. func ClaimContainsString(name, value string) Validator {
  265. return claimContainsString{
  266. name: name,
  267. value: value,
  268. }
  269. }
  270. // IsValidationError returns true if the error is a validation error
  271. func IsValidationError(err error) bool {
  272. switch err {
  273. case errTokenExpired, errTokenNotYetValid, errInvalidIssuedAt:
  274. return true
  275. default:
  276. switch err.(type) {
  277. case *validationError:
  278. return true
  279. default:
  280. return false
  281. }
  282. }
  283. }
  284. func (ccs claimContainsString) Validate(_ context.Context, t Token) error {
  285. v, ok := t.Get(ccs.name)
  286. if !ok {
  287. return NewValidationError(errors.Errorf(`claim %q not found`, ccs.name))
  288. }
  289. list, ok := v.([]string)
  290. if !ok {
  291. return NewValidationError(errors.Errorf(`claim %q must be a []string (got %T)`, ccs.name, v))
  292. }
  293. var found bool
  294. for _, v := range list {
  295. if v == ccs.value {
  296. found = true
  297. break
  298. }
  299. }
  300. if !found {
  301. return NewValidationError(errors.Errorf(`%s not satisfied`, ccs.name))
  302. }
  303. return nil
  304. }
  305. type claimValueIs struct {
  306. name string
  307. value interface{}
  308. }
  309. // ClaimValueIs creates a Validator that checks if the value of claim `name`
  310. // matches `value`. The comparison is done using a simple `==` comparison,
  311. // and therefore complex comparisons may fail using this code. If you
  312. // need to do more, use a custom Validator.
  313. func ClaimValueIs(name string, value interface{}) Validator {
  314. return &claimValueIs{name: name, value: value}
  315. }
  316. func (cv *claimValueIs) Validate(_ context.Context, t Token) error {
  317. v, ok := t.Get(cv.name)
  318. if !ok {
  319. return NewValidationError(errors.Errorf(`%q not satisfied: claim %q does not exist`, cv.name, cv.name))
  320. }
  321. if v != cv.value {
  322. return NewValidationError(errors.Errorf(`%q not satisfied: values do not match`, cv.name))
  323. }
  324. return nil
  325. }
  326. // IsRequired creates a Validator that checks if the required claim `name`
  327. // exists in the token
  328. func IsRequired(name string) Validator {
  329. return isRequired(name)
  330. }
  331. type isRequired string
  332. func (ir isRequired) Validate(_ context.Context, t Token) error {
  333. _, ok := t.Get(string(ir))
  334. if !ok {
  335. return NewValidationError(errors.Errorf(`required claim %q was not found`, string(ir)))
  336. }
  337. return nil
  338. }