| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523 |
- // Copyright (c) 2018 David Crawshaw <david@zentus.com>
- // Copyright (c) 2021 Ross Light <rosss@zombiezen.com>
- //
- // Permission to use, copy, modify, and distribute this software for any
- // purpose with or without fee is hereby granted, provided that the above
- // copyright notice and this permission notice appear in all copies.
- //
- // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
- // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
- // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
- // ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
- // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
- // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
- // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
- //
- // SPDX-License-Identifier: ISC
- // Package sqlitex provides utilities for working with SQLite.
- package sqlitex
- import (
- "fmt"
- "io"
- "reflect"
- "strings"
- "zombiezen.com/go/sqlite"
- "zombiezen.com/go/sqlite/fs"
- )
- // ExecOptions is the set of optional arguments executing a statement.
- type ExecOptions struct {
- // Args is the set of positional arguments to bind to the statement.
- // The first element in the slice is ?1.
- // See https://sqlite.org/lang_expr.html for more details.
- //
- // Basic reflection on Args is used to map:
- //
- // integers to BindInt64
- // floats to BindFloat
- // []byte to BindBytes
- // string to BindText
- // bool to BindBool
- //
- // All other kinds are printed using fmt.Sprint(v) and passed to BindText.
- Args []interface{}
- // Named is the set of named arguments to bind to the statement. Keys must
- // start with ':', '@', or '$'. See https://sqlite.org/lang_expr.html for more
- // details.
- //
- // Basic reflection on Named is used to map:
- //
- // integers to BindInt64
- // floats to BindFloat
- // []byte to BindBytes
- // string to BindText
- // bool to BindBool
- //
- // All other kinds are printed using fmt.Sprint(v) and passed to BindText.
- Named map[string]interface{}
- // ResultFunc is called for each result row.
- // If ResultFunc returns an error then iteration ceases
- // and the execution function returns the error value.
- ResultFunc func(stmt *sqlite.Stmt) error
- }
- // Exec executes an SQLite query.
- //
- // For each result row, the resultFn is called.
- // Result values can be read by resultFn using stmt.Column* methods.
- // If resultFn returns an error then iteration ceases and Exec returns
- // the error value.
- //
- // Any args provided to Exec are bound to numbered parameters of the
- // query using the Stmt Bind* methods. Basic reflection on args is used
- // to map:
- //
- // integers to BindInt64
- // floats to BindFloat
- // []byte to BindBytes
- // string to BindText
- // bool to BindBool
- //
- // All other kinds are printed using fmt.Sprintf("%v", v) and passed
- // to BindText.
- //
- // Exec is implemented using the Stmt prepare mechanism which allows
- // better interactions with Go's type system and avoids pitfalls of
- // passing a Go closure to cgo.
- //
- // As Exec is implemented using Conn.Prepare, subsequent calls to Exec
- // with the same statement will reuse the cached statement object.
- //
- // Deprecated: Use Execute.
- // Exec skips some argument checks for compatibility with crawshaw.io/sqlite.
- func Exec(conn *sqlite.Conn, query string, resultFn func(stmt *sqlite.Stmt) error, args ...interface{}) error {
- stmt, err := conn.Prepare(query)
- if err != nil {
- return annotateErr(err)
- }
- err = exec(stmt, 0, &ExecOptions{
- Args: args,
- ResultFunc: resultFn,
- })
- resetErr := stmt.Reset()
- if err == nil {
- err = resetErr
- }
- return err
- }
- // Execute executes an SQLite query.
- //
- // As Execute is implemented using Conn.Prepare,
- // subsequent calls to Execute with the same statement
- // will reuse the cached statement object.
- func Execute(conn *sqlite.Conn, query string, opts *ExecOptions) error {
- stmt, err := conn.Prepare(query)
- if err != nil {
- return annotateErr(err)
- }
- err = exec(stmt, forbidMissing|forbidExtra, opts)
- resetErr := stmt.Reset()
- if err == nil {
- err = resetErr
- }
- return err
- }
- // ExecFS is an alias for ExecuteFS.
- //
- // Deprecated: Call ExecuteFS directly.
- func ExecFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
- return ExecuteFS(conn, fsys, filename, opts)
- }
- // ExecuteFS executes the single statement in the given SQL file.
- // ExecuteFS is implemented using Conn.Prepare,
- // so subsequent calls to ExecuteFS with the same statement
- // will reuse the cached statement object.
- func ExecuteFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
- query, err := readString(fsys, filename)
- if err != nil {
- return fmt.Errorf("exec: %w", err)
- }
- stmt, err := conn.Prepare(strings.TrimSpace(query))
- if err != nil {
- return fmt.Errorf("exec %s: %w", filename, err)
- }
- err = exec(stmt, forbidMissing|forbidExtra, opts)
- resetErr := stmt.Reset()
- if err != nil {
- // Don't strip the error query: we already do this inside exec.
- return fmt.Errorf("exec %s: %w", filename, err)
- }
- if resetErr != nil {
- return fmt.Errorf("exec %s: %w", filename, err)
- }
- return nil
- }
- // ExecTransient executes an SQLite query without caching the underlying query.
- // The interface is exactly the same as Exec.
- // It is the spiritual equivalent of sqlite3_exec.
- //
- // Deprecated: Use ExecuteTransient.
- // ExecTransient skips some argument checks for compatibility with crawshaw.io/sqlite.
- func ExecTransient(conn *sqlite.Conn, query string, resultFn func(stmt *sqlite.Stmt) error, args ...interface{}) (err error) {
- var stmt *sqlite.Stmt
- var trailingBytes int
- stmt, trailingBytes, err = conn.PrepareTransient(query)
- if err != nil {
- return annotateErr(err)
- }
- defer func() {
- ferr := stmt.Finalize()
- if err == nil {
- err = ferr
- }
- }()
- if trailingBytes != 0 {
- return fmt.Errorf("sqlitex.Exec: query %q has trailing bytes", query)
- }
- return exec(stmt, 0, &ExecOptions{
- Args: args,
- ResultFunc: resultFn,
- })
- }
- // ExecuteTransient executes an SQLite query without caching the underlying query.
- // It is the spiritual equivalent of sqlite3_exec:
- // https://www.sqlite.org/c3ref/exec.html
- func ExecuteTransient(conn *sqlite.Conn, query string, opts *ExecOptions) (err error) {
- var stmt *sqlite.Stmt
- var trailingBytes int
- stmt, trailingBytes, err = conn.PrepareTransient(query)
- if err != nil {
- return annotateErr(err)
- }
- defer func() {
- ferr := stmt.Finalize()
- if err == nil {
- err = ferr
- }
- }()
- if trailingBytes != 0 {
- return fmt.Errorf("sqlitex.Exec: query %q has trailing bytes", query)
- }
- return exec(stmt, forbidMissing|forbidExtra, opts)
- }
- // ExecTransientFS is an alias for ExecuteTransientFS.
- //
- // Deprecated: Call ExecuteTransientFS directly.
- func ExecTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
- return ExecuteTransientFS(conn, fsys, filename, opts)
- }
- // ExecuteTransientFS executes the single statement in the given SQL file without
- // caching the underlying query.
- func ExecuteTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) error {
- query, err := readString(fsys, filename)
- if err != nil {
- return fmt.Errorf("exec: %w", err)
- }
- stmt, _, err := conn.PrepareTransient(strings.TrimSpace(query))
- if err != nil {
- return fmt.Errorf("exec %s: %w", filename, err)
- }
- defer stmt.Finalize()
- err = exec(stmt, forbidMissing|forbidExtra, opts)
- resetErr := stmt.Reset()
- if err != nil {
- // Don't strip the error query: we already do this inside exec.
- return fmt.Errorf("exec %s: %w", filename, err)
- }
- if resetErr != nil {
- return fmt.Errorf("exec %s: %w", filename, err)
- }
- return nil
- }
- // PrepareTransientFS prepares an SQL statement from a file that is not cached by
- // the Conn. Subsequent calls with the same query will create new Stmts.
- // The caller is responsible for calling Finalize on the returned Stmt when the
- // Stmt is no longer needed.
- func PrepareTransientFS(conn *sqlite.Conn, fsys fs.FS, filename string) (*sqlite.Stmt, error) {
- query, err := readString(fsys, filename)
- if err != nil {
- return nil, fmt.Errorf("prepare: %w", err)
- }
- stmt, _, err := conn.PrepareTransient(strings.TrimSpace(query))
- if err != nil {
- return nil, fmt.Errorf("prepare %s: %w", filename, err)
- }
- return stmt, nil
- }
- const (
- forbidMissing = 1 << iota
- forbidExtra
- )
- func exec(stmt *sqlite.Stmt, flags uint8, opts *ExecOptions) (err error) {
- paramCount := stmt.BindParamCount()
- provided := newBitset(paramCount)
- if opts != nil {
- if len(opts.Args) > paramCount {
- return fmt.Errorf("sqlitex.Exec: %w (len(Args) > BindParamCount(); %d > %d)",
- sqlite.ResultRange.ToError(), len(opts.Args), paramCount)
- }
- for i, arg := range opts.Args {
- provided.set(i)
- setArg(stmt, i+1, reflect.ValueOf(arg))
- }
- if err := setNamed(stmt, provided, flags, opts.Named); err != nil {
- return err
- }
- }
- if flags&forbidMissing != 0 && !provided.hasAll(paramCount) {
- i := provided.firstMissing() + 1
- name := stmt.BindParamName(i)
- if name == "" {
- name = fmt.Sprintf("?%d", i)
- }
- return fmt.Errorf("sqlitex.Exec: missing argument for %s", name)
- }
- for {
- hasRow, err := stmt.Step()
- if err != nil {
- return err
- }
- if !hasRow {
- break
- }
- if opts != nil && opts.ResultFunc != nil {
- if err := opts.ResultFunc(stmt); err != nil {
- return err
- }
- }
- }
- return nil
- }
- func setArg(stmt *sqlite.Stmt, i int, v reflect.Value) {
- switch v.Kind() {
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- stmt.BindInt64(i, v.Int())
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- stmt.BindInt64(i, int64(v.Uint()))
- case reflect.Float32, reflect.Float64:
- stmt.BindFloat(i, v.Float())
- case reflect.String:
- stmt.BindText(i, v.String())
- case reflect.Bool:
- stmt.BindBool(i, v.Bool())
- case reflect.Invalid:
- stmt.BindNull(i)
- default:
- if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
- stmt.BindBytes(i, v.Bytes())
- } else {
- stmt.BindText(i, fmt.Sprint(v.Interface()))
- }
- }
- }
- func setNamed(stmt *sqlite.Stmt, provided bitset, flags uint8, args map[string]interface{}) error {
- if len(args) == 0 {
- return nil
- }
- var unused map[string]struct{}
- if flags&forbidExtra != 0 {
- unused = make(map[string]struct{}, len(args))
- for k := range args {
- unused[k] = struct{}{}
- }
- }
- for i, count := 1, stmt.BindParamCount(); i <= count; i++ {
- name := stmt.BindParamName(i)
- if name == "" {
- continue
- }
- arg, present := args[name]
- if !present {
- if flags&forbidMissing != 0 {
- // TODO(maybe): Check provided as well?
- return fmt.Errorf("missing parameter %s", name)
- }
- continue
- }
- delete(unused, name)
- provided.set(i - 1)
- setArg(stmt, i, reflect.ValueOf(arg))
- }
- if len(unused) > 0 {
- return fmt.Errorf("%w: unknown argument %s", sqlite.ResultRange.ToError(), minStringInSet(unused))
- }
- return nil
- }
- func annotateErr(err error) error {
- // TODO(maybe)
- // if err, isError := err.(sqlite.Error); isError {
- // if err.Loc == "" {
- // err.Loc = "Exec"
- // } else {
- // err.Loc = "Exec: " + err.Loc
- // }
- // return err
- // }
- return fmt.Errorf("sqlitex.Exec: %w", err)
- }
- // ExecScript executes a script of SQL statements.
- // It is the same as calling ExecuteScript without options.
- func ExecScript(conn *sqlite.Conn, queries string) (err error) {
- return ExecuteScript(conn, queries, nil)
- }
- // ExecuteScript executes a script of SQL statements.
- // The script is wrapped in a SAVEPOINT transaction,
- // which is rolled back on any error.
- //
- // opts.ResultFunc is ignored.
- func ExecuteScript(conn *sqlite.Conn, queries string, opts *ExecOptions) (err error) {
- defer Save(conn)(&err)
- unused := make(map[string]struct{})
- if opts != nil {
- for k := range opts.Named {
- unused[k] = struct{}{}
- }
- }
- for {
- queries = strings.TrimSpace(queries)
- if queries == "" {
- break
- }
- var stmt *sqlite.Stmt
- var trailingBytes int
- stmt, trailingBytes, err = conn.PrepareTransient(queries)
- if err != nil {
- return err
- }
- for i, n := 1, stmt.BindParamCount(); i <= n; i++ {
- if name := stmt.BindParamName(i); name != "" {
- delete(unused, name)
- }
- }
- usedBytes := len(queries) - trailingBytes
- queries = queries[usedBytes:]
- err = exec(stmt, forbidMissing, opts)
- stmt.Finalize()
- if err != nil {
- return err
- }
- }
- if len(unused) > 0 {
- return fmt.Errorf("%w: unknown argument %s", sqlite.ResultRange.ToError(), minStringInSet(unused))
- }
- return nil
- }
- // ExecScriptFS is an alias for ExecuteScriptFS.
- //
- // Deprecated: Call ExecuteScriptFS directly.
- func ExecScriptFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) (err error) {
- return ExecuteScriptFS(conn, fsys, filename, opts)
- }
- // ExecuteScriptFS executes a script of SQL statements from a file.
- // The script is wrapped in a SAVEPOINT transaction,
- // which is rolled back on any error.
- func ExecuteScriptFS(conn *sqlite.Conn, fsys fs.FS, filename string, opts *ExecOptions) (err error) {
- queries, err := readString(fsys, filename)
- if err != nil {
- return fmt.Errorf("exec: %w", err)
- }
- if err := ExecuteScript(conn, queries, opts); err != nil {
- return fmt.Errorf("exec %s: %w", filename, err)
- }
- return nil
- }
- type bitset []uint64
- func newBitset(n int) bitset {
- return make([]uint64, (n+63)/64)
- }
- // hasAll reports whether the bitset is a superset of [0, n).
- func (bs bitset) hasAll(n int) bool {
- nbytes := (n + 63) / 64
- if len(bs) < nbytes {
- return false
- }
- fullBytes := n / 64
- for _, b := range bs[:fullBytes] {
- if b != ^uint64(0) {
- return false
- }
- }
- if fullBytes == nbytes {
- return true
- }
- mask := uint64(1)<<(n%64) - 1
- return bs[nbytes-1]&mask == mask
- }
- func (bs bitset) firstMissing() int {
- for i, b := range bs {
- if b == ^uint64(0) {
- continue
- }
- for j := 0; j < 64; j++ {
- if b&(1<<j) == 0 {
- return i*64 + j
- }
- }
- }
- return len(bs) * 64
- }
- func (bs bitset) set(n int) {
- bs[n/64] |= 1 << (n % 64)
- }
- func (bs bitset) String() string {
- sb := new(strings.Builder)
- for i := len(bs) - 1; i >= 0; i-- {
- fmt.Fprintf(sb, "%08b", bs[i])
- }
- return sb.String()
- }
- func minStringInSet(set map[string]struct{}) string {
- min := ""
- for k := range set {
- if min == "" || k < min {
- min = k
- }
- }
- return min
- }
- func readString(fsys fs.FS, filename string) (string, error) {
- f, err := fsys.Open(filename)
- if err != nil {
- return "", err
- }
- content := new(strings.Builder)
- _, err = io.Copy(content, f)
- f.Close()
- if err != nil {
- return "", fmt.Errorf("%s: %w", filename, err)
- }
- return content.String(), nil
- }
|