1
2
3
4
5
6
7 package zstd
8
9 import (
10 "encoding/binary"
11 "errors"
12 "fmt"
13 "io"
14 )
15
16
17
18 var fuzzing = false
19
20
21 type Reader struct {
22
23 r io.Reader
24
25
26
27
28 sawFrameHeader bool
29
30
31 hasChecksum bool
32
33
34 readOneFrame bool
35
36
37 frameSizeUnknown bool
38
39
40
41 remainingFrameSize uint64
42
43
44
45 blockOffset int64
46
47
48 buffer []byte
49
50 off int
51
52
53 repeatedOffset1 uint32
54 repeatedOffset2 uint32
55 repeatedOffset3 uint32
56
57
58 huffmanTable []uint16
59 huffmanTableBits int
60
61
62 window window
63
64
65 compressedBuf []byte
66
67
68 literals []byte
69
70
71 seqTables [3][]fseBaselineEntry
72 seqTableBits [3]uint8
73
74
75 seqTableBuffers [3][]fseBaselineEntry
76
77
78 scratch [16]byte
79
80
81 fseScratch []fseEntry
82
83
84 checksum xxhash64
85 }
86
87
88 func NewReader(input io.Reader) *Reader {
89 r := new(Reader)
90 r.Reset(input)
91 return r
92 }
93
94
95
96 func (r *Reader) Reset(input io.Reader) {
97 r.r = input
98
99
100
101 r.sawFrameHeader = false
102 r.hasChecksum = false
103 r.readOneFrame = false
104 r.frameSizeUnknown = false
105 r.remainingFrameSize = 0
106 r.blockOffset = 0
107 r.buffer = r.buffer[:0]
108 r.off = 0
109
110
111
112
113
114
115
116
117
118
119
120
121
122 }
123
124
125 func (r *Reader) Read(p []byte) (int, error) {
126 if err := r.refillIfNeeded(); err != nil {
127 return 0, err
128 }
129 n := copy(p, r.buffer[r.off:])
130 r.off += n
131 return n, nil
132 }
133
134
135 func (r *Reader) ReadByte() (byte, error) {
136 if err := r.refillIfNeeded(); err != nil {
137 return 0, err
138 }
139 ret := r.buffer[r.off]
140 r.off++
141 return ret, nil
142 }
143
144
145 func (r *Reader) refillIfNeeded() error {
146 for r.off >= len(r.buffer) {
147 if err := r.refill(); err != nil {
148 return err
149 }
150 r.off = 0
151 }
152 return nil
153 }
154
155
156 func (r *Reader) refill() error {
157 if !r.sawFrameHeader {
158 if err := r.readFrameHeader(); err != nil {
159 return err
160 }
161 }
162 return r.readBlock()
163 }
164
165
166 func (r *Reader) readFrameHeader() error {
167 retry:
168 relativeOffset := 0
169
170
171 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
172
173 if err == io.EOF && !r.readOneFrame {
174 err = io.ErrUnexpectedEOF
175 }
176 return r.wrapError(relativeOffset, err)
177 }
178
179 if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
180 if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
181
182 r.blockOffset += int64(relativeOffset) + 4
183 if err := r.skipFrame(); err != nil {
184 return err
185 }
186 r.readOneFrame = true
187 goto retry
188 }
189
190 return r.makeError(relativeOffset, "invalid magic number")
191 }
192
193 relativeOffset += 4
194
195
196 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
197 return r.wrapNonEOFError(relativeOffset, err)
198 }
199 descriptor := r.scratch[0]
200
201 singleSegment := descriptor&(1<<5) != 0
202
203 fcsFieldSize := 1 << (descriptor >> 6)
204 if fcsFieldSize == 1 && !singleSegment {
205 fcsFieldSize = 0
206 }
207
208 var windowDescriptorSize int
209 if singleSegment {
210 windowDescriptorSize = 0
211 } else {
212 windowDescriptorSize = 1
213 }
214
215 if descriptor&(1<<3) != 0 {
216 return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
217 }
218
219 r.hasChecksum = descriptor&(1<<2) != 0
220 if r.hasChecksum {
221 r.checksum.reset()
222 }
223
224
225 dictionaryIdSize := 0
226 if dictIdFlag := descriptor & 3; dictIdFlag != 0 {
227 dictionaryIdSize = 1 << (dictIdFlag - 1)
228 }
229
230 relativeOffset++
231
232 headerSize := windowDescriptorSize + dictionaryIdSize + fcsFieldSize
233
234 if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
235 return r.wrapNonEOFError(relativeOffset, err)
236 }
237
238
239
240 var windowSize uint64
241 if !singleSegment {
242
243 windowDescriptor := r.scratch[0]
244 exponent := uint64(windowDescriptor >> 3)
245 mantissa := uint64(windowDescriptor & 7)
246 windowLog := exponent + 10
247 windowBase := uint64(1) << windowLog
248 windowAdd := (windowBase / 8) * mantissa
249 windowSize = windowBase + windowAdd
250
251
252 if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
253 return r.makeError(relativeOffset, "windowSize too large")
254 }
255 }
256
257
258 if dictionaryIdSize != 0 {
259 dictionaryId := r.scratch[windowDescriptorSize : windowDescriptorSize+dictionaryIdSize]
260
261 for _, b := range dictionaryId {
262 if b != 0 {
263 return r.makeError(relativeOffset, "dictionaries are not supported")
264 }
265 }
266 }
267
268
269 r.frameSizeUnknown = false
270 r.remainingFrameSize = 0
271 fb := r.scratch[windowDescriptorSize+dictionaryIdSize:]
272 switch fcsFieldSize {
273 case 0:
274 r.frameSizeUnknown = true
275 case 1:
276 r.remainingFrameSize = uint64(fb[0])
277 case 2:
278 r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
279 case 4:
280 r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
281 case 8:
282 r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
283 default:
284 panic("unreachable")
285 }
286
287
288
289
290 if singleSegment {
291 windowSize = r.remainingFrameSize
292 }
293
294
295 const maxWindowSize = 8 << 20
296 if windowSize > maxWindowSize {
297 windowSize = maxWindowSize
298 }
299
300 relativeOffset += headerSize
301
302 r.sawFrameHeader = true
303 r.readOneFrame = true
304 r.blockOffset += int64(relativeOffset)
305
306
307 r.repeatedOffset1 = 1
308 r.repeatedOffset2 = 4
309 r.repeatedOffset3 = 8
310 r.huffmanTableBits = 0
311 r.window.reset(int(windowSize))
312 r.seqTables[0] = nil
313 r.seqTables[1] = nil
314 r.seqTables[2] = nil
315
316 return nil
317 }
318
319
320 func (r *Reader) skipFrame() error {
321 relativeOffset := 0
322
323 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
324 return r.wrapNonEOFError(relativeOffset, err)
325 }
326
327 relativeOffset += 4
328
329 size := binary.LittleEndian.Uint32(r.scratch[:4])
330 if size == 0 {
331 r.blockOffset += int64(relativeOffset)
332 return nil
333 }
334
335 if seeker, ok := r.r.(io.Seeker); ok {
336 r.blockOffset += int64(relativeOffset)
337
338
339 prev, err := seeker.Seek(0, io.SeekCurrent)
340 if err != nil {
341 return r.wrapError(0, err)
342 }
343 end, err := seeker.Seek(0, io.SeekEnd)
344 if err != nil {
345 return r.wrapError(0, err)
346 }
347 if prev > end-int64(size) {
348 r.blockOffset += end - prev
349 return r.makeEOFError(0)
350 }
351
352
353 _, err = seeker.Seek(prev+int64(size), io.SeekStart)
354 if err != nil {
355 return r.wrapError(0, err)
356 }
357 r.blockOffset += int64(size)
358 return nil
359 }
360
361 var skip []byte
362 const chunk = 1 << 20
363 for size >= chunk {
364 if len(skip) == 0 {
365 skip = make([]byte, chunk)
366 }
367 if _, err := io.ReadFull(r.r, skip); err != nil {
368 return r.wrapNonEOFError(relativeOffset, err)
369 }
370 relativeOffset += chunk
371 size -= chunk
372 }
373 if size > 0 {
374 if len(skip) == 0 {
375 skip = make([]byte, size)
376 }
377 if _, err := io.ReadFull(r.r, skip); err != nil {
378 return r.wrapNonEOFError(relativeOffset, err)
379 }
380 relativeOffset += int(size)
381 }
382
383 r.blockOffset += int64(relativeOffset)
384
385 return nil
386 }
387
388
389 func (r *Reader) readBlock() error {
390 relativeOffset := 0
391
392
393 if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
394 return r.wrapNonEOFError(relativeOffset, err)
395 }
396
397 relativeOffset += 3
398
399 header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
400
401 lastBlock := header&1 != 0
402 blockType := (header >> 1) & 3
403 blockSize := int(header >> 3)
404
405
406
407
408 if blockSize > 128<<10 || (r.window.size > 0 && blockSize > r.window.size) {
409 return r.makeError(relativeOffset, "block size too large")
410 }
411
412
413 switch blockType {
414 case 0:
415 r.setBufferSize(blockSize)
416 if _, err := io.ReadFull(r.r, r.buffer); err != nil {
417 return r.wrapNonEOFError(relativeOffset, err)
418 }
419 relativeOffset += blockSize
420 r.blockOffset += int64(relativeOffset)
421 case 1:
422 r.setBufferSize(blockSize)
423 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
424 return r.wrapNonEOFError(relativeOffset, err)
425 }
426 relativeOffset++
427 v := r.scratch[0]
428 for i := range r.buffer {
429 r.buffer[i] = v
430 }
431 r.blockOffset += int64(relativeOffset)
432 case 2:
433 r.blockOffset += int64(relativeOffset)
434 if err := r.compressedBlock(blockSize); err != nil {
435 return err
436 }
437 r.blockOffset += int64(blockSize)
438 case 3:
439 return r.makeError(relativeOffset, "invalid block type")
440 }
441
442 if !r.frameSizeUnknown {
443 if uint64(len(r.buffer)) > r.remainingFrameSize {
444 return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
445 }
446 r.remainingFrameSize -= uint64(len(r.buffer))
447 }
448
449 if r.hasChecksum {
450 r.checksum.update(r.buffer)
451 }
452
453 if !lastBlock {
454 r.window.save(r.buffer)
455 } else {
456 if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
457 return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
458 }
459
460 if r.hasChecksum {
461 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
462 return r.wrapNonEOFError(0, err)
463 }
464
465 inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
466 dataChecksum := uint32(r.checksum.digest())
467 if inputChecksum != dataChecksum {
468 return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
469 }
470
471 r.blockOffset += 4
472 }
473 r.sawFrameHeader = false
474 }
475
476 return nil
477 }
478
479
480
481 func (r *Reader) setBufferSize(size int) {
482 if cap(r.buffer) < size {
483 need := size - cap(r.buffer)
484 r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...)
485 }
486 r.buffer = r.buffer[:size]
487 }
488
489
490 type zstdError struct {
491 offset int64
492 err error
493 }
494
495 func (ze *zstdError) Error() string {
496 return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
497 }
498
499 func (ze *zstdError) Unwrap() error {
500 return ze.err
501 }
502
503 func (r *Reader) makeEOFError(off int) error {
504 return r.wrapError(off, io.ErrUnexpectedEOF)
505 }
506
507 func (r *Reader) wrapNonEOFError(off int, err error) error {
508 if err == io.EOF {
509 err = io.ErrUnexpectedEOF
510 }
511 return r.wrapError(off, err)
512 }
513
514 func (r *Reader) makeError(off int, msg string) error {
515 return r.wrapError(off, errors.New(msg))
516 }
517
518 func (r *Reader) wrapError(off int, err error) error {
519 if err == io.EOF {
520 return err
521 }
522 return &zstdError{r.blockOffset + int64(off), err}
523 }
524
View as plain text