headers.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. package jwk
  2. import (
  3. "crypto/x509"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/lestrrat/go-jwx/jwa"
  7. "github.com/lestrrat/go-pdebug"
  8. "github.com/pkg/errors"
  9. )
  10. const (
  11. AlgorithmKey = "alg"
  12. KeyIDKey = "kid"
  13. KeyTypeKey = "kty"
  14. KeyUsageKey = "use"
  15. KeyOpsKey = "key_ops"
  16. X509CertChainKey = "x5c"
  17. X509CertThumbprintKey = "x5t"
  18. X509CertThumbprintS256Key = "x5t#S256"
  19. X509URLKey = "x5u"
  20. )
  21. type Headers interface {
  22. Remove(string)
  23. Get(string) (interface{}, bool)
  24. Set(string, interface{}) error
  25. PopulateMap(map[string]interface{}) error
  26. ExtractMap(map[string]interface{}) error
  27. Walk(func(string, interface{}) error) error
  28. Algorithm() string
  29. KeyID() string
  30. KeyType() jwa.KeyType
  31. KeyUsage() string
  32. KeyOps() []KeyOperation
  33. X509CertChain() []*x509.Certificate
  34. X509CertThumbprint() string
  35. X509CertThumbprintS256() string
  36. X509URL() string
  37. }
  38. type StandardHeaders struct {
  39. algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4
  40. keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4
  41. keyType *jwa.KeyType // https://tools.ietf.org/html/rfc7517#section-4.1
  42. keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2
  43. keyops []KeyOperation // https://tools.ietf.org/html/rfc7517#section-4.3
  44. x509CertChain *CertificateChain // https://tools.ietf.org/html/rfc7515#section-4.1.6
  45. x509CertThumbprint *string // https://tools.ietf.org/html/rfc7515#section-4.1.7
  46. x509CertThumbprintS256 *string // https://tools.ietf.org/html/rfc7515#section-4.1.8
  47. x509URL *string // https://tools.ietf.org/html/rfc7515#section-4.1.5
  48. privateParams map[string]interface{}
  49. }
  50. func (h *StandardHeaders) Remove(s string) {
  51. delete(h.privateParams, s)
  52. }
  53. func (h *StandardHeaders) Algorithm() string {
  54. if v := h.algorithm; v != nil {
  55. return *v
  56. }
  57. return ""
  58. }
  59. func (h *StandardHeaders) KeyID() string {
  60. if v := h.keyID; v != nil {
  61. return *v
  62. }
  63. return ""
  64. }
  65. func (h *StandardHeaders) KeyType() jwa.KeyType {
  66. if v := h.keyType; v != nil {
  67. return *v
  68. }
  69. return jwa.InvalidKeyType
  70. }
  71. func (h *StandardHeaders) KeyUsage() string {
  72. if v := h.keyUsage; v != nil {
  73. return *v
  74. }
  75. return ""
  76. }
  77. func (h *StandardHeaders) KeyOps() []KeyOperation {
  78. return h.keyops
  79. }
  80. func (h *StandardHeaders) X509CertChain() []*x509.Certificate {
  81. return h.x509CertChain.Get()
  82. }
  83. func (h *StandardHeaders) X509CertThumbprint() string {
  84. if v := h.x509CertThumbprint; v != nil {
  85. return *v
  86. }
  87. return ""
  88. }
  89. func (h *StandardHeaders) X509CertThumbprintS256() string {
  90. if v := h.x509CertThumbprintS256; v != nil {
  91. return *v
  92. }
  93. return ""
  94. }
  95. func (h *StandardHeaders) X509URL() string {
  96. if v := h.x509URL; v != nil {
  97. return *v
  98. }
  99. return ""
  100. }
  101. func (h *StandardHeaders) Get(name string) (interface{}, bool) {
  102. switch name {
  103. case AlgorithmKey:
  104. v := h.algorithm
  105. if v == nil {
  106. return nil, false
  107. }
  108. return *v, true
  109. case KeyIDKey:
  110. v := h.keyID
  111. if v == nil {
  112. return nil, false
  113. }
  114. return *v, true
  115. case KeyTypeKey:
  116. v := h.keyType
  117. if v == nil {
  118. return nil, false
  119. }
  120. return *v, true
  121. case KeyUsageKey:
  122. v := h.keyUsage
  123. if v == nil {
  124. return nil, false
  125. }
  126. return *v, true
  127. case KeyOpsKey:
  128. v := h.keyops
  129. if len(v) == 0 {
  130. return nil, false
  131. }
  132. return v, true
  133. case X509CertChainKey:
  134. v := h.x509CertChain
  135. if v == nil {
  136. return nil, false
  137. }
  138. return v.Get(), true
  139. case X509CertThumbprintKey:
  140. v := h.x509CertThumbprint
  141. if v == nil {
  142. return nil, false
  143. }
  144. return *v, true
  145. case X509CertThumbprintS256Key:
  146. v := h.x509CertThumbprintS256
  147. if v == nil {
  148. return nil, false
  149. }
  150. return *v, true
  151. case X509URLKey:
  152. v := h.x509URL
  153. if v == nil {
  154. return nil, false
  155. }
  156. return *v, true
  157. default:
  158. v, ok := h.privateParams[name]
  159. return v, ok
  160. }
  161. }
  162. func (h *StandardHeaders) Set(name string, value interface{}) error {
  163. switch name {
  164. case AlgorithmKey:
  165. switch v := value.(type) {
  166. case string:
  167. h.algorithm = &v
  168. return nil
  169. case fmt.Stringer:
  170. s := v.String()
  171. h.algorithm = &s
  172. return nil
  173. }
  174. return errors.Errorf(`invalid value for %s key: %T`, AlgorithmKey, value)
  175. case KeyIDKey:
  176. if v, ok := value.(string); ok {
  177. h.keyID = &v
  178. return nil
  179. }
  180. return errors.Errorf(`invalid value for %s key: %T`, KeyIDKey, value)
  181. case KeyTypeKey:
  182. var acceptor jwa.KeyType
  183. if err := acceptor.Accept(value); err != nil {
  184. return errors.Wrapf(err, `invalid value for %s key`, KeyTypeKey)
  185. }
  186. h.keyType = &acceptor
  187. return nil
  188. case KeyUsageKey:
  189. if v, ok := value.(string); ok {
  190. h.keyUsage = &v
  191. return nil
  192. }
  193. return errors.Errorf(`invalid value for %s key: %T`, KeyUsageKey, value)
  194. case KeyOpsKey:
  195. if v, ok := value.([]KeyOperation); ok {
  196. h.keyops = v
  197. return nil
  198. }
  199. return errors.Errorf(`invalid value for %s key: %T`, KeyOpsKey, value)
  200. case X509CertChainKey:
  201. var acceptor CertificateChain
  202. if err := acceptor.Accept(value); err != nil {
  203. return errors.Wrapf(err, `invalid value for %s key`, X509CertChainKey)
  204. }
  205. h.x509CertChain = &acceptor
  206. return nil
  207. case X509CertThumbprintKey:
  208. if v, ok := value.(string); ok {
  209. h.x509CertThumbprint = &v
  210. return nil
  211. }
  212. return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintKey, value)
  213. case X509CertThumbprintS256Key:
  214. if v, ok := value.(string); ok {
  215. h.x509CertThumbprintS256 = &v
  216. return nil
  217. }
  218. return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintS256Key, value)
  219. case X509URLKey:
  220. if v, ok := value.(string); ok {
  221. h.x509URL = &v
  222. return nil
  223. }
  224. return errors.Errorf(`invalid value for %s key: %T`, X509URLKey, value)
  225. default:
  226. if h.privateParams == nil {
  227. h.privateParams = map[string]interface{}{}
  228. }
  229. h.privateParams[name] = value
  230. }
  231. return nil
  232. }
  233. func (h StandardHeaders) MarshalJSON() ([]byte, error) {
  234. m := map[string]interface{}{}
  235. if err := h.PopulateMap(m); err != nil {
  236. return nil, errors.Wrap(err, `failed to populate map for serialization`)
  237. }
  238. return json.Marshal(m)
  239. }
  240. // PopulateMap populates a map with appropriate values that represent
  241. // the headers as a JSON object. This exists primarily because JWKs are
  242. // represented as flat objects instead of differentiating the different
  243. // parts of the message in separate sub objects.
  244. func (h StandardHeaders) PopulateMap(m map[string]interface{}) error {
  245. for k, v := range h.privateParams {
  246. m[k] = v
  247. }
  248. if v, ok := h.Get(AlgorithmKey); ok {
  249. m[AlgorithmKey] = v
  250. }
  251. if v, ok := h.Get(KeyIDKey); ok {
  252. m[KeyIDKey] = v
  253. }
  254. if v, ok := h.Get(KeyTypeKey); ok {
  255. m[KeyTypeKey] = v
  256. }
  257. if v, ok := h.Get(KeyUsageKey); ok {
  258. m[KeyUsageKey] = v
  259. }
  260. if v, ok := h.Get(KeyOpsKey); ok {
  261. m[KeyOpsKey] = v
  262. }
  263. if v, ok := h.Get(X509CertChainKey); ok {
  264. m[X509CertChainKey] = v
  265. }
  266. if v, ok := h.Get(X509CertThumbprintKey); ok {
  267. m[X509CertThumbprintKey] = v
  268. }
  269. if v, ok := h.Get(X509CertThumbprintS256Key); ok {
  270. m[X509CertThumbprintS256Key] = v
  271. }
  272. if v, ok := h.Get(X509URLKey); ok {
  273. m[X509URLKey] = v
  274. }
  275. return nil
  276. }
  277. // ExtractMap populates the appropriate values from a map that represent
  278. // the headers as a JSON object. This exists primarily because JWKs are
  279. // represented as flat objects instead of differentiating the different
  280. // parts of the message in separate sub objects.
  281. func (h *StandardHeaders) ExtractMap(m map[string]interface{}) (err error) {
  282. if pdebug.Enabled {
  283. g := pdebug.Marker(`jwk.StandardHeaders.ExtractMap`).BindError(&err)
  284. defer g.End()
  285. }
  286. if v, ok := m[AlgorithmKey]; ok {
  287. if err := h.Set(AlgorithmKey, v); err != nil {
  288. return errors.Wrapf(err, `failed to set value for key %s`, AlgorithmKey)
  289. }
  290. }
  291. if v, ok := m[KeyIDKey]; ok {
  292. if err := h.Set(KeyIDKey, v); err != nil {
  293. return errors.Wrapf(err, `failed to set value for key %s`, KeyIDKey)
  294. }
  295. }
  296. if v, ok := m[KeyTypeKey]; ok {
  297. if err := h.Set(KeyTypeKey, v); err != nil {
  298. return errors.Wrapf(err, `failed to set value for key %s`, KeyTypeKey)
  299. }
  300. }
  301. if v, ok := m[KeyUsageKey]; ok {
  302. if err := h.Set(KeyUsageKey, v); err != nil {
  303. return errors.Wrapf(err, `failed to set value for key %s`, KeyUsageKey)
  304. }
  305. }
  306. if v, ok := m[KeyOpsKey]; ok {
  307. if err := h.Set(KeyOpsKey, v); err != nil {
  308. return errors.Wrapf(err, `failed to set value for key %s`, KeyOpsKey)
  309. }
  310. }
  311. if v, ok := m[X509CertChainKey]; ok {
  312. if err := h.Set(X509CertChainKey, v); err != nil {
  313. return errors.Wrapf(err, `failed to set value for key %s`, X509CertChainKey)
  314. }
  315. }
  316. if v, ok := m[X509CertThumbprintKey]; ok {
  317. if err := h.Set(X509CertThumbprintKey, v); err != nil {
  318. return errors.Wrapf(err, `failed to set value for key %s`, X509CertThumbprintKey)
  319. }
  320. }
  321. if v, ok := m[X509CertThumbprintS256Key]; ok {
  322. if err := h.Set(X509CertThumbprintS256Key, v); err != nil {
  323. return errors.Wrapf(err, `failed to set value for key %s`, X509CertThumbprintS256Key)
  324. }
  325. }
  326. if v, ok := m[X509URLKey]; ok {
  327. if err := h.Set(X509URLKey, v); err != nil {
  328. return errors.Wrapf(err, `failed to set value for key %s`, X509URLKey)
  329. }
  330. }
  331. h.privateParams = m
  332. return nil
  333. }
  334. func (h *StandardHeaders) UnmarshalJSON(buf []byte) error {
  335. var m map[string]interface{}
  336. if err := json.Unmarshal(buf, &m); err != nil {
  337. return errors.Wrap(err, `failed to unmarshal headers`)
  338. }
  339. return h.ExtractMap(m)
  340. }
  341. func (h StandardHeaders) Walk(f func(string, interface{}) error) error {
  342. for _, key := range []string{AlgorithmKey, KeyIDKey, KeyTypeKey, KeyUsageKey, KeyOpsKey, X509CertChainKey, X509CertThumbprintKey, X509CertThumbprintS256Key, X509URLKey} {
  343. if v, ok := h.Get(key); ok {
  344. if err := f(key, v); err != nil {
  345. return errors.Wrapf(err, `walk function returned error for %s`, key)
  346. }
  347. }
  348. }
  349. for k, v := range h.privateParams {
  350. if err := f(k, v); err != nil {
  351. return errors.Wrapf(err, `walk function returned error for %s`, k)
  352. }
  353. }
  354. return nil
  355. }