// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package zstd import ( "io" ) // debug can be set in the source to print debug info using println. const debug = false // compressedBlock decompresses a compressed block, storing the decompressed // data in r.buffer. The blockSize argument is the compressed size. // RFC 3.1.1.3. func (r *Reader) compressedBlock(blockSize int) error { if len(r.compressedBuf) >= blockSize { r.compressedBuf = r.compressedBuf[:blockSize] } else { // We know that blockSize <= 128K, // so this won't allocate an enormous amount. need := blockSize - len(r.compressedBuf) r.compressedBuf = append(r.compressedBuf, make([]byte, need)...) } if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil { return r.wrapNonEOFError(0, err) } data := block(r.compressedBuf) off := 0 r.buffer = r.buffer[:0] litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0]) if err != nil { return err } r.literals = litbuf off = litoff seqCount, off, err := r.initSeqs(data, off) if err != nil { return err } if seqCount == 0 { // No sequences, just literals. if off < len(data) { return r.makeError(off, "extraneous data after no sequences") } r.buffer = append(r.buffer, litbuf...) return nil } return r.execSeqs(data, off, litbuf, seqCount) } // seqCode is the kind of sequence codes we have to handle. type seqCode int const ( seqLiteral seqCode = iota seqOffset seqMatch ) // seqCodeInfoData is the information needed to set up seqTables and // seqTableBits for a particular kind of sequence code. type seqCodeInfoData struct { predefTable []fseBaselineEntry // predefined FSE predefTableBits int // number of bits in predefTable maxSym int // max symbol value in FSE maxBits int // max bits for FSE // toBaseline converts from an FSE table to an FSE baseline table. toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error } // seqCodeInfo is the seqCodeInfoData for each kind of sequence code. var seqCodeInfo = [3]seqCodeInfoData{ seqLiteral: { predefTable: predefinedLiteralTable[:], predefTableBits: 6, maxSym: 35, maxBits: 9, toBaseline: (*Reader).makeLiteralBaselineFSE, }, seqOffset: { predefTable: predefinedOffsetTable[:], predefTableBits: 5, maxSym: 31, maxBits: 8, toBaseline: (*Reader).makeOffsetBaselineFSE, }, seqMatch: { predefTable: predefinedMatchTable[:], predefTableBits: 6, maxSym: 52, maxBits: 9, toBaseline: (*Reader).makeMatchBaselineFSE, }, } // initSeqs reads the Sequences_Section_Header and sets up the FSE // tables used to read the sequence codes. It returns the number of // sequences and the new offset. RFC 3.1.1.3.2.1. func (r *Reader) initSeqs(data block, off int) (int, int, error) { if off >= len(data) { return 0, 0, r.makeEOFError(off) } seqHdr := data[off] off++ if seqHdr == 0 { return 0, off, nil } var seqCount int if seqHdr < 128 { seqCount = int(seqHdr) } else if seqHdr < 255 { if off >= len(data) { return 0, 0, r.makeEOFError(off) } seqCount = ((int(seqHdr) - 128) << 8) + int(data[off]) off++ } else { if off+1 >= len(data) { return 0, 0, r.makeEOFError(off) } seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00 off += 2 } // Read the Symbol_Compression_Modes byte. if off >= len(data) { return 0, 0, r.makeEOFError(off) } symMode := data[off] if symMode&3 != 0 { return 0, 0, r.makeError(off, "invalid symbol compression mode") } off++ // Set up the FSE tables used to decode the sequence codes. var err error off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3) if err != nil { return 0, 0, err } off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3) if err != nil { return 0, 0, err } off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3) if err != nil { return 0, 0, err } return seqCount, off, nil } // setSeqTable uses the Compression_Mode in mode to set up r.seqTables and // r.seqTableBits for kind. We store these in the Reader because one of // the modes simply reuses the value from the last block in the frame. func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) { info := &seqCodeInfo[kind] switch mode { case 0: // Predefined_Mode r.seqTables[kind] = info.predefTable r.seqTableBits[kind] = uint8(info.predefTableBits) return off, nil case 1: // RLE_Mode if off >= len(data) { return 0, r.makeEOFError(off) } rle := data[off] off++ // Build a simple baseline table that always returns rle. entry := []fseEntry{ { sym: rle, bits: 0, base: 0, }, } if cap(r.seqTableBuffers[kind]) == 0 { r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1< 128<<10 { return rbr.makeError("uncompressed size too big") } ptoffset := &r.seqTables[seqOffset][offsetState] ptmatch := &r.seqTables[seqMatch][matchState] ptliteral := &r.seqTables[seqLiteral][literalState] add, err := rbr.val(ptoffset.basebits) if err != nil { return err } offset := ptoffset.baseline + add add, err = rbr.val(ptmatch.basebits) if err != nil { return err } match := ptmatch.baseline + add add, err = rbr.val(ptliteral.basebits) if err != nil { return err } literal := ptliteral.baseline + add // Handle repeat offsets. RFC 3.1.1.5. // See the comment in makeOffsetBaselineFSE. if ptoffset.basebits > 1 { r.repeatedOffset3 = r.repeatedOffset2 r.repeatedOffset2 = r.repeatedOffset1 r.repeatedOffset1 = offset } else { if literal == 0 { offset++ } switch offset { case 1: offset = r.repeatedOffset1 case 2: offset = r.repeatedOffset2 r.repeatedOffset2 = r.repeatedOffset1 r.repeatedOffset1 = offset case 3: offset = r.repeatedOffset3 r.repeatedOffset3 = r.repeatedOffset2 r.repeatedOffset2 = r.repeatedOffset1 r.repeatedOffset1 = offset case 4: offset = r.repeatedOffset1 - 1 r.repeatedOffset3 = r.repeatedOffset2 r.repeatedOffset2 = r.repeatedOffset1 r.repeatedOffset1 = offset } } seq++ if seq < seqCount { // Update the states. add, err = rbr.val(ptliteral.bits) if err != nil { return err } literalState = uint32(ptliteral.base) + add add, err = rbr.val(ptmatch.bits) if err != nil { return err } matchState = uint32(ptmatch.base) + add add, err = rbr.val(ptoffset.bits) if err != nil { return err } offsetState = uint32(ptoffset.base) + add } // The next sequence is now in literal, offset, match. if debug { println("literal", literal, "offset", offset, "match", match) } // Copy literal bytes from litbuf. if literal > uint32(len(litbuf)) { return rbr.makeError("literal byte overflow") } if literal > 0 { r.buffer = append(r.buffer, litbuf[:literal]...) litbuf = litbuf[literal:] } if match > 0 { if err := r.copyFromWindow(&rbr, offset, match); err != nil { return err } } } r.buffer = append(r.buffer, litbuf...) if rbr.cnt != 0 { return r.makeError(off, "extraneous data after sequences") } return nil } // Copy match bytes from the decoded output, or the window, at offset. func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error { if offset == 0 { return rbr.makeError("invalid zero offset") } // Offset may point into the buffer or the window and // match may extend past the end of the initial buffer. // |--r.window--|--r.buffer--| // |<-----offset------| // |------match----------->| bufferOffset := uint32(0) lenBlock := uint32(len(r.buffer)) if lenBlock < offset { lenWindow := r.window.len() copy := offset - lenBlock if copy > lenWindow { return rbr.makeError("offset past window") } windowOffset := lenWindow - copy if copy > match { copy = match } r.buffer = r.window.appendTo(r.buffer, windowOffset, windowOffset+copy) match -= copy } else { bufferOffset = lenBlock - offset } // We are being asked to copy data that we are adding to the // buffer in the same copy. for match > 0 { copy := uint32(len(r.buffer)) - bufferOffset if copy > match { copy = match } r.buffer = append(r.buffer, r.buffer[bufferOffset:bufferOffset+copy]...) match -= copy } return nil }