client.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. package base
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "io/ioutil"
  10. "net/http"
  11. "net/url"
  12. "os"
  13. "strings"
  14. "time"
  15. )
  16. const (
  17. accessKey = "VOLC_ACCESSKEY"
  18. secretKey = "VOLC_SECRETKEY"
  19. defaultScheme = "http"
  20. )
  21. var _GlobalClient *http.Client
  22. func init() {
  23. _GlobalClient = &http.Client{
  24. Transport: &http.Transport{
  25. MaxIdleConns: 1000,
  26. MaxIdleConnsPerHost: 100,
  27. IdleConnTimeout: 10 * time.Second,
  28. },
  29. }
  30. }
  31. // Client 基础客户端
  32. type Client struct {
  33. Client *http.Client
  34. SdkVersion string
  35. ServiceInfo *ServiceInfo
  36. ApiInfoList map[string]*ApiInfo
  37. }
  38. // NewClient 生成一个客户端
  39. func NewClient(info *ServiceInfo, apiInfoList map[string]*ApiInfo) *Client {
  40. client := &Client{Client: _GlobalClient, ServiceInfo: info.Clone(), ApiInfoList: apiInfoList}
  41. if client.ServiceInfo.Scheme == "" {
  42. client.ServiceInfo.Scheme = defaultScheme
  43. }
  44. if os.Getenv(accessKey) != "" && os.Getenv(secretKey) != "" {
  45. client.ServiceInfo.Credentials.AccessKeyID = os.Getenv(accessKey)
  46. client.ServiceInfo.Credentials.SecretAccessKey = os.Getenv(secretKey)
  47. } else if _, err := os.Stat(os.Getenv("HOME") + "/.volc/config"); err == nil {
  48. if content, err := ioutil.ReadFile(os.Getenv("HOME") + "/.volc/config"); err == nil {
  49. m := make(map[string]string)
  50. json.Unmarshal(content, &m)
  51. if accessKey, ok := m["ak"]; ok {
  52. client.ServiceInfo.Credentials.AccessKeyID = accessKey
  53. }
  54. if secretKey, ok := m["sk"]; ok {
  55. client.ServiceInfo.Credentials.SecretAccessKey = secretKey
  56. }
  57. }
  58. }
  59. content, err := ioutil.ReadFile("VERSION")
  60. if err == nil {
  61. client.SdkVersion = strings.TrimSpace(string(content))
  62. client.ServiceInfo.Header.Set("User-Agent", strings.Join([]string{"volc-sdk-golang", client.SdkVersion}, "/"))
  63. }
  64. return client
  65. }
  66. func (serviceInfo *ServiceInfo) Clone() *ServiceInfo {
  67. ret := new(ServiceInfo)
  68. //base info
  69. ret.Timeout = serviceInfo.Timeout
  70. ret.Host = serviceInfo.Host
  71. ret.Scheme = serviceInfo.Scheme
  72. //credential
  73. ret.Credentials = serviceInfo.Credentials.Clone()
  74. // header
  75. ret.Header = serviceInfo.Header.Clone()
  76. return ret
  77. }
  78. func (cred Credentials) Clone() Credentials {
  79. return Credentials{
  80. Service: cred.Service,
  81. Region: cred.Region,
  82. SecretAccessKey: cred.SecretAccessKey,
  83. AccessKeyID: cred.AccessKeyID,
  84. SessionToken: cred.SessionToken,
  85. }
  86. }
  87. // SetAccessKey 设置AK
  88. func (client *Client) SetAccessKey(ak string) {
  89. if ak != "" {
  90. client.ServiceInfo.Credentials.AccessKeyID = ak
  91. }
  92. }
  93. // SetSecretKey 设置SK
  94. func (client *Client) SetSecretKey(sk string) {
  95. if sk != "" {
  96. client.ServiceInfo.Credentials.SecretAccessKey = sk
  97. }
  98. }
  99. // SetSessionToken
  100. func (client *Client) SetSessionToken(token string) {
  101. if token != "" {
  102. client.ServiceInfo.Credentials.SessionToken = token
  103. }
  104. }
  105. // SetHost 设置Host
  106. func (client *Client) SetHost(host string) {
  107. if host != "" {
  108. client.ServiceInfo.Host = host
  109. }
  110. }
  111. func (client *Client) SetScheme(scheme string) {
  112. if scheme != "" {
  113. client.ServiceInfo.Scheme = scheme
  114. }
  115. }
  116. // SetCredential 设置Credentials
  117. func (client *Client) SetCredential(c Credentials) {
  118. if c.AccessKeyID != "" {
  119. client.ServiceInfo.Credentials.AccessKeyID = c.AccessKeyID
  120. }
  121. if c.SecretAccessKey != "" {
  122. client.ServiceInfo.Credentials.SecretAccessKey = c.SecretAccessKey
  123. }
  124. if c.Region != "" {
  125. client.ServiceInfo.Credentials.Region = c.Region
  126. }
  127. if c.SessionToken != "" {
  128. client.ServiceInfo.Credentials.SessionToken = c.SessionToken
  129. }
  130. }
  131. func (client *Client) SetTimeout(timeout time.Duration) {
  132. if timeout > 0 {
  133. client.ServiceInfo.Timeout = timeout
  134. }
  135. }
  136. // GetSignUrl 获取签名字符串
  137. func (client *Client) GetSignUrl(api string, query url.Values) (string, error) {
  138. apiInfo := client.ApiInfoList[api]
  139. if apiInfo == nil {
  140. return "", errors.New("相关api不存在")
  141. }
  142. query = mergeQuery(query, apiInfo.Query)
  143. u := url.URL{
  144. Scheme: client.ServiceInfo.Scheme,
  145. Host: client.ServiceInfo.Host,
  146. Path: apiInfo.Path,
  147. RawQuery: query.Encode(),
  148. }
  149. req, err := http.NewRequest(strings.ToUpper(apiInfo.Method), u.String(), nil)
  150. if err != nil {
  151. return "", errors.New("构建request失败")
  152. }
  153. return client.ServiceInfo.Credentials.SignUrl(req), nil
  154. }
  155. // SignSts2 生成sts信息
  156. func (client *Client) SignSts2(inlinePolicy *Policy, expire time.Duration) (*SecurityToken2, error) {
  157. var err error
  158. sts := new(SecurityToken2)
  159. if sts.AccessKeyID, sts.SecretAccessKey, err = createTempAKSK(); err != nil {
  160. return nil, err
  161. }
  162. if expire < time.Minute {
  163. expire = time.Minute
  164. }
  165. now := time.Now()
  166. expireTime := now.Add(expire)
  167. sts.CurrentTime = now.Format(time.RFC3339)
  168. sts.ExpiredTime = expireTime.Format(time.RFC3339)
  169. innerToken, err := createInnerToken(client.ServiceInfo.Credentials, sts, inlinePolicy, expireTime.Unix())
  170. if err != nil {
  171. return nil, err
  172. }
  173. b, _ := json.Marshal(innerToken)
  174. sts.SessionToken = "STS2" + base64.StdEncoding.EncodeToString(b)
  175. return sts, nil
  176. }
  177. // Query 发起Get的query请求
  178. func (client *Client) Query(api string, query url.Values) ([]byte, int, error) {
  179. return client.requestWithContentType(api, query, "", "")
  180. }
  181. // Json 发起Json的post请求
  182. func (client *Client) Json(api string, query url.Values, body string) ([]byte, int, error) {
  183. return client.requestWithContentType(api, query, body, "application/json")
  184. }
  185. // PostWithContentType 发起自定义 Content-Type 的 post 请求,Content-Type 不可以为空
  186. func (client *Client) PostWithContentType(api string, query url.Values, body string, ct string) ([]byte, int, error) {
  187. return client.requestWithContentType(api, query, body, ct)
  188. }
  189. func (client *Client) requestWithContentType(api string, query url.Values, body string, ct string) ([]byte, int, error) {
  190. apiInfo := client.ApiInfoList[api]
  191. if apiInfo == nil {
  192. return []byte(""), 500, errors.New("相关api不存在")
  193. }
  194. timeout := getTimeout(client.ServiceInfo.Timeout, apiInfo.Timeout)
  195. header := mergeHeader(client.ServiceInfo.Header, apiInfo.Header)
  196. query = mergeQuery(query, apiInfo.Query)
  197. u := url.URL{
  198. Scheme: client.ServiceInfo.Scheme,
  199. Host: client.ServiceInfo.Host,
  200. Path: apiInfo.Path,
  201. RawQuery: query.Encode(),
  202. }
  203. var requestBody io.Reader
  204. if body != "" {
  205. requestBody = strings.NewReader(body)
  206. }
  207. req, err := http.NewRequest(strings.ToUpper(apiInfo.Method), u.String(), requestBody)
  208. if err != nil {
  209. return []byte(""), 500, errors.New("构建request失败")
  210. }
  211. req.Header = header
  212. if ct != "" {
  213. req.Header.Set("Content-Type", ct)
  214. }
  215. return client.makeRequest(api, req, timeout)
  216. }
  217. // Post 发起Post请求
  218. func (client *Client) Post(api string, query url.Values, form url.Values) ([]byte, int, error) {
  219. apiInfo := client.ApiInfoList[api]
  220. form = mergeQuery(form, apiInfo.Form)
  221. return client.requestWithContentType(api, query, form.Encode(), "application/x-www-form-urlencoded")
  222. }
  223. func (client *Client) makeRequest(api string, req *http.Request, timeout time.Duration) ([]byte, int, error) {
  224. req = client.ServiceInfo.Credentials.Sign(req)
  225. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  226. defer cancel()
  227. req = req.WithContext(ctx)
  228. resp, err := client.Client.Do(req)
  229. if err != nil {
  230. return []byte(""), 500, err
  231. }
  232. defer resp.Body.Close()
  233. body, err := ioutil.ReadAll(resp.Body)
  234. if err != nil {
  235. return []byte(""), resp.StatusCode, err
  236. }
  237. if resp.StatusCode < 200 || resp.StatusCode > 299 {
  238. return body, resp.StatusCode, fmt.Errorf("api %s http code %d body %s", api, resp.StatusCode, string(body))
  239. }
  240. return body, resp.StatusCode, nil
  241. }