Source file src/cmd/compile/internal/inline/interleaved/interleaved.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package interleaved implements the interleaved devirtualization and
     6  // inlining pass.
     7  package interleaved
     8  
     9  import (
    10  	"cmd/compile/internal/base"
    11  	"cmd/compile/internal/devirtualize"
    12  	"cmd/compile/internal/inline"
    13  	"cmd/compile/internal/inline/inlheur"
    14  	"cmd/compile/internal/ir"
    15  	"cmd/compile/internal/pgoir"
    16  	"cmd/compile/internal/typecheck"
    17  	"fmt"
    18  )
    19  
    20  // DevirtualizeAndInlinePackage interleaves devirtualization and inlining on
    21  // all functions within pkg.
    22  func DevirtualizeAndInlinePackage(pkg *ir.Package, profile *pgoir.Profile) {
    23  	if base.Flag.W > 1 {
    24  		for _, fn := range typecheck.Target.Funcs {
    25  			s := fmt.Sprintf("\nbefore devirtualize-and-inline %v", fn.Sym())
    26  			ir.DumpList(s, fn.Body)
    27  		}
    28  	}
    29  
    30  	if profile != nil && base.Debug.PGODevirtualize > 0 {
    31  		// TODO(mdempsky): Integrate into DevirtualizeAndInlineFunc below.
    32  		ir.VisitFuncsBottomUp(typecheck.Target.Funcs, func(list []*ir.Func, recursive bool) {
    33  			for _, fn := range list {
    34  				devirtualize.ProfileGuided(fn, profile)
    35  			}
    36  		})
    37  		ir.CurFunc = nil
    38  	}
    39  
    40  	if base.Flag.LowerL != 0 {
    41  		inlheur.SetupScoreAdjustments()
    42  	}
    43  
    44  	var inlProfile *pgoir.Profile // copy of profile for inlining
    45  	if base.Debug.PGOInline != 0 {
    46  		inlProfile = profile
    47  	}
    48  
    49  	// First compute inlinability of all functions in the package.
    50  	inline.CanInlineFuncs(pkg.Funcs, inlProfile)
    51  
    52  	inlState := make(map[*ir.Func]*inlClosureState)
    53  	calleeUseCounts := make(map[*ir.Func]int)
    54  
    55  	var state devirtualize.State
    56  
    57  	// Pre-process all the functions, adding parentheses around call sites and starting their "inl state".
    58  	for _, fn := range typecheck.Target.Funcs {
    59  		bigCaller := base.Flag.LowerL != 0 && inline.IsBigFunc(fn)
    60  		if bigCaller && base.Flag.LowerM > 1 {
    61  			fmt.Printf("%v: function %v considered 'big'; reducing max cost of inlinees\n", ir.Line(fn), fn)
    62  		}
    63  
    64  		s := &inlClosureState{bigCaller: bigCaller, profile: profile, fn: fn, callSites: make(map[*ir.ParenExpr]bool), useCounts: calleeUseCounts}
    65  		s.parenthesize()
    66  		inlState[fn] = s
    67  
    68  		// Do a first pass at counting call sites.
    69  		for i := range s.parens {
    70  			s.resolve(&state, i)
    71  		}
    72  	}
    73  
    74  	ir.VisitFuncsBottomUp(typecheck.Target.Funcs, func(list []*ir.Func, recursive bool) {
    75  
    76  		anyInlineHeuristics := false
    77  
    78  		// inline heuristics, placed here because they have static state and that's what seems to work.
    79  		for _, fn := range list {
    80  			if base.Flag.LowerL != 0 {
    81  				if inlheur.Enabled() && !fn.Wrapper() {
    82  					inlheur.ScoreCalls(fn)
    83  					anyInlineHeuristics = true
    84  				}
    85  				if base.Debug.DumpInlFuncProps != "" && !fn.Wrapper() {
    86  					inlheur.DumpFuncProps(fn, base.Debug.DumpInlFuncProps)
    87  				}
    88  			}
    89  		}
    90  
    91  		if anyInlineHeuristics {
    92  			defer inlheur.ScoreCallsCleanup()
    93  		}
    94  
    95  		// Iterate to a fixed point over all the functions.
    96  		done := false
    97  		for !done {
    98  			done = true
    99  			for _, fn := range list {
   100  				s := inlState[fn]
   101  
   102  				ir.WithFunc(fn, func() {
   103  					l1 := len(s.parens)
   104  					l0 := 0
   105  
   106  					// Batch iterations so that newly discovered call sites are
   107  					// resolved in a batch before inlining attempts.
   108  					// Do this to avoid discovering new closure calls 1 at a time
   109  					// which might cause first call to be seen as a single (high-budget)
   110  					// call before the second is observed.
   111  					for {
   112  						for i := l0; i < l1; i++ { // can't use "range parens" here
   113  							paren := s.parens[i]
   114  							if origCall, inlinedCall := s.edit(&state, i); inlinedCall != nil {
   115  								// Update AST and recursively mark nodes.
   116  								paren.X = inlinedCall
   117  								ir.EditChildren(inlinedCall, s.mark) // mark may append to parens
   118  								state.InlinedCall(s.fn, origCall, inlinedCall)
   119  								done = false
   120  							}
   121  						}
   122  						l0, l1 = l1, len(s.parens)
   123  						if l0 == l1 {
   124  							break
   125  						}
   126  						for i := l0; i < l1; i++ {
   127  							s.resolve(&state, i)
   128  						}
   129  
   130  					}
   131  
   132  				}) // WithFunc
   133  
   134  			}
   135  		}
   136  	})
   137  
   138  	ir.CurFunc = nil
   139  
   140  	if base.Flag.LowerL != 0 {
   141  		if base.Debug.DumpInlFuncProps != "" {
   142  			inlheur.DumpFuncProps(nil, base.Debug.DumpInlFuncProps)
   143  		}
   144  		if inlheur.Enabled() {
   145  			inline.PostProcessCallSites(inlProfile)
   146  			inlheur.TearDown()
   147  		}
   148  	}
   149  
   150  	// remove parentheses
   151  	for _, fn := range typecheck.Target.Funcs {
   152  		inlState[fn].unparenthesize()
   153  	}
   154  
   155  }
   156  
   157  // DevirtualizeAndInlineFunc interleaves devirtualization and inlining
   158  // on a single function.
   159  func DevirtualizeAndInlineFunc(fn *ir.Func, profile *pgoir.Profile) {
   160  	ir.WithFunc(fn, func() {
   161  		if base.Flag.LowerL != 0 {
   162  			if inlheur.Enabled() && !fn.Wrapper() {
   163  				inlheur.ScoreCalls(fn)
   164  				defer inlheur.ScoreCallsCleanup()
   165  			}
   166  			if base.Debug.DumpInlFuncProps != "" && !fn.Wrapper() {
   167  				inlheur.DumpFuncProps(fn, base.Debug.DumpInlFuncProps)
   168  			}
   169  		}
   170  
   171  		bigCaller := base.Flag.LowerL != 0 && inline.IsBigFunc(fn)
   172  		if bigCaller && base.Flag.LowerM > 1 {
   173  			fmt.Printf("%v: function %v considered 'big'; reducing max cost of inlinees\n", ir.Line(fn), fn)
   174  		}
   175  
   176  		s := &inlClosureState{bigCaller: bigCaller, profile: profile, fn: fn, callSites: make(map[*ir.ParenExpr]bool), useCounts: make(map[*ir.Func]int)}
   177  		s.parenthesize()
   178  		s.fixpoint()
   179  		s.unparenthesize()
   180  	})
   181  }
   182  
   183  type callSite struct {
   184  	fn         *ir.Func
   185  	whichParen int
   186  }
   187  
   188  type inlClosureState struct {
   189  	fn        *ir.Func
   190  	profile   *pgoir.Profile
   191  	callSites map[*ir.ParenExpr]bool // callSites[p] == "p appears in parens" (do not append again)
   192  	resolved  []*ir.Func             // for each call in parens, the resolved target of the call
   193  	useCounts map[*ir.Func]int       // shared among all InlClosureStates
   194  	parens    []*ir.ParenExpr
   195  	bigCaller bool
   196  }
   197  
   198  // resolve attempts to resolve a call to a potentially inlineable callee
   199  // and updates use counts on the callees.  Returns the call site count
   200  // for that callee.
   201  func (s *inlClosureState) resolve(state *devirtualize.State, i int) (*ir.Func, int) {
   202  	p := s.parens[i]
   203  	if i < len(s.resolved) {
   204  		if callee := s.resolved[i]; callee != nil {
   205  			return callee, s.useCounts[callee]
   206  		}
   207  	}
   208  	n := p.X
   209  	call, ok := n.(*ir.CallExpr)
   210  	if !ok { // previously inlined
   211  		return nil, -1
   212  	}
   213  	devirtualize.StaticCall(state, call)
   214  	if callee := inline.InlineCallTarget(s.fn, call, s.profile); callee != nil {
   215  		for len(s.resolved) <= i {
   216  			s.resolved = append(s.resolved, nil)
   217  		}
   218  		s.resolved[i] = callee
   219  		c := s.useCounts[callee] + 1
   220  		s.useCounts[callee] = c
   221  		return callee, c
   222  	}
   223  	return nil, 0
   224  }
   225  
   226  func (s *inlClosureState) edit(state *devirtualize.State, i int) (*ir.CallExpr, *ir.InlinedCallExpr) {
   227  	n := s.parens[i].X
   228  	call, ok := n.(*ir.CallExpr)
   229  	if !ok {
   230  		return nil, nil
   231  	}
   232  	// This is redundant with earlier calls to
   233  	// resolve, but because things can change it
   234  	// must be re-checked.
   235  	callee, count := s.resolve(state, i)
   236  	if count <= 0 {
   237  		return nil, nil
   238  	}
   239  	if inlCall := inline.TryInlineCall(s.fn, call, s.bigCaller, s.profile, count == 1 && callee.ClosureParent != nil); inlCall != nil {
   240  		return call, inlCall
   241  	}
   242  	return nil, nil
   243  }
   244  
   245  // Mark inserts parentheses, and is called repeatedly.
   246  // These inserted parentheses mark the call sites where
   247  // inlining will be attempted.
   248  func (s *inlClosureState) mark(n ir.Node) ir.Node {
   249  	// Consider the expression "f(g())". We want to be able to replace
   250  	// "g()" in-place with its inlined representation. But if we first
   251  	// replace "f(...)" with its inlined representation, then "g()" will
   252  	// instead appear somewhere within this new AST.
   253  	//
   254  	// To mitigate this, each matched node n is wrapped in a ParenExpr,
   255  	// so we can reliably replace n in-place by assigning ParenExpr.X.
   256  	// It's safe to use ParenExpr here, because typecheck already
   257  	// removed them all.
   258  
   259  	p, _ := n.(*ir.ParenExpr)
   260  	if p != nil && s.callSites[p] {
   261  		return n // already visited n.X before wrapping
   262  	}
   263  
   264  	if p != nil {
   265  		n = p.X // in this case p was copied in from a (marked) inlined function, this is a new unvisited node.
   266  	}
   267  
   268  	ok := match(n)
   269  
   270  	// can't wrap TailCall's child into ParenExpr
   271  	if t, ok := n.(*ir.TailCallStmt); ok {
   272  		ir.EditChildren(t.Call, s.mark)
   273  	} else {
   274  		ir.EditChildren(n, s.mark)
   275  	}
   276  
   277  	if ok {
   278  		if p == nil {
   279  			p = ir.NewParenExpr(n.Pos(), n)
   280  			p.SetType(n.Type())
   281  			p.SetTypecheck(n.Typecheck())
   282  			s.callSites[p] = true
   283  		}
   284  
   285  		s.parens = append(s.parens, p)
   286  		n = p
   287  	} else if p != nil {
   288  		n = p // didn't change anything, restore n
   289  	}
   290  	return n
   291  }
   292  
   293  // parenthesize applies s.mark to all the nodes within
   294  // s.fn to mark calls and simplify rewriting them in place.
   295  func (s *inlClosureState) parenthesize() {
   296  	ir.EditChildren(s.fn, s.mark)
   297  }
   298  
   299  func (s *inlClosureState) unparenthesize() {
   300  	if s == nil {
   301  		return
   302  	}
   303  	if len(s.parens) == 0 {
   304  		return // short circuit
   305  	}
   306  
   307  	var unparen func(ir.Node) ir.Node
   308  	unparen = func(n ir.Node) ir.Node {
   309  		if paren, ok := n.(*ir.ParenExpr); ok {
   310  			n = paren.X
   311  		}
   312  		ir.EditChildren(n, unparen)
   313  		return n
   314  	}
   315  	ir.EditChildren(s.fn, unparen)
   316  }
   317  
   318  // fixpoint repeatedly edits a function until it stabilizes, returning
   319  // whether anything changed in any of the fixpoint iterations.
   320  //
   321  // It applies s.edit(n) to each node n within the parentheses in s.parens.
   322  // If s.edit(n) returns nil, no change is made. Otherwise, the result
   323  // replaces n in fn's body, and fixpoint iterates at least once more.
   324  //
   325  // After an iteration where all edit calls return nil, fixpoint
   326  // returns.
   327  func (s *inlClosureState) fixpoint() bool {
   328  	changed := false
   329  	var state devirtualize.State
   330  	ir.WithFunc(s.fn, func() {
   331  		done := false
   332  		for !done {
   333  			done = true
   334  			for i := 0; i < len(s.parens); i++ { // can't use "range parens" here
   335  				paren := s.parens[i]
   336  				if origCall, inlinedCall := s.edit(&state, i); inlinedCall != nil {
   337  					// Update AST and recursively mark nodes.
   338  					paren.X = inlinedCall
   339  					ir.EditChildren(inlinedCall, s.mark) // mark may append to parens
   340  					state.InlinedCall(s.fn, origCall, inlinedCall)
   341  					done = false
   342  					changed = true
   343  				}
   344  			}
   345  		}
   346  	})
   347  	return changed
   348  }
   349  
   350  func match(n ir.Node) bool {
   351  	switch n := n.(type) {
   352  	case *ir.CallExpr:
   353  		return true
   354  	case *ir.TailCallStmt:
   355  		n.Call.NoInline = true // can't inline yet
   356  	}
   357  	return false
   358  }
   359  

View as plain text