Source file src/iter/pull_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package iter_test
     6  
     7  import (
     8  	"fmt"
     9  	. "iter"
    10  	"runtime"
    11  	"testing"
    12  )
    13  
    14  func count(n int) Seq[int] {
    15  	return func(yield func(int) bool) {
    16  		for i := range n {
    17  			if !yield(i) {
    18  				break
    19  			}
    20  		}
    21  	}
    22  }
    23  
    24  func squares(n int) Seq2[int, int64] {
    25  	return func(yield func(int, int64) bool) {
    26  		for i := range n {
    27  			if !yield(i, int64(i)*int64(i)) {
    28  				break
    29  			}
    30  		}
    31  	}
    32  }
    33  
    34  func TestPull(t *testing.T) {
    35  	for end := 0; end <= 3; end++ {
    36  		t.Run(fmt.Sprint(end), func(t *testing.T) {
    37  			ng := stableNumGoroutine()
    38  			wantNG := func(want int) {
    39  				if xg := runtime.NumGoroutine() - ng; xg != want {
    40  					t.Helper()
    41  					t.Errorf("have %d extra goroutines, want %d", xg, want)
    42  				}
    43  			}
    44  			wantNG(0)
    45  			next, stop := Pull(count(3))
    46  			wantNG(1)
    47  			for i := range end {
    48  				v, ok := next()
    49  				if v != i || ok != true {
    50  					t.Fatalf("next() = %d, %v, want %d, %v", v, ok, i, true)
    51  				}
    52  				wantNG(1)
    53  			}
    54  			wantNG(1)
    55  			if end < 3 {
    56  				stop()
    57  				wantNG(0)
    58  			}
    59  			for range 2 {
    60  				v, ok := next()
    61  				if v != 0 || ok != false {
    62  					t.Fatalf("next() = %d, %v, want %d, %v", v, ok, 0, false)
    63  				}
    64  				wantNG(0)
    65  			}
    66  			wantNG(0)
    67  
    68  			stop()
    69  			stop()
    70  			stop()
    71  			wantNG(0)
    72  		})
    73  	}
    74  }
    75  
    76  func TestPull2(t *testing.T) {
    77  	for end := 0; end <= 3; end++ {
    78  		t.Run(fmt.Sprint(end), func(t *testing.T) {
    79  			ng := stableNumGoroutine()
    80  			wantNG := func(want int) {
    81  				if xg := runtime.NumGoroutine() - ng; xg != want {
    82  					t.Helper()
    83  					t.Errorf("have %d extra goroutines, want %d", xg, want)
    84  				}
    85  			}
    86  			wantNG(0)
    87  			next, stop := Pull2(squares(3))
    88  			wantNG(1)
    89  			for i := range end {
    90  				k, v, ok := next()
    91  				if k != i || v != int64(i*i) || ok != true {
    92  					t.Fatalf("next() = %d, %d, %v, want %d, %d, %v", k, v, ok, i, i*i, true)
    93  				}
    94  				wantNG(1)
    95  			}
    96  			wantNG(1)
    97  			if end < 3 {
    98  				stop()
    99  				wantNG(0)
   100  			}
   101  			for range 2 {
   102  				k, v, ok := next()
   103  				if v != 0 || ok != false {
   104  					t.Fatalf("next() = %d, %d, %v, want %d, %d, %v", k, v, ok, 0, 0, false)
   105  				}
   106  				wantNG(0)
   107  			}
   108  			wantNG(0)
   109  
   110  			stop()
   111  			stop()
   112  			stop()
   113  			wantNG(0)
   114  		})
   115  	}
   116  }
   117  
   118  // stableNumGoroutine is like NumGoroutine but tries to ensure stability of
   119  // the value by letting any exiting goroutines finish exiting.
   120  func stableNumGoroutine() int {
   121  	// The idea behind stablizing the value of NumGoroutine is to
   122  	// see the same value enough times in a row in between calls to
   123  	// runtime.Gosched. With GOMAXPROCS=1, we're trying to make sure
   124  	// that other goroutines run, so that they reach a stable point.
   125  	// It's not guaranteed, because it is still possible for a goroutine
   126  	// to Gosched back into itself, so we require NumGoroutine to be
   127  	// the same 100 times in a row. This should be more than enough to
   128  	// ensure all goroutines get a chance to run to completion (or to
   129  	// some block point) for a small group of test goroutines.
   130  	defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
   131  
   132  	c := 0
   133  	ng := runtime.NumGoroutine()
   134  	for i := 0; i < 1000; i++ {
   135  		nng := runtime.NumGoroutine()
   136  		if nng == ng {
   137  			c++
   138  		} else {
   139  			c = 0
   140  			ng = nng
   141  		}
   142  		if c >= 100 {
   143  			// The same value 100 times in a row is good enough.
   144  			return ng
   145  		}
   146  		runtime.Gosched()
   147  	}
   148  	panic("failed to stabilize NumGoroutine after 1000 iterations")
   149  }
   150  
   151  func TestPullDoubleNext(t *testing.T) {
   152  	next, _ := Pull(doDoubleNext())
   153  	nextSlot = next
   154  	next()
   155  	if nextSlot != nil {
   156  		t.Fatal("double next did not fail")
   157  	}
   158  }
   159  
   160  var nextSlot func() (int, bool)
   161  
   162  func doDoubleNext() Seq[int] {
   163  	return func(_ func(int) bool) {
   164  		defer func() {
   165  			if recover() != nil {
   166  				nextSlot = nil
   167  			}
   168  		}()
   169  		nextSlot()
   170  	}
   171  }
   172  
   173  func TestPullDoubleNext2(t *testing.T) {
   174  	next, _ := Pull2(doDoubleNext2())
   175  	nextSlot2 = next
   176  	next()
   177  	if nextSlot2 != nil {
   178  		t.Fatal("double next did not fail")
   179  	}
   180  }
   181  
   182  var nextSlot2 func() (int, int, bool)
   183  
   184  func doDoubleNext2() Seq2[int, int] {
   185  	return func(_ func(int, int) bool) {
   186  		defer func() {
   187  			if recover() != nil {
   188  				nextSlot2 = nil
   189  			}
   190  		}()
   191  		nextSlot2()
   192  	}
   193  }
   194  
   195  func TestPullDoubleYield(t *testing.T) {
   196  	next, stop := Pull(storeYield())
   197  	next()
   198  	if yieldSlot == nil {
   199  		t.Fatal("yield failed")
   200  	}
   201  	defer func() {
   202  		if recover() != nil {
   203  			yieldSlot = nil
   204  		}
   205  		stop()
   206  	}()
   207  	yieldSlot(5)
   208  	if yieldSlot != nil {
   209  		t.Fatal("double yield did not fail")
   210  	}
   211  }
   212  
   213  func storeYield() Seq[int] {
   214  	return func(yield func(int) bool) {
   215  		yieldSlot = yield
   216  		if !yield(5) {
   217  			return
   218  		}
   219  	}
   220  }
   221  
   222  var yieldSlot func(int) bool
   223  
   224  func TestPullDoubleYield2(t *testing.T) {
   225  	next, stop := Pull2(storeYield2())
   226  	next()
   227  	if yieldSlot2 == nil {
   228  		t.Fatal("yield failed")
   229  	}
   230  	defer func() {
   231  		if recover() != nil {
   232  			yieldSlot2 = nil
   233  		}
   234  		stop()
   235  	}()
   236  	yieldSlot2(23, 77)
   237  	if yieldSlot2 != nil {
   238  		t.Fatal("double yield did not fail")
   239  	}
   240  }
   241  
   242  func storeYield2() Seq2[int, int] {
   243  	return func(yield func(int, int) bool) {
   244  		yieldSlot2 = yield
   245  		if !yield(23, 77) {
   246  			return
   247  		}
   248  	}
   249  }
   250  
   251  var yieldSlot2 func(int, int) bool
   252  
   253  func TestPullPanic(t *testing.T) {
   254  	t.Run("next", func(t *testing.T) {
   255  		next, stop := Pull(panicSeq())
   256  		if !panicsWith("boom", func() { next() }) {
   257  			t.Fatal("failed to propagate panic on first next")
   258  		}
   259  		// Make sure we don't panic again if we try to call next or stop.
   260  		if _, ok := next(); ok {
   261  			t.Fatal("next returned true after iterator panicked")
   262  		}
   263  		// Calling stop again should be a no-op.
   264  		stop()
   265  	})
   266  	t.Run("stop", func(t *testing.T) {
   267  		next, stop := Pull(panicCleanupSeq())
   268  		x, ok := next()
   269  		if !ok || x != 55 {
   270  			t.Fatalf("expected (55, true) from next, got (%d, %t)", x, ok)
   271  		}
   272  		if !panicsWith("boom", func() { stop() }) {
   273  			t.Fatal("failed to propagate panic on stop")
   274  		}
   275  		// Make sure we don't panic again if we try to call next or stop.
   276  		if _, ok := next(); ok {
   277  			t.Fatal("next returned true after iterator panicked")
   278  		}
   279  		// Calling stop again should be a no-op.
   280  		stop()
   281  	})
   282  }
   283  
   284  func panicSeq() Seq[int] {
   285  	return func(yield func(int) bool) {
   286  		panic("boom")
   287  	}
   288  }
   289  
   290  func panicCleanupSeq() Seq[int] {
   291  	return func(yield func(int) bool) {
   292  		for {
   293  			if !yield(55) {
   294  				panic("boom")
   295  			}
   296  		}
   297  	}
   298  }
   299  
   300  func TestPull2Panic(t *testing.T) {
   301  	t.Run("next", func(t *testing.T) {
   302  		next, stop := Pull2(panicSeq2())
   303  		if !panicsWith("boom", func() { next() }) {
   304  			t.Fatal("failed to propagate panic on first next")
   305  		}
   306  		// Make sure we don't panic again if we try to call next or stop.
   307  		if _, _, ok := next(); ok {
   308  			t.Fatal("next returned true after iterator panicked")
   309  		}
   310  		// Calling stop again should be a no-op.
   311  		stop()
   312  	})
   313  	t.Run("stop", func(t *testing.T) {
   314  		next, stop := Pull2(panicCleanupSeq2())
   315  		x, y, ok := next()
   316  		if !ok || x != 55 || y != 100 {
   317  			t.Fatalf("expected (55, 100, true) from next, got (%d, %d, %t)", x, y, ok)
   318  		}
   319  		if !panicsWith("boom", func() { stop() }) {
   320  			t.Fatal("failed to propagate panic on stop")
   321  		}
   322  		// Make sure we don't panic again if we try to call next or stop.
   323  		if _, _, ok := next(); ok {
   324  			t.Fatal("next returned true after iterator panicked")
   325  		}
   326  		// Calling stop again should be a no-op.
   327  		stop()
   328  	})
   329  }
   330  
   331  func panicSeq2() Seq2[int, int] {
   332  	return func(yield func(int, int) bool) {
   333  		panic("boom")
   334  	}
   335  }
   336  
   337  func panicCleanupSeq2() Seq2[int, int] {
   338  	return func(yield func(int, int) bool) {
   339  		for {
   340  			if !yield(55, 100) {
   341  				panic("boom")
   342  			}
   343  		}
   344  	}
   345  }
   346  
   347  func panicsWith(v any, f func()) (panicked bool) {
   348  	defer func() {
   349  		if r := recover(); r != nil {
   350  			if r != v {
   351  				panic(r)
   352  			}
   353  			panicked = true
   354  		}
   355  	}()
   356  	f()
   357  	return
   358  }
   359  
   360  func TestPullGoexit(t *testing.T) {
   361  	t.Run("next", func(t *testing.T) {
   362  		var next func() (int, bool)
   363  		var stop func()
   364  		if !goexits(t, func() {
   365  			next, stop = Pull(goexitSeq())
   366  			next()
   367  		}) {
   368  			t.Fatal("failed to Goexit from next")
   369  		}
   370  		if x, ok := next(); x != 0 || ok {
   371  			t.Fatal("iterator returned valid value after iterator Goexited")
   372  		}
   373  		stop()
   374  	})
   375  	t.Run("stop", func(t *testing.T) {
   376  		next, stop := Pull(goexitCleanupSeq())
   377  		x, ok := next()
   378  		if !ok || x != 55 {
   379  			t.Fatalf("expected (55, true) from next, got (%d, %t)", x, ok)
   380  		}
   381  		if !goexits(t, func() {
   382  			stop()
   383  		}) {
   384  			t.Fatal("failed to Goexit from stop")
   385  		}
   386  		// Make sure we don't panic again if we try to call next or stop.
   387  		if x, ok := next(); x != 0 || ok {
   388  			t.Fatal("next returned true or non-zero value after iterator Goexited")
   389  		}
   390  		// Calling stop again should be a no-op.
   391  		stop()
   392  	})
   393  }
   394  
   395  func goexitSeq() Seq[int] {
   396  	return func(yield func(int) bool) {
   397  		runtime.Goexit()
   398  	}
   399  }
   400  
   401  func goexitCleanupSeq() Seq[int] {
   402  	return func(yield func(int) bool) {
   403  		for {
   404  			if !yield(55) {
   405  				runtime.Goexit()
   406  			}
   407  		}
   408  	}
   409  }
   410  
   411  func TestPull2Goexit(t *testing.T) {
   412  	t.Run("next", func(t *testing.T) {
   413  		var next func() (int, int, bool)
   414  		var stop func()
   415  		if !goexits(t, func() {
   416  			next, stop = Pull2(goexitSeq2())
   417  			next()
   418  		}) {
   419  			t.Fatal("failed to Goexit from next")
   420  		}
   421  		if x, y, ok := next(); x != 0 || y != 0 || ok {
   422  			t.Fatal("iterator returned valid value after iterator Goexited")
   423  		}
   424  		stop()
   425  	})
   426  	t.Run("stop", func(t *testing.T) {
   427  		next, stop := Pull2(goexitCleanupSeq2())
   428  		x, y, ok := next()
   429  		if !ok || x != 55 || y != 100 {
   430  			t.Fatalf("expected (55, 100, true) from next, got (%d, %d, %t)", x, y, ok)
   431  		}
   432  		if !goexits(t, func() {
   433  			stop()
   434  		}) {
   435  			t.Fatal("failed to Goexit from stop")
   436  		}
   437  		// Make sure we don't panic again if we try to call next or stop.
   438  		if x, y, ok := next(); x != 0 || y != 0 || ok {
   439  			t.Fatal("next returned true or non-zero after iterator Goexited")
   440  		}
   441  		// Calling stop again should be a no-op.
   442  		stop()
   443  	})
   444  }
   445  
   446  func goexitSeq2() Seq2[int, int] {
   447  	return func(yield func(int, int) bool) {
   448  		runtime.Goexit()
   449  	}
   450  }
   451  
   452  func goexitCleanupSeq2() Seq2[int, int] {
   453  	return func(yield func(int, int) bool) {
   454  		for {
   455  			if !yield(55, 100) {
   456  				runtime.Goexit()
   457  			}
   458  		}
   459  	}
   460  }
   461  
   462  func goexits(t *testing.T, f func()) bool {
   463  	t.Helper()
   464  
   465  	exit := make(chan bool)
   466  	go func() {
   467  		cleanExit := false
   468  		defer func() {
   469  			exit <- recover() == nil && !cleanExit
   470  		}()
   471  		f()
   472  		cleanExit = true
   473  	}()
   474  	return <-exit
   475  }
   476  
   477  func TestPullImmediateStop(t *testing.T) {
   478  	next, stop := Pull(panicSeq())
   479  	stop()
   480  	// Make sure we don't panic if we try to call next or stop.
   481  	if _, ok := next(); ok {
   482  		t.Fatal("next returned true after iterator was stopped")
   483  	}
   484  }
   485  
   486  func TestPull2ImmediateStop(t *testing.T) {
   487  	next, stop := Pull2(panicSeq2())
   488  	stop()
   489  	// Make sure we don't panic if we try to call next or stop.
   490  	if _, _, ok := next(); ok {
   491  		t.Fatal("next returned true after iterator was stopped")
   492  	}
   493  }
   494  

View as plain text