| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529 |
- // This file is auto-generated by jwt/internal/cmd/gentoken/main.go. DO NOT EDIT
- package jwt
- import (
- "bytes"
- "context"
- "sort"
- "sync"
- "time"
- "github.com/lestrrat-go/iter/mapiter"
- "github.com/lestrrat-go/jwx/internal/base64"
- "github.com/lestrrat-go/jwx/internal/iter"
- "github.com/lestrrat-go/jwx/internal/json"
- "github.com/lestrrat-go/jwx/internal/pool"
- "github.com/lestrrat-go/jwx/jwt/internal/types"
- "github.com/pkg/errors"
- )
- const (
- AudienceKey = "aud"
- ExpirationKey = "exp"
- IssuedAtKey = "iat"
- IssuerKey = "iss"
- JwtIDKey = "jti"
- NotBeforeKey = "nbf"
- SubjectKey = "sub"
- )
- // Token represents a generic JWT token.
- // which are type-aware (to an extent). Other claims may be accessed via the `Get`/`Set`
- // methods but their types are not taken into consideration at all. If you have non-standard
- // claims that you must frequently access, consider creating accessors functions
- // like the following
- //
- // func SetFoo(tok jwt.Token) error
- // func GetFoo(tok jwt.Token) (*Customtyp, error)
- //
- // Embedding jwt.Token into another struct is not recommended, because
- // jwt.Token needs to handle private claims, and this really does not
- // work well when it is embedded in other structure
- type Token interface {
- // Audience returns the value for "aud" field of the token
- Audience() []string
- // Expiration returns the value for "exp" field of the token
- Expiration() time.Time
- // IssuedAt returns the value for "iat" field of the token
- IssuedAt() time.Time
- // Issuer returns the value for "iss" field of the token
- Issuer() string
- // JwtID returns the value for "jti" field of the token
- JwtID() string
- // NotBefore returns the value for "nbf" field of the token
- NotBefore() time.Time
- // Subject returns the value for "sub" field of the token
- Subject() string
- // PrivateClaims return the entire set of fields (claims) in the token
- // *other* than the pre-defined fields such as `iss`, `nbf`, `iat`, etc.
- PrivateClaims() map[string]interface{}
- // Get returns the value of the corresponding field in the token, such as
- // `nbf`, `exp`, `iat`, and other user-defined fields. If the field does not
- // exist in the token, the second return value will be `false`
- //
- // If you need to access fields like `alg`, `kid`, `jku`, etc, you need
- // to access the corresponding fields in the JWS/JWE message. For this,
- // you will need to access them by directly parsing the payload using
- // `jws.Parse` and `jwe.Parse`
- Get(string) (interface{}, bool)
- // Set assigns a value to the corresponding field in the token. Some
- // pre-defined fields such as `nbf`, `iat`, `iss` need their values to
- // be of a specific type. See the other getter methods in this interface
- // for the types of each of these fields
- Set(string, interface{}) error
- Remove(string) error
- Clone() (Token, error)
- Iterate(context.Context) Iterator
- Walk(context.Context, Visitor) error
- AsMap(context.Context) (map[string]interface{}, error)
- }
- type stdToken struct {
- mu *sync.RWMutex
- dc DecodeCtx // per-object context for decoding
- audience types.StringList // https://tools.ietf.org/html/rfc7519#section-4.1.3
- expiration *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.4
- issuedAt *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.6
- issuer *string // https://tools.ietf.org/html/rfc7519#section-4.1.1
- jwtID *string // https://tools.ietf.org/html/rfc7519#section-4.1.7
- notBefore *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.5
- subject *string // https://tools.ietf.org/html/rfc7519#section-4.1.2
- privateClaims map[string]interface{}
- }
- // New creates a standard token, with minimal knowledge of
- // possible claims. Standard claims include"aud", "exp", "iat", "iss", "jti", "nbf" and "sub".
- // Convenience accessors are provided for these standard claims
- func New() Token {
- return &stdToken{
- mu: &sync.RWMutex{},
- privateClaims: make(map[string]interface{}),
- }
- }
- func (t *stdToken) Get(name string) (interface{}, bool) {
- t.mu.RLock()
- defer t.mu.RUnlock()
- switch name {
- case AudienceKey:
- if t.audience == nil {
- return nil, false
- }
- v := t.audience.Get()
- return v, true
- case ExpirationKey:
- if t.expiration == nil {
- return nil, false
- }
- v := t.expiration.Get()
- return v, true
- case IssuedAtKey:
- if t.issuedAt == nil {
- return nil, false
- }
- v := t.issuedAt.Get()
- return v, true
- case IssuerKey:
- if t.issuer == nil {
- return nil, false
- }
- v := *(t.issuer)
- return v, true
- case JwtIDKey:
- if t.jwtID == nil {
- return nil, false
- }
- v := *(t.jwtID)
- return v, true
- case NotBeforeKey:
- if t.notBefore == nil {
- return nil, false
- }
- v := t.notBefore.Get()
- return v, true
- case SubjectKey:
- if t.subject == nil {
- return nil, false
- }
- v := *(t.subject)
- return v, true
- default:
- v, ok := t.privateClaims[name]
- return v, ok
- }
- }
- func (t *stdToken) Remove(key string) error {
- t.mu.Lock()
- defer t.mu.Unlock()
- switch key {
- case AudienceKey:
- t.audience = nil
- case ExpirationKey:
- t.expiration = nil
- case IssuedAtKey:
- t.issuedAt = nil
- case IssuerKey:
- t.issuer = nil
- case JwtIDKey:
- t.jwtID = nil
- case NotBeforeKey:
- t.notBefore = nil
- case SubjectKey:
- t.subject = nil
- default:
- delete(t.privateClaims, key)
- }
- return nil
- }
- func (t *stdToken) Set(name string, value interface{}) error {
- t.mu.Lock()
- defer t.mu.Unlock()
- return t.setNoLock(name, value)
- }
- func (t *stdToken) DecodeCtx() DecodeCtx {
- t.mu.RLock()
- defer t.mu.RUnlock()
- return t.dc
- }
- func (t *stdToken) SetDecodeCtx(v DecodeCtx) {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.dc = v
- }
- func (t *stdToken) setNoLock(name string, value interface{}) error {
- switch name {
- case AudienceKey:
- var acceptor types.StringList
- if err := acceptor.Accept(value); err != nil {
- return errors.Wrapf(err, `invalid value for %s key`, AudienceKey)
- }
- t.audience = acceptor
- return nil
- case ExpirationKey:
- var acceptor types.NumericDate
- if err := acceptor.Accept(value); err != nil {
- return errors.Wrapf(err, `invalid value for %s key`, ExpirationKey)
- }
- t.expiration = &acceptor
- return nil
- case IssuedAtKey:
- var acceptor types.NumericDate
- if err := acceptor.Accept(value); err != nil {
- return errors.Wrapf(err, `invalid value for %s key`, IssuedAtKey)
- }
- t.issuedAt = &acceptor
- return nil
- case IssuerKey:
- if v, ok := value.(string); ok {
- t.issuer = &v
- return nil
- }
- return errors.Errorf(`invalid value for %s key: %T`, IssuerKey, value)
- case JwtIDKey:
- if v, ok := value.(string); ok {
- t.jwtID = &v
- return nil
- }
- return errors.Errorf(`invalid value for %s key: %T`, JwtIDKey, value)
- case NotBeforeKey:
- var acceptor types.NumericDate
- if err := acceptor.Accept(value); err != nil {
- return errors.Wrapf(err, `invalid value for %s key`, NotBeforeKey)
- }
- t.notBefore = &acceptor
- return nil
- case SubjectKey:
- if v, ok := value.(string); ok {
- t.subject = &v
- return nil
- }
- return errors.Errorf(`invalid value for %s key: %T`, SubjectKey, value)
- default:
- if t.privateClaims == nil {
- t.privateClaims = map[string]interface{}{}
- }
- t.privateClaims[name] = value
- }
- return nil
- }
- func (t *stdToken) Audience() []string {
- t.mu.RLock()
- defer t.mu.RUnlock()
- if t.audience != nil {
- return t.audience.Get()
- }
- return nil
- }
- func (t *stdToken) Expiration() time.Time {
- t.mu.RLock()
- defer t.mu.RUnlock()
- if t.expiration != nil {
- return t.expiration.Get()
- }
- return time.Time{}
- }
- func (t *stdToken) IssuedAt() time.Time {
- t.mu.RLock()
- defer t.mu.RUnlock()
- if t.issuedAt != nil {
- return t.issuedAt.Get()
- }
- return time.Time{}
- }
- func (t *stdToken) Issuer() string {
- t.mu.RLock()
- defer t.mu.RUnlock()
- if t.issuer != nil {
- return *(t.issuer)
- }
- return ""
- }
- func (t *stdToken) JwtID() string {
- t.mu.RLock()
- defer t.mu.RUnlock()
- if t.jwtID != nil {
- return *(t.jwtID)
- }
- return ""
- }
- func (t *stdToken) NotBefore() time.Time {
- t.mu.RLock()
- defer t.mu.RUnlock()
- if t.notBefore != nil {
- return t.notBefore.Get()
- }
- return time.Time{}
- }
- func (t *stdToken) Subject() string {
- t.mu.RLock()
- defer t.mu.RUnlock()
- if t.subject != nil {
- return *(t.subject)
- }
- return ""
- }
- func (t *stdToken) PrivateClaims() map[string]interface{} {
- t.mu.RLock()
- defer t.mu.RUnlock()
- return t.privateClaims
- }
- func (t *stdToken) makePairs() []*ClaimPair {
- t.mu.RLock()
- defer t.mu.RUnlock()
- pairs := make([]*ClaimPair, 0, 7)
- if t.audience != nil {
- v := t.audience.Get()
- pairs = append(pairs, &ClaimPair{Key: AudienceKey, Value: v})
- }
- if t.expiration != nil {
- v := t.expiration.Get()
- pairs = append(pairs, &ClaimPair{Key: ExpirationKey, Value: v})
- }
- if t.issuedAt != nil {
- v := t.issuedAt.Get()
- pairs = append(pairs, &ClaimPair{Key: IssuedAtKey, Value: v})
- }
- if t.issuer != nil {
- v := *(t.issuer)
- pairs = append(pairs, &ClaimPair{Key: IssuerKey, Value: v})
- }
- if t.jwtID != nil {
- v := *(t.jwtID)
- pairs = append(pairs, &ClaimPair{Key: JwtIDKey, Value: v})
- }
- if t.notBefore != nil {
- v := t.notBefore.Get()
- pairs = append(pairs, &ClaimPair{Key: NotBeforeKey, Value: v})
- }
- if t.subject != nil {
- v := *(t.subject)
- pairs = append(pairs, &ClaimPair{Key: SubjectKey, Value: v})
- }
- for k, v := range t.privateClaims {
- pairs = append(pairs, &ClaimPair{Key: k, Value: v})
- }
- sort.Slice(pairs, func(i, j int) bool {
- return pairs[i].Key.(string) < pairs[j].Key.(string)
- })
- return pairs
- }
- func (t *stdToken) UnmarshalJSON(buf []byte) error {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.audience = nil
- t.expiration = nil
- t.issuedAt = nil
- t.issuer = nil
- t.jwtID = nil
- t.notBefore = nil
- t.subject = nil
- dec := json.NewDecoder(bytes.NewReader(buf))
- LOOP:
- for {
- tok, err := dec.Token()
- if err != nil {
- return errors.Wrap(err, `error reading token`)
- }
- switch tok := tok.(type) {
- case json.Delim:
- // Assuming we're doing everything correctly, we should ONLY
- // get either '{' or '}' here.
- if tok == '}' { // End of object
- break LOOP
- } else if tok != '{' {
- return errors.Errorf(`expected '{', but got '%c'`, tok)
- }
- case string: // Objects can only have string keys
- switch tok {
- case AudienceKey:
- var decoded types.StringList
- if err := dec.Decode(&decoded); err != nil {
- return errors.Wrapf(err, `failed to decode value for key %s`, AudienceKey)
- }
- t.audience = decoded
- case ExpirationKey:
- var decoded types.NumericDate
- if err := dec.Decode(&decoded); err != nil {
- return errors.Wrapf(err, `failed to decode value for key %s`, ExpirationKey)
- }
- t.expiration = &decoded
- case IssuedAtKey:
- var decoded types.NumericDate
- if err := dec.Decode(&decoded); err != nil {
- return errors.Wrapf(err, `failed to decode value for key %s`, IssuedAtKey)
- }
- t.issuedAt = &decoded
- case IssuerKey:
- if err := json.AssignNextStringToken(&t.issuer, dec); err != nil {
- return errors.Wrapf(err, `failed to decode value for key %s`, IssuerKey)
- }
- case JwtIDKey:
- if err := json.AssignNextStringToken(&t.jwtID, dec); err != nil {
- return errors.Wrapf(err, `failed to decode value for key %s`, JwtIDKey)
- }
- case NotBeforeKey:
- var decoded types.NumericDate
- if err := dec.Decode(&decoded); err != nil {
- return errors.Wrapf(err, `failed to decode value for key %s`, NotBeforeKey)
- }
- t.notBefore = &decoded
- case SubjectKey:
- if err := json.AssignNextStringToken(&t.subject, dec); err != nil {
- return errors.Wrapf(err, `failed to decode value for key %s`, SubjectKey)
- }
- default:
- if dc := t.dc; dc != nil {
- if localReg := dc.Registry(); localReg != nil {
- decoded, err := localReg.Decode(dec, tok)
- if err == nil {
- t.setNoLock(tok, decoded)
- continue
- }
- }
- }
- decoded, err := registry.Decode(dec, tok)
- if err == nil {
- t.setNoLock(tok, decoded)
- continue
- }
- return errors.Wrapf(err, `could not decode field %s`, tok)
- }
- default:
- return errors.Errorf(`invalid token %T`, tok)
- }
- }
- return nil
- }
- func (t stdToken) MarshalJSON() ([]byte, error) {
- t.mu.RLock()
- defer t.mu.RUnlock()
- buf := pool.GetBytesBuffer()
- defer pool.ReleaseBytesBuffer(buf)
- buf.WriteByte('{')
- enc := json.NewEncoder(buf)
- for i, pair := range t.makePairs() {
- f := pair.Key.(string)
- if i > 0 {
- buf.WriteByte(',')
- }
- buf.WriteRune('"')
- buf.WriteString(f)
- buf.WriteString(`":`)
- switch f {
- case AudienceKey:
- if err := json.EncodeAudience(enc, pair.Value.([]string)); err != nil {
- return nil, errors.Wrap(err, `failed to encode "aud"`)
- }
- continue
- case ExpirationKey, IssuedAtKey, NotBeforeKey:
- enc.Encode(pair.Value.(time.Time).Unix())
- continue
- }
- switch v := pair.Value.(type) {
- case []byte:
- buf.WriteRune('"')
- buf.WriteString(base64.EncodeToString(v))
- buf.WriteRune('"')
- default:
- if err := enc.Encode(v); err != nil {
- return nil, errors.Wrapf(err, `failed to marshal field %s`, f)
- }
- buf.Truncate(buf.Len() - 1)
- }
- }
- buf.WriteByte('}')
- ret := make([]byte, buf.Len())
- copy(ret, buf.Bytes())
- return ret, nil
- }
- func (t *stdToken) Iterate(ctx context.Context) Iterator {
- pairs := t.makePairs()
- ch := make(chan *ClaimPair, len(pairs))
- go func(ctx context.Context, ch chan *ClaimPair, pairs []*ClaimPair) {
- defer close(ch)
- for _, pair := range pairs {
- select {
- case <-ctx.Done():
- return
- case ch <- pair:
- }
- }
- }(ctx, ch, pairs)
- return mapiter.New(ch)
- }
- func (t *stdToken) Walk(ctx context.Context, visitor Visitor) error {
- return iter.WalkMap(ctx, t, visitor)
- }
- func (t *stdToken) AsMap(ctx context.Context) (map[string]interface{}, error) {
- return iter.AsMap(ctx, t)
- }
|