| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 |
- package jwk
- import (
- "crypto/x509"
- "encoding/json"
- "fmt"
- "github.com/lestrrat/go-jwx/jwa"
- "github.com/lestrrat/go-pdebug"
- "github.com/pkg/errors"
- )
- const (
- AlgorithmKey = "alg"
- KeyIDKey = "kid"
- KeyTypeKey = "kty"
- KeyUsageKey = "use"
- KeyOpsKey = "key_ops"
- X509CertChainKey = "x5c"
- X509CertThumbprintKey = "x5t"
- X509CertThumbprintS256Key = "x5t#S256"
- X509URLKey = "x5u"
- )
- type Headers interface {
- Remove(string)
- Get(string) (interface{}, bool)
- Set(string, interface{}) error
- PopulateMap(map[string]interface{}) error
- ExtractMap(map[string]interface{}) error
- Walk(func(string, interface{}) error) error
- Algorithm() string
- KeyID() string
- KeyType() jwa.KeyType
- KeyUsage() string
- KeyOps() []KeyOperation
- X509CertChain() []*x509.Certificate
- X509CertThumbprint() string
- X509CertThumbprintS256() string
- X509URL() string
- }
- type StandardHeaders struct {
- algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4
- keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4
- keyType *jwa.KeyType // https://tools.ietf.org/html/rfc7517#section-4.1
- keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2
- keyops []KeyOperation // https://tools.ietf.org/html/rfc7517#section-4.3
- x509CertChain *CertificateChain // https://tools.ietf.org/html/rfc7515#section-4.1.6
- x509CertThumbprint *string // https://tools.ietf.org/html/rfc7515#section-4.1.7
- x509CertThumbprintS256 *string // https://tools.ietf.org/html/rfc7515#section-4.1.8
- x509URL *string // https://tools.ietf.org/html/rfc7515#section-4.1.5
- privateParams map[string]interface{}
- }
- func (h *StandardHeaders) Remove(s string) {
- delete(h.privateParams, s)
- }
- func (h *StandardHeaders) Algorithm() string {
- if v := h.algorithm; v != nil {
- return *v
- }
- return ""
- }
- func (h *StandardHeaders) KeyID() string {
- if v := h.keyID; v != nil {
- return *v
- }
- return ""
- }
- func (h *StandardHeaders) KeyType() jwa.KeyType {
- if v := h.keyType; v != nil {
- return *v
- }
- return jwa.InvalidKeyType
- }
- func (h *StandardHeaders) KeyUsage() string {
- if v := h.keyUsage; v != nil {
- return *v
- }
- return ""
- }
- func (h *StandardHeaders) KeyOps() []KeyOperation {
- return h.keyops
- }
- func (h *StandardHeaders) X509CertChain() []*x509.Certificate {
- return h.x509CertChain.Get()
- }
- func (h *StandardHeaders) X509CertThumbprint() string {
- if v := h.x509CertThumbprint; v != nil {
- return *v
- }
- return ""
- }
- func (h *StandardHeaders) X509CertThumbprintS256() string {
- if v := h.x509CertThumbprintS256; v != nil {
- return *v
- }
- return ""
- }
- func (h *StandardHeaders) X509URL() string {
- if v := h.x509URL; v != nil {
- return *v
- }
- return ""
- }
- func (h *StandardHeaders) Get(name string) (interface{}, bool) {
- switch name {
- case AlgorithmKey:
- v := h.algorithm
- if v == nil {
- return nil, false
- }
- return *v, true
- case KeyIDKey:
- v := h.keyID
- if v == nil {
- return nil, false
- }
- return *v, true
- case KeyTypeKey:
- v := h.keyType
- if v == nil {
- return nil, false
- }
- return *v, true
- case KeyUsageKey:
- v := h.keyUsage
- if v == nil {
- return nil, false
- }
- return *v, true
- case KeyOpsKey:
- v := h.keyops
- if len(v) == 0 {
- return nil, false
- }
- return v, true
- case X509CertChainKey:
- v := h.x509CertChain
- if v == nil {
- return nil, false
- }
- return v.Get(), true
- case X509CertThumbprintKey:
- v := h.x509CertThumbprint
- if v == nil {
- return nil, false
- }
- return *v, true
- case X509CertThumbprintS256Key:
- v := h.x509CertThumbprintS256
- if v == nil {
- return nil, false
- }
- return *v, true
- case X509URLKey:
- v := h.x509URL
- if v == nil {
- return nil, false
- }
- return *v, true
- default:
- v, ok := h.privateParams[name]
- return v, ok
- }
- }
- func (h *StandardHeaders) Set(name string, value interface{}) error {
- switch name {
- case AlgorithmKey:
- switch v := value.(type) {
- case string:
- h.algorithm = &v
- return nil
- case fmt.Stringer:
- s := v.String()
- h.algorithm = &s
- return nil
- }
- return errors.Errorf(`invalid value for %s key: %T`, AlgorithmKey, value)
- case KeyIDKey:
- if v, ok := value.(string); ok {
- h.keyID = &v
- return nil
- }
- return errors.Errorf(`invalid value for %s key: %T`, KeyIDKey, value)
- case KeyTypeKey:
- var acceptor jwa.KeyType
- if err := acceptor.Accept(value); err != nil {
- return errors.Wrapf(err, `invalid value for %s key`, KeyTypeKey)
- }
- h.keyType = &acceptor
- return nil
- case KeyUsageKey:
- if v, ok := value.(string); ok {
- h.keyUsage = &v
- return nil
- }
- return errors.Errorf(`invalid value for %s key: %T`, KeyUsageKey, value)
- case KeyOpsKey:
- if v, ok := value.([]KeyOperation); ok {
- h.keyops = v
- return nil
- }
- return errors.Errorf(`invalid value for %s key: %T`, KeyOpsKey, value)
- case X509CertChainKey:
- var acceptor CertificateChain
- if err := acceptor.Accept(value); err != nil {
- return errors.Wrapf(err, `invalid value for %s key`, X509CertChainKey)
- }
- h.x509CertChain = &acceptor
- return nil
- case X509CertThumbprintKey:
- if v, ok := value.(string); ok {
- h.x509CertThumbprint = &v
- return nil
- }
- return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintKey, value)
- case X509CertThumbprintS256Key:
- if v, ok := value.(string); ok {
- h.x509CertThumbprintS256 = &v
- return nil
- }
- return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintS256Key, value)
- case X509URLKey:
- if v, ok := value.(string); ok {
- h.x509URL = &v
- return nil
- }
- return errors.Errorf(`invalid value for %s key: %T`, X509URLKey, value)
- default:
- if h.privateParams == nil {
- h.privateParams = map[string]interface{}{}
- }
- h.privateParams[name] = value
- }
- return nil
- }
- func (h StandardHeaders) MarshalJSON() ([]byte, error) {
- m := map[string]interface{}{}
- if err := h.PopulateMap(m); err != nil {
- return nil, errors.Wrap(err, `failed to populate map for serialization`)
- }
- return json.Marshal(m)
- }
- // PopulateMap populates a map with appropriate values that represent
- // the headers as a JSON object. This exists primarily because JWKs are
- // represented as flat objects instead of differentiating the different
- // parts of the message in separate sub objects.
- func (h StandardHeaders) PopulateMap(m map[string]interface{}) error {
- for k, v := range h.privateParams {
- m[k] = v
- }
- if v, ok := h.Get(AlgorithmKey); ok {
- m[AlgorithmKey] = v
- }
- if v, ok := h.Get(KeyIDKey); ok {
- m[KeyIDKey] = v
- }
- if v, ok := h.Get(KeyTypeKey); ok {
- m[KeyTypeKey] = v
- }
- if v, ok := h.Get(KeyUsageKey); ok {
- m[KeyUsageKey] = v
- }
- if v, ok := h.Get(KeyOpsKey); ok {
- m[KeyOpsKey] = v
- }
- if v, ok := h.Get(X509CertChainKey); ok {
- m[X509CertChainKey] = v
- }
- if v, ok := h.Get(X509CertThumbprintKey); ok {
- m[X509CertThumbprintKey] = v
- }
- if v, ok := h.Get(X509CertThumbprintS256Key); ok {
- m[X509CertThumbprintS256Key] = v
- }
- if v, ok := h.Get(X509URLKey); ok {
- m[X509URLKey] = v
- }
- return nil
- }
- // ExtractMap populates the appropriate values from a map that represent
- // the headers as a JSON object. This exists primarily because JWKs are
- // represented as flat objects instead of differentiating the different
- // parts of the message in separate sub objects.
- func (h *StandardHeaders) ExtractMap(m map[string]interface{}) (err error) {
- if pdebug.Enabled {
- g := pdebug.Marker(`jwk.StandardHeaders.ExtractMap`).BindError(&err)
- defer g.End()
- }
- if v, ok := m[AlgorithmKey]; ok {
- if err := h.Set(AlgorithmKey, v); err != nil {
- return errors.Wrapf(err, `failed to set value for key %s`, AlgorithmKey)
- }
- }
- if v, ok := m[KeyIDKey]; ok {
- if err := h.Set(KeyIDKey, v); err != nil {
- return errors.Wrapf(err, `failed to set value for key %s`, KeyIDKey)
- }
- }
- if v, ok := m[KeyTypeKey]; ok {
- if err := h.Set(KeyTypeKey, v); err != nil {
- return errors.Wrapf(err, `failed to set value for key %s`, KeyTypeKey)
- }
- }
- if v, ok := m[KeyUsageKey]; ok {
- if err := h.Set(KeyUsageKey, v); err != nil {
- return errors.Wrapf(err, `failed to set value for key %s`, KeyUsageKey)
- }
- }
- if v, ok := m[KeyOpsKey]; ok {
- if err := h.Set(KeyOpsKey, v); err != nil {
- return errors.Wrapf(err, `failed to set value for key %s`, KeyOpsKey)
- }
- }
- if v, ok := m[X509CertChainKey]; ok {
- if err := h.Set(X509CertChainKey, v); err != nil {
- return errors.Wrapf(err, `failed to set value for key %s`, X509CertChainKey)
- }
- }
- if v, ok := m[X509CertThumbprintKey]; ok {
- if err := h.Set(X509CertThumbprintKey, v); err != nil {
- return errors.Wrapf(err, `failed to set value for key %s`, X509CertThumbprintKey)
- }
- }
- if v, ok := m[X509CertThumbprintS256Key]; ok {
- if err := h.Set(X509CertThumbprintS256Key, v); err != nil {
- return errors.Wrapf(err, `failed to set value for key %s`, X509CertThumbprintS256Key)
- }
- }
- if v, ok := m[X509URLKey]; ok {
- if err := h.Set(X509URLKey, v); err != nil {
- return errors.Wrapf(err, `failed to set value for key %s`, X509URLKey)
- }
- }
- h.privateParams = m
- return nil
- }
- func (h *StandardHeaders) UnmarshalJSON(buf []byte) error {
- var m map[string]interface{}
- if err := json.Unmarshal(buf, &m); err != nil {
- return errors.Wrap(err, `failed to unmarshal headers`)
- }
- return h.ExtractMap(m)
- }
- func (h StandardHeaders) Walk(f func(string, interface{}) error) error {
- for _, key := range []string{AlgorithmKey, KeyIDKey, KeyTypeKey, KeyUsageKey, KeyOpsKey, X509CertChainKey, X509CertThumbprintKey, X509CertThumbprintS256Key, X509URLKey} {
- if v, ok := h.Get(key); ok {
- if err := f(key, v); err != nil {
- return errors.Wrapf(err, `walk function returned error for %s`, key)
- }
- }
- }
- for k, v := range h.privateParams {
- if err := f(k, v); err != nil {
- return errors.Wrapf(err, `walk function returned error for %s`, k)
- }
- }
- return nil
- }
|