Source file src/cmd/compile/internal/bloop/bloop.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package bloop
     6  
     7  // This file contains support routines for keeping
     8  // statements alive
     9  // in such loops (example):
    10  //
    11  //	for b.Loop() {
    12  //		var a, b int
    13  //		a = 5
    14  //		b = 6
    15  //		f(a, b)
    16  //	}
    17  //
    18  // The results of a, b and f(a, b) will be kept alive.
    19  //
    20  // Formally, the lhs (if they are [ir.Name]-s) of
    21  // [ir.AssignStmt], [ir.AssignListStmt],
    22  // [ir.AssignOpStmt], and the results of [ir.CallExpr]
    23  // or its args if it doesn't return a value will be kept
    24  // alive.
    25  //
    26  // The keep alive logic is implemented with as wrapping a
    27  // runtime.KeepAlive around the Name.
    28  //
    29  // TODO: currently this is implemented with KeepAlive
    30  // because it will prevent DSE and DCE which is probably
    31  // what we want right now. And KeepAlive takes an ssa
    32  // value instead of a symbol, which is easier to manage.
    33  // But since KeepAlive's context was mainly in the runtime
    34  // and GC, should we implement a new intrinsic that lowers
    35  // to OpVarLive? Peeling out the symbols is a bit tricky
    36  // and also VarLive seems to assume that there exists a
    37  // VarDef on the same symbol that dominates it.
    38  
    39  import (
    40  	"cmd/compile/internal/base"
    41  	"cmd/compile/internal/ir"
    42  	"cmd/compile/internal/reflectdata"
    43  	"cmd/compile/internal/typecheck"
    44  	"cmd/compile/internal/types"
    45  	"fmt"
    46  )
    47  
    48  // getNameFromNode tries to iteratively peel down the node to
    49  // get the name.
    50  func getNameFromNode(n ir.Node) *ir.Name {
    51  	var ret *ir.Name
    52  	if n.Op() == ir.ONAME {
    53  		ret = n.(*ir.Name)
    54  	} else {
    55  		// avoid infinite recursion on circular referencing nodes.
    56  		seen := map[ir.Node]bool{n: true}
    57  		var findName func(ir.Node) bool
    58  		findName = func(a ir.Node) bool {
    59  			if a.Op() == ir.ONAME {
    60  				ret = a.(*ir.Name)
    61  				return true
    62  			}
    63  			if !seen[a] {
    64  				seen[a] = true
    65  				return ir.DoChildren(a, findName)
    66  			}
    67  			return false
    68  		}
    69  		ir.DoChildren(n, findName)
    70  	}
    71  	return ret
    72  }
    73  
    74  // keepAliveAt returns a statement that is either curNode, or a
    75  // block containing curNode followed by a call to runtime.keepAlive for each
    76  // ONAME in ns. These calls ensure that names in ns will be live until
    77  // after curNode's execution.
    78  func keepAliveAt(ns []*ir.Name, curNode ir.Node) ir.Node {
    79  	if len(ns) == 0 {
    80  		return curNode
    81  	}
    82  
    83  	pos := curNode.Pos()
    84  	calls := []ir.Node{curNode}
    85  	for _, n := range ns {
    86  		if n == nil {
    87  			continue
    88  		}
    89  		if n.Sym() == nil {
    90  			continue
    91  		}
    92  		if n.Sym().IsBlank() {
    93  			continue
    94  		}
    95  		arg := ir.NewConvExpr(pos, ir.OCONV, types.Types[types.TINTER], n)
    96  		if !n.Type().IsInterface() {
    97  			srcRType0 := reflectdata.TypePtrAt(pos, n.Type())
    98  			arg.TypeWord = srcRType0
    99  			arg.SrcRType = srcRType0
   100  		}
   101  		callExpr := typecheck.Call(pos,
   102  			typecheck.LookupRuntime("KeepAlive"),
   103  			[]ir.Node{arg}, false).(*ir.CallExpr)
   104  		callExpr.IsCompilerVarLive = true
   105  		callExpr.NoInline = true
   106  		calls = append(calls, callExpr)
   107  	}
   108  
   109  	return ir.NewBlockStmt(pos, calls)
   110  }
   111  
   112  func debugName(name *ir.Name, line string) {
   113  	if base.Flag.LowerM > 0 {
   114  		if name.Linksym() != nil {
   115  			fmt.Printf("%v: %s will be kept alive\n", line, name.Linksym().Name)
   116  		} else {
   117  			fmt.Printf("%v: expr will be kept alive\n", line)
   118  		}
   119  	}
   120  }
   121  
   122  // preserveStmt transforms stmt so that any names defined/assigned within it
   123  // are used after stmt's execution, preventing their dead code elimination
   124  // and dead store elimination. The return value is the transformed statement.
   125  func preserveStmt(curFn *ir.Func, stmt ir.Node) (ret ir.Node) {
   126  	ret = stmt
   127  	switch n := stmt.(type) {
   128  	case *ir.AssignStmt:
   129  		// Peel down struct and slice indexing to get the names
   130  		name := getNameFromNode(n.X)
   131  		if name != nil {
   132  			debugName(name, ir.Line(stmt))
   133  			ret = keepAliveAt([]*ir.Name{name}, n)
   134  		}
   135  	case *ir.AssignListStmt:
   136  		names := []*ir.Name{}
   137  		for _, lhs := range n.Lhs {
   138  			name := getNameFromNode(lhs)
   139  			if name != nil {
   140  				debugName(name, ir.Line(stmt))
   141  				names = append(names, name)
   142  			}
   143  		}
   144  		ret = keepAliveAt(names, n)
   145  	case *ir.AssignOpStmt:
   146  		name := getNameFromNode(n.X)
   147  		if name != nil {
   148  			debugName(name, ir.Line(stmt))
   149  			ret = keepAliveAt([]*ir.Name{name}, n)
   150  		}
   151  	case *ir.CallExpr:
   152  		names := []*ir.Name{}
   153  		curNode := stmt
   154  		if n.Fun != nil && n.Fun.Type() != nil && n.Fun.Type().NumResults() != 0 {
   155  			// This function's results are not assigned, assign them to
   156  			// auto tmps and then keepAliveAt these autos.
   157  			// Note: markStmt assumes the context that it's called - this CallExpr is
   158  			// not within another OAS2, which is guaranteed by the case above.
   159  			results := n.Fun.Type().Results()
   160  			lhs := make([]ir.Node, len(results))
   161  			for i, res := range results {
   162  				tmp := typecheck.TempAt(n.Pos(), curFn, res.Type)
   163  				lhs[i] = tmp
   164  				names = append(names, tmp)
   165  			}
   166  
   167  			// Create an assignment statement.
   168  			assign := typecheck.AssignExpr(
   169  				ir.NewAssignListStmt(n.Pos(), ir.OAS2, lhs,
   170  					[]ir.Node{n})).(*ir.AssignListStmt)
   171  			assign.Def = true
   172  			curNode = assign
   173  			plural := ""
   174  			if len(results) > 1 {
   175  				plural = "s"
   176  			}
   177  			if base.Flag.LowerM > 0 {
   178  				fmt.Printf("%v: function result%s will be kept alive\n", ir.Line(stmt), plural)
   179  			}
   180  		} else {
   181  			// This function probably doesn't return anything, keep its args alive.
   182  			argTmps := []ir.Node{}
   183  			for i, a := range n.Args {
   184  				if name := getNameFromNode(a); name != nil {
   185  					// If they are name, keep them alive directly.
   186  					debugName(name, ir.Line(stmt))
   187  					names = append(names, name)
   188  				} else if a.Op() == ir.OSLICELIT {
   189  					// variadic args are encoded as slice literal.
   190  					s := a.(*ir.CompLitExpr)
   191  					ns := []*ir.Name{}
   192  					for i, n := range s.List {
   193  						if name := getNameFromNode(n); name != nil {
   194  							debugName(name, ir.Line(a))
   195  							ns = append(ns, name)
   196  						} else {
   197  							// We need a temporary to save this arg.
   198  							tmp := typecheck.TempAt(n.Pos(), curFn, n.Type())
   199  							argTmps = append(argTmps, typecheck.AssignExpr(ir.NewAssignStmt(n.Pos(), tmp, n)))
   200  							names = append(names, tmp)
   201  							s.List[i] = tmp
   202  							if base.Flag.LowerM > 0 {
   203  								fmt.Printf("%v: function arg will be kept alive\n", ir.Line(n))
   204  							}
   205  						}
   206  					}
   207  					names = append(names, ns...)
   208  				} else {
   209  					// expressions, we need to assign them to temps and change the original arg to reference
   210  					// them.
   211  					tmp := typecheck.TempAt(n.Pos(), curFn, a.Type())
   212  					argTmps = append(argTmps, typecheck.AssignExpr(ir.NewAssignStmt(n.Pos(), tmp, a)))
   213  					names = append(names, tmp)
   214  					n.Args[i] = tmp
   215  					if base.Flag.LowerM > 0 {
   216  						fmt.Printf("%v: function arg will be kept alive\n", ir.Line(stmt))
   217  					}
   218  				}
   219  			}
   220  			if len(argTmps) > 0 {
   221  				argTmps = append(argTmps, n)
   222  				curNode = ir.NewBlockStmt(n.Pos(), argTmps)
   223  			}
   224  		}
   225  		ret = keepAliveAt(names, curNode)
   226  	}
   227  	return
   228  }
   229  
   230  func preserveStmts(curFn *ir.Func, list ir.Nodes) {
   231  	for i := range list {
   232  		list[i] = preserveStmt(curFn, list[i])
   233  	}
   234  }
   235  
   236  // isTestingBLoop returns true if it matches the node as a
   237  // testing.(*B).Loop. See issue #61515.
   238  func isTestingBLoop(t ir.Node) bool {
   239  	if t.Op() != ir.OFOR {
   240  		return false
   241  	}
   242  	nFor, ok := t.(*ir.ForStmt)
   243  	if !ok || nFor.Cond == nil || nFor.Cond.Op() != ir.OCALLFUNC {
   244  		return false
   245  	}
   246  	n, ok := nFor.Cond.(*ir.CallExpr)
   247  	if !ok || n.Fun == nil || n.Fun.Op() != ir.OMETHEXPR {
   248  		return false
   249  	}
   250  	name := ir.MethodExprName(n.Fun)
   251  	if name == nil {
   252  		return false
   253  	}
   254  	if fSym := name.Sym(); fSym != nil && name.Class == ir.PFUNC && fSym.Pkg != nil &&
   255  		fSym.Name == "(*B).Loop" && fSym.Pkg.Path == "testing" {
   256  		// Attempting to match a function call to testing.(*B).Loop
   257  		return true
   258  	}
   259  	return false
   260  }
   261  
   262  type editor struct {
   263  	inBloop bool
   264  	curFn   *ir.Func
   265  }
   266  
   267  func (e editor) edit(n ir.Node) ir.Node {
   268  	e.inBloop = isTestingBLoop(n) || e.inBloop
   269  	// It's in bloop, mark the stmts with bodies.
   270  	ir.EditChildren(n, e.edit)
   271  	if e.inBloop {
   272  		switch n := n.(type) {
   273  		case *ir.ForStmt:
   274  			preserveStmts(e.curFn, n.Body)
   275  		case *ir.IfStmt:
   276  			preserveStmts(e.curFn, n.Body)
   277  			preserveStmts(e.curFn, n.Else)
   278  		case *ir.BlockStmt:
   279  			preserveStmts(e.curFn, n.List)
   280  		case *ir.CaseClause:
   281  			preserveStmts(e.curFn, n.List)
   282  			preserveStmts(e.curFn, n.Body)
   283  		case *ir.CommClause:
   284  			preserveStmts(e.curFn, n.Body)
   285  		}
   286  	}
   287  	return n
   288  }
   289  
   290  // BloopWalk performs a walk on all functions in the package
   291  // if it imports testing and wrap the results of all qualified
   292  // statements in a runtime.KeepAlive intrinsic call. See package
   293  // doc for more details.
   294  //
   295  //	for b.Loop() {...}
   296  //
   297  // loop's body.
   298  func BloopWalk(pkg *ir.Package) {
   299  	hasTesting := false
   300  	for _, i := range pkg.Imports {
   301  		if i.Path == "testing" {
   302  			hasTesting = true
   303  			break
   304  		}
   305  	}
   306  	if !hasTesting {
   307  		return
   308  	}
   309  	for _, fn := range pkg.Funcs {
   310  		e := editor{false, fn}
   311  		ir.EditChildren(fn, e.edit)
   312  	}
   313  }
   314  

View as plain text