1
2
3
4
5
6
7
8
9
10 package main
11
12 import (
13 "bytes"
14 "fmt"
15 "go/ast"
16 "go/format"
17 "go/parser"
18 "go/token"
19 "io/fs"
20 "log"
21 "os"
22 "slices"
23 "strings"
24 )
25
26 var fset = token.NewFileSet()
27
28 var buf bytes.Buffer
29
30
31
32 var concreteNodes []*ast.TypeSpec
33
34
35 var interfaceNodes []*ast.TypeSpec
36
37
38 var mini = map[string]*ast.TypeSpec{}
39
40
41
42 func implementsNode(t ast.Expr) bool {
43 id, ok := t.(*ast.Ident)
44 if !ok {
45 return false
46 }
47 for _, ts := range interfaceNodes {
48 if ts.Name.Name == id.Name {
49 return true
50 }
51 }
52 for _, ts := range concreteNodes {
53 if ts.Name.Name == id.Name {
54 return true
55 }
56 }
57 return false
58 }
59
60 func isMini(t ast.Expr) bool {
61 id, ok := t.(*ast.Ident)
62 return ok && mini[id.Name] != nil
63 }
64
65 func isNamedType(t ast.Expr, name string) bool {
66 if id, ok := t.(*ast.Ident); ok {
67 if id.Name == name {
68 return true
69 }
70 }
71 return false
72 }
73
74 func main() {
75 fmt.Fprintln(&buf, "// Code generated by mknode.go. DO NOT EDIT.")
76 fmt.Fprintln(&buf)
77 fmt.Fprintln(&buf, "package ir")
78 fmt.Fprintln(&buf)
79 fmt.Fprintln(&buf, `import "fmt"`)
80
81 filter := func(file fs.FileInfo) bool {
82 return !strings.HasPrefix(file.Name(), "mknode")
83 }
84 pkgs, err := parser.ParseDir(fset, ".", filter, 0)
85 if err != nil {
86 panic(err)
87 }
88 pkg := pkgs["ir"]
89
90
91
92 for _, f := range pkg.Files {
93 for _, d := range f.Decls {
94 g, ok := d.(*ast.GenDecl)
95 if !ok {
96 continue
97 }
98 for _, s := range g.Specs {
99 t, ok := s.(*ast.TypeSpec)
100 if !ok {
101 continue
102 }
103 if strings.HasPrefix(t.Name.Name, "mini") {
104 mini[t.Name.Name] = t
105
106 if t.Name.Name != "miniNode" {
107 s := t.Type.(*ast.StructType)
108 if !isNamedType(s.Fields.List[0].Type, "miniNode") {
109 panic(fmt.Sprintf("can't find miniNode in %s", t.Name.Name))
110 }
111 }
112 }
113 }
114 }
115 }
116
117
118 for _, f := range pkg.Files {
119 for _, d := range f.Decls {
120 g, ok := d.(*ast.GenDecl)
121 if !ok {
122 continue
123 }
124 for _, s := range g.Specs {
125 t, ok := s.(*ast.TypeSpec)
126 if !ok {
127 continue
128 }
129 if strings.HasPrefix(t.Name.Name, "mini") {
130
131
132
133
134 continue
135 }
136 if isConcreteNode(t) {
137 concreteNodes = append(concreteNodes, t)
138 }
139 if isInterfaceNode(t) {
140 interfaceNodes = append(interfaceNodes, t)
141 }
142 }
143 }
144 }
145
146 slices.SortFunc(concreteNodes, func(a, b *ast.TypeSpec) int {
147 return strings.Compare(a.Name.Name, b.Name.Name)
148 })
149
150 for _, t := range concreteNodes {
151 processType(t)
152 }
153
154 generateHelpers()
155
156
157 out, err := format.Source(buf.Bytes())
158 if err != nil {
159
160 out = buf.Bytes()
161 }
162 err = os.WriteFile("node_gen.go", out, 0666)
163 if err != nil {
164 log.Fatal(err)
165 }
166 }
167
168
169
170 func isConcreteNode(t *ast.TypeSpec) bool {
171 s, ok := t.Type.(*ast.StructType)
172 if !ok {
173 return false
174 }
175 for _, f := range s.Fields.List {
176 if isMini(f.Type) {
177 return true
178 }
179 }
180 return false
181 }
182
183
184
185 func isInterfaceNode(t *ast.TypeSpec) bool {
186 s, ok := t.Type.(*ast.InterfaceType)
187 if !ok {
188 return false
189 }
190 if t.Name.Name == "Node" {
191 return true
192 }
193 if t.Name.Name == "OrigNode" || t.Name.Name == "InitNode" {
194
195
196 return false
197 }
198
199
200
201
202 for _, f := range s.Methods.List {
203 if len(f.Names) != 0 {
204 continue
205 }
206 if isNamedType(f.Type, "Node") {
207 return true
208 }
209 }
210 return false
211 }
212
213 func processType(t *ast.TypeSpec) {
214 name := t.Name.Name
215 fmt.Fprintf(&buf, "\n")
216 fmt.Fprintf(&buf, "func (n *%s) Format(s fmt.State, verb rune) { fmtNode(n, s, verb) }\n", name)
217
218 switch name {
219 case "Name", "Func":
220
221 return
222 }
223
224 s := t.Type.(*ast.StructType)
225 fields := s.Fields.List
226
227
228 for i := 0; i < len(fields); i++ {
229 f := fields[i]
230 if len(f.Names) != 0 {
231 continue
232 }
233 if isMini(f.Type) {
234
235
236
237 ss := mini[f.Type.(*ast.Ident).Name].Type.(*ast.StructType)
238 var f2 []*ast.Field
239 f2 = append(f2, fields[:i]...)
240 f2 = append(f2, ss.Fields.List...)
241 f2 = append(f2, fields[i+1:]...)
242 fields = f2
243 i--
244 continue
245 } else if isNamedType(f.Type, "origNode") {
246
247 copy(fields[i:], fields[i+1:])
248 fields = fields[:len(fields)-1]
249 i--
250 continue
251 } else {
252 panic("unknown embedded field " + fmt.Sprintf("%v", f.Type))
253 }
254 }
255
256 var copyBody strings.Builder
257 var doChildrenBody strings.Builder
258 var doChildrenWithHiddenBody strings.Builder
259 var editChildrenBody strings.Builder
260 var editChildrenWithHiddenBody strings.Builder
261 var hasHidden bool
262 for _, f := range fields {
263 names := f.Names
264 ft := f.Type
265 hidden := false
266 if f.Tag != nil {
267 tag := f.Tag.Value[1 : len(f.Tag.Value)-1]
268 if strings.HasPrefix(tag, "mknode:") {
269 if tag[7:] == "\"-\"" {
270 if !isNamedType(ft, "Node") {
271 continue
272 }
273 hidden = true
274 } else {
275 panic(fmt.Sprintf("unexpected tag value: %s", tag))
276 }
277 }
278 }
279 if isNamedType(ft, "Nodes") {
280
281 ft = &ast.ArrayType{Elt: &ast.Ident{Name: "Node"}}
282 }
283 isSlice := false
284 if a, ok := ft.(*ast.ArrayType); ok && a.Len == nil {
285 isSlice = true
286 ft = a.Elt
287 }
288 isPtr := false
289 if p, ok := ft.(*ast.StarExpr); ok {
290 isPtr = true
291 ft = p.X
292 }
293 if !implementsNode(ft) {
294 continue
295 }
296 for _, name := range names {
297 ptr := ""
298 if isPtr {
299 ptr = "*"
300 }
301 if isSlice {
302 fmt.Fprintf(&doChildrenWithHiddenBody,
303 "if do%ss(n.%s, do) {\nreturn true\n}\n", ft, name)
304 fmt.Fprintf(&editChildrenWithHiddenBody,
305 "edit%ss(n.%s, edit)\n", ft, name)
306 } else {
307 fmt.Fprintf(&doChildrenWithHiddenBody,
308 "if n.%s != nil && do(n.%s) {\nreturn true\n}\n", name, name)
309 fmt.Fprintf(&editChildrenWithHiddenBody,
310 "if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft)
311 }
312 if hidden {
313 hasHidden = true
314 continue
315 }
316 if isSlice {
317 fmt.Fprintf(©Body, "c.%s = copy%ss(c.%s)\n", name, ft, name)
318 fmt.Fprintf(&doChildrenBody,
319 "if do%ss(n.%s, do) {\nreturn true\n}\n", ft, name)
320 fmt.Fprintf(&editChildrenBody,
321 "edit%ss(n.%s, edit)\n", ft, name)
322 } else {
323 fmt.Fprintf(&doChildrenBody,
324 "if n.%s != nil && do(n.%s) {\nreturn true\n}\n", name, name)
325 fmt.Fprintf(&editChildrenBody,
326 "if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft)
327 }
328 }
329 }
330 fmt.Fprintf(&buf, "func (n *%s) copy() Node {\nc := *n\n", name)
331 buf.WriteString(copyBody.String())
332 buf.WriteString("return &c\n}\n")
333 fmt.Fprintf(&buf, "func (n *%s) doChildren(do func(Node) bool) bool {\n", name)
334 buf.WriteString(doChildrenBody.String())
335 buf.WriteString("return false\n}\n")
336 fmt.Fprintf(&buf, "func (n *%s) doChildrenWithHidden(do func(Node) bool) bool {\n", name)
337 if hasHidden {
338 buf.WriteString(doChildrenWithHiddenBody.String())
339 buf.WriteString("return false\n}\n")
340 } else {
341 buf.WriteString("return n.doChildren(do)\n}\n")
342 }
343 fmt.Fprintf(&buf, "func (n *%s) editChildren(edit func(Node) Node) {\n", name)
344 buf.WriteString(editChildrenBody.String())
345 buf.WriteString("}\n")
346 fmt.Fprintf(&buf, "func (n *%s) editChildrenWithHidden(edit func(Node) Node) {\n", name)
347 if hasHidden {
348 buf.WriteString(editChildrenWithHiddenBody.String())
349 } else {
350 buf.WriteString("n.editChildren(edit)\n")
351 }
352 buf.WriteString("}\n")
353 }
354
355 func generateHelpers() {
356 for _, typ := range []string{"CaseClause", "CommClause", "Name", "Node"} {
357 ptr := "*"
358 if typ == "Node" {
359 ptr = ""
360 }
361 fmt.Fprintf(&buf, "\n")
362 fmt.Fprintf(&buf, "func copy%ss(list []%s%s) []%s%s {\n", typ, ptr, typ, ptr, typ)
363 fmt.Fprintf(&buf, "if list == nil { return nil }\n")
364 fmt.Fprintf(&buf, "c := make([]%s%s, len(list))\n", ptr, typ)
365 fmt.Fprintf(&buf, "copy(c, list)\n")
366 fmt.Fprintf(&buf, "return c\n")
367 fmt.Fprintf(&buf, "}\n")
368 fmt.Fprintf(&buf, "func do%ss(list []%s%s, do func(Node) bool) bool {\n", typ, ptr, typ)
369 fmt.Fprintf(&buf, "for _, x := range list {\n")
370 fmt.Fprintf(&buf, "if x != nil && do(x) {\n")
371 fmt.Fprintf(&buf, "return true\n")
372 fmt.Fprintf(&buf, "}\n")
373 fmt.Fprintf(&buf, "}\n")
374 fmt.Fprintf(&buf, "return false\n")
375 fmt.Fprintf(&buf, "}\n")
376 fmt.Fprintf(&buf, "func edit%ss(list []%s%s, edit func(Node) Node) {\n", typ, ptr, typ)
377 fmt.Fprintf(&buf, "for i, x := range list {\n")
378 fmt.Fprintf(&buf, "if x != nil {\n")
379 fmt.Fprintf(&buf, "list[i] = edit(x).(%s%s)\n", ptr, typ)
380 fmt.Fprintf(&buf, "}\n")
381 fmt.Fprintf(&buf, "}\n")
382 fmt.Fprintf(&buf, "}\n")
383 }
384 }
385
View as plain text