tx.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. package stm
  2. import (
  3. "fmt"
  4. "sort"
  5. "sync"
  6. "unsafe"
  7. "github.com/alecthomas/atomic"
  8. )
  9. type txVar interface {
  10. getValue() *atomic.Value[VarValue]
  11. changeValue(any)
  12. getWatchers() *sync.Map
  13. getLock() *sync.Mutex
  14. }
  15. // A Tx represents an atomic transaction.
  16. type Tx struct {
  17. reads map[txVar]VarValue
  18. writes map[txVar]any
  19. watching map[txVar]struct{}
  20. locks txLocks
  21. mu sync.Mutex
  22. cond sync.Cond
  23. waiting bool
  24. completed bool
  25. tries int
  26. numRetryValues int
  27. }
  28. // Check that none of the logged values have changed since the transaction began.
  29. func (tx *Tx) inputsChanged() bool {
  30. for v, read := range tx.reads {
  31. if read.Changed(v.getValue().Load()) {
  32. return true
  33. }
  34. }
  35. return false
  36. }
  37. // Writes the values in the transaction log to their respective Vars.
  38. func (tx *Tx) commit() {
  39. for v, val := range tx.writes {
  40. v.changeValue(val)
  41. }
  42. }
  43. func (tx *Tx) updateWatchers() {
  44. for v := range tx.watching {
  45. if _, ok := tx.reads[v]; !ok {
  46. delete(tx.watching, v)
  47. v.getWatchers().Delete(tx)
  48. }
  49. }
  50. for v := range tx.reads {
  51. if _, ok := tx.watching[v]; !ok {
  52. v.getWatchers().Store(tx, nil)
  53. tx.watching[v] = struct{}{}
  54. }
  55. }
  56. }
  57. // wait blocks until another transaction modifies any of the Vars read by tx.
  58. func (tx *Tx) wait() {
  59. if len(tx.reads) == 0 {
  60. panic("not waiting on anything")
  61. }
  62. tx.updateWatchers()
  63. tx.mu.Lock()
  64. firstWait := true
  65. for !tx.inputsChanged() {
  66. if !firstWait {
  67. expvars.Add("wakes for unchanged versions", 1)
  68. }
  69. expvars.Add("waits", 1)
  70. tx.waiting = true
  71. tx.cond.Broadcast()
  72. tx.cond.Wait()
  73. tx.waiting = false
  74. firstWait = false
  75. }
  76. tx.mu.Unlock()
  77. }
  78. // Get returns the value of v as of the start of the transaction.
  79. func (v *Var[T]) Get(tx *Tx) T {
  80. // If we previously wrote to v, it will be in the write log.
  81. if val, ok := tx.writes[v]; ok {
  82. return val.(T)
  83. }
  84. // If we haven't previously read v, record its version
  85. vv, ok := tx.reads[v]
  86. if !ok {
  87. vv = v.getValue().Load()
  88. tx.reads[v] = vv
  89. }
  90. return vv.Get().(T)
  91. }
  92. // Set sets the value of a Var for the lifetime of the transaction.
  93. func (v *Var[T]) Set(tx *Tx, val T) {
  94. if v == nil {
  95. panic("nil Var")
  96. }
  97. tx.writes[v] = val
  98. }
  99. type txProfileValue struct {
  100. *Tx
  101. int
  102. }
  103. // Retry aborts the transaction and retries it when a Var changes. You can return from this method
  104. // to satisfy return values, but it should never actually return anything as it panics internally.
  105. func (tx *Tx) Retry() struct{} {
  106. retries.Add(txProfileValue{tx, tx.numRetryValues}, 1)
  107. tx.numRetryValues++
  108. panic(retry)
  109. }
  110. // Assert is a helper function that retries a transaction if the condition is
  111. // not satisfied.
  112. func (tx *Tx) Assert(p bool) {
  113. if !p {
  114. tx.Retry()
  115. }
  116. }
  117. func (tx *Tx) reset() {
  118. tx.mu.Lock()
  119. for k := range tx.reads {
  120. delete(tx.reads, k)
  121. }
  122. for k := range tx.writes {
  123. delete(tx.writes, k)
  124. }
  125. tx.mu.Unlock()
  126. tx.removeRetryProfiles()
  127. tx.resetLocks()
  128. }
  129. func (tx *Tx) removeRetryProfiles() {
  130. for tx.numRetryValues > 0 {
  131. tx.numRetryValues--
  132. retries.Remove(txProfileValue{tx, tx.numRetryValues})
  133. }
  134. }
  135. func (tx *Tx) recycle() {
  136. for v := range tx.watching {
  137. delete(tx.watching, v)
  138. v.getWatchers().Delete(tx)
  139. }
  140. tx.removeRetryProfiles()
  141. // I don't think we can reuse Txs, because the "completed" field should/needs to be set
  142. // indefinitely after use.
  143. //txPool.Put(tx)
  144. }
  145. func (tx *Tx) lockAllVars() {
  146. tx.resetLocks()
  147. tx.collectAllLocks()
  148. tx.sortLocks()
  149. tx.lock()
  150. }
  151. func (tx *Tx) resetLocks() {
  152. tx.locks.clear()
  153. }
  154. func (tx *Tx) collectReadLocks() {
  155. for v := range tx.reads {
  156. tx.locks.append(v.getLock())
  157. }
  158. }
  159. func (tx *Tx) collectAllLocks() {
  160. tx.collectReadLocks()
  161. for v := range tx.writes {
  162. if _, ok := tx.reads[v]; !ok {
  163. tx.locks.append(v.getLock())
  164. }
  165. }
  166. }
  167. func (tx *Tx) sortLocks() {
  168. sort.Sort(&tx.locks)
  169. }
  170. func (tx *Tx) lock() {
  171. for _, l := range tx.locks.mus {
  172. l.Lock()
  173. }
  174. }
  175. func (tx *Tx) unlock() {
  176. for _, l := range tx.locks.mus {
  177. l.Unlock()
  178. }
  179. }
  180. func (tx *Tx) String() string {
  181. return fmt.Sprintf("%[1]T %[1]p", tx)
  182. }
  183. // Dedicated type avoids reflection in sort.Slice.
  184. type txLocks struct {
  185. mus []*sync.Mutex
  186. }
  187. func (me txLocks) Len() int {
  188. return len(me.mus)
  189. }
  190. func (me txLocks) Less(i, j int) bool {
  191. return uintptr(unsafe.Pointer(me.mus[i])) < uintptr(unsafe.Pointer(me.mus[j]))
  192. }
  193. func (me txLocks) Swap(i, j int) {
  194. me.mus[i], me.mus[j] = me.mus[j], me.mus[i]
  195. }
  196. func (me *txLocks) clear() {
  197. me.mus = me.mus[:0]
  198. }
  199. func (me *txLocks) append(mu *sync.Mutex) {
  200. me.mus = append(me.mus, mu)
  201. }