decode.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. package msgpack
  2. import (
  3. "bufio"
  4. "bytes"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "reflect"
  9. "time"
  10. "github.com/vmihailenco/msgpack/codes"
  11. )
  12. const bytesAllocLimit = 1024 * 1024 // 1mb
  13. type bufReader interface {
  14. io.Reader
  15. io.ByteScanner
  16. }
  17. func newBufReader(r io.Reader) bufReader {
  18. if br, ok := r.(bufReader); ok {
  19. return br
  20. }
  21. return bufio.NewReader(r)
  22. }
  23. func makeBuffer() []byte {
  24. return make([]byte, 0, 64)
  25. }
  26. // Unmarshal decodes the MessagePack-encoded data and stores the result
  27. // in the value pointed to by v.
  28. func Unmarshal(data []byte, v interface{}) error {
  29. return NewDecoder(bytes.NewReader(data)).Decode(v)
  30. }
  31. type Decoder struct {
  32. r io.Reader
  33. s io.ByteScanner
  34. buf []byte
  35. extLen int
  36. rec []byte // accumulates read data if not nil
  37. useLoose bool
  38. useJSONTag bool
  39. decodeMapFunc func(*Decoder) (interface{}, error)
  40. }
  41. // NewDecoder returns a new decoder that reads from r.
  42. //
  43. // The decoder introduces its own buffering and may read data from r
  44. // beyond the MessagePack values requested. Buffering can be disabled
  45. // by passing a reader that implements io.ByteScanner interface.
  46. func NewDecoder(r io.Reader) *Decoder {
  47. d := &Decoder{
  48. buf: makeBuffer(),
  49. }
  50. d.resetReader(r)
  51. return d
  52. }
  53. func (d *Decoder) SetDecodeMapFunc(fn func(*Decoder) (interface{}, error)) {
  54. d.decodeMapFunc = fn
  55. }
  56. // UseDecodeInterfaceLoose causes decoder to use DecodeInterfaceLoose
  57. // to decode msgpack value into Go interface{}.
  58. func (d *Decoder) UseDecodeInterfaceLoose(flag bool) *Decoder {
  59. d.useLoose = flag
  60. return d
  61. }
  62. // UseJSONTag causes the Decoder to use json struct tag as fallback option
  63. // if there is no msgpack tag.
  64. func (d *Decoder) UseJSONTag(v bool) *Decoder {
  65. d.useJSONTag = v
  66. return d
  67. }
  68. func (d *Decoder) Reset(r io.Reader) error {
  69. d.resetReader(r)
  70. return nil
  71. }
  72. func (d *Decoder) resetReader(r io.Reader) {
  73. reader := newBufReader(r)
  74. d.r = reader
  75. d.s = reader
  76. }
  77. func (d *Decoder) Decode(v interface{}) error {
  78. var err error
  79. switch v := v.(type) {
  80. case *string:
  81. if v != nil {
  82. *v, err = d.DecodeString()
  83. return err
  84. }
  85. case *[]byte:
  86. if v != nil {
  87. return d.decodeBytesPtr(v)
  88. }
  89. case *int:
  90. if v != nil {
  91. *v, err = d.DecodeInt()
  92. return err
  93. }
  94. case *int8:
  95. if v != nil {
  96. *v, err = d.DecodeInt8()
  97. return err
  98. }
  99. case *int16:
  100. if v != nil {
  101. *v, err = d.DecodeInt16()
  102. return err
  103. }
  104. case *int32:
  105. if v != nil {
  106. *v, err = d.DecodeInt32()
  107. return err
  108. }
  109. case *int64:
  110. if v != nil {
  111. *v, err = d.DecodeInt64()
  112. return err
  113. }
  114. case *uint:
  115. if v != nil {
  116. *v, err = d.DecodeUint()
  117. return err
  118. }
  119. case *uint8:
  120. if v != nil {
  121. *v, err = d.DecodeUint8()
  122. return err
  123. }
  124. case *uint16:
  125. if v != nil {
  126. *v, err = d.DecodeUint16()
  127. return err
  128. }
  129. case *uint32:
  130. if v != nil {
  131. *v, err = d.DecodeUint32()
  132. return err
  133. }
  134. case *uint64:
  135. if v != nil {
  136. *v, err = d.DecodeUint64()
  137. return err
  138. }
  139. case *bool:
  140. if v != nil {
  141. *v, err = d.DecodeBool()
  142. return err
  143. }
  144. case *float32:
  145. if v != nil {
  146. *v, err = d.DecodeFloat32()
  147. return err
  148. }
  149. case *float64:
  150. if v != nil {
  151. *v, err = d.DecodeFloat64()
  152. return err
  153. }
  154. case *[]string:
  155. return d.decodeStringSlicePtr(v)
  156. case *map[string]string:
  157. return d.decodeMapStringStringPtr(v)
  158. case *map[string]interface{}:
  159. return d.decodeMapStringInterfacePtr(v)
  160. case *time.Duration:
  161. if v != nil {
  162. vv, err := d.DecodeInt64()
  163. *v = time.Duration(vv)
  164. return err
  165. }
  166. case *time.Time:
  167. if v != nil {
  168. *v, err = d.DecodeTime()
  169. return err
  170. }
  171. }
  172. vv := reflect.ValueOf(v)
  173. if !vv.IsValid() {
  174. return errors.New("msgpack: Decode(nil)")
  175. }
  176. if vv.Kind() != reflect.Ptr {
  177. return fmt.Errorf("msgpack: Decode(nonsettable %T)", v)
  178. }
  179. vv = vv.Elem()
  180. if !vv.IsValid() {
  181. return fmt.Errorf("msgpack: Decode(nonsettable %T)", v)
  182. }
  183. return d.DecodeValue(vv)
  184. }
  185. func (d *Decoder) DecodeMulti(v ...interface{}) error {
  186. for _, vv := range v {
  187. if err := d.Decode(vv); err != nil {
  188. return err
  189. }
  190. }
  191. return nil
  192. }
  193. func (d *Decoder) decodeInterfaceCond() (interface{}, error) {
  194. if d.useLoose {
  195. return d.DecodeInterfaceLoose()
  196. }
  197. return d.DecodeInterface()
  198. }
  199. func (d *Decoder) DecodeValue(v reflect.Value) error {
  200. decode := getDecoder(v.Type())
  201. return decode(d, v)
  202. }
  203. func (d *Decoder) DecodeNil() error {
  204. c, err := d.readCode()
  205. if err != nil {
  206. return err
  207. }
  208. if c != codes.Nil {
  209. return fmt.Errorf("msgpack: invalid code=%x decoding nil", c)
  210. }
  211. return nil
  212. }
  213. func (d *Decoder) decodeNilValue(v reflect.Value) error {
  214. err := d.DecodeNil()
  215. if v.IsNil() {
  216. return err
  217. }
  218. if v.Kind() == reflect.Ptr {
  219. v = v.Elem()
  220. }
  221. v.Set(reflect.Zero(v.Type()))
  222. return err
  223. }
  224. func (d *Decoder) DecodeBool() (bool, error) {
  225. c, err := d.readCode()
  226. if err != nil {
  227. return false, err
  228. }
  229. return d.bool(c)
  230. }
  231. func (d *Decoder) bool(c codes.Code) (bool, error) {
  232. if c == codes.False {
  233. return false, nil
  234. }
  235. if c == codes.True {
  236. return true, nil
  237. }
  238. return false, fmt.Errorf("msgpack: invalid code=%x decoding bool", c)
  239. }
  240. // DecodeInterface decodes value into interface. It returns following types:
  241. // - nil,
  242. // - bool,
  243. // - int8, int16, int32, int64,
  244. // - uint8, uint16, uint32, uint64,
  245. // - float32 and float64,
  246. // - string,
  247. // - []byte,
  248. // - slices of any of the above,
  249. // - maps of any of the above.
  250. //
  251. // DecodeInterface should be used only when you don't know the type of value
  252. // you are decoding. For example, if you are decoding number it is better to use
  253. // DecodeInt64 for negative numbers and DecodeUint64 for positive numbers.
  254. func (d *Decoder) DecodeInterface() (interface{}, error) {
  255. c, err := d.readCode()
  256. if err != nil {
  257. return nil, err
  258. }
  259. if codes.IsFixedNum(c) {
  260. return int8(c), nil
  261. }
  262. if codes.IsFixedMap(c) {
  263. err = d.s.UnreadByte()
  264. if err != nil {
  265. return nil, err
  266. }
  267. return d.DecodeMap()
  268. }
  269. if codes.IsFixedArray(c) {
  270. return d.decodeSlice(c)
  271. }
  272. if codes.IsFixedString(c) {
  273. return d.string(c)
  274. }
  275. switch c {
  276. case codes.Nil:
  277. return nil, nil
  278. case codes.False, codes.True:
  279. return d.bool(c)
  280. case codes.Float:
  281. return d.float32(c)
  282. case codes.Double:
  283. return d.float64(c)
  284. case codes.Uint8:
  285. return d.uint8()
  286. case codes.Uint16:
  287. return d.uint16()
  288. case codes.Uint32:
  289. return d.uint32()
  290. case codes.Uint64:
  291. return d.uint64()
  292. case codes.Int8:
  293. return d.int8()
  294. case codes.Int16:
  295. return d.int16()
  296. case codes.Int32:
  297. return d.int32()
  298. case codes.Int64:
  299. return d.int64()
  300. case codes.Bin8, codes.Bin16, codes.Bin32:
  301. return d.bytes(c, nil)
  302. case codes.Str8, codes.Str16, codes.Str32:
  303. return d.string(c)
  304. case codes.Array16, codes.Array32:
  305. return d.decodeSlice(c)
  306. case codes.Map16, codes.Map32:
  307. err = d.s.UnreadByte()
  308. if err != nil {
  309. return nil, err
  310. }
  311. return d.DecodeMap()
  312. case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
  313. codes.Ext8, codes.Ext16, codes.Ext32:
  314. return d.extInterface(c)
  315. }
  316. return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
  317. }
  318. // DecodeInterfaceLoose is like DecodeInterface except that:
  319. // - int8, int16, and int32 are converted to int64,
  320. // - uint8, uint16, and uint32 are converted to uint64,
  321. // - float32 is converted to float64.
  322. func (d *Decoder) DecodeInterfaceLoose() (interface{}, error) {
  323. c, err := d.readCode()
  324. if err != nil {
  325. return nil, err
  326. }
  327. if codes.IsFixedNum(c) {
  328. return int64(c), nil
  329. }
  330. if codes.IsFixedMap(c) {
  331. err = d.s.UnreadByte()
  332. if err != nil {
  333. return nil, err
  334. }
  335. return d.DecodeMap()
  336. }
  337. if codes.IsFixedArray(c) {
  338. return d.decodeSlice(c)
  339. }
  340. if codes.IsFixedString(c) {
  341. return d.string(c)
  342. }
  343. switch c {
  344. case codes.Nil:
  345. return nil, nil
  346. case codes.False, codes.True:
  347. return d.bool(c)
  348. case codes.Float, codes.Double:
  349. return d.float64(c)
  350. case codes.Uint8, codes.Uint16, codes.Uint32, codes.Uint64:
  351. return d.uint(c)
  352. case codes.Int8, codes.Int16, codes.Int32, codes.Int64:
  353. return d.int(c)
  354. case codes.Bin8, codes.Bin16, codes.Bin32:
  355. return d.bytes(c, nil)
  356. case codes.Str8, codes.Str16, codes.Str32:
  357. return d.string(c)
  358. case codes.Array16, codes.Array32:
  359. return d.decodeSlice(c)
  360. case codes.Map16, codes.Map32:
  361. err = d.s.UnreadByte()
  362. if err != nil {
  363. return nil, err
  364. }
  365. return d.DecodeMap()
  366. case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
  367. codes.Ext8, codes.Ext16, codes.Ext32:
  368. return d.extInterface(c)
  369. }
  370. return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
  371. }
  372. // Skip skips next value.
  373. func (d *Decoder) Skip() error {
  374. c, err := d.readCode()
  375. if err != nil {
  376. return err
  377. }
  378. if codes.IsFixedNum(c) {
  379. return nil
  380. } else if codes.IsFixedMap(c) {
  381. return d.skipMap(c)
  382. } else if codes.IsFixedArray(c) {
  383. return d.skipSlice(c)
  384. } else if codes.IsFixedString(c) {
  385. return d.skipBytes(c)
  386. }
  387. switch c {
  388. case codes.Nil, codes.False, codes.True:
  389. return nil
  390. case codes.Uint8, codes.Int8:
  391. return d.skipN(1)
  392. case codes.Uint16, codes.Int16:
  393. return d.skipN(2)
  394. case codes.Uint32, codes.Int32, codes.Float:
  395. return d.skipN(4)
  396. case codes.Uint64, codes.Int64, codes.Double:
  397. return d.skipN(8)
  398. case codes.Bin8, codes.Bin16, codes.Bin32:
  399. return d.skipBytes(c)
  400. case codes.Str8, codes.Str16, codes.Str32:
  401. return d.skipBytes(c)
  402. case codes.Array16, codes.Array32:
  403. return d.skipSlice(c)
  404. case codes.Map16, codes.Map32:
  405. return d.skipMap(c)
  406. case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
  407. codes.Ext8, codes.Ext16, codes.Ext32:
  408. return d.skipExt(c)
  409. }
  410. return fmt.Errorf("msgpack: unknown code %x", c)
  411. }
  412. // PeekCode returns the next MessagePack code without advancing the reader.
  413. // Subpackage msgpack/codes contains list of available codes.
  414. func (d *Decoder) PeekCode() (codes.Code, error) {
  415. c, err := d.s.ReadByte()
  416. if err != nil {
  417. return 0, err
  418. }
  419. return codes.Code(c), d.s.UnreadByte()
  420. }
  421. func (d *Decoder) hasNilCode() bool {
  422. code, err := d.PeekCode()
  423. return err == nil && code == codes.Nil
  424. }
  425. func (d *Decoder) readCode() (codes.Code, error) {
  426. d.extLen = 0
  427. c, err := d.s.ReadByte()
  428. if err != nil {
  429. return 0, err
  430. }
  431. if d.rec != nil {
  432. d.rec = append(d.rec, c)
  433. }
  434. return codes.Code(c), nil
  435. }
  436. func (d *Decoder) readFull(b []byte) error {
  437. _, err := io.ReadFull(d.r, b)
  438. if err != nil {
  439. return err
  440. }
  441. if d.rec != nil {
  442. d.rec = append(d.rec, b...)
  443. }
  444. return nil
  445. }
  446. func (d *Decoder) readN(n int) ([]byte, error) {
  447. buf, err := readN(d.r, d.buf, n)
  448. if err != nil {
  449. return nil, err
  450. }
  451. d.buf = buf
  452. if d.rec != nil {
  453. d.rec = append(d.rec, buf...)
  454. }
  455. return buf, nil
  456. }
  457. func readN(r io.Reader, b []byte, n int) ([]byte, error) {
  458. if b == nil {
  459. if n == 0 {
  460. return make([]byte, 0), nil
  461. }
  462. if n <= bytesAllocLimit {
  463. b = make([]byte, n)
  464. } else {
  465. b = make([]byte, bytesAllocLimit)
  466. }
  467. }
  468. if n <= cap(b) {
  469. b = b[:n]
  470. _, err := io.ReadFull(r, b)
  471. return b, err
  472. }
  473. b = b[:cap(b)]
  474. var pos int
  475. for {
  476. alloc := n - len(b)
  477. if alloc > bytesAllocLimit {
  478. alloc = bytesAllocLimit
  479. }
  480. b = append(b, make([]byte, alloc)...)
  481. _, err := io.ReadFull(r, b[pos:])
  482. if err != nil {
  483. return nil, err
  484. }
  485. if len(b) == n {
  486. break
  487. }
  488. pos = len(b)
  489. }
  490. return b, nil
  491. }
  492. func min(a, b int) int {
  493. if a <= b {
  494. return a
  495. }
  496. return b
  497. }