Source file src/compress/flate/deflate_test.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package flate
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"internal/testenv"
    12  	"io"
    13  	"math/rand"
    14  	"os"
    15  	"reflect"
    16  	"runtime/debug"
    17  	"sync"
    18  	"testing"
    19  )
    20  
    21  type deflateTest struct {
    22  	in    []byte
    23  	level int
    24  	out   []byte
    25  }
    26  
    27  type deflateInflateTest struct {
    28  	in []byte
    29  }
    30  
    31  type reverseBitsTest struct {
    32  	in       uint16
    33  	bitCount uint8
    34  	out      uint16
    35  }
    36  
    37  var deflateTests = []*deflateTest{
    38  	{[]byte{}, 0, []byte{1, 0, 0, 255, 255}},
    39  	{[]byte{0x11}, -1, []byte{18, 4, 4, 0, 0, 255, 255}},
    40  	{[]byte{0x11}, DefaultCompression, []byte{18, 4, 4, 0, 0, 255, 255}},
    41  	{[]byte{0x11}, 4, []byte{18, 4, 4, 0, 0, 255, 255}},
    42  
    43  	{[]byte{0x11}, 0, []byte{0, 1, 0, 254, 255, 17, 1, 0, 0, 255, 255}},
    44  	{[]byte{0x11, 0x12}, 0, []byte{0, 2, 0, 253, 255, 17, 18, 1, 0, 0, 255, 255}},
    45  	{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}, 0,
    46  		[]byte{0, 8, 0, 247, 255, 17, 17, 17, 17, 17, 17, 17, 17, 1, 0, 0, 255, 255},
    47  	},
    48  	{[]byte{}, 2, []byte{1, 0, 0, 255, 255}},
    49  	{[]byte{0x11}, 2, []byte{18, 4, 4, 0, 0, 255, 255}},
    50  	{[]byte{0x11, 0x12}, 2, []byte{18, 20, 2, 4, 0, 0, 255, 255}},
    51  	{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}, 2, []byte{18, 132, 2, 64, 0, 0, 0, 255, 255}},
    52  	{[]byte{}, 9, []byte{1, 0, 0, 255, 255}},
    53  	{[]byte{0x11}, 9, []byte{18, 4, 4, 0, 0, 255, 255}},
    54  	{[]byte{0x11, 0x12}, 9, []byte{18, 20, 2, 4, 0, 0, 255, 255}},
    55  	{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}, 9, []byte{18, 132, 2, 64, 0, 0, 0, 255, 255}},
    56  }
    57  
    58  var deflateInflateTests = []*deflateInflateTest{
    59  	{[]byte{}},
    60  	{[]byte{0x11}},
    61  	{[]byte{0x11, 0x12}},
    62  	{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}},
    63  	{[]byte{0x11, 0x10, 0x13, 0x41, 0x21, 0x21, 0x41, 0x13, 0x87, 0x78, 0x13}},
    64  	{largeDataChunk()},
    65  }
    66  
    67  var reverseBitsTests = []*reverseBitsTest{
    68  	{1, 1, 1},
    69  	{1, 2, 2},
    70  	{1, 3, 4},
    71  	{1, 4, 8},
    72  	{1, 5, 16},
    73  	{17, 5, 17},
    74  	{257, 9, 257},
    75  	{29, 5, 23},
    76  }
    77  
    78  func largeDataChunk() []byte {
    79  	result := make([]byte, 100000)
    80  	for i := range result {
    81  		result[i] = byte(i * i & 0xFF)
    82  	}
    83  	return result
    84  }
    85  
    86  func TestBulkHash4(t *testing.T) {
    87  	for _, x := range deflateTests {
    88  		y := x.out
    89  		if len(y) < minMatchLength {
    90  			continue
    91  		}
    92  		y = append(y, y...)
    93  		for j := 4; j < len(y); j++ {
    94  			y := y[:j]
    95  			dst := make([]uint32, len(y)-minMatchLength+1)
    96  			for i := range dst {
    97  				dst[i] = uint32(i + 100)
    98  			}
    99  			bulkHash4(y, dst)
   100  			for i, got := range dst {
   101  				want := hash4(y[i:])
   102  				if got != want && got == uint32(i)+100 {
   103  					t.Errorf("Len:%d Index:%d, want 0x%08x but not modified", len(y), i, want)
   104  				} else if got != want {
   105  					t.Errorf("Len:%d Index:%d, got 0x%08x want:0x%08x", len(y), i, got, want)
   106  				}
   107  			}
   108  		}
   109  	}
   110  }
   111  
   112  func TestDeflate(t *testing.T) {
   113  	for _, h := range deflateTests {
   114  		var buf bytes.Buffer
   115  		w, err := NewWriter(&buf, h.level)
   116  		if err != nil {
   117  			t.Errorf("NewWriter: %v", err)
   118  			continue
   119  		}
   120  		w.Write(h.in)
   121  		w.Close()
   122  		if !bytes.Equal(buf.Bytes(), h.out) {
   123  			t.Errorf("Deflate(%d, %x) = \n%#v, want \n%#v", h.level, h.in, buf.Bytes(), h.out)
   124  		}
   125  	}
   126  }
   127  
   128  func TestWriterClose(t *testing.T) {
   129  	b := new(bytes.Buffer)
   130  	zw, err := NewWriter(b, 6)
   131  	if err != nil {
   132  		t.Fatalf("NewWriter: %v", err)
   133  	}
   134  
   135  	if c, err := zw.Write([]byte("Test")); err != nil || c != 4 {
   136  		t.Fatalf("Write to not closed writer: %s, %d", err, c)
   137  	}
   138  
   139  	if err := zw.Close(); err != nil {
   140  		t.Fatalf("Close: %v", err)
   141  	}
   142  
   143  	afterClose := b.Len()
   144  
   145  	if c, err := zw.Write([]byte("Test")); err == nil || c != 0 {
   146  		t.Fatalf("Write to closed writer: %v, %d", err, c)
   147  	}
   148  
   149  	if err := zw.Flush(); err == nil {
   150  		t.Fatalf("Flush to closed writer: %s", err)
   151  	}
   152  
   153  	if err := zw.Close(); err != nil {
   154  		t.Fatalf("Close: %v", err)
   155  	}
   156  
   157  	if afterClose != b.Len() {
   158  		t.Fatalf("Writer wrote data after close. After close: %d. After writes on closed stream: %d", afterClose, b.Len())
   159  	}
   160  }
   161  
   162  // A sparseReader returns a stream consisting of 0s followed by 1<<16 1s.
   163  // This tests missing hash references in a very large input.
   164  type sparseReader struct {
   165  	l   int64
   166  	cur int64
   167  }
   168  
   169  func (r *sparseReader) Read(b []byte) (n int, err error) {
   170  	if r.cur >= r.l {
   171  		return 0, io.EOF
   172  	}
   173  	n = len(b)
   174  	cur := r.cur + int64(n)
   175  	if cur > r.l {
   176  		n -= int(cur - r.l)
   177  		cur = r.l
   178  	}
   179  	for i := range b[0:n] {
   180  		if r.cur+int64(i) >= r.l-1<<16 {
   181  			b[i] = 1
   182  		} else {
   183  			b[i] = 0
   184  		}
   185  	}
   186  	r.cur = cur
   187  	return
   188  }
   189  
   190  func TestVeryLongSparseChunk(t *testing.T) {
   191  	if testing.Short() {
   192  		t.Skip("skipping sparse chunk during short test")
   193  	}
   194  	w, err := NewWriter(io.Discard, 1)
   195  	if err != nil {
   196  		t.Errorf("NewWriter: %v", err)
   197  		return
   198  	}
   199  	if _, err = io.Copy(w, &sparseReader{l: 23e8}); err != nil {
   200  		t.Errorf("Compress failed: %v", err)
   201  		return
   202  	}
   203  }
   204  
   205  type syncBuffer struct {
   206  	buf    bytes.Buffer
   207  	mu     sync.RWMutex
   208  	closed bool
   209  	ready  chan bool
   210  }
   211  
   212  func newSyncBuffer() *syncBuffer {
   213  	return &syncBuffer{ready: make(chan bool, 1)}
   214  }
   215  
   216  func (b *syncBuffer) Read(p []byte) (n int, err error) {
   217  	for {
   218  		b.mu.RLock()
   219  		n, err = b.buf.Read(p)
   220  		b.mu.RUnlock()
   221  		if n > 0 || b.closed {
   222  			return
   223  		}
   224  		<-b.ready
   225  	}
   226  }
   227  
   228  func (b *syncBuffer) signal() {
   229  	select {
   230  	case b.ready <- true:
   231  	default:
   232  	}
   233  }
   234  
   235  func (b *syncBuffer) Write(p []byte) (n int, err error) {
   236  	n, err = b.buf.Write(p)
   237  	b.signal()
   238  	return
   239  }
   240  
   241  func (b *syncBuffer) WriteMode() {
   242  	b.mu.Lock()
   243  }
   244  
   245  func (b *syncBuffer) ReadMode() {
   246  	b.mu.Unlock()
   247  	b.signal()
   248  }
   249  
   250  func (b *syncBuffer) Close() error {
   251  	b.closed = true
   252  	b.signal()
   253  	return nil
   254  }
   255  
   256  func testSync(t *testing.T, level int, input []byte, name string) {
   257  	if len(input) == 0 {
   258  		return
   259  	}
   260  
   261  	t.Logf("--testSync %d, %d, %s", level, len(input), name)
   262  	buf := newSyncBuffer()
   263  	buf1 := new(bytes.Buffer)
   264  	buf.WriteMode()
   265  	w, err := NewWriter(io.MultiWriter(buf, buf1), level)
   266  	if err != nil {
   267  		t.Errorf("NewWriter: %v", err)
   268  		return
   269  	}
   270  	r := NewReader(buf)
   271  
   272  	// Write half the input and read back.
   273  	for i := 0; i < 2; i++ {
   274  		var lo, hi int
   275  		if i == 0 {
   276  			lo, hi = 0, (len(input)+1)/2
   277  		} else {
   278  			lo, hi = (len(input)+1)/2, len(input)
   279  		}
   280  		t.Logf("#%d: write %d-%d", i, lo, hi)
   281  		if _, err := w.Write(input[lo:hi]); err != nil {
   282  			t.Errorf("testSync: write: %v", err)
   283  			return
   284  		}
   285  		if i == 0 {
   286  			if err := w.Flush(); err != nil {
   287  				t.Errorf("testSync: flush: %v", err)
   288  				return
   289  			}
   290  		} else {
   291  			if err := w.Close(); err != nil {
   292  				t.Errorf("testSync: close: %v", err)
   293  			}
   294  		}
   295  		buf.ReadMode()
   296  		out := make([]byte, hi-lo+1)
   297  		m, err := io.ReadAtLeast(r, out, hi-lo)
   298  		t.Logf("#%d: read %d", i, m)
   299  		if m != hi-lo || err != nil {
   300  			t.Errorf("testSync/%d (%d, %d, %s): read %d: %d, %v (%d left)", i, level, len(input), name, hi-lo, m, err, buf.buf.Len())
   301  			return
   302  		}
   303  		if !bytes.Equal(input[lo:hi], out[:hi-lo]) {
   304  			t.Errorf("testSync/%d: read wrong bytes: %x vs %x", i, input[lo:hi], out[:hi-lo])
   305  			return
   306  		}
   307  		// This test originally checked that after reading
   308  		// the first half of the input, there was nothing left
   309  		// in the read buffer (buf.buf.Len() != 0) but that is
   310  		// not necessarily the case: the write Flush may emit
   311  		// some extra framing bits that are not necessary
   312  		// to process to obtain the first half of the uncompressed
   313  		// data. The test ran correctly most of the time, because
   314  		// the background goroutine had usually read even
   315  		// those extra bits by now, but it's not a useful thing to
   316  		// check.
   317  		buf.WriteMode()
   318  	}
   319  	buf.ReadMode()
   320  	out := make([]byte, 10)
   321  	if n, err := r.Read(out); n > 0 || err != io.EOF {
   322  		t.Errorf("testSync (%d, %d, %s): final Read: %d, %v (hex: %x)", level, len(input), name, n, err, out[0:n])
   323  	}
   324  	if buf.buf.Len() != 0 {
   325  		t.Errorf("testSync (%d, %d, %s): extra data at end", level, len(input), name)
   326  	}
   327  	r.Close()
   328  
   329  	// stream should work for ordinary reader too
   330  	r = NewReader(buf1)
   331  	out, err = io.ReadAll(r)
   332  	if err != nil {
   333  		t.Errorf("testSync: read: %s", err)
   334  		return
   335  	}
   336  	r.Close()
   337  	if !bytes.Equal(input, out) {
   338  		t.Errorf("testSync: decompress(compress(data)) != data: level=%d input=%s", level, name)
   339  	}
   340  }
   341  
   342  func testToFromWithLevelAndLimit(t *testing.T, level int, input []byte, name string, limit int) {
   343  	var buffer bytes.Buffer
   344  	w, err := NewWriter(&buffer, level)
   345  	if err != nil {
   346  		t.Errorf("NewWriter: %v", err)
   347  		return
   348  	}
   349  	w.Write(input)
   350  	w.Close()
   351  	if limit > 0 && buffer.Len() > limit {
   352  		t.Errorf("level: %d, len(compress(data)) = %d > limit = %d", level, buffer.Len(), limit)
   353  		return
   354  	}
   355  	if limit > 0 {
   356  		t.Logf("level: %d, size:%.2f%%, %d b\n", level, float64(buffer.Len()*100)/float64(limit), buffer.Len())
   357  	}
   358  	r := NewReader(&buffer)
   359  	out, err := io.ReadAll(r)
   360  	if err != nil {
   361  		t.Errorf("read: %s", err)
   362  		return
   363  	}
   364  	r.Close()
   365  	if !bytes.Equal(input, out) {
   366  		t.Errorf("decompress(compress(data)) != data: level=%d input=%s", level, name)
   367  		return
   368  	}
   369  	testSync(t, level, input, name)
   370  }
   371  
   372  func testToFromWithLimit(t *testing.T, input []byte, name string, limit [11]int) {
   373  	for i := 0; i < 10; i++ {
   374  		testToFromWithLevelAndLimit(t, i, input, name, limit[i])
   375  	}
   376  	// Test HuffmanCompression
   377  	testToFromWithLevelAndLimit(t, -2, input, name, limit[10])
   378  }
   379  
   380  func TestDeflateInflate(t *testing.T) {
   381  	t.Parallel()
   382  	for i, h := range deflateInflateTests {
   383  		if testing.Short() && len(h.in) > 10000 {
   384  			continue
   385  		}
   386  		testToFromWithLimit(t, h.in, fmt.Sprintf("#%d", i), [11]int{})
   387  	}
   388  }
   389  
   390  func TestReverseBits(t *testing.T) {
   391  	for _, h := range reverseBitsTests {
   392  		if v := reverseBits(h.in, h.bitCount); v != h.out {
   393  			t.Errorf("reverseBits(%v,%v) = %v, want %v",
   394  				h.in, h.bitCount, v, h.out)
   395  		}
   396  	}
   397  }
   398  
   399  type deflateInflateStringTest struct {
   400  	filename string
   401  	label    string
   402  	limit    [11]int
   403  }
   404  
   405  var deflateInflateStringTests = []deflateInflateStringTest{
   406  	{
   407  		"../testdata/e.txt",
   408  		"2.718281828...",
   409  		[...]int{100018, 50650, 50960, 51150, 50930, 50790, 50790, 50790, 50790, 50790, 43683},
   410  	},
   411  	{
   412  		"../../testdata/Isaac.Newton-Opticks.txt",
   413  		"Isaac.Newton-Opticks",
   414  		[...]int{567248, 218338, 198211, 193152, 181100, 175427, 175427, 173597, 173422, 173422, 325240},
   415  	},
   416  }
   417  
   418  func TestDeflateInflateString(t *testing.T) {
   419  	t.Parallel()
   420  	if testing.Short() && testenv.Builder() == "" {
   421  		t.Skip("skipping in short mode")
   422  	}
   423  	for _, test := range deflateInflateStringTests {
   424  		gold, err := os.ReadFile(test.filename)
   425  		if err != nil {
   426  			t.Error(err)
   427  		}
   428  		testToFromWithLimit(t, gold, test.label, test.limit)
   429  		if testing.Short() {
   430  			break
   431  		}
   432  	}
   433  }
   434  
   435  func TestReaderDict(t *testing.T) {
   436  	const (
   437  		dict = "hello world"
   438  		text = "hello again world"
   439  	)
   440  	var b bytes.Buffer
   441  	w, err := NewWriter(&b, 5)
   442  	if err != nil {
   443  		t.Fatalf("NewWriter: %v", err)
   444  	}
   445  	w.Write([]byte(dict))
   446  	w.Flush()
   447  	b.Reset()
   448  	w.Write([]byte(text))
   449  	w.Close()
   450  
   451  	r := NewReaderDict(&b, []byte(dict))
   452  	data, err := io.ReadAll(r)
   453  	if err != nil {
   454  		t.Fatal(err)
   455  	}
   456  	if string(data) != "hello again world" {
   457  		t.Fatalf("read returned %q want %q", string(data), text)
   458  	}
   459  }
   460  
   461  func TestWriterDict(t *testing.T) {
   462  	const (
   463  		dict = "hello world"
   464  		text = "hello again world"
   465  	)
   466  	var b bytes.Buffer
   467  	w, err := NewWriter(&b, 5)
   468  	if err != nil {
   469  		t.Fatalf("NewWriter: %v", err)
   470  	}
   471  	w.Write([]byte(dict))
   472  	w.Flush()
   473  	b.Reset()
   474  	w.Write([]byte(text))
   475  	w.Close()
   476  
   477  	var b1 bytes.Buffer
   478  	w, _ = NewWriterDict(&b1, 5, []byte(dict))
   479  	w.Write([]byte(text))
   480  	w.Close()
   481  
   482  	if !bytes.Equal(b1.Bytes(), b.Bytes()) {
   483  		t.Fatalf("writer wrote %q want %q", b1.Bytes(), b.Bytes())
   484  	}
   485  }
   486  
   487  // See https://golang.org/issue/2508
   488  func TestRegression2508(t *testing.T) {
   489  	if testing.Short() {
   490  		t.Logf("test disabled with -short")
   491  		return
   492  	}
   493  	w, err := NewWriter(io.Discard, 1)
   494  	if err != nil {
   495  		t.Fatalf("NewWriter: %v", err)
   496  	}
   497  	buf := make([]byte, 1024)
   498  	for i := 0; i < 131072; i++ {
   499  		if _, err := w.Write(buf); err != nil {
   500  			t.Fatalf("writer failed: %v", err)
   501  		}
   502  	}
   503  	w.Close()
   504  }
   505  
   506  func TestWriterReset(t *testing.T) {
   507  	t.Parallel()
   508  	for level := 0; level <= 9; level++ {
   509  		if testing.Short() && level > 1 {
   510  			break
   511  		}
   512  		w, err := NewWriter(io.Discard, level)
   513  		if err != nil {
   514  			t.Fatalf("NewWriter: %v", err)
   515  		}
   516  		buf := []byte("hello world")
   517  		n := 1024
   518  		if testing.Short() {
   519  			n = 10
   520  		}
   521  		for i := 0; i < n; i++ {
   522  			w.Write(buf)
   523  		}
   524  		w.Reset(io.Discard)
   525  
   526  		wref, err := NewWriter(io.Discard, level)
   527  		if err != nil {
   528  			t.Fatalf("NewWriter: %v", err)
   529  		}
   530  
   531  		// DeepEqual doesn't compare functions.
   532  		w.d.fill, wref.d.fill = nil, nil
   533  		w.d.step, wref.d.step = nil, nil
   534  		w.d.bulkHasher, wref.d.bulkHasher = nil, nil
   535  		w.d.bestSpeed, wref.d.bestSpeed = nil, nil
   536  		// hashMatch is always overwritten when used.
   537  		copy(w.d.hashMatch[:], wref.d.hashMatch[:])
   538  		if len(w.d.tokens) != 0 {
   539  			t.Errorf("level %d Writer not reset after Reset. %d tokens were present", level, len(w.d.tokens))
   540  		}
   541  		// As long as the length is 0, we don't care about the content.
   542  		w.d.tokens = wref.d.tokens
   543  
   544  		// We don't care if there are values in the window, as long as it is at d.index is 0
   545  		w.d.window = wref.d.window
   546  		if !reflect.DeepEqual(w, wref) {
   547  			t.Errorf("level %d Writer not reset after Reset", level)
   548  		}
   549  	}
   550  
   551  	levels := []int{0, 1, 2, 5, 9}
   552  	for _, level := range levels {
   553  		t.Run(fmt.Sprint(level), func(t *testing.T) {
   554  			testResetOutput(t, level, nil)
   555  		})
   556  	}
   557  
   558  	t.Run("dict", func(t *testing.T) {
   559  		for _, level := range levels {
   560  			t.Run(fmt.Sprint(level), func(t *testing.T) {
   561  				testResetOutput(t, level, nil)
   562  			})
   563  		}
   564  	})
   565  }
   566  
   567  func testResetOutput(t *testing.T, level int, dict []byte) {
   568  	writeData := func(w *Writer) {
   569  		msg := []byte("now is the time for all good gophers")
   570  		w.Write(msg)
   571  		w.Flush()
   572  
   573  		hello := []byte("hello world")
   574  		for i := 0; i < 1024; i++ {
   575  			w.Write(hello)
   576  		}
   577  
   578  		fill := bytes.Repeat([]byte("x"), 65000)
   579  		w.Write(fill)
   580  	}
   581  
   582  	buf := new(bytes.Buffer)
   583  	var w *Writer
   584  	var err error
   585  	if dict == nil {
   586  		w, err = NewWriter(buf, level)
   587  	} else {
   588  		w, err = NewWriterDict(buf, level, dict)
   589  	}
   590  	if err != nil {
   591  		t.Fatalf("NewWriter: %v", err)
   592  	}
   593  
   594  	writeData(w)
   595  	w.Close()
   596  	out1 := buf.Bytes()
   597  
   598  	buf2 := new(bytes.Buffer)
   599  	w.Reset(buf2)
   600  	writeData(w)
   601  	w.Close()
   602  	out2 := buf2.Bytes()
   603  
   604  	if len(out1) != len(out2) {
   605  		t.Errorf("got %d, expected %d bytes", len(out2), len(out1))
   606  		return
   607  	}
   608  	if !bytes.Equal(out1, out2) {
   609  		mm := 0
   610  		for i, b := range out1[:len(out2)] {
   611  			if b != out2[i] {
   612  				t.Errorf("mismatch index %d: %#02x, expected %#02x", i, out2[i], b)
   613  			}
   614  			mm++
   615  			if mm == 10 {
   616  				t.Fatal("Stopping")
   617  			}
   618  		}
   619  	}
   620  	t.Logf("got %d bytes", len(out1))
   621  }
   622  
   623  // TestBestSpeed tests that round-tripping through deflate and then inflate
   624  // recovers the original input. The Write sizes are near the thresholds in the
   625  // compressor.encSpeed method (0, 16, 128), as well as near maxStoreBlockSize
   626  // (65535).
   627  func TestBestSpeed(t *testing.T) {
   628  	t.Parallel()
   629  	abc := make([]byte, 128)
   630  	for i := range abc {
   631  		abc[i] = byte(i)
   632  	}
   633  	abcabc := bytes.Repeat(abc, 131072/len(abc))
   634  	var want []byte
   635  
   636  	testCases := [][]int{
   637  		{65536, 0},
   638  		{65536, 1},
   639  		{65536, 1, 256},
   640  		{65536, 1, 65536},
   641  		{65536, 14},
   642  		{65536, 15},
   643  		{65536, 16},
   644  		{65536, 16, 256},
   645  		{65536, 16, 65536},
   646  		{65536, 127},
   647  		{65536, 128},
   648  		{65536, 128, 256},
   649  		{65536, 128, 65536},
   650  		{65536, 129},
   651  		{65536, 65536, 256},
   652  		{65536, 65536, 65536},
   653  	}
   654  
   655  	for i, tc := range testCases {
   656  		if i >= 3 && testing.Short() {
   657  			break
   658  		}
   659  		for _, firstN := range []int{1, 65534, 65535, 65536, 65537, 131072} {
   660  			tc[0] = firstN
   661  		outer:
   662  			for _, flush := range []bool{false, true} {
   663  				buf := new(bytes.Buffer)
   664  				want = want[:0]
   665  
   666  				w, err := NewWriter(buf, BestSpeed)
   667  				if err != nil {
   668  					t.Errorf("i=%d, firstN=%d, flush=%t: NewWriter: %v", i, firstN, flush, err)
   669  					continue
   670  				}
   671  				for _, n := range tc {
   672  					want = append(want, abcabc[:n]...)
   673  					if _, err := w.Write(abcabc[:n]); err != nil {
   674  						t.Errorf("i=%d, firstN=%d, flush=%t: Write: %v", i, firstN, flush, err)
   675  						continue outer
   676  					}
   677  					if !flush {
   678  						continue
   679  					}
   680  					if err := w.Flush(); err != nil {
   681  						t.Errorf("i=%d, firstN=%d, flush=%t: Flush: %v", i, firstN, flush, err)
   682  						continue outer
   683  					}
   684  				}
   685  				if err := w.Close(); err != nil {
   686  					t.Errorf("i=%d, firstN=%d, flush=%t: Close: %v", i, firstN, flush, err)
   687  					continue
   688  				}
   689  
   690  				r := NewReader(buf)
   691  				got, err := io.ReadAll(r)
   692  				if err != nil {
   693  					t.Errorf("i=%d, firstN=%d, flush=%t: ReadAll: %v", i, firstN, flush, err)
   694  					continue
   695  				}
   696  				r.Close()
   697  
   698  				if !bytes.Equal(got, want) {
   699  					t.Errorf("i=%d, firstN=%d, flush=%t: corruption during deflate-then-inflate", i, firstN, flush)
   700  					continue
   701  				}
   702  			}
   703  		}
   704  	}
   705  }
   706  
   707  var errIO = errors.New("IO error")
   708  
   709  // failWriter fails with errIO exactly at the nth call to Write.
   710  type failWriter struct{ n int }
   711  
   712  func (w *failWriter) Write(b []byte) (int, error) {
   713  	w.n--
   714  	if w.n == -1 {
   715  		return 0, errIO
   716  	}
   717  	return len(b), nil
   718  }
   719  
   720  func TestWriterPersistentWriteError(t *testing.T) {
   721  	t.Parallel()
   722  	d, err := os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
   723  	if err != nil {
   724  		t.Fatalf("ReadFile: %v", err)
   725  	}
   726  	d = d[:10000] // Keep this test short
   727  
   728  	zw, err := NewWriter(nil, DefaultCompression)
   729  	if err != nil {
   730  		t.Fatalf("NewWriter: %v", err)
   731  	}
   732  
   733  	// Sweep over the threshold at which an error is returned.
   734  	// The variable i makes it such that the ith call to failWriter.Write will
   735  	// return errIO. Since failWriter errors are not persistent, we must ensure
   736  	// that flate.Writer errors are persistent.
   737  	for i := 0; i < 1000; i++ {
   738  		fw := &failWriter{i}
   739  		zw.Reset(fw)
   740  
   741  		_, werr := zw.Write(d)
   742  		cerr := zw.Close()
   743  		ferr := zw.Flush()
   744  		if werr != errIO && werr != nil {
   745  			t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
   746  		}
   747  		if cerr != errIO && fw.n < 0 {
   748  			t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
   749  		}
   750  		if ferr != errIO && fw.n < 0 {
   751  			t.Errorf("test %d, mismatching Flush error: got %v, want %v", i, ferr, errIO)
   752  		}
   753  		if fw.n >= 0 {
   754  			// At this point, the failure threshold was sufficiently high enough
   755  			// that we wrote the whole stream without any errors.
   756  			return
   757  		}
   758  	}
   759  }
   760  func TestWriterPersistentFlushError(t *testing.T) {
   761  	zw, err := NewWriter(&failWriter{0}, DefaultCompression)
   762  	if err != nil {
   763  		t.Fatalf("NewWriter: %v", err)
   764  	}
   765  	flushErr := zw.Flush()
   766  	closeErr := zw.Close()
   767  	_, writeErr := zw.Write([]byte("Test"))
   768  	checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t)
   769  }
   770  
   771  func TestWriterPersistentCloseError(t *testing.T) {
   772  	// If underlying writer return error on closing stream we should persistent this error across all writer calls.
   773  	zw, err := NewWriter(&failWriter{0}, DefaultCompression)
   774  	if err != nil {
   775  		t.Fatalf("NewWriter: %v", err)
   776  	}
   777  	closeErr := zw.Close()
   778  	flushErr := zw.Flush()
   779  	_, writeErr := zw.Write([]byte("Test"))
   780  	checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t)
   781  
   782  	// After closing writer we should persistent "write after close" error across Flush and Write calls, but return nil
   783  	// on next Close calls.
   784  	var b bytes.Buffer
   785  	zw.Reset(&b)
   786  	err = zw.Close()
   787  	if err != nil {
   788  		t.Fatalf("First call to close returned error: %s", err)
   789  	}
   790  	err = zw.Close()
   791  	if err != nil {
   792  		t.Fatalf("Second call to close returned error: %s", err)
   793  	}
   794  
   795  	flushErr = zw.Flush()
   796  	_, writeErr = zw.Write([]byte("Test"))
   797  	checkErrors([]error{flushErr, writeErr}, errWriterClosed, t)
   798  }
   799  
   800  func checkErrors(got []error, want error, t *testing.T) {
   801  	t.Helper()
   802  	for _, err := range got {
   803  		if err != want {
   804  			t.Errorf("Error doesn't match\nWant: %s\nGot: %s", want, got)
   805  		}
   806  	}
   807  }
   808  
   809  func TestBestSpeedMatch(t *testing.T) {
   810  	t.Parallel()
   811  	cases := []struct {
   812  		previous, current []byte
   813  		t, s, want        int32
   814  	}{{
   815  		previous: []byte{0, 0, 0, 1, 2},
   816  		current:  []byte{3, 4, 5, 0, 1, 2, 3, 4, 5},
   817  		t:        -3,
   818  		s:        3,
   819  		want:     6,
   820  	}, {
   821  		previous: []byte{0, 0, 0, 1, 2},
   822  		current:  []byte{2, 4, 5, 0, 1, 2, 3, 4, 5},
   823  		t:        -3,
   824  		s:        3,
   825  		want:     3,
   826  	}, {
   827  		previous: []byte{0, 0, 0, 1, 1},
   828  		current:  []byte{3, 4, 5, 0, 1, 2, 3, 4, 5},
   829  		t:        -3,
   830  		s:        3,
   831  		want:     2,
   832  	}, {
   833  		previous: []byte{0, 0, 0, 1, 2},
   834  		current:  []byte{2, 2, 2, 2, 1, 2, 3, 4, 5},
   835  		t:        -1,
   836  		s:        0,
   837  		want:     4,
   838  	}, {
   839  		previous: []byte{0, 0, 0, 1, 2, 3, 4, 5, 2, 2},
   840  		current:  []byte{2, 2, 2, 2, 1, 2, 3, 4, 5},
   841  		t:        -7,
   842  		s:        4,
   843  		want:     5,
   844  	}, {
   845  		previous: []byte{9, 9, 9, 9, 9},
   846  		current:  []byte{2, 2, 2, 2, 1, 2, 3, 4, 5},
   847  		t:        -1,
   848  		s:        0,
   849  		want:     0,
   850  	}, {
   851  		previous: []byte{9, 9, 9, 9, 9},
   852  		current:  []byte{9, 2, 2, 2, 1, 2, 3, 4, 5},
   853  		t:        0,
   854  		s:        1,
   855  		want:     0,
   856  	}, {
   857  		previous: []byte{},
   858  		current:  []byte{9, 2, 2, 2, 1, 2, 3, 4, 5},
   859  		t:        -5,
   860  		s:        1,
   861  		want:     0,
   862  	}, {
   863  		previous: []byte{},
   864  		current:  []byte{9, 2, 2, 2, 1, 2, 3, 4, 5},
   865  		t:        -1,
   866  		s:        1,
   867  		want:     0,
   868  	}, {
   869  		previous: []byte{},
   870  		current:  []byte{2, 2, 2, 2, 1, 2, 3, 4, 5},
   871  		t:        0,
   872  		s:        1,
   873  		want:     3,
   874  	}, {
   875  		previous: []byte{3, 4, 5},
   876  		current:  []byte{3, 4, 5},
   877  		t:        -3,
   878  		s:        0,
   879  		want:     3,
   880  	}, {
   881  		previous: make([]byte, 1000),
   882  		current:  make([]byte, 1000),
   883  		t:        -1000,
   884  		s:        0,
   885  		want:     maxMatchLength - 4,
   886  	}, {
   887  		previous: make([]byte, 200),
   888  		current:  make([]byte, 500),
   889  		t:        -200,
   890  		s:        0,
   891  		want:     maxMatchLength - 4,
   892  	}, {
   893  		previous: make([]byte, 200),
   894  		current:  make([]byte, 500),
   895  		t:        0,
   896  		s:        1,
   897  		want:     maxMatchLength - 4,
   898  	}, {
   899  		previous: make([]byte, maxMatchLength-4),
   900  		current:  make([]byte, 500),
   901  		t:        -(maxMatchLength - 4),
   902  		s:        0,
   903  		want:     maxMatchLength - 4,
   904  	}, {
   905  		previous: make([]byte, 200),
   906  		current:  make([]byte, 500),
   907  		t:        -200,
   908  		s:        400,
   909  		want:     100,
   910  	}, {
   911  		previous: make([]byte, 10),
   912  		current:  make([]byte, 500),
   913  		t:        200,
   914  		s:        400,
   915  		want:     100,
   916  	}}
   917  	for i, c := range cases {
   918  		e := deflateFast{prev: c.previous}
   919  		got := e.matchLen(c.s, c.t, c.current)
   920  		if got != c.want {
   921  			t.Errorf("Test %d: match length, want %d, got %d", i, c.want, got)
   922  		}
   923  	}
   924  }
   925  
   926  func TestBestSpeedMaxMatchOffset(t *testing.T) {
   927  	t.Parallel()
   928  	const abc, xyz = "abcdefgh", "stuvwxyz"
   929  	for _, matchBefore := range []bool{false, true} {
   930  		for _, extra := range []int{0, inputMargin - 1, inputMargin, inputMargin + 1, 2 * inputMargin} {
   931  			for offsetAdj := -5; offsetAdj <= +5; offsetAdj++ {
   932  				report := func(desc string, err error) {
   933  					t.Errorf("matchBefore=%t, extra=%d, offsetAdj=%d: %s%v",
   934  						matchBefore, extra, offsetAdj, desc, err)
   935  				}
   936  
   937  				offset := maxMatchOffset + offsetAdj
   938  
   939  				// Make src to be a []byte of the form
   940  				//	"%s%s%s%s%s" % (abc, zeros0, xyzMaybe, abc, zeros1)
   941  				// where:
   942  				//	zeros0 is approximately maxMatchOffset zeros.
   943  				//	xyzMaybe is either xyz or the empty string.
   944  				//	zeros1 is between 0 and 30 zeros.
   945  				// The difference between the two abc's will be offset, which
   946  				// is maxMatchOffset plus or minus a small adjustment.
   947  				src := make([]byte, offset+len(abc)+extra)
   948  				copy(src, abc)
   949  				if !matchBefore {
   950  					copy(src[offset-len(xyz):], xyz)
   951  				}
   952  				copy(src[offset:], abc)
   953  
   954  				buf := new(bytes.Buffer)
   955  				w, err := NewWriter(buf, BestSpeed)
   956  				if err != nil {
   957  					report("NewWriter: ", err)
   958  					continue
   959  				}
   960  				if _, err := w.Write(src); err != nil {
   961  					report("Write: ", err)
   962  					continue
   963  				}
   964  				if err := w.Close(); err != nil {
   965  					report("Writer.Close: ", err)
   966  					continue
   967  				}
   968  
   969  				r := NewReader(buf)
   970  				dst, err := io.ReadAll(r)
   971  				r.Close()
   972  				if err != nil {
   973  					report("ReadAll: ", err)
   974  					continue
   975  				}
   976  
   977  				if !bytes.Equal(dst, src) {
   978  					report("", fmt.Errorf("bytes differ after round-tripping"))
   979  					continue
   980  				}
   981  			}
   982  		}
   983  	}
   984  }
   985  
   986  func TestBestSpeedShiftOffsets(t *testing.T) {
   987  	// Test if shiftoffsets properly preserves matches and resets out-of-range matches
   988  	// seen in https://github.com/golang/go/issues/4142
   989  	enc := newDeflateFast()
   990  
   991  	// testData may not generate internal matches.
   992  	testData := make([]byte, 32)
   993  	rng := rand.New(rand.NewSource(0))
   994  	for i := range testData {
   995  		testData[i] = byte(rng.Uint32())
   996  	}
   997  
   998  	// Encode the testdata with clean state.
   999  	// Second part should pick up matches from the first block.
  1000  	wantFirstTokens := len(enc.encode(nil, testData))
  1001  	wantSecondTokens := len(enc.encode(nil, testData))
  1002  
  1003  	if wantFirstTokens <= wantSecondTokens {
  1004  		t.Fatalf("test needs matches between inputs to be generated")
  1005  	}
  1006  	// Forward the current indicator to before wraparound.
  1007  	enc.cur = bufferReset - int32(len(testData))
  1008  
  1009  	// Part 1 before wrap, should match clean state.
  1010  	got := len(enc.encode(nil, testData))
  1011  	if wantFirstTokens != got {
  1012  		t.Errorf("got %d, want %d tokens", got, wantFirstTokens)
  1013  	}
  1014  
  1015  	// Verify we are about to wrap.
  1016  	if enc.cur != bufferReset {
  1017  		t.Errorf("got %d, want e.cur to be at bufferReset (%d)", enc.cur, bufferReset)
  1018  	}
  1019  
  1020  	// Part 2 should match clean state as well even if wrapped.
  1021  	got = len(enc.encode(nil, testData))
  1022  	if wantSecondTokens != got {
  1023  		t.Errorf("got %d, want %d token", got, wantSecondTokens)
  1024  	}
  1025  
  1026  	// Verify that we wrapped.
  1027  	if enc.cur >= bufferReset {
  1028  		t.Errorf("want e.cur to be < bufferReset (%d), got %d", bufferReset, enc.cur)
  1029  	}
  1030  
  1031  	// Forward the current buffer, leaving the matches at the bottom.
  1032  	enc.cur = bufferReset
  1033  	enc.shiftOffsets()
  1034  
  1035  	// Ensure that no matches were picked up.
  1036  	got = len(enc.encode(nil, testData))
  1037  	if wantFirstTokens != got {
  1038  		t.Errorf("got %d, want %d tokens", got, wantFirstTokens)
  1039  	}
  1040  }
  1041  
  1042  func TestMaxStackSize(t *testing.T) {
  1043  	// This test must not run in parallel with other tests as debug.SetMaxStack
  1044  	// affects all goroutines.
  1045  	n := debug.SetMaxStack(1 << 16)
  1046  	defer debug.SetMaxStack(n)
  1047  
  1048  	var wg sync.WaitGroup
  1049  	defer wg.Wait()
  1050  
  1051  	b := make([]byte, 1<<20)
  1052  	for level := HuffmanOnly; level <= BestCompression; level++ {
  1053  		// Run in separate goroutine to increase probability of stack regrowth.
  1054  		wg.Add(1)
  1055  		go func(level int) {
  1056  			defer wg.Done()
  1057  			zw, err := NewWriter(io.Discard, level)
  1058  			if err != nil {
  1059  				t.Errorf("level %d, NewWriter() = %v, want nil", level, err)
  1060  			}
  1061  			if n, err := zw.Write(b); n != len(b) || err != nil {
  1062  				t.Errorf("level %d, Write() = (%d, %v), want (%d, nil)", level, n, err, len(b))
  1063  			}
  1064  			if err := zw.Close(); err != nil {
  1065  				t.Errorf("level %d, Close() = %v, want nil", level, err)
  1066  			}
  1067  			zw.Reset(io.Discard)
  1068  		}(level)
  1069  	}
  1070  }
  1071  

View as plain text