framedec.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "encoding/binary"
  7. "encoding/hex"
  8. "errors"
  9. "io"
  10. "github.com/klauspost/compress/zstd/internal/xxhash"
  11. )
  12. type frameDec struct {
  13. o decoderOptions
  14. crc *xxhash.Digest
  15. WindowSize uint64
  16. // Frame history passed between blocks
  17. history history
  18. rawInput byteBuffer
  19. // Byte buffer that can be reused for small input blocks.
  20. bBuf byteBuf
  21. FrameContentSize uint64
  22. DictionaryID uint32
  23. HasCheckSum bool
  24. SingleSegment bool
  25. }
  26. const (
  27. // MinWindowSize is the minimum Window Size, which is 1 KB.
  28. MinWindowSize = 1 << 10
  29. // MaxWindowSize is the maximum encoder window size
  30. // and the default decoder maximum window size.
  31. MaxWindowSize = 1 << 29
  32. )
  33. const (
  34. frameMagic = "\x28\xb5\x2f\xfd"
  35. skippableFrameMagic = "\x2a\x4d\x18"
  36. )
  37. func newFrameDec(o decoderOptions) *frameDec {
  38. if o.maxWindowSize > o.maxDecodedSize {
  39. o.maxWindowSize = o.maxDecodedSize
  40. }
  41. d := frameDec{
  42. o: o,
  43. }
  44. return &d
  45. }
  46. // reset will read the frame header and prepare for block decoding.
  47. // If nothing can be read from the input, io.EOF will be returned.
  48. // Any other error indicated that the stream contained data, but
  49. // there was a problem.
  50. func (d *frameDec) reset(br byteBuffer) error {
  51. d.HasCheckSum = false
  52. d.WindowSize = 0
  53. var signature [4]byte
  54. for {
  55. var err error
  56. // Check if we can read more...
  57. b, err := br.readSmall(1)
  58. switch err {
  59. case io.EOF, io.ErrUnexpectedEOF:
  60. return io.EOF
  61. default:
  62. return err
  63. case nil:
  64. signature[0] = b[0]
  65. }
  66. // Read the rest, don't allow io.ErrUnexpectedEOF
  67. b, err = br.readSmall(3)
  68. switch err {
  69. case io.EOF:
  70. return io.EOF
  71. default:
  72. return err
  73. case nil:
  74. copy(signature[1:], b)
  75. }
  76. if string(signature[1:4]) != skippableFrameMagic || signature[0]&0xf0 != 0x50 {
  77. if debugDecoder {
  78. println("Not skippable", hex.EncodeToString(signature[:]), hex.EncodeToString([]byte(skippableFrameMagic)))
  79. }
  80. // Break if not skippable frame.
  81. break
  82. }
  83. // Read size to skip
  84. b, err = br.readSmall(4)
  85. if err != nil {
  86. if debugDecoder {
  87. println("Reading Frame Size", err)
  88. }
  89. return err
  90. }
  91. n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
  92. println("Skipping frame with", n, "bytes.")
  93. err = br.skipN(int64(n))
  94. if err != nil {
  95. if debugDecoder {
  96. println("Reading discarded frame", err)
  97. }
  98. return err
  99. }
  100. }
  101. if string(signature[:]) != frameMagic {
  102. if debugDecoder {
  103. println("Got magic numbers: ", signature, "want:", []byte(frameMagic))
  104. }
  105. return ErrMagicMismatch
  106. }
  107. // Read Frame_Header_Descriptor
  108. fhd, err := br.readByte()
  109. if err != nil {
  110. if debugDecoder {
  111. println("Reading Frame_Header_Descriptor", err)
  112. }
  113. return err
  114. }
  115. d.SingleSegment = fhd&(1<<5) != 0
  116. if fhd&(1<<3) != 0 {
  117. return errors.New("reserved bit set on frame header")
  118. }
  119. // Read Window_Descriptor
  120. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
  121. d.WindowSize = 0
  122. if !d.SingleSegment {
  123. wd, err := br.readByte()
  124. if err != nil {
  125. if debugDecoder {
  126. println("Reading Window_Descriptor", err)
  127. }
  128. return err
  129. }
  130. printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
  131. windowLog := 10 + (wd >> 3)
  132. windowBase := uint64(1) << windowLog
  133. windowAdd := (windowBase / 8) * uint64(wd&0x7)
  134. d.WindowSize = windowBase + windowAdd
  135. }
  136. // Read Dictionary_ID
  137. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id
  138. d.DictionaryID = 0
  139. if size := fhd & 3; size != 0 {
  140. if size == 3 {
  141. size = 4
  142. }
  143. b, err := br.readSmall(int(size))
  144. if err != nil {
  145. println("Reading Dictionary_ID", err)
  146. return err
  147. }
  148. var id uint32
  149. switch len(b) {
  150. case 1:
  151. id = uint32(b[0])
  152. case 2:
  153. id = uint32(b[0]) | (uint32(b[1]) << 8)
  154. case 4:
  155. id = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
  156. }
  157. if debugDecoder {
  158. println("Dict size", size, "ID:", id)
  159. }
  160. d.DictionaryID = id
  161. }
  162. // Read Frame_Content_Size
  163. // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size
  164. var fcsSize int
  165. v := fhd >> 6
  166. switch v {
  167. case 0:
  168. if d.SingleSegment {
  169. fcsSize = 1
  170. }
  171. default:
  172. fcsSize = 1 << v
  173. }
  174. d.FrameContentSize = fcsUnknown
  175. if fcsSize > 0 {
  176. b, err := br.readSmall(fcsSize)
  177. if err != nil {
  178. println("Reading Frame content", err)
  179. return err
  180. }
  181. switch len(b) {
  182. case 1:
  183. d.FrameContentSize = uint64(b[0])
  184. case 2:
  185. // When FCS_Field_Size is 2, the offset of 256 is added.
  186. d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256
  187. case 4:
  188. d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24)
  189. case 8:
  190. d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
  191. d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24)
  192. d.FrameContentSize = uint64(d1) | (uint64(d2) << 32)
  193. }
  194. if debugDecoder {
  195. println("Read FCS:", d.FrameContentSize)
  196. }
  197. }
  198. // Move this to shared.
  199. d.HasCheckSum = fhd&(1<<2) != 0
  200. if d.HasCheckSum {
  201. if d.crc == nil {
  202. d.crc = xxhash.New()
  203. }
  204. d.crc.Reset()
  205. }
  206. if d.WindowSize > d.o.maxWindowSize {
  207. if debugDecoder {
  208. printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize)
  209. }
  210. return ErrWindowSizeExceeded
  211. }
  212. if d.WindowSize == 0 && d.SingleSegment {
  213. // We may not need window in this case.
  214. d.WindowSize = d.FrameContentSize
  215. if d.WindowSize < MinWindowSize {
  216. d.WindowSize = MinWindowSize
  217. }
  218. if d.WindowSize > d.o.maxDecodedSize {
  219. if debugDecoder {
  220. printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize)
  221. }
  222. return ErrDecoderSizeExceeded
  223. }
  224. }
  225. // The minimum Window_Size is 1 KB.
  226. if d.WindowSize < MinWindowSize {
  227. if debugDecoder {
  228. println("got window size: ", d.WindowSize)
  229. }
  230. return ErrWindowSizeTooSmall
  231. }
  232. d.history.windowSize = int(d.WindowSize)
  233. if !d.o.lowMem || d.history.windowSize < maxBlockSize {
  234. // Alloc 2x window size if not low-mem, or window size below 2MB.
  235. d.history.allocFrameBuffer = d.history.windowSize * 2
  236. } else {
  237. if d.o.lowMem {
  238. // Alloc with 1MB extra.
  239. d.history.allocFrameBuffer = d.history.windowSize + maxBlockSize/2
  240. } else {
  241. // Alloc with 2MB extra.
  242. d.history.allocFrameBuffer = d.history.windowSize + maxBlockSize
  243. }
  244. }
  245. if debugDecoder {
  246. println("Frame: Dict:", d.DictionaryID, "FrameContentSize:", d.FrameContentSize, "singleseg:", d.SingleSegment, "window:", d.WindowSize, "crc:", d.HasCheckSum)
  247. }
  248. // history contains input - maybe we do something
  249. d.rawInput = br
  250. return nil
  251. }
  252. // next will start decoding the next block from stream.
  253. func (d *frameDec) next(block *blockDec) error {
  254. if debugDecoder {
  255. println("decoding new block")
  256. }
  257. err := block.reset(d.rawInput, d.WindowSize)
  258. if err != nil {
  259. println("block error:", err)
  260. // Signal the frame decoder we have a problem.
  261. block.sendErr(err)
  262. return err
  263. }
  264. return nil
  265. }
  266. // checkCRC will check the checksum if the frame has one.
  267. // Will return ErrCRCMismatch if crc check failed, otherwise nil.
  268. func (d *frameDec) checkCRC() error {
  269. if !d.HasCheckSum {
  270. return nil
  271. }
  272. // We can overwrite upper tmp now
  273. buf, err := d.rawInput.readSmall(4)
  274. if err != nil {
  275. println("CRC missing?", err)
  276. return err
  277. }
  278. if d.o.ignoreChecksum {
  279. return nil
  280. }
  281. want := binary.LittleEndian.Uint32(buf[:4])
  282. got := uint32(d.crc.Sum64())
  283. if got != want {
  284. if debugDecoder {
  285. printf("CRC check failed: got %08x, want %08x\n", got, want)
  286. }
  287. return ErrCRCMismatch
  288. }
  289. if debugDecoder {
  290. printf("CRC ok %08x\n", got)
  291. }
  292. return nil
  293. }
  294. // consumeCRC reads the checksum data if the frame has one.
  295. func (d *frameDec) consumeCRC() error {
  296. if d.HasCheckSum {
  297. _, err := d.rawInput.readSmall(4)
  298. if err != nil {
  299. println("CRC missing?", err)
  300. return err
  301. }
  302. }
  303. return nil
  304. }
  305. // runDecoder will run the decoder for the remainder of the frame.
  306. func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
  307. saved := d.history.b
  308. // We use the history for output to avoid copying it.
  309. d.history.b = dst
  310. d.history.ignoreBuffer = len(dst)
  311. // Store input length, so we only check new data.
  312. crcStart := len(dst)
  313. d.history.decoders.maxSyncLen = 0
  314. if d.o.limitToCap {
  315. d.history.decoders.maxSyncLen = uint64(cap(dst) - len(dst))
  316. }
  317. if d.FrameContentSize != fcsUnknown {
  318. if !d.o.limitToCap || d.FrameContentSize+uint64(len(dst)) < d.history.decoders.maxSyncLen {
  319. d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst))
  320. }
  321. if d.history.decoders.maxSyncLen > d.o.maxDecodedSize {
  322. if debugDecoder {
  323. println("maxSyncLen:", d.history.decoders.maxSyncLen, "> maxDecodedSize:", d.o.maxDecodedSize)
  324. }
  325. return dst, ErrDecoderSizeExceeded
  326. }
  327. if debugDecoder {
  328. println("maxSyncLen:", d.history.decoders.maxSyncLen)
  329. }
  330. if !d.o.limitToCap && uint64(cap(dst)) < d.history.decoders.maxSyncLen {
  331. // Alloc for output
  332. dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc)
  333. copy(dst2, dst)
  334. dst = dst2
  335. }
  336. }
  337. var err error
  338. for {
  339. err = dec.reset(d.rawInput, d.WindowSize)
  340. if err != nil {
  341. break
  342. }
  343. if debugDecoder {
  344. println("next block:", dec)
  345. }
  346. err = dec.decodeBuf(&d.history)
  347. if err != nil {
  348. break
  349. }
  350. if uint64(len(d.history.b)-crcStart) > d.o.maxDecodedSize {
  351. println("runDecoder: maxDecodedSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.o.maxDecodedSize)
  352. err = ErrDecoderSizeExceeded
  353. break
  354. }
  355. if d.o.limitToCap && len(d.history.b) > cap(dst) {
  356. println("runDecoder: cap exceeded", uint64(len(d.history.b)), ">", cap(dst))
  357. err = ErrDecoderSizeExceeded
  358. break
  359. }
  360. if uint64(len(d.history.b)-crcStart) > d.FrameContentSize {
  361. println("runDecoder: FrameContentSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.FrameContentSize)
  362. err = ErrFrameSizeExceeded
  363. break
  364. }
  365. if dec.Last {
  366. break
  367. }
  368. if debugDecoder {
  369. println("runDecoder: FrameContentSize", uint64(len(d.history.b)-crcStart), "<=", d.FrameContentSize)
  370. }
  371. }
  372. dst = d.history.b
  373. if err == nil {
  374. if d.FrameContentSize != fcsUnknown && uint64(len(d.history.b)-crcStart) != d.FrameContentSize {
  375. err = ErrFrameSizeMismatch
  376. } else if d.HasCheckSum {
  377. if d.o.ignoreChecksum {
  378. err = d.consumeCRC()
  379. } else {
  380. var n int
  381. n, err = d.crc.Write(dst[crcStart:])
  382. if err == nil {
  383. if n != len(dst)-crcStart {
  384. err = io.ErrShortWrite
  385. } else {
  386. err = d.checkCRC()
  387. }
  388. }
  389. }
  390. }
  391. }
  392. d.history.b = saved
  393. return dst, err
  394. }