// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package ast import ( "fmt" "iter" ) // A Visitor's Visit method is invoked for each node encountered by [Walk]. // If the result visitor w is not nil, [Walk] visits each of the children // of node with the visitor w, followed by a call of w.Visit(nil). type Visitor interface { Visit(node Node) (w Visitor) } func walkList[N Node](v Visitor, list []N) { for _, node := range list { Walk(v, node) } } // TODO(gri): Investigate if providing a closure to Walk leads to // simpler use (and may help eliminate Inspect in turn). // Walk traverses an AST in depth-first order: It starts by calling // v.Visit(node); node must not be nil. If the visitor w returned by // v.Visit(node) is not nil, Walk is invoked recursively with visitor // w for each of the non-nil children of node, followed by a call of // w.Visit(nil). func Walk(v Visitor, node Node) { if v = v.Visit(node); v == nil { return } // walk children // (the order of the cases matches the order // of the corresponding node types in ast.go) switch n := node.(type) { // Comments and fields case *Comment: // nothing to do case *CommentGroup: walkList(v, n.List) case *Field: if n.Doc != nil { Walk(v, n.Doc) } walkList(v, n.Names) if n.Type != nil { Walk(v, n.Type) } if n.Tag != nil { Walk(v, n.Tag) } if n.Comment != nil { Walk(v, n.Comment) } case *FieldList: walkList(v, n.List) // Expressions case *BadExpr, *Ident, *BasicLit: // nothing to do case *Ellipsis: if n.Elt != nil { Walk(v, n.Elt) } case *FuncLit: Walk(v, n.Type) Walk(v, n.Body) case *CompositeLit: if n.Type != nil { Walk(v, n.Type) } walkList(v, n.Elts) case *ParenExpr: Walk(v, n.X) case *SelectorExpr: Walk(v, n.X) Walk(v, n.Sel) case *IndexExpr: Walk(v, n.X) Walk(v, n.Index) case *IndexListExpr: Walk(v, n.X) walkList(v, n.Indices) case *SliceExpr: Walk(v, n.X) if n.Low != nil { Walk(v, n.Low) } if n.High != nil { Walk(v, n.High) } if n.Max != nil { Walk(v, n.Max) } case *TypeAssertExpr: Walk(v, n.X) if n.Type != nil { Walk(v, n.Type) } case *CallExpr: Walk(v, n.Fun) walkList(v, n.Args) case *StarExpr: Walk(v, n.X) case *UnaryExpr: Walk(v, n.X) case *BinaryExpr: Walk(v, n.X) Walk(v, n.Y) case *KeyValueExpr: Walk(v, n.Key) Walk(v, n.Value) // Types case *ArrayType: if n.Len != nil { Walk(v, n.Len) } Walk(v, n.Elt) case *StructType: Walk(v, n.Fields) case *FuncType: if n.TypeParams != nil { Walk(v, n.TypeParams) } if n.Params != nil { Walk(v, n.Params) } if n.Results != nil { Walk(v, n.Results) } case *InterfaceType: Walk(v, n.Methods) case *MapType: Walk(v, n.Key) Walk(v, n.Value) case *ChanType: Walk(v, n.Value) // Statements case *BadStmt: // nothing to do case *DeclStmt: Walk(v, n.Decl) case *EmptyStmt: // nothing to do case *LabeledStmt: Walk(v, n.Label) Walk(v, n.Stmt) case *ExprStmt: Walk(v, n.X) case *SendStmt: Walk(v, n.Chan) Walk(v, n.Value) case *IncDecStmt: Walk(v, n.X) case *AssignStmt: walkList(v, n.Lhs) walkList(v, n.Rhs) case *GoStmt: Walk(v, n.Call) case *DeferStmt: Walk(v, n.Call) case *ReturnStmt: walkList(v, n.Results) case *BranchStmt: if n.Label != nil { Walk(v, n.Label) } case *BlockStmt: walkList(v, n.List) case *IfStmt: if n.Init != nil { Walk(v, n.Init) } Walk(v, n.Cond) Walk(v, n.Body) if n.Else != nil { Walk(v, n.Else) } case *CaseClause: walkList(v, n.List) walkList(v, n.Body) case *SwitchStmt: if n.Init != nil { Walk(v, n.Init) } if n.Tag != nil { Walk(v, n.Tag) } Walk(v, n.Body) case *TypeSwitchStmt: if n.Init != nil { Walk(v, n.Init) } Walk(v, n.Assign) Walk(v, n.Body) case *CommClause: if n.Comm != nil { Walk(v, n.Comm) } walkList(v, n.Body) case *SelectStmt: Walk(v, n.Body) case *ForStmt: if n.Init != nil { Walk(v, n.Init) } if n.Cond != nil { Walk(v, n.Cond) } if n.Post != nil { Walk(v, n.Post) } Walk(v, n.Body) case *RangeStmt: if n.Key != nil { Walk(v, n.Key) } if n.Value != nil { Walk(v, n.Value) } Walk(v, n.X) Walk(v, n.Body) // Declarations case *ImportSpec: if n.Doc != nil { Walk(v, n.Doc) } if n.Name != nil { Walk(v, n.Name) } Walk(v, n.Path) if n.Comment != nil { Walk(v, n.Comment) } case *ValueSpec: if n.Doc != nil { Walk(v, n.Doc) } walkList(v, n.Names) if n.Type != nil { Walk(v, n.Type) } walkList(v, n.Values) if n.Comment != nil { Walk(v, n.Comment) } case *TypeSpec: if n.Doc != nil { Walk(v, n.Doc) } Walk(v, n.Name) if n.TypeParams != nil { Walk(v, n.TypeParams) } Walk(v, n.Type) if n.Comment != nil { Walk(v, n.Comment) } case *BadDecl: // nothing to do case *GenDecl: if n.Doc != nil { Walk(v, n.Doc) } walkList(v, n.Specs) case *FuncDecl: if n.Doc != nil { Walk(v, n.Doc) } if n.Recv != nil { Walk(v, n.Recv) } Walk(v, n.Name) Walk(v, n.Type) if n.Body != nil { Walk(v, n.Body) } // Files and packages case *File: if n.Doc != nil { Walk(v, n.Doc) } Walk(v, n.Name) walkList(v, n.Decls) // don't walk n.Comments - they have been // visited already through the individual // nodes case *Package: for _, f := range n.Files { Walk(v, f) } default: panic(fmt.Sprintf("ast.Walk: unexpected node type %T", n)) } v.Visit(nil) } type inspector func(Node) bool func (f inspector) Visit(node Node) Visitor { if f(node) { return f } return nil } // Inspect traverses an AST in depth-first order: It starts by calling // f(node); node must not be nil. If f returns true, Inspect invokes f // recursively for each of the non-nil children of node, followed by a // call of f(nil). func Inspect(node Node, f func(Node) bool) { Walk(inspector(f), node) } // Preorder returns an iterator over all the nodes of the syntax tree // beneath (and including) the specified root, in depth-first // preorder. // // For greater control over the traversal of each subtree, use [Inspect]. func Preorder(root Node) iter.Seq[Node] { return func(yield func(Node) bool) { ok := true Inspect(root, func(n Node) bool { if n != nil { // yield must not be called once ok is false. ok = ok && yield(n) } return ok }) } }