auth.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. // Copyright 2021 Ross Light
  2. // SPDX-License-Identifier: ISC
  3. package sqlite
  4. import (
  5. "fmt"
  6. "strings"
  7. "sync"
  8. "modernc.org/libc"
  9. lib "modernc.org/sqlite/lib"
  10. )
  11. // An Authorizer is called during statement preparation to see whether an action
  12. // is allowed by the application. An Authorizer must not modify the database
  13. // connection, including by preparing statements.
  14. //
  15. // See https://sqlite.org/c3ref/set_authorizer.html for a longer explanation.
  16. type Authorizer interface {
  17. Authorize(Action) AuthResult
  18. }
  19. // SetAuthorizer registers an authorizer for the database connection.
  20. // SetAuthorizer(nil) clears any authorizer previously set.
  21. func (c *Conn) SetAuthorizer(auth Authorizer) error {
  22. if c == nil {
  23. return fmt.Errorf("sqlite: set authorizer: nil connection")
  24. }
  25. if auth == nil {
  26. c.releaseAuthorizer()
  27. res := ResultCode(lib.Xsqlite3_set_authorizer(c.tls, c.conn, 0, 0))
  28. if err := res.ToError(); err != nil {
  29. return fmt.Errorf("sqlite: set authorizer: %w", err)
  30. }
  31. return nil
  32. }
  33. authorizers.mu.Lock()
  34. if authorizers.m == nil {
  35. authorizers.m = make(map[uintptr]Authorizer)
  36. }
  37. authorizers.m[c.conn] = auth
  38. authorizers.mu.Unlock()
  39. xAuth := cFuncPointer(authTrampoline)
  40. res := ResultCode(lib.Xsqlite3_set_authorizer(c.tls, c.conn, xAuth, c.conn))
  41. if err := res.ToError(); err != nil {
  42. return fmt.Errorf("sqlite: set authorizer: %w", err)
  43. }
  44. return nil
  45. }
  46. func (c *Conn) releaseAuthorizer() {
  47. authorizers.mu.Lock()
  48. delete(authorizers.m, c.conn)
  49. authorizers.mu.Unlock()
  50. }
  51. var authorizers struct {
  52. mu sync.RWMutex
  53. m map[uintptr]Authorizer // sqlite3* -> Authorizer
  54. }
  55. func authTrampoline(tls *libc.TLS, conn uintptr, op int32, cArg1, cArg2, cDB, cTrigger uintptr) int32 {
  56. authorizers.mu.RLock()
  57. auth := authorizers.m[conn]
  58. authorizers.mu.RUnlock()
  59. return int32(auth.Authorize(Action{
  60. op: OpType(op),
  61. arg1: libc.GoString(cArg1),
  62. arg2: libc.GoString(cArg2),
  63. database: libc.GoString(cDB),
  64. trigger: libc.GoString(cTrigger),
  65. }))
  66. }
  67. // AuthorizeFunc is a function that implements Authorizer.
  68. type AuthorizeFunc func(Action) AuthResult
  69. // Authorize calls f.
  70. func (f AuthorizeFunc) Authorize(action Action) AuthResult {
  71. return f(action)
  72. }
  73. // AuthResult is the result of a call to an Authorizer. The zero value is
  74. // AuthResultOK.
  75. type AuthResult int32
  76. // Possible return values from Authorize.
  77. const (
  78. // AuthResultOK allows the SQL statement to be compiled.
  79. AuthResultOK AuthResult = lib.SQLITE_OK
  80. // AuthResultDeny causes the entire SQL statement to be rejected with an error.
  81. AuthResultDeny AuthResult = lib.SQLITE_DENY
  82. // AuthResultIgnore disallows the specific action but allow the SQL statement
  83. // to continue to be compiled. For OpRead, this substitutes a NULL for the
  84. // column value. For OpDelete, the DELETE operation proceeds but the truncate
  85. // optimization is disabled and all rows are deleted individually.
  86. AuthResultIgnore AuthResult = lib.SQLITE_IGNORE
  87. )
  88. // String returns the C constant name of the result.
  89. func (result AuthResult) String() string {
  90. switch result {
  91. case AuthResultOK:
  92. return "SQLITE_OK"
  93. case AuthResultDeny:
  94. return "SQLITE_DENY"
  95. case AuthResultIgnore:
  96. return "SQLITE_IGNORE"
  97. default:
  98. return fmt.Sprintf("AuthResult(%d)", int32(result))
  99. }
  100. }
  101. // Action represents an action to be authorized.
  102. type Action struct {
  103. op OpType
  104. arg1 string
  105. arg2 string
  106. database string
  107. trigger string
  108. }
  109. // Mapping of argument position to concept at:
  110. // https://sqlite.org/c3ref/c_alter_table.html
  111. // Type returns the type of action being authorized.
  112. func (action Action) Type() OpType {
  113. return action.op
  114. }
  115. // Accessor returns the name of the inner-most trigger or view that is
  116. // responsible for the access attempt or the empty string if this access attempt
  117. // is directly from top-level SQL code.
  118. func (action Action) Accessor() string {
  119. return action.trigger
  120. }
  121. // Database returns the name of the database (e.g. "main", "temp", etc.) this
  122. // action affects or the empty string if not applicable.
  123. func (action Action) Database() string {
  124. switch action.op {
  125. case OpDetach, OpAlterTable:
  126. return action.arg1
  127. default:
  128. return action.database
  129. }
  130. }
  131. // Index returns the name of the index this action affects or the empty string
  132. // if not applicable.
  133. func (action Action) Index() string {
  134. switch action.op {
  135. case OpCreateIndex, OpCreateTempIndex, OpDropIndex, OpDropTempIndex, OpReindex:
  136. return action.arg1
  137. default:
  138. return ""
  139. }
  140. }
  141. // Table returns the name of the table this action affects or the empty string
  142. // if not applicable.
  143. func (action Action) Table() string {
  144. switch action.op {
  145. case OpCreateTable, OpCreateTempTable, OpDelete, OpDropTable, OpDropTempTable, OpInsert, OpRead, OpUpdate, OpAnalyze, OpCreateVTable, OpDropVTable:
  146. return action.arg1
  147. case OpCreateIndex, OpCreateTempIndex, OpCreateTempTrigger, OpCreateTrigger, OpDropIndex, OpDropTempIndex, OpDropTempTrigger, OpDropTrigger, OpAlterTable:
  148. return action.arg2
  149. default:
  150. return ""
  151. }
  152. }
  153. // Trigger returns the name of the trigger this action affects or the empty
  154. // string if not applicable.
  155. func (action Action) Trigger() string {
  156. switch action.op {
  157. case OpCreateTempTrigger, OpCreateTrigger, OpDropTempTrigger, OpDropTrigger:
  158. return action.arg1
  159. default:
  160. return ""
  161. }
  162. }
  163. // View returns the name of the view this action affects or the empty string
  164. // if not applicable.
  165. func (action Action) View() string {
  166. switch action.op {
  167. case OpCreateTempView, OpCreateView, OpDropTempView, OpDropView:
  168. return action.arg1
  169. default:
  170. return ""
  171. }
  172. }
  173. // Pragma returns the name of the action's PRAGMA command or the empty string
  174. // if the action does not represent a PRAGMA command.
  175. // See https://sqlite.org/pragma.html#toc for a list of possible values.
  176. func (action Action) Pragma() string {
  177. if action.op != OpPragma {
  178. return ""
  179. }
  180. return action.arg1
  181. }
  182. // PragmaArg returns the argument to the PRAGMA command or the empty string if
  183. // the action does not represent a PRAGMA command or the PRAGMA command does not
  184. // take an argument.
  185. func (action Action) PragmaArg() string {
  186. if action.op != OpPragma {
  187. return ""
  188. }
  189. return action.arg2
  190. }
  191. // Column returns the name of the column this action affects or the empty string
  192. // if not applicable. For OpRead actions, this will return the empty string if a
  193. // table is referenced but no column values are extracted from that table
  194. // (e.g. a query like "SELECT COUNT(*) FROM tab").
  195. func (action Action) Column() string {
  196. switch action.op {
  197. case OpRead, OpUpdate:
  198. return action.arg2
  199. default:
  200. return ""
  201. }
  202. }
  203. // Operation returns one of "BEGIN", "COMMIT", "RELEASE", or "ROLLBACK" for a
  204. // transaction or savepoint statement or the empty string otherwise.
  205. func (action Action) Operation() string {
  206. switch action.op {
  207. case OpTransaction, OpSavepoint:
  208. return action.arg1
  209. default:
  210. return ""
  211. }
  212. }
  213. // File returns the name of the file being ATTACHed or the empty string if the
  214. // action does not represent an ATTACH DATABASE statement.
  215. func (action Action) File() string {
  216. if action.op != OpAttach {
  217. return ""
  218. }
  219. return action.arg1
  220. }
  221. // Module returns the module name given to the virtual table statement or the
  222. // empty string if the action does not represent a CREATE VIRTUAL TABLE or
  223. // DROP VIRTUAL TABLE statement.
  224. func (action Action) Module() string {
  225. switch action.op {
  226. case OpCreateVTable, OpDropVTable:
  227. return action.arg2
  228. default:
  229. return ""
  230. }
  231. }
  232. // Savepoint returns the name given to the SAVEPOINT statement or the empty
  233. // string if the action does not represent a SAVEPOINT statement.
  234. func (action Action) Savepoint() string {
  235. if action.op != OpSavepoint {
  236. return ""
  237. }
  238. return action.arg2
  239. }
  240. // String returns a debugging representation of the action.
  241. func (action Action) String() string {
  242. sb := new(strings.Builder)
  243. sb.WriteString(action.op.String())
  244. params := []struct {
  245. name, value string
  246. }{
  247. {"database", action.Database()},
  248. {"file", action.File()},
  249. {"trigger", action.Trigger()},
  250. {"index", action.Index()},
  251. {"table", action.Table()},
  252. {"view", action.View()},
  253. {"module", action.Module()},
  254. {"column", action.Column()},
  255. {"operation", action.Operation()},
  256. {"savepoint", action.Savepoint()},
  257. {"pragma", action.Pragma()},
  258. {"arg", action.PragmaArg()},
  259. }
  260. for _, p := range params {
  261. if p.value != "" {
  262. sb.WriteString(" ")
  263. sb.WriteString(p.name)
  264. sb.WriteString(":")
  265. sb.WriteString(p.value)
  266. }
  267. }
  268. return sb.String()
  269. }