transport.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. package tos
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "math/rand"
  6. "net"
  7. "net/http"
  8. "net/http/httptrace"
  9. "time"
  10. )
  11. type TransportConfig struct {
  12. // MaxIdleConns same as http.Transport MaxIdleConns. Default is 1024.
  13. MaxIdleConns int
  14. // MaxIdleConnsPerHost same as http.Transport MaxIdleConnsPerHost. Default is 1024.
  15. MaxIdleConnsPerHost int
  16. // MaxConnsPerHost same as http.Transport MaxConnsPerHost. Default is no limit.
  17. MaxConnsPerHost int
  18. // RequestTimeout same as http.Client Timeout
  19. // Deprecated: use ReadTimeout or WriteTimeout instead
  20. RequestTimeout time.Duration
  21. // DialTimeout same as net.Dialer Timeout
  22. DialTimeout time.Duration
  23. // KeepAlive same as net.Dialer KeepAlive
  24. KeepAlive time.Duration
  25. // IdleConnTimeout same as http.Transport IdleConnTimeout
  26. IdleConnTimeout time.Duration
  27. // TLSHandshakeTimeout same as http.Transport TLSHandshakeTimeout
  28. TLSHandshakeTimeout time.Duration
  29. // ResponseHeaderTimeout same as http.Transport ResponseHeaderTimeout
  30. ResponseHeaderTimeout time.Duration
  31. // ExpectContinueTimeout same as http.Transport ExpectContinueTimeout
  32. ExpectContinueTimeout time.Duration
  33. // ReadTimeout see net.Conn SetReadDeadline
  34. ReadTimeout time.Duration
  35. // WriteTimeout set net.Conn SetWriteDeadline
  36. WriteTimeout time.Duration
  37. // InsecureSkipVerify set tls.Config InsecureSkipVerify
  38. InsecureSkipVerify bool
  39. // DNSCacheTime Set Dns Cache Time.
  40. DNSCacheTime time.Duration
  41. // Proxy Set http proxy for http client.
  42. Proxy *Proxy
  43. }
  44. type Transport interface {
  45. RoundTrip(context.Context, *Request) (*Response, error)
  46. }
  47. type DefaultTransport struct {
  48. client http.Client
  49. logger Logger
  50. }
  51. func (d *DefaultTransport) WithDefaultTransportLogger(logger Logger) {
  52. d.logger = logger
  53. }
  54. // NewDefaultTransport create a DefaultTransport with config
  55. func NewDefaultTransport(config *TransportConfig) *DefaultTransport {
  56. var r *resolver
  57. if config.DNSCacheTime >= time.Minute {
  58. r = newResolver(config.DNSCacheTime)
  59. }
  60. transport := &http.Transport{
  61. DialContext: (&TimeoutDialer{
  62. Dialer: net.Dialer{
  63. Timeout: config.DialTimeout,
  64. KeepAlive: config.KeepAlive,
  65. },
  66. resolver: r,
  67. ReadTimeout: config.ReadTimeout,
  68. WriteTimeout: config.WriteTimeout,
  69. }).DialContext,
  70. MaxIdleConns: config.MaxIdleConns,
  71. MaxIdleConnsPerHost: config.MaxIdleConnsPerHost,
  72. MaxConnsPerHost: config.MaxConnsPerHost,
  73. IdleConnTimeout: config.IdleConnTimeout,
  74. TLSHandshakeTimeout: config.TLSHandshakeTimeout,
  75. ResponseHeaderTimeout: config.ResponseHeaderTimeout,
  76. ExpectContinueTimeout: config.ExpectContinueTimeout,
  77. DisableCompression: true,
  78. // #nosec G402
  79. TLSClientConfig: &tls.Config{InsecureSkipVerify: config.InsecureSkipVerify},
  80. }
  81. if config.Proxy != nil && config.Proxy.proxyHost != "" {
  82. transport.Proxy = http.ProxyURL(config.Proxy.Url())
  83. }
  84. return &DefaultTransport{
  85. client: http.Client{
  86. CheckRedirect: func(req *http.Request, via []*http.Request) error {
  87. return http.ErrUseLastResponse
  88. },
  89. Transport: transport,
  90. },
  91. }
  92. }
  93. // newDefaultTranposrtWithHTTPTransport
  94. func newDefaultTranposrtWithHTTPTransport(transport http.RoundTripper) *DefaultTransport {
  95. return &DefaultTransport{
  96. client: http.Client{
  97. CheckRedirect: func(req *http.Request, via []*http.Request) error {
  98. return http.ErrUseLastResponse
  99. },
  100. Transport: transport,
  101. },
  102. }
  103. }
  104. // NewDefaultTransportWithClient crate a DefaultTransport with a http.Client
  105. func NewDefaultTransportWithClient(client http.Client) *DefaultTransport {
  106. return &DefaultTransport{client: client}
  107. }
  108. func (dt *DefaultTransport) RoundTrip(ctx context.Context, req *Request) (*Response, error) {
  109. hr, err := http.NewRequestWithContext(ctx, req.Method, req.URL(), req.Content)
  110. if err != nil {
  111. return nil, newTosClientError(err.Error(), err)
  112. }
  113. if req.ContentLength != nil {
  114. hr.ContentLength = *req.ContentLength
  115. }
  116. for key, values := range req.Header {
  117. hr.Header[key] = values
  118. }
  119. var accessLog *accessLogRequest
  120. if dt.logger != nil {
  121. var trace *httptrace.ClientTrace
  122. trace, accessLog = getClientTrace(GetUnixTimeMs())
  123. ctx = httptrace.WithClientTrace(ctx, trace)
  124. hr = hr.WithContext(ctx)
  125. }
  126. res, err := dt.client.Do(hr)
  127. if accessLog != nil {
  128. accessLog.PrintAccessLog(dt.logger, hr, res)
  129. }
  130. if err != nil {
  131. return nil, newTosClientError(err.Error(), err)
  132. }
  133. return &Response{
  134. StatusCode: res.StatusCode,
  135. ContentLength: res.ContentLength,
  136. Header: res.Header,
  137. Body: res.Body,
  138. }, nil
  139. }
  140. type TimeoutConn struct {
  141. net.Conn
  142. readTimeout time.Duration
  143. writeTimeout time.Duration
  144. zero time.Time
  145. }
  146. func NewTimeoutConn(conn net.Conn, readTimeout, writeTimeout time.Duration) *TimeoutConn {
  147. return &TimeoutConn{
  148. Conn: conn,
  149. readTimeout: readTimeout,
  150. writeTimeout: writeTimeout,
  151. }
  152. }
  153. func (tc *TimeoutConn) Read(b []byte) (n int, err error) {
  154. timeout := tc.readTimeout > 0
  155. if timeout {
  156. _ = tc.SetReadDeadline(time.Now().Add(tc.readTimeout))
  157. }
  158. n, err = tc.Conn.Read(b)
  159. if timeout {
  160. _ = tc.SetReadDeadline(time.Now().Add(tc.readTimeout * 5))
  161. }
  162. return n, err
  163. }
  164. func (tc *TimeoutConn) Write(b []byte) (n int, err error) {
  165. timeout := tc.writeTimeout > 0
  166. if timeout {
  167. _ = tc.SetWriteDeadline(time.Now().Add(tc.writeTimeout))
  168. }
  169. n, err = tc.Conn.Write(b)
  170. if tc.readTimeout > 0 {
  171. _ = tc.SetReadDeadline(time.Now().Add(tc.readTimeout * 5))
  172. }
  173. return n, err
  174. }
  175. type TimeoutDialer struct {
  176. net.Dialer
  177. resolver *resolver
  178. ReadTimeout time.Duration
  179. WriteTimeout time.Duration
  180. }
  181. func (d *TimeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  182. if d.resolver != nil {
  183. host, port, err := net.SplitHostPort(address)
  184. if err != nil {
  185. return nil, err
  186. }
  187. ipList, err := d.resolver.GetIpList(ctx, host)
  188. if err != nil {
  189. return nil, err
  190. }
  191. // 随机打乱 IP List
  192. rand.Shuffle(len(ipList), func(i, j int) {
  193. ipList[i], ipList[j] = ipList[j], ipList[i]
  194. })
  195. for _, ip := range ipList {
  196. conn, err := d.Dialer.DialContext(ctx, network, ip+":"+port)
  197. if err == nil {
  198. return NewTimeoutConn(conn, d.ReadTimeout, d.WriteTimeout), nil
  199. } else {
  200. d.resolver.Remove(address, ip)
  201. }
  202. }
  203. }
  204. conn, err := d.Dialer.DialContext(ctx, network, address)
  205. if err != nil {
  206. return nil, err
  207. }
  208. return NewTimeoutConn(conn, d.ReadTimeout, d.WriteTimeout), nil
  209. }