exec.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. // Copyright (c) 2018 David Crawshaw <david@zentus.com>
  2. // Copyright (c) 2021 Ross Light <rosss@zombiezen.com>
  3. //
  4. // Permission to use, copy, modify, and distribute this software for any
  5. // purpose with or without fee is hereby granted, provided that the above
  6. // copyright notice and this permission notice appear in all copies.
  7. //
  8. // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9. // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  10. // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  11. // ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  12. // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  13. // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  14. // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. //
  16. // SPDX-License-Identifier: ISC
  17. // Package sqlitex provides utilities for working with SQLite.
  18. package sqlitex
  19. import (
  20. "fmt"
  21. "io"
  22. "reflect"
  23. "strings"
  24. "zombiezen.com/go/sqlite"
  25. "zombiezen.com/go/sqlite/fs"
  26. )
  27. // ExecOptions is the set of optional arguments executing a statement.
  28. type ExecOptions struct {
  29. // Args is the set of positional arguments to bind to the statement.
  30. // The first element in the slice is ?1.
  31. // See https://sqlite.org/lang_expr.html for more details.
  32. //
  33. // Basic reflection on Args is used to map:
  34. //
  35. // integers to BindInt64
  36. // floats to BindFloat
  37. // []byte to BindBytes
  38. // string to BindText
  39. // bool to BindBool
  40. //
  41. // All other kinds are printed using fmt.Sprint(v) and passed to BindText.
  42. Args []interface{}
  43. // Named is the set of named arguments to bind to the statement. Keys must
  44. // start with ':', '@', or '$'. See https://sqlite.org/lang_expr.html for more
  45. // details.
  46. //
  47. // Basic reflection on Named is used to map:
  48. //
  49. // integers to BindInt64
  50. // floats to BindFloat
  51. // []byte to BindBytes
  52. // string to BindText
  53. // bool to BindBool
  54. //
  55. // All other kinds are printed using fmt.Sprint(v) and passed to BindText.
  56. Named map[string]interface{}
  57. // ResultFunc is called for each result row.
  58. // If ResultFunc returns an error then iteration ceases
  59. // and the execution function returns the error value.
  60. ResultFunc func(stmt *sqlite.Stmt) error
  61. }
  62. // Exec executes an SQLite query.
  63. //
  64. // For each result row, the resultFn is called.
  65. // Result values can be read by resultFn using stmt.Column* methods.
  66. // If resultFn returns an error then iteration ceases and Exec returns
  67. // the error value.
  68. //
  69. // Any args provided to Exec are bound to numbered parameters of the
  70. // query using the Stmt Bind* methods. Basic reflection on args is used
  71. // to map:
  72. //
  73. // integers to BindInt64
  74. // floats to BindFloat
  75. // []byte to BindBytes
  76. // string to BindText
  77. // bool to BindBool
  78. //
  79. // All other kinds are printed using fmt.Sprintf("%v", v) and passed
  80. // to BindText.
  81. //
  82. // Exec is implemented using the Stmt prepare mechanism which allows
  83. // better interactions with Go's type system and avoids pitfalls of
  84. // passing a Go closure to cgo.
  85. //
  86. // As Exec is implemented using Conn.Prepare, subsequent calls to Exec
  87. // with the same statement will reuse the cached statement object.
  88. //
  89. // Deprecated: Use Execute.
  90. // Exec skips some argument checks for compatibility with crawshaw.io/sqlite.
  91. func Exec(conn *sqlite.Conn, query string, resultFn func(stmt *sqlite.Stmt) error, args ...interface{}) error {
  92. stmt, err := conn.Prepare(query)
  93. if err != nil {
  94. return annotateErr(err)
  95. }
  96. err = exec(stmt, 0, &ExecOptions{
  97. Args: args,
  98. ResultFunc: resultFn,
  99. })
  100. resetErr := stmt.Reset()
  101. if err == nil {
  102. err = resetErr
  103. }
  104. return err
  105. }
  106. // Execute executes an SQLite query.
  107. //
  108. // As Execute is implemented using Conn.Prepare,
  109. // subsequent calls to Execute with the same statement
  110. // will reuse the cached statement object.
  111. func Execute(conn *sqlite.Conn, query string, opts *ExecOptions) error {
  112. stmt, err := conn.Prepare(query)
  113. if err != nil {
  114. return annotateErr(err)
  115. }
  116. err = exec(stmt, forbidMissing|forbidExtra, opts)
  117. resetErr := stmt.Reset()
  118. if err == nil {
  119. err = resetErr
  120. }
  121. return err
  122. }
  123. // ExecFS is an alias for ExecuteFS.
  124. //
  125. // Deprecated: Call ExecuteFS directly.
  126. func ExecFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
  127. return ExecuteFS(conn, fsys, filename, opts)
  128. }
  129. // ExecuteFS executes the single statement in the given SQL file.
  130. // ExecuteFS is implemented using Conn.Prepare,
  131. // so subsequent calls to ExecuteFS with the same statement
  132. // will reuse the cached statement object.
  133. func ExecuteFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
  134. query, err := readString(fsys, filename)
  135. if err != nil {
  136. return fmt.Errorf("exec: %w", err)
  137. }
  138. stmt, err := conn.Prepare(strings.TrimSpace(query))
  139. if err != nil {
  140. return fmt.Errorf("exec %s: %w", filename, err)
  141. }
  142. err = exec(stmt, forbidMissing|forbidExtra, opts)
  143. resetErr := stmt.Reset()
  144. if err != nil {
  145. // Don't strip the error query: we already do this inside exec.
  146. return fmt.Errorf("exec %s: %w", filename, err)
  147. }
  148. if resetErr != nil {
  149. return fmt.Errorf("exec %s: %w", filename, err)
  150. }
  151. return nil
  152. }
  153. // ExecTransient executes an SQLite query without caching the underlying query.
  154. // The interface is exactly the same as Exec.
  155. // It is the spiritual equivalent of sqlite3_exec.
  156. //
  157. // Deprecated: Use ExecuteTransient.
  158. // ExecTransient skips some argument checks for compatibility with crawshaw.io/sqlite.
  159. func ExecTransient(conn *sqlite.Conn, query string, resultFn func(stmt *sqlite.Stmt) error, args ...interface{}) (err error) {
  160. var stmt *sqlite.Stmt
  161. var trailingBytes int
  162. stmt, trailingBytes, err = conn.PrepareTransient(query)
  163. if err != nil {
  164. return annotateErr(err)
  165. }
  166. defer func() {
  167. ferr := stmt.Finalize()
  168. if err == nil {
  169. err = ferr
  170. }
  171. }()
  172. if trailingBytes != 0 {
  173. return fmt.Errorf("sqlitex.Exec: query %q has trailing bytes", query)
  174. }
  175. return exec(stmt, 0, &ExecOptions{
  176. Args: args,
  177. ResultFunc: resultFn,
  178. })
  179. }
  180. // ExecuteTransient executes an SQLite query without caching the underlying query.
  181. // It is the spiritual equivalent of sqlite3_exec:
  182. // https://www.sqlite.org/c3ref/exec.html
  183. func ExecuteTransient(conn *sqlite.Conn, query string, opts *ExecOptions) (err error) {
  184. var stmt *sqlite.Stmt
  185. var trailingBytes int
  186. stmt, trailingBytes, err = conn.PrepareTransient(query)
  187. if err != nil {
  188. return annotateErr(err)
  189. }
  190. defer func() {
  191. ferr := stmt.Finalize()
  192. if err == nil {
  193. err = ferr
  194. }
  195. }()
  196. if trailingBytes != 0 {
  197. return fmt.Errorf("sqlitex.Exec: query %q has trailing bytes", query)
  198. }
  199. return exec(stmt, forbidMissing|forbidExtra, opts)
  200. }
  201. // ExecTransientFS is an alias for ExecuteTransientFS.
  202. //
  203. // Deprecated: Call ExecuteTransientFS directly.
  204. func ExecTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
  205. return ExecuteTransientFS(conn, fsys, filename, opts)
  206. }
  207. // ExecuteTransientFS executes the single statement in the given SQL file without
  208. // caching the underlying query.
  209. func ExecuteTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
  210. query, err := readString(fsys, filename)
  211. if err != nil {
  212. return fmt.Errorf("exec: %w", err)
  213. }
  214. stmt, _, err := conn.PrepareTransient(strings.TrimSpace(query))
  215. if err != nil {
  216. return fmt.Errorf("exec %s: %w", filename, err)
  217. }
  218. defer stmt.Finalize()
  219. err = exec(stmt, forbidMissing|forbidExtra, opts)
  220. resetErr := stmt.Reset()
  221. if err != nil {
  222. // Don't strip the error query: we already do this inside exec.
  223. return fmt.Errorf("exec %s: %w", filename, err)
  224. }
  225. if resetErr != nil {
  226. return fmt.Errorf("exec %s: %w", filename, err)
  227. }
  228. return nil
  229. }
  230. // PrepareTransientFS prepares an SQL statement from a file that is not cached by
  231. // the Conn. Subsequent calls with the same query will create new Stmts.
  232. // The caller is responsible for calling Finalize on the returned Stmt when the
  233. // Stmt is no longer needed.
  234. func PrepareTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string) (*sqlite.Stmt, error) {
  235. query, err := readString(fsys, filename)
  236. if err != nil {
  237. return nil, fmt.Errorf("prepare: %w", err)
  238. }
  239. stmt, _, err := conn.PrepareTransient(strings.TrimSpace(query))
  240. if err != nil {
  241. return nil, fmt.Errorf("prepare %s: %w", filename, err)
  242. }
  243. return stmt, nil
  244. }
  245. const (
  246. forbidMissing = 1 << iota
  247. forbidExtra
  248. )
  249. func exec(stmt *sqlite.Stmt, flags uint8, opts *ExecOptions) (err error) {
  250. paramCount := stmt.BindParamCount()
  251. provided := newBitset(paramCount)
  252. if opts != nil {
  253. if len(opts.Args) > paramCount {
  254. return fmt.Errorf("sqlitex.Exec: %w (len(Args) > BindParamCount(); %d > %d)",
  255. sqlite.ResultRange.ToError(), len(opts.Args), paramCount)
  256. }
  257. for i, arg := range opts.Args {
  258. provided.set(i)
  259. setArg(stmt, i+1, reflect.ValueOf(arg))
  260. }
  261. if err := setNamed(stmt, provided, flags, opts.Named); err != nil {
  262. return err
  263. }
  264. }
  265. if flags&forbidMissing != 0 && !provided.hasAll(paramCount) {
  266. i := provided.firstMissing() + 1
  267. name := stmt.BindParamName(i)
  268. if name == "" {
  269. name = fmt.Sprintf("?%d", i)
  270. }
  271. return fmt.Errorf("sqlitex.Exec: missing argument for %s", name)
  272. }
  273. for {
  274. hasRow, err := stmt.Step()
  275. if err != nil {
  276. return err
  277. }
  278. if !hasRow {
  279. break
  280. }
  281. if opts != nil && opts.ResultFunc != nil {
  282. if err := opts.ResultFunc(stmt); err != nil {
  283. return err
  284. }
  285. }
  286. }
  287. return nil
  288. }
  289. func setArg(stmt *sqlite.Stmt, i int, v reflect.Value) {
  290. switch v.Kind() {
  291. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  292. stmt.BindInt64(i, v.Int())
  293. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  294. stmt.BindInt64(i, int64(v.Uint()))
  295. case reflect.Float32, reflect.Float64:
  296. stmt.BindFloat(i, v.Float())
  297. case reflect.String:
  298. stmt.BindText(i, v.String())
  299. case reflect.Bool:
  300. stmt.BindBool(i, v.Bool())
  301. case reflect.Invalid:
  302. stmt.BindNull(i)
  303. default:
  304. if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
  305. stmt.BindBytes(i, v.Bytes())
  306. } else {
  307. stmt.BindText(i, fmt.Sprint(v.Interface()))
  308. }
  309. }
  310. }
  311. func setNamed(stmt *sqlite.Stmt, provided bitset, flags uint8, args map[string]interface{}) error {
  312. if len(args) == 0 {
  313. return nil
  314. }
  315. var unused map[string]struct{}
  316. if flags&forbidExtra != 0 {
  317. unused = make(map[string]struct{}, len(args))
  318. for k := range args {
  319. unused[k] = struct{}{}
  320. }
  321. }
  322. for i, count := 1, stmt.BindParamCount(); i <= count; i++ {
  323. name := stmt.BindParamName(i)
  324. if name == "" {
  325. continue
  326. }
  327. arg, present := args[name]
  328. if !present {
  329. if flags&forbidMissing != 0 {
  330. // TODO(maybe): Check provided as well?
  331. return fmt.Errorf("missing parameter %s", name)
  332. }
  333. continue
  334. }
  335. delete(unused, name)
  336. provided.set(i - 1)
  337. setArg(stmt, i, reflect.ValueOf(arg))
  338. }
  339. if len(unused) > 0 {
  340. return fmt.Errorf("%w: unknown argument %s", sqlite.ResultRange.ToError(), minStringInSet(unused))
  341. }
  342. return nil
  343. }
  344. func annotateErr(err error) error {
  345. // TODO(maybe)
  346. // if err, isError := err.(sqlite.Error); isError {
  347. // if err.Loc == "" {
  348. // err.Loc = "Exec"
  349. // } else {
  350. // err.Loc = "Exec: " + err.Loc
  351. // }
  352. // return err
  353. // }
  354. return fmt.Errorf("sqlitex.Exec: %w", err)
  355. }
  356. // ExecScript executes a script of SQL statements.
  357. // It is the same as calling ExecuteScript without options.
  358. func ExecScript(conn *sqlite.Conn, queries string) (err error) {
  359. return ExecuteScript(conn, queries, nil)
  360. }
  361. // ExecuteScript executes a script of SQL statements.
  362. // The script is wrapped in a SAVEPOINT transaction,
  363. // which is rolled back on any error.
  364. //
  365. // opts.ResultFunc is ignored.
  366. func ExecuteScript(conn *sqlite.Conn, queries string, opts *ExecOptions) (err error) {
  367. defer Save(conn)(&err)
  368. unused := make(map[string]struct{})
  369. if opts != nil {
  370. for k := range opts.Named {
  371. unused[k] = struct{}{}
  372. }
  373. }
  374. for {
  375. queries = strings.TrimSpace(queries)
  376. if queries == "" {
  377. break
  378. }
  379. var stmt *sqlite.Stmt
  380. var trailingBytes int
  381. stmt, trailingBytes, err = conn.PrepareTransient(queries)
  382. if err != nil {
  383. return err
  384. }
  385. for i, n := 1, stmt.BindParamCount(); i <= n; i++ {
  386. if name := stmt.BindParamName(i); name != "" {
  387. delete(unused, name)
  388. }
  389. }
  390. usedBytes := len(queries) - trailingBytes
  391. queries = queries[usedBytes:]
  392. err = exec(stmt, forbidMissing, opts)
  393. stmt.Finalize()
  394. if err != nil {
  395. return err
  396. }
  397. }
  398. if len(unused) > 0 {
  399. return fmt.Errorf("%w: unknown argument %s", sqlite.ResultRange.ToError(), minStringInSet(unused))
  400. }
  401. return nil
  402. }
  403. // ExecScriptFS is an alias for ExecuteScriptFS.
  404. //
  405. // Deprecated: Call ExecuteScriptFS directly.
  406. func ExecScriptFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) (err error) {
  407. return ExecuteScriptFS(conn, fsys, filename, opts)
  408. }
  409. // ExecuteScriptFS executes a script of SQL statements from a file.
  410. // The script is wrapped in a SAVEPOINT transaction,
  411. // which is rolled back on any error.
  412. func ExecuteScriptFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) (err error) {
  413. queries, err := readString(fsys, filename)
  414. if err != nil {
  415. return fmt.Errorf("exec: %w", err)
  416. }
  417. if err := ExecuteScript(conn, queries, opts); err != nil {
  418. return fmt.Errorf("exec %s: %w", filename, err)
  419. }
  420. return nil
  421. }
  422. type bitset []uint64
  423. func newBitset(n int) bitset {
  424. return make([]uint64, (n+63)/64)
  425. }
  426. // hasAll reports whether the bitset is a superset of [0, n).
  427. func (bs bitset) hasAll(n int) bool {
  428. nbytes := (n + 63) / 64
  429. if len(bs) < nbytes {
  430. return false
  431. }
  432. fullBytes := n / 64
  433. for _, b := range bs[:fullBytes] {
  434. if b != ^uint64(0) {
  435. return false
  436. }
  437. }
  438. if fullBytes == nbytes {
  439. return true
  440. }
  441. mask := uint64(1)<<(n%64) - 1
  442. return bs[nbytes-1]&mask == mask
  443. }
  444. func (bs bitset) firstMissing() int {
  445. for i, b := range bs {
  446. if b == ^uint64(0) {
  447. continue
  448. }
  449. for j := 0; j < 64; j++ {
  450. if b&(1<<j) == 0 {
  451. return i*64 + j
  452. }
  453. }
  454. }
  455. return len(bs) * 64
  456. }
  457. func (bs bitset) set(n int) {
  458. bs[n/64] |= 1 << (n % 64)
  459. }
  460. func (bs bitset) String() string {
  461. sb := new(strings.Builder)
  462. for i := len(bs) - 1; i >= 0; i-- {
  463. fmt.Fprintf(sb, "%08b", bs[i])
  464. }
  465. return sb.String()
  466. }
  467. func minStringInSet(set map[string]struct{}) string {
  468. min := ""
  469. for k := range set {
  470. if min == "" || k < min {
  471. min = k
  472. }
  473. }
  474. return min
  475. }
  476. func readString(fsys fs.FS, filename string) (string, error) {
  477. f, err := fsys.Open(filename)
  478. if err != nil {
  479. return "", err
  480. }
  481. content := new(strings.Builder)
  482. _, err = io.Copy(content, f)
  483. f.Close()
  484. if err != nil {
  485. return "", fmt.Errorf("%s: %w", filename, err)
  486. }
  487. return content.String(), nil
  488. }