message.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. package jws
  2. import (
  3. "bytes"
  4. "context"
  5. "github.com/lestrrat-go/jwx/internal/base64"
  6. "github.com/lestrrat-go/jwx/internal/json"
  7. "github.com/lestrrat-go/jwx/internal/pool"
  8. "github.com/lestrrat-go/jwx/jwk"
  9. "github.com/pkg/errors"
  10. )
  11. type collectRawCtx struct{}
  12. func (collectRawCtx) CollectRaw() bool {
  13. return true
  14. }
  15. func NewSignature() *Signature {
  16. return &Signature{}
  17. }
  18. func (s *Signature) DecodeCtx() DecodeCtx {
  19. return s.dc
  20. }
  21. func (s *Signature) SetDecodeCtx(dc DecodeCtx) {
  22. s.dc = dc
  23. }
  24. func (s Signature) PublicHeaders() Headers {
  25. return s.headers
  26. }
  27. func (s *Signature) SetPublicHeaders(v Headers) *Signature {
  28. s.headers = v
  29. return s
  30. }
  31. func (s Signature) ProtectedHeaders() Headers {
  32. return s.protected
  33. }
  34. func (s *Signature) SetProtectedHeaders(v Headers) *Signature {
  35. s.protected = v
  36. return s
  37. }
  38. func (s Signature) Signature() []byte {
  39. return s.signature
  40. }
  41. func (s *Signature) SetSignature(v []byte) *Signature {
  42. s.signature = v
  43. return s
  44. }
  45. type signatureUnmarshalProbe struct {
  46. Header Headers `json:"header,omitempty"`
  47. Protected *string `json:"protected,omitempty"`
  48. Signature *string `json:"signature,omitempty"`
  49. }
  50. func (s *Signature) UnmarshalJSON(data []byte) error {
  51. var sup signatureUnmarshalProbe
  52. sup.Header = NewHeaders()
  53. if err := json.Unmarshal(data, &sup); err != nil {
  54. return errors.Wrap(err, `failed to unmarshal signature into temporary struct`)
  55. }
  56. s.headers = sup.Header
  57. if buf := sup.Protected; buf != nil {
  58. src := []byte(*buf)
  59. if !bytes.HasPrefix(src, []byte{'{'}) {
  60. decoded, err := base64.Decode(src)
  61. if err != nil {
  62. return errors.Wrap(err, `failed to base64 decode protected headers`)
  63. }
  64. src = decoded
  65. }
  66. prt := NewHeaders()
  67. //nolint:forcetypeassert
  68. prt.(*stdHeaders).SetDecodeCtx(s.DecodeCtx())
  69. if err := json.Unmarshal(src, prt); err != nil {
  70. return errors.Wrap(err, `failed to unmarshal protected headers`)
  71. }
  72. //nolint:forcetypeassert
  73. prt.(*stdHeaders).SetDecodeCtx(nil)
  74. s.protected = prt
  75. }
  76. decoded, err := base64.DecodeString(*sup.Signature)
  77. if err != nil {
  78. return errors.Wrap(err, `failed to base decode signature`)
  79. }
  80. s.signature = decoded
  81. return nil
  82. }
  83. // Sign populates the signature field, with a signature generated by
  84. // given the signer object and payload.
  85. //
  86. // The first return value is the raw signature in binary format.
  87. // The second return value s the full three-segment signature
  88. // (e.g. "eyXXXX.XXXXX.XXXX")
  89. func (s *Signature) Sign(payload []byte, signer Signer, key interface{}) ([]byte, []byte, error) {
  90. ctx, cancel := context.WithCancel(context.Background())
  91. defer cancel()
  92. hdrs, err := mergeHeaders(ctx, s.headers, s.protected)
  93. if err != nil {
  94. return nil, nil, errors.Wrap(err, `failed to merge headers`)
  95. }
  96. if err := hdrs.Set(AlgorithmKey, signer.Algorithm()); err != nil {
  97. return nil, nil, errors.Wrap(err, `failed to set "alg"`)
  98. }
  99. // If the key is a jwk.Key instance, obtain the raw key
  100. if jwkKey, ok := key.(jwk.Key); ok {
  101. // If we have a key ID specified by this jwk.Key, use that in the header
  102. if kid := jwkKey.KeyID(); kid != "" {
  103. if err := hdrs.Set(jwk.KeyIDKey, kid); err != nil {
  104. return nil, nil, errors.Wrap(err, `set key ID from jwk.Key`)
  105. }
  106. }
  107. }
  108. hdrbuf, err := json.Marshal(hdrs)
  109. if err != nil {
  110. return nil, nil, errors.Wrap(err, `failed to marshal headers`)
  111. }
  112. buf := pool.GetBytesBuffer()
  113. defer pool.ReleaseBytesBuffer(buf)
  114. buf.WriteString(base64.EncodeToString(hdrbuf))
  115. buf.WriteByte('.')
  116. var plen int
  117. b64 := getB64Value(hdrs)
  118. if b64 {
  119. encoded := base64.EncodeToString(payload)
  120. plen = len(encoded)
  121. buf.WriteString(encoded)
  122. } else {
  123. if !s.detached {
  124. if bytes.Contains(payload, []byte{'.'}) {
  125. return nil, nil, errors.New(`payload must not contain a "."`)
  126. }
  127. }
  128. plen = len(payload)
  129. buf.Write(payload)
  130. }
  131. signature, err := signer.Sign(buf.Bytes(), key)
  132. if err != nil {
  133. return nil, nil, errors.Wrap(err, `failed to sign payload`)
  134. }
  135. s.signature = signature
  136. // Detached payload, this should be removed from the end result
  137. if s.detached {
  138. buf.Truncate(buf.Len() - plen)
  139. }
  140. buf.WriteByte('.')
  141. buf.WriteString(base64.EncodeToString(signature))
  142. ret := make([]byte, buf.Len())
  143. copy(ret, buf.Bytes())
  144. return signature, ret, nil
  145. }
  146. func NewMessage() *Message {
  147. return &Message{}
  148. }
  149. // Clears the internal raw buffer that was accumulated during
  150. // the verify phase
  151. func (m *Message) clearRaw() {
  152. for _, sig := range m.signatures {
  153. if protected := sig.protected; protected != nil {
  154. if cr, ok := protected.(*stdHeaders); ok {
  155. cr.raw = nil
  156. }
  157. }
  158. }
  159. }
  160. func (m *Message) SetDecodeCtx(dc DecodeCtx) {
  161. m.dc = dc
  162. }
  163. func (m *Message) DecodeCtx() DecodeCtx {
  164. return m.dc
  165. }
  166. // Payload returns the decoded payload
  167. func (m Message) Payload() []byte {
  168. return m.payload
  169. }
  170. func (m *Message) SetPayload(v []byte) *Message {
  171. m.payload = v
  172. return m
  173. }
  174. func (m Message) Signatures() []*Signature {
  175. return m.signatures
  176. }
  177. func (m *Message) AppendSignature(v *Signature) *Message {
  178. m.signatures = append(m.signatures, v)
  179. return m
  180. }
  181. func (m *Message) ClearSignatures() *Message {
  182. m.signatures = nil
  183. return m
  184. }
  185. // LookupSignature looks up a particular signature entry using
  186. // the `kid` value
  187. func (m Message) LookupSignature(kid string) []*Signature {
  188. var sigs []*Signature
  189. for _, sig := range m.signatures {
  190. if hdr := sig.PublicHeaders(); hdr != nil {
  191. hdrKeyID := hdr.KeyID()
  192. if hdrKeyID == kid {
  193. sigs = append(sigs, sig)
  194. continue
  195. }
  196. }
  197. if hdr := sig.ProtectedHeaders(); hdr != nil {
  198. hdrKeyID := hdr.KeyID()
  199. if hdrKeyID == kid {
  200. sigs = append(sigs, sig)
  201. continue
  202. }
  203. }
  204. }
  205. return sigs
  206. }
  207. // This struct is used to first probe for the structure of the
  208. // incoming JSON object. We then decide how to parse it
  209. // from the fields that are populated.
  210. type messageUnmarshalProbe struct {
  211. Payload *string `json:"payload"`
  212. Signatures []json.RawMessage `json:"signatures,omitempty"`
  213. Header Headers `json:"header,omitempty"`
  214. Protected *string `json:"protected,omitempty"`
  215. Signature *string `json:"signature,omitempty"`
  216. }
  217. func (m *Message) UnmarshalJSON(buf []byte) error {
  218. m.payload = nil
  219. m.signatures = nil
  220. m.b64 = true
  221. var mup messageUnmarshalProbe
  222. mup.Header = NewHeaders()
  223. if err := json.Unmarshal(buf, &mup); err != nil {
  224. return errors.Wrap(err, `failed to unmarshal into temporary structure`)
  225. }
  226. b64 := true
  227. if mup.Signature == nil { // flattened signature is NOT present
  228. if len(mup.Signatures) == 0 {
  229. return errors.New(`required field "signatures" not present`)
  230. }
  231. m.signatures = make([]*Signature, 0, len(mup.Signatures))
  232. for i, rawsig := range mup.Signatures {
  233. var sig Signature
  234. sig.SetDecodeCtx(m.DecodeCtx())
  235. if err := json.Unmarshal(rawsig, &sig); err != nil {
  236. return errors.Wrapf(err, `failed to unmarshal signature #%d`, i+1)
  237. }
  238. sig.SetDecodeCtx(nil)
  239. if i == 0 {
  240. if !getB64Value(sig.protected) {
  241. b64 = false
  242. }
  243. } else {
  244. if b64 != getB64Value(sig.protected) {
  245. return errors.Errorf(`b64 value must be the same for all signatures`)
  246. }
  247. }
  248. m.signatures = append(m.signatures, &sig)
  249. }
  250. } else { // .signature is present, it's a flattened structure
  251. if len(mup.Signatures) != 0 {
  252. return errors.New(`invalid format ("signatures" and "signature" keys cannot both be present)`)
  253. }
  254. var sig Signature
  255. sig.headers = mup.Header
  256. if src := mup.Protected; src != nil {
  257. decoded, err := base64.DecodeString(*src)
  258. if err != nil {
  259. return errors.Wrap(err, `failed to base64 decode flattened protected headers`)
  260. }
  261. prt := NewHeaders()
  262. //nolint:forcetypeassert
  263. prt.(*stdHeaders).SetDecodeCtx(m.DecodeCtx())
  264. if err := json.Unmarshal(decoded, prt); err != nil {
  265. return errors.Wrap(err, `failed to unmarshal flattened protected headers`)
  266. }
  267. //nolint:forcetypeassert
  268. prt.(*stdHeaders).SetDecodeCtx(nil)
  269. sig.protected = prt
  270. }
  271. decoded, err := base64.DecodeString(*mup.Signature)
  272. if err != nil {
  273. return errors.Wrap(err, `failed to base64 decode flattened signature`)
  274. }
  275. sig.signature = decoded
  276. m.signatures = []*Signature{&sig}
  277. b64 = getB64Value(sig.protected)
  278. }
  279. if mup.Payload != nil {
  280. if !b64 { // NOT base64 encoded
  281. m.payload = []byte(*mup.Payload)
  282. } else {
  283. decoded, err := base64.DecodeString(*mup.Payload)
  284. if err != nil {
  285. return errors.Wrap(err, `failed to base64 decode payload`)
  286. }
  287. m.payload = decoded
  288. }
  289. }
  290. m.b64 = b64
  291. return nil
  292. }
  293. func (m Message) MarshalJSON() ([]byte, error) {
  294. if len(m.signatures) == 1 {
  295. return m.marshalFlattened()
  296. }
  297. return m.marshalFull()
  298. }
  299. func (m Message) marshalFlattened() ([]byte, error) {
  300. buf := pool.GetBytesBuffer()
  301. defer pool.ReleaseBytesBuffer(buf)
  302. sig := m.signatures[0]
  303. buf.WriteRune('{')
  304. var wrote bool
  305. if hdr := sig.headers; hdr != nil {
  306. hdrjs, err := hdr.MarshalJSON()
  307. if err != nil {
  308. return nil, errors.Wrap(err, `failed to marshal "header" (flattened format)`)
  309. }
  310. buf.WriteString(`"header":`)
  311. buf.Write(hdrjs)
  312. wrote = true
  313. }
  314. if wrote {
  315. buf.WriteRune(',')
  316. }
  317. buf.WriteString(`"payload":"`)
  318. buf.WriteString(base64.EncodeToString(m.payload))
  319. buf.WriteRune('"')
  320. if protected := sig.protected; protected != nil {
  321. protectedbuf, err := protected.MarshalJSON()
  322. if err != nil {
  323. return nil, errors.Wrap(err, `failed to marshal "protected" (flattened format)`)
  324. }
  325. buf.WriteString(`,"protected":"`)
  326. buf.WriteString(base64.EncodeToString(protectedbuf))
  327. buf.WriteRune('"')
  328. }
  329. buf.WriteString(`,"signature":"`)
  330. buf.WriteString(base64.EncodeToString(sig.signature))
  331. buf.WriteRune('"')
  332. buf.WriteRune('}')
  333. ret := make([]byte, buf.Len())
  334. copy(ret, buf.Bytes())
  335. return ret, nil
  336. }
  337. func (m Message) marshalFull() ([]byte, error) {
  338. buf := pool.GetBytesBuffer()
  339. defer pool.ReleaseBytesBuffer(buf)
  340. buf.WriteString(`{"payload":"`)
  341. buf.WriteString(base64.EncodeToString(m.payload))
  342. buf.WriteString(`","signatures":[`)
  343. for i, sig := range m.signatures {
  344. if i > 0 {
  345. buf.WriteRune(',')
  346. }
  347. buf.WriteRune('{')
  348. var wrote bool
  349. if hdr := sig.headers; hdr != nil {
  350. hdrbuf, err := hdr.MarshalJSON()
  351. if err != nil {
  352. return nil, errors.Wrapf(err, `failed to marshal "header" for signature #%d`, i+1)
  353. }
  354. buf.WriteString(`"header":`)
  355. buf.Write(hdrbuf)
  356. wrote = true
  357. }
  358. if protected := sig.protected; protected != nil {
  359. protectedbuf, err := protected.MarshalJSON()
  360. if err != nil {
  361. return nil, errors.Wrapf(err, `failed to marshal "protected" for signature #%d`, i+1)
  362. }
  363. if wrote {
  364. buf.WriteRune(',')
  365. }
  366. buf.WriteString(`"protected":"`)
  367. buf.WriteString(base64.EncodeToString(protectedbuf))
  368. buf.WriteRune('"')
  369. wrote = true
  370. }
  371. if wrote {
  372. buf.WriteRune(',')
  373. }
  374. buf.WriteString(`"signature":"`)
  375. buf.WriteString(base64.EncodeToString(sig.signature))
  376. buf.WriteString(`"}`)
  377. }
  378. buf.WriteString(`]}`)
  379. ret := make([]byte, buf.Len())
  380. copy(ret, buf.Bytes())
  381. return ret, nil
  382. }