1
2
3
4
5 package devirtualize
6
7 import (
8 "cmd/compile/internal/base"
9 "cmd/compile/internal/inline"
10 "cmd/compile/internal/ir"
11 "cmd/compile/internal/logopt"
12 "cmd/compile/internal/pgoir"
13 "cmd/compile/internal/typecheck"
14 "cmd/compile/internal/types"
15 "cmd/internal/obj"
16 "cmd/internal/src"
17 "encoding/json"
18 "fmt"
19 "os"
20 "strings"
21 )
22
23
24
25
26 type CallStat struct {
27 Pkg string
28 Pos string
29
30 Caller string
31
32
33 Direct bool
34
35
36 Interface bool
37
38
39 Weight int64
40
41
42
43 Hottest string
44 HottestWeight int64
45
46
47
48
49
50
51 Devirtualized string
52 DevirtualizedWeight int64
53 }
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105 func ProfileGuided(fn *ir.Func, p *pgoir.Profile) {
106 ir.CurFunc = fn
107
108 name := ir.LinkFuncName(fn)
109
110 var jsonW *json.Encoder
111 if base.Debug.PGODebug >= 3 {
112 jsonW = json.NewEncoder(os.Stdout)
113 }
114
115 var edit func(n ir.Node) ir.Node
116 edit = func(n ir.Node) ir.Node {
117 if n == nil {
118 return n
119 }
120
121 ir.EditChildren(n, edit)
122
123 call, ok := n.(*ir.CallExpr)
124 if !ok {
125 return n
126 }
127
128 var stat *CallStat
129 if base.Debug.PGODebug >= 3 {
130
131
132
133 stat = constructCallStat(p, fn, name, call)
134 if stat != nil {
135 defer func() {
136 jsonW.Encode(&stat)
137 }()
138 }
139 }
140
141 op := call.Op()
142 if op != ir.OCALLFUNC && op != ir.OCALLINTER {
143 return n
144 }
145
146 if base.Debug.PGODebug >= 2 {
147 fmt.Printf("%v: PGO devirtualize considering call %v\n", ir.Line(call), call)
148 }
149
150 if call.GoDefer {
151 if base.Debug.PGODebug >= 2 {
152 fmt.Printf("%v: can't PGO devirtualize go/defer call %v\n", ir.Line(call), call)
153 }
154 return n
155 }
156
157 var newNode ir.Node
158 var callee *ir.Func
159 var weight int64
160 switch op {
161 case ir.OCALLFUNC:
162 newNode, callee, weight = maybeDevirtualizeFunctionCall(p, fn, call)
163 case ir.OCALLINTER:
164 newNode, callee, weight = maybeDevirtualizeInterfaceCall(p, fn, call)
165 default:
166 panic("unreachable")
167 }
168
169 if newNode == nil {
170 return n
171 }
172
173 if stat != nil {
174 stat.Devirtualized = ir.LinkFuncName(callee)
175 stat.DevirtualizedWeight = weight
176 }
177
178 return newNode
179 }
180
181 ir.EditChildren(fn, edit)
182 }
183
184
185
186
187 func maybeDevirtualizeInterfaceCall(p *pgoir.Profile, fn *ir.Func, call *ir.CallExpr) (ir.Node, *ir.Func, int64) {
188 if base.Debug.PGODevirtualize < 1 {
189 return nil, nil, 0
190 }
191
192
193 callee, weight := findHotConcreteInterfaceCallee(p, fn, call)
194 if callee == nil {
195 return nil, nil, 0
196 }
197
198 ctyp := methodRecvType(callee)
199 if ctyp == nil {
200 return nil, nil, 0
201 }
202
203 if !shouldPGODevirt(callee) {
204 return nil, nil, 0
205 }
206
207 if !base.PGOHash.MatchPosWithInfo(call.Pos(), "devirt", nil) {
208 return nil, nil, 0
209 }
210
211 return rewriteInterfaceCall(call, fn, callee, ctyp), callee, weight
212 }
213
214
215
216
217 func maybeDevirtualizeFunctionCall(p *pgoir.Profile, fn *ir.Func, call *ir.CallExpr) (ir.Node, *ir.Func, int64) {
218 if base.Debug.PGODevirtualize < 2 {
219 return nil, nil, 0
220 }
221
222
223 callee := pgoir.DirectCallee(call.Fun)
224 if callee != nil {
225 return nil, nil, 0
226 }
227
228
229 callee, weight := findHotConcreteFunctionCallee(p, fn, call)
230 if callee == nil {
231 return nil, nil, 0
232 }
233
234
235
236
237 if callee.OClosure != nil {
238 if base.Debug.PGODebug >= 3 {
239 fmt.Printf("callee %s is a closure, skipping\n", ir.FuncName(callee))
240 }
241 return nil, nil, 0
242 }
243
244
245
246 if callee.Sym().Pkg.Path == "runtime" && callee.Sym().Name == "memhash_varlen" {
247 if base.Debug.PGODebug >= 3 {
248 fmt.Printf("callee %s is a closure (runtime.memhash_varlen), skipping\n", ir.FuncName(callee))
249 }
250 return nil, nil, 0
251 }
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279 if callee.Type().Recv() != nil {
280 if base.Debug.PGODebug >= 3 {
281 fmt.Printf("callee %s is a method, skipping\n", ir.FuncName(callee))
282 }
283 return nil, nil, 0
284 }
285
286
287 if !shouldPGODevirt(callee) {
288 return nil, nil, 0
289 }
290
291 if !base.PGOHash.MatchPosWithInfo(call.Pos(), "devirt", nil) {
292 return nil, nil, 0
293 }
294
295 return rewriteFunctionCall(call, fn, callee), callee, weight
296 }
297
298
299
300
301
302
303 func shouldPGODevirt(fn *ir.Func) bool {
304 var reason string
305 if base.Flag.LowerM > 1 || logopt.Enabled() {
306 defer func() {
307 if reason != "" {
308 if base.Flag.LowerM > 1 {
309 fmt.Printf("%v: should not PGO devirtualize %v: %s\n", ir.Line(fn), ir.FuncName(fn), reason)
310 }
311 if logopt.Enabled() {
312 logopt.LogOpt(fn.Pos(), ": should not PGO devirtualize function", "pgoir-devirtualize", ir.FuncName(fn), reason)
313 }
314 }
315 }()
316 }
317
318 reason = inline.InlineImpossible(fn)
319 if reason != "" {
320 return false
321 }
322
323
324
325
326
327
328
329
330
331
332
333 return true
334 }
335
336
337
338
339 func constructCallStat(p *pgoir.Profile, fn *ir.Func, name string, call *ir.CallExpr) *CallStat {
340 switch call.Op() {
341 case ir.OCALLFUNC, ir.OCALLINTER, ir.OCALLMETH:
342 default:
343
344 return nil
345 }
346
347 stat := CallStat{
348 Pkg: base.Ctxt.Pkgpath,
349 Pos: ir.Line(call),
350 Caller: name,
351 }
352
353 offset := pgoir.NodeLineOffset(call, fn)
354
355 hotter := func(e *pgoir.IREdge) bool {
356 if stat.Hottest == "" {
357 return true
358 }
359 if e.Weight != stat.HottestWeight {
360 return e.Weight > stat.HottestWeight
361 }
362
363
364 return e.Dst.Name() < stat.Hottest
365 }
366
367 callerNode := p.WeightedCG.IRNodes[name]
368 if callerNode == nil {
369 return nil
370 }
371
372
373
374
375
376 for _, edge := range callerNode.OutEdges {
377 if edge.CallSiteOffset != offset {
378 continue
379 }
380 stat.Weight += edge.Weight
381 if hotter(edge) {
382 stat.HottestWeight = edge.Weight
383 stat.Hottest = edge.Dst.Name()
384 }
385 }
386
387 switch call.Op() {
388 case ir.OCALLFUNC:
389 stat.Interface = false
390
391 callee := pgoir.DirectCallee(call.Fun)
392 if callee != nil {
393 stat.Direct = true
394 if stat.Hottest == "" {
395 stat.Hottest = ir.LinkFuncName(callee)
396 }
397 } else {
398 stat.Direct = false
399 }
400 case ir.OCALLINTER:
401 stat.Direct = false
402 stat.Interface = true
403 case ir.OCALLMETH:
404 base.FatalfAt(call.Pos(), "OCALLMETH missed by typecheck")
405 }
406
407 return &stat
408 }
409
410
411
412
413
414
415
416 func copyInputs(curfn *ir.Func, pos src.XPos, recvOrFn ir.Node, args []ir.Node, init *ir.Nodes) (ir.Node, []ir.Node) {
417
418
419
420
421
422
423
424 var lhs, rhs []ir.Node
425 newRecvOrFn := typecheck.TempAt(pos, curfn, recvOrFn.Type())
426 lhs = append(lhs, newRecvOrFn)
427 rhs = append(rhs, recvOrFn)
428
429 for _, arg := range args {
430 argvar := typecheck.TempAt(pos, curfn, arg.Type())
431
432 lhs = append(lhs, argvar)
433 rhs = append(rhs, arg)
434 }
435
436 asList := ir.NewAssignListStmt(pos, ir.OAS2, lhs, rhs)
437 init.Append(typecheck.Stmt(asList))
438
439 return newRecvOrFn, lhs[1:]
440 }
441
442
443 func retTemps(curfn *ir.Func, pos src.XPos, call *ir.CallExpr) []ir.Node {
444 sig := call.Fun.Type()
445 var retvars []ir.Node
446 for _, ret := range sig.Results() {
447 retvars = append(retvars, typecheck.TempAt(pos, curfn, ret.Type))
448 }
449 return retvars
450 }
451
452
453
454
455 func condCall(curfn *ir.Func, pos src.XPos, cond ir.Node, thenCall, elseCall *ir.CallExpr, init ir.Nodes) *ir.InlinedCallExpr {
456
457
458 retvars := retTemps(curfn, pos, thenCall)
459
460 var thenBlock, elseBlock ir.Nodes
461 if len(retvars) == 0 {
462 thenBlock.Append(thenCall)
463 elseBlock.Append(elseCall)
464 } else {
465
466 thenRet := append([]ir.Node(nil), retvars...)
467 thenAsList := ir.NewAssignListStmt(pos, ir.OAS2, thenRet, []ir.Node{thenCall})
468 thenBlock.Append(typecheck.Stmt(thenAsList))
469
470 elseRet := append([]ir.Node(nil), retvars...)
471 elseAsList := ir.NewAssignListStmt(pos, ir.OAS2, elseRet, []ir.Node{elseCall})
472 elseBlock.Append(typecheck.Stmt(elseAsList))
473 }
474
475 nif := ir.NewIfStmt(pos, cond, thenBlock, elseBlock)
476 nif.SetInit(init)
477 nif.Likely = true
478
479 body := []ir.Node{typecheck.Stmt(nif)}
480
481
482
483 res := ir.NewInlinedCallExpr(pos, body, retvars)
484 res.SetType(thenCall.Type())
485 res.SetTypecheck(1)
486 return res
487 }
488
489
490
491 func rewriteInterfaceCall(call *ir.CallExpr, curfn, callee *ir.Func, concretetyp *types.Type) ir.Node {
492 if base.Flag.LowerM != 0 {
493 fmt.Printf("%v: PGO devirtualizing interface call %v to %v\n", ir.Line(call), call.Fun, callee)
494 }
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524 sel := call.Fun.(*ir.SelectorExpr)
525 method := sel.Sel
526 pos := call.Pos()
527 init := ir.TakeInit(call)
528
529 recv, args := copyInputs(curfn, pos, sel.X, call.Args.Take(), &init)
530
531
532 argvars := append([]ir.Node(nil), args...)
533 call.Args = argvars
534
535 tmpnode := typecheck.TempAt(base.Pos, curfn, concretetyp)
536 tmpok := typecheck.TempAt(base.Pos, curfn, types.Types[types.TBOOL])
537
538 assert := ir.NewTypeAssertExpr(pos, recv, concretetyp)
539
540 assertAsList := ir.NewAssignListStmt(pos, ir.OAS2, []ir.Node{tmpnode, tmpok}, []ir.Node{typecheck.Expr(assert)})
541 init.Append(typecheck.Stmt(assertAsList))
542
543 concreteCallee := typecheck.XDotMethod(pos, tmpnode, method, true)
544
545 argvars = append([]ir.Node(nil), argvars...)
546 concreteCall := typecheck.Call(pos, concreteCallee, argvars, call.IsDDD).(*ir.CallExpr)
547
548 res := condCall(curfn, pos, tmpok, concreteCall, call, init)
549
550 if base.Debug.PGODebug >= 3 {
551 fmt.Printf("PGO devirtualizing interface call to %+v. After: %+v\n", concretetyp, res)
552 }
553
554 return res
555 }
556
557
558
559 func rewriteFunctionCall(call *ir.CallExpr, curfn, callee *ir.Func) ir.Node {
560 if base.Flag.LowerM != 0 {
561 fmt.Printf("%v: PGO devirtualizing function call %v to %v\n", ir.Line(call), call.Fun, callee)
562 }
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590 pos := call.Pos()
591 init := ir.TakeInit(call)
592
593 fn, args := copyInputs(curfn, pos, call.Fun, call.Args.Take(), &init)
594
595
596 argvars := append([]ir.Node(nil), args...)
597 call.Args = argvars
598
599
600
601 fnIface := typecheck.Expr(ir.NewConvExpr(pos, ir.OCONV, types.Types[types.TINTER], fn))
602 calleeIface := typecheck.Expr(ir.NewConvExpr(pos, ir.OCONV, types.Types[types.TINTER], callee.Nname))
603
604 fnPC := ir.FuncPC(pos, fnIface, obj.ABIInternal)
605 concretePC := ir.FuncPC(pos, calleeIface, obj.ABIInternal)
606
607 pcEq := typecheck.Expr(ir.NewBinaryExpr(base.Pos, ir.OEQ, fnPC, concretePC))
608
609
610
611
612 if callee.OClosure != nil {
613 base.Fatalf("Callee is a closure: %+v", callee)
614 }
615
616
617 argvars = append([]ir.Node(nil), argvars...)
618 concreteCall := typecheck.Call(pos, callee.Nname, argvars, call.IsDDD).(*ir.CallExpr)
619
620 res := condCall(curfn, pos, pcEq, concreteCall, call, init)
621
622 if base.Debug.PGODebug >= 3 {
623 fmt.Printf("PGO devirtualizing function call to %+v. After: %+v\n", ir.FuncName(callee), res)
624 }
625
626 return res
627 }
628
629
630
631 func methodRecvType(fn *ir.Func) *types.Type {
632 recv := fn.Nname.Type().Recv()
633 if recv == nil {
634 return nil
635 }
636 return recv.Type
637 }
638
639
640
641 func interfaceCallRecvTypeAndMethod(call *ir.CallExpr) (*types.Type, *types.Sym) {
642 if call.Op() != ir.OCALLINTER {
643 base.Fatalf("Call isn't OCALLINTER: %+v", call)
644 }
645
646 sel, ok := call.Fun.(*ir.SelectorExpr)
647 if !ok {
648 base.Fatalf("OCALLINTER doesn't contain SelectorExpr: %+v", call)
649 }
650
651 return sel.X.Type(), sel.Sel
652 }
653
654
655
656
657
658 func findHotConcreteCallee(p *pgoir.Profile, caller *ir.Func, call *ir.CallExpr, extraFn func(callerName string, callOffset int, candidate *pgoir.IREdge) bool) (*ir.Func, int64) {
659 callerName := ir.LinkFuncName(caller)
660 callerNode := p.WeightedCG.IRNodes[callerName]
661 callOffset := pgoir.NodeLineOffset(call, caller)
662
663 if callerNode == nil {
664 return nil, 0
665 }
666
667 var hottest *pgoir.IREdge
668
669
670
671
672
673
674
675 hotter := func(e *pgoir.IREdge) bool {
676 if hottest == nil {
677 return true
678 }
679 if e.Weight != hottest.Weight {
680 return e.Weight > hottest.Weight
681 }
682
683
684
685
686
687 if (hottest.Dst.AST == nil) != (e.Dst.AST == nil) {
688 if e.Dst.AST != nil {
689 return true
690 }
691 return false
692 }
693
694
695
696 return e.Dst.Name() < hottest.Dst.Name()
697 }
698
699 for _, e := range callerNode.OutEdges {
700 if e.CallSiteOffset != callOffset {
701 continue
702 }
703
704 if !hotter(e) {
705
706
707
708
709
710 if base.Debug.PGODebug >= 2 {
711 fmt.Printf("%v: edge %s:%d -> %s (weight %d): too cold (hottest %d)\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight, hottest.Weight)
712 }
713 continue
714 }
715
716 if e.Dst.AST == nil {
717
718
719
720
721
722
723
724
725
726 if base.Debug.PGODebug >= 2 {
727 fmt.Printf("%v: edge %s:%d -> %s (weight %d) (missing IR): hottest so far\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight)
728 }
729 hottest = e
730 continue
731 }
732
733 if extraFn != nil && !extraFn(callerName, callOffset, e) {
734 continue
735 }
736
737 if base.Debug.PGODebug >= 2 {
738 fmt.Printf("%v: edge %s:%d -> %s (weight %d): hottest so far\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight)
739 }
740 hottest = e
741 }
742
743 if hottest == nil {
744 if base.Debug.PGODebug >= 2 {
745 fmt.Printf("%v: call %s:%d: no hot callee\n", ir.Line(call), callerName, callOffset)
746 }
747 return nil, 0
748 }
749
750 if base.Debug.PGODebug >= 2 {
751 fmt.Printf("%v: call %s:%d: hottest callee %s (weight %d)\n", ir.Line(call), callerName, callOffset, hottest.Dst.Name(), hottest.Weight)
752 }
753 return hottest.Dst.AST, hottest.Weight
754 }
755
756
757
758 func findHotConcreteInterfaceCallee(p *pgoir.Profile, caller *ir.Func, call *ir.CallExpr) (*ir.Func, int64) {
759 inter, method := interfaceCallRecvTypeAndMethod(call)
760
761 return findHotConcreteCallee(p, caller, call, func(callerName string, callOffset int, e *pgoir.IREdge) bool {
762 ctyp := methodRecvType(e.Dst.AST)
763 if ctyp == nil {
764
765
766 if base.Debug.PGODebug >= 2 {
767 fmt.Printf("%v: edge %s:%d -> %s (weight %d): callee not a method\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight)
768 }
769 return false
770 }
771
772
773
774 if !typecheck.Implements(ctyp, inter) {
775
776
777
778
779
780
781
782
783
784 if base.Debug.PGODebug >= 2 {
785 why := typecheck.ImplementsExplain(ctyp, inter)
786 fmt.Printf("%v: edge %s:%d -> %s (weight %d): %v doesn't implement %v (%s)\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight, ctyp, inter, why)
787 }
788 return false
789 }
790
791
792
793 if !strings.HasSuffix(e.Dst.Name(), "."+method.Name) {
794 if base.Debug.PGODebug >= 2 {
795 fmt.Printf("%v: edge %s:%d -> %s (weight %d): callee is a different method\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight)
796 }
797 return false
798 }
799
800 return true
801 })
802 }
803
804
805
806 func findHotConcreteFunctionCallee(p *pgoir.Profile, caller *ir.Func, call *ir.CallExpr) (*ir.Func, int64) {
807 typ := call.Fun.Type().Underlying()
808
809 return findHotConcreteCallee(p, caller, call, func(callerName string, callOffset int, e *pgoir.IREdge) bool {
810 ctyp := e.Dst.AST.Type().Underlying()
811
812
813
814
815
816
817
818
819 if !types.Identical(typ, ctyp) {
820 if base.Debug.PGODebug >= 2 {
821 fmt.Printf("%v: edge %s:%d -> %s (weight %d): %v doesn't match %v\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight, ctyp, typ)
822 }
823 return false
824 }
825
826 return true
827 })
828 }
829
View as plain text