| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- package cos
- import (
- "crypto/hmac"
- "crypto/sha1"
- "fmt"
- "hash"
- "net/http"
- "net/url"
- "sort"
- "strings"
- "sync"
- "time"
- )
- const sha1SignAlgorithm = "sha1"
- const privateHeaderPrefix = "x-cos-"
- const defaultAuthExpire = time.Hour
- // 需要校验的 Headers 列表
- var needSignHeaders = map[string]bool{
- "host": true,
- "range": true,
- "x-cos-acl": true,
- "x-cos-grant-read": true,
- "x-cos-grant-write": true,
- "x-cos-grant-full-control": true,
- "response-content-type": true,
- "response-content-language": true,
- "response-expires": true,
- "response-cache-control": true,
- "response-content-disposition": true,
- "response-content-encoding": true,
- "cache-control": true,
- "content-disposition": true,
- "content-encoding": true,
- "content-type": true,
- "content-length": true,
- "content-md5": true,
- "expect": true,
- "expires": true,
- "x-cos-content-sha1": true,
- "x-cos-storage-class": true,
- "if-modified-since": true,
- "origin": true,
- "access-control-request-method": true,
- "access-control-request-headers": true,
- "x-cos-object-type": true,
- }
- func safeURLEncode(s string) string {
- s = encodeURIComponent(s)
- s = strings.Replace(s, "!", "%21", -1)
- s = strings.Replace(s, "'", "%27", -1)
- s = strings.Replace(s, "(", "%28", -1)
- s = strings.Replace(s, ")", "%29", -1)
- s = strings.Replace(s, "*", "%2A", -1)
- return s
- }
- type valuesSignMap map[string][]string
- func (vs valuesSignMap) Add(key, value string) {
- key = strings.ToLower(key)
- vs[key] = append(vs[key], value)
- }
- func (vs valuesSignMap) Encode() string {
- var keys []string
- for k := range vs {
- keys = append(keys, k)
- }
- sort.Strings(keys)
- var pairs []string
- for _, k := range keys {
- items := vs[k]
- sort.Strings(items)
- for _, val := range items {
- pairs = append(
- pairs,
- fmt.Sprintf("%s=%s", safeURLEncode(k), safeURLEncode(val)))
- }
- }
- return strings.Join(pairs, "&")
- }
- // AuthTime 用于生成签名所需的 q-sign-time 和 q-key-time 相关参数
- type AuthTime struct {
- SignStartTime time.Time
- SignEndTime time.Time
- KeyStartTime time.Time
- KeyEndTime time.Time
- }
- // NewAuthTime 生成 AuthTime 的便捷函数
- //
- // expire: 从现在开始多久过期.
- func NewAuthTime(expire time.Duration) *AuthTime {
- signStartTime := time.Now()
- keyStartTime := signStartTime
- signEndTime := signStartTime.Add(expire)
- keyEndTime := signEndTime
- return &AuthTime{
- SignStartTime: signStartTime,
- SignEndTime: signEndTime,
- KeyStartTime: keyStartTime,
- KeyEndTime: keyEndTime,
- }
- }
- // signString return q-sign-time string
- func (a *AuthTime) signString() string {
- return fmt.Sprintf("%d;%d", a.SignStartTime.Unix(), a.SignEndTime.Unix())
- }
- // keyString return q-key-time string
- func (a *AuthTime) keyString() string {
- return fmt.Sprintf("%d;%d", a.KeyStartTime.Unix(), a.KeyEndTime.Unix())
- }
- // newAuthorization 通过一系列步骤生成最终需要的 Authorization 字符串
- func newAuthorization(secretID, secretKey string, req *http.Request, authTime *AuthTime) string {
- signTime := authTime.signString()
- keyTime := authTime.keyString()
- signKey := calSignKey(secretKey, keyTime)
- req.Header.Set("Host", req.Host)
- formatHeaders := *new(string)
- signedHeaderList := *new([]string)
- formatHeaders, signedHeaderList = genFormatHeaders(req.Header)
- formatParameters, signedParameterList := genFormatParameters(req.URL.Query())
- formatString := genFormatString(req.Method, *req.URL, formatParameters, formatHeaders)
- stringToSign := calStringToSign(sha1SignAlgorithm, keyTime, formatString)
- signature := calSignature(signKey, stringToSign)
- return genAuthorization(
- secretID, signTime, keyTime, signature, signedHeaderList,
- signedParameterList,
- )
- }
- // AddAuthorizationHeader 给 req 增加签名信息
- func AddAuthorizationHeader(secretID, secretKey string, sessionToken string, req *http.Request, authTime *AuthTime) {
- if secretID == "" {
- return
- }
- auth := newAuthorization(secretID, secretKey, req,
- authTime,
- )
- if len(sessionToken) > 0 {
- req.Header.Set("x-cos-security-token", sessionToken)
- }
- req.Header.Set("Authorization", auth)
- }
- // calSignKey 计算 SignKey
- func calSignKey(secretKey, keyTime string) string {
- digest := calHMACDigest(secretKey, keyTime, sha1SignAlgorithm)
- return fmt.Sprintf("%x", digest)
- }
- // calStringToSign 计算 StringToSign
- func calStringToSign(signAlgorithm, signTime, formatString string) string {
- h := sha1.New()
- h.Write([]byte(formatString))
- return fmt.Sprintf("%s\n%s\n%x\n", signAlgorithm, signTime, h.Sum(nil))
- }
- // calSignature 计算 Signature
- func calSignature(signKey, stringToSign string) string {
- digest := calHMACDigest(signKey, stringToSign, sha1SignAlgorithm)
- return fmt.Sprintf("%x", digest)
- }
- // genAuthorization 生成 Authorization
- func genAuthorization(secretID, signTime, keyTime, signature string, signedHeaderList, signedParameterList []string) string {
- return strings.Join([]string{
- "q-sign-algorithm=" + sha1SignAlgorithm,
- "q-ak=" + secretID,
- "q-sign-time=" + signTime,
- "q-key-time=" + keyTime,
- "q-header-list=" + strings.Join(signedHeaderList, ";"),
- "q-url-param-list=" + strings.Join(signedParameterList, ";"),
- "q-signature=" + signature,
- }, "&")
- }
- // genFormatString 生成 FormatString
- func genFormatString(method string, uri url.URL, formatParameters, formatHeaders string) string {
- formatMethod := strings.ToLower(method)
- formatURI := uri.Path
- return fmt.Sprintf("%s\n%s\n%s\n%s\n", formatMethod, formatURI,
- formatParameters, formatHeaders,
- )
- }
- // genFormatParameters 生成 FormatParameters 和 SignedParameterList
- // instead of the url.Values{}
- func genFormatParameters(parameters url.Values) (formatParameters string, signedParameterList []string) {
- ps := valuesSignMap{}
- for key, values := range parameters {
- key = strings.ToLower(key)
- for _, value := range values {
- ps.Add(key, value)
- signedParameterList = append(signedParameterList, key)
- }
- }
- //formatParameters = strings.ToLower(ps.Encode())
- formatParameters = ps.Encode()
- sort.Strings(signedParameterList)
- return
- }
- // genFormatHeaders 生成 FormatHeaders 和 SignedHeaderList
- func genFormatHeaders(headers http.Header) (formatHeaders string, signedHeaderList []string) {
- hs := valuesSignMap{}
- for key, values := range headers {
- key = strings.ToLower(key)
- for _, value := range values {
- if isSignHeader(key) {
- hs.Add(key, value)
- signedHeaderList = append(signedHeaderList, key)
- }
- }
- }
- formatHeaders = hs.Encode()
- sort.Strings(signedHeaderList)
- return
- }
- // HMAC 签名
- func calHMACDigest(key, msg, signMethod string) []byte {
- var hashFunc func() hash.Hash
- switch signMethod {
- case "sha1":
- hashFunc = sha1.New
- default:
- hashFunc = sha1.New
- }
- h := hmac.New(hashFunc, []byte(key))
- h.Write([]byte(msg))
- return h.Sum(nil)
- }
- func isSignHeader(key string) bool {
- for k, v := range needSignHeaders {
- if key == k && v {
- return true
- }
- }
- return strings.HasPrefix(key, privateHeaderPrefix)
- }
- // AuthorizationTransport 给请求增加 Authorization header
- type AuthorizationTransport struct {
- SecretID string
- SecretKey string
- SessionToken string
- rwLocker sync.RWMutex
- // 签名多久过期
- Expire time.Duration
- Transport http.RoundTripper
- }
- // SetCredential update the SecretID(ak), SercretKey(sk), sessiontoken
- func (t *AuthorizationTransport) SetCredential(ak, sk, token string) {
- t.rwLocker.Lock()
- defer t.rwLocker.Unlock()
- t.SecretID = ak
- t.SecretKey = sk
- t.SessionToken = token
- }
- // GetCredential get the ak, sk, token
- func (t *AuthorizationTransport) GetCredential() (string, string, string) {
- t.rwLocker.RLock()
- defer t.rwLocker.RUnlock()
- return t.SecretID, t.SecretKey, t.SessionToken
- }
- // RoundTrip implements the RoundTripper interface.
- func (t *AuthorizationTransport) RoundTrip(req *http.Request) (*http.Response, error) {
- req = cloneRequest(req) // per RoundTrip contract
- if t.Expire == time.Duration(0) {
- t.Expire = defaultAuthExpire
- }
- ak, sk, token := t.GetCredential()
- // 增加 Authorization header
- authTime := NewAuthTime(t.Expire)
- AddAuthorizationHeader(ak, sk, token, req, authTime)
- resp, err := t.transport().RoundTrip(req)
- return resp, err
- }
- func (t *AuthorizationTransport) transport() http.RoundTripper {
- if t.Transport != nil {
- return t.Transport
- }
- return http.DefaultTransport
- }
|