unmarshal.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. package rest
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "math"
  9. "net/http"
  10. "reflect"
  11. "strconv"
  12. "strings"
  13. "time"
  14. "github.com/aws/aws-sdk-go/aws"
  15. "github.com/aws/aws-sdk-go/aws/awserr"
  16. "github.com/aws/aws-sdk-go/aws/request"
  17. awsStrings "github.com/aws/aws-sdk-go/internal/strings"
  18. "github.com/aws/aws-sdk-go/private/protocol"
  19. )
  20. // UnmarshalHandler is a named request handler for unmarshaling rest protocol requests
  21. var UnmarshalHandler = request.NamedHandler{Name: "awssdk.rest.Unmarshal", Fn: Unmarshal}
  22. // UnmarshalMetaHandler is a named request handler for unmarshaling rest protocol request metadata
  23. var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta", Fn: UnmarshalMeta}
  24. // Unmarshal unmarshals the REST component of a response in a REST service.
  25. func Unmarshal(r *request.Request) {
  26. if r.DataFilled() {
  27. v := reflect.Indirect(reflect.ValueOf(r.Data))
  28. if err := unmarshalBody(r, v); err != nil {
  29. r.Error = err
  30. }
  31. }
  32. }
  33. // UnmarshalMeta unmarshals the REST metadata of a response in a REST service
  34. func UnmarshalMeta(r *request.Request) {
  35. r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
  36. if r.RequestID == "" {
  37. // Alternative version of request id in the header
  38. r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id")
  39. }
  40. if r.DataFilled() {
  41. if err := UnmarshalResponse(r.HTTPResponse, r.Data, aws.BoolValue(r.Config.LowerCaseHeaderMaps)); err != nil {
  42. r.Error = err
  43. }
  44. }
  45. }
  46. // UnmarshalResponse attempts to unmarshal the REST response headers to
  47. // the data type passed in. The type must be a pointer. An error is returned
  48. // with any error unmarshaling the response into the target datatype.
  49. func UnmarshalResponse(resp *http.Response, data interface{}, lowerCaseHeaderMaps bool) error {
  50. v := reflect.Indirect(reflect.ValueOf(data))
  51. return unmarshalLocationElements(resp, v, lowerCaseHeaderMaps)
  52. }
  53. func unmarshalBody(r *request.Request, v reflect.Value) error {
  54. if field, ok := v.Type().FieldByName("_"); ok {
  55. if payloadName := field.Tag.Get("payload"); payloadName != "" {
  56. pfield, _ := v.Type().FieldByName(payloadName)
  57. if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
  58. payload := v.FieldByName(payloadName)
  59. if payload.IsValid() {
  60. switch payload.Interface().(type) {
  61. case []byte:
  62. defer r.HTTPResponse.Body.Close()
  63. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  64. if err != nil {
  65. return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
  66. }
  67. payload.Set(reflect.ValueOf(b))
  68. case *string:
  69. defer r.HTTPResponse.Body.Close()
  70. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  71. if err != nil {
  72. return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
  73. }
  74. str := string(b)
  75. payload.Set(reflect.ValueOf(&str))
  76. default:
  77. switch payload.Type().String() {
  78. case "io.ReadCloser":
  79. payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
  80. case "io.ReadSeeker":
  81. b, err := ioutil.ReadAll(r.HTTPResponse.Body)
  82. if err != nil {
  83. return awserr.New(request.ErrCodeSerialization,
  84. "failed to read response body", err)
  85. }
  86. payload.Set(reflect.ValueOf(ioutil.NopCloser(bytes.NewReader(b))))
  87. default:
  88. io.Copy(ioutil.Discard, r.HTTPResponse.Body)
  89. r.HTTPResponse.Body.Close()
  90. return awserr.New(request.ErrCodeSerialization,
  91. "failed to decode REST response",
  92. fmt.Errorf("unknown payload type %s", payload.Type()))
  93. }
  94. }
  95. }
  96. }
  97. }
  98. }
  99. return nil
  100. }
  101. func unmarshalLocationElements(resp *http.Response, v reflect.Value, lowerCaseHeaderMaps bool) error {
  102. for i := 0; i < v.NumField(); i++ {
  103. m, field := v.Field(i), v.Type().Field(i)
  104. if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
  105. continue
  106. }
  107. if m.IsValid() {
  108. name := field.Tag.Get("locationName")
  109. if name == "" {
  110. name = field.Name
  111. }
  112. switch field.Tag.Get("location") {
  113. case "statusCode":
  114. unmarshalStatusCode(m, resp.StatusCode)
  115. case "header":
  116. err := unmarshalHeader(m, resp.Header.Get(name), field.Tag)
  117. if err != nil {
  118. return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
  119. }
  120. case "headers":
  121. prefix := field.Tag.Get("locationName")
  122. err := unmarshalHeaderMap(m, resp.Header, prefix, lowerCaseHeaderMaps)
  123. if err != nil {
  124. return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
  125. }
  126. }
  127. }
  128. }
  129. return nil
  130. }
  131. func unmarshalStatusCode(v reflect.Value, statusCode int) {
  132. if !v.IsValid() {
  133. return
  134. }
  135. switch v.Interface().(type) {
  136. case *int64:
  137. s := int64(statusCode)
  138. v.Set(reflect.ValueOf(&s))
  139. }
  140. }
  141. func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string, normalize bool) error {
  142. if len(headers) == 0 {
  143. return nil
  144. }
  145. switch r.Interface().(type) {
  146. case map[string]*string: // we only support string map value types
  147. out := map[string]*string{}
  148. for k, v := range headers {
  149. if awsStrings.HasPrefixFold(k, prefix) {
  150. if normalize == true {
  151. k = strings.ToLower(k)
  152. } else {
  153. k = http.CanonicalHeaderKey(k)
  154. }
  155. out[k[len(prefix):]] = &v[0]
  156. }
  157. }
  158. if len(out) != 0 {
  159. r.Set(reflect.ValueOf(out))
  160. }
  161. }
  162. return nil
  163. }
  164. func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) error {
  165. switch tag.Get("type") {
  166. case "jsonvalue":
  167. if len(header) == 0 {
  168. return nil
  169. }
  170. case "blob":
  171. if len(header) == 0 {
  172. return nil
  173. }
  174. default:
  175. if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
  176. return nil
  177. }
  178. }
  179. switch v.Interface().(type) {
  180. case *string:
  181. if tag.Get("suppressedJSONValue") == "true" && tag.Get("location") == "header" {
  182. b, err := base64.StdEncoding.DecodeString(header)
  183. if err != nil {
  184. return fmt.Errorf("failed to decode JSONValue, %v", err)
  185. }
  186. header = string(b)
  187. }
  188. v.Set(reflect.ValueOf(&header))
  189. case []byte:
  190. b, err := base64.StdEncoding.DecodeString(header)
  191. if err != nil {
  192. return err
  193. }
  194. v.Set(reflect.ValueOf(b))
  195. case *bool:
  196. b, err := strconv.ParseBool(header)
  197. if err != nil {
  198. return err
  199. }
  200. v.Set(reflect.ValueOf(&b))
  201. case *int64:
  202. i, err := strconv.ParseInt(header, 10, 64)
  203. if err != nil {
  204. return err
  205. }
  206. v.Set(reflect.ValueOf(&i))
  207. case *float64:
  208. var f float64
  209. switch {
  210. case strings.EqualFold(header, floatNaN):
  211. f = math.NaN()
  212. case strings.EqualFold(header, floatInf):
  213. f = math.Inf(1)
  214. case strings.EqualFold(header, floatNegInf):
  215. f = math.Inf(-1)
  216. default:
  217. var err error
  218. f, err = strconv.ParseFloat(header, 64)
  219. if err != nil {
  220. return err
  221. }
  222. }
  223. v.Set(reflect.ValueOf(&f))
  224. case *time.Time:
  225. format := tag.Get("timestampFormat")
  226. if len(format) == 0 {
  227. format = protocol.RFC822TimeFormatName
  228. }
  229. t, err := protocol.ParseTime(format, header)
  230. if err != nil {
  231. return err
  232. }
  233. v.Set(reflect.ValueOf(&t))
  234. case aws.JSONValue:
  235. escaping := protocol.NoEscape
  236. if tag.Get("location") == "header" {
  237. escaping = protocol.Base64Escape
  238. }
  239. m, err := protocol.DecodeJSONValue(header, escaping)
  240. if err != nil {
  241. return err
  242. }
  243. v.Set(reflect.ValueOf(m))
  244. default:
  245. err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
  246. return err
  247. }
  248. return nil
  249. }