Source file src/cmd/compile/internal/inline/inlheur/analyze_func_callsites.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package inlheur
     6  
     7  import (
     8  	"cmd/compile/internal/ir"
     9  	"cmd/compile/internal/pgoir"
    10  	"cmd/compile/internal/typecheck"
    11  	"fmt"
    12  	"os"
    13  	"strings"
    14  )
    15  
    16  type callSiteAnalyzer struct {
    17  	fn *ir.Func
    18  	*nameFinder
    19  }
    20  
    21  type callSiteTableBuilder struct {
    22  	fn *ir.Func
    23  	*nameFinder
    24  	cstab    CallSiteTab
    25  	ptab     map[ir.Node]pstate
    26  	nstack   []ir.Node
    27  	loopNest int
    28  	isInit   bool
    29  }
    30  
    31  func makeCallSiteAnalyzer(fn *ir.Func) *callSiteAnalyzer {
    32  	return &callSiteAnalyzer{
    33  		fn:         fn,
    34  		nameFinder: newNameFinder(fn),
    35  	}
    36  }
    37  
    38  func makeCallSiteTableBuilder(fn *ir.Func, cstab CallSiteTab, ptab map[ir.Node]pstate, loopNestingLevel int, nf *nameFinder) *callSiteTableBuilder {
    39  	isInit := fn.IsPackageInit() || strings.HasPrefix(fn.Sym().Name, "init.")
    40  	return &callSiteTableBuilder{
    41  		fn:         fn,
    42  		cstab:      cstab,
    43  		ptab:       ptab,
    44  		isInit:     isInit,
    45  		loopNest:   loopNestingLevel,
    46  		nstack:     []ir.Node{fn},
    47  		nameFinder: nf,
    48  	}
    49  }
    50  
    51  // computeCallSiteTable builds and returns a table of call sites for
    52  // the specified region in function fn. A region here corresponds to a
    53  // specific subtree within the AST for a function. The main intended
    54  // use cases are for 'region' to be either A) an entire function body,
    55  // or B) an inlined call expression.
    56  func computeCallSiteTable(fn *ir.Func, region ir.Nodes, cstab CallSiteTab, ptab map[ir.Node]pstate, loopNestingLevel int, nf *nameFinder) CallSiteTab {
    57  	cstb := makeCallSiteTableBuilder(fn, cstab, ptab, loopNestingLevel, nf)
    58  	var doNode func(ir.Node) bool
    59  	doNode = func(n ir.Node) bool {
    60  		cstb.nodeVisitPre(n)
    61  		ir.DoChildren(n, doNode)
    62  		cstb.nodeVisitPost(n)
    63  		return false
    64  	}
    65  	for _, n := range region {
    66  		doNode(n)
    67  	}
    68  	return cstb.cstab
    69  }
    70  
    71  func (cstb *callSiteTableBuilder) flagsForNode(call *ir.CallExpr) CSPropBits {
    72  	var r CSPropBits
    73  
    74  	if debugTrace&debugTraceCalls != 0 {
    75  		fmt.Fprintf(os.Stderr, "=-= analyzing call at %s\n",
    76  			fmtFullPos(call.Pos()))
    77  	}
    78  
    79  	// Set a bit if this call is within a loop.
    80  	if cstb.loopNest > 0 {
    81  		r |= CallSiteInLoop
    82  	}
    83  
    84  	// Set a bit if the call is within an init function (either
    85  	// compiler-generated or user-written).
    86  	if cstb.isInit {
    87  		r |= CallSiteInInitFunc
    88  	}
    89  
    90  	// Decide whether to apply the panic path heuristic. Hack: don't
    91  	// apply this heuristic in the function "main.main" (mostly just
    92  	// to avoid annoying users).
    93  	if !isMainMain(cstb.fn) {
    94  		r = cstb.determinePanicPathBits(call, r)
    95  	}
    96  
    97  	return r
    98  }
    99  
   100  // determinePanicPathBits updates the CallSiteOnPanicPath bit within
   101  // "r" if we think this call is on an unconditional path to
   102  // panic/exit. Do this by walking back up the node stack to see if we
   103  // can find either A) an enclosing panic, or B) a statement node that
   104  // we've determined leads to a panic/exit.
   105  func (cstb *callSiteTableBuilder) determinePanicPathBits(call ir.Node, r CSPropBits) CSPropBits {
   106  	cstb.nstack = append(cstb.nstack, call)
   107  	defer func() {
   108  		cstb.nstack = cstb.nstack[:len(cstb.nstack)-1]
   109  	}()
   110  
   111  	for ri := range cstb.nstack[:len(cstb.nstack)-1] {
   112  		i := len(cstb.nstack) - ri - 1
   113  		n := cstb.nstack[i]
   114  		_, isCallExpr := n.(*ir.CallExpr)
   115  		_, isStmt := n.(ir.Stmt)
   116  		if isCallExpr {
   117  			isStmt = false
   118  		}
   119  
   120  		if debugTrace&debugTraceCalls != 0 {
   121  			ps, inps := cstb.ptab[n]
   122  			fmt.Fprintf(os.Stderr, "=-= callpar %d op=%s ps=%s inptab=%v stmt=%v\n", i, n.Op().String(), ps.String(), inps, isStmt)
   123  		}
   124  
   125  		if n.Op() == ir.OPANIC {
   126  			r |= CallSiteOnPanicPath
   127  			break
   128  		}
   129  		if v, ok := cstb.ptab[n]; ok {
   130  			if v == psCallsPanic {
   131  				r |= CallSiteOnPanicPath
   132  				break
   133  			}
   134  			if isStmt {
   135  				break
   136  			}
   137  		}
   138  	}
   139  	return r
   140  }
   141  
   142  // propsForArg returns property bits for a given call argument expression arg.
   143  func (cstb *callSiteTableBuilder) propsForArg(arg ir.Node) ActualExprPropBits {
   144  	if cval := cstb.constValue(arg); cval != nil {
   145  		return ActualExprConstant
   146  	}
   147  	if cstb.isConcreteConvIface(arg) {
   148  		return ActualExprIsConcreteConvIface
   149  	}
   150  	fname := cstb.funcName(arg)
   151  	if fname != nil {
   152  		if fn := fname.Func; fn != nil && typecheck.HaveInlineBody(fn) {
   153  			return ActualExprIsInlinableFunc
   154  		}
   155  		return ActualExprIsFunc
   156  	}
   157  	return 0
   158  }
   159  
   160  // argPropsForCall returns a slice of argument properties for the
   161  // expressions being passed to the callee in the specific call
   162  // expression; these will be stored in the CallSite object for a given
   163  // call and then consulted when scoring. If no arg has any interesting
   164  // properties we try to save some space and return a nil slice.
   165  func (cstb *callSiteTableBuilder) argPropsForCall(ce *ir.CallExpr) []ActualExprPropBits {
   166  	rv := make([]ActualExprPropBits, len(ce.Args))
   167  	somethingInteresting := false
   168  	for idx := range ce.Args {
   169  		argProp := cstb.propsForArg(ce.Args[idx])
   170  		somethingInteresting = somethingInteresting || (argProp != 0)
   171  		rv[idx] = argProp
   172  	}
   173  	if !somethingInteresting {
   174  		return nil
   175  	}
   176  	return rv
   177  }
   178  
   179  func (cstb *callSiteTableBuilder) addCallSite(callee *ir.Func, call *ir.CallExpr) {
   180  	flags := cstb.flagsForNode(call)
   181  	argProps := cstb.argPropsForCall(call)
   182  	if debugTrace&debugTraceCalls != 0 {
   183  		fmt.Fprintf(os.Stderr, "=-= props %+v for call %v\n", argProps, call)
   184  	}
   185  	// FIXME: maybe bulk-allocate these?
   186  	cs := &CallSite{
   187  		Call:     call,
   188  		Callee:   callee,
   189  		Assign:   cstb.containingAssignment(call),
   190  		ArgProps: argProps,
   191  		Flags:    flags,
   192  		ID:       uint(len(cstb.cstab)),
   193  	}
   194  	if _, ok := cstb.cstab[call]; ok {
   195  		fmt.Fprintf(os.Stderr, "*** cstab duplicate entry at: %s\n",
   196  			fmtFullPos(call.Pos()))
   197  		fmt.Fprintf(os.Stderr, "*** call: %+v\n", call)
   198  		panic("bad")
   199  	}
   200  	// Set initial score for callsite to the cost computed
   201  	// by CanInline; this score will be refined later based
   202  	// on heuristics.
   203  	cs.Score = int(callee.Inl.Cost)
   204  
   205  	if cstb.cstab == nil {
   206  		cstb.cstab = make(CallSiteTab)
   207  	}
   208  	cstb.cstab[call] = cs
   209  	if debugTrace&debugTraceCalls != 0 {
   210  		fmt.Fprintf(os.Stderr, "=-= added callsite: caller=%v callee=%v n=%s\n",
   211  			cstb.fn, callee, fmtFullPos(call.Pos()))
   212  	}
   213  }
   214  
   215  func (cstb *callSiteTableBuilder) nodeVisitPre(n ir.Node) {
   216  	switch n.Op() {
   217  	case ir.ORANGE, ir.OFOR:
   218  		if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
   219  			cstb.loopNest++
   220  		}
   221  	case ir.OCALLFUNC:
   222  		ce := n.(*ir.CallExpr)
   223  		callee := pgoir.DirectCallee(ce.Fun)
   224  		if callee != nil && callee.Inl != nil {
   225  			cstb.addCallSite(callee, ce)
   226  		}
   227  	}
   228  	cstb.nstack = append(cstb.nstack, n)
   229  }
   230  
   231  func (cstb *callSiteTableBuilder) nodeVisitPost(n ir.Node) {
   232  	cstb.nstack = cstb.nstack[:len(cstb.nstack)-1]
   233  	switch n.Op() {
   234  	case ir.ORANGE, ir.OFOR:
   235  		if !hasTopLevelLoopBodyReturnOrBreak(loopBody(n)) {
   236  			cstb.loopNest--
   237  		}
   238  	}
   239  }
   240  
   241  func loopBody(n ir.Node) ir.Nodes {
   242  	if forst, ok := n.(*ir.ForStmt); ok {
   243  		return forst.Body
   244  	}
   245  	if rst, ok := n.(*ir.RangeStmt); ok {
   246  		return rst.Body
   247  	}
   248  	return nil
   249  }
   250  
   251  // hasTopLevelLoopBodyReturnOrBreak examines the body of a "for" or
   252  // "range" loop to try to verify that it is a real loop, as opposed to
   253  // a construct that is syntactically loopy but doesn't actually iterate
   254  // multiple times, like:
   255  //
   256  //	for {
   257  //	  blah()
   258  //	  return 1
   259  //	}
   260  //
   261  // [Remark: the pattern above crops up quite a bit in the source code
   262  // for the compiler itself, e.g. the auto-generated rewrite code]
   263  //
   264  // Note that we don't look for GOTO statements here, so it's possible
   265  // we'll get the wrong result for a loop with complicated control
   266  // jumps via gotos.
   267  func hasTopLevelLoopBodyReturnOrBreak(loopBody ir.Nodes) bool {
   268  	for _, n := range loopBody {
   269  		if n.Op() == ir.ORETURN || n.Op() == ir.OBREAK {
   270  			return true
   271  		}
   272  	}
   273  	return false
   274  }
   275  
   276  // containingAssignment returns the top-level assignment statement
   277  // for a statement level function call "n". Examples:
   278  //
   279  //	x := foo()
   280  //	x, y := bar(z, baz())
   281  //	if blah() { ...
   282  //
   283  // Here the top-level assignment statement for the foo() call is the
   284  // statement assigning to "x"; the top-level assignment for "bar()"
   285  // call is the assignment to x,y. For the baz() and blah() calls,
   286  // there is no top level assignment statement.
   287  //
   288  // The unstated goal here is that we want to use the containing
   289  // assignment to establish a connection between a given call and the
   290  // variables to which its results/returns are being assigned.
   291  //
   292  // Note that for the "bar" command above, the front end sometimes
   293  // decomposes this into two assignments, the first one assigning the
   294  // call to a pair of auto-temps, then the second one assigning the
   295  // auto-temps to the user-visible vars. This helper will return the
   296  // second (outer) of these two.
   297  func (cstb *callSiteTableBuilder) containingAssignment(n ir.Node) ir.Node {
   298  	parent := cstb.nstack[len(cstb.nstack)-1]
   299  
   300  	// assignsOnlyAutoTemps returns TRUE of the specified OAS2FUNC
   301  	// node assigns only auto-temps.
   302  	assignsOnlyAutoTemps := func(x ir.Node) bool {
   303  		alst := x.(*ir.AssignListStmt)
   304  		oa2init := alst.Init()
   305  		if len(oa2init) == 0 {
   306  			return false
   307  		}
   308  		for _, v := range oa2init {
   309  			d := v.(*ir.Decl)
   310  			if !ir.IsAutoTmp(d.X) {
   311  				return false
   312  			}
   313  		}
   314  		return true
   315  	}
   316  
   317  	// Simple case: x := foo()
   318  	if parent.Op() == ir.OAS {
   319  		return parent
   320  	}
   321  
   322  	// Multi-return case: x, y := bar()
   323  	if parent.Op() == ir.OAS2FUNC {
   324  		// Hack city: if the result vars are auto-temps, try looking
   325  		// for an outer assignment in the tree. The code shape we're
   326  		// looking for here is:
   327  		//
   328  		// OAS1({x,y},OCONVNOP(OAS2FUNC({auto1,auto2},OCALLFUNC(bar))))
   329  		//
   330  		if assignsOnlyAutoTemps(parent) {
   331  			par2 := cstb.nstack[len(cstb.nstack)-2]
   332  			if par2.Op() == ir.OAS2 {
   333  				return par2
   334  			}
   335  			if par2.Op() == ir.OCONVNOP {
   336  				par3 := cstb.nstack[len(cstb.nstack)-3]
   337  				if par3.Op() == ir.OAS2 {
   338  					return par3
   339  				}
   340  			}
   341  		}
   342  	}
   343  
   344  	return nil
   345  }
   346  
   347  // UpdateCallsiteTable handles updating of callerfn's call site table
   348  // after an inlined has been carried out, e.g. the call at 'n' as been
   349  // turned into the inlined call expression 'ic' within function
   350  // callerfn. The chief thing of interest here is to make sure that any
   351  // call nodes within 'ic' are added to the call site table for
   352  // 'callerfn' and scored appropriately.
   353  func UpdateCallsiteTable(callerfn *ir.Func, n *ir.CallExpr, ic *ir.InlinedCallExpr) {
   354  	enableDebugTraceIfEnv()
   355  	defer disableDebugTrace()
   356  
   357  	funcInlHeur, ok := fpmap[callerfn]
   358  	if !ok {
   359  		// This can happen for compiler-generated wrappers.
   360  		if debugTrace&debugTraceCalls != 0 {
   361  			fmt.Fprintf(os.Stderr, "=-= early exit, no entry for caller fn %v\n", callerfn)
   362  		}
   363  		return
   364  	}
   365  
   366  	if debugTrace&debugTraceCalls != 0 {
   367  		fmt.Fprintf(os.Stderr, "=-= UpdateCallsiteTable(caller=%v, cs=%s)\n",
   368  			callerfn, fmtFullPos(n.Pos()))
   369  	}
   370  
   371  	// Mark the call in question as inlined.
   372  	oldcs, ok := funcInlHeur.cstab[n]
   373  	if !ok {
   374  		// This can happen for compiler-generated wrappers.
   375  		return
   376  	}
   377  	oldcs.aux |= csAuxInlined
   378  
   379  	if debugTrace&debugTraceCalls != 0 {
   380  		fmt.Fprintf(os.Stderr, "=-= marked as inlined: callee=%v %s\n",
   381  			oldcs.Callee, EncodeCallSiteKey(oldcs))
   382  	}
   383  
   384  	// Walk the inlined call region to collect new callsites.
   385  	var icp pstate
   386  	if oldcs.Flags&CallSiteOnPanicPath != 0 {
   387  		icp = psCallsPanic
   388  	}
   389  	var loopNestLevel int
   390  	if oldcs.Flags&CallSiteInLoop != 0 {
   391  		loopNestLevel = 1
   392  	}
   393  	ptab := map[ir.Node]pstate{ic: icp}
   394  	nf := newNameFinder(nil)
   395  	icstab := computeCallSiteTable(callerfn, ic.Body, nil, ptab, loopNestLevel, nf)
   396  
   397  	// Record parent callsite. This is primarily for debug output.
   398  	for _, cs := range icstab {
   399  		cs.parent = oldcs
   400  	}
   401  
   402  	// Score the calls in the inlined body. Note the setting of
   403  	// "doCallResults" to false here: at the moment there isn't any
   404  	// easy way to localize or region-ize the work done by
   405  	// "rescoreBasedOnCallResultUses", which currently does a walk
   406  	// over the entire function to look for uses of a given set of
   407  	// results. Similarly we're passing nil to makeCallSiteAnalyzer,
   408  	// so as to run name finding without the use of static value &
   409  	// friends.
   410  	csa := makeCallSiteAnalyzer(nil)
   411  	const doCallResults = false
   412  	csa.scoreCallsRegion(callerfn, ic.Body, icstab, doCallResults, ic)
   413  }
   414  

View as plain text