1
2
3
4
5
6
7 package zstd
8
9 import (
10 "encoding/binary"
11 "errors"
12 "fmt"
13 "io"
14 )
15
16
17
18 var fuzzing = false
19
20
21 type Reader struct {
22
23 r io.Reader
24
25
26
27
28 sawFrameHeader bool
29
30
31 hasChecksum bool
32
33
34 readOneFrame bool
35
36
37 frameSizeUnknown bool
38
39
40
41 remainingFrameSize uint64
42
43
44
45 blockOffset int64
46
47
48 buffer []byte
49
50 off int
51
52
53 repeatedOffset1 uint32
54 repeatedOffset2 uint32
55 repeatedOffset3 uint32
56
57
58 huffmanTable []uint16
59 huffmanTableBits int
60
61
62 window window
63
64
65 compressedBuf []byte
66
67
68 literals []byte
69
70
71 seqTables [3][]fseBaselineEntry
72 seqTableBits [3]uint8
73
74
75 seqTableBuffers [3][]fseBaselineEntry
76
77
78 scratch [16]byte
79
80
81 fseScratch []fseEntry
82
83
84 checksum xxhash64
85 }
86
87
88 func NewReader(input io.Reader) *Reader {
89 r := new(Reader)
90 r.Reset(input)
91 return r
92 }
93
94
95
96 func (r *Reader) Reset(input io.Reader) {
97 r.r = input
98
99
100
101 r.sawFrameHeader = false
102 r.hasChecksum = false
103 r.readOneFrame = false
104 r.frameSizeUnknown = false
105 r.remainingFrameSize = 0
106 r.blockOffset = 0
107 r.buffer = r.buffer[:0]
108 r.off = 0
109
110
111
112
113
114
115
116
117
118
119
120
121
122 }
123
124
125 func (r *Reader) Read(p []byte) (int, error) {
126 if err := r.refillIfNeeded(); err != nil {
127 return 0, err
128 }
129 n := copy(p, r.buffer[r.off:])
130 r.off += n
131 return n, nil
132 }
133
134
135 func (r *Reader) ReadByte() (byte, error) {
136 if err := r.refillIfNeeded(); err != nil {
137 return 0, err
138 }
139 ret := r.buffer[r.off]
140 r.off++
141 return ret, nil
142 }
143
144
145 func (r *Reader) refillIfNeeded() error {
146 for r.off >= len(r.buffer) {
147 if err := r.refill(); err != nil {
148 return err
149 }
150 r.off = 0
151 }
152 return nil
153 }
154
155
156 func (r *Reader) refill() error {
157 if !r.sawFrameHeader {
158 if err := r.readFrameHeader(); err != nil {
159 return err
160 }
161 }
162 return r.readBlock()
163 }
164
165
166 func (r *Reader) readFrameHeader() error {
167 retry:
168 relativeOffset := 0
169
170
171 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
172
173 if err == io.EOF && !r.readOneFrame {
174 err = io.ErrUnexpectedEOF
175 }
176 return r.wrapError(relativeOffset, err)
177 }
178
179 if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
180 if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
181
182 r.blockOffset += int64(relativeOffset) + 4
183 if err := r.skipFrame(); err != nil {
184 return err
185 }
186 r.readOneFrame = true
187 goto retry
188 }
189
190 return r.makeError(relativeOffset, "invalid magic number")
191 }
192
193 relativeOffset += 4
194
195
196 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
197 return r.wrapNonEOFError(relativeOffset, err)
198 }
199 descriptor := r.scratch[0]
200
201 singleSegment := descriptor&(1<<5) != 0
202
203 fcsFieldSize := 1 << (descriptor >> 6)
204 if fcsFieldSize == 1 && !singleSegment {
205 fcsFieldSize = 0
206 }
207
208 var windowDescriptorSize int
209 if singleSegment {
210 windowDescriptorSize = 0
211 } else {
212 windowDescriptorSize = 1
213 }
214
215 if descriptor&(1<<3) != 0 {
216 return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
217 }
218
219 r.hasChecksum = descriptor&(1<<2) != 0
220 if r.hasChecksum {
221 r.checksum.reset()
222 }
223
224
225 dictionaryIdSize := 0
226 if dictIdFlag := descriptor & 3; dictIdFlag != 0 {
227 dictionaryIdSize = 1 << (dictIdFlag - 1)
228 }
229
230 relativeOffset++
231
232 headerSize := windowDescriptorSize + dictionaryIdSize + fcsFieldSize
233
234 if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
235 return r.wrapNonEOFError(relativeOffset, err)
236 }
237
238
239
240 var windowSize uint64
241 if !singleSegment {
242
243 windowDescriptor := r.scratch[0]
244 exponent := uint64(windowDescriptor >> 3)
245 mantissa := uint64(windowDescriptor & 7)
246 windowLog := exponent + 10
247 windowBase := uint64(1) << windowLog
248 windowAdd := (windowBase / 8) * mantissa
249 windowSize = windowBase + windowAdd
250
251
252 if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
253 return r.makeError(relativeOffset, "windowSize too large")
254 }
255 }
256
257
258 if dictionaryIdSize != 0 {
259 dictionaryId := r.scratch[windowDescriptorSize : windowDescriptorSize+dictionaryIdSize]
260
261 for _, b := range dictionaryId {
262 if b != 0 {
263 return r.makeError(relativeOffset, "dictionaries are not supported")
264 }
265 }
266 }
267
268
269 r.frameSizeUnknown = false
270 r.remainingFrameSize = 0
271 fb := r.scratch[windowDescriptorSize+dictionaryIdSize:]
272 switch fcsFieldSize {
273 case 0:
274 r.frameSizeUnknown = true
275 case 1:
276 r.remainingFrameSize = uint64(fb[0])
277 case 2:
278 r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
279 case 4:
280 r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
281 case 8:
282 r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
283 default:
284 panic("unreachable")
285 }
286
287
288
289
290 if singleSegment {
291 windowSize = r.remainingFrameSize
292 }
293
294
295 const maxWindowSize = 8 << 20
296 if windowSize > maxWindowSize {
297 windowSize = maxWindowSize
298 }
299
300 relativeOffset += headerSize
301
302 r.sawFrameHeader = true
303 r.readOneFrame = true
304 r.blockOffset += int64(relativeOffset)
305
306
307 r.repeatedOffset1 = 1
308 r.repeatedOffset2 = 4
309 r.repeatedOffset3 = 8
310 r.huffmanTableBits = 0
311 r.window.reset(int(windowSize))
312 r.seqTables[0] = nil
313 r.seqTables[1] = nil
314 r.seqTables[2] = nil
315
316 return nil
317 }
318
319
320 func (r *Reader) skipFrame() error {
321 relativeOffset := 0
322
323 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
324 return r.wrapNonEOFError(relativeOffset, err)
325 }
326
327 relativeOffset += 4
328
329 size := binary.LittleEndian.Uint32(r.scratch[:4])
330 if size == 0 {
331 r.blockOffset += int64(relativeOffset)
332 return nil
333 }
334
335 if seeker, ok := r.r.(io.Seeker); ok {
336 r.blockOffset += int64(relativeOffset)
337
338
339 prev, err := seeker.Seek(0, io.SeekCurrent)
340 if err != nil {
341 return r.wrapError(0, err)
342 }
343 end, err := seeker.Seek(0, io.SeekEnd)
344 if err != nil {
345 return r.wrapError(0, err)
346 }
347 if prev > end-int64(size) {
348 r.blockOffset += end - prev
349 return r.makeEOFError(0)
350 }
351
352
353 _, err = seeker.Seek(prev+int64(size), io.SeekStart)
354 if err != nil {
355 return r.wrapError(0, err)
356 }
357 r.blockOffset += int64(size)
358 return nil
359 }
360
361 n, err := io.CopyN(io.Discard, r.r, int64(size))
362 relativeOffset += int(n)
363 if err != nil {
364 return r.wrapNonEOFError(relativeOffset, err)
365 }
366 r.blockOffset += int64(relativeOffset)
367 return nil
368 }
369
370
371 func (r *Reader) readBlock() error {
372 relativeOffset := 0
373
374
375 if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
376 return r.wrapNonEOFError(relativeOffset, err)
377 }
378
379 relativeOffset += 3
380
381 header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
382
383 lastBlock := header&1 != 0
384 blockType := (header >> 1) & 3
385 blockSize := int(header >> 3)
386
387
388
389
390 if blockSize > 128<<10 || (r.window.size > 0 && blockSize > r.window.size) {
391 return r.makeError(relativeOffset, "block size too large")
392 }
393
394
395 switch blockType {
396 case 0:
397 r.setBufferSize(blockSize)
398 if _, err := io.ReadFull(r.r, r.buffer); err != nil {
399 return r.wrapNonEOFError(relativeOffset, err)
400 }
401 relativeOffset += blockSize
402 r.blockOffset += int64(relativeOffset)
403 case 1:
404 r.setBufferSize(blockSize)
405 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
406 return r.wrapNonEOFError(relativeOffset, err)
407 }
408 relativeOffset++
409 v := r.scratch[0]
410 for i := range r.buffer {
411 r.buffer[i] = v
412 }
413 r.blockOffset += int64(relativeOffset)
414 case 2:
415 r.blockOffset += int64(relativeOffset)
416 if err := r.compressedBlock(blockSize); err != nil {
417 return err
418 }
419 r.blockOffset += int64(blockSize)
420 case 3:
421 return r.makeError(relativeOffset, "invalid block type")
422 }
423
424 if !r.frameSizeUnknown {
425 if uint64(len(r.buffer)) > r.remainingFrameSize {
426 return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
427 }
428 r.remainingFrameSize -= uint64(len(r.buffer))
429 }
430
431 if r.hasChecksum {
432 r.checksum.update(r.buffer)
433 }
434
435 if !lastBlock {
436 r.window.save(r.buffer)
437 } else {
438 if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
439 return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
440 }
441
442 if r.hasChecksum {
443 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
444 return r.wrapNonEOFError(0, err)
445 }
446
447 inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
448 dataChecksum := uint32(r.checksum.digest())
449 if inputChecksum != dataChecksum {
450 return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
451 }
452
453 r.blockOffset += 4
454 }
455 r.sawFrameHeader = false
456 }
457
458 return nil
459 }
460
461
462
463 func (r *Reader) setBufferSize(size int) {
464 if cap(r.buffer) < size {
465 need := size - cap(r.buffer)
466 r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...)
467 }
468 r.buffer = r.buffer[:size]
469 }
470
471
472 type zstdError struct {
473 offset int64
474 err error
475 }
476
477 func (ze *zstdError) Error() string {
478 return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
479 }
480
481 func (ze *zstdError) Unwrap() error {
482 return ze.err
483 }
484
485 func (r *Reader) makeEOFError(off int) error {
486 return r.wrapError(off, io.ErrUnexpectedEOF)
487 }
488
489 func (r *Reader) wrapNonEOFError(off int, err error) error {
490 if err == io.EOF {
491 err = io.ErrUnexpectedEOF
492 }
493 return r.wrapError(off, err)
494 }
495
496 func (r *Reader) makeError(off int, msg string) error {
497 return r.wrapError(off, errors.New(msg))
498 }
499
500 func (r *Reader) wrapError(off int, err error) error {
501 if err == io.EOF {
502 return err
503 }
504 return &zstdError{r.blockOffset + int64(off), err}
505 }
506
View as plain text