savepoint.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. // Copyright (c) 2018 David Crawshaw <david@zentus.com>
  2. //
  3. // Permission to use, copy, modify, and distribute this software for any
  4. // purpose with or without fee is hereby granted, provided that the above
  5. // copyright notice and this permission notice appear in all copies.
  6. //
  7. // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  8. // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  9. // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  10. // ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  11. // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  12. // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  13. // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  14. package sqlitex
  15. import (
  16. "fmt"
  17. "runtime"
  18. "strings"
  19. "github.com/go-llsqlite/crawshaw"
  20. )
  21. // Save creates a named SQLite transaction using SAVEPOINT.
  22. //
  23. // On success Savepoint returns a releaseFn that will call either
  24. // RELEASE or ROLLBACK depending on whether the parameter *error
  25. // points to a nil or non-nil error. This is designed to be deferred.
  26. //
  27. // Example:
  28. //
  29. // func doWork(conn *sqlite.Conn) (err error) {
  30. // defer sqlitex.Save(conn)(&err)
  31. //
  32. // // ... do work in the transaction
  33. // }
  34. //
  35. // https://www.sqlite.org/lang_savepoint.html
  36. func Save(conn *sqlite.Conn) (releaseFn func(*error)) {
  37. name := "sqlitex.Save" // safe as names can be reused
  38. var pc [3]uintptr
  39. if n := runtime.Callers(0, pc[:]); n > 0 {
  40. frames := runtime.CallersFrames(pc[:n])
  41. if _, more := frames.Next(); more { // runtime.Callers
  42. if _, more := frames.Next(); more { // savepoint.Save
  43. frame, _ := frames.Next() // caller we care about
  44. if frame.Function != "" {
  45. name = frame.Function
  46. }
  47. }
  48. }
  49. }
  50. releaseFn, err := savepoint(conn, name)
  51. if err != nil {
  52. if sqlite.ErrCode(err) == sqlite.SQLITE_INTERRUPT {
  53. return func(errp *error) {
  54. if *errp == nil {
  55. *errp = err
  56. }
  57. }
  58. }
  59. panic(err)
  60. }
  61. return releaseFn
  62. }
  63. func savepoint(conn *sqlite.Conn, name string) (releaseFn func(*error), err error) {
  64. if strings.Contains(name, `"`) {
  65. return nil, fmt.Errorf("sqlitex.Savepoint: invalid name: %q", name)
  66. }
  67. if err := Exec(conn, fmt.Sprintf("SAVEPOINT %q;", name), nil); err != nil {
  68. return nil, err
  69. }
  70. tracer := conn.Tracer()
  71. if tracer != nil {
  72. tracer.Push("TX " + name)
  73. }
  74. releaseFn = func(errp *error) {
  75. if tracer != nil {
  76. tracer.Pop()
  77. }
  78. recoverP := recover()
  79. // If a query was interrupted or if a user exec'd COMMIT or
  80. // ROLLBACK, then everything was already rolled back
  81. // automatically, thus returning the connection to autocommit
  82. // mode.
  83. if conn.GetAutocommit() {
  84. // There is nothing to rollback.
  85. if recoverP != nil {
  86. panic(recoverP)
  87. }
  88. return
  89. }
  90. if *errp == nil && recoverP == nil {
  91. // Success path. Release the savepoint successfully.
  92. *errp = Exec(conn, fmt.Sprintf("RELEASE %q;", name), nil)
  93. if *errp == nil {
  94. return
  95. }
  96. // Possible interrupt. Fall through to the error path.
  97. if conn.GetAutocommit() {
  98. // There is nothing to rollback.
  99. if recoverP != nil {
  100. panic(recoverP)
  101. }
  102. return
  103. }
  104. }
  105. orig := ""
  106. if *errp != nil {
  107. orig = (*errp).Error() + "\n\t"
  108. }
  109. // Error path.
  110. // Always run ROLLBACK even if the connection has been interrupted.
  111. oldDoneCh := conn.SetInterrupt(nil)
  112. defer conn.SetInterrupt(oldDoneCh)
  113. err := Exec(conn, fmt.Sprintf("ROLLBACK TO %q;", name), nil)
  114. if err != nil {
  115. panic(orig + err.Error())
  116. }
  117. err = Exec(conn, fmt.Sprintf("RELEASE %q;", name), nil)
  118. if err != nil {
  119. panic(orig + err.Error())
  120. }
  121. if recoverP != nil {
  122. panic(recoverP)
  123. }
  124. }
  125. return releaseFn, nil
  126. }