Source file src/runtime/testdata/testprogcgo/coro.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.rangefunc && !windows
     6  
     7  package main
     8  
     9  /*
    10  #include <stdint.h> // for uintptr_t
    11  
    12  void go_callback_coro(uintptr_t handle);
    13  
    14  static void call_go(uintptr_t handle) {
    15  	go_callback_coro(handle);
    16  }
    17  */
    18  import "C"
    19  
    20  import (
    21  	"fmt"
    22  	"iter"
    23  	"runtime/cgo"
    24  )
    25  
    26  func init() {
    27  	register("CoroCgoIterCallback", func() {
    28  		println("expect: OK")
    29  		CoroCgo(callerExhaust, iterCallback)
    30  	})
    31  	register("CoroCgoIterCallbackYield", func() {
    32  		println("expect: OS thread locking must match")
    33  		CoroCgo(callerExhaust, iterCallbackYield)
    34  	})
    35  	register("CoroCgoCallback", func() {
    36  		println("expect: OK")
    37  		CoroCgo(callerExhaustCallback, iterSimple)
    38  	})
    39  	register("CoroCgoCallbackIterNested", func() {
    40  		println("expect: OK")
    41  		CoroCgo(callerExhaustCallback, iterNested)
    42  	})
    43  	register("CoroCgoCallbackIterCallback", func() {
    44  		println("expect: OK")
    45  		CoroCgo(callerExhaustCallback, iterCallback)
    46  	})
    47  	register("CoroCgoCallbackIterCallbackYield", func() {
    48  		println("expect: OS thread locking must match")
    49  		CoroCgo(callerExhaustCallback, iterCallbackYield)
    50  	})
    51  	register("CoroCgoCallbackAfterPull", func() {
    52  		println("expect: OS thread locking must match")
    53  		CoroCgo(callerCallbackAfterPull, iterSimple)
    54  	})
    55  	register("CoroCgoStopCallback", func() {
    56  		println("expect: OK")
    57  		CoroCgo(callerStopCallback, iterSimple)
    58  	})
    59  	register("CoroCgoStopCallbackIterNested", func() {
    60  		println("expect: OK")
    61  		CoroCgo(callerStopCallback, iterNested)
    62  	})
    63  }
    64  
    65  var toCall func()
    66  
    67  //export go_callback_coro
    68  func go_callback_coro(handle C.uintptr_t) {
    69  	h := cgo.Handle(handle)
    70  	h.Value().(func())()
    71  	h.Delete()
    72  }
    73  
    74  func callFromC(f func()) {
    75  	C.call_go(C.uintptr_t(cgo.NewHandle(f)))
    76  }
    77  
    78  func CoroCgo(driver func(iter.Seq[int]) error, seq iter.Seq[int]) {
    79  	if err := driver(seq); err != nil {
    80  		println("error:", err.Error())
    81  		return
    82  	}
    83  	println("OK")
    84  }
    85  
    86  func callerExhaust(i iter.Seq[int]) error {
    87  	next, _ := iter.Pull(i)
    88  	for {
    89  		v, ok := next()
    90  		if !ok {
    91  			break
    92  		}
    93  		if v != 5 {
    94  			return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
    95  		}
    96  	}
    97  	return nil
    98  }
    99  
   100  func callerExhaustCallback(i iter.Seq[int]) (err error) {
   101  	callFromC(func() {
   102  		next, _ := iter.Pull(i)
   103  		for {
   104  			v, ok := next()
   105  			if !ok {
   106  				break
   107  			}
   108  			if v != 5 {
   109  				err = fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
   110  			}
   111  		}
   112  	})
   113  	return err
   114  }
   115  
   116  func callerStopCallback(i iter.Seq[int]) (err error) {
   117  	callFromC(func() {
   118  		next, stop := iter.Pull(i)
   119  		v, _ := next()
   120  		stop()
   121  		if v != 5 {
   122  			err = fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
   123  		}
   124  	})
   125  	return err
   126  }
   127  
   128  func callerCallbackAfterPull(i iter.Seq[int]) (err error) {
   129  	next, _ := iter.Pull(i)
   130  	callFromC(func() {
   131  		for {
   132  			v, ok := next()
   133  			if !ok {
   134  				break
   135  			}
   136  			if v != 5 {
   137  				err = fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
   138  			}
   139  		}
   140  	})
   141  	return err
   142  }
   143  
   144  func iterSimple(yield func(int) bool) {
   145  	for range 3 {
   146  		if !yield(5) {
   147  			return
   148  		}
   149  	}
   150  }
   151  
   152  func iterNested(yield func(int) bool) {
   153  	next, stop := iter.Pull(iterSimple)
   154  	for {
   155  		v, ok := next()
   156  		if ok {
   157  			if !yield(v) {
   158  				stop()
   159  			}
   160  		} else {
   161  			return
   162  		}
   163  	}
   164  }
   165  
   166  func iterCallback(yield func(int) bool) {
   167  	for range 3 {
   168  		callFromC(func() {})
   169  		if !yield(5) {
   170  			return
   171  		}
   172  	}
   173  }
   174  
   175  func iterCallbackYield(yield func(int) bool) {
   176  	for range 3 {
   177  		var ok bool
   178  		callFromC(func() {
   179  			ok = yield(5)
   180  		})
   181  		if !ok {
   182  			return
   183  		}
   184  	}
   185  }
   186  

View as plain text