1
2
3
4
5 package astutil
6
7 import (
8 "fmt"
9 "go/ast"
10 "reflect"
11 "sort"
12 )
13
14
15
16
17
18
19
20 type ApplyFunc func(*Cursor) bool
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42 func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
43 parent := &struct{ ast.Node }{root}
44 defer func() {
45 if r := recover(); r != nil && r != abort {
46 panic(r)
47 }
48 result = parent.Node
49 }()
50 a := &application{pre: pre, post: post}
51 a.apply(parent, "Node", nil, root)
52 return
53 }
54
55 var abort = new(int)
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 type Cursor struct {
71 parent ast.Node
72 name string
73 iter *iterator
74 node ast.Node
75 }
76
77
78 func (c *Cursor) Node() ast.Node { return c.node }
79
80
81 func (c *Cursor) Parent() ast.Node { return c.parent }
82
83
84
85
86 func (c *Cursor) Name() string { return c.name }
87
88
89
90
91
92 func (c *Cursor) Index() int {
93 if c.iter != nil {
94 return c.iter.index
95 }
96 return -1
97 }
98
99
100 func (c *Cursor) field() reflect.Value {
101 return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
102 }
103
104
105
106 func (c *Cursor) Replace(n ast.Node) {
107 if _, ok := c.node.(*ast.File); ok {
108 file, ok := n.(*ast.File)
109 if !ok {
110 panic("attempt to replace *ast.File with non-*ast.File")
111 }
112 c.parent.(*ast.Package).Files[c.name] = file
113 return
114 }
115
116 v := c.field()
117 if i := c.Index(); i >= 0 {
118 v = v.Index(i)
119 }
120 v.Set(reflect.ValueOf(n))
121 }
122
123
124
125
126
127 func (c *Cursor) Delete() {
128 if _, ok := c.node.(*ast.File); ok {
129 delete(c.parent.(*ast.Package).Files, c.name)
130 return
131 }
132
133 i := c.Index()
134 if i < 0 {
135 panic("Delete node not contained in slice")
136 }
137 v := c.field()
138 l := v.Len()
139 reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
140 v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
141 v.SetLen(l - 1)
142 c.iter.step--
143 }
144
145
146
147
148 func (c *Cursor) InsertAfter(n ast.Node) {
149 i := c.Index()
150 if i < 0 {
151 panic("InsertAfter node not contained in slice")
152 }
153 v := c.field()
154 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
155 l := v.Len()
156 reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
157 v.Index(i + 1).Set(reflect.ValueOf(n))
158 c.iter.step++
159 }
160
161
162
163
164 func (c *Cursor) InsertBefore(n ast.Node) {
165 i := c.Index()
166 if i < 0 {
167 panic("InsertBefore node not contained in slice")
168 }
169 v := c.field()
170 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
171 l := v.Len()
172 reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
173 v.Index(i).Set(reflect.ValueOf(n))
174 c.iter.index++
175 }
176
177
178 type application struct {
179 pre, post ApplyFunc
180 cursor Cursor
181 iter iterator
182 }
183
184 func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
185
186 if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() {
187 n = nil
188 }
189
190
191 saved := a.cursor
192 a.cursor.parent = parent
193 a.cursor.name = name
194 a.cursor.iter = iter
195 a.cursor.node = n
196
197 if a.pre != nil && !a.pre(&a.cursor) {
198 a.cursor = saved
199 return
200 }
201
202
203
204 switch n := n.(type) {
205 case nil:
206
207
208
209 case *ast.Comment:
210
211
212 case *ast.CommentGroup:
213 if n != nil {
214 a.applyList(n, "List")
215 }
216
217 case *ast.Field:
218 a.apply(n, "Doc", nil, n.Doc)
219 a.applyList(n, "Names")
220 a.apply(n, "Type", nil, n.Type)
221 a.apply(n, "Tag", nil, n.Tag)
222 a.apply(n, "Comment", nil, n.Comment)
223
224 case *ast.FieldList:
225 a.applyList(n, "List")
226
227
228 case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
229
230
231 case *ast.Ellipsis:
232 a.apply(n, "Elt", nil, n.Elt)
233
234 case *ast.FuncLit:
235 a.apply(n, "Type", nil, n.Type)
236 a.apply(n, "Body", nil, n.Body)
237
238 case *ast.CompositeLit:
239 a.apply(n, "Type", nil, n.Type)
240 a.applyList(n, "Elts")
241
242 case *ast.ParenExpr:
243 a.apply(n, "X", nil, n.X)
244
245 case *ast.SelectorExpr:
246 a.apply(n, "X", nil, n.X)
247 a.apply(n, "Sel", nil, n.Sel)
248
249 case *ast.IndexExpr:
250 a.apply(n, "X", nil, n.X)
251 a.apply(n, "Index", nil, n.Index)
252
253 case *ast.IndexListExpr:
254 a.apply(n, "X", nil, n.X)
255 a.applyList(n, "Indices")
256
257 case *ast.SliceExpr:
258 a.apply(n, "X", nil, n.X)
259 a.apply(n, "Low", nil, n.Low)
260 a.apply(n, "High", nil, n.High)
261 a.apply(n, "Max", nil, n.Max)
262
263 case *ast.TypeAssertExpr:
264 a.apply(n, "X", nil, n.X)
265 a.apply(n, "Type", nil, n.Type)
266
267 case *ast.CallExpr:
268 a.apply(n, "Fun", nil, n.Fun)
269 a.applyList(n, "Args")
270
271 case *ast.StarExpr:
272 a.apply(n, "X", nil, n.X)
273
274 case *ast.UnaryExpr:
275 a.apply(n, "X", nil, n.X)
276
277 case *ast.BinaryExpr:
278 a.apply(n, "X", nil, n.X)
279 a.apply(n, "Y", nil, n.Y)
280
281 case *ast.KeyValueExpr:
282 a.apply(n, "Key", nil, n.Key)
283 a.apply(n, "Value", nil, n.Value)
284
285
286 case *ast.ArrayType:
287 a.apply(n, "Len", nil, n.Len)
288 a.apply(n, "Elt", nil, n.Elt)
289
290 case *ast.StructType:
291 a.apply(n, "Fields", nil, n.Fields)
292
293 case *ast.FuncType:
294 if tparams := n.TypeParams; tparams != nil {
295 a.apply(n, "TypeParams", nil, tparams)
296 }
297 a.apply(n, "Params", nil, n.Params)
298 a.apply(n, "Results", nil, n.Results)
299
300 case *ast.InterfaceType:
301 a.apply(n, "Methods", nil, n.Methods)
302
303 case *ast.MapType:
304 a.apply(n, "Key", nil, n.Key)
305 a.apply(n, "Value", nil, n.Value)
306
307 case *ast.ChanType:
308 a.apply(n, "Value", nil, n.Value)
309
310
311 case *ast.BadStmt:
312
313
314 case *ast.DeclStmt:
315 a.apply(n, "Decl", nil, n.Decl)
316
317 case *ast.EmptyStmt:
318
319
320 case *ast.LabeledStmt:
321 a.apply(n, "Label", nil, n.Label)
322 a.apply(n, "Stmt", nil, n.Stmt)
323
324 case *ast.ExprStmt:
325 a.apply(n, "X", nil, n.X)
326
327 case *ast.SendStmt:
328 a.apply(n, "Chan", nil, n.Chan)
329 a.apply(n, "Value", nil, n.Value)
330
331 case *ast.IncDecStmt:
332 a.apply(n, "X", nil, n.X)
333
334 case *ast.AssignStmt:
335 a.applyList(n, "Lhs")
336 a.applyList(n, "Rhs")
337
338 case *ast.GoStmt:
339 a.apply(n, "Call", nil, n.Call)
340
341 case *ast.DeferStmt:
342 a.apply(n, "Call", nil, n.Call)
343
344 case *ast.ReturnStmt:
345 a.applyList(n, "Results")
346
347 case *ast.BranchStmt:
348 a.apply(n, "Label", nil, n.Label)
349
350 case *ast.BlockStmt:
351 a.applyList(n, "List")
352
353 case *ast.IfStmt:
354 a.apply(n, "Init", nil, n.Init)
355 a.apply(n, "Cond", nil, n.Cond)
356 a.apply(n, "Body", nil, n.Body)
357 a.apply(n, "Else", nil, n.Else)
358
359 case *ast.CaseClause:
360 a.applyList(n, "List")
361 a.applyList(n, "Body")
362
363 case *ast.SwitchStmt:
364 a.apply(n, "Init", nil, n.Init)
365 a.apply(n, "Tag", nil, n.Tag)
366 a.apply(n, "Body", nil, n.Body)
367
368 case *ast.TypeSwitchStmt:
369 a.apply(n, "Init", nil, n.Init)
370 a.apply(n, "Assign", nil, n.Assign)
371 a.apply(n, "Body", nil, n.Body)
372
373 case *ast.CommClause:
374 a.apply(n, "Comm", nil, n.Comm)
375 a.applyList(n, "Body")
376
377 case *ast.SelectStmt:
378 a.apply(n, "Body", nil, n.Body)
379
380 case *ast.ForStmt:
381 a.apply(n, "Init", nil, n.Init)
382 a.apply(n, "Cond", nil, n.Cond)
383 a.apply(n, "Post", nil, n.Post)
384 a.apply(n, "Body", nil, n.Body)
385
386 case *ast.RangeStmt:
387 a.apply(n, "Key", nil, n.Key)
388 a.apply(n, "Value", nil, n.Value)
389 a.apply(n, "X", nil, n.X)
390 a.apply(n, "Body", nil, n.Body)
391
392
393 case *ast.ImportSpec:
394 a.apply(n, "Doc", nil, n.Doc)
395 a.apply(n, "Name", nil, n.Name)
396 a.apply(n, "Path", nil, n.Path)
397 a.apply(n, "Comment", nil, n.Comment)
398
399 case *ast.ValueSpec:
400 a.apply(n, "Doc", nil, n.Doc)
401 a.applyList(n, "Names")
402 a.apply(n, "Type", nil, n.Type)
403 a.applyList(n, "Values")
404 a.apply(n, "Comment", nil, n.Comment)
405
406 case *ast.TypeSpec:
407 a.apply(n, "Doc", nil, n.Doc)
408 a.apply(n, "Name", nil, n.Name)
409 if tparams := n.TypeParams; tparams != nil {
410 a.apply(n, "TypeParams", nil, tparams)
411 }
412 a.apply(n, "Type", nil, n.Type)
413 a.apply(n, "Comment", nil, n.Comment)
414
415 case *ast.BadDecl:
416
417
418 case *ast.GenDecl:
419 a.apply(n, "Doc", nil, n.Doc)
420 a.applyList(n, "Specs")
421
422 case *ast.FuncDecl:
423 a.apply(n, "Doc", nil, n.Doc)
424 a.apply(n, "Recv", nil, n.Recv)
425 a.apply(n, "Name", nil, n.Name)
426 a.apply(n, "Type", nil, n.Type)
427 a.apply(n, "Body", nil, n.Body)
428
429
430 case *ast.File:
431 a.apply(n, "Doc", nil, n.Doc)
432 a.apply(n, "Name", nil, n.Name)
433 a.applyList(n, "Decls")
434
435
436
437 case *ast.Package:
438
439 var names []string
440 for name := range n.Files {
441 names = append(names, name)
442 }
443 sort.Strings(names)
444 for _, name := range names {
445 a.apply(n, name, nil, n.Files[name])
446 }
447
448 default:
449 panic(fmt.Sprintf("Apply: unexpected node type %T", n))
450 }
451
452 if a.post != nil && !a.post(&a.cursor) {
453 panic(abort)
454 }
455
456 a.cursor = saved
457 }
458
459
460 type iterator struct {
461 index, step int
462 }
463
464 func (a *application) applyList(parent ast.Node, name string) {
465
466 saved := a.iter
467 a.iter.index = 0
468 for {
469
470 v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
471 if a.iter.index >= v.Len() {
472 break
473 }
474
475
476 var x ast.Node
477 if e := v.Index(a.iter.index); e.IsValid() {
478 x = e.Interface().(ast.Node)
479 }
480
481 a.iter.step = 1
482 a.apply(parent, name, &a.iter, x)
483 a.iter.index += a.iter.step
484 }
485 a.iter = saved
486 }
487
View as plain text