connect.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. package clickhouse
  2. import (
  3. "bufio"
  4. "crypto/tls"
  5. "database/sql/driver"
  6. "net"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. )
  11. var tick int32
  12. type openStrategy int8
  13. func (s openStrategy) String() string {
  14. switch s {
  15. case connOpenInOrder:
  16. return "in_order"
  17. case connOpenTimeRandom:
  18. return "time_random"
  19. }
  20. return "random"
  21. }
  22. const (
  23. connOpenRandom openStrategy = iota + 1
  24. connOpenInOrder
  25. connOpenTimeRandom
  26. )
  27. type connOptions struct {
  28. secure, skipVerify bool
  29. tlsConfig *tls.Config
  30. hosts []string
  31. connTimeout, readTimeout, writeTimeout time.Duration
  32. noDelay bool
  33. openStrategy openStrategy
  34. logf func(string, ...interface{})
  35. }
  36. // DialFunc is a function which can be used to establish the network connection.
  37. // Custom dial functions must be registered with RegisterDial
  38. type DialFunc func(network, address string, timeout time.Duration, config *tls.Config) (net.Conn, error)
  39. var (
  40. customDialLock sync.RWMutex
  41. customDial DialFunc
  42. )
  43. // RegisterDial registers a custom dial function.
  44. func RegisterDial(dial DialFunc) {
  45. customDialLock.Lock()
  46. customDial = dial
  47. customDialLock.Unlock()
  48. }
  49. // DeregisterDial deregisters the custom dial function.
  50. func DeregisterDial() {
  51. customDialLock.Lock()
  52. customDial = nil
  53. customDialLock.Unlock()
  54. }
  55. func dial(options connOptions) (*connect, error) {
  56. var (
  57. err error
  58. abs = func(v int) int {
  59. if v < 0 {
  60. return -1 * v
  61. }
  62. return v
  63. }
  64. conn net.Conn
  65. ident = abs(int(atomic.AddInt32(&tick, 1)))
  66. )
  67. tlsConfig := options.tlsConfig
  68. if options.secure {
  69. if tlsConfig == nil {
  70. tlsConfig = &tls.Config{}
  71. }
  72. tlsConfig.InsecureSkipVerify = options.skipVerify
  73. }
  74. checkedHosts := make(map[int]struct{}, len(options.hosts))
  75. for i := range options.hosts {
  76. var num int
  77. switch options.openStrategy {
  78. case connOpenInOrder:
  79. num = i
  80. case connOpenRandom:
  81. num = (ident + i) % len(options.hosts)
  82. case connOpenTimeRandom:
  83. // select host based on milliseconds
  84. num = int((time.Now().UnixNano()/1000)%1000) % len(options.hosts)
  85. for _, ok := checkedHosts[num]; ok; _, ok = checkedHosts[num] {
  86. num = int(time.Now().UnixNano()) % len(options.hosts)
  87. }
  88. checkedHosts[num] = struct{}{}
  89. }
  90. customDialLock.RLock()
  91. cd := customDial
  92. customDialLock.RUnlock()
  93. switch {
  94. case options.secure:
  95. if cd != nil {
  96. conn, err = cd("tcp", options.hosts[num], options.connTimeout, tlsConfig)
  97. } else {
  98. conn, err = tls.DialWithDialer(
  99. &net.Dialer{
  100. Timeout: options.connTimeout,
  101. },
  102. "tcp",
  103. options.hosts[num],
  104. tlsConfig,
  105. )
  106. }
  107. default:
  108. if cd != nil {
  109. conn, err = cd("tcp", options.hosts[num], options.connTimeout, nil)
  110. } else {
  111. conn, err = net.DialTimeout("tcp", options.hosts[num], options.connTimeout)
  112. }
  113. }
  114. if err == nil {
  115. options.logf(
  116. "[dial] secure=%t, skip_verify=%t, strategy=%s, ident=%d, server=%d -> %s",
  117. options.secure,
  118. options.skipVerify,
  119. options.openStrategy,
  120. ident,
  121. num,
  122. conn.RemoteAddr(),
  123. )
  124. if tcp, ok := conn.(*net.TCPConn); ok {
  125. err = tcp.SetNoDelay(options.noDelay) // Disable or enable the Nagle Algorithm for this tcp socket
  126. if err != nil {
  127. return nil, err
  128. }
  129. }
  130. return &connect{
  131. Conn: conn,
  132. logf: options.logf,
  133. ident: ident,
  134. buffer: bufio.NewReader(conn),
  135. readTimeout: options.readTimeout,
  136. writeTimeout: options.writeTimeout,
  137. }, nil
  138. } else {
  139. options.logf(
  140. "[dial err] secure=%t, skip_verify=%t, strategy=%s, ident=%d, addr=%s\n%#v",
  141. options.secure,
  142. options.skipVerify,
  143. options.openStrategy,
  144. ident,
  145. options.hosts[num],
  146. err,
  147. )
  148. }
  149. }
  150. return nil, err
  151. }
  152. type connect struct {
  153. net.Conn
  154. logf func(string, ...interface{})
  155. ident int
  156. buffer *bufio.Reader
  157. closed bool
  158. readTimeout time.Duration
  159. writeTimeout time.Duration
  160. lastReadDeadlineTime time.Time
  161. lastWriteDeadlineTime time.Time
  162. }
  163. func (conn *connect) Read(b []byte) (int, error) {
  164. var (
  165. n int
  166. err error
  167. total int
  168. dstLen = len(b)
  169. )
  170. if currentTime := now(); conn.readTimeout != 0 && currentTime.Sub(conn.lastReadDeadlineTime) > (conn.readTimeout>>2) {
  171. conn.SetReadDeadline(time.Now().Add(conn.readTimeout))
  172. conn.lastReadDeadlineTime = currentTime
  173. }
  174. for total < dstLen {
  175. if n, err = conn.buffer.Read(b[total:]); err != nil {
  176. conn.logf("[connect] read error: %v", err)
  177. conn.Close()
  178. return n, driver.ErrBadConn
  179. }
  180. total += n
  181. }
  182. return total, nil
  183. }
  184. func (conn *connect) Write(b []byte) (int, error) {
  185. var (
  186. n int
  187. err error
  188. total int
  189. srcLen = len(b)
  190. )
  191. if currentTime := now(); conn.writeTimeout != 0 && currentTime.Sub(conn.lastWriteDeadlineTime) > (conn.writeTimeout>>2) {
  192. conn.SetWriteDeadline(time.Now().Add(conn.writeTimeout))
  193. conn.lastWriteDeadlineTime = currentTime
  194. }
  195. for total < srcLen {
  196. if n, err = conn.Conn.Write(b[total:]); err != nil {
  197. conn.logf("[connect] write error: %v", err)
  198. conn.Close()
  199. return n, driver.ErrBadConn
  200. }
  201. total += n
  202. }
  203. return n, nil
  204. }
  205. func (conn *connect) Close() error {
  206. if !conn.closed {
  207. conn.closed = true
  208. return conn.Conn.Close()
  209. }
  210. return nil
  211. }