| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448 |
- package jws
- import (
- "bytes"
- "context"
- "github.com/lestrrat-go/jwx/internal/base64"
- "github.com/lestrrat-go/jwx/internal/json"
- "github.com/lestrrat-go/jwx/internal/pool"
- "github.com/lestrrat-go/jwx/jwk"
- "github.com/pkg/errors"
- )
- type collectRawCtx struct{}
- func (collectRawCtx) CollectRaw() bool {
- return true
- }
- func NewSignature() *Signature {
- return &Signature{}
- }
- func (s *Signature) DecodeCtx() DecodeCtx {
- return s.dc
- }
- func (s *Signature) SetDecodeCtx(dc DecodeCtx) {
- s.dc = dc
- }
- func (s Signature) PublicHeaders() Headers {
- return s.headers
- }
- func (s *Signature) SetPublicHeaders(v Headers) *Signature {
- s.headers = v
- return s
- }
- func (s Signature) ProtectedHeaders() Headers {
- return s.protected
- }
- func (s *Signature) SetProtectedHeaders(v Headers) *Signature {
- s.protected = v
- return s
- }
- func (s Signature) Signature() []byte {
- return s.signature
- }
- func (s *Signature) SetSignature(v []byte) *Signature {
- s.signature = v
- return s
- }
- type signatureUnmarshalProbe struct {
- Header Headers `json:"header,omitempty"`
- Protected *string `json:"protected,omitempty"`
- Signature *string `json:"signature,omitempty"`
- }
- func (s *Signature) UnmarshalJSON(data []byte) error {
- var sup signatureUnmarshalProbe
- sup.Header = NewHeaders()
- if err := json.Unmarshal(data, &sup); err != nil {
- return errors.Wrap(err, `failed to unmarshal signature into temporary struct`)
- }
- s.headers = sup.Header
- if buf := sup.Protected; buf != nil {
- src := []byte(*buf)
- if !bytes.HasPrefix(src, []byte{'{'}) {
- decoded, err := base64.Decode(src)
- if err != nil {
- return errors.Wrap(err, `failed to base64 decode protected headers`)
- }
- src = decoded
- }
- prt := NewHeaders()
- //nolint:forcetypeassert
- prt.(*stdHeaders).SetDecodeCtx(s.DecodeCtx())
- if err := json.Unmarshal(src, prt); err != nil {
- return errors.Wrap(err, `failed to unmarshal protected headers`)
- }
- //nolint:forcetypeassert
- prt.(*stdHeaders).SetDecodeCtx(nil)
- s.protected = prt
- }
- decoded, err := base64.DecodeString(*sup.Signature)
- if err != nil {
- return errors.Wrap(err, `failed to base decode signature`)
- }
- s.signature = decoded
- return nil
- }
- // Sign populates the signature field, with a signature generated by
- // given the signer object and payload.
- //
- // The first return value is the raw signature in binary format.
- // The second return value s the full three-segment signature
- // (e.g. "eyXXXX.XXXXX.XXXX")
- func (s *Signature) Sign(payload []byte, signer Signer, key interface{}) ([]byte, []byte, error) {
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
- hdrs, err := mergeHeaders(ctx, s.headers, s.protected)
- if err != nil {
- return nil, nil, errors.Wrap(err, `failed to merge headers`)
- }
- if err := hdrs.Set(AlgorithmKey, signer.Algorithm()); err != nil {
- return nil, nil, errors.Wrap(err, `failed to set "alg"`)
- }
- // If the key is a jwk.Key instance, obtain the raw key
- if jwkKey, ok := key.(jwk.Key); ok {
- // If we have a key ID specified by this jwk.Key, use that in the header
- if kid := jwkKey.KeyID(); kid != "" {
- if err := hdrs.Set(jwk.KeyIDKey, kid); err != nil {
- return nil, nil, errors.Wrap(err, `set key ID from jwk.Key`)
- }
- }
- }
- hdrbuf, err := json.Marshal(hdrs)
- if err != nil {
- return nil, nil, errors.Wrap(err, `failed to marshal headers`)
- }
- buf := pool.GetBytesBuffer()
- defer pool.ReleaseBytesBuffer(buf)
- buf.WriteString(base64.EncodeToString(hdrbuf))
- buf.WriteByte('.')
- var plen int
- b64 := getB64Value(hdrs)
- if b64 {
- encoded := base64.EncodeToString(payload)
- plen = len(encoded)
- buf.WriteString(encoded)
- } else {
- if !s.detached {
- if bytes.Contains(payload, []byte{'.'}) {
- return nil, nil, errors.New(`payload must not contain a "."`)
- }
- }
- plen = len(payload)
- buf.Write(payload)
- }
- signature, err := signer.Sign(buf.Bytes(), key)
- if err != nil {
- return nil, nil, errors.Wrap(err, `failed to sign payload`)
- }
- s.signature = signature
- // Detached payload, this should be removed from the end result
- if s.detached {
- buf.Truncate(buf.Len() - plen)
- }
- buf.WriteByte('.')
- buf.WriteString(base64.EncodeToString(signature))
- ret := make([]byte, buf.Len())
- copy(ret, buf.Bytes())
- return signature, ret, nil
- }
- func NewMessage() *Message {
- return &Message{}
- }
- // Clears the internal raw buffer that was accumulated during
- // the verify phase
- func (m *Message) clearRaw() {
- for _, sig := range m.signatures {
- if protected := sig.protected; protected != nil {
- if cr, ok := protected.(*stdHeaders); ok {
- cr.raw = nil
- }
- }
- }
- }
- func (m *Message) SetDecodeCtx(dc DecodeCtx) {
- m.dc = dc
- }
- func (m *Message) DecodeCtx() DecodeCtx {
- return m.dc
- }
- // Payload returns the decoded payload
- func (m Message) Payload() []byte {
- return m.payload
- }
- func (m *Message) SetPayload(v []byte) *Message {
- m.payload = v
- return m
- }
- func (m Message) Signatures() []*Signature {
- return m.signatures
- }
- func (m *Message) AppendSignature(v *Signature) *Message {
- m.signatures = append(m.signatures, v)
- return m
- }
- func (m *Message) ClearSignatures() *Message {
- m.signatures = nil
- return m
- }
- // LookupSignature looks up a particular signature entry using
- // the `kid` value
- func (m Message) LookupSignature(kid string) []*Signature {
- var sigs []*Signature
- for _, sig := range m.signatures {
- if hdr := sig.PublicHeaders(); hdr != nil {
- hdrKeyID := hdr.KeyID()
- if hdrKeyID == kid {
- sigs = append(sigs, sig)
- continue
- }
- }
- if hdr := sig.ProtectedHeaders(); hdr != nil {
- hdrKeyID := hdr.KeyID()
- if hdrKeyID == kid {
- sigs = append(sigs, sig)
- continue
- }
- }
- }
- return sigs
- }
- // This struct is used to first probe for the structure of the
- // incoming JSON object. We then decide how to parse it
- // from the fields that are populated.
- type messageUnmarshalProbe struct {
- Payload *string `json:"payload"`
- Signatures []json.RawMessage `json:"signatures,omitempty"`
- Header Headers `json:"header,omitempty"`
- Protected *string `json:"protected,omitempty"`
- Signature *string `json:"signature,omitempty"`
- }
- func (m *Message) UnmarshalJSON(buf []byte) error {
- m.payload = nil
- m.signatures = nil
- m.b64 = true
- var mup messageUnmarshalProbe
- mup.Header = NewHeaders()
- if err := json.Unmarshal(buf, &mup); err != nil {
- return errors.Wrap(err, `failed to unmarshal into temporary structure`)
- }
- b64 := true
- if mup.Signature == nil { // flattened signature is NOT present
- if len(mup.Signatures) == 0 {
- return errors.New(`required field "signatures" not present`)
- }
- m.signatures = make([]*Signature, 0, len(mup.Signatures))
- for i, rawsig := range mup.Signatures {
- var sig Signature
- sig.SetDecodeCtx(m.DecodeCtx())
- if err := json.Unmarshal(rawsig, &sig); err != nil {
- return errors.Wrapf(err, `failed to unmarshal signature #%d`, i+1)
- }
- sig.SetDecodeCtx(nil)
- if i == 0 {
- if !getB64Value(sig.protected) {
- b64 = false
- }
- } else {
- if b64 != getB64Value(sig.protected) {
- return errors.Errorf(`b64 value must be the same for all signatures`)
- }
- }
- m.signatures = append(m.signatures, &sig)
- }
- } else { // .signature is present, it's a flattened structure
- if len(mup.Signatures) != 0 {
- return errors.New(`invalid format ("signatures" and "signature" keys cannot both be present)`)
- }
- var sig Signature
- sig.headers = mup.Header
- if src := mup.Protected; src != nil {
- decoded, err := base64.DecodeString(*src)
- if err != nil {
- return errors.Wrap(err, `failed to base64 decode flattened protected headers`)
- }
- prt := NewHeaders()
- //nolint:forcetypeassert
- prt.(*stdHeaders).SetDecodeCtx(m.DecodeCtx())
- if err := json.Unmarshal(decoded, prt); err != nil {
- return errors.Wrap(err, `failed to unmarshal flattened protected headers`)
- }
- //nolint:forcetypeassert
- prt.(*stdHeaders).SetDecodeCtx(nil)
- sig.protected = prt
- }
- decoded, err := base64.DecodeString(*mup.Signature)
- if err != nil {
- return errors.Wrap(err, `failed to base64 decode flattened signature`)
- }
- sig.signature = decoded
- m.signatures = []*Signature{&sig}
- b64 = getB64Value(sig.protected)
- }
- if mup.Payload != nil {
- if !b64 { // NOT base64 encoded
- m.payload = []byte(*mup.Payload)
- } else {
- decoded, err := base64.DecodeString(*mup.Payload)
- if err != nil {
- return errors.Wrap(err, `failed to base64 decode payload`)
- }
- m.payload = decoded
- }
- }
- m.b64 = b64
- return nil
- }
- func (m Message) MarshalJSON() ([]byte, error) {
- if len(m.signatures) == 1 {
- return m.marshalFlattened()
- }
- return m.marshalFull()
- }
- func (m Message) marshalFlattened() ([]byte, error) {
- buf := pool.GetBytesBuffer()
- defer pool.ReleaseBytesBuffer(buf)
- sig := m.signatures[0]
- buf.WriteRune('{')
- var wrote bool
- if hdr := sig.headers; hdr != nil {
- hdrjs, err := hdr.MarshalJSON()
- if err != nil {
- return nil, errors.Wrap(err, `failed to marshal "header" (flattened format)`)
- }
- buf.WriteString(`"header":`)
- buf.Write(hdrjs)
- wrote = true
- }
- if wrote {
- buf.WriteRune(',')
- }
- buf.WriteString(`"payload":"`)
- buf.WriteString(base64.EncodeToString(m.payload))
- buf.WriteRune('"')
- if protected := sig.protected; protected != nil {
- protectedbuf, err := protected.MarshalJSON()
- if err != nil {
- return nil, errors.Wrap(err, `failed to marshal "protected" (flattened format)`)
- }
- buf.WriteString(`,"protected":"`)
- buf.WriteString(base64.EncodeToString(protectedbuf))
- buf.WriteRune('"')
- }
- buf.WriteString(`,"signature":"`)
- buf.WriteString(base64.EncodeToString(sig.signature))
- buf.WriteRune('"')
- buf.WriteRune('}')
- ret := make([]byte, buf.Len())
- copy(ret, buf.Bytes())
- return ret, nil
- }
- func (m Message) marshalFull() ([]byte, error) {
- buf := pool.GetBytesBuffer()
- defer pool.ReleaseBytesBuffer(buf)
- buf.WriteString(`{"payload":"`)
- buf.WriteString(base64.EncodeToString(m.payload))
- buf.WriteString(`","signatures":[`)
- for i, sig := range m.signatures {
- if i > 0 {
- buf.WriteRune(',')
- }
- buf.WriteRune('{')
- var wrote bool
- if hdr := sig.headers; hdr != nil {
- hdrbuf, err := hdr.MarshalJSON()
- if err != nil {
- return nil, errors.Wrapf(err, `failed to marshal "header" for signature #%d`, i+1)
- }
- buf.WriteString(`"header":`)
- buf.Write(hdrbuf)
- wrote = true
- }
- if protected := sig.protected; protected != nil {
- protectedbuf, err := protected.MarshalJSON()
- if err != nil {
- return nil, errors.Wrapf(err, `failed to marshal "protected" for signature #%d`, i+1)
- }
- if wrote {
- buf.WriteRune(',')
- }
- buf.WriteString(`"protected":"`)
- buf.WriteString(base64.EncodeToString(protectedbuf))
- buf.WriteRune('"')
- wrote = true
- }
- if wrote {
- buf.WriteRune(',')
- }
- buf.WriteString(`"signature":"`)
- buf.WriteString(base64.EncodeToString(sig.signature))
- buf.WriteString(`"}`)
- }
- buf.WriteString(`]}`)
- ret := make([]byte, buf.Len())
- copy(ret, buf.Bytes())
- return ret, nil
- }
|