stmt.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. package clickhouse
  2. import (
  3. "bytes"
  4. "context"
  5. "database/sql/driver"
  6. "unicode"
  7. "github.com/ClickHouse/clickhouse-go/lib/data"
  8. )
  9. type stmt struct {
  10. ch *clickhouse
  11. query string
  12. counter int
  13. numInput int
  14. isInsert bool
  15. }
  16. var emptyResult = &result{}
  17. type key string
  18. var queryIDKey key
  19. //Put query ID into context and use it in ExecContext or QueryContext
  20. func WithQueryID(ctx context.Context, queryID string) context.Context {
  21. return context.WithValue(ctx, queryIDKey, queryID)
  22. }
  23. func (stmt *stmt) NumInput() int {
  24. switch {
  25. case stmt.ch.block != nil:
  26. return len(stmt.ch.block.Columns)
  27. case stmt.numInput < 0:
  28. return 0
  29. }
  30. return stmt.numInput
  31. }
  32. func (stmt *stmt) Exec(args []driver.Value) (driver.Result, error) {
  33. return stmt.execContext(context.Background(), args)
  34. }
  35. func (stmt *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
  36. dargs := make([]driver.Value, len(args))
  37. for i, nv := range args {
  38. dargs[i] = nv.Value
  39. }
  40. return stmt.execContext(ctx, dargs)
  41. }
  42. func (stmt *stmt) execContext(ctx context.Context, args []driver.Value) (driver.Result, error) {
  43. if stmt.isInsert {
  44. stmt.counter++
  45. if err := stmt.ch.block.AppendRow(args); err != nil {
  46. return nil, err
  47. }
  48. if (stmt.counter % stmt.ch.blockSize) == 0 {
  49. stmt.ch.logf("[exec] flush block")
  50. if err := stmt.ch.writeBlock(stmt.ch.block, ""); err != nil {
  51. return nil, err
  52. }
  53. if err := stmt.ch.encoder.Flush(); err != nil {
  54. return nil, err
  55. }
  56. }
  57. return emptyResult, nil
  58. }
  59. query, externalTables := stmt.bind(convertOldArgs(args))
  60. if err := stmt.ch.sendQuery(ctx, query, externalTables); err != nil {
  61. return nil, err
  62. }
  63. if err := stmt.ch.process(); err != nil {
  64. return nil, err
  65. }
  66. return emptyResult, nil
  67. }
  68. func (stmt *stmt) Query(args []driver.Value) (driver.Rows, error) {
  69. return stmt.queryContext(context.Background(), convertOldArgs(args))
  70. }
  71. func (stmt *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
  72. return stmt.queryContext(ctx, args)
  73. }
  74. func (stmt *stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
  75. finish := stmt.ch.watchCancel(ctx)
  76. query, externalTables := stmt.bind(args)
  77. if err := stmt.ch.sendQuery(ctx, query, externalTables); err != nil {
  78. finish()
  79. return nil, err
  80. }
  81. meta, err := stmt.ch.readMeta()
  82. if err != nil {
  83. finish()
  84. return nil, err
  85. }
  86. rows := rows{
  87. ch: stmt.ch,
  88. finish: finish,
  89. stream: make(chan *data.Block, 50),
  90. columns: meta.ColumnNames(),
  91. blockColumns: meta.Columns,
  92. }
  93. go rows.receiveData()
  94. return &rows, nil
  95. }
  96. func (stmt *stmt) Close() error {
  97. stmt.ch.logf("[stmt] close")
  98. return nil
  99. }
  100. func (stmt *stmt) bind(args []driver.NamedValue) (string, []ExternalTable) {
  101. var (
  102. buf bytes.Buffer
  103. index int
  104. keyword bool
  105. inBetween bool
  106. like = newMatcher("like")
  107. limit = newMatcher("limit")
  108. offset = newMatcher("offset")
  109. between = newMatcher("between")
  110. and = newMatcher("and")
  111. in = newMatcher("in")
  112. from = newMatcher("from")
  113. join = newMatcher("join")
  114. subSelect = newMatcher("select")
  115. externalTables = make([]ExternalTable, 0)
  116. )
  117. switch {
  118. case stmt.NumInput() != 0:
  119. reader := bytes.NewReader([]byte(stmt.query))
  120. for {
  121. if char, _, err := reader.ReadRune(); err == nil {
  122. switch char {
  123. case '@':
  124. if param := paramParser(reader); len(param) != 0 {
  125. for _, v := range args {
  126. if len(v.Name) != 0 && v.Name == param {
  127. switch v := v.Value.(type) {
  128. case ExternalTable:
  129. buf.WriteString(v.Name)
  130. externalTables = append(externalTables, v)
  131. default:
  132. buf.WriteString(quote(v))
  133. }
  134. }
  135. }
  136. }
  137. case '?':
  138. if keyword && index < len(args) && len(args[index].Name) == 0 {
  139. switch v := args[index].Value.(type) {
  140. case ExternalTable:
  141. buf.WriteString(v.Name)
  142. externalTables = append(externalTables, v)
  143. default:
  144. buf.WriteString(quote(v))
  145. }
  146. index++
  147. } else {
  148. buf.WriteRune(char)
  149. }
  150. default:
  151. switch {
  152. case
  153. char == '=',
  154. char == '<',
  155. char == '>',
  156. char == '(',
  157. char == ',',
  158. char == '+',
  159. char == '-',
  160. char == '*',
  161. char == '/',
  162. char == '[':
  163. keyword = true
  164. default:
  165. if limit.matchRune(char) || offset.matchRune(char) || like.matchRune(char) ||
  166. in.matchRune(char) || from.matchRune(char) || join.matchRune(char) || subSelect.matchRune(char) {
  167. keyword = true
  168. } else if between.matchRune(char) {
  169. keyword = true
  170. inBetween = true
  171. } else if inBetween && and.matchRune(char) {
  172. keyword = true
  173. inBetween = false
  174. } else {
  175. keyword = keyword && unicode.IsSpace(char)
  176. }
  177. }
  178. buf.WriteRune(char)
  179. }
  180. } else {
  181. break
  182. }
  183. }
  184. default:
  185. buf.WriteString(stmt.query)
  186. }
  187. return buf.String(), externalTables
  188. }
  189. func convertOldArgs(args []driver.Value) []driver.NamedValue {
  190. dargs := make([]driver.NamedValue, len(args))
  191. for i, v := range args {
  192. dargs[i] = driver.NamedValue{
  193. Ordinal: i + 1,
  194. Value: v,
  195. }
  196. }
  197. return dargs
  198. }