encoder.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "crypto/rand"
  7. "fmt"
  8. "io"
  9. "math"
  10. rdebug "runtime/debug"
  11. "sync"
  12. "github.com/klauspost/compress/zstd/internal/xxhash"
  13. )
  14. // Encoder provides encoding to Zstandard.
  15. // An Encoder can be used for either compressing a stream via the
  16. // io.WriteCloser interface supported by the Encoder or as multiple independent
  17. // tasks via the EncodeAll function.
  18. // Smaller encodes are encouraged to use the EncodeAll function.
  19. // Use NewWriter to create a new instance.
  20. type Encoder struct {
  21. o encoderOptions
  22. encoders chan encoder
  23. state encoderState
  24. init sync.Once
  25. }
  26. type encoder interface {
  27. Encode(blk *blockEnc, src []byte)
  28. EncodeNoHist(blk *blockEnc, src []byte)
  29. Block() *blockEnc
  30. CRC() *xxhash.Digest
  31. AppendCRC([]byte) []byte
  32. WindowSize(size int64) int32
  33. UseBlock(*blockEnc)
  34. Reset(d *dict, singleBlock bool)
  35. }
  36. type encoderState struct {
  37. w io.Writer
  38. filling []byte
  39. current []byte
  40. previous []byte
  41. encoder encoder
  42. writing *blockEnc
  43. err error
  44. writeErr error
  45. nWritten int64
  46. nInput int64
  47. frameContentSize int64
  48. headerWritten bool
  49. eofWritten bool
  50. fullFrameWritten bool
  51. // This waitgroup indicates an encode is running.
  52. wg sync.WaitGroup
  53. // This waitgroup indicates we have a block encoding/writing.
  54. wWg sync.WaitGroup
  55. }
  56. // NewWriter will create a new Zstandard encoder.
  57. // If the encoder will be used for encoding blocks a nil writer can be used.
  58. func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
  59. initPredefined()
  60. var e Encoder
  61. e.o.setDefault()
  62. for _, o := range opts {
  63. err := o(&e.o)
  64. if err != nil {
  65. return nil, err
  66. }
  67. }
  68. if w != nil {
  69. e.Reset(w)
  70. }
  71. return &e, nil
  72. }
  73. func (e *Encoder) initialize() {
  74. if e.o.concurrent == 0 {
  75. e.o.setDefault()
  76. }
  77. e.encoders = make(chan encoder, e.o.concurrent)
  78. for i := 0; i < e.o.concurrent; i++ {
  79. enc := e.o.encoder()
  80. e.encoders <- enc
  81. }
  82. }
  83. // Reset will re-initialize the writer and new writes will encode to the supplied writer
  84. // as a new, independent stream.
  85. func (e *Encoder) Reset(w io.Writer) {
  86. s := &e.state
  87. s.wg.Wait()
  88. s.wWg.Wait()
  89. if cap(s.filling) == 0 {
  90. s.filling = make([]byte, 0, e.o.blockSize)
  91. }
  92. if e.o.concurrent > 1 {
  93. if cap(s.current) == 0 {
  94. s.current = make([]byte, 0, e.o.blockSize)
  95. }
  96. if cap(s.previous) == 0 {
  97. s.previous = make([]byte, 0, e.o.blockSize)
  98. }
  99. s.current = s.current[:0]
  100. s.previous = s.previous[:0]
  101. if s.writing == nil {
  102. s.writing = &blockEnc{lowMem: e.o.lowMem}
  103. s.writing.init()
  104. }
  105. s.writing.initNewEncode()
  106. }
  107. if s.encoder == nil {
  108. s.encoder = e.o.encoder()
  109. }
  110. s.filling = s.filling[:0]
  111. s.encoder.Reset(e.o.dict, false)
  112. s.headerWritten = false
  113. s.eofWritten = false
  114. s.fullFrameWritten = false
  115. s.w = w
  116. s.err = nil
  117. s.nWritten = 0
  118. s.nInput = 0
  119. s.writeErr = nil
  120. s.frameContentSize = 0
  121. }
  122. // ResetContentSize will reset and set a content size for the next stream.
  123. // If the bytes written does not match the size given an error will be returned
  124. // when calling Close().
  125. // This is removed when Reset is called.
  126. // Sizes <= 0 results in no content size set.
  127. func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
  128. e.Reset(w)
  129. if size >= 0 {
  130. e.state.frameContentSize = size
  131. }
  132. }
  133. // Write data to the encoder.
  134. // Input data will be buffered and as the buffer fills up
  135. // content will be compressed and written to the output.
  136. // When done writing, use Close to flush the remaining output
  137. // and write CRC if requested.
  138. func (e *Encoder) Write(p []byte) (n int, err error) {
  139. s := &e.state
  140. for len(p) > 0 {
  141. if len(p)+len(s.filling) < e.o.blockSize {
  142. if e.o.crc {
  143. _, _ = s.encoder.CRC().Write(p)
  144. }
  145. s.filling = append(s.filling, p...)
  146. return n + len(p), nil
  147. }
  148. add := p
  149. if len(p)+len(s.filling) > e.o.blockSize {
  150. add = add[:e.o.blockSize-len(s.filling)]
  151. }
  152. if e.o.crc {
  153. _, _ = s.encoder.CRC().Write(add)
  154. }
  155. s.filling = append(s.filling, add...)
  156. p = p[len(add):]
  157. n += len(add)
  158. if len(s.filling) < e.o.blockSize {
  159. return n, nil
  160. }
  161. err := e.nextBlock(false)
  162. if err != nil {
  163. return n, err
  164. }
  165. if debugAsserts && len(s.filling) > 0 {
  166. panic(len(s.filling))
  167. }
  168. }
  169. return n, nil
  170. }
  171. // nextBlock will synchronize and start compressing input in e.state.filling.
  172. // If an error has occurred during encoding it will be returned.
  173. func (e *Encoder) nextBlock(final bool) error {
  174. s := &e.state
  175. // Wait for current block.
  176. s.wg.Wait()
  177. if s.err != nil {
  178. return s.err
  179. }
  180. if len(s.filling) > e.o.blockSize {
  181. return fmt.Errorf("block > maxStoreBlockSize")
  182. }
  183. if !s.headerWritten {
  184. // If we have a single block encode, do a sync compression.
  185. if final && len(s.filling) == 0 && !e.o.fullZero {
  186. s.headerWritten = true
  187. s.fullFrameWritten = true
  188. s.eofWritten = true
  189. return nil
  190. }
  191. if final && len(s.filling) > 0 {
  192. s.current = e.EncodeAll(s.filling, s.current[:0])
  193. var n2 int
  194. n2, s.err = s.w.Write(s.current)
  195. if s.err != nil {
  196. return s.err
  197. }
  198. s.nWritten += int64(n2)
  199. s.nInput += int64(len(s.filling))
  200. s.current = s.current[:0]
  201. s.filling = s.filling[:0]
  202. s.headerWritten = true
  203. s.fullFrameWritten = true
  204. s.eofWritten = true
  205. return nil
  206. }
  207. var tmp [maxHeaderSize]byte
  208. fh := frameHeader{
  209. ContentSize: uint64(s.frameContentSize),
  210. WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)),
  211. SingleSegment: false,
  212. Checksum: e.o.crc,
  213. DictID: e.o.dict.ID(),
  214. }
  215. dst, err := fh.appendTo(tmp[:0])
  216. if err != nil {
  217. return err
  218. }
  219. s.headerWritten = true
  220. s.wWg.Wait()
  221. var n2 int
  222. n2, s.err = s.w.Write(dst)
  223. if s.err != nil {
  224. return s.err
  225. }
  226. s.nWritten += int64(n2)
  227. }
  228. if s.eofWritten {
  229. // Ensure we only write it once.
  230. final = false
  231. }
  232. if len(s.filling) == 0 {
  233. // Final block, but no data.
  234. if final {
  235. enc := s.encoder
  236. blk := enc.Block()
  237. blk.reset(nil)
  238. blk.last = true
  239. blk.encodeRaw(nil)
  240. s.wWg.Wait()
  241. _, s.err = s.w.Write(blk.output)
  242. s.nWritten += int64(len(blk.output))
  243. s.eofWritten = true
  244. }
  245. return s.err
  246. }
  247. // SYNC:
  248. if e.o.concurrent == 1 {
  249. src := s.filling
  250. s.nInput += int64(len(s.filling))
  251. if debugEncoder {
  252. println("Adding sync block,", len(src), "bytes, final:", final)
  253. }
  254. enc := s.encoder
  255. blk := enc.Block()
  256. blk.reset(nil)
  257. enc.Encode(blk, src)
  258. blk.last = final
  259. if final {
  260. s.eofWritten = true
  261. }
  262. err := errIncompressible
  263. // If we got the exact same number of literals as input,
  264. // assume the literals cannot be compressed.
  265. if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
  266. err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  267. }
  268. switch err {
  269. case errIncompressible:
  270. if debugEncoder {
  271. println("Storing incompressible block as raw")
  272. }
  273. blk.encodeRaw(src)
  274. // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
  275. case nil:
  276. default:
  277. s.err = err
  278. return err
  279. }
  280. _, s.err = s.w.Write(blk.output)
  281. s.nWritten += int64(len(blk.output))
  282. s.filling = s.filling[:0]
  283. return s.err
  284. }
  285. // Move blocks forward.
  286. s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
  287. s.nInput += int64(len(s.current))
  288. s.wg.Add(1)
  289. go func(src []byte) {
  290. if debugEncoder {
  291. println("Adding block,", len(src), "bytes, final:", final)
  292. }
  293. defer func() {
  294. if r := recover(); r != nil {
  295. s.err = fmt.Errorf("panic while encoding: %v", r)
  296. rdebug.PrintStack()
  297. }
  298. s.wg.Done()
  299. }()
  300. enc := s.encoder
  301. blk := enc.Block()
  302. enc.Encode(blk, src)
  303. blk.last = final
  304. if final {
  305. s.eofWritten = true
  306. }
  307. // Wait for pending writes.
  308. s.wWg.Wait()
  309. if s.writeErr != nil {
  310. s.err = s.writeErr
  311. return
  312. }
  313. // Transfer encoders from previous write block.
  314. blk.swapEncoders(s.writing)
  315. // Transfer recent offsets to next.
  316. enc.UseBlock(s.writing)
  317. s.writing = blk
  318. s.wWg.Add(1)
  319. go func() {
  320. defer func() {
  321. if r := recover(); r != nil {
  322. s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
  323. rdebug.PrintStack()
  324. }
  325. s.wWg.Done()
  326. }()
  327. err := errIncompressible
  328. // If we got the exact same number of literals as input,
  329. // assume the literals cannot be compressed.
  330. if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
  331. err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  332. }
  333. switch err {
  334. case errIncompressible:
  335. if debugEncoder {
  336. println("Storing incompressible block as raw")
  337. }
  338. blk.encodeRaw(src)
  339. // In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
  340. case nil:
  341. default:
  342. s.writeErr = err
  343. return
  344. }
  345. _, s.writeErr = s.w.Write(blk.output)
  346. s.nWritten += int64(len(blk.output))
  347. }()
  348. }(s.current)
  349. return nil
  350. }
  351. // ReadFrom reads data from r until EOF or error.
  352. // The return value n is the number of bytes read.
  353. // Any error except io.EOF encountered during the read is also returned.
  354. //
  355. // The Copy function uses ReaderFrom if available.
  356. func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
  357. if debugEncoder {
  358. println("Using ReadFrom")
  359. }
  360. // Flush any current writes.
  361. if len(e.state.filling) > 0 {
  362. if err := e.nextBlock(false); err != nil {
  363. return 0, err
  364. }
  365. }
  366. e.state.filling = e.state.filling[:e.o.blockSize]
  367. src := e.state.filling
  368. for {
  369. n2, err := r.Read(src)
  370. if e.o.crc {
  371. _, _ = e.state.encoder.CRC().Write(src[:n2])
  372. }
  373. // src is now the unfilled part...
  374. src = src[n2:]
  375. n += int64(n2)
  376. switch err {
  377. case io.EOF:
  378. e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
  379. if debugEncoder {
  380. println("ReadFrom: got EOF final block:", len(e.state.filling))
  381. }
  382. return n, nil
  383. case nil:
  384. default:
  385. if debugEncoder {
  386. println("ReadFrom: got error:", err)
  387. }
  388. e.state.err = err
  389. return n, err
  390. }
  391. if len(src) > 0 {
  392. if debugEncoder {
  393. println("ReadFrom: got space left in source:", len(src))
  394. }
  395. continue
  396. }
  397. err = e.nextBlock(false)
  398. if err != nil {
  399. return n, err
  400. }
  401. e.state.filling = e.state.filling[:e.o.blockSize]
  402. src = e.state.filling
  403. }
  404. }
  405. // Flush will send the currently written data to output
  406. // and block until everything has been written.
  407. // This should only be used on rare occasions where pushing the currently queued data is critical.
  408. func (e *Encoder) Flush() error {
  409. s := &e.state
  410. if len(s.filling) > 0 {
  411. err := e.nextBlock(false)
  412. if err != nil {
  413. return err
  414. }
  415. }
  416. s.wg.Wait()
  417. s.wWg.Wait()
  418. if s.err != nil {
  419. return s.err
  420. }
  421. return s.writeErr
  422. }
  423. // Close will flush the final output and close the stream.
  424. // The function will block until everything has been written.
  425. // The Encoder can still be re-used after calling this.
  426. func (e *Encoder) Close() error {
  427. s := &e.state
  428. if s.encoder == nil {
  429. return nil
  430. }
  431. err := e.nextBlock(true)
  432. if err != nil {
  433. return err
  434. }
  435. if s.frameContentSize > 0 {
  436. if s.nInput != s.frameContentSize {
  437. return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
  438. }
  439. }
  440. if e.state.fullFrameWritten {
  441. return s.err
  442. }
  443. s.wg.Wait()
  444. s.wWg.Wait()
  445. if s.err != nil {
  446. return s.err
  447. }
  448. if s.writeErr != nil {
  449. return s.writeErr
  450. }
  451. // Write CRC
  452. if e.o.crc && s.err == nil {
  453. // heap alloc.
  454. var tmp [4]byte
  455. _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
  456. s.nWritten += 4
  457. }
  458. // Add padding with content from crypto/rand.Reader
  459. if s.err == nil && e.o.pad > 0 {
  460. add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
  461. frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
  462. if err != nil {
  463. return err
  464. }
  465. _, s.err = s.w.Write(frame)
  466. }
  467. return s.err
  468. }
  469. // EncodeAll will encode all input in src and append it to dst.
  470. // This function can be called concurrently, but each call will only run on a single goroutine.
  471. // If empty input is given, nothing is returned, unless WithZeroFrames is specified.
  472. // Encoded blocks can be concatenated and the result will be the combined input stream.
  473. // Data compressed with EncodeAll can be decoded with the Decoder,
  474. // using either a stream or DecodeAll.
  475. func (e *Encoder) EncodeAll(src, dst []byte) []byte {
  476. if len(src) == 0 {
  477. if e.o.fullZero {
  478. // Add frame header.
  479. fh := frameHeader{
  480. ContentSize: 0,
  481. WindowSize: MinWindowSize,
  482. SingleSegment: true,
  483. // Adding a checksum would be a waste of space.
  484. Checksum: false,
  485. DictID: 0,
  486. }
  487. dst, _ = fh.appendTo(dst)
  488. // Write raw block as last one only.
  489. var blk blockHeader
  490. blk.setSize(0)
  491. blk.setType(blockTypeRaw)
  492. blk.setLast(true)
  493. dst = blk.appendTo(dst)
  494. }
  495. return dst
  496. }
  497. e.init.Do(e.initialize)
  498. enc := <-e.encoders
  499. defer func() {
  500. // Release encoder reference to last block.
  501. // If a non-single block is needed the encoder will reset again.
  502. e.encoders <- enc
  503. }()
  504. // Use single segments when above minimum window and below window size.
  505. single := len(src) <= e.o.windowSize && len(src) > MinWindowSize
  506. if e.o.single != nil {
  507. single = *e.o.single
  508. }
  509. fh := frameHeader{
  510. ContentSize: uint64(len(src)),
  511. WindowSize: uint32(enc.WindowSize(int64(len(src)))),
  512. SingleSegment: single,
  513. Checksum: e.o.crc,
  514. DictID: e.o.dict.ID(),
  515. }
  516. // If less than 1MB, allocate a buffer up front.
  517. if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
  518. dst = make([]byte, 0, len(src))
  519. }
  520. dst, err := fh.appendTo(dst)
  521. if err != nil {
  522. panic(err)
  523. }
  524. // If we can do everything in one block, prefer that.
  525. if len(src) <= e.o.blockSize {
  526. enc.Reset(e.o.dict, true)
  527. // Slightly faster with no history and everything in one block.
  528. if e.o.crc {
  529. _, _ = enc.CRC().Write(src)
  530. }
  531. blk := enc.Block()
  532. blk.last = true
  533. if e.o.dict == nil {
  534. enc.EncodeNoHist(blk, src)
  535. } else {
  536. enc.Encode(blk, src)
  537. }
  538. // If we got the exact same number of literals as input,
  539. // assume the literals cannot be compressed.
  540. err := errIncompressible
  541. oldout := blk.output
  542. if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
  543. // Output directly to dst
  544. blk.output = dst
  545. err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  546. }
  547. switch err {
  548. case errIncompressible:
  549. if debugEncoder {
  550. println("Storing incompressible block as raw")
  551. }
  552. dst = blk.encodeRawTo(dst, src)
  553. case nil:
  554. dst = blk.output
  555. default:
  556. panic(err)
  557. }
  558. blk.output = oldout
  559. } else {
  560. enc.Reset(e.o.dict, false)
  561. blk := enc.Block()
  562. for len(src) > 0 {
  563. todo := src
  564. if len(todo) > e.o.blockSize {
  565. todo = todo[:e.o.blockSize]
  566. }
  567. src = src[len(todo):]
  568. if e.o.crc {
  569. _, _ = enc.CRC().Write(todo)
  570. }
  571. blk.pushOffsets()
  572. enc.Encode(blk, todo)
  573. if len(src) == 0 {
  574. blk.last = true
  575. }
  576. err := errIncompressible
  577. // If we got the exact same number of literals as input,
  578. // assume the literals cannot be compressed.
  579. if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
  580. err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
  581. }
  582. switch err {
  583. case errIncompressible:
  584. if debugEncoder {
  585. println("Storing incompressible block as raw")
  586. }
  587. dst = blk.encodeRawTo(dst, todo)
  588. blk.popOffsets()
  589. case nil:
  590. dst = append(dst, blk.output...)
  591. default:
  592. panic(err)
  593. }
  594. blk.reset(nil)
  595. }
  596. }
  597. if e.o.crc {
  598. dst = enc.AppendCRC(dst)
  599. }
  600. // Add padding with content from crypto/rand.Reader
  601. if e.o.pad > 0 {
  602. add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
  603. dst, err = skippableFrame(dst, add, rand.Reader)
  604. if err != nil {
  605. panic(err)
  606. }
  607. }
  608. return dst
  609. }
  610. // MaxEncodedSize returns the expected maximum
  611. // size of an encoded block or stream.
  612. func (e *Encoder) MaxEncodedSize(size int) int {
  613. frameHeader := 4 + 2 // magic + frame header & window descriptor
  614. if e.o.dict != nil {
  615. frameHeader += 4
  616. }
  617. // Frame content size:
  618. if size < 256 {
  619. frameHeader++
  620. } else if size < 65536+256 {
  621. frameHeader += 2
  622. } else if size < math.MaxInt32 {
  623. frameHeader += 4
  624. } else {
  625. frameHeader += 8
  626. }
  627. // Final crc
  628. if e.o.crc {
  629. frameHeader += 4
  630. }
  631. // Max overhead is 3 bytes/block.
  632. // There cannot be 0 blocks.
  633. blocks := (size + e.o.blockSize) / e.o.blockSize
  634. // Combine, add padding.
  635. maxSz := frameHeader + 3*blocks + size
  636. if e.o.pad > 1 {
  637. maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad))
  638. }
  639. return maxSz
  640. }