Source file src/cmd/compile/internal/inline/inlheur/analyze_func_params.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package inlheur
     6  
     7  import (
     8  	"cmd/compile/internal/ir"
     9  	"fmt"
    10  	"os"
    11  )
    12  
    13  // paramsAnalyzer holds state information for the phase that computes
    14  // flags for a Go functions parameters, for use in inline heuristics.
    15  // Note that the params slice below includes entries for blanks.
    16  type paramsAnalyzer struct {
    17  	fname  string
    18  	values []ParamPropBits
    19  	params []*ir.Name
    20  	top    []bool
    21  	*condLevelTracker
    22  	*nameFinder
    23  }
    24  
    25  // getParams returns an *ir.Name slice containing all params for the
    26  // function (plus rcvr as well if applicable).
    27  func getParams(fn *ir.Func) []*ir.Name {
    28  	sig := fn.Type()
    29  	numParams := sig.NumRecvs() + sig.NumParams()
    30  	return fn.Dcl[:numParams]
    31  }
    32  
    33  // addParamsAnalyzer creates a new paramsAnalyzer helper object for
    34  // the function fn, appends it to the analyzers list, and returns the
    35  // new list. If the function in question doesn't have any interesting
    36  // parameters then the analyzer list is returned unchanged, and the
    37  // params flags in "fp" are updated accordingly.
    38  func addParamsAnalyzer(fn *ir.Func, analyzers []propAnalyzer, fp *FuncProps, nf *nameFinder) []propAnalyzer {
    39  	pa, props := makeParamsAnalyzer(fn, nf)
    40  	if pa != nil {
    41  		analyzers = append(analyzers, pa)
    42  	} else {
    43  		fp.ParamFlags = props
    44  	}
    45  	return analyzers
    46  }
    47  
    48  // makeParamsAnalyzer creates a new helper object to analyze parameters
    49  // of function fn. If the function doesn't have any interesting
    50  // params, a nil helper is returned along with a set of default param
    51  // flags for the func.
    52  func makeParamsAnalyzer(fn *ir.Func, nf *nameFinder) (*paramsAnalyzer, []ParamPropBits) {
    53  	params := getParams(fn) // includes receiver if applicable
    54  	if len(params) == 0 {
    55  		return nil, nil
    56  	}
    57  	vals := make([]ParamPropBits, len(params))
    58  	if fn.Inl == nil {
    59  		return nil, vals
    60  	}
    61  	top := make([]bool, len(params))
    62  	interestingToAnalyze := false
    63  	for i, pn := range params {
    64  		if pn == nil {
    65  			continue
    66  		}
    67  		pt := pn.Type()
    68  		if !pt.IsScalar() && !pt.HasNil() {
    69  			// existing properties not applicable here (for things
    70  			// like structs, arrays, slices, etc).
    71  			continue
    72  		}
    73  		// If param is reassigned, skip it.
    74  		if ir.Reassigned(pn) {
    75  			continue
    76  		}
    77  		top[i] = true
    78  		interestingToAnalyze = true
    79  	}
    80  	if !interestingToAnalyze {
    81  		return nil, vals
    82  	}
    83  
    84  	if debugTrace&debugTraceParams != 0 {
    85  		fmt.Fprintf(os.Stderr, "=-= param analysis of func %v:\n",
    86  			fn.Sym().Name)
    87  		for i := range vals {
    88  			n := "_"
    89  			if params[i] != nil {
    90  				n = params[i].Sym().String()
    91  			}
    92  			fmt.Fprintf(os.Stderr, "=-=  %d: %q %s top=%v\n",
    93  				i, n, vals[i].String(), top[i])
    94  		}
    95  	}
    96  	pa := &paramsAnalyzer{
    97  		fname:            fn.Sym().Name,
    98  		values:           vals,
    99  		params:           params,
   100  		top:              top,
   101  		condLevelTracker: new(condLevelTracker),
   102  		nameFinder:       nf,
   103  	}
   104  	return pa, nil
   105  }
   106  
   107  func (pa *paramsAnalyzer) setResults(funcProps *FuncProps) {
   108  	funcProps.ParamFlags = pa.values
   109  }
   110  
   111  func (pa *paramsAnalyzer) findParamIdx(n *ir.Name) int {
   112  	if n == nil {
   113  		panic("bad")
   114  	}
   115  	for i := range pa.params {
   116  		if pa.params[i] == n {
   117  			return i
   118  		}
   119  	}
   120  	return -1
   121  }
   122  
   123  type testfType func(x ir.Node, param *ir.Name, idx int) (bool, bool)
   124  
   125  // paramsAnalyzer invokes function 'testf' on the specified expression
   126  // 'x' for each parameter, and if the result is TRUE, or's 'flag' into
   127  // the flags for that param.
   128  func (pa *paramsAnalyzer) checkParams(x ir.Node, flag ParamPropBits, mayflag ParamPropBits, testf testfType) {
   129  	for idx, p := range pa.params {
   130  		if !pa.top[idx] && pa.values[idx] == ParamNoInfo {
   131  			continue
   132  		}
   133  		result, may := testf(x, p, idx)
   134  		if debugTrace&debugTraceParams != 0 {
   135  			fmt.Fprintf(os.Stderr, "=-= test expr %v param %s result=%v flag=%s\n", x, p.Sym().Name, result, flag.String())
   136  		}
   137  		if result {
   138  			v := flag
   139  			if pa.condLevel != 0 || may {
   140  				v = mayflag
   141  			}
   142  			pa.values[idx] |= v
   143  			pa.top[idx] = false
   144  		}
   145  	}
   146  }
   147  
   148  // foldCheckParams checks expression 'x' (an 'if' condition or
   149  // 'switch' stmt expr) to see if the expr would fold away if a
   150  // specific parameter had a constant value.
   151  func (pa *paramsAnalyzer) foldCheckParams(x ir.Node) {
   152  	pa.checkParams(x, ParamFeedsIfOrSwitch, ParamMayFeedIfOrSwitch,
   153  		func(x ir.Node, p *ir.Name, idx int) (bool, bool) {
   154  			return ShouldFoldIfNameConstant(x, []*ir.Name{p}), false
   155  		})
   156  }
   157  
   158  // callCheckParams examines the target of call expression 'ce' to see
   159  // if it is making a call to the value passed in for some parameter.
   160  func (pa *paramsAnalyzer) callCheckParams(ce *ir.CallExpr) {
   161  	switch ce.Op() {
   162  	case ir.OCALLINTER:
   163  		if ce.Op() != ir.OCALLINTER {
   164  			return
   165  		}
   166  		sel := ce.Fun.(*ir.SelectorExpr)
   167  		r := pa.staticValue(sel.X)
   168  		if r.Op() != ir.ONAME {
   169  			return
   170  		}
   171  		name := r.(*ir.Name)
   172  		if name.Class != ir.PPARAM {
   173  			return
   174  		}
   175  		pa.checkParams(r, ParamFeedsInterfaceMethodCall,
   176  			ParamMayFeedInterfaceMethodCall,
   177  			func(x ir.Node, p *ir.Name, idx int) (bool, bool) {
   178  				name := x.(*ir.Name)
   179  				return name == p, false
   180  			})
   181  	case ir.OCALLFUNC:
   182  		if ce.Fun.Op() != ir.ONAME {
   183  			return
   184  		}
   185  		called := ir.StaticValue(ce.Fun)
   186  		if called.Op() != ir.ONAME {
   187  			return
   188  		}
   189  		name := called.(*ir.Name)
   190  		if name.Class == ir.PPARAM {
   191  			pa.checkParams(called, ParamFeedsIndirectCall,
   192  				ParamMayFeedIndirectCall,
   193  				func(x ir.Node, p *ir.Name, idx int) (bool, bool) {
   194  					name := x.(*ir.Name)
   195  					return name == p, false
   196  				})
   197  		} else {
   198  			cname := pa.funcName(called)
   199  			if cname != nil {
   200  				pa.deriveFlagsFromCallee(ce, cname.Func)
   201  			}
   202  		}
   203  	}
   204  }
   205  
   206  // deriveFlagsFromCallee tries to derive flags for the current
   207  // function based on a call this function makes to some other
   208  // function. Example:
   209  //
   210  //	/* Simple */                /* Derived from callee */
   211  //	func foo(f func(int)) {     func foo(f func(int)) {
   212  //	  f(2)                        bar(32, f)
   213  //	}                           }
   214  //	                            func bar(x int, f func()) {
   215  //	                              f(x)
   216  //	                            }
   217  //
   218  // Here we can set the "param feeds indirect call" flag for
   219  // foo's param 'f' since we know that bar has that flag set for
   220  // its second param, and we're passing that param a function.
   221  func (pa *paramsAnalyzer) deriveFlagsFromCallee(ce *ir.CallExpr, callee *ir.Func) {
   222  	calleeProps := propsForFunc(callee)
   223  	if calleeProps == nil {
   224  		return
   225  	}
   226  	if debugTrace&debugTraceParams != 0 {
   227  		fmt.Fprintf(os.Stderr, "=-= callee props for %v:\n%s",
   228  			callee.Sym().Name, calleeProps.String())
   229  	}
   230  
   231  	must := []ParamPropBits{ParamFeedsInterfaceMethodCall, ParamFeedsIndirectCall, ParamFeedsIfOrSwitch}
   232  	may := []ParamPropBits{ParamMayFeedInterfaceMethodCall, ParamMayFeedIndirectCall, ParamMayFeedIfOrSwitch}
   233  
   234  	for pidx, arg := range ce.Args {
   235  		// Does the callee param have any interesting properties?
   236  		// If not we can skip this one.
   237  		pflag := calleeProps.ParamFlags[pidx]
   238  		if pflag == 0 {
   239  			continue
   240  		}
   241  		// See if one of the caller's parameters is flowing unmodified
   242  		// into this actual expression.
   243  		r := pa.staticValue(arg)
   244  		if r.Op() != ir.ONAME {
   245  			return
   246  		}
   247  		name := r.(*ir.Name)
   248  		if name.Class != ir.PPARAM {
   249  			return
   250  		}
   251  		callerParamIdx := pa.findParamIdx(name)
   252  		// note that callerParamIdx may return -1 in the case where
   253  		// the param belongs not to the current closure func we're
   254  		// analyzing but to an outer enclosing func.
   255  		if callerParamIdx == -1 {
   256  			return
   257  		}
   258  		if pa.params[callerParamIdx] == nil {
   259  			panic("something went wrong")
   260  		}
   261  		if !pa.top[callerParamIdx] &&
   262  			pa.values[callerParamIdx] == ParamNoInfo {
   263  			continue
   264  		}
   265  		if debugTrace&debugTraceParams != 0 {
   266  			fmt.Fprintf(os.Stderr, "=-= pflag for arg %d is %s\n",
   267  				pidx, pflag.String())
   268  		}
   269  		for i := range must {
   270  			mayv := may[i]
   271  			mustv := must[i]
   272  			if pflag&mustv != 0 && pa.condLevel == 0 {
   273  				pa.values[callerParamIdx] |= mustv
   274  			} else if pflag&(mustv|mayv) != 0 {
   275  				pa.values[callerParamIdx] |= mayv
   276  			}
   277  		}
   278  		pa.top[callerParamIdx] = false
   279  	}
   280  }
   281  
   282  func (pa *paramsAnalyzer) nodeVisitPost(n ir.Node) {
   283  	if len(pa.values) == 0 {
   284  		return
   285  	}
   286  	pa.condLevelTracker.post(n)
   287  	switch n.Op() {
   288  	case ir.OCALLFUNC:
   289  		ce := n.(*ir.CallExpr)
   290  		pa.callCheckParams(ce)
   291  	case ir.OCALLINTER:
   292  		ce := n.(*ir.CallExpr)
   293  		pa.callCheckParams(ce)
   294  	case ir.OIF:
   295  		ifst := n.(*ir.IfStmt)
   296  		pa.foldCheckParams(ifst.Cond)
   297  	case ir.OSWITCH:
   298  		swst := n.(*ir.SwitchStmt)
   299  		if swst.Tag != nil {
   300  			pa.foldCheckParams(swst.Tag)
   301  		}
   302  	}
   303  }
   304  
   305  func (pa *paramsAnalyzer) nodeVisitPre(n ir.Node) {
   306  	if len(pa.values) == 0 {
   307  		return
   308  	}
   309  	pa.condLevelTracker.pre(n)
   310  }
   311  
   312  // condLevelTracker helps keeps track very roughly of "level of conditional
   313  // nesting", e.g. how many "if" statements you have to go through to
   314  // get to the point where a given stmt executes. Example:
   315  //
   316  //	                      cond nesting level
   317  //	func foo() {
   318  //	 G = 1                   0
   319  //	 if x < 10 {             0
   320  //	  if y < 10 {            1
   321  //	   G = 0                 2
   322  //	  }
   323  //	 }
   324  //	}
   325  //
   326  // The intent here is to provide some sort of very abstract relative
   327  // hotness metric, e.g. "G = 1" above is expected to be executed more
   328  // often than "G = 0" (in the aggregate, across large numbers of
   329  // functions).
   330  type condLevelTracker struct {
   331  	condLevel int
   332  }
   333  
   334  func (c *condLevelTracker) pre(n ir.Node) {
   335  	// Increment level of "conditional testing" if we see
   336  	// an "if" or switch statement, and decrement if in
   337  	// a loop.
   338  	switch n.Op() {
   339  	case ir.OIF, ir.OSWITCH:
   340  		c.condLevel++
   341  	case ir.OFOR, ir.ORANGE:
   342  		c.condLevel--
   343  	}
   344  }
   345  
   346  func (c *condLevelTracker) post(n ir.Node) {
   347  	switch n.Op() {
   348  	case ir.OFOR, ir.ORANGE:
   349  		c.condLevel++
   350  	case ir.OIF:
   351  		c.condLevel--
   352  	case ir.OSWITCH:
   353  		c.condLevel--
   354  	}
   355  }
   356  

View as plain text