seqdec_amd64.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. //go:build amd64 && !appengine && !noasm && gc
  2. // +build amd64,!appengine,!noasm,gc
  3. package zstd
  4. import (
  5. "fmt"
  6. "github.com/klauspost/compress/internal/cpuinfo"
  7. )
  8. type decodeSyncAsmContext struct {
  9. llTable []decSymbol
  10. mlTable []decSymbol
  11. ofTable []decSymbol
  12. llState uint64
  13. mlState uint64
  14. ofState uint64
  15. iteration int
  16. litRemain int
  17. out []byte
  18. outPosition int
  19. literals []byte
  20. litPosition int
  21. history []byte
  22. windowSize int
  23. ll int // set on error (not for all errors, please refer to _generate/gen.go)
  24. ml int // set on error (not for all errors, please refer to _generate/gen.go)
  25. mo int // set on error (not for all errors, please refer to _generate/gen.go)
  26. }
  27. // sequenceDecs_decodeSync_amd64 implements the main loop of sequenceDecs.decodeSync in x86 asm.
  28. //
  29. // Please refer to seqdec_generic.go for the reference implementation.
  30. //
  31. //go:noescape
  32. func sequenceDecs_decodeSync_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
  33. // sequenceDecs_decodeSync_bmi2 implements the main loop of sequenceDecs.decodeSync in x86 asm with BMI2 extensions.
  34. //
  35. //go:noescape
  36. func sequenceDecs_decodeSync_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
  37. // sequenceDecs_decodeSync_safe_amd64 does the same as above, but does not write more than output buffer.
  38. //
  39. //go:noescape
  40. func sequenceDecs_decodeSync_safe_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
  41. // sequenceDecs_decodeSync_safe_bmi2 does the same as above, but does not write more than output buffer.
  42. //
  43. //go:noescape
  44. func sequenceDecs_decodeSync_safe_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
  45. // decode sequences from the stream with the provided history but without a dictionary.
  46. func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) {
  47. if len(s.dict) > 0 {
  48. return false, nil
  49. }
  50. if s.maxSyncLen == 0 && cap(s.out)-len(s.out) < maxCompressedBlockSize {
  51. return false, nil
  52. }
  53. // FIXME: Using unsafe memory copies leads to rare, random crashes
  54. // with fuzz testing. It is therefore disabled for now.
  55. const useSafe = true
  56. /*
  57. useSafe := false
  58. if s.maxSyncLen == 0 && cap(s.out)-len(s.out) < maxCompressedBlockSizeAlloc {
  59. useSafe = true
  60. }
  61. if s.maxSyncLen > 0 && cap(s.out)-len(s.out)-compressedBlockOverAlloc < int(s.maxSyncLen) {
  62. useSafe = true
  63. }
  64. if cap(s.literals) < len(s.literals)+compressedBlockOverAlloc {
  65. useSafe = true
  66. }
  67. */
  68. br := s.br
  69. maxBlockSize := maxCompressedBlockSize
  70. if s.windowSize < maxBlockSize {
  71. maxBlockSize = s.windowSize
  72. }
  73. ctx := decodeSyncAsmContext{
  74. llTable: s.litLengths.fse.dt[:maxTablesize],
  75. mlTable: s.matchLengths.fse.dt[:maxTablesize],
  76. ofTable: s.offsets.fse.dt[:maxTablesize],
  77. llState: uint64(s.litLengths.state.state),
  78. mlState: uint64(s.matchLengths.state.state),
  79. ofState: uint64(s.offsets.state.state),
  80. iteration: s.nSeqs - 1,
  81. litRemain: len(s.literals),
  82. out: s.out,
  83. outPosition: len(s.out),
  84. literals: s.literals,
  85. windowSize: s.windowSize,
  86. history: hist,
  87. }
  88. s.seqSize = 0
  89. startSize := len(s.out)
  90. var errCode int
  91. if cpuinfo.HasBMI2() {
  92. if useSafe {
  93. errCode = sequenceDecs_decodeSync_safe_bmi2(s, br, &ctx)
  94. } else {
  95. errCode = sequenceDecs_decodeSync_bmi2(s, br, &ctx)
  96. }
  97. } else {
  98. if useSafe {
  99. errCode = sequenceDecs_decodeSync_safe_amd64(s, br, &ctx)
  100. } else {
  101. errCode = sequenceDecs_decodeSync_amd64(s, br, &ctx)
  102. }
  103. }
  104. switch errCode {
  105. case noError:
  106. break
  107. case errorMatchLenOfsMismatch:
  108. return true, fmt.Errorf("zero matchoff and matchlen (%d) > 0", ctx.ml)
  109. case errorMatchLenTooBig:
  110. return true, fmt.Errorf("match len (%d) bigger than max allowed length", ctx.ml)
  111. case errorMatchOffTooBig:
  112. return true, fmt.Errorf("match offset (%d) bigger than current history (%d)",
  113. ctx.mo, ctx.outPosition+len(hist)-startSize)
  114. case errorNotEnoughLiterals:
  115. return true, fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available",
  116. ctx.ll, ctx.litRemain+ctx.ll)
  117. case errorNotEnoughSpace:
  118. size := ctx.outPosition + ctx.ll + ctx.ml
  119. if debugDecoder {
  120. println("msl:", s.maxSyncLen, "cap", cap(s.out), "bef:", startSize, "sz:", size-startSize, "mbs:", maxBlockSize, "outsz:", cap(s.out)-startSize)
  121. }
  122. return true, fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
  123. default:
  124. return true, fmt.Errorf("sequenceDecs_decode returned erronous code %d", errCode)
  125. }
  126. s.seqSize += ctx.litRemain
  127. if s.seqSize > maxBlockSize {
  128. return true, fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
  129. }
  130. err := br.close()
  131. if err != nil {
  132. printf("Closing sequences: %v, %+v\n", err, *br)
  133. return true, err
  134. }
  135. s.literals = s.literals[ctx.litPosition:]
  136. t := ctx.outPosition
  137. s.out = s.out[:t]
  138. // Add final literals
  139. s.out = append(s.out, s.literals...)
  140. if debugDecoder {
  141. t += len(s.literals)
  142. if t != len(s.out) {
  143. panic(fmt.Errorf("length mismatch, want %d, got %d", len(s.out), t))
  144. }
  145. }
  146. return true, nil
  147. }
  148. // --------------------------------------------------------------------------------
  149. type decodeAsmContext struct {
  150. llTable []decSymbol
  151. mlTable []decSymbol
  152. ofTable []decSymbol
  153. llState uint64
  154. mlState uint64
  155. ofState uint64
  156. iteration int
  157. seqs []seqVals
  158. litRemain int
  159. }
  160. const noError = 0
  161. // error reported when mo == 0 && ml > 0
  162. const errorMatchLenOfsMismatch = 1
  163. // error reported when ml > maxMatchLen
  164. const errorMatchLenTooBig = 2
  165. // error reported when mo > available history or mo > s.windowSize
  166. const errorMatchOffTooBig = 3
  167. // error reported when the sum of literal lengths exeeceds the literal buffer size
  168. const errorNotEnoughLiterals = 4
  169. // error reported when capacity of `out` is too small
  170. const errorNotEnoughSpace = 5
  171. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
  172. //
  173. // Please refer to seqdec_generic.go for the reference implementation.
  174. //
  175. //go:noescape
  176. func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
  177. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
  178. //
  179. // Please refer to seqdec_generic.go for the reference implementation.
  180. //
  181. //go:noescape
  182. func sequenceDecs_decode_56_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
  183. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions.
  184. //
  185. //go:noescape
  186. func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
  187. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions.
  188. //
  189. //go:noescape
  190. func sequenceDecs_decode_56_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
  191. // decode sequences from the stream without the provided history.
  192. func (s *sequenceDecs) decode(seqs []seqVals) error {
  193. br := s.br
  194. maxBlockSize := maxCompressedBlockSize
  195. if s.windowSize < maxBlockSize {
  196. maxBlockSize = s.windowSize
  197. }
  198. ctx := decodeAsmContext{
  199. llTable: s.litLengths.fse.dt[:maxTablesize],
  200. mlTable: s.matchLengths.fse.dt[:maxTablesize],
  201. ofTable: s.offsets.fse.dt[:maxTablesize],
  202. llState: uint64(s.litLengths.state.state),
  203. mlState: uint64(s.matchLengths.state.state),
  204. ofState: uint64(s.offsets.state.state),
  205. seqs: seqs,
  206. iteration: len(seqs) - 1,
  207. litRemain: len(s.literals),
  208. }
  209. s.seqSize = 0
  210. lte56bits := s.maxBits+s.offsets.fse.actualTableLog+s.matchLengths.fse.actualTableLog+s.litLengths.fse.actualTableLog <= 56
  211. var errCode int
  212. if cpuinfo.HasBMI2() {
  213. if lte56bits {
  214. errCode = sequenceDecs_decode_56_bmi2(s, br, &ctx)
  215. } else {
  216. errCode = sequenceDecs_decode_bmi2(s, br, &ctx)
  217. }
  218. } else {
  219. if lte56bits {
  220. errCode = sequenceDecs_decode_56_amd64(s, br, &ctx)
  221. } else {
  222. errCode = sequenceDecs_decode_amd64(s, br, &ctx)
  223. }
  224. }
  225. if errCode != 0 {
  226. i := len(seqs) - ctx.iteration - 1
  227. switch errCode {
  228. case errorMatchLenOfsMismatch:
  229. ml := ctx.seqs[i].ml
  230. return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml)
  231. case errorMatchLenTooBig:
  232. ml := ctx.seqs[i].ml
  233. return fmt.Errorf("match len (%d) bigger than max allowed length", ml)
  234. case errorNotEnoughLiterals:
  235. ll := ctx.seqs[i].ll
  236. return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, ctx.litRemain+ll)
  237. }
  238. return fmt.Errorf("sequenceDecs_decode_amd64 returned erronous code %d", errCode)
  239. }
  240. if ctx.litRemain < 0 {
  241. return fmt.Errorf("literal count is too big: total available %d, total requested %d",
  242. len(s.literals), len(s.literals)-ctx.litRemain)
  243. }
  244. s.seqSize += ctx.litRemain
  245. if s.seqSize > maxBlockSize {
  246. return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
  247. }
  248. err := br.close()
  249. if err != nil {
  250. printf("Closing sequences: %v, %+v\n", err, *br)
  251. }
  252. return err
  253. }
  254. // --------------------------------------------------------------------------------
  255. type executeAsmContext struct {
  256. seqs []seqVals
  257. seqIndex int
  258. out []byte
  259. history []byte
  260. literals []byte
  261. outPosition int
  262. litPosition int
  263. windowSize int
  264. }
  265. // sequenceDecs_executeSimple_amd64 implements the main loop of sequenceDecs.executeSimple in x86 asm.
  266. //
  267. // Returns false if a match offset is too big.
  268. //
  269. // Please refer to seqdec_generic.go for the reference implementation.
  270. //
  271. //go:noescape
  272. func sequenceDecs_executeSimple_amd64(ctx *executeAsmContext) bool
  273. // Same as above, but with safe memcopies
  274. //
  275. //go:noescape
  276. func sequenceDecs_executeSimple_safe_amd64(ctx *executeAsmContext) bool
  277. // executeSimple handles cases when dictionary is not used.
  278. func (s *sequenceDecs) executeSimple(seqs []seqVals, hist []byte) error {
  279. // Ensure we have enough output size...
  280. if len(s.out)+s.seqSize+compressedBlockOverAlloc > cap(s.out) {
  281. addBytes := s.seqSize + len(s.out) + compressedBlockOverAlloc
  282. s.out = append(s.out, make([]byte, addBytes)...)
  283. s.out = s.out[:len(s.out)-addBytes]
  284. }
  285. if debugDecoder {
  286. printf("Execute %d seqs with literals: %d into %d bytes\n", len(seqs), len(s.literals), s.seqSize)
  287. }
  288. var t = len(s.out)
  289. out := s.out[:t+s.seqSize]
  290. ctx := executeAsmContext{
  291. seqs: seqs,
  292. seqIndex: 0,
  293. out: out,
  294. history: hist,
  295. outPosition: t,
  296. litPosition: 0,
  297. literals: s.literals,
  298. windowSize: s.windowSize,
  299. }
  300. var ok bool
  301. if cap(s.literals) < len(s.literals)+compressedBlockOverAlloc {
  302. ok = sequenceDecs_executeSimple_safe_amd64(&ctx)
  303. } else {
  304. ok = sequenceDecs_executeSimple_amd64(&ctx)
  305. }
  306. if !ok {
  307. return fmt.Errorf("match offset (%d) bigger than current history (%d)",
  308. seqs[ctx.seqIndex].mo, ctx.outPosition+len(hist))
  309. }
  310. s.literals = s.literals[ctx.litPosition:]
  311. t = ctx.outPosition
  312. // Add final literals
  313. copy(out[t:], s.literals)
  314. if debugDecoder {
  315. t += len(s.literals)
  316. if t != len(out) {
  317. panic(fmt.Errorf("length mismatch, want %d, got %d, ss: %d", len(out), t, s.seqSize))
  318. }
  319. }
  320. s.out = out
  321. return nil
  322. }