set.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. package jwk
  2. import (
  3. "bytes"
  4. "context"
  5. "fmt"
  6. "sort"
  7. "github.com/lestrrat-go/iter/arrayiter"
  8. "github.com/lestrrat-go/jwx/internal/json"
  9. "github.com/lestrrat-go/jwx/internal/pool"
  10. "github.com/pkg/errors"
  11. )
  12. const keysKey = `keys` // appease linter
  13. // NewSet creates and empty `jwk.Set` object
  14. func NewSet() Set {
  15. return &set{
  16. privateParams: make(map[string]interface{}),
  17. }
  18. }
  19. func (s *set) Set(n string, v interface{}) error {
  20. s.mu.RLock()
  21. defer s.mu.RUnlock()
  22. if n == keysKey {
  23. vl, ok := v.([]Key)
  24. if !ok {
  25. return errors.Errorf(`value for field "keys" must be []jwk.Key`)
  26. }
  27. s.keys = vl
  28. return nil
  29. }
  30. s.privateParams[n] = v
  31. return nil
  32. }
  33. func (s *set) Field(n string) (interface{}, bool) {
  34. s.mu.RLock()
  35. defer s.mu.RUnlock()
  36. v, ok := s.privateParams[n]
  37. return v, ok
  38. }
  39. func (s *set) Get(idx int) (Key, bool) {
  40. s.mu.RLock()
  41. defer s.mu.RUnlock()
  42. if idx >= 0 && idx < len(s.keys) {
  43. return s.keys[idx], true
  44. }
  45. return nil, false
  46. }
  47. func (s *set) Len() int {
  48. s.mu.RLock()
  49. defer s.mu.RUnlock()
  50. return len(s.keys)
  51. }
  52. // indexNL is Index(), but without the locking
  53. func (s *set) indexNL(key Key) int {
  54. for i, k := range s.keys {
  55. if k == key {
  56. return i
  57. }
  58. }
  59. return -1
  60. }
  61. func (s *set) Index(key Key) int {
  62. s.mu.RLock()
  63. defer s.mu.RUnlock()
  64. return s.indexNL(key)
  65. }
  66. func (s *set) Add(key Key) bool {
  67. s.mu.Lock()
  68. defer s.mu.Unlock()
  69. if i := s.indexNL(key); i > -1 {
  70. return false
  71. }
  72. s.keys = append(s.keys, key)
  73. return true
  74. }
  75. func (s *set) Remove(key Key) bool {
  76. s.mu.Lock()
  77. defer s.mu.Unlock()
  78. for i, k := range s.keys {
  79. if k == key {
  80. switch i {
  81. case 0:
  82. s.keys = s.keys[1:]
  83. case len(s.keys) - 1:
  84. s.keys = s.keys[:i]
  85. default:
  86. s.keys = append(s.keys[:i], s.keys[i+1:]...)
  87. }
  88. return true
  89. }
  90. }
  91. return false
  92. }
  93. func (s *set) Clear() {
  94. s.mu.Lock()
  95. defer s.mu.Unlock()
  96. s.keys = nil
  97. }
  98. func (s *set) Iterate(ctx context.Context) KeyIterator {
  99. ch := make(chan *KeyPair, s.Len())
  100. go iterate(ctx, s.keys, ch)
  101. return arrayiter.New(ch)
  102. }
  103. func iterate(ctx context.Context, keys []Key, ch chan *KeyPair) {
  104. defer close(ch)
  105. for i, key := range keys {
  106. pair := &KeyPair{Index: i, Value: key}
  107. select {
  108. case <-ctx.Done():
  109. return
  110. case ch <- pair:
  111. }
  112. }
  113. }
  114. func (s *set) MarshalJSON() ([]byte, error) {
  115. s.mu.RLock()
  116. defer s.mu.RUnlock()
  117. buf := pool.GetBytesBuffer()
  118. defer pool.ReleaseBytesBuffer(buf)
  119. enc := json.NewEncoder(buf)
  120. fields := []string{keysKey}
  121. for k := range s.privateParams {
  122. fields = append(fields, k)
  123. }
  124. sort.Strings(fields)
  125. buf.WriteByte('{')
  126. for i, field := range fields {
  127. if i > 0 {
  128. buf.WriteByte(',')
  129. }
  130. fmt.Fprintf(buf, `%q:`, field)
  131. if field != keysKey {
  132. if err := enc.Encode(s.privateParams[field]); err != nil {
  133. return nil, errors.Wrapf(err, `failed to marshal field %q`, field)
  134. }
  135. } else {
  136. buf.WriteByte('[')
  137. for j, k := range s.keys {
  138. if j > 0 {
  139. buf.WriteByte(',')
  140. }
  141. if err := enc.Encode(k); err != nil {
  142. return nil, errors.Wrapf(err, `failed to marshal key #%d`, i)
  143. }
  144. }
  145. buf.WriteByte(']')
  146. }
  147. }
  148. buf.WriteByte('}')
  149. ret := make([]byte, buf.Len())
  150. copy(ret, buf.Bytes())
  151. return ret, nil
  152. }
  153. func (s *set) UnmarshalJSON(data []byte) error {
  154. s.mu.Lock()
  155. defer s.mu.Unlock()
  156. s.privateParams = make(map[string]interface{})
  157. s.keys = nil
  158. var options []ParseOption
  159. var ignoreParseError bool
  160. if dc := s.dc; dc != nil {
  161. if localReg := dc.Registry(); localReg != nil {
  162. options = append(options, withLocalRegistry(localReg))
  163. }
  164. ignoreParseError = dc.IgnoreParseError()
  165. }
  166. var sawKeysField bool
  167. dec := json.NewDecoder(bytes.NewReader(data))
  168. LOOP:
  169. for {
  170. tok, err := dec.Token()
  171. if err != nil {
  172. return errors.Wrap(err, `error reading token`)
  173. }
  174. switch tok := tok.(type) {
  175. case json.Delim:
  176. // Assuming we're doing everything correctly, we should ONLY
  177. // get either '{' or '}' here.
  178. if tok == '}' { // End of object
  179. break LOOP
  180. } else if tok != '{' {
  181. return errors.Errorf(`expected '{', but got '%c'`, tok)
  182. }
  183. case string:
  184. switch tok {
  185. case "keys":
  186. sawKeysField = true
  187. var list []json.RawMessage
  188. if err := dec.Decode(&list); err != nil {
  189. return errors.Wrap(err, `failed to decode "keys"`)
  190. }
  191. for i, keysrc := range list {
  192. key, err := ParseKey(keysrc, options...)
  193. if err != nil {
  194. if !ignoreParseError {
  195. return errors.Wrapf(err, `failed to decode key #%d in "keys"`, i)
  196. }
  197. continue
  198. }
  199. s.keys = append(s.keys, key)
  200. }
  201. default:
  202. var v interface{}
  203. if err := dec.Decode(&v); err != nil {
  204. return errors.Wrapf(err, `failed to decode value for key %q`, tok)
  205. }
  206. s.privateParams[tok] = v
  207. }
  208. }
  209. }
  210. // This is really silly, but we can only detect the
  211. // lack of the "keys" field after going through the
  212. // entire object once
  213. // Not checking for len(s.keys) == 0, because it could be
  214. // an empty key set
  215. if !sawKeysField {
  216. key, err := ParseKey(data, options...)
  217. if err != nil {
  218. return errors.Wrapf(err, `failed to parse sole key in key set`)
  219. }
  220. s.keys = append(s.keys, key)
  221. }
  222. return nil
  223. }
  224. func (s *set) LookupKeyID(kid string) (Key, bool) {
  225. s.mu.RLock()
  226. defer s.mu.RUnlock()
  227. n := s.Len()
  228. for i := 0; i < n; i++ {
  229. key, ok := s.Get(i)
  230. if !ok {
  231. return nil, false
  232. }
  233. if key.KeyID() == kid {
  234. return key, true
  235. }
  236. }
  237. return nil, false
  238. }
  239. func (s *set) DecodeCtx() DecodeCtx {
  240. s.mu.RLock()
  241. defer s.mu.RUnlock()
  242. return s.dc
  243. }
  244. func (s *set) SetDecodeCtx(dc DecodeCtx) {
  245. s.mu.Lock()
  246. defer s.mu.Unlock()
  247. s.dc = dc
  248. }
  249. func (s *set) Clone() (Set, error) {
  250. s2 := &set{}
  251. s.mu.RLock()
  252. defer s.mu.RUnlock()
  253. s2.keys = make([]Key, len(s.keys))
  254. for i := 0; i < len(s.keys); i++ {
  255. s2.keys[i] = s.keys[i]
  256. }
  257. return s2, nil
  258. }