limiter.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. // Unless explicitly stated otherwise all files in this repository are licensed
  2. // under the Apache License Version 2.0.
  3. // This product includes software developed at Datadog (https://www.datadoghq.com/).
  4. // Copyright 2022 Datadog, Inc.
  5. //go:build appsec
  6. // +build appsec
  7. package appsec
  8. import (
  9. "sync/atomic"
  10. "time"
  11. )
  12. // Limiter is used to abstract the rate limiter implementation to only expose the needed function for rate limiting.
  13. // This is for example useful for testing, allowing us to use a modified rate limiter tuned for testing through the same
  14. // interface.
  15. type Limiter interface {
  16. Allow() bool
  17. }
  18. // TokenTicker is a thread-safe and lock-free rate limiter based on a token bucket.
  19. // The idea is to have a goroutine that will update the bucket with fresh tokens at regular intervals using a time.Ticker.
  20. // The advantage of using a goroutine here is that the implementation becomes easily thread-safe using a few
  21. // atomic operations with little overhead overall. TokenTicker.Start() *should* be called before the first call to
  22. // TokenTicker.Allow() and TokenTicker.Stop() *must* be called once done using. Note that calling TokenTicker.Allow()
  23. // before TokenTicker.Start() is valid, but it means the bucket won't be refilling until the call to TokenTicker.Start() is made
  24. type TokenTicker struct {
  25. tokens int64
  26. maxTokens int64
  27. ticker *time.Ticker
  28. stopChan chan struct{}
  29. }
  30. // NewTokenTicker is a utility function that allocates a token ticker, initializes necessary fields and returns it
  31. func NewTokenTicker(tokens, maxTokens int64) *TokenTicker {
  32. return &TokenTicker{
  33. tokens: tokens,
  34. maxTokens: maxTokens,
  35. }
  36. }
  37. // updateBucket performs a select loop to update the token amount in the bucket.
  38. // Used in a goroutine by the rate limiter.
  39. func (t *TokenTicker) updateBucket(ticksChan <-chan time.Time, startTime time.Time, syncChan chan struct{}) {
  40. nsPerToken := time.Second.Nanoseconds() / t.maxTokens
  41. elapsedNs := int64(0)
  42. prevStamp := startTime
  43. for {
  44. select {
  45. case <-t.stopChan:
  46. if syncChan != nil {
  47. close(syncChan)
  48. }
  49. return
  50. case stamp := <-ticksChan:
  51. // Compute the time in nanoseconds that passed between the previous timestamp and this one
  52. // This will be used to know how many tokens can be added into the bucket depending on the limiter rate
  53. elapsedNs += stamp.Sub(prevStamp).Nanoseconds()
  54. if elapsedNs > t.maxTokens*nsPerToken {
  55. elapsedNs = t.maxTokens * nsPerToken
  56. }
  57. prevStamp = stamp
  58. // Update the number of tokens in the bucket if enough nanoseconds have passed
  59. if elapsedNs >= nsPerToken {
  60. // Atomic spin lock to make sure we don't race for `t.tokens`
  61. for {
  62. tokens := atomic.LoadInt64(&t.tokens)
  63. if tokens == t.maxTokens {
  64. break // Bucket is already full, nothing to do
  65. }
  66. inc := elapsedNs / nsPerToken
  67. // Make sure not to add more tokens than we are allowed to into the bucket
  68. if tokens+inc > t.maxTokens {
  69. inc -= (tokens + inc) % t.maxTokens
  70. }
  71. if atomic.CompareAndSwapInt64(&t.tokens, tokens, tokens+inc) {
  72. // Keep track of remaining elapsed ns that were not taken into account for this computation,
  73. // so that increment computation remains precise over time
  74. elapsedNs = elapsedNs % nsPerToken
  75. break
  76. }
  77. }
  78. }
  79. // Sync channel used to signify that the goroutine is done updating the bucket. Used for tests to guarantee
  80. // that the goroutine ticked at least once.
  81. if syncChan != nil {
  82. syncChan <- struct{}{}
  83. }
  84. }
  85. }
  86. }
  87. // Start starts the ticker and launches the goroutine responsible for updating the token bucket.
  88. // The ticker is set to tick at a fixed rate of 500us.
  89. func (t *TokenTicker) Start() {
  90. timeNow := time.Now()
  91. t.ticker = time.NewTicker(500 * time.Microsecond)
  92. t.start(t.ticker.C, timeNow, false)
  93. }
  94. // start is used for internal testing. Controlling the ticker means being able to test per-tick
  95. // rather than per-duration, which is more reliable if the app is under a lot of stress.
  96. // sync is used to decide whether the limiter should create a channel for synchronization with the testing app after a
  97. // bucket update. The limiter is in charge of closing the channel in this case.
  98. func (t *TokenTicker) start(ticksChan <-chan time.Time, startTime time.Time, sync bool) <-chan struct{} {
  99. t.stopChan = make(chan struct{})
  100. var syncChan chan struct{}
  101. if sync {
  102. syncChan = make(chan struct{})
  103. }
  104. go t.updateBucket(ticksChan, startTime, syncChan)
  105. return syncChan
  106. }
  107. // Stop shuts down the rate limiter, taking care stopping the ticker and closing all channels
  108. func (t *TokenTicker) Stop() {
  109. // Stop the ticker only if it has been instantiated (not the case when testing by calling start() directly)
  110. if t.ticker != nil {
  111. t.ticker.Stop()
  112. }
  113. // Close the stop channel only if it has been created. This covers the case where Stop() is called without any prior
  114. // call to Start()
  115. if t.stopChan != nil {
  116. close(t.stopChan)
  117. }
  118. }
  119. // Allow checks and returns whether a token can be retrieved from the bucket and consumed.
  120. // Thread-safe.
  121. func (t *TokenTicker) Allow() bool {
  122. for {
  123. tokens := atomic.LoadInt64(&t.tokens)
  124. if tokens == 0 {
  125. return false
  126. } else if atomic.CompareAndSwapInt64(&t.tokens, tokens, tokens-1) {
  127. return true
  128. }
  129. }
  130. }