func.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808
  1. // Copyright (c) 2018 David Crawshaw <david@zentus.com>
  2. // Copyright (c) 2021 Ross Light <ross@zombiezen.com>
  3. //
  4. // Permission to use, copy, modify, and distribute this software for any
  5. // purpose with or without fee is hereby granted, provided that the above
  6. // copyright notice and this permission notice appear in all copies.
  7. //
  8. // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9. // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  10. // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  11. // ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  12. // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  13. // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  14. // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. //
  16. // SPDX-License-Identifier: ISC
  17. package sqlite
  18. import (
  19. "errors"
  20. "fmt"
  21. "math"
  22. "math/bits"
  23. "strconv"
  24. "strings"
  25. "sync"
  26. "unsafe"
  27. "modernc.org/libc"
  28. "modernc.org/libc/sys/types"
  29. lib "modernc.org/sqlite/lib"
  30. )
  31. var auxdata struct {
  32. mu sync.RWMutex
  33. m map[uintptr]interface{}
  34. ids idGen
  35. }
  36. // Context is a SQL function execution context.
  37. // It is in no way related to a Go context.Context.
  38. // https://sqlite.org/c3ref/context.html
  39. type Context struct {
  40. tls *libc.TLS
  41. ptr uintptr
  42. }
  43. // Conn returns the database connection that is calling the SQL function.
  44. func (ctx Context) Conn() *Conn {
  45. connPtr := lib.Xsqlite3_context_db_handle(ctx.tls, ctx.ptr)
  46. allConns.mu.RLock()
  47. defer allConns.mu.RUnlock()
  48. return allConns.table[connPtr]
  49. }
  50. // AuxData returns the auxiliary data associated with the given argument, with
  51. // zero being the leftmost argument, or nil if no such data is present.
  52. //
  53. // Auxiliary data may be used by (non-aggregate) SQL functions to associate
  54. // metadata with argument values. If the same value is passed to multiple
  55. // invocations of the same SQL function during query execution, under some
  56. // circumstances the associated metadata may be preserved. An example of where
  57. // this might be useful is in a regular-expression matching function. The
  58. // compiled version of the regular expression can be stored as metadata
  59. // associated with the pattern string. Then as long as the pattern string
  60. // remains the same, the compiled regular expression can be reused on multiple
  61. // invocations of the same function.
  62. //
  63. // For more details, see https://www.sqlite.org/c3ref/get_auxdata.html
  64. func (ctx Context) AuxData(arg int) interface{} {
  65. id := lib.Xsqlite3_get_auxdata(ctx.tls, ctx.ptr, int32(arg))
  66. if id == 0 {
  67. return nil
  68. }
  69. auxdata.mu.RLock()
  70. defer auxdata.mu.RUnlock()
  71. return auxdata.m[id]
  72. }
  73. // SetAuxData sets the auxiliary data associated with the given argument, with
  74. // zero being the leftmost argument. SQLite is free to discard the metadata at
  75. // any time, including during the call to SetAuxData.
  76. //
  77. // Auxiliary data may be used by (non-aggregate) SQL functions to associate
  78. // metadata with argument values. If the same value is passed to multiple
  79. // invocations of the same SQL function during query execution, under some
  80. // circumstances the associated metadata may be preserved. An example of where
  81. // this might be useful is in a regular-expression matching function. The
  82. // compiled version of the regular expression can be stored as metadata
  83. // associated with the pattern string. Then as long as the pattern string
  84. // remains the same, the compiled regular expression can be reused on multiple
  85. // invocations of the same function.
  86. //
  87. // For more details, see https://www.sqlite.org/c3ref/get_auxdata.html
  88. func (ctx Context) SetAuxData(arg int, data interface{}) {
  89. auxdata.mu.Lock()
  90. id := auxdata.ids.next()
  91. if auxdata.m == nil {
  92. auxdata.m = make(map[uintptr]interface{})
  93. }
  94. auxdata.m[id] = data
  95. auxdata.mu.Unlock()
  96. deleteFn := cFuncPointer(freeAuxData)
  97. lib.Xsqlite3_set_auxdata(ctx.tls, ctx.ptr, int32(arg), id, deleteFn)
  98. }
  99. func freeAuxData(tls *libc.TLS, id uintptr) {
  100. auxdata.mu.Lock()
  101. defer auxdata.mu.Unlock()
  102. delete(auxdata.m, id)
  103. auxdata.ids.reclaim(id)
  104. }
  105. func (ctx Context) result(v Value, err error) {
  106. if err != nil {
  107. ctx.resultError(err)
  108. return
  109. }
  110. if v.tls != nil {
  111. if ctx.tls != v.tls {
  112. ctx.resultError(fmt.Errorf("function result Value from different connection"))
  113. return
  114. }
  115. lib.Xsqlite3_result_value(ctx.tls, ctx.ptr, v.ptrOrType)
  116. return
  117. }
  118. switch ColumnType(v.ptrOrType) {
  119. case 0, TypeNull:
  120. lib.Xsqlite3_result_null(ctx.tls, ctx.ptr)
  121. case TypeInteger:
  122. lib.Xsqlite3_result_int64(ctx.tls, ctx.ptr, v.n)
  123. case TypeFloat:
  124. lib.Xsqlite3_result_double(ctx.tls, ctx.ptr, v.float())
  125. case TypeText:
  126. if len(v.s) == 0 {
  127. lib.Xsqlite3_result_text(ctx.tls, ctx.ptr, emptyCString, 0, sqliteStatic)
  128. } else {
  129. cv, err := libc.CString(v.s)
  130. if err != nil {
  131. ctx.resultError(fmt.Errorf("alloc function result: %w", err))
  132. return
  133. }
  134. lib.Xsqlite3_result_text(ctx.tls, ctx.ptr, cv, int32(len(v.s)), freeFuncPtr)
  135. }
  136. case TypeBlob:
  137. if len(v.s) == 0 {
  138. lib.Xsqlite3_result_blob(ctx.tls, ctx.ptr, emptyCString, 0, sqliteStatic)
  139. } else {
  140. cv, err := malloc(ctx.tls, types.Size_t(len(v.s)))
  141. if err != nil {
  142. ctx.resultError(fmt.Errorf("alloc function result: %w", err))
  143. return
  144. }
  145. copy(libc.GoBytes(cv, len(v.s)), v.s)
  146. lib.Xsqlite3_result_blob(ctx.tls, ctx.ptr, cv, int32(len(v.s)), freeFuncPtr)
  147. }
  148. default:
  149. panic("unknown result Value type")
  150. }
  151. }
  152. func (ctx Context) resultError(err error) {
  153. errstr := err.Error()
  154. cerrstr, err := libc.CString(errstr)
  155. if err != nil {
  156. panic(err)
  157. }
  158. defer libc.Xfree(ctx.tls, cerrstr)
  159. lib.Xsqlite3_result_error(ctx.tls, ctx.ptr, cerrstr, int32(len(errstr)))
  160. lib.Xsqlite3_result_error_code(ctx.tls, ctx.ptr, int32(ErrCode(err)))
  161. }
  162. // Value represents a value that can be stored in a database table.
  163. // The zero value is NULL.
  164. // The accessor methods on Value may perform automatic conversions
  165. // and thus methods on Value must not be called concurrently.
  166. type Value struct {
  167. tls *libc.TLS
  168. ptrOrType uintptr // pointer to sqlite_value if tls != nil, ColumnType otherwise
  169. s string
  170. n int64 // if ptrOrType == 0 and n != 0, then indicates the "nochange" NULL.
  171. }
  172. // IntegerValue returns a new Value representing the given integer.
  173. func IntegerValue(i int64) Value {
  174. return Value{ptrOrType: uintptr(TypeInteger), n: i}
  175. }
  176. // FloatValue returns a new Value representing the given floating-point number.
  177. func FloatValue(f float64) Value {
  178. return Value{ptrOrType: uintptr(TypeFloat), n: int64(math.Float64bits(f))}
  179. }
  180. // TextValue returns a new Value representing the given string.
  181. func TextValue(s string) Value {
  182. return Value{ptrOrType: uintptr(TypeText), s: s}
  183. }
  184. // BlobValue returns a new blob Value, copying the bytes from the given
  185. // byte slice.
  186. func BlobValue(b []byte) Value {
  187. return Value{ptrOrType: uintptr(TypeBlob), s: string(b)}
  188. }
  189. // Unchanged returns a NULL Value for which [Value.NoChange] reports true.
  190. // This is only significant as the return value for [VTableCursor.Column].
  191. func Unchanged() Value {
  192. return Value{n: 1}
  193. }
  194. // Type returns the data type of the value. The result of Type is undefined if
  195. // an automatic type conversion has occurred due to calling one of the other
  196. // accessor methods.
  197. func (v Value) Type() ColumnType {
  198. if v.ptrOrType == 0 {
  199. return TypeNull
  200. }
  201. if v.tls == nil {
  202. return ColumnType(v.ptrOrType)
  203. }
  204. return ColumnType(lib.Xsqlite3_value_type(v.tls, v.ptrOrType))
  205. }
  206. // Conversions follow the table in https://sqlite.org/c3ref/column_blob.html
  207. // Int returns the value as an integer.
  208. func (v Value) Int() int {
  209. return int(v.Int64())
  210. }
  211. // Int64 returns the value as a 64-bit integer.
  212. func (v Value) Int64() int64 {
  213. if v.ptrOrType == 0 {
  214. return 0
  215. }
  216. if v.tls == nil {
  217. switch ColumnType(v.ptrOrType) {
  218. case TypeNull:
  219. return 0
  220. case TypeInteger:
  221. return v.n
  222. case TypeFloat:
  223. return int64(v.float())
  224. case TypeBlob, TypeText:
  225. return castTextToInteger(v.s)
  226. default:
  227. panic("unknown value type")
  228. }
  229. }
  230. return int64(lib.Xsqlite3_value_int64(v.tls, v.ptrOrType))
  231. }
  232. // castTextToInteger emulates the SQLite CAST operator for a TEXT value to
  233. // INTEGER, as documented in https://sqlite.org/lang_expr.html#castexpr
  234. func castTextToInteger(s string) int64 {
  235. const digits = "0123456789"
  236. s = strings.TrimSpace(s)
  237. if len(s) > 0 && (s[0] == '+' || s[0] == '-') {
  238. s = s[:1+len(longestPrefix(s[1:], digits))]
  239. } else {
  240. s = longestPrefix(s, digits)
  241. }
  242. n, _ := strconv.ParseInt(s, 10, 64)
  243. return n
  244. }
  245. func longestPrefix(s string, allowSet string) string {
  246. sloop:
  247. for i := 0; i < len(s); i++ {
  248. for j := 0; j < len(allowSet); j++ {
  249. if s[i] == allowSet[j] {
  250. continue sloop
  251. }
  252. }
  253. return s[:i]
  254. }
  255. return s
  256. }
  257. // Float returns the value as floating-point number
  258. func (v Value) Float() float64 {
  259. if v.ptrOrType == 0 {
  260. return 0
  261. }
  262. if v.tls == nil {
  263. switch ColumnType(v.ptrOrType) {
  264. case TypeNull:
  265. return 0
  266. case TypeInteger:
  267. return float64(v.n)
  268. case TypeFloat:
  269. return v.float()
  270. case TypeBlob, TypeText:
  271. return castTextToReal(v.s)
  272. default:
  273. panic("unknown value type")
  274. }
  275. }
  276. return float64(lib.Xsqlite3_value_double(v.tls, v.ptrOrType))
  277. }
  278. func (v Value) float() float64 { return math.Float64frombits(uint64(v.n)) }
  279. // castTextToReal emulates the SQLite CAST operator for a TEXT value to
  280. // REAL, as documented in https://sqlite.org/lang_expr.html#castexpr
  281. func castTextToReal(s string) float64 {
  282. s = strings.TrimSpace(s)
  283. for ; len(s) > 0; s = s[:len(s)-1] {
  284. n, err := strconv.ParseFloat(s, 64)
  285. if !errors.Is(err, strconv.ErrSyntax) {
  286. return n
  287. }
  288. }
  289. return 0
  290. }
  291. // Text returns the value as a string.
  292. func (v Value) Text() string {
  293. if v.ptrOrType == 0 {
  294. return ""
  295. }
  296. if v.tls == nil {
  297. switch ColumnType(v.ptrOrType) {
  298. case TypeNull:
  299. return ""
  300. case TypeInteger:
  301. return strconv.FormatInt(v.n, 10)
  302. case TypeFloat:
  303. return strconv.FormatFloat(v.float(), 'g', -1, 64)
  304. case TypeText, TypeBlob:
  305. return v.s
  306. default:
  307. panic("unknown value type")
  308. }
  309. }
  310. ptr := lib.Xsqlite3_value_text(v.tls, v.ptrOrType)
  311. return goStringN(ptr, int(lib.Xsqlite3_value_bytes(v.tls, v.ptrOrType)))
  312. }
  313. // Blob returns a copy of the value as a blob.
  314. func (v Value) Blob() []byte {
  315. if v.ptrOrType == 0 {
  316. return nil
  317. }
  318. if v.tls == nil {
  319. switch ColumnType(v.ptrOrType) {
  320. case TypeNull:
  321. return nil
  322. case TypeInteger:
  323. return strconv.AppendInt(nil, v.n, 10)
  324. case TypeFloat:
  325. return strconv.AppendFloat(nil, v.float(), 'g', -1, 64)
  326. case TypeBlob, TypeText:
  327. return []byte(v.s)
  328. default:
  329. panic("unknown value type")
  330. }
  331. }
  332. ptr := lib.Xsqlite3_value_blob(v.tls, v.ptrOrType)
  333. return libc.GoBytes(ptr, int(lib.Xsqlite3_value_bytes(v.tls, v.ptrOrType)))
  334. }
  335. // NoChange reports whether a column
  336. // corresponding to this value in a [VTable.Update] method
  337. // is unchanged by the UPDATE operation
  338. // that the VTable.Update method call was invoked to implement
  339. // and if the prior [VTableCursor.Column] method call that was invoked
  340. // to extract the value for that column returned [Unchanged].
  341. func (v Value) NoChange() bool {
  342. if v.ptrOrType == 0 {
  343. return v.n != 0
  344. }
  345. if v.tls == nil {
  346. return false
  347. }
  348. return lib.Xsqlite3_value_nochange(v.tls, v.ptrOrType) != 0
  349. }
  350. // FunctionImpl describes an [application-defined SQL function].
  351. // Either Scalar or MakeAggregate must be set, but not both.
  352. //
  353. // [application-defined SQL function]: https://sqlite.org/appfunc.html
  354. type FunctionImpl struct {
  355. // NArgs is the required number of arguments that the function accepts.
  356. // If NArgs is negative, then the function is variadic.
  357. //
  358. // Multiple function implementations may be registered with the same name
  359. // with different numbers of required arguments.
  360. NArgs int
  361. // Scalar is called when a scalar function is invoked in SQL.
  362. // The argument Values are not valid past the return of the function.
  363. Scalar func(ctx Context, args []Value) (Value, error)
  364. // MakeAggregate is called at the beginning of an evaluation of an aggregate function.
  365. MakeAggregate func(ctx Context) (AggregateFunction, error)
  366. // If Deterministic is true, the function must always give the same output
  367. // when the input parameters are the same. This enables functions to be used
  368. // in additional contexts like the WHERE clause of partial indexes and enables
  369. // additional optimizations.
  370. //
  371. // See https://sqlite.org/c3ref/c_deterministic.html#sqlitedeterministic for
  372. // more details.
  373. Deterministic bool
  374. // If AllowIndirect is false, then the function may only be invoked from
  375. // top-level SQL. If AllowIndirect is true, then the function can be used in
  376. // VIEWs, TRIGGERs, and schema structures (e.g. CHECK constraints and DEFAULT
  377. // clauses).
  378. //
  379. // This is the inverse of SQLITE_DIRECTONLY. See
  380. // https://sqlite.org/c3ref/c_deterministic.html#sqlitedirectonly for more
  381. // details. This defaults to false for better security.
  382. AllowIndirect bool
  383. }
  384. // An AggregateFunction is an invocation of an aggregate function.
  385. // See the documentation for [aggregate function callbacks]
  386. // and [application-defined window functions] for an overview.
  387. //
  388. // [aggregate function callbacks]: https://www.sqlite.org/appfunc.html#the_aggregate_function_callbacks
  389. // [application-defined window functions]: https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions
  390. type AggregateFunction interface {
  391. // Step is called for each row
  392. // of an aggregate function's SQL invocation.
  393. // The argument Values are not valid past the return of the function.
  394. Step(ctx Context, rowArgs []Value) error
  395. // WindowInverse is called to remove
  396. // the oldest presently aggregated result of Step
  397. // from the current window.
  398. // The arguments are those passed to Step for the row being removed.
  399. // The argument Values are not valid past the return of the function.
  400. WindowInverse(ctx Context, rowArgs []Value) error
  401. // WindowValue is called to get the current value of an aggregate function.
  402. WindowValue(ctx Context) (Value, error)
  403. // Finalize is called after all of the aggregate function's input rows
  404. // have been stepped through.
  405. // No other methods will be called on the AggregateFunction after calling Finalize.
  406. Finalize(ctx Context)
  407. }
  408. // CreateFunction registers a Go function with SQLite
  409. // for use in SQL queries.
  410. //
  411. // https://sqlite.org/appfunc.html
  412. func (c *Conn) CreateFunction(name string, impl *FunctionImpl) error {
  413. if c == nil {
  414. return fmt.Errorf("sqlite: create function: nil connection")
  415. }
  416. if name == "" {
  417. return fmt.Errorf("sqlite: create function: no name provided")
  418. }
  419. if impl.NArgs > 127 {
  420. return fmt.Errorf("sqlite: create function %s: too many permitted arguments (%d)", name, impl.NArgs)
  421. }
  422. if impl.Scalar == nil && impl.MakeAggregate == nil {
  423. return fmt.Errorf("sqlite: create function %s: must specify one of Scalar or MakeAggregate", name)
  424. }
  425. if impl.Scalar != nil && impl.MakeAggregate != nil {
  426. return fmt.Errorf("sqlite: create function %s: both Scalar and MakeAggregate specified", name)
  427. }
  428. cname, err := libc.CString(name)
  429. if err != nil {
  430. return fmt.Errorf("sqlite: create function %s: %w", name, err)
  431. }
  432. defer libc.Xfree(c.tls, cname)
  433. eTextRep := int32(lib.SQLITE_UTF8)
  434. if impl.Deterministic {
  435. eTextRep |= lib.SQLITE_DETERMINISTIC
  436. }
  437. if !impl.AllowIndirect {
  438. eTextRep |= lib.SQLITE_DIRECTONLY
  439. }
  440. numArgs := impl.NArgs
  441. if numArgs < 0 {
  442. numArgs = -1
  443. }
  444. var res ResultCode
  445. if impl.Scalar != nil {
  446. xfuncs.mu.Lock()
  447. id := xfuncs.ids.next()
  448. xfuncs.m[id] = impl.Scalar
  449. xfuncs.mu.Unlock()
  450. res = ResultCode(lib.Xsqlite3_create_function_v2(
  451. c.tls,
  452. c.conn,
  453. cname,
  454. int32(numArgs),
  455. eTextRep,
  456. id,
  457. cFuncPointer(funcTrampoline),
  458. 0,
  459. 0,
  460. cFuncPointer(destroyScalarFunc),
  461. ))
  462. } else {
  463. xAggregateFactories.mu.Lock()
  464. id := xAggregateFactories.ids.next()
  465. xAggregateFactories.m[id] = impl.MakeAggregate
  466. xAggregateFactories.mu.Unlock()
  467. res = ResultCode(lib.Xsqlite3_create_window_function(
  468. c.tls,
  469. c.conn,
  470. cname,
  471. int32(numArgs),
  472. eTextRep,
  473. id,
  474. cFuncPointer(stepTrampoline),
  475. cFuncPointer(finalTrampoline),
  476. cFuncPointer(valueTrampoline),
  477. cFuncPointer(inverseTrampoline),
  478. cFuncPointer(destroyAggregateFunc),
  479. ))
  480. }
  481. if err := res.ToError(); err != nil {
  482. return fmt.Errorf("sqlite: create function %s: %w", name, err)
  483. }
  484. return nil
  485. }
  486. var xfuncs = struct {
  487. mu sync.RWMutex
  488. m map[uintptr]func(Context, []Value) (Value, error)
  489. ids idGen
  490. }{
  491. m: make(map[uintptr]func(Context, []Value) (Value, error)),
  492. }
  493. func funcTrampoline(tls *libc.TLS, ctx uintptr, n int32, valarray uintptr) {
  494. id := lib.Xsqlite3_user_data(tls, ctx)
  495. xfuncs.mu.RLock()
  496. x := xfuncs.m[id]
  497. xfuncs.mu.RUnlock()
  498. vals := make([]Value, 0, int(n))
  499. for ; len(vals) < cap(vals); valarray += uintptr(ptrSize) {
  500. vals = append(vals, Value{
  501. tls: tls,
  502. ptrOrType: *(*uintptr)(unsafe.Pointer(valarray)),
  503. })
  504. }
  505. goCtx := Context{tls: tls, ptr: ctx}
  506. goCtx.result(x(goCtx, vals))
  507. }
  508. func destroyScalarFunc(tls *libc.TLS, id uintptr) {
  509. xfuncs.mu.Lock()
  510. defer xfuncs.mu.Unlock()
  511. delete(xfuncs.m, id)
  512. xfuncs.ids.reclaim(id)
  513. }
  514. var (
  515. xAggregateFactories = struct {
  516. mu sync.RWMutex
  517. m map[uintptr]func(Context) (AggregateFunction, error)
  518. ids idGen
  519. }{
  520. m: make(map[uintptr]func(Context) (AggregateFunction, error)),
  521. }
  522. xAggregateContext = struct {
  523. mu sync.RWMutex
  524. m map[uintptr]AggregateFunction
  525. ids idGen
  526. }{
  527. m: make(map[uintptr]AggregateFunction),
  528. }
  529. )
  530. func makeAggregate(tls *libc.TLS, ctx uintptr) (AggregateFunction, uintptr) {
  531. goCtx := Context{tls: tls, ptr: ctx}
  532. aggCtx := (*uintptr)(unsafe.Pointer(lib.Xsqlite3_aggregate_context(tls, ctx, int32(ptrSize))))
  533. if aggCtx == nil {
  534. goCtx.resultError(errors.New("insufficient memory for aggregate"))
  535. return nil, 0
  536. }
  537. if *aggCtx != 0 {
  538. // Already created.
  539. xAggregateContext.mu.RLock()
  540. f := xAggregateContext.m[*aggCtx]
  541. xAggregateContext.mu.RUnlock()
  542. return f, *aggCtx
  543. }
  544. factoryID := lib.Xsqlite3_user_data(tls, ctx)
  545. xAggregateFactories.mu.RLock()
  546. factory := xAggregateFactories.m[factoryID]
  547. xAggregateFactories.mu.RUnlock()
  548. f, err := factory(goCtx)
  549. if err != nil {
  550. goCtx.resultError(err)
  551. return nil, 0
  552. }
  553. if f == nil {
  554. goCtx.resultError(errors.New("MakeAggregate function returned nil"))
  555. return nil, 0
  556. }
  557. xAggregateContext.mu.Lock()
  558. *aggCtx = xAggregateContext.ids.next()
  559. xAggregateContext.m[*aggCtx] = f
  560. xAggregateContext.mu.Unlock()
  561. return f, *aggCtx
  562. }
  563. func stepTrampoline(tls *libc.TLS, ctx uintptr, n int32, valarray uintptr) {
  564. x, _ := makeAggregate(tls, ctx)
  565. if x == nil {
  566. return
  567. }
  568. vals := make([]Value, 0, int(n))
  569. for ; len(vals) < cap(vals); valarray += uintptr(ptrSize) {
  570. vals = append(vals, Value{
  571. tls: tls,
  572. ptrOrType: *(*uintptr)(unsafe.Pointer(valarray)),
  573. })
  574. }
  575. goCtx := Context{tls: tls, ptr: ctx}
  576. if err := x.Step(goCtx, vals); err != nil {
  577. goCtx.resultError(err)
  578. }
  579. }
  580. func finalTrampoline(tls *libc.TLS, ctx uintptr) {
  581. x, id := makeAggregate(tls, ctx)
  582. if x == nil {
  583. return
  584. }
  585. goCtx := Context{tls: tls, ptr: ctx}
  586. goCtx.result(x.WindowValue(goCtx))
  587. x.Finalize(goCtx)
  588. xAggregateContext.mu.Lock()
  589. defer xAggregateContext.mu.Unlock()
  590. delete(xAggregateContext.m, id)
  591. xAggregateContext.ids.reclaim(id)
  592. }
  593. func valueTrampoline(tls *libc.TLS, ctx uintptr) {
  594. x, _ := makeAggregate(tls, ctx)
  595. if x == nil {
  596. return
  597. }
  598. goCtx := Context{tls: tls, ptr: ctx}
  599. goCtx.result(x.WindowValue(goCtx))
  600. }
  601. func inverseTrampoline(tls *libc.TLS, ctx uintptr, n int32, valarray uintptr) {
  602. x, _ := makeAggregate(tls, ctx)
  603. if x == nil {
  604. return
  605. }
  606. vals := make([]Value, 0, int(n))
  607. for ; len(vals) < cap(vals); valarray += uintptr(ptrSize) {
  608. vals = append(vals, Value{
  609. tls: tls,
  610. ptrOrType: *(*uintptr)(unsafe.Pointer(valarray)),
  611. })
  612. }
  613. goCtx := Context{tls: tls, ptr: ctx}
  614. if err := x.WindowInverse(goCtx, vals); err != nil {
  615. goCtx.resultError(err)
  616. }
  617. }
  618. func destroyAggregateFunc(tls *libc.TLS, id uintptr) {
  619. xAggregateFactories.mu.Lock()
  620. defer xAggregateFactories.mu.Unlock()
  621. delete(xAggregateFactories.m, id)
  622. xAggregateFactories.ids.reclaim(id)
  623. }
  624. // CollatingFunc is a [collating function/sequence],
  625. // that is, a function that compares two strings.
  626. // The function returns a negative number if a < b,
  627. // a positive number if a > b,
  628. // or zero if a == b.
  629. // A collating function must always return the same answer given the same inputs.
  630. // The collating function must obey the following properties for all strings A, B, and C:
  631. //
  632. // 1. If A==B then B==A.
  633. // 2. If A==B and B==C then A==C.
  634. // 3. If A<B then B>A.
  635. // 4. If A<B and B<C then A<C.
  636. //
  637. // [collating function/sequence]: https://www.sqlite.org/datatype3.html#collation
  638. type CollatingFunc func(a, b string) int
  639. // SetCollation sets the [collating function] for the given name.
  640. //
  641. // [collating function]: https://www.sqlite.org/datatype3.html#collation
  642. func (c *Conn) SetCollation(name string, compare CollatingFunc) error {
  643. verb := "create"
  644. if compare == nil {
  645. verb = "remove"
  646. }
  647. if c == nil {
  648. return fmt.Errorf("sqlite: %s collation: nil connection", verb)
  649. }
  650. if name == "" {
  651. return fmt.Errorf("sqlite: %s collation: no name provided", verb)
  652. }
  653. cname, err := libc.CString(name)
  654. if err != nil {
  655. return fmt.Errorf("sqlite: %s collation: no name provided", verb)
  656. }
  657. defer libc.Xfree(c.tls, cname)
  658. if compare == nil {
  659. res := ResultCode(lib.Xsqlite3_create_collation_v2(
  660. c.tls, c.conn, cname, lib.SQLITE_UTF8, 0, 0, 0,
  661. ))
  662. if err := res.ToError(); err != nil {
  663. return fmt.Errorf("sqlite: %s collation %s: %w", verb, name, err)
  664. }
  665. return nil
  666. }
  667. xcollations.mu.Lock()
  668. id := xcollations.ids.next()
  669. xcollations.m[id] = compare
  670. xcollations.mu.Unlock()
  671. res := ResultCode(lib.Xsqlite3_create_collation_v2(
  672. c.tls, c.conn, cname, lib.SQLITE_UTF8, id, cFuncPointer(collationTrampoline), cFuncPointer(destroyCollation),
  673. ))
  674. if err := res.ToError(); err != nil {
  675. destroyCollation(c.tls, id)
  676. return fmt.Errorf("sqlite: %s collation %s: %w", verb, name, err)
  677. }
  678. return nil
  679. }
  680. var xcollations = struct {
  681. mu sync.RWMutex
  682. m map[uintptr]CollatingFunc
  683. ids idGen
  684. }{
  685. m: make(map[uintptr]CollatingFunc),
  686. }
  687. func collationTrampoline(tls *libc.TLS, id uintptr, n1 int32, p1 uintptr, n2 int32, p2 uintptr) int32 {
  688. xcollations.mu.RLock()
  689. f := xcollations.m[id]
  690. xcollations.mu.RUnlock()
  691. s1 := goStringN(p1, int(n1))
  692. s2 := goStringN(p2, int(n2))
  693. switch x := f(s1, s2); {
  694. case x > 0:
  695. return 1
  696. case x < 0:
  697. return -1
  698. default:
  699. return 0
  700. }
  701. }
  702. func destroyCollation(tls *libc.TLS, id uintptr) {
  703. xcollations.mu.Lock()
  704. defer xcollations.mu.Unlock()
  705. delete(xcollations.m, id)
  706. xcollations.ids.reclaim(id)
  707. }
  708. // idGen is an ID generator. The zero value is ready to use.
  709. type idGen struct {
  710. bitset []uint64
  711. }
  712. func (gen *idGen) next() uintptr {
  713. base := uintptr(1)
  714. for i := 0; i < len(gen.bitset); i, base = i+1, base+64 {
  715. b := gen.bitset[i]
  716. if b != 1<<64-1 {
  717. n := uintptr(bits.TrailingZeros64(^b))
  718. gen.bitset[i] |= 1 << n
  719. return base + n
  720. }
  721. }
  722. gen.bitset = append(gen.bitset, 1)
  723. return base
  724. }
  725. func (gen *idGen) reclaim(id uintptr) {
  726. bit := id - 1
  727. gen.bitset[bit/64] &^= 1 << (bit % 64)
  728. }