Source file src/internal/zstd/zstd.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package zstd provides a decompressor for zstd streams,
     6  // described in RFC 8878. It does not support dictionaries.
     7  package zstd
     8  
     9  import (
    10  	"encoding/binary"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  )
    15  
    16  // fuzzing is a fuzzer hook set to true when fuzzing.
    17  // This is used to reject cases where we don't match zstd.
    18  var fuzzing = false
    19  
    20  // Reader implements [io.Reader] to read a zstd compressed stream.
    21  type Reader struct {
    22  	// The underlying Reader.
    23  	r io.Reader
    24  
    25  	// Whether we have read the frame header.
    26  	// This is of interest when buffer is empty.
    27  	// If true we expect to see a new block.
    28  	sawFrameHeader bool
    29  
    30  	// Whether the current frame expects a checksum.
    31  	hasChecksum bool
    32  
    33  	// Whether we have read at least one frame.
    34  	readOneFrame bool
    35  
    36  	// True if the frame size is not known.
    37  	frameSizeUnknown bool
    38  
    39  	// The number of uncompressed bytes remaining in the current frame.
    40  	// If frameSizeUnknown is true, this is not valid.
    41  	remainingFrameSize uint64
    42  
    43  	// The number of bytes read from r up to the start of the current
    44  	// block, for error reporting.
    45  	blockOffset int64
    46  
    47  	// Buffered decompressed data.
    48  	buffer []byte
    49  	// Current read offset in buffer.
    50  	off int
    51  
    52  	// The current repeated offsets.
    53  	repeatedOffset1 uint32
    54  	repeatedOffset2 uint32
    55  	repeatedOffset3 uint32
    56  
    57  	// The current Huffman tree used for compressing literals.
    58  	huffmanTable     []uint16
    59  	huffmanTableBits int
    60  
    61  	// The window for back references.
    62  	window window
    63  
    64  	// A buffer available to hold a compressed block.
    65  	compressedBuf []byte
    66  
    67  	// A buffer for literals.
    68  	literals []byte
    69  
    70  	// Sequence decode FSE tables.
    71  	seqTables    [3][]fseBaselineEntry
    72  	seqTableBits [3]uint8
    73  
    74  	// Buffers for sequence decode FSE tables.
    75  	seqTableBuffers [3][]fseBaselineEntry
    76  
    77  	// Scratch space used for small reads, to avoid allocation.
    78  	scratch [16]byte
    79  
    80  	// A scratch table for reading an FSE. Only temporarily valid.
    81  	fseScratch []fseEntry
    82  
    83  	// For checksum computation.
    84  	checksum xxhash64
    85  }
    86  
    87  // NewReader creates a new Reader that decompresses data from the given reader.
    88  func NewReader(input io.Reader) *Reader {
    89  	r := new(Reader)
    90  	r.Reset(input)
    91  	return r
    92  }
    93  
    94  // Reset discards the current state and starts reading a new stream from r.
    95  // This permits reusing a Reader rather than allocating a new one.
    96  func (r *Reader) Reset(input io.Reader) {
    97  	r.r = input
    98  
    99  	// Several fields are preserved to avoid allocation.
   100  	// Others are always set before they are used.
   101  	r.sawFrameHeader = false
   102  	r.hasChecksum = false
   103  	r.readOneFrame = false
   104  	r.frameSizeUnknown = false
   105  	r.remainingFrameSize = 0
   106  	r.blockOffset = 0
   107  	r.buffer = r.buffer[:0]
   108  	r.off = 0
   109  	// repeatedOffset1
   110  	// repeatedOffset2
   111  	// repeatedOffset3
   112  	// huffmanTable
   113  	// huffmanTableBits
   114  	// window
   115  	// compressedBuf
   116  	// literals
   117  	// seqTables
   118  	// seqTableBits
   119  	// seqTableBuffers
   120  	// scratch
   121  	// fseScratch
   122  }
   123  
   124  // Read implements [io.Reader].
   125  func (r *Reader) Read(p []byte) (int, error) {
   126  	if err := r.refillIfNeeded(); err != nil {
   127  		return 0, err
   128  	}
   129  	n := copy(p, r.buffer[r.off:])
   130  	r.off += n
   131  	return n, nil
   132  }
   133  
   134  // ReadByte implements [io.ByteReader].
   135  func (r *Reader) ReadByte() (byte, error) {
   136  	if err := r.refillIfNeeded(); err != nil {
   137  		return 0, err
   138  	}
   139  	ret := r.buffer[r.off]
   140  	r.off++
   141  	return ret, nil
   142  }
   143  
   144  // refillIfNeeded reads the next block if necessary.
   145  func (r *Reader) refillIfNeeded() error {
   146  	for r.off >= len(r.buffer) {
   147  		if err := r.refill(); err != nil {
   148  			return err
   149  		}
   150  		r.off = 0
   151  	}
   152  	return nil
   153  }
   154  
   155  // refill reads and decompresses the next block.
   156  func (r *Reader) refill() error {
   157  	if !r.sawFrameHeader {
   158  		if err := r.readFrameHeader(); err != nil {
   159  			return err
   160  		}
   161  	}
   162  	return r.readBlock()
   163  }
   164  
   165  // readFrameHeader reads the frame header and prepares to read a block.
   166  func (r *Reader) readFrameHeader() error {
   167  retry:
   168  	relativeOffset := 0
   169  
   170  	// Read magic number. RFC 3.1.1.
   171  	if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
   172  		// We require that the stream contains at least one frame.
   173  		if err == io.EOF && !r.readOneFrame {
   174  			err = io.ErrUnexpectedEOF
   175  		}
   176  		return r.wrapError(relativeOffset, err)
   177  	}
   178  
   179  	if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
   180  		if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
   181  			// This is a skippable frame.
   182  			r.blockOffset += int64(relativeOffset) + 4
   183  			if err := r.skipFrame(); err != nil {
   184  				return err
   185  			}
   186  			r.readOneFrame = true
   187  			goto retry
   188  		}
   189  
   190  		return r.makeError(relativeOffset, "invalid magic number")
   191  	}
   192  
   193  	relativeOffset += 4
   194  
   195  	// Read Frame_Header_Descriptor. RFC 3.1.1.1.1.
   196  	if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
   197  		return r.wrapNonEOFError(relativeOffset, err)
   198  	}
   199  	descriptor := r.scratch[0]
   200  
   201  	singleSegment := descriptor&(1<<5) != 0
   202  
   203  	fcsFieldSize := 1 << (descriptor >> 6)
   204  	if fcsFieldSize == 1 && !singleSegment {
   205  		fcsFieldSize = 0
   206  	}
   207  
   208  	var windowDescriptorSize int
   209  	if singleSegment {
   210  		windowDescriptorSize = 0
   211  	} else {
   212  		windowDescriptorSize = 1
   213  	}
   214  
   215  	if descriptor&(1<<3) != 0 {
   216  		return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
   217  	}
   218  
   219  	r.hasChecksum = descriptor&(1<<2) != 0
   220  	if r.hasChecksum {
   221  		r.checksum.reset()
   222  	}
   223  
   224  	// Dictionary_ID_Flag. RFC 3.1.1.1.1.6.
   225  	dictionaryIdSize := 0
   226  	if dictIdFlag := descriptor & 3; dictIdFlag != 0 {
   227  		dictionaryIdSize = 1 << (dictIdFlag - 1)
   228  	}
   229  
   230  	relativeOffset++
   231  
   232  	headerSize := windowDescriptorSize + dictionaryIdSize + fcsFieldSize
   233  
   234  	if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
   235  		return r.wrapNonEOFError(relativeOffset, err)
   236  	}
   237  
   238  	// Figure out the maximum amount of data we need to retain
   239  	// for backreferences.
   240  	var windowSize uint64
   241  	if !singleSegment {
   242  		// Window descriptor. RFC 3.1.1.1.2.
   243  		windowDescriptor := r.scratch[0]
   244  		exponent := uint64(windowDescriptor >> 3)
   245  		mantissa := uint64(windowDescriptor & 7)
   246  		windowLog := exponent + 10
   247  		windowBase := uint64(1) << windowLog
   248  		windowAdd := (windowBase / 8) * mantissa
   249  		windowSize = windowBase + windowAdd
   250  
   251  		// Default zstd sets limits on the window size.
   252  		if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
   253  			return r.makeError(relativeOffset, "windowSize too large")
   254  		}
   255  	}
   256  
   257  	// Dictionary_ID. RFC 3.1.1.1.3.
   258  	if dictionaryIdSize != 0 {
   259  		dictionaryId := r.scratch[windowDescriptorSize : windowDescriptorSize+dictionaryIdSize]
   260  		// Allow only zero Dictionary ID.
   261  		for _, b := range dictionaryId {
   262  			if b != 0 {
   263  				return r.makeError(relativeOffset, "dictionaries are not supported")
   264  			}
   265  		}
   266  	}
   267  
   268  	// Frame_Content_Size. RFC 3.1.1.1.4.
   269  	r.frameSizeUnknown = false
   270  	r.remainingFrameSize = 0
   271  	fb := r.scratch[windowDescriptorSize+dictionaryIdSize:]
   272  	switch fcsFieldSize {
   273  	case 0:
   274  		r.frameSizeUnknown = true
   275  	case 1:
   276  		r.remainingFrameSize = uint64(fb[0])
   277  	case 2:
   278  		r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
   279  	case 4:
   280  		r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
   281  	case 8:
   282  		r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
   283  	default:
   284  		panic("unreachable")
   285  	}
   286  
   287  	// RFC 3.1.1.1.2.
   288  	// When Single_Segment_Flag is set, Window_Descriptor is not present.
   289  	// In this case, Window_Size is Frame_Content_Size.
   290  	if singleSegment {
   291  		windowSize = r.remainingFrameSize
   292  	}
   293  
   294  	// RFC 8878 3.1.1.1.1.2. permits us to set an 8M max on window size.
   295  	const maxWindowSize = 8 << 20
   296  	if windowSize > maxWindowSize {
   297  		windowSize = maxWindowSize
   298  	}
   299  
   300  	relativeOffset += headerSize
   301  
   302  	r.sawFrameHeader = true
   303  	r.readOneFrame = true
   304  	r.blockOffset += int64(relativeOffset)
   305  
   306  	// Prepare to read blocks from the frame.
   307  	r.repeatedOffset1 = 1
   308  	r.repeatedOffset2 = 4
   309  	r.repeatedOffset3 = 8
   310  	r.huffmanTableBits = 0
   311  	r.window.reset(int(windowSize))
   312  	r.seqTables[0] = nil
   313  	r.seqTables[1] = nil
   314  	r.seqTables[2] = nil
   315  
   316  	return nil
   317  }
   318  
   319  // skipFrame skips a skippable frame. RFC 3.1.2.
   320  func (r *Reader) skipFrame() error {
   321  	relativeOffset := 0
   322  
   323  	if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
   324  		return r.wrapNonEOFError(relativeOffset, err)
   325  	}
   326  
   327  	relativeOffset += 4
   328  
   329  	size := binary.LittleEndian.Uint32(r.scratch[:4])
   330  	if size == 0 {
   331  		r.blockOffset += int64(relativeOffset)
   332  		return nil
   333  	}
   334  
   335  	if seeker, ok := r.r.(io.Seeker); ok {
   336  		r.blockOffset += int64(relativeOffset)
   337  		// Implementations of Seeker do not always detect invalid offsets,
   338  		// so check that the new offset is valid by comparing to the end.
   339  		prev, err := seeker.Seek(0, io.SeekCurrent)
   340  		if err != nil {
   341  			return r.wrapError(0, err)
   342  		}
   343  		end, err := seeker.Seek(0, io.SeekEnd)
   344  		if err != nil {
   345  			return r.wrapError(0, err)
   346  		}
   347  		if prev > end-int64(size) {
   348  			r.blockOffset += end - prev
   349  			return r.makeEOFError(0)
   350  		}
   351  
   352  		// The new offset is valid, so seek to it.
   353  		_, err = seeker.Seek(prev+int64(size), io.SeekStart)
   354  		if err != nil {
   355  			return r.wrapError(0, err)
   356  		}
   357  		r.blockOffset += int64(size)
   358  		return nil
   359  	}
   360  
   361  	n, err := io.CopyN(io.Discard, r.r, int64(size))
   362  	relativeOffset += int(n)
   363  	if err != nil {
   364  		return r.wrapNonEOFError(relativeOffset, err)
   365  	}
   366  	r.blockOffset += int64(relativeOffset)
   367  	return nil
   368  }
   369  
   370  // readBlock reads the next block from a frame.
   371  func (r *Reader) readBlock() error {
   372  	relativeOffset := 0
   373  
   374  	// Read Block_Header. RFC 3.1.1.2.
   375  	if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
   376  		return r.wrapNonEOFError(relativeOffset, err)
   377  	}
   378  
   379  	relativeOffset += 3
   380  
   381  	header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
   382  
   383  	lastBlock := header&1 != 0
   384  	blockType := (header >> 1) & 3
   385  	blockSize := int(header >> 3)
   386  
   387  	// Maximum block size is smaller of window size and 128K.
   388  	// We don't record the window size for a single segment frame,
   389  	// so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4.
   390  	if blockSize > 128<<10 || (r.window.size > 0 && blockSize > r.window.size) {
   391  		return r.makeError(relativeOffset, "block size too large")
   392  	}
   393  
   394  	// Handle different block types. RFC 3.1.1.2.2.
   395  	switch blockType {
   396  	case 0:
   397  		r.setBufferSize(blockSize)
   398  		if _, err := io.ReadFull(r.r, r.buffer); err != nil {
   399  			return r.wrapNonEOFError(relativeOffset, err)
   400  		}
   401  		relativeOffset += blockSize
   402  		r.blockOffset += int64(relativeOffset)
   403  	case 1:
   404  		r.setBufferSize(blockSize)
   405  		if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
   406  			return r.wrapNonEOFError(relativeOffset, err)
   407  		}
   408  		relativeOffset++
   409  		v := r.scratch[0]
   410  		for i := range r.buffer {
   411  			r.buffer[i] = v
   412  		}
   413  		r.blockOffset += int64(relativeOffset)
   414  	case 2:
   415  		r.blockOffset += int64(relativeOffset)
   416  		if err := r.compressedBlock(blockSize); err != nil {
   417  			return err
   418  		}
   419  		r.blockOffset += int64(blockSize)
   420  	case 3:
   421  		return r.makeError(relativeOffset, "invalid block type")
   422  	}
   423  
   424  	if !r.frameSizeUnknown {
   425  		if uint64(len(r.buffer)) > r.remainingFrameSize {
   426  			return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
   427  		}
   428  		r.remainingFrameSize -= uint64(len(r.buffer))
   429  	}
   430  
   431  	if r.hasChecksum {
   432  		r.checksum.update(r.buffer)
   433  	}
   434  
   435  	if !lastBlock {
   436  		r.window.save(r.buffer)
   437  	} else {
   438  		if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
   439  			return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
   440  		}
   441  		// Check for checksum at end of frame. RFC 3.1.1.
   442  		if r.hasChecksum {
   443  			if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
   444  				return r.wrapNonEOFError(0, err)
   445  			}
   446  
   447  			inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
   448  			dataChecksum := uint32(r.checksum.digest())
   449  			if inputChecksum != dataChecksum {
   450  				return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
   451  			}
   452  
   453  			r.blockOffset += 4
   454  		}
   455  		r.sawFrameHeader = false
   456  	}
   457  
   458  	return nil
   459  }
   460  
   461  // setBufferSize sets the decompressed buffer size.
   462  // When this is called the buffer is empty.
   463  func (r *Reader) setBufferSize(size int) {
   464  	if cap(r.buffer) < size {
   465  		need := size - cap(r.buffer)
   466  		r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...)
   467  	}
   468  	r.buffer = r.buffer[:size]
   469  }
   470  
   471  // zstdError is an error while decompressing.
   472  type zstdError struct {
   473  	offset int64
   474  	err    error
   475  }
   476  
   477  func (ze *zstdError) Error() string {
   478  	return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
   479  }
   480  
   481  func (ze *zstdError) Unwrap() error {
   482  	return ze.err
   483  }
   484  
   485  func (r *Reader) makeEOFError(off int) error {
   486  	return r.wrapError(off, io.ErrUnexpectedEOF)
   487  }
   488  
   489  func (r *Reader) wrapNonEOFError(off int, err error) error {
   490  	if err == io.EOF {
   491  		err = io.ErrUnexpectedEOF
   492  	}
   493  	return r.wrapError(off, err)
   494  }
   495  
   496  func (r *Reader) makeError(off int, msg string) error {
   497  	return r.wrapError(off, errors.New(msg))
   498  }
   499  
   500  func (r *Reader) wrapError(off int, err error) error {
   501  	if err == io.EOF {
   502  		return err
   503  	}
   504  	return &zstdError{r.blockOffset + int64(off), err}
   505  }
   506  

View as plain text