Source file src/cmd/compile/internal/inline/interleaved/interleaved.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package interleaved implements the interleaved devirtualization and
     6  // inlining pass.
     7  package interleaved
     8  
     9  import (
    10  	"cmd/compile/internal/base"
    11  	"cmd/compile/internal/devirtualize"
    12  	"cmd/compile/internal/inline"
    13  	"cmd/compile/internal/inline/inlheur"
    14  	"cmd/compile/internal/ir"
    15  	"cmd/compile/internal/pgoir"
    16  	"cmd/compile/internal/typecheck"
    17  	"fmt"
    18  )
    19  
    20  // DevirtualizeAndInlinePackage interleaves devirtualization and inlining on
    21  // all functions within pkg.
    22  func DevirtualizeAndInlinePackage(pkg *ir.Package, profile *pgoir.Profile) {
    23  	if profile != nil && base.Debug.PGODevirtualize > 0 {
    24  		// TODO(mdempsky): Integrate into DevirtualizeAndInlineFunc below.
    25  		ir.VisitFuncsBottomUp(typecheck.Target.Funcs, func(list []*ir.Func, recursive bool) {
    26  			for _, fn := range list {
    27  				devirtualize.ProfileGuided(fn, profile)
    28  			}
    29  		})
    30  		ir.CurFunc = nil
    31  	}
    32  
    33  	if base.Flag.LowerL != 0 {
    34  		inlheur.SetupScoreAdjustments()
    35  	}
    36  
    37  	var inlProfile *pgoir.Profile // copy of profile for inlining
    38  	if base.Debug.PGOInline != 0 {
    39  		inlProfile = profile
    40  	}
    41  
    42  	// First compute inlinability of all functions in the package.
    43  	inline.CanInlineFuncs(pkg.Funcs, inlProfile)
    44  
    45  	inlState := make(map[*ir.Func]*inlClosureState)
    46  	calleeUseCounts := make(map[*ir.Func]int)
    47  
    48  	var state devirtualize.State
    49  
    50  	// Pre-process all the functions, adding parentheses around call sites and starting their "inl state".
    51  	for _, fn := range typecheck.Target.Funcs {
    52  		bigCaller := base.Flag.LowerL != 0 && inline.IsBigFunc(fn)
    53  		if bigCaller && base.Flag.LowerM > 1 {
    54  			fmt.Printf("%v: function %v considered 'big'; reducing max cost of inlinees\n", ir.Line(fn), fn)
    55  		}
    56  
    57  		s := &inlClosureState{bigCaller: bigCaller, profile: profile, fn: fn, callSites: make(map[*ir.ParenExpr]bool), useCounts: calleeUseCounts}
    58  		s.parenthesize()
    59  		inlState[fn] = s
    60  
    61  		// Do a first pass at counting call sites.
    62  		for i := range s.parens {
    63  			s.resolve(&state, i)
    64  		}
    65  	}
    66  
    67  	ir.VisitFuncsBottomUp(typecheck.Target.Funcs, func(list []*ir.Func, recursive bool) {
    68  
    69  		anyInlineHeuristics := false
    70  
    71  		// inline heuristics, placed here because they have static state and that's what seems to work.
    72  		for _, fn := range list {
    73  			if base.Flag.LowerL != 0 {
    74  				if inlheur.Enabled() && !fn.Wrapper() {
    75  					inlheur.ScoreCalls(fn)
    76  					anyInlineHeuristics = true
    77  				}
    78  				if base.Debug.DumpInlFuncProps != "" && !fn.Wrapper() {
    79  					inlheur.DumpFuncProps(fn, base.Debug.DumpInlFuncProps)
    80  				}
    81  			}
    82  		}
    83  
    84  		if anyInlineHeuristics {
    85  			defer inlheur.ScoreCallsCleanup()
    86  		}
    87  
    88  		// Iterate to a fixed point over all the functions.
    89  		done := false
    90  		for !done {
    91  			done = true
    92  			for _, fn := range list {
    93  				s := inlState[fn]
    94  
    95  				ir.WithFunc(fn, func() {
    96  					l1 := len(s.parens)
    97  					l0 := 0
    98  
    99  					// Batch iterations so that newly discovered call sites are
   100  					// resolved in a batch before inlining attempts.
   101  					// Do this to avoid discovering new closure calls 1 at a time
   102  					// which might cause first call to be seen as a single (high-budget)
   103  					// call before the second is observed.
   104  					for {
   105  						for i := l0; i < l1; i++ { // can't use "range parens" here
   106  							paren := s.parens[i]
   107  							if origCall, inlinedCall := s.edit(&state, i); inlinedCall != nil {
   108  								// Update AST and recursively mark nodes.
   109  								paren.X = inlinedCall
   110  								ir.EditChildren(inlinedCall, s.mark) // mark may append to parens
   111  								state.InlinedCall(s.fn, origCall, inlinedCall)
   112  								done = false
   113  							}
   114  						}
   115  						l0, l1 = l1, len(s.parens)
   116  						if l0 == l1 {
   117  							break
   118  						}
   119  						for i := l0; i < l1; i++ {
   120  							s.resolve(&state, i)
   121  						}
   122  
   123  					}
   124  
   125  				}) // WithFunc
   126  
   127  			}
   128  		}
   129  	})
   130  
   131  	ir.CurFunc = nil
   132  
   133  	if base.Flag.LowerL != 0 {
   134  		if base.Debug.DumpInlFuncProps != "" {
   135  			inlheur.DumpFuncProps(nil, base.Debug.DumpInlFuncProps)
   136  		}
   137  		if inlheur.Enabled() {
   138  			inline.PostProcessCallSites(inlProfile)
   139  			inlheur.TearDown()
   140  		}
   141  	}
   142  
   143  	// remove parentheses
   144  	for _, fn := range typecheck.Target.Funcs {
   145  		inlState[fn].unparenthesize()
   146  	}
   147  
   148  }
   149  
   150  // DevirtualizeAndInlineFunc interleaves devirtualization and inlining
   151  // on a single function.
   152  func DevirtualizeAndInlineFunc(fn *ir.Func, profile *pgoir.Profile) {
   153  	ir.WithFunc(fn, func() {
   154  		if base.Flag.LowerL != 0 {
   155  			if inlheur.Enabled() && !fn.Wrapper() {
   156  				inlheur.ScoreCalls(fn)
   157  				defer inlheur.ScoreCallsCleanup()
   158  			}
   159  			if base.Debug.DumpInlFuncProps != "" && !fn.Wrapper() {
   160  				inlheur.DumpFuncProps(fn, base.Debug.DumpInlFuncProps)
   161  			}
   162  		}
   163  
   164  		bigCaller := base.Flag.LowerL != 0 && inline.IsBigFunc(fn)
   165  		if bigCaller && base.Flag.LowerM > 1 {
   166  			fmt.Printf("%v: function %v considered 'big'; reducing max cost of inlinees\n", ir.Line(fn), fn)
   167  		}
   168  
   169  		s := &inlClosureState{bigCaller: bigCaller, profile: profile, fn: fn, callSites: make(map[*ir.ParenExpr]bool), useCounts: make(map[*ir.Func]int)}
   170  		s.parenthesize()
   171  		s.fixpoint()
   172  		s.unparenthesize()
   173  	})
   174  }
   175  
   176  type callSite struct {
   177  	fn         *ir.Func
   178  	whichParen int
   179  }
   180  
   181  type inlClosureState struct {
   182  	fn        *ir.Func
   183  	profile   *pgoir.Profile
   184  	callSites map[*ir.ParenExpr]bool // callSites[p] == "p appears in parens" (do not append again)
   185  	resolved  []*ir.Func             // for each call in parens, the resolved target of the call
   186  	useCounts map[*ir.Func]int       // shared among all InlClosureStates
   187  	parens    []*ir.ParenExpr
   188  	bigCaller bool
   189  }
   190  
   191  // resolve attempts to resolve a call to a potentially inlineable callee
   192  // and updates use counts on the callees.  Returns the call site count
   193  // for that callee.
   194  func (s *inlClosureState) resolve(state *devirtualize.State, i int) (*ir.Func, int) {
   195  	p := s.parens[i]
   196  	if i < len(s.resolved) {
   197  		if callee := s.resolved[i]; callee != nil {
   198  			return callee, s.useCounts[callee]
   199  		}
   200  	}
   201  	n := p.X
   202  	call, ok := n.(*ir.CallExpr)
   203  	if !ok { // previously inlined
   204  		return nil, -1
   205  	}
   206  	devirtualize.StaticCall(state, call)
   207  	if callee := inline.InlineCallTarget(s.fn, call, s.profile); callee != nil {
   208  		for len(s.resolved) <= i {
   209  			s.resolved = append(s.resolved, nil)
   210  		}
   211  		s.resolved[i] = callee
   212  		c := s.useCounts[callee] + 1
   213  		s.useCounts[callee] = c
   214  		return callee, c
   215  	}
   216  	return nil, 0
   217  }
   218  
   219  func (s *inlClosureState) edit(state *devirtualize.State, i int) (*ir.CallExpr, *ir.InlinedCallExpr) {
   220  	n := s.parens[i].X
   221  	call, ok := n.(*ir.CallExpr)
   222  	if !ok {
   223  		return nil, nil
   224  	}
   225  	// This is redundant with earlier calls to
   226  	// resolve, but because things can change it
   227  	// must be re-checked.
   228  	callee, count := s.resolve(state, i)
   229  	if count <= 0 {
   230  		return nil, nil
   231  	}
   232  	if inlCall := inline.TryInlineCall(s.fn, call, s.bigCaller, s.profile, count == 1 && callee.ClosureParent != nil); inlCall != nil {
   233  		return call, inlCall
   234  	}
   235  	return nil, nil
   236  }
   237  
   238  // Mark inserts parentheses, and is called repeatedly.
   239  // These inserted parentheses mark the call sites where
   240  // inlining will be attempted.
   241  func (s *inlClosureState) mark(n ir.Node) ir.Node {
   242  	// Consider the expression "f(g())". We want to be able to replace
   243  	// "g()" in-place with its inlined representation. But if we first
   244  	// replace "f(...)" with its inlined representation, then "g()" will
   245  	// instead appear somewhere within this new AST.
   246  	//
   247  	// To mitigate this, each matched node n is wrapped in a ParenExpr,
   248  	// so we can reliably replace n in-place by assigning ParenExpr.X.
   249  	// It's safe to use ParenExpr here, because typecheck already
   250  	// removed them all.
   251  
   252  	p, _ := n.(*ir.ParenExpr)
   253  	if p != nil && s.callSites[p] {
   254  		return n // already visited n.X before wrapping
   255  	}
   256  
   257  	if isTestingBLoop(n) {
   258  		// No inlining nor devirtualization performed on b.Loop body
   259  		if base.Flag.LowerM > 0 {
   260  			fmt.Printf("%v: skip inlining within testing.B.loop for %v\n", ir.Line(n), n)
   261  		}
   262  		// We still want to explore inlining opportunities in other parts of ForStmt.
   263  		nFor, _ := n.(*ir.ForStmt)
   264  		nForInit := nFor.Init()
   265  		for i, x := range nForInit {
   266  			if x != nil {
   267  				nForInit[i] = s.mark(x)
   268  			}
   269  		}
   270  		if nFor.Cond != nil {
   271  			nFor.Cond = s.mark(nFor.Cond)
   272  		}
   273  		if nFor.Post != nil {
   274  			nFor.Post = s.mark(nFor.Post)
   275  		}
   276  		return n
   277  	}
   278  
   279  	if p != nil {
   280  		n = p.X // in this case p was copied in from a (marked) inlined function, this is a new unvisited node.
   281  	}
   282  
   283  	ok := match(n)
   284  
   285  	// can't wrap TailCall's child into ParenExpr
   286  	if t, ok := n.(*ir.TailCallStmt); ok {
   287  		ir.EditChildren(t.Call, s.mark)
   288  	} else {
   289  		ir.EditChildren(n, s.mark)
   290  	}
   291  
   292  	if ok {
   293  		if p == nil {
   294  			p = ir.NewParenExpr(n.Pos(), n)
   295  			p.SetType(n.Type())
   296  			p.SetTypecheck(n.Typecheck())
   297  			s.callSites[p] = true
   298  		}
   299  
   300  		s.parens = append(s.parens, p)
   301  		n = p
   302  	} else if p != nil {
   303  		n = p // didn't change anything, restore n
   304  	}
   305  	return n
   306  }
   307  
   308  // parenthesize applies s.mark to all the nodes within
   309  // s.fn to mark calls and simplify rewriting them in place.
   310  func (s *inlClosureState) parenthesize() {
   311  	ir.EditChildren(s.fn, s.mark)
   312  }
   313  
   314  func (s *inlClosureState) unparenthesize() {
   315  	if s == nil {
   316  		return
   317  	}
   318  	if len(s.parens) == 0 {
   319  		return // short circuit
   320  	}
   321  
   322  	var unparen func(ir.Node) ir.Node
   323  	unparen = func(n ir.Node) ir.Node {
   324  		if paren, ok := n.(*ir.ParenExpr); ok {
   325  			n = paren.X
   326  		}
   327  		ir.EditChildren(n, unparen)
   328  		return n
   329  	}
   330  	ir.EditChildren(s.fn, unparen)
   331  }
   332  
   333  // fixpoint repeatedly edits a function until it stabilizes, returning
   334  // whether anything changed in any of the fixpoint iterations.
   335  //
   336  // It applies s.edit(n) to each node n within the parentheses in s.parens.
   337  // If s.edit(n) returns nil, no change is made. Otherwise, the result
   338  // replaces n in fn's body, and fixpoint iterates at least once more.
   339  //
   340  // After an iteration where all edit calls return nil, fixpoint
   341  // returns.
   342  func (s *inlClosureState) fixpoint() bool {
   343  	changed := false
   344  	var state devirtualize.State
   345  	ir.WithFunc(s.fn, func() {
   346  		done := false
   347  		for !done {
   348  			done = true
   349  			for i := 0; i < len(s.parens); i++ { // can't use "range parens" here
   350  				paren := s.parens[i]
   351  				if origCall, inlinedCall := s.edit(&state, i); inlinedCall != nil {
   352  					// Update AST and recursively mark nodes.
   353  					paren.X = inlinedCall
   354  					ir.EditChildren(inlinedCall, s.mark) // mark may append to parens
   355  					state.InlinedCall(s.fn, origCall, inlinedCall)
   356  					done = false
   357  					changed = true
   358  				}
   359  			}
   360  		}
   361  	})
   362  	return changed
   363  }
   364  
   365  func match(n ir.Node) bool {
   366  	switch n := n.(type) {
   367  	case *ir.CallExpr:
   368  		return true
   369  	case *ir.TailCallStmt:
   370  		n.Call.NoInline = true // can't inline yet
   371  	}
   372  	return false
   373  }
   374  
   375  // isTestingBLoop returns true if it matches the node as a
   376  // testing.(*B).Loop. See issue #61515.
   377  func isTestingBLoop(t ir.Node) bool {
   378  	if t.Op() != ir.OFOR {
   379  		return false
   380  	}
   381  	nFor, ok := t.(*ir.ForStmt)
   382  	if !ok || nFor.Cond == nil || nFor.Cond.Op() != ir.OCALLFUNC {
   383  		return false
   384  	}
   385  	n, ok := nFor.Cond.(*ir.CallExpr)
   386  	if !ok || n.Fun == nil || n.Fun.Op() != ir.OMETHEXPR {
   387  		return false
   388  	}
   389  	name := ir.MethodExprName(n.Fun)
   390  	if name == nil {
   391  		return false
   392  	}
   393  	if fSym := name.Sym(); fSym != nil && name.Class == ir.PFUNC && fSym.Pkg != nil &&
   394  		fSym.Name == "(*B).Loop" && fSym.Pkg.Path == "testing" {
   395  		// Attempting to match a function call to testing.(*B).Loop
   396  		return true
   397  	}
   398  	return false
   399  }
   400  

View as plain text