helpers.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package clickhouse
  2. import (
  3. "bytes"
  4. "database/sql/driver"
  5. "fmt"
  6. "reflect"
  7. "regexp"
  8. "strings"
  9. "time"
  10. )
  11. func numInput(query string) int {
  12. var (
  13. count int
  14. args = make(map[string]struct{})
  15. reader = bytes.NewReader([]byte(query))
  16. quote, gravis bool
  17. escape bool
  18. keyword bool
  19. inBetween bool
  20. like = newMatcher("like")
  21. limit = newMatcher("limit")
  22. offset = newMatcher("offset")
  23. between = newMatcher("between")
  24. in = newMatcher("in")
  25. and = newMatcher("and")
  26. from = newMatcher("from")
  27. join = newMatcher("join")
  28. subSelect = newMatcher("select")
  29. )
  30. for {
  31. if char, _, err := reader.ReadRune(); err == nil {
  32. if escape {
  33. escape = false
  34. continue
  35. }
  36. switch char {
  37. case '\\':
  38. if gravis || quote {
  39. escape = true
  40. }
  41. case '\'':
  42. if !gravis {
  43. quote = !quote
  44. }
  45. case '`':
  46. if !quote {
  47. gravis = !gravis
  48. }
  49. }
  50. if quote || gravis {
  51. continue
  52. }
  53. switch {
  54. case char == '?' && keyword:
  55. count++
  56. case char == '@':
  57. if param := paramParser(reader); len(param) != 0 {
  58. if _, found := args[param]; !found {
  59. args[param] = struct{}{}
  60. count++
  61. }
  62. }
  63. case
  64. char == '=',
  65. char == '<',
  66. char == '>',
  67. char == '(',
  68. char == ',',
  69. char == '[',
  70. char == '%':
  71. keyword = true
  72. default:
  73. if limit.matchRune(char) || offset.matchRune(char) || like.matchRune(char) ||
  74. in.matchRune(char) || from.matchRune(char) || join.matchRune(char) || subSelect.matchRune(char) {
  75. keyword = true
  76. } else if between.matchRune(char) {
  77. keyword = true
  78. inBetween = true
  79. } else if inBetween && and.matchRune(char) {
  80. keyword = true
  81. inBetween = false
  82. } else {
  83. keyword = keyword && (char == ' ' || char == '\t' || char == '\n')
  84. }
  85. }
  86. } else {
  87. break
  88. }
  89. }
  90. return count
  91. }
  92. func paramParser(reader *bytes.Reader) string {
  93. var name bytes.Buffer
  94. for {
  95. if char, _, err := reader.ReadRune(); err == nil {
  96. if char == '_' || char >= '0' && char <= '9' || 'a' <= char && char <= 'z' || 'A' <= char && char <= 'Z' {
  97. name.WriteRune(char)
  98. } else {
  99. reader.UnreadRune()
  100. break
  101. }
  102. } else {
  103. break
  104. }
  105. }
  106. return name.String()
  107. }
  108. var selectRe = regexp.MustCompile(`\s+SELECT\s+`)
  109. func isInsert(query string) bool {
  110. if f := strings.Fields(query); len(f) > 2 {
  111. return strings.EqualFold("INSERT", f[0]) && strings.EqualFold("INTO", f[1]) && !selectRe.MatchString(strings.ToUpper(query))
  112. }
  113. return false
  114. }
  115. func quote(v driver.Value) string {
  116. switch v := reflect.ValueOf(v); v.Kind() {
  117. case reflect.Slice:
  118. values := make([]string, 0, v.Len())
  119. for i := 0; i < v.Len(); i++ {
  120. values = append(values, quote(v.Index(i).Interface()))
  121. }
  122. return strings.Join(values, ", ")
  123. }
  124. switch v := v.(type) {
  125. case string:
  126. return "'" + strings.NewReplacer(`\`, `\\`, `'`, `\'`).Replace(v) + "'"
  127. case time.Time:
  128. return formatTime(v)
  129. case nil:
  130. return "null"
  131. }
  132. return fmt.Sprint(v)
  133. }
  134. func formatTime(v time.Time) string {
  135. return v.Format("toDateTime('2006-01-02 15:04:05', '" + v.Location().String() + "')")
  136. }