build.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. // Package rest provides RESTful serialization of AWS requests and responses.
  2. package rest
  3. import (
  4. "bytes"
  5. "encoding/base64"
  6. "fmt"
  7. "io"
  8. "math"
  9. "net/http"
  10. "net/url"
  11. "path"
  12. "reflect"
  13. "strconv"
  14. "strings"
  15. "time"
  16. "github.com/aws/aws-sdk-go/aws"
  17. "github.com/aws/aws-sdk-go/aws/awserr"
  18. "github.com/aws/aws-sdk-go/aws/request"
  19. "github.com/aws/aws-sdk-go/private/protocol"
  20. )
  21. const (
  22. floatNaN = "NaN"
  23. floatInf = "Infinity"
  24. floatNegInf = "-Infinity"
  25. )
  26. // Whether the byte value can be sent without escaping in AWS URLs
  27. var noEscape [256]bool
  28. var errValueNotSet = fmt.Errorf("value not set")
  29. var byteSliceType = reflect.TypeOf([]byte{})
  30. func init() {
  31. for i := 0; i < len(noEscape); i++ {
  32. // AWS expects every character except these to be escaped
  33. noEscape[i] = (i >= 'A' && i <= 'Z') ||
  34. (i >= 'a' && i <= 'z') ||
  35. (i >= '0' && i <= '9') ||
  36. i == '-' ||
  37. i == '.' ||
  38. i == '_' ||
  39. i == '~'
  40. }
  41. }
  42. // BuildHandler is a named request handler for building rest protocol requests
  43. var BuildHandler = request.NamedHandler{Name: "awssdk.rest.Build", Fn: Build}
  44. // Build builds the REST component of a service request.
  45. func Build(r *request.Request) {
  46. if r.ParamsFilled() {
  47. v := reflect.ValueOf(r.Params).Elem()
  48. buildLocationElements(r, v, false)
  49. buildBody(r, v)
  50. }
  51. }
  52. // BuildAsGET builds the REST component of a service request with the ability to hoist
  53. // data from the body.
  54. func BuildAsGET(r *request.Request) {
  55. if r.ParamsFilled() {
  56. v := reflect.ValueOf(r.Params).Elem()
  57. buildLocationElements(r, v, true)
  58. buildBody(r, v)
  59. }
  60. }
  61. func buildLocationElements(r *request.Request, v reflect.Value, buildGETQuery bool) {
  62. query := r.HTTPRequest.URL.Query()
  63. // Setup the raw path to match the base path pattern. This is needed
  64. // so that when the path is mutated a custom escaped version can be
  65. // stored in RawPath that will be used by the Go client.
  66. r.HTTPRequest.URL.RawPath = r.HTTPRequest.URL.Path
  67. for i := 0; i < v.NumField(); i++ {
  68. m := v.Field(i)
  69. if n := v.Type().Field(i).Name; n[0:1] == strings.ToLower(n[0:1]) {
  70. continue
  71. }
  72. if m.IsValid() {
  73. field := v.Type().Field(i)
  74. name := field.Tag.Get("locationName")
  75. if name == "" {
  76. name = field.Name
  77. }
  78. if kind := m.Kind(); kind == reflect.Ptr {
  79. m = m.Elem()
  80. } else if kind == reflect.Interface {
  81. if !m.Elem().IsValid() {
  82. continue
  83. }
  84. }
  85. if !m.IsValid() {
  86. continue
  87. }
  88. if field.Tag.Get("ignore") != "" {
  89. continue
  90. }
  91. // Support the ability to customize values to be marshaled as a
  92. // blob even though they were modeled as a string. Required for S3
  93. // API operations like SSECustomerKey is modeled as string but
  94. // required to be base64 encoded in request.
  95. if field.Tag.Get("marshal-as") == "blob" {
  96. m = m.Convert(byteSliceType)
  97. }
  98. var err error
  99. switch field.Tag.Get("location") {
  100. case "headers": // header maps
  101. err = buildHeaderMap(&r.HTTPRequest.Header, m, field.Tag)
  102. case "header":
  103. err = buildHeader(&r.HTTPRequest.Header, m, name, field.Tag)
  104. case "uri":
  105. err = buildURI(r.HTTPRequest.URL, m, name, field.Tag)
  106. case "querystring":
  107. err = buildQueryString(query, m, name, field.Tag)
  108. default:
  109. if buildGETQuery {
  110. err = buildQueryString(query, m, name, field.Tag)
  111. }
  112. }
  113. r.Error = err
  114. }
  115. if r.Error != nil {
  116. return
  117. }
  118. }
  119. r.HTTPRequest.URL.RawQuery = query.Encode()
  120. if !aws.BoolValue(r.Config.DisableRestProtocolURICleaning) {
  121. cleanPath(r.HTTPRequest.URL)
  122. }
  123. }
  124. func buildBody(r *request.Request, v reflect.Value) {
  125. if field, ok := v.Type().FieldByName("_"); ok {
  126. if payloadName := field.Tag.Get("payload"); payloadName != "" {
  127. pfield, _ := v.Type().FieldByName(payloadName)
  128. if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
  129. payload := reflect.Indirect(v.FieldByName(payloadName))
  130. if payload.IsValid() && payload.Interface() != nil {
  131. switch reader := payload.Interface().(type) {
  132. case io.ReadSeeker:
  133. r.SetReaderBody(reader)
  134. case []byte:
  135. r.SetBufferBody(reader)
  136. case string:
  137. r.SetStringBody(reader)
  138. default:
  139. r.Error = awserr.New(request.ErrCodeSerialization,
  140. "failed to encode REST request",
  141. fmt.Errorf("unknown payload type %s", payload.Type()))
  142. }
  143. }
  144. }
  145. }
  146. }
  147. }
  148. func buildHeader(header *http.Header, v reflect.Value, name string, tag reflect.StructTag) error {
  149. str, err := convertType(v, tag)
  150. if err == errValueNotSet {
  151. return nil
  152. } else if err != nil {
  153. return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err)
  154. }
  155. name = strings.TrimSpace(name)
  156. str = strings.TrimSpace(str)
  157. header.Add(name, str)
  158. return nil
  159. }
  160. func buildHeaderMap(header *http.Header, v reflect.Value, tag reflect.StructTag) error {
  161. prefix := tag.Get("locationName")
  162. for _, key := range v.MapKeys() {
  163. str, err := convertType(v.MapIndex(key), tag)
  164. if err == errValueNotSet {
  165. continue
  166. } else if err != nil {
  167. return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err)
  168. }
  169. keyStr := strings.TrimSpace(key.String())
  170. str = strings.TrimSpace(str)
  171. header.Add(prefix+keyStr, str)
  172. }
  173. return nil
  174. }
  175. func buildURI(u *url.URL, v reflect.Value, name string, tag reflect.StructTag) error {
  176. value, err := convertType(v, tag)
  177. if err == errValueNotSet {
  178. return nil
  179. } else if err != nil {
  180. return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err)
  181. }
  182. u.Path = strings.Replace(u.Path, "{"+name+"}", value, -1)
  183. u.Path = strings.Replace(u.Path, "{"+name+"+}", value, -1)
  184. u.RawPath = strings.Replace(u.RawPath, "{"+name+"}", EscapePath(value, true), -1)
  185. u.RawPath = strings.Replace(u.RawPath, "{"+name+"+}", EscapePath(value, false), -1)
  186. return nil
  187. }
  188. func buildQueryString(query url.Values, v reflect.Value, name string, tag reflect.StructTag) error {
  189. switch value := v.Interface().(type) {
  190. case []*string:
  191. for _, item := range value {
  192. query.Add(name, *item)
  193. }
  194. case map[string]*string:
  195. for key, item := range value {
  196. query.Add(key, *item)
  197. }
  198. case map[string][]*string:
  199. for key, items := range value {
  200. for _, item := range items {
  201. query.Add(key, *item)
  202. }
  203. }
  204. default:
  205. str, err := convertType(v, tag)
  206. if err == errValueNotSet {
  207. return nil
  208. } else if err != nil {
  209. return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err)
  210. }
  211. query.Set(name, str)
  212. }
  213. return nil
  214. }
  215. func cleanPath(u *url.URL) {
  216. hasSlash := strings.HasSuffix(u.Path, "/")
  217. // clean up path, removing duplicate `/`
  218. u.Path = path.Clean(u.Path)
  219. u.RawPath = path.Clean(u.RawPath)
  220. if hasSlash && !strings.HasSuffix(u.Path, "/") {
  221. u.Path += "/"
  222. u.RawPath += "/"
  223. }
  224. }
  225. // EscapePath escapes part of a URL path in Amazon style
  226. func EscapePath(path string, encodeSep bool) string {
  227. var buf bytes.Buffer
  228. for i := 0; i < len(path); i++ {
  229. c := path[i]
  230. if noEscape[c] || (c == '/' && !encodeSep) {
  231. buf.WriteByte(c)
  232. } else {
  233. fmt.Fprintf(&buf, "%%%02X", c)
  234. }
  235. }
  236. return buf.String()
  237. }
  238. func convertType(v reflect.Value, tag reflect.StructTag) (str string, err error) {
  239. v = reflect.Indirect(v)
  240. if !v.IsValid() {
  241. return "", errValueNotSet
  242. }
  243. switch value := v.Interface().(type) {
  244. case string:
  245. if tag.Get("suppressedJSONValue") == "true" && tag.Get("location") == "header" {
  246. value = base64.StdEncoding.EncodeToString([]byte(value))
  247. }
  248. str = value
  249. case []*string:
  250. if tag.Get("location") != "header" || tag.Get("enum") == "" {
  251. return "", fmt.Errorf("%T is only supported with location header and enum shapes", value)
  252. }
  253. buff := &bytes.Buffer{}
  254. for i, sv := range value {
  255. if sv == nil || len(*sv) == 0 {
  256. continue
  257. }
  258. if i != 0 {
  259. buff.WriteRune(',')
  260. }
  261. item := *sv
  262. if strings.Index(item, `,`) != -1 || strings.Index(item, `"`) != -1 {
  263. item = strconv.Quote(item)
  264. }
  265. buff.WriteString(item)
  266. }
  267. str = string(buff.Bytes())
  268. case []byte:
  269. str = base64.StdEncoding.EncodeToString(value)
  270. case bool:
  271. str = strconv.FormatBool(value)
  272. case int64:
  273. str = strconv.FormatInt(value, 10)
  274. case float64:
  275. switch {
  276. case math.IsNaN(value):
  277. str = floatNaN
  278. case math.IsInf(value, 1):
  279. str = floatInf
  280. case math.IsInf(value, -1):
  281. str = floatNegInf
  282. default:
  283. str = strconv.FormatFloat(value, 'f', -1, 64)
  284. }
  285. case time.Time:
  286. format := tag.Get("timestampFormat")
  287. if len(format) == 0 {
  288. format = protocol.RFC822TimeFormatName
  289. if tag.Get("location") == "querystring" {
  290. format = protocol.ISO8601TimeFormatName
  291. }
  292. }
  293. str = protocol.FormatTime(format, value)
  294. case aws.JSONValue:
  295. if len(value) == 0 {
  296. return "", errValueNotSet
  297. }
  298. escaping := protocol.NoEscape
  299. if tag.Get("location") == "header" {
  300. escaping = protocol.Base64Escape
  301. }
  302. str, err = protocol.EncodeJSONValue(value, escaping)
  303. if err != nil {
  304. return "", fmt.Errorf("unable to encode JSONValue, %v", err)
  305. }
  306. default:
  307. err := fmt.Errorf("unsupported value for param %v (%s)", v.Interface(), v.Type())
  308. return "", err
  309. }
  310. return str, nil
  311. }