session.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874
  1. // Copyright (c) 2018 David Crawshaw <david@zentus.com>
  2. // Copyright (c) 2021 Ross Light <ross@zombiezen.com>
  3. //
  4. // Permission to use, copy, modify, and distribute this software for any
  5. // purpose with or without fee is hereby granted, provided that the above
  6. // copyright notice and this permission notice appear in all copies.
  7. //
  8. // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9. // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  10. // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  11. // ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  12. // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  13. // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  14. // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. //
  16. // SPDX-License-Identifier: ISC
  17. package sqlite
  18. import (
  19. "fmt"
  20. "io"
  21. "io/ioutil"
  22. "runtime"
  23. "sync"
  24. "unsafe"
  25. "modernc.org/libc"
  26. "modernc.org/libc/sys/types"
  27. lib "modernc.org/sqlite/lib"
  28. )
  29. // A Session tracks database changes made by a Conn.
  30. // It is used to build changesets.
  31. //
  32. // For more details: https://www.sqlite.org/sessionintro.html
  33. type Session struct {
  34. tls *libc.TLS
  35. ptr uintptr
  36. }
  37. // CreateSession creates a new session object.
  38. // If db is "", then a default of "main" is used.
  39. // It is the caller's responsibility to call Delete when the session is
  40. // no longer needed.
  41. //
  42. // https://www.sqlite.org/session/sqlite3session_create.html
  43. func (c *Conn) CreateSession(db string) (*Session, error) {
  44. if c == nil {
  45. return nil, fmt.Errorf("sqlite: create session: nil connection")
  46. }
  47. var cdb uintptr
  48. if db == "" || db == "main" {
  49. cdb = mainCString
  50. } else {
  51. var err error
  52. cdb, err = libc.CString(db)
  53. if err != nil {
  54. return nil, fmt.Errorf("sqlite: create session: %w", err)
  55. }
  56. defer libc.Xfree(c.tls, cdb)
  57. }
  58. doublePtr, err := malloc(c.tls, ptrSize)
  59. if err != nil {
  60. return nil, fmt.Errorf("sqlite: create session: %w", err)
  61. }
  62. defer libc.Xfree(c.tls, doublePtr)
  63. res := ResultCode(lib.Xsqlite3session_create(c.tls, c.conn, cdb, doublePtr))
  64. if err := res.ToError(); err != nil {
  65. return nil, fmt.Errorf("sqlite: create session: %w", err)
  66. }
  67. s := &Session{
  68. tls: c.tls,
  69. ptr: *(*uintptr)(unsafe.Pointer(doublePtr)),
  70. }
  71. runtime.SetFinalizer(s, func(s *Session) {
  72. if s.ptr != 0 {
  73. panic("open *sqlite.Session garbage collected, call Delete method")
  74. }
  75. })
  76. return s, nil
  77. }
  78. // Delete releases any resources associated with the session.
  79. // It must be called before closing the Conn the session is attached to.
  80. func (s *Session) Delete() {
  81. if s.ptr == 0 {
  82. panic("Session.Delete called twice on same session")
  83. }
  84. lib.Xsqlite3session_delete(s.tls, s.ptr)
  85. s.ptr = 0
  86. s.tls = nil
  87. }
  88. // Enable enables recording of changes after a previous call to Disable.
  89. // New sessions start enabled.
  90. //
  91. // https://www.sqlite.org/session/sqlite3session_enable.html
  92. func (s *Session) Enable() {
  93. if s.ptr == 0 {
  94. panic("Session.Enable called on deleted session")
  95. }
  96. lib.Xsqlite3session_enable(s.tls, s.ptr, 1)
  97. }
  98. // Disable disables recording of changes.
  99. //
  100. // https://www.sqlite.org/session/sqlite3session_enable.html
  101. func (s *Session) Disable() {
  102. if s.ptr == 0 {
  103. panic("Session.Disable called on deleted session")
  104. }
  105. lib.Xsqlite3session_enable(s.tls, s.ptr, 0)
  106. }
  107. // Attach attaches a table to the session object.
  108. // Changes made to the table will be tracked by the session.
  109. // An empty tableName attaches all the tables in the database.
  110. func (s *Session) Attach(tableName string) error {
  111. if s.ptr == 0 {
  112. return fmt.Errorf("sqlite: attach table %q to session: session deleted", tableName)
  113. }
  114. var ctable uintptr
  115. if tableName != "" {
  116. var err error
  117. ctable, err = libc.CString(tableName)
  118. if err != nil {
  119. return fmt.Errorf("sqlite: attach table %q to session: %v", tableName, err)
  120. }
  121. defer libc.Xfree(s.tls, ctable)
  122. }
  123. res := ResultCode(lib.Xsqlite3session_attach(s.tls, s.ptr, ctable))
  124. if err := res.ToError(); err != nil {
  125. return fmt.Errorf("sqlite: attach table %q to session: %w", tableName, err)
  126. }
  127. return nil
  128. }
  129. // Diff appends the difference between two tables (srcDB and the session DB) to
  130. // the session. The two tables must have the same name and schema.
  131. //
  132. // https://www.sqlite.org/session/sqlite3session_diff.html
  133. func (s *Session) Diff(srcDB, tableName string) error {
  134. if s.ptr == 0 {
  135. return fmt.Errorf("sqlite: diff table %q: session deleted", tableName)
  136. }
  137. errMsgPtr, err := malloc(s.tls, ptrSize)
  138. if err != nil {
  139. return fmt.Errorf("sqlite: diff table %q: %v", tableName, err)
  140. }
  141. defer libc.Xfree(s.tls, errMsgPtr)
  142. csrcDB, err := libc.CString(srcDB)
  143. if err != nil {
  144. return fmt.Errorf("sqlite: diff table %q: %v", tableName, err)
  145. }
  146. defer libc.Xfree(s.tls, csrcDB)
  147. ctable, err := libc.CString(tableName)
  148. if err != nil {
  149. return fmt.Errorf("sqlite: diff table %q: %v", tableName, err)
  150. }
  151. defer libc.Xfree(s.tls, ctable)
  152. res := ResultCode(lib.Xsqlite3session_diff(s.tls, s.ptr, csrcDB, ctable, errMsgPtr))
  153. if err := res.ToError(); err != nil {
  154. cerrMsg := *(*uintptr)(unsafe.Pointer(errMsgPtr))
  155. if cerrMsg == 0 {
  156. return fmt.Errorf("sqlite: diff table %q: %w", tableName, err)
  157. }
  158. errMsg := libc.GoString(cerrMsg)
  159. lib.Xsqlite3_free(s.tls, cerrMsg)
  160. return fmt.Errorf("sqlite: diff table %q: %w (%s)", tableName, err, errMsg)
  161. }
  162. return nil
  163. }
  164. // WriteChangeset generates a changeset from a session.
  165. //
  166. // https://www.sqlite.org/session/sqlite3session_changeset.html
  167. func (s *Session) WriteChangeset(w io.Writer) error {
  168. if s.ptr == 0 {
  169. return fmt.Errorf("sqlite: write session changeset: session deleted")
  170. }
  171. xOutput, pOut := registerStreamWriter(w)
  172. defer unregisterStreamWriter(pOut)
  173. res := ResultCode(lib.Xsqlite3session_changeset_strm(s.tls, s.ptr, xOutput, pOut))
  174. if err := res.ToError(); err != nil {
  175. return fmt.Errorf("sqlite: write session changeset: %w", err)
  176. }
  177. return nil
  178. }
  179. // WritePatchset generates a patchset from a session.
  180. //
  181. // https://www.sqlite.org/session/sqlite3session_patchset.html
  182. func (s *Session) WritePatchset(w io.Writer) error {
  183. if s.ptr == 0 {
  184. return fmt.Errorf("sqlite: write session patchset: session deleted")
  185. }
  186. xOutput, pOut := registerStreamWriter(w)
  187. defer unregisterStreamWriter(pOut)
  188. res := ResultCode(lib.Xsqlite3session_patchset_strm(s.tls, s.ptr, xOutput, pOut))
  189. if err := res.ToError(); err != nil {
  190. return fmt.Errorf("sqlite: write session patchset: %w", err)
  191. }
  192. return nil
  193. }
  194. // ApplyChangeset applies a changeset to the database.
  195. //
  196. // If filterFn is not nil and the changeset includes changes for a table for
  197. // which the function reports false, then the changes are ignored. If filterFn
  198. // is nil, then all changes are applied.
  199. //
  200. // If a changeset will not apply cleanly, then conflictFn will be called to
  201. // resolve the conflict. See https://www.sqlite.org/session/sqlite3changeset_apply.html
  202. // for more details.
  203. func (c *Conn) ApplyChangeset(r io.Reader, filterFn func(tableName string) bool, conflictFn ConflictHandler) error {
  204. if c == nil {
  205. return fmt.Errorf("sqlite: apply changeset: nil connection")
  206. }
  207. if conflictFn == nil {
  208. return fmt.Errorf("sqlite: apply changeset: no conflict handler provided")
  209. }
  210. xInput, pIn := registerStreamReader(r)
  211. defer unregisterStreamReader(pIn)
  212. appliesIDMu.Lock()
  213. pCtx := appliesIDs.next()
  214. appliesIDMu.Unlock()
  215. applies.Store(pCtx, applyFuncs{
  216. tls: c.tls,
  217. filterFn: filterFn,
  218. conflictFn: conflictFn,
  219. })
  220. defer func() {
  221. applies.Delete(pCtx)
  222. appliesIDMu.Lock()
  223. appliesIDs.reclaim(pCtx)
  224. appliesIDMu.Unlock()
  225. }()
  226. xFilter := uintptr(0)
  227. if filterFn != nil {
  228. xFilter = cFuncPointer(changesetApplyFilter)
  229. }
  230. xConflict := cFuncPointer(changesetApplyConflict)
  231. res := ResultCode(lib.Xsqlite3changeset_apply_strm(c.tls, c.conn, xInput, pIn, xFilter, xConflict, pCtx))
  232. if err := res.ToError(); err != nil {
  233. return fmt.Errorf("sqlite: apply changeset: %w", err)
  234. }
  235. return nil
  236. }
  237. type applyFuncs struct {
  238. tls *libc.TLS
  239. filterFn func(tableName string) bool
  240. conflictFn ConflictHandler
  241. }
  242. var (
  243. applies sync.Map // map[uintptr]applyFuncs
  244. appliesIDMu sync.Mutex
  245. appliesIDs idGen
  246. )
  247. func changesetApplyFilter(tls *libc.TLS, pCtx uintptr, zTab uintptr) int32 {
  248. appliesValue, _ := applies.Load(pCtx)
  249. funcs := appliesValue.(applyFuncs)
  250. tableName := libc.GoString(zTab)
  251. if funcs.filterFn(tableName) {
  252. return 1
  253. } else {
  254. return 0
  255. }
  256. }
  257. func changesetApplyConflict(tls *libc.TLS, pCtx uintptr, eConflict int32, p uintptr) int32 {
  258. appliesValue, _ := applies.Load(pCtx)
  259. funcs := appliesValue.(applyFuncs)
  260. return int32(funcs.conflictFn(ConflictType(eConflict), &ChangesetIterator{
  261. tls: tls,
  262. ptr: p,
  263. }))
  264. }
  265. // ApplyInverseChangeset applies the inverse of a changeset to the database.
  266. // See ApplyChangeset and InvertChangeset for more details.
  267. func (c *Conn) ApplyInverseChangeset(r io.Reader, filterFn func(tableName string) bool, conflictFn ConflictHandler) error {
  268. if c == nil {
  269. return fmt.Errorf("sqlite: apply changeset: nil connection")
  270. }
  271. pr, pw := io.Pipe()
  272. go func() {
  273. err := InvertChangeset(pw, pr)
  274. pw.CloseWithError(err)
  275. }()
  276. err := c.ApplyChangeset(pr, filterFn, conflictFn)
  277. io.Copy(ioutil.Discard, pr) // wait for invert goroutine to finish
  278. return err
  279. }
  280. // InvertChangeset generates an inverted changeset. Applying an inverted
  281. // changeset to a database reverses the effects of applying the uninverted
  282. // changeset.
  283. //
  284. // This function currently assumes that the input is a valid changeset.
  285. // If it is not, the results are undefined.
  286. //
  287. // https://www.sqlite.org/session/sqlite3changeset_invert.html
  288. func InvertChangeset(w io.Writer, r io.Reader) error {
  289. tls := libc.NewTLS()
  290. defer tls.Close()
  291. xInput, pIn := registerStreamReader(r)
  292. defer unregisterStreamReader(pIn)
  293. xOutput, pOut := registerStreamWriter(w)
  294. defer unregisterStreamWriter(pOut)
  295. res := ResultCode(lib.Xsqlite3changeset_invert_strm(tls, xInput, pIn, xOutput, pOut))
  296. if err := res.ToError(); err != nil {
  297. return fmt.Errorf("sqlite: invert changeset: %w", err)
  298. }
  299. return nil
  300. }
  301. // ChangesetIterator is an iterator over a changeset.
  302. type ChangesetIterator struct {
  303. tls *libc.TLS
  304. ptr uintptr
  305. ownTLS bool
  306. pIn uintptr // if non-zero, then must be unregistered at Finalize
  307. }
  308. // NewChangesetIterator returns a new iterator over the contents of the
  309. // changeset. The caller is responsible for calling Finalize on the returned
  310. // iterator.
  311. //
  312. // https://www.sqlite.org/session/sqlite3changeset_start.html
  313. func NewChangesetIterator(r io.Reader) (*ChangesetIterator, error) {
  314. tls := libc.NewTLS()
  315. xInput, pIn := registerStreamReader(r)
  316. pp, err := malloc(tls, ptrSize)
  317. if err != nil {
  318. unregisterStreamReader(pIn)
  319. tls.Close()
  320. return nil, fmt.Errorf("sqlite: start changeset iterator: %v", err)
  321. }
  322. defer libc.Xfree(tls, pp)
  323. res := ResultCode(lib.Xsqlite3changeset_start_strm(tls, pp, xInput, pIn))
  324. if err := res.ToError(); err != nil {
  325. unregisterStreamReader(pIn)
  326. tls.Close()
  327. return nil, fmt.Errorf("sqlite: start changeset iterator: %w", err)
  328. }
  329. iter := &ChangesetIterator{
  330. tls: tls,
  331. ownTLS: true,
  332. ptr: *(*uintptr)(unsafe.Pointer(pp)),
  333. pIn: pIn,
  334. }
  335. runtime.SetFinalizer(iter, func(iter *ChangesetIterator) {
  336. if iter.ptr != 0 {
  337. panic("open *sqlite.ChangesetIterator garbage collected, call Finalize method")
  338. }
  339. })
  340. return iter, nil
  341. }
  342. // Close releases any resources associated with the iterator created with
  343. // NewChangesetIterator.
  344. func (iter *ChangesetIterator) Close() error {
  345. if iter.ptr == 0 {
  346. return fmt.Errorf("sqlite: finalize changeset iterator: called twice on same iterator")
  347. }
  348. res := ResultCode(lib.Xsqlite3changeset_finalize(iter.tls, iter.ptr))
  349. iter.ptr = 0
  350. if iter.ownTLS {
  351. iter.tls.Close()
  352. }
  353. iter.tls = nil
  354. if iter.pIn != 0 {
  355. unregisterStreamReader(iter.pIn)
  356. }
  357. iter.pIn = 0
  358. if err := res.ToError(); err != nil {
  359. return fmt.Errorf("sqlite: finalize changeset iterator: %w", err)
  360. }
  361. return nil
  362. }
  363. // Next advances the iterator to the next change in the changeset.
  364. // It is an error to call Next on an iterator passed to an ApplyChangeset
  365. // conflict handler.
  366. //
  367. // https://www.sqlite.org/session/sqlite3changeset_next.html
  368. func (iter *ChangesetIterator) Next() (rowReturned bool, err error) {
  369. res := ResultCode(lib.Xsqlite3changeset_next(iter.tls, iter.ptr))
  370. switch res {
  371. case ResultRow:
  372. return true, nil
  373. case ResultDone:
  374. return false, nil
  375. default:
  376. return false, fmt.Errorf("sqlite: iterate changeset: %w", res.ToError())
  377. }
  378. }
  379. // ChangesetOperation holds information about a change in a changeset.
  380. type ChangesetOperation struct {
  381. // Type is one of OpInsert, OpDelete, or OpUpdate.
  382. Type OpType
  383. // TableName is the name of the table affected by the change.
  384. TableName string
  385. // NumColumns is the number of columns in the table affected by the change.
  386. NumColumns int
  387. // Indirect is true if the session object "indirect" flag was set when the
  388. // change was made or the change was made by an SQL trigger or foreign key
  389. // action instead of directly as a result of a users SQL statement.
  390. Indirect bool
  391. }
  392. // Operation obtains the current operation from the iterator.
  393. //
  394. // https://www.sqlite.org/session/sqlite3changeset_op.html
  395. func (iter *ChangesetIterator) Operation() (*ChangesetOperation, error) {
  396. if iter.ptr == 0 {
  397. return nil, fmt.Errorf("sqlite: changeset iterator operation: iterator finalized")
  398. }
  399. pzTab, err := malloc(iter.tls, ptrSize)
  400. if err != nil {
  401. return nil, fmt.Errorf("sqlite: changeset iterator operation: %v", err)
  402. }
  403. defer libc.Xfree(iter.tls, pzTab)
  404. pnCol, err := malloc(iter.tls, types.Size_t(unsafe.Sizeof(int32(0))))
  405. if err != nil {
  406. return nil, fmt.Errorf("sqlite: changeset iterator operation: %v", err)
  407. }
  408. defer libc.Xfree(iter.tls, pnCol)
  409. pOp, err := malloc(iter.tls, types.Size_t(unsafe.Sizeof(int32(0))))
  410. if err != nil {
  411. return nil, fmt.Errorf("sqlite: changeset iterator operation: %v", err)
  412. }
  413. defer libc.Xfree(iter.tls, pOp)
  414. pbIndirect, err := malloc(iter.tls, types.Size_t(unsafe.Sizeof(int32(0))))
  415. if err != nil {
  416. return nil, fmt.Errorf("sqlite: changeset iterator operation: %v", err)
  417. }
  418. defer libc.Xfree(iter.tls, pbIndirect)
  419. res := ResultCode(lib.Xsqlite3changeset_op(iter.tls, iter.ptr, pzTab, pnCol, pOp, pbIndirect))
  420. if err := res.ToError(); err != nil {
  421. return nil, fmt.Errorf("sqlite: changeset iterator operation: %w", err)
  422. }
  423. return &ChangesetOperation{
  424. Type: OpType(*(*int32)(unsafe.Pointer(pOp))),
  425. TableName: libc.GoString(*(*uintptr)(unsafe.Pointer(pzTab))),
  426. NumColumns: int(*(*int32)(unsafe.Pointer(pnCol))),
  427. Indirect: *(*int32)(unsafe.Pointer(pbIndirect)) != 0,
  428. }, nil
  429. }
  430. // Old obtains the old row value from an iterator. Column indices start at 0.
  431. // The returned value is valid until the iterator is finalized.
  432. //
  433. // https://www.sqlite.org/session/sqlite3changeset_old.html
  434. func (iter *ChangesetIterator) Old(col int) (Value, error) {
  435. if iter.ptr == 0 {
  436. return Value{}, fmt.Errorf("sqlite: get changeset iterator value: iterator finalized")
  437. }
  438. ppValue, err := malloc(iter.tls, ptrSize)
  439. if err != nil {
  440. return Value{}, fmt.Errorf("sqlite: get changeset iterator value: %v", err)
  441. }
  442. defer libc.Xfree(iter.tls, ppValue)
  443. res := ResultCode(lib.Xsqlite3changeset_old(iter.tls, iter.ptr, int32(col), ppValue))
  444. if err := res.ToError(); err != nil {
  445. return Value{}, fmt.Errorf("sqlite: get changeset iterator value: %w", err)
  446. }
  447. return Value{
  448. tls: iter.tls,
  449. ptrOrType: *(*uintptr)(unsafe.Pointer(ppValue)),
  450. }, nil
  451. }
  452. // New obtains the new row value from an iterator. Column indices start at 0.
  453. // The returned value is valid until the iterator is finalized.
  454. //
  455. // https://www.sqlite.org/session/sqlite3changeset_new.html
  456. func (iter *ChangesetIterator) New(col int) (Value, error) {
  457. if iter.ptr == 0 {
  458. return Value{}, fmt.Errorf("sqlite: get changeset iterator value: iterator finalized")
  459. }
  460. ppValue, err := malloc(iter.tls, ptrSize)
  461. if err != nil {
  462. return Value{}, fmt.Errorf("sqlite: get changeset iterator value: %v", err)
  463. }
  464. defer libc.Xfree(iter.tls, ppValue)
  465. res := ResultCode(lib.Xsqlite3changeset_new(iter.tls, iter.ptr, int32(col), ppValue))
  466. if err := res.ToError(); err != nil {
  467. return Value{}, fmt.Errorf("sqlite: get changeset iterator value: %w", err)
  468. }
  469. return Value{
  470. tls: iter.tls,
  471. ptrOrType: *(*uintptr)(unsafe.Pointer(ppValue)),
  472. }, nil
  473. }
  474. // ConflictValue obtains the conflicting row value from an iterator.
  475. // Column indices start at 0. The returned value is valid until the iterator is
  476. // finalized.
  477. //
  478. // https://www.sqlite.org/session/sqlite3changeset_conflict.html
  479. func (iter *ChangesetIterator) ConflictValue(col int) (Value, error) {
  480. if iter.ptr == 0 {
  481. return Value{}, fmt.Errorf("sqlite: get changeset iterator value: iterator finalized")
  482. }
  483. ppValue, err := malloc(iter.tls, ptrSize)
  484. if err != nil {
  485. return Value{}, fmt.Errorf("sqlite: get changeset iterator value: %v", err)
  486. }
  487. defer libc.Xfree(iter.tls, ppValue)
  488. res := ResultCode(lib.Xsqlite3changeset_conflict(iter.tls, iter.ptr, int32(col), ppValue))
  489. if err := res.ToError(); err != nil {
  490. return Value{}, fmt.Errorf("sqlite: get changeset iterator value: %w", err)
  491. }
  492. return Value{
  493. tls: iter.tls,
  494. ptrOrType: *(*uintptr)(unsafe.Pointer(ppValue)),
  495. }, nil
  496. }
  497. // ForeignKeyConflicts returns the number of foreign key constraint violations.
  498. //
  499. // https://www.sqlite.org/session/sqlite3changeset_fk_conflicts.html
  500. func (iter *ChangesetIterator) ForeignKeyConflicts() (int, error) {
  501. pnOut, err := malloc(iter.tls, types.Size_t(unsafe.Sizeof(int32(0))))
  502. if err != nil {
  503. return 0, fmt.Errorf("sqlite: get number of foreign key conflicts: %v", err)
  504. }
  505. defer libc.Xfree(iter.tls, pnOut)
  506. res := ResultCode(lib.Xsqlite3changeset_fk_conflicts(iter.tls, iter.ptr, pnOut))
  507. if err := res.ToError(); err != nil {
  508. return 0, fmt.Errorf("sqlite: get number of foreign key conflicts: %w", err)
  509. }
  510. return int(*(*int32)(unsafe.Pointer(pnOut))), nil
  511. }
  512. // PrimaryKey returns a map of columns that make up the primary key.
  513. //
  514. // https://www.sqlite.org/session/sqlite3changeset_pk.html
  515. func (iter *ChangesetIterator) PrimaryKey() ([]bool, error) {
  516. pabPK, err := malloc(iter.tls, ptrSize)
  517. if err != nil {
  518. return nil, fmt.Errorf("sqlite: get primary key columns: %v", err)
  519. }
  520. defer libc.Xfree(iter.tls, pabPK)
  521. pnCol, err := malloc(iter.tls, types.Size_t(unsafe.Sizeof(int32(0))))
  522. if err != nil {
  523. return nil, fmt.Errorf("sqlite: get primary key columns: %v", err)
  524. }
  525. defer libc.Xfree(iter.tls, pnCol)
  526. res := ResultCode(lib.Xsqlite3changeset_pk(iter.tls, iter.ptr, pabPK, pnCol))
  527. if err := res.ToError(); err != nil {
  528. return nil, fmt.Errorf("sqlite: get primary key columns: %w", err)
  529. }
  530. c := libc.GoBytes(*(*uintptr)(unsafe.Pointer(pabPK)), int(*(*int32)(unsafe.Pointer(pnCol))))
  531. cols := make([]bool, len(c))
  532. for i := range cols {
  533. cols[i] = c[i] != 0
  534. }
  535. return cols, nil
  536. }
  537. // ConcatChangesets concatenates two changesets into a single changeset.
  538. //
  539. // https://www.sqlite.org/session/sqlite3changeset_concat.html
  540. func ConcatChangesets(w io.Writer, changeset1, changeset2 io.Reader) error {
  541. tls := libc.NewTLS()
  542. defer tls.Close()
  543. xInput1, pIn1 := registerStreamReader(changeset1)
  544. defer unregisterStreamReader(pIn1)
  545. xInput2, pIn2 := registerStreamReader(changeset2)
  546. defer unregisterStreamReader(pIn2)
  547. xOutput, pOut := registerStreamWriter(w)
  548. defer unregisterStreamWriter(pOut)
  549. res := ResultCode(lib.Xsqlite3changeset_concat_strm(tls, xInput1, pIn1, xInput2, pIn2, xOutput, pOut))
  550. if err := res.ToError(); err != nil {
  551. return fmt.Errorf("sqlite: concatenate changesets: %w", err)
  552. }
  553. return nil
  554. }
  555. // A Changegroup is an object used to combine two or more changesets or
  556. // patchesets. The zero value is an empty changegroup.
  557. //
  558. // https://www.sqlite.org/session/changegroup.html
  559. type Changegroup struct {
  560. tls *libc.TLS
  561. ptr uintptr
  562. }
  563. // NewChangegroup returns a new changegroup. The caller is responsible for
  564. // calling Clear on the returned changegroup.
  565. //
  566. // https://www.sqlite.org/session/sqlite3changegroup_new.html
  567. //
  568. // Deprecated: Use new(sqlite.Changegroup) instead, which does not require
  569. // calling Clear until Add is called.
  570. func NewChangegroup() (*Changegroup, error) {
  571. cg := new(Changegroup)
  572. if err := cg.init(); err != nil {
  573. return nil, fmt.Errorf("sqlite: %w", err)
  574. }
  575. return cg, nil
  576. }
  577. func (cg *Changegroup) init() error {
  578. if cg.tls == nil {
  579. cg.tls = libc.NewTLS()
  580. }
  581. if cg.ptr == 0 {
  582. pp, err := malloc(cg.tls, ptrSize)
  583. if err != nil {
  584. cg.tls.Close()
  585. cg.tls = nil
  586. return fmt.Errorf("init changegroup: %v", err)
  587. }
  588. defer libc.Xfree(cg.tls, pp)
  589. res := ResultCode(lib.Xsqlite3changegroup_new(cg.tls, pp))
  590. if err := res.ToError(); err != nil {
  591. cg.tls.Close()
  592. cg.tls = nil
  593. return fmt.Errorf("init changegroup: %w", err)
  594. }
  595. cg.ptr = *(*uintptr)(unsafe.Pointer(pp))
  596. }
  597. return nil
  598. }
  599. // Clear empties the changegroup and releases any resources associated with
  600. // the changegroup. This method may be called multiple times.
  601. func (cg *Changegroup) Clear() {
  602. if cg == nil {
  603. return
  604. }
  605. if cg.ptr != 0 {
  606. lib.Xsqlite3changegroup_delete(cg.tls, cg.ptr)
  607. cg.ptr = 0
  608. }
  609. if cg.tls != nil {
  610. cg.tls.Close()
  611. cg.tls = nil
  612. }
  613. }
  614. // Add adds all changes within the changeset (or patchset) read from r to
  615. // the changegroup. Once Add has been called, it is the caller's responsibility
  616. // to call Clear.
  617. //
  618. // https://www.sqlite.org/session/sqlite3changegroup_add.html
  619. func (cg *Changegroup) Add(r io.Reader) error {
  620. if err := cg.init(); err != nil {
  621. return fmt.Errorf("sqlite: add to changegroup: %w", err)
  622. }
  623. xInput, pIn := registerStreamReader(r)
  624. defer unregisterStreamReader(pIn)
  625. res := ResultCode(lib.Xsqlite3changegroup_add_strm(cg.tls, cg.ptr, xInput, pIn))
  626. if err := res.ToError(); err != nil {
  627. return fmt.Errorf("sqlite: add to changegroup: %w", err)
  628. }
  629. return nil
  630. }
  631. // WriteTo writes the current contents of the changegroup to w.
  632. //
  633. // https://www.sqlite.org/session/sqlite3changegroup_output.html
  634. func (cg *Changegroup) WriteTo(w io.Writer) (n int64, err error) {
  635. // We want to allow uninitialized changegroups to write output without
  636. // forcing the caller to call Clear. In theses cases, we initialize a new
  637. // changegroup that lasts for the length of the WriteTo call.
  638. if cg == nil {
  639. cg = new(Changegroup)
  640. }
  641. if cg.ptr == 0 {
  642. defer cg.Clear()
  643. }
  644. if err := cg.init(); err != nil {
  645. return 0, fmt.Errorf("sqlite: write changegroup: %w", err)
  646. }
  647. wc := &writeCounter{Writer: w}
  648. xOutput, pOut := registerStreamWriter(wc)
  649. defer unregisterStreamWriter(pOut)
  650. res := ResultCode(lib.Xsqlite3changegroup_output_strm(cg.tls, cg.ptr, xOutput, pOut))
  651. if err := res.ToError(); err != nil {
  652. return wc.n, fmt.Errorf("sqlite: write changegroup: %w", err)
  653. }
  654. return wc.n, nil
  655. }
  656. // A ConflictHandler function determines the action to take to resolve a
  657. // conflict while applying a changeset.
  658. //
  659. // https://www.sqlite.org/session/sqlite3changeset_apply.html
  660. type ConflictHandler func(ConflictType, *ChangesetIterator) ConflictAction
  661. // ConflictType is an enumeration of changeset conflict types.
  662. //
  663. // https://www.sqlite.org/session/c_changeset_conflict.html
  664. type ConflictType int32
  665. // Conflict types.
  666. const (
  667. ChangesetData = ConflictType(lib.SQLITE_CHANGESET_DATA)
  668. ChangesetNotFound = ConflictType(lib.SQLITE_CHANGESET_NOTFOUND)
  669. ChangesetConflict = ConflictType(lib.SQLITE_CHANGESET_CONFLICT)
  670. ChangesetConstraint = ConflictType(lib.SQLITE_CHANGESET_CONSTRAINT)
  671. ChangesetForeignKey = ConflictType(lib.SQLITE_CHANGESET_FOREIGN_KEY)
  672. )
  673. // String returns the C constant name of the conflict type.
  674. func (code ConflictType) String() string {
  675. switch code {
  676. case ChangesetData:
  677. return "SQLITE_CHANGESET_DATA"
  678. case ChangesetNotFound:
  679. return "SQLITE_CHANGESET_NOTFOUND"
  680. case ChangesetConflict:
  681. return "SQLITE_CHANGESET_CONFLICT"
  682. case ChangesetConstraint:
  683. return "SQLITE_CHANGESET_CONSTRAINT"
  684. case ChangesetForeignKey:
  685. return "SQLITE_CHANGESET_FOREIGN_KEY"
  686. default:
  687. return fmt.Sprintf("ConflictType(%d)", int32(code))
  688. }
  689. }
  690. // ConflictAction is an enumeration of actions that can be taken in response to
  691. // a changeset conflict. The zero value is ChangesetOmit.
  692. //
  693. // https://www.sqlite.org/session/c_changeset_abort.html
  694. type ConflictAction int32
  695. // Conflict actions.
  696. const (
  697. // ChangesetOmit signals that no special action should be taken. The change
  698. // that caused the conflict will not be applied. The session module continues
  699. // to the next change in the changeset.
  700. ChangesetOmit = ConflictAction(lib.SQLITE_CHANGESET_OMIT)
  701. // ChangesetAbort signals that any changes applied so far should be rolled
  702. // back and the call to ApplyChangeset returns an error whose code
  703. // is ResultAbort.
  704. ChangesetAbort = ConflictAction(lib.SQLITE_CHANGESET_ABORT)
  705. // ChangesetReplace signals a different action depending on the conflict type.
  706. //
  707. // If the conflict type is ChangesetData, ChangesetReplace signals the
  708. // conflicting row should be updated or deleted.
  709. //
  710. // If the conflict type is ChangesetConflict, then ChangesetReplace signals
  711. // that the conflicting row should be removed from the database and a second
  712. // attempt to apply the change should be made. If this second attempt fails,
  713. // the original row is restored to the database before continuing.
  714. //
  715. // For all other conflict types, returning ChangesetReplace will cause
  716. // ApplyChangeset to roll back any changes applied so far and return an error
  717. // whose code is ResultMisuse.
  718. ChangesetReplace = ConflictAction(lib.SQLITE_CHANGESET_REPLACE)
  719. )
  720. // String returns the C constant name of the conflict action.
  721. func (code ConflictAction) String() string {
  722. switch code {
  723. case ChangesetOmit:
  724. return "SQLITE_CHANGESET_OMIT"
  725. case ChangesetAbort:
  726. return "SQLITE_CHANGESET_ABORT"
  727. case ChangesetReplace:
  728. return "SQLITE_CHANGESET_REPLACE"
  729. default:
  730. return fmt.Sprintf("ConflictAction(%d)", int32(code))
  731. }
  732. }
  733. var (
  734. streamReaders sync.Map // map[uintptr]io.Reader
  735. streamReadersIDMu sync.Mutex
  736. streamReadersIDs idGen
  737. )
  738. func registerStreamReader(r io.Reader) (xInput, pIn uintptr) {
  739. xInput = cFuncPointer(sessionStreamInput)
  740. streamReadersIDMu.Lock()
  741. pIn = streamReadersIDs.next()
  742. streamReadersIDMu.Unlock()
  743. streamReaders.Store(pIn, r)
  744. return
  745. }
  746. func unregisterStreamReader(pIn uintptr) {
  747. streamReaders.Delete(pIn)
  748. streamReadersIDMu.Lock()
  749. streamReadersIDs.reclaim(pIn)
  750. streamReadersIDMu.Unlock()
  751. }
  752. // sessionStreamInput is the callback returned by registerSessionReader used
  753. // for the session streaming APIs.
  754. // https://www.sqlite.org/session/sqlite3changegroup_add_strm.html
  755. func sessionStreamInput(tls *libc.TLS, pIn uintptr, pData uintptr, pnData uintptr) int32 {
  756. rval, _ := streamReaders.Load(pIn)
  757. r, _ := rval.(io.Reader)
  758. if r == nil {
  759. return lib.SQLITE_MISUSE
  760. }
  761. n := int(*(*int32)(unsafe.Pointer(pnData)))
  762. n, err := r.Read(libc.GoBytes(pData, n))
  763. *(*int32)(unsafe.Pointer(pnData)) = int32(n)
  764. if n == 0 && err != io.EOF {
  765. // Readers should not return n == 0 && err == nil. However, as per io.Reader
  766. // docs, we can't treat it as an EOF condition.
  767. return lib.SQLITE_IOERR_READ
  768. }
  769. return lib.SQLITE_OK
  770. }
  771. var (
  772. streamWriters sync.Map // map[uintptr]io.Writer
  773. streamWritersIDMu sync.Mutex
  774. streamWritersIDs idGen
  775. )
  776. func registerStreamWriter(w io.Writer) (xOutput, pOut uintptr) {
  777. xOutput = cFuncPointer(sessionStreamOutput)
  778. streamWritersIDMu.Lock()
  779. pOut = streamWritersIDs.next()
  780. streamWritersIDMu.Unlock()
  781. streamWriters.Store(pOut, w)
  782. return
  783. }
  784. func unregisterStreamWriter(pOut uintptr) {
  785. streamWriters.Delete(pOut)
  786. streamWritersIDMu.Lock()
  787. streamWritersIDs.reclaim(pOut)
  788. streamWritersIDMu.Unlock()
  789. }
  790. // sessionStreamOutput is the callback returned by registerSessionWriter used
  791. // for the session streaming APIs.
  792. // https://www.sqlite.org/session/sqlite3changegroup_add_strm.html
  793. func sessionStreamOutput(tls *libc.TLS, pOut uintptr, pData uintptr, nData int32) int32 {
  794. wval, _ := streamWriters.Load(pOut)
  795. w, _ := wval.(io.Writer)
  796. if w == nil {
  797. return lib.SQLITE_MISUSE
  798. }
  799. _, err := w.Write(libc.GoBytes(pData, int(nData)))
  800. if err != nil {
  801. return lib.SQLITE_IOERR_WRITE
  802. }
  803. return lib.SQLITE_OK
  804. }
  805. type writeCounter struct {
  806. io.Writer
  807. n int64
  808. }
  809. func (wc *writeCounter) Write(p []byte) (int, error) {
  810. n, err := wc.Writer.Write(p)
  811. wc.n += int64(n)
  812. return n, err
  813. }