refresh.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. package jwk
  2. import (
  3. "context"
  4. "net/http"
  5. "reflect"
  6. "sync"
  7. "time"
  8. "github.com/lestrrat-go/backoff/v2"
  9. "github.com/lestrrat-go/httpcc"
  10. "github.com/pkg/errors"
  11. )
  12. // AutoRefresh is a container that keeps track of jwk.Set object by their source URLs.
  13. // The jwk.Set objects are refreshed automatically behind the scenes.
  14. //
  15. // Before retrieving the jwk.Set objects, the user must pre-register the
  16. // URLs they intend to use by calling `Configure()`
  17. //
  18. // ar := jwk.NewAutoRefresh(ctx)
  19. // ar.Configure(url, options...)
  20. //
  21. // Once registered, you can call `Fetch()` to retrieve the jwk.Set object.
  22. //
  23. // All JWKS objects that are retrieved via the auto-fetch mechanism should be
  24. // treated read-only, as they are shared among the consumers and this object.
  25. type AutoRefresh struct {
  26. errSink chan AutoRefreshError
  27. cache map[string]Set
  28. configureCh chan struct{}
  29. removeCh chan removeReq
  30. fetching map[string]chan struct{}
  31. muErrSink sync.Mutex
  32. muCache sync.RWMutex
  33. muFetching sync.Mutex
  34. muRegistry sync.RWMutex
  35. registry map[string]*target
  36. resetTimerCh chan *resetTimerReq
  37. }
  38. type target struct {
  39. // The backoff policy to use when fetching the JWKS fails
  40. backoff backoff.Policy
  41. // The HTTP client to use. The user may opt to use a client which is
  42. // aware of HTTP caching, or one that goes through a proxy
  43. httpcl HTTPClient
  44. // Interval between refreshes are calculated two ways.
  45. // 1) You can set an explicit refresh interval by using WithRefreshInterval().
  46. // In this mode, it doesn't matter what the HTTP response says in its
  47. // Cache-Control or Expires headers
  48. // 2) You can let us calculate the time-to-refresh based on the key's
  49. // Cache-Control or Expires headers.
  50. // First, the user provides us the absolute minimum interval before
  51. // refreshes. We will never check for refreshes before this specified
  52. // amount of time.
  53. //
  54. // Next, max-age directive in the Cache-Control header is consulted.
  55. // If `max-age` is not present, we skip the following section, and
  56. // proceed to the next option.
  57. // If `max-age > user-supplied minimum interval`, then we use the max-age,
  58. // otherwise the user-supplied minimum interval is used.
  59. //
  60. // Next, the value specified in Expires header is consulted.
  61. // If the header is not present, we skip the following seciont and
  62. // proceed to the next option.
  63. // We take the time until expiration `expires - time.Now()`, and
  64. // if `time-until-expiration > user-supplied minimum interval`, then
  65. // we use the expires value, otherwise the user-supplied minimum interval is used.
  66. //
  67. // If all of the above fails, we used the user-supplied minimum interval
  68. refreshInterval *time.Duration
  69. minRefreshInterval time.Duration
  70. url string
  71. // The timer for refreshing the keyset. should not be set by anyone
  72. // other than the refreshing goroutine
  73. timer *time.Timer
  74. // Semaphore to limit the number of concurrent refreshes in the background
  75. sem chan struct{}
  76. // for debugging, snapshoting
  77. lastRefresh time.Time
  78. nextRefresh time.Time
  79. wl Whitelist
  80. parseOptions []ParseOption
  81. }
  82. type resetTimerReq struct {
  83. t *target
  84. d time.Duration
  85. }
  86. // NewAutoRefresh creates a container that keeps track of JWKS objects which
  87. // are automatically refreshed.
  88. //
  89. // The context object in the argument controls the life-span of the
  90. // auto-refresh worker. If you are using this in a long running process, this
  91. // should mostly be set to a context that ends when the main loop/part of your
  92. // program exits:
  93. //
  94. // func MainLoop() {
  95. // ctx, cancel := context.WithCancel(context.Background())
  96. // defer cancel()
  97. // ar := jwk.AutoRefresh(ctx)
  98. // for ... {
  99. // ...
  100. // }
  101. // }
  102. func NewAutoRefresh(ctx context.Context) *AutoRefresh {
  103. af := &AutoRefresh{
  104. cache: make(map[string]Set),
  105. configureCh: make(chan struct{}),
  106. removeCh: make(chan removeReq),
  107. fetching: make(map[string]chan struct{}),
  108. registry: make(map[string]*target),
  109. resetTimerCh: make(chan *resetTimerReq),
  110. }
  111. go af.refreshLoop(ctx)
  112. return af
  113. }
  114. func (af *AutoRefresh) getCached(url string) (Set, bool) {
  115. af.muCache.RLock()
  116. ks, ok := af.cache[url]
  117. af.muCache.RUnlock()
  118. if ok {
  119. return ks, true
  120. }
  121. return nil, false
  122. }
  123. type removeReq struct {
  124. replyCh chan error
  125. url string
  126. }
  127. // Remove removes `url` from the list of urls being watched by jwk.AutoRefresh.
  128. // If the url is not already registered, returns an error.
  129. func (af *AutoRefresh) Remove(url string) error {
  130. ch := make(chan error)
  131. af.removeCh <- removeReq{replyCh: ch, url: url}
  132. return <-ch
  133. }
  134. // Configure registers the url to be controlled by AutoRefresh, and also
  135. // sets any options associated to it.
  136. //
  137. // Note that options are treated as a whole -- you can't just update
  138. // one value. For example, if you did:
  139. //
  140. // ar.Configure(url, jwk.WithHTTPClient(...))
  141. // ar.Configure(url, jwk.WithRefreshInterval(...))
  142. // The the end result is that `url` is ONLY associated with the options
  143. // given in the second call to `Configure()`, i.e. `jwk.WithRefreshInterval`.
  144. // The other unspecified options, including the HTTP client, is set to
  145. // their default values.
  146. //
  147. // Configuration must propagate between goroutines, and therefore are
  148. // not atomic (But changes should be felt "soon enough" for practical
  149. // purposes)
  150. func (af *AutoRefresh) Configure(url string, options ...AutoRefreshOption) {
  151. var httpcl HTTPClient = http.DefaultClient
  152. var hasRefreshInterval bool
  153. var refreshInterval time.Duration
  154. var wl Whitelist
  155. var parseOptions []ParseOption
  156. minRefreshInterval := time.Hour
  157. bo := backoff.Null()
  158. for _, option := range options {
  159. if v, ok := option.(ParseOption); ok {
  160. parseOptions = append(parseOptions, v)
  161. continue
  162. }
  163. //nolint:forcetypeassert
  164. switch option.Ident() {
  165. case identFetchBackoff{}:
  166. bo = option.Value().(backoff.Policy)
  167. case identRefreshInterval{}:
  168. refreshInterval = option.Value().(time.Duration)
  169. hasRefreshInterval = true
  170. case identMinRefreshInterval{}:
  171. minRefreshInterval = option.Value().(time.Duration)
  172. case identHTTPClient{}:
  173. httpcl = option.Value().(HTTPClient)
  174. case identFetchWhitelist{}:
  175. wl = option.Value().(Whitelist)
  176. }
  177. }
  178. af.muRegistry.Lock()
  179. t, ok := af.registry[url]
  180. if ok {
  181. if t.httpcl != httpcl {
  182. t.httpcl = httpcl
  183. }
  184. if t.minRefreshInterval != minRefreshInterval {
  185. t.minRefreshInterval = minRefreshInterval
  186. }
  187. if t.refreshInterval != nil {
  188. if !hasRefreshInterval {
  189. t.refreshInterval = nil
  190. } else if *t.refreshInterval != refreshInterval {
  191. *t.refreshInterval = refreshInterval
  192. }
  193. } else {
  194. if hasRefreshInterval {
  195. t.refreshInterval = &refreshInterval
  196. }
  197. }
  198. if t.wl != wl {
  199. t.wl = wl
  200. }
  201. t.parseOptions = parseOptions
  202. } else {
  203. t = &target{
  204. backoff: bo,
  205. httpcl: httpcl,
  206. minRefreshInterval: minRefreshInterval,
  207. url: url,
  208. sem: make(chan struct{}, 1),
  209. // This is a placeholder timer so we can call Reset() on it later
  210. // Make it sufficiently in the future so that we don't have bogus
  211. // events firing
  212. timer: time.NewTimer(24 * time.Hour),
  213. wl: wl,
  214. parseOptions: parseOptions,
  215. }
  216. if hasRefreshInterval {
  217. t.refreshInterval = &refreshInterval
  218. }
  219. // Record this in the registry
  220. af.registry[url] = t
  221. }
  222. af.muRegistry.Unlock()
  223. // Tell the backend to reconfigure itself
  224. af.configureCh <- struct{}{}
  225. }
  226. func (af *AutoRefresh) releaseFetching(url string) {
  227. // first delete the entry from the map, then close the channel or
  228. // otherwise we may end up getting multiple groutines doing the fetch
  229. af.muFetching.Lock()
  230. fetchingCh, ok := af.fetching[url]
  231. if !ok {
  232. // Juuuuuuust in case. But shouldn't happen
  233. af.muFetching.Unlock()
  234. return
  235. }
  236. delete(af.fetching, url)
  237. close(fetchingCh)
  238. af.muFetching.Unlock()
  239. }
  240. // IsRegistered checks if `url` is registered already.
  241. func (af *AutoRefresh) IsRegistered(url string) bool {
  242. _, ok := af.getRegistered(url)
  243. return ok
  244. }
  245. // Fetch returns a jwk.Set from the given url.
  246. func (af *AutoRefresh) getRegistered(url string) (*target, bool) {
  247. af.muRegistry.RLock()
  248. t, ok := af.registry[url]
  249. af.muRegistry.RUnlock()
  250. return t, ok
  251. }
  252. // Fetch returns a jwk.Set from the given url.
  253. //
  254. // If it has previously been fetched, then a cached value is returned.
  255. //
  256. // If this the first time `url` was requested, an HTTP request will be
  257. // sent, synchronously.
  258. //
  259. // When accessed via multiple goroutines concurrently, and the cache
  260. // has not been populated yet, only the first goroutine is
  261. // allowed to perform the initialization (HTTP fetch and cache population).
  262. // All other goroutines will be blocked until the operation is completed.
  263. //
  264. // DO NOT modify the jwk.Set object returned by this method, as the
  265. // objects are shared among all consumers and the backend goroutine
  266. func (af *AutoRefresh) Fetch(ctx context.Context, url string) (Set, error) {
  267. if _, ok := af.getRegistered(url); !ok {
  268. return nil, errors.Errorf(`url %s must be configured using "Configure()" first`, url)
  269. }
  270. ks, found := af.getCached(url)
  271. if found {
  272. return ks, nil
  273. }
  274. return af.refresh(ctx, url)
  275. }
  276. // Refresh is the same as Fetch(), except that HTTP fetching is done synchronously.
  277. //
  278. // This is useful when you want to force an HTTP fetch instead of waiting
  279. // for the background goroutine to do it, for example when you want to
  280. // make sure the AutoRefresh cache is warmed up before starting your main loop
  281. func (af *AutoRefresh) Refresh(ctx context.Context, url string) (Set, error) {
  282. if _, ok := af.getRegistered(url); !ok {
  283. return nil, errors.Errorf(`url %s must be configured using "Configure()" first`, url)
  284. }
  285. return af.refresh(ctx, url)
  286. }
  287. func (af *AutoRefresh) refresh(ctx context.Context, url string) (Set, error) {
  288. // To avoid a thundering herd, only one goroutine per url may enter into this
  289. // initial fetch phase.
  290. af.muFetching.Lock()
  291. fetchingCh, fetching := af.fetching[url]
  292. // unlock happens in each of the if/else clauses because we need to perform
  293. // the channel initialization when there is no channel present
  294. if fetching {
  295. af.muFetching.Unlock()
  296. select {
  297. case <-ctx.Done():
  298. return nil, ctx.Err()
  299. case <-fetchingCh:
  300. }
  301. } else {
  302. fetchingCh = make(chan struct{})
  303. af.fetching[url] = fetchingCh
  304. af.muFetching.Unlock()
  305. // Register a cleanup handler, to make sure we always
  306. defer af.releaseFetching(url)
  307. // The first time around, we need to fetch the keyset
  308. if err := af.doRefreshRequest(ctx, url, false); err != nil {
  309. return nil, errors.Wrapf(err, `failed to fetch resource pointed by %s`, url)
  310. }
  311. }
  312. // the cache should now be populated
  313. ks, ok := af.getCached(url)
  314. if !ok {
  315. return nil, errors.New("cache was not populated after explicit refresh")
  316. }
  317. return ks, nil
  318. }
  319. // Keeps looping, while refreshing the KeySet.
  320. func (af *AutoRefresh) refreshLoop(ctx context.Context) {
  321. // reflect.Select() is slow IF we are executing it over and over
  322. // in a very fast iteration, but we assume here that refreshes happen
  323. // seldom enough that being able to call one `select{}` with multiple
  324. // targets / channels outweighs the speed penalty of using reflect.
  325. //
  326. const (
  327. ctxDoneIdx = iota
  328. configureIdx
  329. resetTimerIdx
  330. removeIdx
  331. baseSelcasesLen
  332. )
  333. baseSelcases := make([]reflect.SelectCase, baseSelcasesLen)
  334. baseSelcases[ctxDoneIdx] = reflect.SelectCase{
  335. Dir: reflect.SelectRecv,
  336. Chan: reflect.ValueOf(ctx.Done()),
  337. }
  338. baseSelcases[configureIdx] = reflect.SelectCase{
  339. Dir: reflect.SelectRecv,
  340. Chan: reflect.ValueOf(af.configureCh),
  341. }
  342. baseSelcases[resetTimerIdx] = reflect.SelectCase{
  343. Dir: reflect.SelectRecv,
  344. Chan: reflect.ValueOf(af.resetTimerCh),
  345. }
  346. baseSelcases[removeIdx] = reflect.SelectCase{
  347. Dir: reflect.SelectRecv,
  348. Chan: reflect.ValueOf(af.removeCh),
  349. }
  350. var targets []*target
  351. var selcases []reflect.SelectCase
  352. for {
  353. // It seems silly, but it's much easier to keep track of things
  354. // if we re-build the select cases every iteration
  355. af.muRegistry.RLock()
  356. if cap(targets) < len(af.registry) {
  357. targets = make([]*target, 0, len(af.registry))
  358. } else {
  359. targets = targets[:0]
  360. }
  361. if cap(selcases) < len(af.registry) {
  362. selcases = make([]reflect.SelectCase, 0, len(af.registry)+baseSelcasesLen)
  363. } else {
  364. selcases = selcases[:0]
  365. }
  366. selcases = append(selcases, baseSelcases...)
  367. for _, data := range af.registry {
  368. targets = append(targets, data)
  369. selcases = append(selcases, reflect.SelectCase{
  370. Dir: reflect.SelectRecv,
  371. Chan: reflect.ValueOf(data.timer.C),
  372. })
  373. }
  374. af.muRegistry.RUnlock()
  375. chosen, recv, recvOK := reflect.Select(selcases)
  376. switch chosen {
  377. case ctxDoneIdx:
  378. // <-ctx.Done(). Just bail out of this loop
  379. return
  380. case configureIdx:
  381. // <-configureCh. rebuild the select list from the registry.
  382. // since we're rebuilding everything for each iteration,
  383. // we just need to start the loop all over again
  384. continue
  385. case resetTimerIdx:
  386. // <-resetTimerCh. interrupt polling, and reset the timer on
  387. // a single target. this needs to be handled inside this select
  388. if !recvOK {
  389. continue
  390. }
  391. req := recv.Interface().(*resetTimerReq) //nolint:forcetypeassert
  392. t := req.t
  393. d := req.d
  394. if !t.timer.Stop() {
  395. select {
  396. case <-t.timer.C:
  397. default:
  398. }
  399. }
  400. t.timer.Reset(d)
  401. case removeIdx:
  402. // <-removeCh. remove the URL from future fetching
  403. //nolint:forcetypeassert
  404. req := recv.Interface().(removeReq)
  405. replyCh := req.replyCh
  406. url := req.url
  407. af.muRegistry.Lock()
  408. if _, ok := af.registry[url]; !ok {
  409. replyCh <- errors.Errorf(`invalid url %q (not registered)`, url)
  410. } else {
  411. delete(af.registry, url)
  412. replyCh <- nil
  413. }
  414. af.muRegistry.Unlock()
  415. default:
  416. // Do not fire a refresh in case the channel was closed.
  417. if !recvOK {
  418. continue
  419. }
  420. // Time to refresh a target
  421. t := targets[chosen-baseSelcasesLen]
  422. // Check if there are other goroutines still doing the refresh asynchronously.
  423. // This could happen if the refreshing goroutine is stuck on a backoff
  424. // waiting for the HTTP request to complete.
  425. select {
  426. case t.sem <- struct{}{}:
  427. // There can only be one refreshing goroutine
  428. default:
  429. continue
  430. }
  431. go func() {
  432. //nolint:errcheck
  433. af.doRefreshRequest(ctx, t.url, true)
  434. <-t.sem
  435. }()
  436. }
  437. }
  438. }
  439. func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableBackoff bool) error {
  440. af.muRegistry.RLock()
  441. t, ok := af.registry[url]
  442. if !ok {
  443. af.muRegistry.RUnlock()
  444. return errors.Errorf(`url "%s" is not registered`, url)
  445. }
  446. // In case the refresh fails due to errors in fetching/parsing the JWKS,
  447. // we want to retry. Create a backoff object,
  448. parseOptions := t.parseOptions
  449. fetchOptions := []FetchOption{WithHTTPClient(t.httpcl)}
  450. if enableBackoff {
  451. fetchOptions = append(fetchOptions, WithFetchBackoff(t.backoff))
  452. }
  453. if t.wl != nil {
  454. fetchOptions = append(fetchOptions, WithFetchWhitelist(t.wl))
  455. }
  456. af.muRegistry.RUnlock()
  457. res, err := fetch(ctx, url, fetchOptions...)
  458. if err == nil {
  459. if res.StatusCode != http.StatusOK {
  460. // now, can there be a remote resource that responds with a status code
  461. // other than 200 and still be valid...? naaaaaaahhhhhh....
  462. err = errors.Errorf(`bad response status code (%d)`, res.StatusCode)
  463. } else {
  464. defer res.Body.Close()
  465. keyset, parseErr := ParseReader(res.Body, parseOptions...)
  466. if parseErr == nil {
  467. // Got a new key set. replace the keyset in the target
  468. af.muCache.Lock()
  469. af.cache[url] = keyset
  470. af.muCache.Unlock()
  471. af.muRegistry.RLock()
  472. nextInterval := calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval)
  473. af.muRegistry.RUnlock()
  474. rtr := &resetTimerReq{
  475. t: t,
  476. d: nextInterval,
  477. }
  478. select {
  479. case <-ctx.Done():
  480. return ctx.Err()
  481. case af.resetTimerCh <- rtr:
  482. }
  483. now := time.Now()
  484. af.muRegistry.Lock()
  485. t.lastRefresh = now.Local()
  486. t.nextRefresh = now.Add(nextInterval).Local()
  487. af.muRegistry.Unlock()
  488. return nil
  489. }
  490. err = parseErr
  491. }
  492. }
  493. // At this point if err != nil, we know that there was something wrong
  494. // in either the fetching or the parsing. Send this error to be processed,
  495. // but take the extra mileage to not block regular processing by
  496. // discarding the error if we fail to send it through the channel
  497. if err != nil {
  498. select {
  499. case af.errSink <- AutoRefreshError{Error: err, URL: url}:
  500. default:
  501. }
  502. }
  503. // We either failed to perform the HTTP GET, or we failed to parse the
  504. // JWK set. Even in case of errors, we don't delete the old key.
  505. // We persist the old key set, even if it may be stale so the user has something to work with
  506. // TODO: maybe this behavior should be customizable?
  507. // If we failed to get a single time, then queue another fetch in the future.
  508. rtr := &resetTimerReq{
  509. t: t,
  510. d: calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval),
  511. }
  512. select {
  513. case <-ctx.Done():
  514. return ctx.Err()
  515. case af.resetTimerCh <- rtr:
  516. }
  517. return err
  518. }
  519. // ErrorSink sets a channel to receive JWK fetch errors, if any.
  520. // Only the errors that occurred *after* the channel was set will be sent.
  521. //
  522. // The user is responsible for properly draining the channel. If the channel
  523. // is not drained properly, errors will be discarded.
  524. //
  525. // To disable, set a nil channel.
  526. func (af *AutoRefresh) ErrorSink(ch chan AutoRefreshError) {
  527. af.muErrSink.Lock()
  528. af.errSink = ch
  529. af.muErrSink.Unlock()
  530. }
  531. func calculateRefreshDuration(res *http.Response, refreshInterval *time.Duration, minRefreshInterval time.Duration) time.Duration {
  532. // This always has precedence
  533. if refreshInterval != nil {
  534. return *refreshInterval
  535. }
  536. if res != nil {
  537. if v := res.Header.Get(`Cache-Control`); v != "" {
  538. dir, err := httpcc.ParseResponse(v)
  539. if err == nil {
  540. maxAge, ok := dir.MaxAge()
  541. if ok {
  542. resDuration := time.Duration(maxAge) * time.Second
  543. if resDuration > minRefreshInterval {
  544. return resDuration
  545. }
  546. return minRefreshInterval
  547. }
  548. // fallthrough
  549. }
  550. // fallthrough
  551. }
  552. if v := res.Header.Get(`Expires`); v != "" {
  553. expires, err := http.ParseTime(v)
  554. if err == nil {
  555. resDuration := time.Until(expires)
  556. if resDuration > minRefreshInterval {
  557. return resDuration
  558. }
  559. return minRefreshInterval
  560. }
  561. // fallthrough
  562. }
  563. }
  564. // Previous fallthroughs are a little redandunt, but hey, it's all good.
  565. return minRefreshInterval
  566. }
  567. // TargetSnapshot is the structure returned by the Snapshot method.
  568. // It contains information about a url that has been configured
  569. // in AutoRefresh.
  570. type TargetSnapshot struct {
  571. URL string
  572. NextRefresh time.Time
  573. LastRefresh time.Time
  574. }
  575. func (af *AutoRefresh) Snapshot() <-chan TargetSnapshot {
  576. af.muRegistry.Lock()
  577. ch := make(chan TargetSnapshot, len(af.registry))
  578. for url, t := range af.registry {
  579. ch <- TargetSnapshot{
  580. URL: url,
  581. NextRefresh: t.nextRefresh,
  582. LastRefresh: t.lastRefresh,
  583. }
  584. }
  585. af.muRegistry.Unlock()
  586. close(ch)
  587. return ch
  588. }