Source file
src/cmd/fix/typecheck.go
1
2
3
4
5 package main
6
7 import (
8 "fmt"
9 "go/ast"
10 "go/parser"
11 "go/token"
12 "maps"
13 "os"
14 "os/exec"
15 "path/filepath"
16 "reflect"
17 "runtime"
18 "strings"
19 )
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58 func mkType(t string) string {
59 return "type " + t
60 }
61
62 func getType(t string) string {
63 if !isType(t) {
64 return ""
65 }
66 return t[len("type "):]
67 }
68
69 func isType(t string) bool {
70 return strings.HasPrefix(t, "type ")
71 }
72
73
74
75
76
77
78 type TypeConfig struct {
79 Type map[string]*Type
80 Var map[string]string
81 Func map[string]string
82
83
84
85
86 External map[string]string
87 }
88
89
90
91 func (cfg *TypeConfig) typeof(name string) string {
92 if cfg.Var != nil {
93 if t := cfg.Var[name]; t != "" {
94 return t
95 }
96 }
97 if cfg.Func != nil {
98 if t := cfg.Func[name]; t != "" {
99 return "func()" + t
100 }
101 }
102 return ""
103 }
104
105
106
107
108 type Type struct {
109 Field map[string]string
110 Method map[string]string
111 Embed []string
112 Def string
113 }
114
115
116
117 func (typ *Type) dot(cfg *TypeConfig, name string) string {
118 if typ.Field != nil {
119 if t := typ.Field[name]; t != "" {
120 return t
121 }
122 }
123 if typ.Method != nil {
124 if t := typ.Method[name]; t != "" {
125 return t
126 }
127 }
128
129 for _, e := range typ.Embed {
130 etyp := cfg.Type[e]
131 if etyp != nil {
132 if t := etyp.dot(cfg, name); t != "" {
133 return t
134 }
135 }
136 }
137
138 return ""
139 }
140
141
142
143
144
145
146 func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[any]string, assign map[string][]any) {
147 typeof = make(map[any]string)
148 assign = make(map[string][]any)
149 cfg1 := &TypeConfig{}
150 *cfg1 = *cfg
151 copied := false
152
153
154 cfg.External = map[string]string{}
155 cfg1.External = cfg.External
156 if imports(f, "C") {
157
158
159
160 err := func() error {
161 txt, err := gofmtFile(f)
162 if err != nil {
163 return err
164 }
165 dir, err := os.MkdirTemp(os.TempDir(), "fix_cgo_typecheck")
166 if err != nil {
167 return err
168 }
169 defer os.RemoveAll(dir)
170 err = os.WriteFile(filepath.Join(dir, "in.go"), txt, 0600)
171 if err != nil {
172 return err
173 }
174 goCmd := "go"
175 if goroot := runtime.GOROOT(); goroot != "" {
176 goCmd = filepath.Join(goroot, "bin", "go")
177 }
178 cmd := exec.Command(goCmd, "tool", "cgo", "-objdir", dir, "-srcdir", dir, "in.go")
179 if reportCgoError != nil {
180
181
182 cmd.Stderr = os.Stderr
183 }
184 err = cmd.Run()
185 if err != nil {
186 return err
187 }
188 out, err := os.ReadFile(filepath.Join(dir, "_cgo_gotypes.go"))
189 if err != nil {
190 return err
191 }
192 cgo, err := parser.ParseFile(token.NewFileSet(), "cgo.go", out, 0)
193 if err != nil {
194 return err
195 }
196 for _, decl := range cgo.Decls {
197 fn, ok := decl.(*ast.FuncDecl)
198 if !ok {
199 continue
200 }
201 if strings.HasPrefix(fn.Name.Name, "_Cfunc_") {
202 var params, results []string
203 for _, p := range fn.Type.Params.List {
204 t := gofmt(p.Type)
205 t = strings.ReplaceAll(t, "_Ctype_", "C.")
206 params = append(params, t)
207 }
208 for _, r := range fn.Type.Results.List {
209 t := gofmt(r.Type)
210 t = strings.ReplaceAll(t, "_Ctype_", "C.")
211 results = append(results, t)
212 }
213 cfg.External["C."+fn.Name.Name[7:]] = joinFunc(params, results)
214 }
215 }
216 return nil
217 }()
218 if err != nil {
219 if reportCgoError == nil {
220 fmt.Fprintf(os.Stderr, "go fix: warning: no cgo types: %s\n", err)
221 } else {
222 reportCgoError(err)
223 }
224 }
225 }
226
227
228 for _, decl := range f.Decls {
229 fn, ok := decl.(*ast.FuncDecl)
230 if !ok {
231 continue
232 }
233 typecheck1(cfg, fn.Type, typeof, assign)
234 t := typeof[fn.Type]
235 if fn.Recv != nil {
236
237 rcvr := typeof[fn.Recv]
238 if !isType(rcvr) {
239 if len(fn.Recv.List) != 1 {
240 continue
241 }
242 rcvr = mkType(gofmt(fn.Recv.List[0].Type))
243 typeof[fn.Recv.List[0].Type] = rcvr
244 }
245 rcvr = getType(rcvr)
246 if rcvr != "" && rcvr[0] == '*' {
247 rcvr = rcvr[1:]
248 }
249 typeof[rcvr+"."+fn.Name.Name] = t
250 } else {
251 if isType(t) {
252 t = getType(t)
253 } else {
254 t = gofmt(fn.Type)
255 }
256 typeof[fn.Name] = t
257
258
259 typeof[fn.Name.Obj] = t
260 }
261 }
262
263
264 for _, decl := range f.Decls {
265 d, ok := decl.(*ast.GenDecl)
266 if ok {
267 for _, s := range d.Specs {
268 switch s := s.(type) {
269 case *ast.TypeSpec:
270 if cfg1.Type[s.Name.Name] != nil {
271 break
272 }
273 if !copied {
274 copied = true
275
276 cfg1.Type = maps.Clone(cfg.Type)
277 if cfg1.Type == nil {
278 cfg1.Type = make(map[string]*Type)
279 }
280 }
281 t := &Type{Field: map[string]string{}}
282 cfg1.Type[s.Name.Name] = t
283 switch st := s.Type.(type) {
284 case *ast.StructType:
285 for _, f := range st.Fields.List {
286 for _, n := range f.Names {
287 t.Field[n.Name] = gofmt(f.Type)
288 }
289 }
290 case *ast.ArrayType, *ast.StarExpr, *ast.MapType:
291 t.Def = gofmt(st)
292 }
293 }
294 }
295 }
296 }
297
298 typecheck1(cfg1, f, typeof, assign)
299 return typeof, assign
300 }
301
302
303
304 var reportCgoError func(err error)
305
306 func makeExprList(a []*ast.Ident) []ast.Expr {
307 var b []ast.Expr
308 for _, x := range a {
309 b = append(b, x)
310 }
311 return b
312 }
313
314
315
316
317 func typecheck1(cfg *TypeConfig, f any, typeof map[any]string, assign map[string][]any) {
318
319
320 set := func(n ast.Expr, typ string, isDecl bool) {
321 if typeof[n] != "" || typ == "" {
322 if typeof[n] != typ {
323 assign[typ] = append(assign[typ], n)
324 }
325 return
326 }
327 typeof[n] = typ
328
329
330
331
332
333
334
335 if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") {
336 typeof[id.Obj] = typ
337 }
338 }
339
340
341
342
343 typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) {
344 if len(lhs) > 1 && len(rhs) == 1 {
345 if _, ok := rhs[0].(*ast.CallExpr); ok {
346 t := split(typeof[rhs[0]])
347
348 for i := 0; i < len(lhs) && i < len(t); i++ {
349 set(lhs[i], t[i], isDecl)
350 }
351 return
352 }
353 }
354 if len(lhs) == 1 && len(rhs) == 2 {
355
356 rhs = rhs[:1]
357 } else if len(lhs) == 2 && len(rhs) == 1 {
358
359 lhs = lhs[:1]
360 }
361
362
363 for i := 0; i < len(lhs) && i < len(rhs); i++ {
364 x, y := lhs[i], rhs[i]
365 if typeof[y] != "" {
366 set(x, typeof[y], isDecl)
367 } else {
368 set(y, typeof[x], false)
369 }
370 }
371 }
372
373 expand := func(s string) string {
374 typ := cfg.Type[s]
375 if typ != nil && typ.Def != "" {
376 return typ.Def
377 }
378 return s
379 }
380
381
382
383
384
385
386
387 var curfn []*ast.FuncType
388
389 before := func(n any) {
390
391 switch n := n.(type) {
392 case *ast.FuncDecl:
393 curfn = append(curfn, n.Type)
394 case *ast.FuncLit:
395 curfn = append(curfn, n.Type)
396 }
397 }
398
399
400 after := func(n any) {
401 if n == nil {
402 return
403 }
404 if false && reflect.TypeOf(n).Kind() == reflect.Pointer {
405 defer func() {
406 if t := typeof[n]; t != "" {
407 pos := fset.Position(n.(ast.Node).Pos())
408 fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t)
409 }
410 }()
411 }
412
413 switch n := n.(type) {
414 case *ast.FuncDecl, *ast.FuncLit:
415
416 curfn = curfn[:len(curfn)-1]
417
418 case *ast.FuncType:
419 typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results])))
420
421 case *ast.FieldList:
422
423 t := ""
424 for _, field := range n.List {
425 if t != "" {
426 t += ", "
427 }
428 t += typeof[field]
429 }
430 typeof[n] = t
431
432 case *ast.Field:
433
434 all := ""
435 t := typeof[n.Type]
436 if !isType(t) {
437
438
439 t = mkType(gofmt(n.Type))
440 typeof[n.Type] = t
441 }
442 t = getType(t)
443 if len(n.Names) == 0 {
444 all = t
445 } else {
446 for _, id := range n.Names {
447 if all != "" {
448 all += ", "
449 }
450 all += t
451 typeof[id.Obj] = t
452 typeof[id] = t
453 }
454 }
455 typeof[n] = all
456
457 case *ast.ValueSpec:
458
459 if n.Type != nil {
460 t := typeof[n.Type]
461 if !isType(t) {
462 t = mkType(gofmt(n.Type))
463 typeof[n.Type] = t
464 }
465 t = getType(t)
466 for _, id := range n.Names {
467 set(id, t, true)
468 }
469 }
470
471 typecheckAssign(makeExprList(n.Names), n.Values, true)
472
473 case *ast.AssignStmt:
474 typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE)
475
476 case *ast.Ident:
477
478 if t := typeof[n.Obj]; t != "" {
479 typeof[n] = t
480 }
481
482 case *ast.SelectorExpr:
483
484 name := n.Sel.Name
485 if t := typeof[n.X]; t != "" {
486 t = strings.TrimPrefix(t, "*")
487 if typ := cfg.Type[t]; typ != nil {
488 if t := typ.dot(cfg, name); t != "" {
489 typeof[n] = t
490 return
491 }
492 }
493 tt := typeof[t+"."+name]
494 if isType(tt) {
495 typeof[n] = getType(tt)
496 return
497 }
498 }
499
500 if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil {
501 str := x.Name + "." + name
502 if cfg.Type[str] != nil {
503 typeof[n] = mkType(str)
504 return
505 }
506 if t := cfg.typeof(x.Name + "." + name); t != "" {
507 typeof[n] = t
508 return
509 }
510 }
511
512 case *ast.CallExpr:
513
514 if isTopName(n.Fun, "make") && len(n.Args) >= 1 {
515 typeof[n] = gofmt(n.Args[0])
516 return
517 }
518
519 if isTopName(n.Fun, "new") && len(n.Args) == 1 {
520 typeof[n] = "*" + gofmt(n.Args[0])
521 return
522 }
523
524 t := typeof[n.Fun]
525 if t == "" {
526 t = cfg.External[gofmt(n.Fun)]
527 }
528 in, out := splitFunc(t)
529 if in == nil && out == nil {
530 return
531 }
532 typeof[n] = join(out)
533 for i, arg := range n.Args {
534 if i >= len(in) {
535 break
536 }
537 if typeof[arg] == "" {
538 typeof[arg] = in[i]
539 }
540 }
541
542 case *ast.TypeAssertExpr:
543
544 if n.Type == nil {
545 typeof[n] = typeof[n.X]
546 return
547 }
548
549 if t := typeof[n.Type]; isType(t) {
550 typeof[n] = getType(t)
551 } else {
552 typeof[n] = gofmt(n.Type)
553 }
554
555 case *ast.SliceExpr:
556
557 typeof[n] = typeof[n.X]
558
559 case *ast.IndexExpr:
560
561 t := expand(typeof[n.X])
562 if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") {
563
564
565 if _, elem, ok := strings.Cut(t, "]"); ok {
566 typeof[n] = elem
567 }
568 }
569
570 case *ast.StarExpr:
571
572
573
574 t := expand(typeof[n.X])
575 if isType(t) {
576 typeof[n] = "type *" + getType(t)
577 } else if strings.HasPrefix(t, "*") {
578 typeof[n] = t[len("*"):]
579 }
580
581 case *ast.UnaryExpr:
582
583 t := typeof[n.X]
584 if t != "" && n.Op == token.AND {
585 typeof[n] = "*" + t
586 }
587
588 case *ast.CompositeLit:
589
590 typeof[n] = gofmt(n.Type)
591
592
593 t := expand(typeof[n])
594 if strings.HasPrefix(t, "[") {
595
596 if _, et, ok := strings.Cut(t, "]"); ok {
597 for _, e := range n.Elts {
598 if kv, ok := e.(*ast.KeyValueExpr); ok {
599 e = kv.Value
600 }
601 if typeof[e] == "" {
602 typeof[e] = et
603 }
604 }
605 }
606 }
607 if strings.HasPrefix(t, "map[") {
608
609 if kt, vt, ok := strings.Cut(t[len("map["):], "]"); ok {
610 for _, e := range n.Elts {
611 if kv, ok := e.(*ast.KeyValueExpr); ok {
612 if typeof[kv.Key] == "" {
613 typeof[kv.Key] = kt
614 }
615 if typeof[kv.Value] == "" {
616 typeof[kv.Value] = vt
617 }
618 }
619 }
620 }
621 }
622 if typ := cfg.Type[t]; typ != nil && len(typ.Field) > 0 {
623 for _, e := range n.Elts {
624 if kv, ok := e.(*ast.KeyValueExpr); ok {
625 if ft := typ.Field[fmt.Sprintf("%s", kv.Key)]; ft != "" {
626 if typeof[kv.Value] == "" {
627 typeof[kv.Value] = ft
628 }
629 }
630 }
631 }
632 }
633
634 case *ast.ParenExpr:
635
636 typeof[n] = typeof[n.X]
637
638 case *ast.RangeStmt:
639 t := expand(typeof[n.X])
640 if t == "" {
641 return
642 }
643 var key, value string
644 if t == "string" {
645 key, value = "int", "rune"
646 } else if strings.HasPrefix(t, "[") {
647 key = "int"
648 _, value, _ = strings.Cut(t, "]")
649 } else if strings.HasPrefix(t, "map[") {
650 if k, v, ok := strings.Cut(t[len("map["):], "]"); ok {
651 key, value = k, v
652 }
653 }
654 changed := false
655 if n.Key != nil && key != "" {
656 changed = true
657 set(n.Key, key, n.Tok == token.DEFINE)
658 }
659 if n.Value != nil && value != "" {
660 changed = true
661 set(n.Value, value, n.Tok == token.DEFINE)
662 }
663
664
665 if changed {
666 typecheck1(cfg, n.Body, typeof, assign)
667 }
668
669 case *ast.TypeSwitchStmt:
670
671
672
673
674 as, ok := n.Assign.(*ast.AssignStmt)
675 if !ok {
676 return
677 }
678 varx, ok := as.Lhs[0].(*ast.Ident)
679 if !ok {
680 return
681 }
682 t := typeof[varx]
683 for _, cas := range n.Body.List {
684 cas := cas.(*ast.CaseClause)
685 if len(cas.List) == 1 {
686
687
688 if tt := typeof[cas.List[0]]; isType(tt) {
689 tt = getType(tt)
690 typeof[varx] = tt
691 typeof[varx.Obj] = tt
692 typecheck1(cfg, cas.Body, typeof, assign)
693 }
694 }
695 }
696
697 typeof[varx] = t
698 typeof[varx.Obj] = t
699
700 case *ast.ReturnStmt:
701 if len(curfn) == 0 {
702
703 return
704 }
705 f := curfn[len(curfn)-1]
706 res := n.Results
707 if f.Results != nil {
708 t := split(typeof[f.Results])
709 for i := 0; i < len(res) && i < len(t); i++ {
710 set(res[i], t[i], false)
711 }
712 }
713
714 case *ast.BinaryExpr:
715
716 switch n.Op {
717 case token.EQL, token.NEQ:
718 if typeof[n.X] != "" && typeof[n.Y] == "" {
719 typeof[n.Y] = typeof[n.X]
720 }
721 if typeof[n.X] == "" && typeof[n.Y] != "" {
722 typeof[n.X] = typeof[n.Y]
723 }
724 }
725 }
726 }
727 walkBeforeAfter(f, before, after)
728 }
729
730
731
732
733
734
735
736 func splitFunc(s string) (in, out []string) {
737 if !strings.HasPrefix(s, "func(") {
738 return nil, nil
739 }
740
741 i := len("func(")
742 nparen := 0
743 for j := i; j < len(s); j++ {
744 switch s[j] {
745 case '(':
746 nparen++
747 case ')':
748 nparen--
749 if nparen < 0 {
750
751 out := strings.TrimSpace(s[j+1:])
752 if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' {
753 out = out[1 : len(out)-1]
754 }
755 return split(s[i:j]), split(out)
756 }
757 }
758 }
759 return nil, nil
760 }
761
762
763 func joinFunc(in, out []string) string {
764 outs := ""
765 if len(out) == 1 {
766 outs = " " + out[0]
767 } else if len(out) > 1 {
768 outs = " (" + join(out) + ")"
769 }
770 return "func(" + join(in) + ")" + outs
771 }
772
773
774 func split(s string) []string {
775 out := []string{}
776 i := 0
777 nparen := 0
778 for j := 0; j < len(s); j++ {
779 switch s[j] {
780 case ' ':
781 if i == j {
782 i++
783 }
784 case '(':
785 nparen++
786 case ')':
787 nparen--
788 if nparen < 0 {
789
790 return nil
791 }
792 case ',':
793 if nparen == 0 {
794 if i < j {
795 out = append(out, s[i:j])
796 }
797 i = j + 1
798 }
799 }
800 }
801 if nparen != 0 {
802
803 return nil
804 }
805 if i < len(s) {
806 out = append(out, s[i:])
807 }
808 return out
809 }
810
811
812 func join(x []string) string {
813 return strings.Join(x, ", ")
814 }
815
View as plain text