| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- package jwt
- import (
- "fmt"
- "github.com/lestrrat-go/jwx/internal/json"
- "github.com/lestrrat-go/jwx/jwa"
- "github.com/lestrrat-go/jwx/jwe"
- "github.com/lestrrat-go/jwx/jws"
- "github.com/pkg/errors"
- )
- type SerializeCtx interface {
- Step() int
- Nested() bool
- }
- type serializeCtx struct {
- step int
- nested bool
- }
- func (ctx *serializeCtx) Step() int {
- return ctx.step
- }
- func (ctx *serializeCtx) Nested() bool {
- return ctx.nested
- }
- type SerializeStep interface {
- Serialize(SerializeCtx, interface{}) (interface{}, error)
- }
- // Serializer is a generic serializer for JWTs. Whereas other conveinience
- // functions can only do one thing (such as generate a JWS signed JWT),
- // Using this construct you can serialize the token however you want.
- //
- // By default the serializer only marshals the token into a JSON payload.
- // You must set up the rest of the steps that should be taken by the
- // serializer.
- //
- // For example, to marshal the token into JSON, then apply JWS and JWE
- // in that order, you would do:
- //
- // serialized, err := jwt.NewSerialer().
- // Sign(jwa.RS256, key).
- // Encrypt(jwa.RSA_OAEP, key.PublicKey).
- // Serialize(token)
- //
- // The `jwt.Sign()` function is equivalent to
- //
- // serialized, err := jwt.NewSerializer().
- // Sign(...args...).
- // Serialize(token)
- type Serializer struct {
- steps []SerializeStep
- }
- // NewSerializer creates a new empty serializer.
- func NewSerializer() *Serializer {
- return &Serializer{}
- }
- // Reset clears all of the registered steps.
- func (s *Serializer) Reset() *Serializer {
- s.steps = nil
- return s
- }
- // Step adds a new Step to the serialization process
- func (s *Serializer) Step(step SerializeStep) *Serializer {
- s.steps = append(s.steps, step)
- return s
- }
- type jsonSerializer struct{}
- func (jsonSerializer) Serialize(_ SerializeCtx, v interface{}) (interface{}, error) {
- token, ok := v.(Token)
- if !ok {
- return nil, errors.Errorf(`invalid input: expected jwt.Token`)
- }
- buf, err := json.Marshal(token)
- if err != nil {
- return nil, errors.Errorf(`failed to serialize as JSON`)
- }
- return buf, nil
- }
- type genericHeader interface {
- Get(string) (interface{}, bool)
- Set(string, interface{}) error
- }
- func setTypeOrCty(ctx SerializeCtx, hdrs genericHeader) error {
- // cty and typ are common between JWE/JWS, so we don't use
- // the constants in jws/jwe package here
- const typKey = `typ`
- const ctyKey = `cty`
- if ctx.Step() == 1 {
- // We are executed immediately after json marshaling
- if _, ok := hdrs.Get(typKey); !ok {
- if err := hdrs.Set(typKey, `JWT`); err != nil {
- return errors.Wrapf(err, `failed to set %s key to "JWT"`, typKey)
- }
- }
- } else {
- if ctx.Nested() {
- // If this is part of a nested sequence, we should set cty = 'JWT'
- // https://datatracker.ietf.org/doc/html/rfc7519#section-5.2
- if err := hdrs.Set(ctyKey, `JWT`); err != nil {
- return errors.Wrapf(err, `failed to set %s key to "JWT"`, ctyKey)
- }
- }
- }
- return nil
- }
- type jwsSerializer struct {
- alg jwa.SignatureAlgorithm
- key interface{}
- options []SignOption
- }
- func (s *jwsSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
- payload, ok := v.([]byte)
- if !ok {
- return nil, errors.New(`expected []byte as input`)
- }
- var hdrs jws.Headers
- //nolint:forcetypeassert
- for _, option := range s.options {
- switch option.Ident() {
- case identJwsHeaders{}:
- hdrs = option.Value().(jws.Headers)
- }
- }
- if hdrs == nil {
- hdrs = jws.NewHeaders()
- }
- if err := setTypeOrCty(ctx, hdrs); err != nil {
- return nil, err // this is already wrapped
- }
- // JWTs MUST NOT use b64 = false
- // https://datatracker.ietf.org/doc/html/rfc7797#section-7
- if v, ok := hdrs.Get("b64"); ok {
- if bval, bok := v.(bool); bok {
- if !bval { // b64 = false
- return nil, errors.New(`b64 cannot be false for JWTs`)
- }
- }
- }
- return jws.Sign(payload, s.alg, s.key, jws.WithHeaders(hdrs))
- }
- func (s *Serializer) Sign(alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) *Serializer {
- return s.Step(&jwsSerializer{
- alg: alg,
- key: key,
- options: options,
- })
- }
- type jweSerializer struct {
- keyalg jwa.KeyEncryptionAlgorithm
- key interface{}
- contentalg jwa.ContentEncryptionAlgorithm
- compressalg jwa.CompressionAlgorithm
- options []EncryptOption
- }
- func (s *jweSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
- payload, ok := v.([]byte)
- if !ok {
- return nil, fmt.Errorf(`expected []byte as input`)
- }
- var hdrs jwe.Headers
- //nolint:forcetypeassert
- for _, option := range s.options {
- switch option.Ident() {
- case identJweHeaders{}:
- hdrs = option.Value().(jwe.Headers)
- }
- }
- if hdrs == nil {
- hdrs = jwe.NewHeaders()
- }
- if err := setTypeOrCty(ctx, hdrs); err != nil {
- return nil, err // this is already wrapped
- }
- return jwe.Encrypt(payload, s.keyalg, s.key, s.contentalg, s.compressalg, jwe.WithProtectedHeaders(hdrs))
- }
- func (s *Serializer) Encrypt(keyalg jwa.KeyEncryptionAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm, options ...EncryptOption) *Serializer {
- return s.Step(&jweSerializer{
- keyalg: keyalg,
- key: key,
- contentalg: contentalg,
- compressalg: compressalg,
- options: options,
- })
- }
- func (s *Serializer) Serialize(t Token) ([]byte, error) {
- steps := make([]SerializeStep, len(s.steps)+1)
- steps[0] = jsonSerializer{}
- for i, step := range s.steps {
- steps[i+1] = step
- }
- var ctx serializeCtx
- ctx.nested = len(s.steps) > 1
- var payload interface{} = t
- for i, step := range steps {
- ctx.step = i
- v, err := step.Serialize(&ctx, payload)
- if err != nil {
- return nil, errors.Wrapf(err, `failed to serialize token at step #%d`, i+1)
- }
- payload = v
- }
- res, ok := payload.([]byte)
- if !ok {
- return nil, errors.New(`invalid serialization produced`)
- }
- return res, nil
- }
|