1
2
3
4
5
6
7
8
9
10 package main
11
12 import (
13 "bytes"
14 "fmt"
15 "go/ast"
16 "go/format"
17 "go/parser"
18 "go/token"
19 "io/fs"
20 "log"
21 "os"
22 "slices"
23 "strings"
24 )
25
26 var fset = token.NewFileSet()
27
28 var buf bytes.Buffer
29
30
31
32 var concreteNodes []*ast.TypeSpec
33
34
35 var interfaceNodes []*ast.TypeSpec
36
37
38 var mini = map[string]*ast.TypeSpec{}
39
40
41
42 func implementsNode(t ast.Expr) bool {
43 id, ok := t.(*ast.Ident)
44 if !ok {
45 return false
46 }
47 for _, ts := range interfaceNodes {
48 if ts.Name.Name == id.Name {
49 return true
50 }
51 }
52 for _, ts := range concreteNodes {
53 if ts.Name.Name == id.Name {
54 return true
55 }
56 }
57 return false
58 }
59
60 func isMini(t ast.Expr) bool {
61 id, ok := t.(*ast.Ident)
62 return ok && mini[id.Name] != nil
63 }
64
65 func isNamedType(t ast.Expr, name string) bool {
66 if id, ok := t.(*ast.Ident); ok {
67 if id.Name == name {
68 return true
69 }
70 }
71 return false
72 }
73
74 func main() {
75 fmt.Fprintln(&buf, "// Code generated by mknode.go. DO NOT EDIT.")
76 fmt.Fprintln(&buf)
77 fmt.Fprintln(&buf, "package ir")
78 fmt.Fprintln(&buf)
79 fmt.Fprintln(&buf, `import "fmt"`)
80
81 filter := func(file fs.FileInfo) bool {
82 return !strings.HasPrefix(file.Name(), "mknode")
83 }
84 pkgs, err := parser.ParseDir(fset, ".", filter, 0)
85 if err != nil {
86 panic(err)
87 }
88 pkg := pkgs["ir"]
89
90
91
92 for _, f := range pkg.Files {
93 for _, d := range f.Decls {
94 g, ok := d.(*ast.GenDecl)
95 if !ok {
96 continue
97 }
98 for _, s := range g.Specs {
99 t, ok := s.(*ast.TypeSpec)
100 if !ok {
101 continue
102 }
103 if strings.HasPrefix(t.Name.Name, "mini") {
104 mini[t.Name.Name] = t
105
106 if t.Name.Name != "miniNode" {
107 s := t.Type.(*ast.StructType)
108 if !isNamedType(s.Fields.List[0].Type, "miniNode") {
109 panic(fmt.Sprintf("can't find miniNode in %s", t.Name.Name))
110 }
111 }
112 }
113 }
114 }
115 }
116
117
118 for _, f := range pkg.Files {
119 for _, d := range f.Decls {
120 g, ok := d.(*ast.GenDecl)
121 if !ok {
122 continue
123 }
124 for _, s := range g.Specs {
125 t, ok := s.(*ast.TypeSpec)
126 if !ok {
127 continue
128 }
129 if strings.HasPrefix(t.Name.Name, "mini") {
130
131
132
133
134 continue
135 }
136 if isConcreteNode(t) {
137 concreteNodes = append(concreteNodes, t)
138 }
139 if isInterfaceNode(t) {
140 interfaceNodes = append(interfaceNodes, t)
141 }
142 }
143 }
144 }
145
146 slices.SortFunc(concreteNodes, func(a, b *ast.TypeSpec) int {
147 return strings.Compare(a.Name.Name, b.Name.Name)
148 })
149
150 for _, t := range concreteNodes {
151 processType(t)
152 }
153
154 generateHelpers()
155
156
157 out, err := format.Source(buf.Bytes())
158 if err != nil {
159
160 out = buf.Bytes()
161 }
162 err = os.WriteFile("node_gen.go", out, 0666)
163 if err != nil {
164 log.Fatal(err)
165 }
166 }
167
168
169
170 func isConcreteNode(t *ast.TypeSpec) bool {
171 s, ok := t.Type.(*ast.StructType)
172 if !ok {
173 return false
174 }
175 for _, f := range s.Fields.List {
176 if isMini(f.Type) {
177 return true
178 }
179 }
180 return false
181 }
182
183
184
185 func isInterfaceNode(t *ast.TypeSpec) bool {
186 s, ok := t.Type.(*ast.InterfaceType)
187 if !ok {
188 return false
189 }
190 if t.Name.Name == "Node" {
191 return true
192 }
193 if t.Name.Name == "OrigNode" || t.Name.Name == "InitNode" {
194
195
196 return false
197 }
198
199
200
201
202 for _, f := range s.Methods.List {
203 if len(f.Names) != 0 {
204 continue
205 }
206 if isNamedType(f.Type, "Node") {
207 return true
208 }
209 }
210 return false
211 }
212
213 func processType(t *ast.TypeSpec) {
214 name := t.Name.Name
215 fmt.Fprintf(&buf, "\n")
216 fmt.Fprintf(&buf, "func (n *%s) Format(s fmt.State, verb rune) { fmtNode(n, s, verb) }\n", name)
217
218 switch name {
219 case "Name", "Func":
220
221 return
222 }
223
224 s := t.Type.(*ast.StructType)
225 fields := s.Fields.List
226
227
228 for i := 0; i < len(fields); i++ {
229 f := fields[i]
230 if len(f.Names) != 0 {
231 continue
232 }
233 if isMini(f.Type) {
234
235
236
237 ss := mini[f.Type.(*ast.Ident).Name].Type.(*ast.StructType)
238 var f2 []*ast.Field
239 f2 = append(f2, fields[:i]...)
240 f2 = append(f2, ss.Fields.List...)
241 f2 = append(f2, fields[i+1:]...)
242 fields = f2
243 i--
244 continue
245 } else if isNamedType(f.Type, "origNode") {
246
247 copy(fields[i:], fields[i+1:])
248 fields = fields[:len(fields)-1]
249 i--
250 continue
251 } else {
252 panic("unknown embedded field " + fmt.Sprintf("%v", f.Type))
253 }
254 }
255
256 var copyBody strings.Builder
257 var doChildrenBody strings.Builder
258 var doChildrenWithHiddenBody strings.Builder
259 var editChildrenBody strings.Builder
260 var editChildrenWithHiddenBody strings.Builder
261 for _, f := range fields {
262 names := f.Names
263 ft := f.Type
264 hidden := false
265 if f.Tag != nil {
266 tag := f.Tag.Value[1 : len(f.Tag.Value)-1]
267 if strings.HasPrefix(tag, "mknode:") {
268 if tag[7:] == "\"-\"" {
269 if !isNamedType(ft, "Node") {
270 continue
271 }
272 hidden = true
273 } else {
274 panic(fmt.Sprintf("unexpected tag value: %s", tag))
275 }
276 }
277 }
278 if isNamedType(ft, "Nodes") {
279
280 ft = &ast.ArrayType{Elt: &ast.Ident{Name: "Node"}}
281 }
282 isSlice := false
283 if a, ok := ft.(*ast.ArrayType); ok && a.Len == nil {
284 isSlice = true
285 ft = a.Elt
286 }
287 isPtr := false
288 if p, ok := ft.(*ast.StarExpr); ok {
289 isPtr = true
290 ft = p.X
291 }
292 if !implementsNode(ft) {
293 continue
294 }
295 for _, name := range names {
296 ptr := ""
297 if isPtr {
298 ptr = "*"
299 }
300 if isSlice {
301 fmt.Fprintf(&doChildrenWithHiddenBody,
302 "if do%ss(n.%s, do) {\nreturn true\n}\n", ft, name)
303 fmt.Fprintf(&editChildrenWithHiddenBody,
304 "edit%ss(n.%s, edit)\n", ft, name)
305 } else {
306 fmt.Fprintf(&doChildrenWithHiddenBody,
307 "if n.%s != nil && do(n.%s) {\nreturn true\n}\n", name, name)
308 fmt.Fprintf(&editChildrenWithHiddenBody,
309 "if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft)
310 }
311 if hidden {
312 continue
313 }
314 if isSlice {
315 fmt.Fprintf(©Body, "c.%s = copy%ss(c.%s)\n", name, ft, name)
316 fmt.Fprintf(&doChildrenBody,
317 "if do%ss(n.%s, do) {\nreturn true\n}\n", ft, name)
318 fmt.Fprintf(&editChildrenBody,
319 "edit%ss(n.%s, edit)\n", ft, name)
320 } else {
321 fmt.Fprintf(&doChildrenBody,
322 "if n.%s != nil && do(n.%s) {\nreturn true\n}\n", name, name)
323 fmt.Fprintf(&editChildrenBody,
324 "if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft)
325 }
326 }
327 }
328 fmt.Fprintf(&buf, "func (n *%s) copy() Node {\nc := *n\n", name)
329 buf.WriteString(copyBody.String())
330 fmt.Fprintf(&buf, "return &c\n}\n")
331 fmt.Fprintf(&buf, "func (n *%s) doChildren(do func(Node) bool) bool {\n", name)
332 buf.WriteString(doChildrenBody.String())
333 fmt.Fprintf(&buf, "return false\n}\n")
334 fmt.Fprintf(&buf, "func (n *%s) doChildrenWithHidden(do func(Node) bool) bool {\n", name)
335 buf.WriteString(doChildrenWithHiddenBody.String())
336 fmt.Fprintf(&buf, "return false\n}\n")
337 fmt.Fprintf(&buf, "func (n *%s) editChildren(edit func(Node) Node) {\n", name)
338 buf.WriteString(editChildrenBody.String())
339 fmt.Fprintf(&buf, "}\n")
340 fmt.Fprintf(&buf, "func (n *%s) editChildrenWithHidden(edit func(Node) Node) {\n", name)
341 buf.WriteString(editChildrenWithHiddenBody.String())
342 fmt.Fprintf(&buf, "}\n")
343 }
344
345 func generateHelpers() {
346 for _, typ := range []string{"CaseClause", "CommClause", "Name", "Node"} {
347 ptr := "*"
348 if typ == "Node" {
349 ptr = ""
350 }
351 fmt.Fprintf(&buf, "\n")
352 fmt.Fprintf(&buf, "func copy%ss(list []%s%s) []%s%s {\n", typ, ptr, typ, ptr, typ)
353 fmt.Fprintf(&buf, "if list == nil { return nil }\n")
354 fmt.Fprintf(&buf, "c := make([]%s%s, len(list))\n", ptr, typ)
355 fmt.Fprintf(&buf, "copy(c, list)\n")
356 fmt.Fprintf(&buf, "return c\n")
357 fmt.Fprintf(&buf, "}\n")
358 fmt.Fprintf(&buf, "func do%ss(list []%s%s, do func(Node) bool) bool {\n", typ, ptr, typ)
359 fmt.Fprintf(&buf, "for _, x := range list {\n")
360 fmt.Fprintf(&buf, "if x != nil && do(x) {\n")
361 fmt.Fprintf(&buf, "return true\n")
362 fmt.Fprintf(&buf, "}\n")
363 fmt.Fprintf(&buf, "}\n")
364 fmt.Fprintf(&buf, "return false\n")
365 fmt.Fprintf(&buf, "}\n")
366 fmt.Fprintf(&buf, "func edit%ss(list []%s%s, edit func(Node) Node) {\n", typ, ptr, typ)
367 fmt.Fprintf(&buf, "for i, x := range list {\n")
368 fmt.Fprintf(&buf, "if x != nil {\n")
369 fmt.Fprintf(&buf, "list[i] = edit(x).(%s%s)\n", ptr, typ)
370 fmt.Fprintf(&buf, "}\n")
371 fmt.Fprintf(&buf, "}\n")
372 fmt.Fprintf(&buf, "}\n")
373 }
374 }
375
View as plain text