lzx.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. // Package lzx implements a decompressor for the the WIM variant of the
  2. // LZX compression algorithm.
  3. //
  4. // The LZX algorithm is an earlier variant of LZX DELTA, which is documented
  5. // at https://msdn.microsoft.com/en-us/library/cc483133(v=exchg.80).aspx.
  6. package lzx
  7. import (
  8. "bytes"
  9. "encoding/binary"
  10. "errors"
  11. "io"
  12. )
  13. const (
  14. maincodecount = 496
  15. maincodesplit = 256
  16. lencodecount = 249
  17. lenshift = 9
  18. codemask = 0x1ff
  19. tablebits = 9
  20. tablesize = 1 << tablebits
  21. maxBlockSize = 32768
  22. windowSize = 32768
  23. maxTreePathLen = 16
  24. e8filesize = 12000000
  25. maxe8offset = 0x3fffffff
  26. verbatimBlock = 1
  27. alignedOffsetBlock = 2
  28. uncompressedBlock = 3
  29. )
  30. var footerBits = [...]byte{
  31. 0, 0, 0, 0, 1, 1, 2, 2,
  32. 3, 3, 4, 4, 5, 5, 6, 6,
  33. 7, 7, 8, 8, 9, 9, 10, 10,
  34. 11, 11, 12, 12, 13, 13, 14,
  35. }
  36. var basePosition = [...]uint16{
  37. 0, 1, 2, 3, 4, 6, 8, 12,
  38. 16, 24, 32, 48, 64, 96, 128, 192,
  39. 256, 384, 512, 768, 1024, 1536, 2048, 3072,
  40. 4096, 6144, 8192, 12288, 16384, 24576, 32768,
  41. }
  42. var (
  43. errCorrupt = errors.New("LZX data corrupt")
  44. )
  45. // Reader is an interface used by the decompressor to access
  46. // the input stream. If the provided io.Reader does not implement
  47. // Reader, then a bufio.Reader is used.
  48. type Reader interface {
  49. io.Reader
  50. io.ByteReader
  51. }
  52. type decompressor struct {
  53. r io.Reader
  54. err error
  55. unaligned bool
  56. nbits byte
  57. c uint32
  58. lru [3]uint16
  59. uncompressed int
  60. windowReader *bytes.Reader
  61. mainlens [maincodecount]byte
  62. lenlens [lencodecount]byte
  63. window [windowSize]byte
  64. b []byte
  65. bv int
  66. bo int
  67. }
  68. //go:noinline
  69. func (f *decompressor) fail(err error) {
  70. if f.err == nil {
  71. f.err = err
  72. }
  73. f.bo = 0
  74. f.bv = 0
  75. }
  76. func (f *decompressor) ensureAtLeast(n int) error {
  77. if f.bv-f.bo >= n {
  78. return nil
  79. }
  80. if f.err != nil {
  81. return f.err
  82. }
  83. if f.bv != f.bo {
  84. copy(f.b[:f.bv-f.bo], f.b[f.bo:f.bv])
  85. }
  86. n, err := io.ReadAtLeast(f.r, f.b[f.bv-f.bo:], n)
  87. if err != nil {
  88. if err == io.EOF { //nolint:errorlint
  89. err = io.ErrUnexpectedEOF
  90. } else {
  91. f.fail(err)
  92. }
  93. return err
  94. }
  95. f.bv = f.bv - f.bo + n
  96. f.bo = 0
  97. return nil
  98. }
  99. // feed retrieves another 16-bit word from the stream and consumes
  100. // it into f.c. It returns false if there are no more bytes available.
  101. // Otherwise, on error, it sets f.err.
  102. func (f *decompressor) feed() bool {
  103. err := f.ensureAtLeast(2)
  104. if err == io.ErrUnexpectedEOF { //nolint:errorlint // returns io.ErrUnexpectedEOF by contract
  105. return false
  106. }
  107. f.c |= (uint32(f.b[f.bo+1])<<8 | uint32(f.b[f.bo])) << (16 - f.nbits)
  108. f.nbits += 16
  109. f.bo += 2
  110. return true
  111. }
  112. // getBits retrieves the next n bits from the byte stream. n
  113. // must be <= 16. It sets f.err on error.
  114. func (f *decompressor) getBits(n byte) uint16 {
  115. if f.nbits < n {
  116. if !f.feed() {
  117. f.fail(io.ErrUnexpectedEOF)
  118. }
  119. }
  120. c := uint16(f.c >> (32 - n))
  121. f.c <<= n
  122. f.nbits -= n
  123. return c
  124. }
  125. type huffman struct {
  126. extra [][]uint16
  127. maxbits byte
  128. table [tablesize]uint16
  129. }
  130. // buildTable builds a huffman decoding table from a slice of code lengths,
  131. // one per code, in order. Each code length must be <= maxTreePathLen.
  132. // See https://en.wikipedia.org/wiki/Canonical_Huffman_code.
  133. func buildTable(codelens []byte) *huffman {
  134. // Determine the number of codes of each length, and the
  135. // maximum length.
  136. var count [maxTreePathLen + 1]uint
  137. var max byte
  138. for _, cl := range codelens {
  139. count[cl]++
  140. if max < cl {
  141. max = cl
  142. }
  143. }
  144. if max == 0 {
  145. return &huffman{}
  146. }
  147. // Determine the first code of each length.
  148. var first [maxTreePathLen + 1]uint
  149. code := uint(0)
  150. for i := byte(1); i <= max; i++ {
  151. code <<= 1
  152. first[i] = code
  153. code += count[i]
  154. }
  155. if code != 1<<max {
  156. return nil
  157. }
  158. // Build a table for code lookup. For code sizes < max,
  159. // put all possible suffixes for the code into the table, too.
  160. // For max > tablebits, split long codes into additional tables
  161. // of suffixes of max-tablebits length.
  162. h := &huffman{maxbits: max}
  163. if max > tablebits {
  164. core := first[tablebits+1] / 2 // Number of codes that fit without extra tables
  165. nextra := 1<<tablebits - core // Number of extra entries
  166. h.extra = make([][]uint16, nextra)
  167. for code := core; code < 1<<tablebits; code++ {
  168. h.table[code] = uint16(code - core)
  169. h.extra[code-core] = make([]uint16, 1<<(max-tablebits))
  170. }
  171. }
  172. for i, cl := range codelens {
  173. if cl != 0 {
  174. code := first[cl]
  175. first[cl]++
  176. v := uint16(cl)<<lenshift | uint16(i)
  177. if cl <= tablebits {
  178. extendedCode := code << (tablebits - cl)
  179. for j := uint(0); j < 1<<(tablebits-cl); j++ {
  180. h.table[extendedCode+j] = v
  181. }
  182. } else {
  183. prefix := code >> (cl - tablebits)
  184. suffix := code & (1<<(cl-tablebits) - 1)
  185. extendedCode := suffix << (max - cl)
  186. for j := uint(0); j < 1<<(max-cl); j++ {
  187. h.extra[h.table[prefix]][extendedCode+j] = v
  188. }
  189. }
  190. }
  191. }
  192. return h
  193. }
  194. // getCode retrieves the next code using the provided
  195. // huffman tree. It sets f.err on error.
  196. func (f *decompressor) getCode(h *huffman) uint16 {
  197. if h.maxbits > 0 {
  198. if f.nbits < maxTreePathLen {
  199. f.feed()
  200. }
  201. // For codes with length < tablebits, it doesn't matter
  202. // what the remainder of the bits used for table lookup
  203. // are, since entries with all possible suffixes were
  204. // added to the table.
  205. c := h.table[f.c>>(32-tablebits)]
  206. if !(c >= 1<<lenshift) {
  207. // The code is not in c.
  208. c = h.extra[c][f.c<<tablebits>>(32-(h.maxbits-tablebits))]
  209. }
  210. n := byte(c >> lenshift)
  211. if f.nbits >= n {
  212. // Only consume the length of the code, not the maximum
  213. // code length.
  214. f.c <<= n
  215. f.nbits -= n
  216. return c & codemask
  217. }
  218. f.fail(io.ErrUnexpectedEOF)
  219. return 0
  220. }
  221. // This is an empty tree. It should not be used.
  222. f.fail(errCorrupt)
  223. return 0
  224. }
  225. // readTree updates the huffman tree path lengths in lens by
  226. // reading and decoding lengths from the byte stream. lens
  227. // should be prepopulated with the previous block's tree's path
  228. // lengths. For the first block, lens should be zero.
  229. func (f *decompressor) readTree(lens []byte) error {
  230. // Get the pre-tree for the main tree.
  231. var pretreeLen [20]byte
  232. for i := range pretreeLen {
  233. pretreeLen[i] = byte(f.getBits(4))
  234. }
  235. if f.err != nil {
  236. return f.err
  237. }
  238. h := buildTable(pretreeLen[:])
  239. // The lengths are encoded as a series of huffman codes
  240. // encoded by the pre-tree.
  241. for i := 0; i < len(lens); {
  242. c := byte(f.getCode(h))
  243. if f.err != nil {
  244. return f.err
  245. }
  246. switch {
  247. case c <= 16: // length is delta from previous length
  248. lens[i] = (lens[i] + 17 - c) % 17
  249. i++
  250. case c == 17: // next n + 4 lengths are zero
  251. zeroes := int(f.getBits(4)) + 4
  252. if i+zeroes > len(lens) {
  253. return errCorrupt
  254. }
  255. for j := 0; j < zeroes; j++ {
  256. lens[i+j] = 0
  257. }
  258. i += zeroes
  259. case c == 18: // next n + 20 lengths are zero
  260. zeroes := int(f.getBits(5)) + 20
  261. if i+zeroes > len(lens) {
  262. return errCorrupt
  263. }
  264. for j := 0; j < zeroes; j++ {
  265. lens[i+j] = 0
  266. }
  267. i += zeroes
  268. case c == 19: // next n + 4 lengths all have the same value
  269. same := int(f.getBits(1)) + 4
  270. if i+same > len(lens) {
  271. return errCorrupt
  272. }
  273. c = byte(f.getCode(h))
  274. if c > 16 {
  275. return errCorrupt
  276. }
  277. l := (lens[i] + 17 - c) % 17
  278. for j := 0; j < same; j++ {
  279. lens[i+j] = l
  280. }
  281. i += same
  282. default:
  283. return errCorrupt
  284. }
  285. }
  286. if f.err != nil {
  287. return f.err
  288. }
  289. return nil
  290. }
  291. func (f *decompressor) readBlockHeader() (byte, uint16, error) {
  292. // If the previous block was an unaligned uncompressed block, restore
  293. // 2-byte alignment.
  294. if f.unaligned {
  295. err := f.ensureAtLeast(1)
  296. if err != nil {
  297. return 0, 0, err
  298. }
  299. f.bo++
  300. f.unaligned = false
  301. }
  302. blockType := f.getBits(3)
  303. full := f.getBits(1)
  304. var blockSize uint16
  305. if full != 0 {
  306. blockSize = maxBlockSize
  307. } else {
  308. blockSize = f.getBits(16)
  309. if blockSize > maxBlockSize {
  310. return 0, 0, errCorrupt
  311. }
  312. }
  313. if f.err != nil {
  314. return 0, 0, f.err
  315. }
  316. switch blockType {
  317. case verbatimBlock, alignedOffsetBlock:
  318. // The caller will read the huffman trees.
  319. case uncompressedBlock:
  320. if f.nbits > 16 {
  321. panic("impossible: more than one 16-bit word remains")
  322. }
  323. // Drop the remaining bits in the current 16-bit word
  324. // If there are no bits left, discard a full 16-bit word.
  325. n := f.nbits
  326. if n == 0 {
  327. n = 16
  328. }
  329. f.getBits(n)
  330. // Read the LRU values for the next block.
  331. err := f.ensureAtLeast(12)
  332. if err != nil {
  333. return 0, 0, err
  334. }
  335. f.lru[0] = uint16(binary.LittleEndian.Uint32(f.b[f.bo : f.bo+4]))
  336. f.lru[1] = uint16(binary.LittleEndian.Uint32(f.b[f.bo+4 : f.bo+8]))
  337. f.lru[2] = uint16(binary.LittleEndian.Uint32(f.b[f.bo+8 : f.bo+12]))
  338. f.bo += 12
  339. default:
  340. return 0, 0, errCorrupt
  341. }
  342. return byte(blockType), blockSize, nil
  343. }
  344. // readTrees reads the two or three huffman trees for the current block.
  345. // readAligned specifies whether to read the aligned offset tree.
  346. func (f *decompressor) readTrees(readAligned bool) (main *huffman, length *huffman, aligned *huffman, err error) {
  347. // Aligned offset blocks start with a small aligned offset tree.
  348. if readAligned {
  349. var alignedLen [8]byte
  350. for i := range alignedLen {
  351. alignedLen[i] = byte(f.getBits(3))
  352. }
  353. aligned = buildTable(alignedLen[:])
  354. if aligned == nil {
  355. return main, length, aligned, errors.New("corrupt")
  356. }
  357. }
  358. // The main tree is encoded in two parts.
  359. err = f.readTree(f.mainlens[:maincodesplit])
  360. if err != nil {
  361. return main, length, aligned, err
  362. }
  363. err = f.readTree(f.mainlens[maincodesplit:])
  364. if err != nil {
  365. return main, length, aligned, err
  366. }
  367. main = buildTable(f.mainlens[:])
  368. if main == nil {
  369. return main, length, aligned, errors.New("corrupt")
  370. }
  371. // The length tree is encoding in a single part.
  372. err = f.readTree(f.lenlens[:])
  373. if err != nil {
  374. return main, length, aligned, err
  375. }
  376. length = buildTable(f.lenlens[:])
  377. if length == nil {
  378. return main, length, aligned, errors.New("corrupt")
  379. }
  380. return main, length, aligned, f.err
  381. }
  382. // readCompressedBlock decodes a compressed block, writing into the window
  383. // starting at start and ending at end, and using the provided huffman trees.
  384. func (f *decompressor) readCompressedBlock(start, end uint16, hmain, hlength, haligned *huffman) (int, error) {
  385. i := start
  386. for i < end {
  387. main := f.getCode(hmain)
  388. if f.err != nil {
  389. break
  390. }
  391. if main < 256 {
  392. // Literal byte.
  393. f.window[i] = byte(main)
  394. i++
  395. continue
  396. }
  397. // This is a match backward in the window. Determine
  398. // the offset and dlength.
  399. matchlen := (main - 256) % 8
  400. slot := (main - 256) / 8
  401. // The length is either the low bits of the code,
  402. // or if this is 7, is encoded with the length tree.
  403. if matchlen == 7 {
  404. matchlen += f.getCode(hlength)
  405. }
  406. matchlen += 2
  407. var matchoffset uint16
  408. if slot < 3 { //nolint:nestif // todo: simplify nested complexity
  409. // The offset is one of the LRU values.
  410. matchoffset = f.lru[slot]
  411. f.lru[slot] = f.lru[0]
  412. f.lru[0] = matchoffset
  413. } else {
  414. // The offset is encoded as a combination of the
  415. // slot and more bits from the bit stream.
  416. offsetbits := footerBits[slot]
  417. var verbatimbits, alignedbits uint16
  418. if offsetbits > 0 {
  419. if haligned != nil && offsetbits >= 3 {
  420. // This is an aligned offset block. Combine
  421. // the bits written verbatim with the aligned
  422. // offset tree code.
  423. verbatimbits = f.getBits(offsetbits-3) * 8
  424. alignedbits = f.getCode(haligned)
  425. } else {
  426. // There are no aligned offset bits to read,
  427. // only verbatim bits.
  428. verbatimbits = f.getBits(offsetbits)
  429. alignedbits = 0
  430. }
  431. }
  432. matchoffset = basePosition[slot] + verbatimbits + alignedbits - 2
  433. // Update the LRU cache.
  434. f.lru[2] = f.lru[1]
  435. f.lru[1] = f.lru[0]
  436. f.lru[0] = matchoffset
  437. }
  438. if !(matchoffset <= i && matchlen <= end-i) {
  439. f.fail(errCorrupt)
  440. break
  441. }
  442. copyend := i + matchlen
  443. for ; i < copyend; i++ {
  444. f.window[i] = f.window[i-matchoffset]
  445. }
  446. }
  447. return int(i - start), f.err
  448. }
  449. // readBlock decodes the current block and returns the number of uncompressed bytes.
  450. func (f *decompressor) readBlock(start uint16) (int, error) {
  451. blockType, size, err := f.readBlockHeader()
  452. if err != nil {
  453. return 0, err
  454. }
  455. if blockType == uncompressedBlock {
  456. if size%2 == 1 {
  457. // Remember to realign the byte stream at the next block.
  458. f.unaligned = true
  459. }
  460. copied := 0
  461. if f.bo < f.bv {
  462. copied = int(size)
  463. s := int(start)
  464. if copied > f.bv-f.bo {
  465. copied = f.bv - f.bo
  466. }
  467. copy(f.window[s:s+copied], f.b[f.bo:f.bo+copied])
  468. f.bo += copied
  469. }
  470. n, err := io.ReadFull(f.r, f.window[start+uint16(copied):start+size])
  471. return copied + n, err
  472. }
  473. hmain, hlength, haligned, err := f.readTrees(blockType == alignedOffsetBlock)
  474. if err != nil {
  475. return 0, err
  476. }
  477. return f.readCompressedBlock(start, start+size, hmain, hlength, haligned)
  478. }
  479. // decodeE8 reverses the 0xe8 x86 instruction encoding that was performed
  480. // to the uncompressed data before it was compressed.
  481. func decodeE8(b []byte, off int64) {
  482. if off > maxe8offset || len(b) < 10 {
  483. return
  484. }
  485. for i := 0; i < len(b)-10; i++ {
  486. if b[i] == 0xe8 {
  487. currentPtr := int32(off) + int32(i)
  488. abs := int32(binary.LittleEndian.Uint32(b[i+1 : i+5]))
  489. if abs >= -currentPtr && abs < e8filesize {
  490. var rel int32
  491. if abs >= 0 {
  492. rel = abs - currentPtr
  493. } else {
  494. rel = abs + e8filesize
  495. }
  496. binary.LittleEndian.PutUint32(b[i+1:i+5], uint32(rel))
  497. }
  498. i += 4
  499. }
  500. }
  501. }
  502. func (f *decompressor) Read(b []byte) (int, error) {
  503. // Read and uncompress everything.
  504. if f.windowReader == nil {
  505. n := 0
  506. for n < f.uncompressed {
  507. k, err := f.readBlock(uint16(n))
  508. if err != nil {
  509. return 0, err
  510. }
  511. n += k
  512. }
  513. decodeE8(f.window[:f.uncompressed], 0)
  514. f.windowReader = bytes.NewReader(f.window[:f.uncompressed])
  515. }
  516. // Just read directly from the window.
  517. return f.windowReader.Read(b)
  518. }
  519. func (*decompressor) Close() error {
  520. return nil
  521. }
  522. // NewReader returns a new io.ReadCloser that decompresses a
  523. // WIM LZX stream until uncompressedSize bytes have been returned.
  524. func NewReader(r io.Reader, uncompressedSize int) (io.ReadCloser, error) {
  525. if uncompressedSize > windowSize {
  526. return nil, errors.New("uncompressed size is limited to 32KB")
  527. }
  528. f := &decompressor{
  529. lru: [3]uint16{1, 1, 1},
  530. uncompressed: uncompressedSize,
  531. b: make([]byte, 4096),
  532. r: r,
  533. }
  534. return f, nil
  535. }