serialize.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. package jwt
  2. import (
  3. "fmt"
  4. "github.com/lestrrat-go/jwx/internal/json"
  5. "github.com/lestrrat-go/jwx/jwa"
  6. "github.com/lestrrat-go/jwx/jwe"
  7. "github.com/lestrrat-go/jwx/jws"
  8. "github.com/pkg/errors"
  9. )
  10. type SerializeCtx interface {
  11. Step() int
  12. Nested() bool
  13. }
  14. type serializeCtx struct {
  15. step int
  16. nested bool
  17. }
  18. func (ctx *serializeCtx) Step() int {
  19. return ctx.step
  20. }
  21. func (ctx *serializeCtx) Nested() bool {
  22. return ctx.nested
  23. }
  24. type SerializeStep interface {
  25. Serialize(SerializeCtx, interface{}) (interface{}, error)
  26. }
  27. // Serializer is a generic serializer for JWTs. Whereas other conveinience
  28. // functions can only do one thing (such as generate a JWS signed JWT),
  29. // Using this construct you can serialize the token however you want.
  30. //
  31. // By default the serializer only marshals the token into a JSON payload.
  32. // You must set up the rest of the steps that should be taken by the
  33. // serializer.
  34. //
  35. // For example, to marshal the token into JSON, then apply JWS and JWE
  36. // in that order, you would do:
  37. //
  38. // serialized, err := jwt.NewSerialer().
  39. // Sign(jwa.RS256, key).
  40. // Encrypt(jwa.RSA_OAEP, key.PublicKey).
  41. // Serialize(token)
  42. //
  43. // The `jwt.Sign()` function is equivalent to
  44. //
  45. // serialized, err := jwt.NewSerializer().
  46. // Sign(...args...).
  47. // Serialize(token)
  48. type Serializer struct {
  49. steps []SerializeStep
  50. }
  51. // NewSerializer creates a new empty serializer.
  52. func NewSerializer() *Serializer {
  53. return &Serializer{}
  54. }
  55. // Reset clears all of the registered steps.
  56. func (s *Serializer) Reset() *Serializer {
  57. s.steps = nil
  58. return s
  59. }
  60. // Step adds a new Step to the serialization process
  61. func (s *Serializer) Step(step SerializeStep) *Serializer {
  62. s.steps = append(s.steps, step)
  63. return s
  64. }
  65. type jsonSerializer struct{}
  66. func (jsonSerializer) Serialize(_ SerializeCtx, v interface{}) (interface{}, error) {
  67. token, ok := v.(Token)
  68. if !ok {
  69. return nil, errors.Errorf(`invalid input: expected jwt.Token`)
  70. }
  71. buf, err := json.Marshal(token)
  72. if err != nil {
  73. return nil, errors.Errorf(`failed to serialize as JSON`)
  74. }
  75. return buf, nil
  76. }
  77. type genericHeader interface {
  78. Get(string) (interface{}, bool)
  79. Set(string, interface{}) error
  80. }
  81. func setTypeOrCty(ctx SerializeCtx, hdrs genericHeader) error {
  82. // cty and typ are common between JWE/JWS, so we don't use
  83. // the constants in jws/jwe package here
  84. const typKey = `typ`
  85. const ctyKey = `cty`
  86. if ctx.Step() == 1 {
  87. // We are executed immediately after json marshaling
  88. if _, ok := hdrs.Get(typKey); !ok {
  89. if err := hdrs.Set(typKey, `JWT`); err != nil {
  90. return errors.Wrapf(err, `failed to set %s key to "JWT"`, typKey)
  91. }
  92. }
  93. } else {
  94. if ctx.Nested() {
  95. // If this is part of a nested sequence, we should set cty = 'JWT'
  96. // https://datatracker.ietf.org/doc/html/rfc7519#section-5.2
  97. if err := hdrs.Set(ctyKey, `JWT`); err != nil {
  98. return errors.Wrapf(err, `failed to set %s key to "JWT"`, ctyKey)
  99. }
  100. }
  101. }
  102. return nil
  103. }
  104. type jwsSerializer struct {
  105. alg jwa.SignatureAlgorithm
  106. key interface{}
  107. options []SignOption
  108. }
  109. func (s *jwsSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
  110. payload, ok := v.([]byte)
  111. if !ok {
  112. return nil, errors.New(`expected []byte as input`)
  113. }
  114. var hdrs jws.Headers
  115. //nolint:forcetypeassert
  116. for _, option := range s.options {
  117. switch option.Ident() {
  118. case identJwsHeaders{}:
  119. hdrs = option.Value().(jws.Headers)
  120. }
  121. }
  122. if hdrs == nil {
  123. hdrs = jws.NewHeaders()
  124. }
  125. if err := setTypeOrCty(ctx, hdrs); err != nil {
  126. return nil, err // this is already wrapped
  127. }
  128. // JWTs MUST NOT use b64 = false
  129. // https://datatracker.ietf.org/doc/html/rfc7797#section-7
  130. if v, ok := hdrs.Get("b64"); ok {
  131. if bval, bok := v.(bool); bok {
  132. if !bval { // b64 = false
  133. return nil, errors.New(`b64 cannot be false for JWTs`)
  134. }
  135. }
  136. }
  137. return jws.Sign(payload, s.alg, s.key, jws.WithHeaders(hdrs))
  138. }
  139. func (s *Serializer) Sign(alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) *Serializer {
  140. return s.Step(&jwsSerializer{
  141. alg: alg,
  142. key: key,
  143. options: options,
  144. })
  145. }
  146. type jweSerializer struct {
  147. keyalg jwa.KeyEncryptionAlgorithm
  148. key interface{}
  149. contentalg jwa.ContentEncryptionAlgorithm
  150. compressalg jwa.CompressionAlgorithm
  151. options []EncryptOption
  152. }
  153. func (s *jweSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
  154. payload, ok := v.([]byte)
  155. if !ok {
  156. return nil, fmt.Errorf(`expected []byte as input`)
  157. }
  158. var hdrs jwe.Headers
  159. //nolint:forcetypeassert
  160. for _, option := range s.options {
  161. switch option.Ident() {
  162. case identJweHeaders{}:
  163. hdrs = option.Value().(jwe.Headers)
  164. }
  165. }
  166. if hdrs == nil {
  167. hdrs = jwe.NewHeaders()
  168. }
  169. if err := setTypeOrCty(ctx, hdrs); err != nil {
  170. return nil, err // this is already wrapped
  171. }
  172. return jwe.Encrypt(payload, s.keyalg, s.key, s.contentalg, s.compressalg, jwe.WithProtectedHeaders(hdrs))
  173. }
  174. func (s *Serializer) Encrypt(keyalg jwa.KeyEncryptionAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm, options ...EncryptOption) *Serializer {
  175. return s.Step(&jweSerializer{
  176. keyalg: keyalg,
  177. key: key,
  178. contentalg: contentalg,
  179. compressalg: compressalg,
  180. options: options,
  181. })
  182. }
  183. func (s *Serializer) Serialize(t Token) ([]byte, error) {
  184. steps := make([]SerializeStep, len(s.steps)+1)
  185. steps[0] = jsonSerializer{}
  186. for i, step := range s.steps {
  187. steps[i+1] = step
  188. }
  189. var ctx serializeCtx
  190. ctx.nested = len(s.steps) > 1
  191. var payload interface{} = t
  192. for i, step := range steps {
  193. ctx.step = i
  194. v, err := step.Serialize(&ctx, payload)
  195. if err != nil {
  196. return nil, errors.Wrapf(err, `failed to serialize token at step #%d`, i+1)
  197. }
  198. payload = v
  199. }
  200. res, ok := payload.([]byte)
  201. if !ok {
  202. return nil, errors.New(`invalid serialization produced`)
  203. }
  204. return res, nil
  205. }