1
2
3
4
5 package devirtualize
6
7 import (
8 "cmd/compile/internal/base"
9 "cmd/compile/internal/inline"
10 "cmd/compile/internal/ir"
11 "cmd/compile/internal/logopt"
12 "cmd/compile/internal/pgoir"
13 "cmd/compile/internal/typecheck"
14 "cmd/compile/internal/types"
15 "cmd/internal/obj"
16 "cmd/internal/src"
17 "encoding/json"
18 "fmt"
19 "os"
20 "strings"
21 )
22
23
24
25
26 type CallStat struct {
27 Pkg string
28 Pos string
29
30 Caller string
31
32
33 Direct bool
34
35
36 Interface bool
37
38
39 Weight int64
40
41
42
43 Hottest string
44 HottestWeight int64
45
46
47
48
49
50
51 Devirtualized string
52 DevirtualizedWeight int64
53 }
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105 func ProfileGuided(fn *ir.Func, p *pgoir.Profile) {
106 ir.CurFunc = fn
107
108 name := ir.LinkFuncName(fn)
109
110 var jsonW *json.Encoder
111 if base.Debug.PGODebug >= 3 {
112 jsonW = json.NewEncoder(os.Stdout)
113 }
114
115 var edit func(n ir.Node) ir.Node
116 edit = func(n ir.Node) ir.Node {
117 if n == nil {
118 return n
119 }
120
121 ir.EditChildren(n, edit)
122
123 call, ok := n.(*ir.CallExpr)
124 if !ok {
125 return n
126 }
127
128 var stat *CallStat
129 if base.Debug.PGODebug >= 3 {
130
131
132
133 stat = constructCallStat(p, fn, name, call)
134 if stat != nil {
135 defer func() {
136 jsonW.Encode(&stat)
137 }()
138 }
139 }
140
141 op := call.Op()
142 if op != ir.OCALLFUNC && op != ir.OCALLINTER {
143 return n
144 }
145
146 if base.Debug.PGODebug >= 2 {
147 fmt.Printf("%v: PGO devirtualize considering call %v\n", ir.Line(call), call)
148 }
149
150 if call.GoDefer {
151 if base.Debug.PGODebug >= 2 {
152 fmt.Printf("%v: can't PGO devirtualize go/defer call %v\n", ir.Line(call), call)
153 }
154 return n
155 }
156
157 var newNode ir.Node
158 var callee *ir.Func
159 var weight int64
160 switch op {
161 case ir.OCALLFUNC:
162 newNode, callee, weight = maybeDevirtualizeFunctionCall(p, fn, call)
163 case ir.OCALLINTER:
164 newNode, callee, weight = maybeDevirtualizeInterfaceCall(p, fn, call)
165 default:
166 panic("unreachable")
167 }
168
169 if newNode == nil {
170 return n
171 }
172
173 if stat != nil {
174 stat.Devirtualized = ir.LinkFuncName(callee)
175 stat.DevirtualizedWeight = weight
176 }
177
178 return newNode
179 }
180
181 ir.EditChildren(fn, edit)
182 }
183
184
185
186
187 func maybeDevirtualizeInterfaceCall(p *pgoir.Profile, fn *ir.Func, call *ir.CallExpr) (ir.Node, *ir.Func, int64) {
188 if base.Debug.PGODevirtualize < 1 {
189 return nil, nil, 0
190 }
191
192
193 callee, weight := findHotConcreteInterfaceCallee(p, fn, call)
194 if callee == nil {
195 return nil, nil, 0
196 }
197
198 ctyp := methodRecvType(callee)
199 if ctyp == nil {
200 return nil, nil, 0
201 }
202
203 if !shouldPGODevirt(callee) {
204 return nil, nil, 0
205 }
206
207 if !base.PGOHash.MatchPosWithInfo(call.Pos(), "devirt", nil) {
208 return nil, nil, 0
209 }
210
211 return rewriteInterfaceCall(call, fn, callee, ctyp), callee, weight
212 }
213
214
215
216
217 func maybeDevirtualizeFunctionCall(p *pgoir.Profile, fn *ir.Func, call *ir.CallExpr) (ir.Node, *ir.Func, int64) {
218 if base.Debug.PGODevirtualize < 2 {
219 return nil, nil, 0
220 }
221
222
223 callee := pgoir.DirectCallee(call.Fun)
224 if callee != nil {
225 return nil, nil, 0
226 }
227
228
229 callee, weight := findHotConcreteFunctionCallee(p, fn, call)
230 if callee == nil {
231 return nil, nil, 0
232 }
233
234
235
236
237 if callee.OClosure != nil {
238 if base.Debug.PGODebug >= 3 {
239 fmt.Printf("callee %s is a closure, skipping\n", ir.FuncName(callee))
240 }
241 return nil, nil, 0
242 }
243
244
245
246
247 if callee.Sym().Pkg.Path == "runtime" && callee.Sym().Name == "memhash_varlen" {
248 if base.Debug.PGODebug >= 3 {
249 fmt.Printf("callee %s is a closure (runtime.memhash_varlen), skipping\n", ir.FuncName(callee))
250 }
251 return nil, nil, 0
252 }
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280 if callee.Type().Recv() != nil {
281 if base.Debug.PGODebug >= 3 {
282 fmt.Printf("callee %s is a method, skipping\n", ir.FuncName(callee))
283 }
284 return nil, nil, 0
285 }
286
287
288 if !shouldPGODevirt(callee) {
289 return nil, nil, 0
290 }
291
292 if !base.PGOHash.MatchPosWithInfo(call.Pos(), "devirt", nil) {
293 return nil, nil, 0
294 }
295
296 return rewriteFunctionCall(call, fn, callee), callee, weight
297 }
298
299
300
301
302
303
304 func shouldPGODevirt(fn *ir.Func) bool {
305 var reason string
306 if base.Flag.LowerM > 1 || logopt.Enabled() {
307 defer func() {
308 if reason != "" {
309 if base.Flag.LowerM > 1 {
310 fmt.Printf("%v: should not PGO devirtualize %v: %s\n", ir.Line(fn), ir.FuncName(fn), reason)
311 }
312 if logopt.Enabled() {
313 logopt.LogOpt(fn.Pos(), ": should not PGO devirtualize function", "pgoir-devirtualize", ir.FuncName(fn), reason)
314 }
315 }
316 }()
317 }
318
319 reason = inline.InlineImpossible(fn)
320 if reason != "" {
321 return false
322 }
323
324
325
326
327
328
329
330
331
332
333
334 return true
335 }
336
337
338
339
340 func constructCallStat(p *pgoir.Profile, fn *ir.Func, name string, call *ir.CallExpr) *CallStat {
341 switch call.Op() {
342 case ir.OCALLFUNC, ir.OCALLINTER, ir.OCALLMETH:
343 default:
344
345 return nil
346 }
347
348 stat := CallStat{
349 Pkg: base.Ctxt.Pkgpath,
350 Pos: ir.Line(call),
351 Caller: name,
352 }
353
354 offset := pgoir.NodeLineOffset(call, fn)
355
356 hotter := func(e *pgoir.IREdge) bool {
357 if stat.Hottest == "" {
358 return true
359 }
360 if e.Weight != stat.HottestWeight {
361 return e.Weight > stat.HottestWeight
362 }
363
364
365 return e.Dst.Name() < stat.Hottest
366 }
367
368 callerNode := p.WeightedCG.IRNodes[name]
369 if callerNode == nil {
370 return nil
371 }
372
373
374
375
376
377 for _, edge := range callerNode.OutEdges {
378 if edge.CallSiteOffset != offset {
379 continue
380 }
381 stat.Weight += edge.Weight
382 if hotter(edge) {
383 stat.HottestWeight = edge.Weight
384 stat.Hottest = edge.Dst.Name()
385 }
386 }
387
388 switch call.Op() {
389 case ir.OCALLFUNC:
390 stat.Interface = false
391
392 callee := pgoir.DirectCallee(call.Fun)
393 if callee != nil {
394 stat.Direct = true
395 if stat.Hottest == "" {
396 stat.Hottest = ir.LinkFuncName(callee)
397 }
398 } else {
399 stat.Direct = false
400 }
401 case ir.OCALLINTER:
402 stat.Direct = false
403 stat.Interface = true
404 case ir.OCALLMETH:
405 base.FatalfAt(call.Pos(), "OCALLMETH missed by typecheck")
406 }
407
408 return &stat
409 }
410
411
412
413
414
415
416
417 func copyInputs(curfn *ir.Func, pos src.XPos, recvOrFn ir.Node, args []ir.Node, init *ir.Nodes) (ir.Node, []ir.Node) {
418
419
420
421
422
423
424
425 var lhs, rhs []ir.Node
426 newRecvOrFn := typecheck.TempAt(pos, curfn, recvOrFn.Type())
427 lhs = append(lhs, newRecvOrFn)
428 rhs = append(rhs, recvOrFn)
429
430 for _, arg := range args {
431 argvar := typecheck.TempAt(pos, curfn, arg.Type())
432
433 lhs = append(lhs, argvar)
434 rhs = append(rhs, arg)
435 }
436
437 asList := ir.NewAssignListStmt(pos, ir.OAS2, lhs, rhs)
438 init.Append(typecheck.Stmt(asList))
439
440 return newRecvOrFn, lhs[1:]
441 }
442
443
444 func retTemps(curfn *ir.Func, pos src.XPos, call *ir.CallExpr) []ir.Node {
445 sig := call.Fun.Type()
446 var retvars []ir.Node
447 for _, ret := range sig.Results() {
448 retvars = append(retvars, typecheck.TempAt(pos, curfn, ret.Type))
449 }
450 return retvars
451 }
452
453
454
455
456 func condCall(curfn *ir.Func, pos src.XPos, cond ir.Node, thenCall, elseCall *ir.CallExpr, init ir.Nodes) *ir.InlinedCallExpr {
457
458
459 retvars := retTemps(curfn, pos, thenCall)
460
461 var thenBlock, elseBlock ir.Nodes
462 if len(retvars) == 0 {
463 thenBlock.Append(thenCall)
464 elseBlock.Append(elseCall)
465 } else {
466
467 thenRet := append([]ir.Node(nil), retvars...)
468 thenAsList := ir.NewAssignListStmt(pos, ir.OAS2, thenRet, []ir.Node{thenCall})
469 thenBlock.Append(typecheck.Stmt(thenAsList))
470
471 elseRet := append([]ir.Node(nil), retvars...)
472 elseAsList := ir.NewAssignListStmt(pos, ir.OAS2, elseRet, []ir.Node{elseCall})
473 elseBlock.Append(typecheck.Stmt(elseAsList))
474 }
475
476 nif := ir.NewIfStmt(pos, cond, thenBlock, elseBlock)
477 nif.SetInit(init)
478 nif.Likely = true
479
480 body := []ir.Node{typecheck.Stmt(nif)}
481
482
483
484 res := ir.NewInlinedCallExpr(pos, body, retvars)
485 res.SetType(thenCall.Type())
486 res.SetTypecheck(1)
487 return res
488 }
489
490
491
492 func rewriteInterfaceCall(call *ir.CallExpr, curfn, callee *ir.Func, concretetyp *types.Type) ir.Node {
493 if base.Flag.LowerM != 0 {
494 fmt.Printf("%v: PGO devirtualizing interface call %v to %v\n", ir.Line(call), call.Fun, callee)
495 }
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525 sel := call.Fun.(*ir.SelectorExpr)
526 method := sel.Sel
527 pos := call.Pos()
528 init := ir.TakeInit(call)
529
530 recv, args := copyInputs(curfn, pos, sel.X, call.Args.Take(), &init)
531
532
533 argvars := append([]ir.Node(nil), args...)
534 call.Args = argvars
535
536 tmpnode := typecheck.TempAt(base.Pos, curfn, concretetyp)
537 tmpok := typecheck.TempAt(base.Pos, curfn, types.Types[types.TBOOL])
538
539 assert := ir.NewTypeAssertExpr(pos, recv, concretetyp)
540
541 assertAsList := ir.NewAssignListStmt(pos, ir.OAS2, []ir.Node{tmpnode, tmpok}, []ir.Node{typecheck.Expr(assert)})
542 init.Append(typecheck.Stmt(assertAsList))
543
544 concreteCallee := typecheck.XDotMethod(pos, tmpnode, method, true)
545
546 argvars = append([]ir.Node(nil), argvars...)
547 concreteCall := typecheck.Call(pos, concreteCallee, argvars, call.IsDDD).(*ir.CallExpr)
548
549 res := condCall(curfn, pos, tmpok, concreteCall, call, init)
550
551 if base.Debug.PGODebug >= 3 {
552 fmt.Printf("PGO devirtualizing interface call to %+v. After: %+v\n", concretetyp, res)
553 }
554
555 return res
556 }
557
558
559
560 func rewriteFunctionCall(call *ir.CallExpr, curfn, callee *ir.Func) ir.Node {
561 if base.Flag.LowerM != 0 {
562 fmt.Printf("%v: PGO devirtualizing function call %v to %v\n", ir.Line(call), call.Fun, callee)
563 }
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591 pos := call.Pos()
592 init := ir.TakeInit(call)
593
594 fn, args := copyInputs(curfn, pos, call.Fun, call.Args.Take(), &init)
595
596
597 argvars := append([]ir.Node(nil), args...)
598 call.Args = argvars
599
600
601
602 fnIface := typecheck.Expr(ir.NewConvExpr(pos, ir.OCONV, types.Types[types.TINTER], fn))
603 calleeIface := typecheck.Expr(ir.NewConvExpr(pos, ir.OCONV, types.Types[types.TINTER], callee.Nname))
604
605 fnPC := ir.FuncPC(pos, fnIface, obj.ABIInternal)
606 concretePC := ir.FuncPC(pos, calleeIface, obj.ABIInternal)
607
608 pcEq := typecheck.Expr(ir.NewBinaryExpr(base.Pos, ir.OEQ, fnPC, concretePC))
609
610
611
612
613 if callee.OClosure != nil {
614 base.Fatalf("Callee is a closure: %+v", callee)
615 }
616
617
618 argvars = append([]ir.Node(nil), argvars...)
619 concreteCall := typecheck.Call(pos, callee.Nname, argvars, call.IsDDD).(*ir.CallExpr)
620
621 res := condCall(curfn, pos, pcEq, concreteCall, call, init)
622
623 if base.Debug.PGODebug >= 3 {
624 fmt.Printf("PGO devirtualizing function call to %+v. After: %+v\n", ir.FuncName(callee), res)
625 }
626
627 return res
628 }
629
630
631
632 func methodRecvType(fn *ir.Func) *types.Type {
633 recv := fn.Nname.Type().Recv()
634 if recv == nil {
635 return nil
636 }
637 return recv.Type
638 }
639
640
641
642 func interfaceCallRecvTypeAndMethod(call *ir.CallExpr) (*types.Type, *types.Sym) {
643 if call.Op() != ir.OCALLINTER {
644 base.Fatalf("Call isn't OCALLINTER: %+v", call)
645 }
646
647 sel, ok := call.Fun.(*ir.SelectorExpr)
648 if !ok {
649 base.Fatalf("OCALLINTER doesn't contain SelectorExpr: %+v", call)
650 }
651
652 return sel.X.Type(), sel.Sel
653 }
654
655
656
657
658
659 func findHotConcreteCallee(p *pgoir.Profile, caller *ir.Func, call *ir.CallExpr, extraFn func(callerName string, callOffset int, candidate *pgoir.IREdge) bool) (*ir.Func, int64) {
660 callerName := ir.LinkFuncName(caller)
661 callerNode := p.WeightedCG.IRNodes[callerName]
662 callOffset := pgoir.NodeLineOffset(call, caller)
663
664 if callerNode == nil {
665 return nil, 0
666 }
667
668 var hottest *pgoir.IREdge
669
670
671
672
673
674
675
676 hotter := func(e *pgoir.IREdge) bool {
677 if hottest == nil {
678 return true
679 }
680 if e.Weight != hottest.Weight {
681 return e.Weight > hottest.Weight
682 }
683
684
685
686
687
688 if (hottest.Dst.AST == nil) != (e.Dst.AST == nil) {
689 if e.Dst.AST != nil {
690 return true
691 }
692 return false
693 }
694
695
696
697 return e.Dst.Name() < hottest.Dst.Name()
698 }
699
700 for _, e := range callerNode.OutEdges {
701 if e.CallSiteOffset != callOffset {
702 continue
703 }
704
705 if !hotter(e) {
706
707
708
709
710
711 if base.Debug.PGODebug >= 2 {
712 fmt.Printf("%v: edge %s:%d -> %s (weight %d): too cold (hottest %d)\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight, hottest.Weight)
713 }
714 continue
715 }
716
717 if e.Dst.AST == nil {
718
719
720
721
722
723
724
725
726
727 if base.Debug.PGODebug >= 2 {
728 fmt.Printf("%v: edge %s:%d -> %s (weight %d) (missing IR): hottest so far\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight)
729 }
730 hottest = e
731 continue
732 }
733
734 if extraFn != nil && !extraFn(callerName, callOffset, e) {
735 continue
736 }
737
738 if base.Debug.PGODebug >= 2 {
739 fmt.Printf("%v: edge %s:%d -> %s (weight %d): hottest so far\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight)
740 }
741 hottest = e
742 }
743
744 if hottest == nil {
745 if base.Debug.PGODebug >= 2 {
746 fmt.Printf("%v: call %s:%d: no hot callee\n", ir.Line(call), callerName, callOffset)
747 }
748 return nil, 0
749 }
750
751 if base.Debug.PGODebug >= 2 {
752 fmt.Printf("%v: call %s:%d: hottest callee %s (weight %d)\n", ir.Line(call), callerName, callOffset, hottest.Dst.Name(), hottest.Weight)
753 }
754 return hottest.Dst.AST, hottest.Weight
755 }
756
757
758
759 func findHotConcreteInterfaceCallee(p *pgoir.Profile, caller *ir.Func, call *ir.CallExpr) (*ir.Func, int64) {
760 inter, method := interfaceCallRecvTypeAndMethod(call)
761
762 return findHotConcreteCallee(p, caller, call, func(callerName string, callOffset int, e *pgoir.IREdge) bool {
763 ctyp := methodRecvType(e.Dst.AST)
764 if ctyp == nil {
765
766
767 if base.Debug.PGODebug >= 2 {
768 fmt.Printf("%v: edge %s:%d -> %s (weight %d): callee not a method\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight)
769 }
770 return false
771 }
772
773
774
775 if !typecheck.Implements(ctyp, inter) {
776
777
778
779
780
781
782
783
784
785 if base.Debug.PGODebug >= 2 {
786 why := typecheck.ImplementsExplain(ctyp, inter)
787 fmt.Printf("%v: edge %s:%d -> %s (weight %d): %v doesn't implement %v (%s)\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight, ctyp, inter, why)
788 }
789 return false
790 }
791
792
793
794 if !strings.HasSuffix(e.Dst.Name(), "."+method.Name) {
795 if base.Debug.PGODebug >= 2 {
796 fmt.Printf("%v: edge %s:%d -> %s (weight %d): callee is a different method\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight)
797 }
798 return false
799 }
800
801 return true
802 })
803 }
804
805
806
807 func findHotConcreteFunctionCallee(p *pgoir.Profile, caller *ir.Func, call *ir.CallExpr) (*ir.Func, int64) {
808 typ := call.Fun.Type().Underlying()
809
810 return findHotConcreteCallee(p, caller, call, func(callerName string, callOffset int, e *pgoir.IREdge) bool {
811 ctyp := e.Dst.AST.Type().Underlying()
812
813
814
815
816
817
818
819
820 if !types.Identical(typ, ctyp) {
821 if base.Debug.PGODebug >= 2 {
822 fmt.Printf("%v: edge %s:%d -> %s (weight %d): %v doesn't match %v\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight, ctyp, typ)
823 }
824 return false
825 }
826
827 return true
828 })
829 }
830
View as plain text