// Copyright 2022 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build ignore // Note: this program must be run in this directory. // go run mknode.go package main import ( "bytes" "fmt" "go/ast" "go/format" "go/parser" "go/token" "io/fs" "log" "os" "slices" "strings" ) var fset = token.NewFileSet() var buf bytes.Buffer // concreteNodes contains all concrete types in the package that implement Node // (except for the mini* types). var concreteNodes []*ast.TypeSpec // interfaceNodes contains all interface types in the package that implement Node. var interfaceNodes []*ast.TypeSpec // mini contains the embeddable mini types (miniNode, miniExpr, and miniStmt). var mini = map[string]*ast.TypeSpec{} // implementsNode reports whether the type t is one which represents a Node // in the AST. func implementsNode(t ast.Expr) bool { id, ok := t.(*ast.Ident) if !ok { return false // only named types } for _, ts := range interfaceNodes { if ts.Name.Name == id.Name { return true } } for _, ts := range concreteNodes { if ts.Name.Name == id.Name { return true } } return false } func isMini(t ast.Expr) bool { id, ok := t.(*ast.Ident) return ok && mini[id.Name] != nil } func isNamedType(t ast.Expr, name string) bool { if id, ok := t.(*ast.Ident); ok { if id.Name == name { return true } } return false } func main() { fmt.Fprintln(&buf, "// Code generated by mknode.go. DO NOT EDIT.") fmt.Fprintln(&buf) fmt.Fprintln(&buf, "package ir") fmt.Fprintln(&buf) fmt.Fprintln(&buf, `import "fmt"`) filter := func(file fs.FileInfo) bool { return !strings.HasPrefix(file.Name(), "mknode") } pkgs, err := parser.ParseDir(fset, ".", filter, 0) if err != nil { panic(err) } pkg := pkgs["ir"] // Find all the mini types. These let us determine which // concrete types implement Node, so we need to find them first. for _, f := range pkg.Files { for _, d := range f.Decls { g, ok := d.(*ast.GenDecl) if !ok { continue } for _, s := range g.Specs { t, ok := s.(*ast.TypeSpec) if !ok { continue } if strings.HasPrefix(t.Name.Name, "mini") { mini[t.Name.Name] = t // Double-check that it is or embeds miniNode. if t.Name.Name != "miniNode" { s := t.Type.(*ast.StructType) if !isNamedType(s.Fields.List[0].Type, "miniNode") { panic(fmt.Sprintf("can't find miniNode in %s", t.Name.Name)) } } } } } } // Find all the declarations of concrete types that implement Node. for _, f := range pkg.Files { for _, d := range f.Decls { g, ok := d.(*ast.GenDecl) if !ok { continue } for _, s := range g.Specs { t, ok := s.(*ast.TypeSpec) if !ok { continue } if strings.HasPrefix(t.Name.Name, "mini") { // We don't treat the mini types as // concrete implementations of Node // (even though they are) because // we only use them by embedding them. continue } if isConcreteNode(t) { concreteNodes = append(concreteNodes, t) } if isInterfaceNode(t) { interfaceNodes = append(interfaceNodes, t) } } } } // Sort for deterministic output. slices.SortFunc(concreteNodes, func(a, b *ast.TypeSpec) int { return strings.Compare(a.Name.Name, b.Name.Name) }) // Generate code for each concrete type. for _, t := range concreteNodes { processType(t) } // Add some helpers. generateHelpers() // Format and write output. out, err := format.Source(buf.Bytes()) if err != nil { // write out mangled source so we can see the bug. out = buf.Bytes() } err = os.WriteFile("node_gen.go", out, 0666) if err != nil { log.Fatal(err) } } // isConcreteNode reports whether the type t is a concrete type // implementing Node. func isConcreteNode(t *ast.TypeSpec) bool { s, ok := t.Type.(*ast.StructType) if !ok { return false } for _, f := range s.Fields.List { if isMini(f.Type) { return true } } return false } // isInterfaceNode reports whether the type t is an interface type // implementing Node (including Node itself). func isInterfaceNode(t *ast.TypeSpec) bool { s, ok := t.Type.(*ast.InterfaceType) if !ok { return false } if t.Name.Name == "Node" { return true } if t.Name.Name == "OrigNode" || t.Name.Name == "InitNode" { // These we exempt from consideration (fields of // this type don't need to be walked or copied). return false } // Look for embedded Node type. // Note that this doesn't handle multi-level embedding, but // we have none of that at the moment. for _, f := range s.Methods.List { if len(f.Names) != 0 { continue } if isNamedType(f.Type, "Node") { return true } } return false } func processType(t *ast.TypeSpec) { name := t.Name.Name fmt.Fprintf(&buf, "\n") fmt.Fprintf(&buf, "func (n *%s) Format(s fmt.State, verb rune) { fmtNode(n, s, verb) }\n", name) switch name { case "Name", "Func": // Too specialized to automate. return } s := t.Type.(*ast.StructType) fields := s.Fields.List // Expand any embedded fields. for i := 0; i < len(fields); i++ { f := fields[i] if len(f.Names) != 0 { continue // not embedded } if isMini(f.Type) { // Insert the fields of the embedded type into the main type. // (It would be easier just to append, but inserting in place // matches the old mknode behavior.) ss := mini[f.Type.(*ast.Ident).Name].Type.(*ast.StructType) var f2 []*ast.Field f2 = append(f2, fields[:i]...) f2 = append(f2, ss.Fields.List...) f2 = append(f2, fields[i+1:]...) fields = f2 i-- continue } else if isNamedType(f.Type, "origNode") { // Ignore this field copy(fields[i:], fields[i+1:]) fields = fields[:len(fields)-1] i-- continue } else { panic("unknown embedded field " + fmt.Sprintf("%v", f.Type)) } } // Process fields. var copyBody strings.Builder var doChildrenBody strings.Builder var doChildrenWithHiddenBody strings.Builder var editChildrenBody strings.Builder var editChildrenWithHiddenBody strings.Builder for _, f := range fields { names := f.Names ft := f.Type hidden := false if f.Tag != nil { tag := f.Tag.Value[1 : len(f.Tag.Value)-1] if strings.HasPrefix(tag, "mknode:") { if tag[7:] == "\"-\"" { if !isNamedType(ft, "Node") { continue } hidden = true } else { panic(fmt.Sprintf("unexpected tag value: %s", tag)) } } } if isNamedType(ft, "Nodes") { // Nodes == []Node ft = &ast.ArrayType{Elt: &ast.Ident{Name: "Node"}} } isSlice := false if a, ok := ft.(*ast.ArrayType); ok && a.Len == nil { isSlice = true ft = a.Elt } isPtr := false if p, ok := ft.(*ast.StarExpr); ok { isPtr = true ft = p.X } if !implementsNode(ft) { continue } for _, name := range names { ptr := "" if isPtr { ptr = "*" } if isSlice { fmt.Fprintf(&doChildrenWithHiddenBody, "if do%ss(n.%s, do) {\nreturn true\n}\n", ft, name) fmt.Fprintf(&editChildrenWithHiddenBody, "edit%ss(n.%s, edit)\n", ft, name) } else { fmt.Fprintf(&doChildrenWithHiddenBody, "if n.%s != nil && do(n.%s) {\nreturn true\n}\n", name, name) fmt.Fprintf(&editChildrenWithHiddenBody, "if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft) } if hidden { continue } if isSlice { fmt.Fprintf(©Body, "c.%s = copy%ss(c.%s)\n", name, ft, name) fmt.Fprintf(&doChildrenBody, "if do%ss(n.%s, do) {\nreturn true\n}\n", ft, name) fmt.Fprintf(&editChildrenBody, "edit%ss(n.%s, edit)\n", ft, name) } else { fmt.Fprintf(&doChildrenBody, "if n.%s != nil && do(n.%s) {\nreturn true\n}\n", name, name) fmt.Fprintf(&editChildrenBody, "if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft) } } } fmt.Fprintf(&buf, "func (n *%s) copy() Node {\nc := *n\n", name) buf.WriteString(copyBody.String()) fmt.Fprintf(&buf, "return &c\n}\n") fmt.Fprintf(&buf, "func (n *%s) doChildren(do func(Node) bool) bool {\n", name) buf.WriteString(doChildrenBody.String()) fmt.Fprintf(&buf, "return false\n}\n") fmt.Fprintf(&buf, "func (n *%s) doChildrenWithHidden(do func(Node) bool) bool {\n", name) buf.WriteString(doChildrenWithHiddenBody.String()) fmt.Fprintf(&buf, "return false\n}\n") fmt.Fprintf(&buf, "func (n *%s) editChildren(edit func(Node) Node) {\n", name) buf.WriteString(editChildrenBody.String()) fmt.Fprintf(&buf, "}\n") fmt.Fprintf(&buf, "func (n *%s) editChildrenWithHidden(edit func(Node) Node) {\n", name) buf.WriteString(editChildrenWithHiddenBody.String()) fmt.Fprintf(&buf, "}\n") } func generateHelpers() { for _, typ := range []string{"CaseClause", "CommClause", "Name", "Node"} { ptr := "*" if typ == "Node" { ptr = "" // interfaces don't need * } fmt.Fprintf(&buf, "\n") fmt.Fprintf(&buf, "func copy%ss(list []%s%s) []%s%s {\n", typ, ptr, typ, ptr, typ) fmt.Fprintf(&buf, "if list == nil { return nil }\n") fmt.Fprintf(&buf, "c := make([]%s%s, len(list))\n", ptr, typ) fmt.Fprintf(&buf, "copy(c, list)\n") fmt.Fprintf(&buf, "return c\n") fmt.Fprintf(&buf, "}\n") fmt.Fprintf(&buf, "func do%ss(list []%s%s, do func(Node) bool) bool {\n", typ, ptr, typ) fmt.Fprintf(&buf, "for _, x := range list {\n") fmt.Fprintf(&buf, "if x != nil && do(x) {\n") fmt.Fprintf(&buf, "return true\n") fmt.Fprintf(&buf, "}\n") fmt.Fprintf(&buf, "}\n") fmt.Fprintf(&buf, "return false\n") fmt.Fprintf(&buf, "}\n") fmt.Fprintf(&buf, "func edit%ss(list []%s%s, edit func(Node) Node) {\n", typ, ptr, typ) fmt.Fprintf(&buf, "for i, x := range list {\n") fmt.Fprintf(&buf, "if x != nil {\n") fmt.Fprintf(&buf, "list[i] = edit(x).(%s%s)\n", ptr, typ) fmt.Fprintf(&buf, "}\n") fmt.Fprintf(&buf, "}\n") fmt.Fprintf(&buf, "}\n") } }