| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- package clickhouse
- import (
- "bufio"
- "database/sql"
- "database/sql/driver"
- "fmt"
- "io"
- "log"
- "net/url"
- "os"
- "strconv"
- "strings"
- "sync"
- "sync/atomic"
- "time"
- "github.com/ClickHouse/clickhouse-go/lib/binary"
- "github.com/ClickHouse/clickhouse-go/lib/data"
- "github.com/ClickHouse/clickhouse-go/lib/protocol"
- )
- const (
- // DefaultDatabase when connecting to ClickHouse
- DefaultDatabase = "default"
- // DefaultUsername when connecting to ClickHouse
- DefaultUsername = "default"
- // DefaultConnTimeout when connecting to ClickHouse
- DefaultConnTimeout = 5 * time.Second
- // DefaultReadTimeout when reading query results
- DefaultReadTimeout = time.Minute
- // DefaultWriteTimeout when sending queries
- DefaultWriteTimeout = time.Minute
- )
- var (
- unixtime int64
- logOutput io.Writer = os.Stdout
- hostname, _ = os.Hostname()
- poolInit sync.Once
- )
- func init() {
- sql.Register("clickhouse", &bootstrap{})
- go func() {
- for tick := time.Tick(time.Second); ; {
- select {
- case <-tick:
- atomic.AddInt64(&unixtime, int64(time.Second))
- }
- }
- }()
- }
- func now() time.Time {
- return time.Unix(0, atomic.LoadInt64(&unixtime))
- }
- type bootstrap struct{}
- func (d *bootstrap) Open(dsn string) (driver.Conn, error) {
- return Open(dsn)
- }
- // SetLogOutput allows to change output of the default logger
- func SetLogOutput(output io.Writer) {
- logOutput = output
- }
- // Open the connection
- func Open(dsn string) (driver.Conn, error) {
- clickhouse, err := open(dsn)
- if err != nil {
- return nil, err
- }
- return clickhouse, err
- }
- func open(dsn string) (*clickhouse, error) {
- url, err := url.Parse(dsn)
- if err != nil {
- return nil, err
- }
- var (
- hosts = []string{url.Host}
- query = url.Query()
- secure = false
- skipVerify = false
- tlsConfigName = query.Get("tls_config")
- noDelay = true
- compress = false
- database = query.Get("database")
- username = query.Get("username")
- password = query.Get("password")
- blockSize = 1000000
- connTimeout = DefaultConnTimeout
- readTimeout = DefaultReadTimeout
- writeTimeout = DefaultWriteTimeout
- connOpenStrategy = connOpenRandom
- checkConnLiveness = true
- )
- if len(database) == 0 {
- database = DefaultDatabase
- }
- if len(username) == 0 {
- username = DefaultUsername
- }
- if v, err := strconv.ParseBool(query.Get("no_delay")); err == nil {
- noDelay = v
- }
- tlsConfig := getTLSConfigClone(tlsConfigName)
- if tlsConfigName != "" && tlsConfig == nil {
- return nil, fmt.Errorf("invalid tls_config - no config registered under name %s", tlsConfigName)
- }
- secure = tlsConfig != nil
- if v, err := strconv.ParseBool(query.Get("secure")); err == nil {
- secure = v
- }
- if v, err := strconv.ParseBool(query.Get("skip_verify")); err == nil {
- skipVerify = v
- }
- if duration, err := strconv.ParseFloat(query.Get("timeout"), 64); err == nil {
- connTimeout = time.Duration(duration * float64(time.Second))
- }
- if duration, err := strconv.ParseFloat(query.Get("read_timeout"), 64); err == nil {
- readTimeout = time.Duration(duration * float64(time.Second))
- }
- if duration, err := strconv.ParseFloat(query.Get("write_timeout"), 64); err == nil {
- writeTimeout = time.Duration(duration * float64(time.Second))
- }
- if size, err := strconv.ParseInt(query.Get("block_size"), 10, 64); err == nil {
- blockSize = int(size)
- }
- if altHosts := strings.Split(query.Get("alt_hosts"), ","); len(altHosts) != 0 {
- for _, host := range altHosts {
- if len(host) != 0 {
- hosts = append(hosts, host)
- }
- }
- }
- switch query.Get("connection_open_strategy") {
- case "random":
- connOpenStrategy = connOpenRandom
- case "in_order":
- connOpenStrategy = connOpenInOrder
- case "time_random":
- connOpenStrategy = connOpenTimeRandom
- }
- settings, err := makeQuerySettings(query)
- if err != nil {
- return nil, err
- }
- if v, err := strconv.ParseBool(query.Get("compress")); err == nil {
- compress = v
- }
- if v, err := strconv.ParseBool(query.Get("check_connection_liveness")); err == nil {
- checkConnLiveness = v
- }
- if secure {
- // There is no way to check the liveness of a secure connection, as long as there is no access to raw TCP net.Conn
- checkConnLiveness = false
- }
- var (
- ch = clickhouse{
- logf: func(string, ...interface{}) {},
- settings: settings,
- compress: compress,
- blockSize: blockSize,
- checkConnLiveness: checkConnLiveness,
- ServerInfo: data.ServerInfo{
- Timezone: time.Local,
- },
- }
- logger = log.New(logOutput, "[clickhouse]", 0)
- )
- if debug, err := strconv.ParseBool(url.Query().Get("debug")); err == nil && debug {
- ch.logf = logger.Printf
- }
- ch.logf("host(s)=%s, database=%s, username=%s",
- strings.Join(hosts, ", "),
- database,
- username,
- )
- options := connOptions{
- secure: secure,
- tlsConfig: tlsConfig,
- skipVerify: skipVerify,
- hosts: hosts,
- connTimeout: connTimeout,
- readTimeout: readTimeout,
- writeTimeout: writeTimeout,
- noDelay: noDelay,
- openStrategy: connOpenStrategy,
- logf: ch.logf,
- }
- if ch.conn, err = dial(options); err != nil {
- return nil, err
- }
- logger.SetPrefix(fmt.Sprintf("[clickhouse][connect=%d]", ch.conn.ident))
- ch.buffer = bufio.NewWriter(ch.conn)
- ch.decoder = binary.NewDecoderWithCompress(ch.conn)
- ch.encoder = binary.NewEncoderWithCompress(ch.buffer)
- if err := ch.hello(database, username, password); err != nil {
- ch.conn.Close()
- return nil, err
- }
- return &ch, nil
- }
- func (ch *clickhouse) hello(database, username, password string) error {
- ch.logf("[hello] -> %s", ch.ClientInfo)
- {
- ch.encoder.Uvarint(protocol.ClientHello)
- if err := ch.ClientInfo.Write(ch.encoder); err != nil {
- return err
- }
- {
- ch.encoder.String(database)
- ch.encoder.String(username)
- ch.encoder.String(password)
- }
- if err := ch.encoder.Flush(); err != nil {
- return err
- }
- }
- {
- packet, err := ch.decoder.Uvarint()
- if err != nil {
- return err
- }
- switch packet {
- case protocol.ServerException:
- return ch.exception()
- case protocol.ServerHello:
- if err := ch.ServerInfo.Read(ch.decoder); err != nil {
- return err
- }
- case protocol.ServerEndOfStream:
- ch.logf("[bootstrap] <- end of stream")
- return nil
- default:
- return fmt.Errorf("[hello] unexpected packet [%d] from server", packet)
- }
- }
- ch.logf("[hello] <- %s", ch.ServerInfo)
- return nil
- }
|