token_gen.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. // This file is auto-generated by jwt/internal/cmd/gentoken/main.go. DO NOT EDIT
  2. package jwt
  3. import (
  4. "bytes"
  5. "context"
  6. "sort"
  7. "sync"
  8. "time"
  9. "github.com/lestrrat-go/iter/mapiter"
  10. "github.com/lestrrat-go/jwx/internal/base64"
  11. "github.com/lestrrat-go/jwx/internal/iter"
  12. "github.com/lestrrat-go/jwx/internal/json"
  13. "github.com/lestrrat-go/jwx/internal/pool"
  14. "github.com/lestrrat-go/jwx/jwt/internal/types"
  15. "github.com/pkg/errors"
  16. )
  17. const (
  18. AudienceKey = "aud"
  19. ExpirationKey = "exp"
  20. IssuedAtKey = "iat"
  21. IssuerKey = "iss"
  22. JwtIDKey = "jti"
  23. NotBeforeKey = "nbf"
  24. SubjectKey = "sub"
  25. )
  26. // Token represents a generic JWT token.
  27. // which are type-aware (to an extent). Other claims may be accessed via the `Get`/`Set`
  28. // methods but their types are not taken into consideration at all. If you have non-standard
  29. // claims that you must frequently access, consider creating accessors functions
  30. // like the following
  31. //
  32. // func SetFoo(tok jwt.Token) error
  33. // func GetFoo(tok jwt.Token) (*Customtyp, error)
  34. //
  35. // Embedding jwt.Token into another struct is not recommended, because
  36. // jwt.Token needs to handle private claims, and this really does not
  37. // work well when it is embedded in other structure
  38. type Token interface {
  39. // Audience returns the value for "aud" field of the token
  40. Audience() []string
  41. // Expiration returns the value for "exp" field of the token
  42. Expiration() time.Time
  43. // IssuedAt returns the value for "iat" field of the token
  44. IssuedAt() time.Time
  45. // Issuer returns the value for "iss" field of the token
  46. Issuer() string
  47. // JwtID returns the value for "jti" field of the token
  48. JwtID() string
  49. // NotBefore returns the value for "nbf" field of the token
  50. NotBefore() time.Time
  51. // Subject returns the value for "sub" field of the token
  52. Subject() string
  53. // PrivateClaims return the entire set of fields (claims) in the token
  54. // *other* than the pre-defined fields such as `iss`, `nbf`, `iat`, etc.
  55. PrivateClaims() map[string]interface{}
  56. // Get returns the value of the corresponding field in the token, such as
  57. // `nbf`, `exp`, `iat`, and other user-defined fields. If the field does not
  58. // exist in the token, the second return value will be `false`
  59. //
  60. // If you need to access fields like `alg`, `kid`, `jku`, etc, you need
  61. // to access the corresponding fields in the JWS/JWE message. For this,
  62. // you will need to access them by directly parsing the payload using
  63. // `jws.Parse` and `jwe.Parse`
  64. Get(string) (interface{}, bool)
  65. // Set assigns a value to the corresponding field in the token. Some
  66. // pre-defined fields such as `nbf`, `iat`, `iss` need their values to
  67. // be of a specific type. See the other getter methods in this interface
  68. // for the types of each of these fields
  69. Set(string, interface{}) error
  70. Remove(string) error
  71. Clone() (Token, error)
  72. Iterate(context.Context) Iterator
  73. Walk(context.Context, Visitor) error
  74. AsMap(context.Context) (map[string]interface{}, error)
  75. }
  76. type stdToken struct {
  77. mu *sync.RWMutex
  78. dc DecodeCtx // per-object context for decoding
  79. audience types.StringList // https://tools.ietf.org/html/rfc7519#section-4.1.3
  80. expiration *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.4
  81. issuedAt *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.6
  82. issuer *string // https://tools.ietf.org/html/rfc7519#section-4.1.1
  83. jwtID *string // https://tools.ietf.org/html/rfc7519#section-4.1.7
  84. notBefore *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.5
  85. subject *string // https://tools.ietf.org/html/rfc7519#section-4.1.2
  86. privateClaims map[string]interface{}
  87. }
  88. // New creates a standard token, with minimal knowledge of
  89. // possible claims. Standard claims include"aud", "exp", "iat", "iss", "jti", "nbf" and "sub".
  90. // Convenience accessors are provided for these standard claims
  91. func New() Token {
  92. return &stdToken{
  93. mu: &sync.RWMutex{},
  94. privateClaims: make(map[string]interface{}),
  95. }
  96. }
  97. func (t *stdToken) Get(name string) (interface{}, bool) {
  98. t.mu.RLock()
  99. defer t.mu.RUnlock()
  100. switch name {
  101. case AudienceKey:
  102. if t.audience == nil {
  103. return nil, false
  104. }
  105. v := t.audience.Get()
  106. return v, true
  107. case ExpirationKey:
  108. if t.expiration == nil {
  109. return nil, false
  110. }
  111. v := t.expiration.Get()
  112. return v, true
  113. case IssuedAtKey:
  114. if t.issuedAt == nil {
  115. return nil, false
  116. }
  117. v := t.issuedAt.Get()
  118. return v, true
  119. case IssuerKey:
  120. if t.issuer == nil {
  121. return nil, false
  122. }
  123. v := *(t.issuer)
  124. return v, true
  125. case JwtIDKey:
  126. if t.jwtID == nil {
  127. return nil, false
  128. }
  129. v := *(t.jwtID)
  130. return v, true
  131. case NotBeforeKey:
  132. if t.notBefore == nil {
  133. return nil, false
  134. }
  135. v := t.notBefore.Get()
  136. return v, true
  137. case SubjectKey:
  138. if t.subject == nil {
  139. return nil, false
  140. }
  141. v := *(t.subject)
  142. return v, true
  143. default:
  144. v, ok := t.privateClaims[name]
  145. return v, ok
  146. }
  147. }
  148. func (t *stdToken) Remove(key string) error {
  149. t.mu.Lock()
  150. defer t.mu.Unlock()
  151. switch key {
  152. case AudienceKey:
  153. t.audience = nil
  154. case ExpirationKey:
  155. t.expiration = nil
  156. case IssuedAtKey:
  157. t.issuedAt = nil
  158. case IssuerKey:
  159. t.issuer = nil
  160. case JwtIDKey:
  161. t.jwtID = nil
  162. case NotBeforeKey:
  163. t.notBefore = nil
  164. case SubjectKey:
  165. t.subject = nil
  166. default:
  167. delete(t.privateClaims, key)
  168. }
  169. return nil
  170. }
  171. func (t *stdToken) Set(name string, value interface{}) error {
  172. t.mu.Lock()
  173. defer t.mu.Unlock()
  174. return t.setNoLock(name, value)
  175. }
  176. func (t *stdToken) DecodeCtx() DecodeCtx {
  177. t.mu.RLock()
  178. defer t.mu.RUnlock()
  179. return t.dc
  180. }
  181. func (t *stdToken) SetDecodeCtx(v DecodeCtx) {
  182. t.mu.Lock()
  183. defer t.mu.Unlock()
  184. t.dc = v
  185. }
  186. func (t *stdToken) setNoLock(name string, value interface{}) error {
  187. switch name {
  188. case AudienceKey:
  189. var acceptor types.StringList
  190. if err := acceptor.Accept(value); err != nil {
  191. return errors.Wrapf(err, `invalid value for %s key`, AudienceKey)
  192. }
  193. t.audience = acceptor
  194. return nil
  195. case ExpirationKey:
  196. var acceptor types.NumericDate
  197. if err := acceptor.Accept(value); err != nil {
  198. return errors.Wrapf(err, `invalid value for %s key`, ExpirationKey)
  199. }
  200. t.expiration = &acceptor
  201. return nil
  202. case IssuedAtKey:
  203. var acceptor types.NumericDate
  204. if err := acceptor.Accept(value); err != nil {
  205. return errors.Wrapf(err, `invalid value for %s key`, IssuedAtKey)
  206. }
  207. t.issuedAt = &acceptor
  208. return nil
  209. case IssuerKey:
  210. if v, ok := value.(string); ok {
  211. t.issuer = &v
  212. return nil
  213. }
  214. return errors.Errorf(`invalid value for %s key: %T`, IssuerKey, value)
  215. case JwtIDKey:
  216. if v, ok := value.(string); ok {
  217. t.jwtID = &v
  218. return nil
  219. }
  220. return errors.Errorf(`invalid value for %s key: %T`, JwtIDKey, value)
  221. case NotBeforeKey:
  222. var acceptor types.NumericDate
  223. if err := acceptor.Accept(value); err != nil {
  224. return errors.Wrapf(err, `invalid value for %s key`, NotBeforeKey)
  225. }
  226. t.notBefore = &acceptor
  227. return nil
  228. case SubjectKey:
  229. if v, ok := value.(string); ok {
  230. t.subject = &v
  231. return nil
  232. }
  233. return errors.Errorf(`invalid value for %s key: %T`, SubjectKey, value)
  234. default:
  235. if t.privateClaims == nil {
  236. t.privateClaims = map[string]interface{}{}
  237. }
  238. t.privateClaims[name] = value
  239. }
  240. return nil
  241. }
  242. func (t *stdToken) Audience() []string {
  243. t.mu.RLock()
  244. defer t.mu.RUnlock()
  245. if t.audience != nil {
  246. return t.audience.Get()
  247. }
  248. return nil
  249. }
  250. func (t *stdToken) Expiration() time.Time {
  251. t.mu.RLock()
  252. defer t.mu.RUnlock()
  253. if t.expiration != nil {
  254. return t.expiration.Get()
  255. }
  256. return time.Time{}
  257. }
  258. func (t *stdToken) IssuedAt() time.Time {
  259. t.mu.RLock()
  260. defer t.mu.RUnlock()
  261. if t.issuedAt != nil {
  262. return t.issuedAt.Get()
  263. }
  264. return time.Time{}
  265. }
  266. func (t *stdToken) Issuer() string {
  267. t.mu.RLock()
  268. defer t.mu.RUnlock()
  269. if t.issuer != nil {
  270. return *(t.issuer)
  271. }
  272. return ""
  273. }
  274. func (t *stdToken) JwtID() string {
  275. t.mu.RLock()
  276. defer t.mu.RUnlock()
  277. if t.jwtID != nil {
  278. return *(t.jwtID)
  279. }
  280. return ""
  281. }
  282. func (t *stdToken) NotBefore() time.Time {
  283. t.mu.RLock()
  284. defer t.mu.RUnlock()
  285. if t.notBefore != nil {
  286. return t.notBefore.Get()
  287. }
  288. return time.Time{}
  289. }
  290. func (t *stdToken) Subject() string {
  291. t.mu.RLock()
  292. defer t.mu.RUnlock()
  293. if t.subject != nil {
  294. return *(t.subject)
  295. }
  296. return ""
  297. }
  298. func (t *stdToken) PrivateClaims() map[string]interface{} {
  299. t.mu.RLock()
  300. defer t.mu.RUnlock()
  301. return t.privateClaims
  302. }
  303. func (t *stdToken) makePairs() []*ClaimPair {
  304. t.mu.RLock()
  305. defer t.mu.RUnlock()
  306. pairs := make([]*ClaimPair, 0, 7)
  307. if t.audience != nil {
  308. v := t.audience.Get()
  309. pairs = append(pairs, &ClaimPair{Key: AudienceKey, Value: v})
  310. }
  311. if t.expiration != nil {
  312. v := t.expiration.Get()
  313. pairs = append(pairs, &ClaimPair{Key: ExpirationKey, Value: v})
  314. }
  315. if t.issuedAt != nil {
  316. v := t.issuedAt.Get()
  317. pairs = append(pairs, &ClaimPair{Key: IssuedAtKey, Value: v})
  318. }
  319. if t.issuer != nil {
  320. v := *(t.issuer)
  321. pairs = append(pairs, &ClaimPair{Key: IssuerKey, Value: v})
  322. }
  323. if t.jwtID != nil {
  324. v := *(t.jwtID)
  325. pairs = append(pairs, &ClaimPair{Key: JwtIDKey, Value: v})
  326. }
  327. if t.notBefore != nil {
  328. v := t.notBefore.Get()
  329. pairs = append(pairs, &ClaimPair{Key: NotBeforeKey, Value: v})
  330. }
  331. if t.subject != nil {
  332. v := *(t.subject)
  333. pairs = append(pairs, &ClaimPair{Key: SubjectKey, Value: v})
  334. }
  335. for k, v := range t.privateClaims {
  336. pairs = append(pairs, &ClaimPair{Key: k, Value: v})
  337. }
  338. sort.Slice(pairs, func(i, j int) bool {
  339. return pairs[i].Key.(string) < pairs[j].Key.(string)
  340. })
  341. return pairs
  342. }
  343. func (t *stdToken) UnmarshalJSON(buf []byte) error {
  344. t.mu.Lock()
  345. defer t.mu.Unlock()
  346. t.audience = nil
  347. t.expiration = nil
  348. t.issuedAt = nil
  349. t.issuer = nil
  350. t.jwtID = nil
  351. t.notBefore = nil
  352. t.subject = nil
  353. dec := json.NewDecoder(bytes.NewReader(buf))
  354. LOOP:
  355. for {
  356. tok, err := dec.Token()
  357. if err != nil {
  358. return errors.Wrap(err, `error reading token`)
  359. }
  360. switch tok := tok.(type) {
  361. case json.Delim:
  362. // Assuming we're doing everything correctly, we should ONLY
  363. // get either '{' or '}' here.
  364. if tok == '}' { // End of object
  365. break LOOP
  366. } else if tok != '{' {
  367. return errors.Errorf(`expected '{', but got '%c'`, tok)
  368. }
  369. case string: // Objects can only have string keys
  370. switch tok {
  371. case AudienceKey:
  372. var decoded types.StringList
  373. if err := dec.Decode(&decoded); err != nil {
  374. return errors.Wrapf(err, `failed to decode value for key %s`, AudienceKey)
  375. }
  376. t.audience = decoded
  377. case ExpirationKey:
  378. var decoded types.NumericDate
  379. if err := dec.Decode(&decoded); err != nil {
  380. return errors.Wrapf(err, `failed to decode value for key %s`, ExpirationKey)
  381. }
  382. t.expiration = &decoded
  383. case IssuedAtKey:
  384. var decoded types.NumericDate
  385. if err := dec.Decode(&decoded); err != nil {
  386. return errors.Wrapf(err, `failed to decode value for key %s`, IssuedAtKey)
  387. }
  388. t.issuedAt = &decoded
  389. case IssuerKey:
  390. if err := json.AssignNextStringToken(&t.issuer, dec); err != nil {
  391. return errors.Wrapf(err, `failed to decode value for key %s`, IssuerKey)
  392. }
  393. case JwtIDKey:
  394. if err := json.AssignNextStringToken(&t.jwtID, dec); err != nil {
  395. return errors.Wrapf(err, `failed to decode value for key %s`, JwtIDKey)
  396. }
  397. case NotBeforeKey:
  398. var decoded types.NumericDate
  399. if err := dec.Decode(&decoded); err != nil {
  400. return errors.Wrapf(err, `failed to decode value for key %s`, NotBeforeKey)
  401. }
  402. t.notBefore = &decoded
  403. case SubjectKey:
  404. if err := json.AssignNextStringToken(&t.subject, dec); err != nil {
  405. return errors.Wrapf(err, `failed to decode value for key %s`, SubjectKey)
  406. }
  407. default:
  408. if dc := t.dc; dc != nil {
  409. if localReg := dc.Registry(); localReg != nil {
  410. decoded, err := localReg.Decode(dec, tok)
  411. if err == nil {
  412. t.setNoLock(tok, decoded)
  413. continue
  414. }
  415. }
  416. }
  417. decoded, err := registry.Decode(dec, tok)
  418. if err == nil {
  419. t.setNoLock(tok, decoded)
  420. continue
  421. }
  422. return errors.Wrapf(err, `could not decode field %s`, tok)
  423. }
  424. default:
  425. return errors.Errorf(`invalid token %T`, tok)
  426. }
  427. }
  428. return nil
  429. }
  430. func (t stdToken) MarshalJSON() ([]byte, error) {
  431. t.mu.RLock()
  432. defer t.mu.RUnlock()
  433. buf := pool.GetBytesBuffer()
  434. defer pool.ReleaseBytesBuffer(buf)
  435. buf.WriteByte('{')
  436. enc := json.NewEncoder(buf)
  437. for i, pair := range t.makePairs() {
  438. f := pair.Key.(string)
  439. if i > 0 {
  440. buf.WriteByte(',')
  441. }
  442. buf.WriteRune('"')
  443. buf.WriteString(f)
  444. buf.WriteString(`":`)
  445. switch f {
  446. case AudienceKey:
  447. if err := json.EncodeAudience(enc, pair.Value.([]string)); err != nil {
  448. return nil, errors.Wrap(err, `failed to encode "aud"`)
  449. }
  450. continue
  451. case ExpirationKey, IssuedAtKey, NotBeforeKey:
  452. enc.Encode(pair.Value.(time.Time).Unix())
  453. continue
  454. }
  455. switch v := pair.Value.(type) {
  456. case []byte:
  457. buf.WriteRune('"')
  458. buf.WriteString(base64.EncodeToString(v))
  459. buf.WriteRune('"')
  460. default:
  461. if err := enc.Encode(v); err != nil {
  462. return nil, errors.Wrapf(err, `failed to marshal field %s`, f)
  463. }
  464. buf.Truncate(buf.Len() - 1)
  465. }
  466. }
  467. buf.WriteByte('}')
  468. ret := make([]byte, buf.Len())
  469. copy(ret, buf.Bytes())
  470. return ret, nil
  471. }
  472. func (t *stdToken) Iterate(ctx context.Context) Iterator {
  473. pairs := t.makePairs()
  474. ch := make(chan *ClaimPair, len(pairs))
  475. go func(ctx context.Context, ch chan *ClaimPair, pairs []*ClaimPair) {
  476. defer close(ch)
  477. for _, pair := range pairs {
  478. select {
  479. case <-ctx.Done():
  480. return
  481. case ch <- pair:
  482. }
  483. }
  484. }(ctx, ch, pairs)
  485. return mapiter.New(ch)
  486. }
  487. func (t *stdToken) Walk(ctx context.Context, visitor Visitor) error {
  488. return iter.WalkMap(ctx, t, visitor)
  489. }
  490. func (t *stdToken) AsMap(ctx context.Context) (map[string]interface{}, error) {
  491. return iter.AsMap(ctx, t)
  492. }