exec.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. // Copyright (c) 2018 David Crawshaw <david@zentus.com>
  2. // Copyright (c) 2021 Ross Light <rosss@zombiezen.com>
  3. //
  4. // Permission to use, copy, modify, and distribute this software for any
  5. // purpose with or without fee is hereby granted, provided that the above
  6. // copyright notice and this permission notice appear in all copies.
  7. //
  8. // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9. // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  10. // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  11. // ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  12. // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  13. // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  14. // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. //
  16. // SPDX-License-Identifier: ISC
  17. // Package sqlitex provides utilities for working with SQLite.
  18. package sqlitex
  19. import (
  20. "fmt"
  21. "io"
  22. "reflect"
  23. "strings"
  24. "github.com/go-llsqlite/crawshaw"
  25. "io/fs"
  26. )
  27. // ExecOptions is the set of optional arguments executing a statement.
  28. type ExecOptions struct {
  29. // Args is the set of positional arguments to bind to the statement.
  30. // The first element in the slice is ?1.
  31. // See https://sqlite.org/lang_expr.html for more details.
  32. //
  33. // Basic reflection on Args is used to map:
  34. //
  35. // integers to BindInt64
  36. // floats to BindFloat
  37. // []byte to BindBytes
  38. // string to BindText
  39. // bool to BindBool
  40. //
  41. // All other kinds are printed using fmt.Sprint(v) and passed to BindText.
  42. Args []interface{}
  43. // Named is the set of named arguments to bind to the statement. Keys must
  44. // start with ':', '@', or '$'. See https://sqlite.org/lang_expr.html for more
  45. // details.
  46. //
  47. // Basic reflection on Named is used to map:
  48. //
  49. // integers to BindInt64
  50. // floats to BindFloat
  51. // []byte to BindBytes
  52. // string to BindText
  53. // bool to BindBool
  54. //
  55. // All other kinds are printed using fmt.Sprint(v) and passed to BindText.
  56. Named map[string]interface{}
  57. // ResultFunc is called for each result row.
  58. // If ResultFunc returns an error then iteration ceases
  59. // and the execution function returns the error value.
  60. ResultFunc func(stmt *sqlite.Stmt) error
  61. // Allow unused parameters. SQLite normally treats these as null anyway, so this reverts to the
  62. // default behaviour.
  63. AllowUnused bool
  64. }
  65. // Exec executes an SQLite query.
  66. //
  67. // For each result row, the resultFn is called.
  68. // Result values can be read by resultFn using stmt.Column* methods.
  69. // If resultFn returns an error then iteration ceases and Exec returns
  70. // the error value.
  71. //
  72. // Any args provided to Exec are bound to numbered parameters of the
  73. // query using the Stmt Bind* methods. Basic reflection on args is used
  74. // to map:
  75. //
  76. // integers to BindInt64
  77. // floats to BindFloat
  78. // []byte to BindBytes
  79. // string to BindText
  80. // bool to BindBool
  81. //
  82. // All other kinds are printed using fmt.Sprintf("%v", v) and passed
  83. // to BindText.
  84. //
  85. // Exec is implemented using the Stmt prepare mechanism which allows
  86. // better interactions with Go's type system and avoids pitfalls of
  87. // passing a Go closure to cgo.
  88. //
  89. // As Exec is implemented using Conn.Prepare, subsequent calls to Exec
  90. // with the same statement will reuse the cached statement object.
  91. //
  92. // Deprecated: Use Execute.
  93. // Exec skips some argument checks for compatibility with crawshaw.io/sqlite.
  94. func Exec(conn *sqlite.Conn, query string, resultFn func(stmt *sqlite.Stmt) error, args ...interface{}) error {
  95. stmt, err := conn.Prepare(query)
  96. if err != nil {
  97. return annotateErr(err)
  98. }
  99. err = exec(stmt, 0, &ExecOptions{
  100. Args: args,
  101. ResultFunc: resultFn,
  102. })
  103. resetErr := stmt.Reset()
  104. if err == nil {
  105. err = resetErr
  106. }
  107. return err
  108. }
  109. // Execute executes an SQLite query.
  110. //
  111. // As Execute is implemented using Conn.Prepare,
  112. // subsequent calls to Execute with the same statement
  113. // will reuse the cached statement object.
  114. func Execute(conn *sqlite.Conn, query string, opts *ExecOptions) error {
  115. stmt, err := conn.Prepare(query)
  116. if err != nil {
  117. return annotateErr(err)
  118. }
  119. err = exec(stmt, forbidMissing|forbidExtra, opts)
  120. resetErr := stmt.Reset()
  121. if err == nil {
  122. err = resetErr
  123. }
  124. return err
  125. }
  126. // ExecFS is an alias for ExecuteFS.
  127. //
  128. // Deprecated: Call ExecuteFS directly.
  129. func ExecFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
  130. return ExecuteFS(conn, fsys, filename, opts)
  131. }
  132. // ExecuteFS executes the single statement in the given SQL file.
  133. // ExecuteFS is implemented using Conn.Prepare,
  134. // so subsequent calls to ExecuteFS with the same statement
  135. // will reuse the cached statement object.
  136. func ExecuteFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
  137. query, err := readString(fsys, filename)
  138. if err != nil {
  139. return fmt.Errorf("exec: %w", err)
  140. }
  141. stmt, err := conn.Prepare(strings.TrimSpace(query))
  142. if err != nil {
  143. return fmt.Errorf("exec %s: %w", filename, err)
  144. }
  145. err = exec(stmt, forbidMissing|forbidExtra, opts)
  146. resetErr := stmt.Reset()
  147. if err != nil {
  148. // Don't strip the error query: we already do this inside exec.
  149. return fmt.Errorf("exec %s: %w", filename, err)
  150. }
  151. if resetErr != nil {
  152. return fmt.Errorf("exec %s: %w", filename, err)
  153. }
  154. return nil
  155. }
  156. // ExecTransient executes an SQLite query without caching the underlying query.
  157. // The interface is exactly the same as Exec.
  158. // It is the spiritual equivalent of sqlite3_exec.
  159. //
  160. // Deprecated: Use ExecuteTransient.
  161. // ExecTransient skips some argument checks for compatibility with crawshaw.io/sqlite.
  162. func ExecTransient(conn *sqlite.Conn, query string, resultFn func(stmt *sqlite.Stmt) error, args ...interface{}) (err error) {
  163. var stmt *sqlite.Stmt
  164. var trailingBytes int
  165. stmt, trailingBytes, err = conn.PrepareTransient(query)
  166. if err != nil {
  167. return annotateErr(err)
  168. }
  169. defer func() {
  170. ferr := stmt.Finalize()
  171. if err == nil {
  172. err = ferr
  173. }
  174. }()
  175. if trailingBytes != 0 {
  176. return fmt.Errorf("sqlitex.Exec: query %q has trailing bytes", query)
  177. }
  178. return exec(stmt, 0, &ExecOptions{
  179. Args: args,
  180. ResultFunc: resultFn,
  181. })
  182. }
  183. // ExecuteTransient executes an SQLite query without caching the underlying query.
  184. // It is the spiritual equivalent of sqlite3_exec:
  185. // https://www.sqlite.org/c3ref/exec.html
  186. func ExecuteTransient(conn *sqlite.Conn, query string, opts *ExecOptions) (err error) {
  187. var stmt *sqlite.Stmt
  188. var trailingBytes int
  189. stmt, trailingBytes, err = conn.PrepareTransient(query)
  190. if err != nil {
  191. return annotateErr(err)
  192. }
  193. defer func() {
  194. ferr := stmt.Finalize()
  195. if err == nil {
  196. err = ferr
  197. }
  198. }()
  199. if trailingBytes != 0 {
  200. return fmt.Errorf("sqlitex.Exec: query %q has trailing bytes", query)
  201. }
  202. return exec(stmt, forbidMissing|forbidExtra, opts)
  203. }
  204. // ExecTransientFS is an alias for ExecuteTransientFS.
  205. //
  206. // Deprecated: Call ExecuteTransientFS directly.
  207. func ExecTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
  208. return ExecuteTransientFS(conn, fsys, filename, opts)
  209. }
  210. // ExecuteTransientFS executes the single statement in the given SQL file without
  211. // caching the underlying query.
  212. func ExecuteTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
  213. query, err := readString(fsys, filename)
  214. if err != nil {
  215. return fmt.Errorf("exec: %w", err)
  216. }
  217. stmt, _, err := conn.PrepareTransient(strings.TrimSpace(query))
  218. if err != nil {
  219. return fmt.Errorf("exec %s: %w", filename, err)
  220. }
  221. defer stmt.Finalize()
  222. err = exec(stmt, forbidMissing|forbidExtra, opts)
  223. resetErr := stmt.Reset()
  224. if err != nil {
  225. // Don't strip the error query: we already do this inside exec.
  226. return fmt.Errorf("exec %s: %w", filename, err)
  227. }
  228. if resetErr != nil {
  229. return fmt.Errorf("exec %s: %w", filename, err)
  230. }
  231. return nil
  232. }
  233. // PrepareTransientFS prepares an SQL statement from a file that is not cached by
  234. // the Conn. Subsequent calls with the same query will create new Stmts.
  235. // The caller is responsible for calling Finalize on the returned Stmt when the
  236. // Stmt is no longer needed.
  237. func PrepareTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string) (*sqlite.Stmt, error) {
  238. query, err := readString(fsys, filename)
  239. if err != nil {
  240. return nil, fmt.Errorf("prepare: %w", err)
  241. }
  242. stmt, _, err := conn.PrepareTransient(strings.TrimSpace(query))
  243. if err != nil {
  244. return nil, fmt.Errorf("prepare %s: %w", filename, err)
  245. }
  246. return stmt, nil
  247. }
  248. const (
  249. forbidMissing = 1 << iota
  250. forbidExtra
  251. )
  252. func exec(stmt *sqlite.Stmt, flags uint8, opts *ExecOptions) (err error) {
  253. paramCount := stmt.BindParamCount()
  254. provided := newBitset(paramCount)
  255. if opts != nil {
  256. if len(opts.Args) > paramCount {
  257. return fmt.Errorf("sqlitex.Exec: %w (len(Args) > BindParamCount(); %d > %d)",
  258. sqlite.ResultRange.ToError(), len(opts.Args), paramCount)
  259. }
  260. for i, arg := range opts.Args {
  261. provided.set(i)
  262. setArg(stmt, i+1, reflect.ValueOf(arg))
  263. }
  264. if err := setNamed(stmt, provided, flags, opts.Named); err != nil {
  265. return err
  266. }
  267. }
  268. if flags&forbidMissing != 0 && !provided.hasAll(paramCount) {
  269. i := provided.firstMissing() + 1
  270. name := stmt.BindParamName(i)
  271. if name == "" {
  272. name = fmt.Sprintf("?%d", i)
  273. }
  274. return fmt.Errorf("sqlitex.Exec: missing argument for %s", name)
  275. }
  276. for {
  277. hasRow, err := stmt.Step()
  278. if err != nil {
  279. return err
  280. }
  281. if !hasRow {
  282. break
  283. }
  284. if opts != nil && opts.ResultFunc != nil {
  285. if err := opts.ResultFunc(stmt); err != nil {
  286. return err
  287. }
  288. }
  289. }
  290. return nil
  291. }
  292. func setArg(stmt *sqlite.Stmt, i int, v reflect.Value) {
  293. switch v.Kind() {
  294. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  295. stmt.BindInt64(i, v.Int())
  296. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  297. stmt.BindInt64(i, int64(v.Uint()))
  298. case reflect.Float32, reflect.Float64:
  299. stmt.BindFloat(i, v.Float())
  300. case reflect.String:
  301. stmt.BindText(i, v.String())
  302. case reflect.Bool:
  303. stmt.BindBool(i, v.Bool())
  304. case reflect.Invalid:
  305. stmt.BindNull(i)
  306. default:
  307. if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
  308. stmt.BindBytes(i, v.Bytes())
  309. } else {
  310. stmt.BindText(i, fmt.Sprint(v.Interface()))
  311. }
  312. }
  313. }
  314. func setNamed(stmt *sqlite.Stmt, provided bitset, flags uint8, args map[string]interface{}) error {
  315. if len(args) == 0 {
  316. return nil
  317. }
  318. var unused map[string]struct{}
  319. if flags&forbidExtra != 0 {
  320. unused = make(map[string]struct{}, len(args))
  321. for k := range args {
  322. unused[k] = struct{}{}
  323. }
  324. }
  325. for i, count := 1, stmt.BindParamCount(); i <= count; i++ {
  326. name := stmt.BindParamName(i)
  327. if name == "" {
  328. continue
  329. }
  330. arg, present := args[name]
  331. if !present {
  332. if flags&forbidMissing != 0 {
  333. // TODO(maybe): Check provided as well?
  334. return fmt.Errorf("missing parameter %s", name)
  335. }
  336. continue
  337. }
  338. delete(unused, name)
  339. provided.set(i - 1)
  340. setArg(stmt, i, reflect.ValueOf(arg))
  341. }
  342. if len(unused) > 0 {
  343. return fmt.Errorf("%w: unknown argument %s", sqlite.ResultRange.ToError(), minStringInSet(unused))
  344. }
  345. return nil
  346. }
  347. func annotateErr(err error) error {
  348. // TODO(maybe)
  349. // if err, isError := err.(sqlite.Error); isError {
  350. // if err.Loc == "" {
  351. // err.Loc = "Exec"
  352. // } else {
  353. // err.Loc = "Exec: " + err.Loc
  354. // }
  355. // return err
  356. // }
  357. return fmt.Errorf("sqlitex.Exec: %w", err)
  358. }
  359. // ExecScript executes a script of SQL statements.
  360. // It is the same as calling ExecuteScript without options.
  361. func ExecScript(conn *sqlite.Conn, queries string) (err error) {
  362. return ExecuteScript(conn, queries, nil)
  363. }
  364. // ExecuteScript executes a script of SQL statements.
  365. // The script is wrapped in a SAVEPOINT transaction,
  366. // which is rolled back on any error.
  367. //
  368. // opts.ResultFunc is ignored.
  369. func ExecuteScript(conn *sqlite.Conn, queries string, opts *ExecOptions) (err error) {
  370. defer Save(conn)(&err)
  371. unused := make(map[string]struct{})
  372. if opts != nil {
  373. for k := range opts.Named {
  374. unused[k] = struct{}{}
  375. }
  376. }
  377. for {
  378. queries = strings.TrimSpace(queries)
  379. if queries == "" {
  380. break
  381. }
  382. var stmt *sqlite.Stmt
  383. var trailingBytes int
  384. stmt, trailingBytes, err = conn.PrepareTransient(queries)
  385. if err != nil {
  386. return err
  387. }
  388. for i, n := 1, stmt.BindParamCount(); i <= n; i++ {
  389. if name := stmt.BindParamName(i); name != "" {
  390. delete(unused, name)
  391. }
  392. }
  393. usedBytes := len(queries) - trailingBytes
  394. queries = queries[usedBytes:]
  395. err = exec(stmt, forbidMissing, opts)
  396. stmt.Finalize()
  397. if err != nil {
  398. return err
  399. }
  400. }
  401. if len(unused) > 0 && !opts.AllowUnused {
  402. return fmt.Errorf("%w: unknown argument %s", sqlite.ResultRange.ToError(), minStringInSet(unused))
  403. }
  404. return nil
  405. }
  406. // ExecScriptFS is an alias for ExecuteScriptFS.
  407. //
  408. // Deprecated: Call ExecuteScriptFS directly.
  409. func ExecScriptFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) (err error) {
  410. return ExecuteScriptFS(conn, fsys, filename, opts)
  411. }
  412. // ExecuteScriptFS executes a script of SQL statements from a file.
  413. // The script is wrapped in a SAVEPOINT transaction,
  414. // which is rolled back on any error.
  415. func ExecuteScriptFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) (err error) {
  416. queries, err := readString(fsys, filename)
  417. if err != nil {
  418. return fmt.Errorf("exec: %w", err)
  419. }
  420. if err := ExecuteScript(conn, queries, opts); err != nil {
  421. return fmt.Errorf("exec %s: %w", filename, err)
  422. }
  423. return nil
  424. }
  425. type bitset []uint64
  426. func newBitset(n int) bitset {
  427. return make([]uint64, (n+63)/64)
  428. }
  429. // hasAll reports whether the bitset is a superset of [0, n).
  430. func (bs bitset) hasAll(n int) bool {
  431. nbytes := (n + 63) / 64
  432. if len(bs) < nbytes {
  433. return false
  434. }
  435. fullBytes := n / 64
  436. for _, b := range bs[:fullBytes] {
  437. if b != ^uint64(0) {
  438. return false
  439. }
  440. }
  441. if fullBytes == nbytes {
  442. return true
  443. }
  444. mask := uint64(1)<<(n%64) - 1
  445. return bs[nbytes-1]&mask == mask
  446. }
  447. func (bs bitset) firstMissing() int {
  448. for i, b := range bs {
  449. if b == ^uint64(0) {
  450. continue
  451. }
  452. for j := 0; j < 64; j++ {
  453. if b&(1<<j) == 0 {
  454. return i*64 + j
  455. }
  456. }
  457. }
  458. return len(bs) * 64
  459. }
  460. func (bs bitset) set(n int) {
  461. bs[n/64] |= 1 << (n % 64)
  462. }
  463. func (bs bitset) String() string {
  464. sb := new(strings.Builder)
  465. for i := len(bs) - 1; i >= 0; i-- {
  466. fmt.Fprintf(sb, "%08b", bs[i])
  467. }
  468. return sb.String()
  469. }
  470. func minStringInSet(set map[string]struct{}) string {
  471. min := ""
  472. for k := range set {
  473. if min == "" || k < min {
  474. min = k
  475. }
  476. }
  477. return min
  478. }
  479. func readString(fsys fs.FS, filename string) (string, error) {
  480. f, err := fsys.Open(filename)
  481. if err != nil {
  482. return "", err
  483. }
  484. content := new(strings.Builder)
  485. _, err = io.Copy(content, f)
  486. f.Close()
  487. if err != nil {
  488. return "", fmt.Errorf("%s: %w", filename, err)
  489. }
  490. return content.String(), nil
  491. }