rows.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package clickhouse
  2. import (
  3. "database/sql/driver"
  4. "fmt"
  5. "io"
  6. "reflect"
  7. "sync"
  8. "time"
  9. "github.com/ClickHouse/clickhouse-go/lib/column"
  10. "github.com/ClickHouse/clickhouse-go/lib/data"
  11. "github.com/ClickHouse/clickhouse-go/lib/protocol"
  12. )
  13. type rows struct {
  14. ch *clickhouse
  15. err error
  16. mutex sync.RWMutex
  17. finish func()
  18. offset int
  19. block *data.Block
  20. totals *data.Block
  21. extremes *data.Block
  22. stream chan *data.Block
  23. columns []string
  24. blockColumns []column.Column
  25. }
  26. func (rows *rows) Columns() []string {
  27. return rows.columns
  28. }
  29. func (rows *rows) ColumnTypeScanType(idx int) reflect.Type {
  30. return rows.blockColumns[idx].ScanType()
  31. }
  32. func (rows *rows) ColumnTypeDatabaseTypeName(idx int) string {
  33. return rows.blockColumns[idx].CHType()
  34. }
  35. func (rows *rows) Next(dest []driver.Value) error {
  36. if rows.block == nil || int(rows.block.NumRows) <= rows.offset {
  37. switch block, ok := <-rows.stream; true {
  38. case !ok:
  39. if err := rows.error(); err != nil {
  40. return err
  41. }
  42. return io.EOF
  43. default:
  44. rows.block = block
  45. rows.offset = 0
  46. }
  47. }
  48. for i := range dest {
  49. dest[i] = rows.block.Values[i][rows.offset]
  50. }
  51. rows.offset++
  52. return nil
  53. }
  54. func (rows *rows) HasNextResultSet() bool {
  55. return rows.totals != nil || rows.extremes != nil
  56. }
  57. func (rows *rows) NextResultSet() error {
  58. switch {
  59. case rows.totals != nil:
  60. rows.block = rows.totals
  61. rows.offset = 0
  62. rows.totals = nil
  63. case rows.extremes != nil:
  64. rows.block = rows.extremes
  65. rows.offset = 0
  66. rows.extremes = nil
  67. default:
  68. return io.EOF
  69. }
  70. return nil
  71. }
  72. func (rows *rows) receiveData() error {
  73. defer close(rows.stream)
  74. var (
  75. err error
  76. packet uint64
  77. progress *progress
  78. profileInfo *profileInfo
  79. )
  80. for {
  81. if packet, err = rows.ch.decoder.Uvarint(); err != nil {
  82. return rows.setError(err)
  83. }
  84. switch packet {
  85. case protocol.ServerException:
  86. rows.ch.logf("[rows] <- exception")
  87. return rows.setError(rows.ch.exception())
  88. case protocol.ServerProgress:
  89. if progress, err = rows.ch.progress(); err != nil {
  90. return rows.setError(err)
  91. }
  92. rows.ch.logf("[rows] <- progress: rows=%d, bytes=%d, total rows=%d",
  93. progress.rows,
  94. progress.bytes,
  95. progress.totalRows,
  96. )
  97. case protocol.ServerProfileInfo:
  98. if profileInfo, err = rows.ch.profileInfo(); err != nil {
  99. return rows.setError(err)
  100. }
  101. rows.ch.logf("[rows] <- profiling: rows=%d, bytes=%d, blocks=%d", profileInfo.rows, profileInfo.bytes, profileInfo.blocks)
  102. case protocol.ServerData, protocol.ServerTotals, protocol.ServerExtremes:
  103. var (
  104. block *data.Block
  105. begin = time.Now()
  106. )
  107. if block, err = rows.ch.readBlock(); err != nil {
  108. return rows.setError(err)
  109. }
  110. rows.ch.logf("[rows] <- data: packet=%d, columns=%d, rows=%d, elapsed=%s", packet, block.NumColumns, block.NumRows, time.Since(begin))
  111. if block.NumRows == 0 {
  112. continue
  113. }
  114. switch packet {
  115. case protocol.ServerData:
  116. rows.stream <- block
  117. case protocol.ServerTotals:
  118. rows.totals = block
  119. case protocol.ServerExtremes:
  120. rows.extremes = block
  121. }
  122. case protocol.ServerEndOfStream:
  123. rows.ch.logf("[rows] <- end of stream")
  124. return nil
  125. default:
  126. rows.ch.conn.Close()
  127. rows.ch.logf("[rows] unexpected packet [%d]", packet)
  128. return rows.setError(fmt.Errorf("[rows] unexpected packet [%d] from server", packet))
  129. }
  130. }
  131. }
  132. func (rows *rows) Close() error {
  133. rows.ch.logf("[rows] close")
  134. rows.columns = nil
  135. for range rows.stream {
  136. }
  137. rows.finish()
  138. return nil
  139. }
  140. func (rows *rows) error() error {
  141. rows.mutex.RLock()
  142. defer rows.mutex.RUnlock()
  143. return rows.err
  144. }
  145. func (rows *rows) setError(err error) error {
  146. rows.mutex.Lock()
  147. rows.err = err
  148. rows.mutex.Unlock()
  149. return err
  150. }
  151. func (rows *rows) ColumnTypeNullable(idx int) (nullable, ok bool) {
  152. _, ok = rows.blockColumns[idx].(*column.Nullable)
  153. return ok, true
  154. }
  155. func (rows *rows) ColumnTypePrecisionScale(idx int) (precision, scale int64, ok bool) {
  156. decimalVal, ok := rows.blockColumns[idx].(*column.Decimal)
  157. if !ok {
  158. if nullable, nullOk := rows.blockColumns[idx].(*column.Nullable); nullOk {
  159. decimalVal, ok = nullable.GetColumn().(*column.Decimal)
  160. }
  161. }
  162. if ok {
  163. return int64(decimalVal.GetPrecision()), int64(decimalVal.GetScale()), ok
  164. }
  165. return 0, 0, false
  166. }