Source file src/runtime/testdata/testprog/coro.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.rangefunc
     6  
     7  package main
     8  
     9  import (
    10  	"fmt"
    11  	"iter"
    12  	"runtime"
    13  )
    14  
    15  func init() {
    16  	register("CoroLockOSThreadIterLock", func() {
    17  		println("expect: OK")
    18  		CoroLockOSThread(callerExhaust, iterLock)
    19  	})
    20  	register("CoroLockOSThreadIterLockYield", func() {
    21  		println("expect: OS thread locking must match")
    22  		CoroLockOSThread(callerExhaust, iterLockYield)
    23  	})
    24  	register("CoroLockOSThreadLock", func() {
    25  		println("expect: OK")
    26  		CoroLockOSThread(callerExhaustLocked, iterSimple)
    27  	})
    28  	register("CoroLockOSThreadLockIterNested", func() {
    29  		println("expect: OK")
    30  		CoroLockOSThread(callerExhaustLocked, iterNested)
    31  	})
    32  	register("CoroLockOSThreadLockIterLock", func() {
    33  		println("expect: OK")
    34  		CoroLockOSThread(callerExhaustLocked, iterLock)
    35  	})
    36  	register("CoroLockOSThreadLockIterLockYield", func() {
    37  		println("expect: OS thread locking must match")
    38  		CoroLockOSThread(callerExhaustLocked, iterLockYield)
    39  	})
    40  	register("CoroLockOSThreadLockIterYieldNewG", func() {
    41  		println("expect: OS thread locking must match")
    42  		CoroLockOSThread(callerExhaustLocked, iterYieldNewG)
    43  	})
    44  	register("CoroLockOSThreadLockAfterPull", func() {
    45  		println("expect: OS thread locking must match")
    46  		CoroLockOSThread(callerLockAfterPull, iterSimple)
    47  	})
    48  	register("CoroLockOSThreadStopLocked", func() {
    49  		println("expect: OK")
    50  		CoroLockOSThread(callerStopLocked, iterSimple)
    51  	})
    52  	register("CoroLockOSThreadStopLockedIterNested", func() {
    53  		println("expect: OK")
    54  		CoroLockOSThread(callerStopLocked, iterNested)
    55  	})
    56  }
    57  
    58  func CoroLockOSThread(driver func(iter.Seq[int]) error, seq iter.Seq[int]) {
    59  	if err := driver(seq); err != nil {
    60  		println("error:", err.Error())
    61  		return
    62  	}
    63  	println("OK")
    64  }
    65  
    66  func callerExhaust(i iter.Seq[int]) error {
    67  	next, _ := iter.Pull(i)
    68  	for {
    69  		v, ok := next()
    70  		if !ok {
    71  			break
    72  		}
    73  		if v != 5 {
    74  			return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
    75  		}
    76  	}
    77  	return nil
    78  }
    79  
    80  func callerExhaustLocked(i iter.Seq[int]) error {
    81  	runtime.LockOSThread()
    82  	next, _ := iter.Pull(i)
    83  	for {
    84  		v, ok := next()
    85  		if !ok {
    86  			break
    87  		}
    88  		if v != 5 {
    89  			return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
    90  		}
    91  	}
    92  	runtime.UnlockOSThread()
    93  	return nil
    94  }
    95  
    96  func callerLockAfterPull(i iter.Seq[int]) error {
    97  	n := 0
    98  	next, _ := iter.Pull(i)
    99  	for {
   100  		runtime.LockOSThread()
   101  		n++
   102  		v, ok := next()
   103  		if !ok {
   104  			break
   105  		}
   106  		if v != 5 {
   107  			return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
   108  		}
   109  	}
   110  	for range n {
   111  		runtime.UnlockOSThread()
   112  	}
   113  	return nil
   114  }
   115  
   116  func callerStopLocked(i iter.Seq[int]) error {
   117  	runtime.LockOSThread()
   118  	next, stop := iter.Pull(i)
   119  	v, _ := next()
   120  	stop()
   121  	if v != 5 {
   122  		return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
   123  	}
   124  	runtime.UnlockOSThread()
   125  	return nil
   126  }
   127  
   128  func iterSimple(yield func(int) bool) {
   129  	for range 3 {
   130  		if !yield(5) {
   131  			return
   132  		}
   133  	}
   134  }
   135  
   136  func iterNested(yield func(int) bool) {
   137  	next, stop := iter.Pull(iterSimple)
   138  	for {
   139  		v, ok := next()
   140  		if ok {
   141  			if !yield(v) {
   142  				stop()
   143  			}
   144  		} else {
   145  			return
   146  		}
   147  	}
   148  }
   149  
   150  func iterLock(yield func(int) bool) {
   151  	for range 3 {
   152  		runtime.LockOSThread()
   153  		runtime.UnlockOSThread()
   154  
   155  		if !yield(5) {
   156  			return
   157  		}
   158  	}
   159  }
   160  
   161  func iterLockYield(yield func(int) bool) {
   162  	for range 3 {
   163  		runtime.LockOSThread()
   164  		ok := yield(5)
   165  		runtime.UnlockOSThread()
   166  		if !ok {
   167  			return
   168  		}
   169  	}
   170  }
   171  
   172  func iterYieldNewG(yield func(int) bool) {
   173  	for range 3 {
   174  		done := make(chan struct{})
   175  		var ok bool
   176  		go func() {
   177  			ok = yield(5)
   178  			done <- struct{}{}
   179  		}()
   180  		<-done
   181  		if !ok {
   182  			return
   183  		}
   184  	}
   185  }
   186  

View as plain text