Skip to content

Commit e9f7be9

Browse files
adonovangopherbot
authored andcommitted
internal/astutil/cursor: add Cursor.Child(Node) Cursor
This method returns the cursor for a direct child, more efficiently than FindNode. Also, add edge.Kind.Get method. + tests Change-Id: I1176ac55713ef0c06b02a1e3a9c64f530caa9a09 Reviewed-on: https://go-review.googlesource.com/c/tools/+/642936 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Robert Findley <[email protected]> Commit-Queue: Alan Donovan <[email protected]> Auto-Submit: Alan Donovan <[email protected]>
1 parent f912a4f commit e9f7be9

File tree

3 files changed

+61
-4
lines changed

3 files changed

+61
-4
lines changed

internal/astutil/cursor/cursor.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package cursor
1616

1717
import (
18+
"fmt"
1819
"go/ast"
1920
"go/token"
2021
"iter"
@@ -227,6 +228,34 @@ func (c Cursor) Edge() (edge.Kind, int) {
227228
return unpackEdgeKindAndIndex(events[pop].parent)
228229
}
229230

231+
// Child returns the cursor for n, which must be a direct child of c's Node.
232+
//
233+
// Child must not be called on the Root node (whose [Cursor.Node] returns nil).
234+
func (c Cursor) Child(n ast.Node) Cursor {
235+
if c.index < 0 {
236+
panic("Cursor.Child called on Root node")
237+
}
238+
239+
if false {
240+
// reference implementation
241+
for child := range c.Children() {
242+
if child.Node() == n {
243+
return child
244+
}
245+
}
246+
247+
} else {
248+
// optimized implementation
249+
events := c.events()
250+
for i := c.index + 1; events[i].index > i; i = events[i].index + 1 {
251+
if events[i].node == n {
252+
return Cursor{c.in, i}
253+
}
254+
}
255+
}
256+
panic(fmt.Sprintf("Child(%T): not a child of %v", n, c))
257+
}
258+
230259
// NextSibling returns the cursor for the next sibling node in the same list
231260
// (for example, of files, decls, specs, statements, fields, or expressions) as
232261
// the current node. It returns (zero, false) if the node is the last node in

internal/astutil/cursor/cursor_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,12 @@ func TestCursor_Edge(t *testing.T) {
360360
e.NodeType(), parent.Node())
361361
}
362362

363+
// Check consistency of c.Edge.Get(c.Parent().Node()) == c.Node().
364+
if got := e.Get(parent.Node(), idx); got != cur.Node() {
365+
t.Errorf("cur=%v@%s: %s.Get(cur.Parent().Node(), %d) = %T@%s, want cur.Node()",
366+
cur, netFset.Position(cur.Node().Pos()), e, idx, got, netFset.Position(got.Pos()))
367+
}
368+
363369
// Check that reflection on the parent finds the current node.
364370
fv := reflect.ValueOf(parent.Node()).Elem().FieldByName(e.FieldName())
365371
if idx >= 0 {
@@ -373,6 +379,11 @@ func TestCursor_Edge(t *testing.T) {
373379
t.Errorf("%v.Edge = (%v, %d); FieldName/Index reflection gave %T@%s, not original node",
374380
cur, e, idx, got, netFset.Position(got.Pos()))
375381
}
382+
383+
// Check that Cursor.Child is the reverse of Parent.
384+
if cur.Parent().Child(cur.Node()) != cur {
385+
t.Errorf("Cursor.Parent.Child = %v, want %v", cur.Parent().Child(cur.Node()), cur)
386+
}
376387
}
377388
}
378389

internal/astutil/edge/edge.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,34 @@ func (k Kind) String() string {
2121
return "<invalid>"
2222
}
2323
info := fieldInfos[k]
24-
return fmt.Sprintf("%v.%s", info.nodeType.Elem().Name(), info.fieldName)
24+
return fmt.Sprintf("%v.%s", info.nodeType.Elem().Name(), info.name)
2525
}
2626

2727
// NodeType returns the pointer-to-struct type of the ast.Node implementation.
2828
func (k Kind) NodeType() reflect.Type { return fieldInfos[k].nodeType }
2929

3030
// FieldName returns the name of the field.
31-
func (k Kind) FieldName() string { return fieldInfos[k].fieldName }
31+
func (k Kind) FieldName() string { return fieldInfos[k].name }
3232

3333
// FieldType returns the declared type of the field.
3434
func (k Kind) FieldType() reflect.Type { return fieldInfos[k].fieldType }
3535

36+
// Get returns the direct child of n identified by (k, idx).
37+
// n's type must match k.NodeType().
38+
// idx must be a valid slice index, or -1 for a non-slice.
39+
func (k Kind) Get(n ast.Node, idx int) ast.Node {
40+
if k.NodeType() != reflect.TypeOf(n) {
41+
panic(fmt.Sprintf("%v.Get(%T): invalid node type", k, n))
42+
}
43+
v := reflect.ValueOf(n).Elem().Field(fieldInfos[k].index)
44+
if idx != -1 {
45+
v = v.Index(idx) // asserts valid index
46+
} else {
47+
// (The type assertion below asserts that v is not a slice.)
48+
}
49+
return v.Interface().(ast.Node) // may be nil
50+
}
51+
3652
const (
3753
Invalid Kind = iota // for nodes at the root of the traversal
3854

@@ -156,7 +172,8 @@ var _ = [1 << 7]struct{}{}[maxKind]
156172

157173
type fieldInfo struct {
158174
nodeType reflect.Type // pointer-to-struct type of ast.Node implementation
159-
fieldName string
175+
name string
176+
index int
160177
fieldType reflect.Type
161178
}
162179

@@ -166,7 +183,7 @@ func info[N ast.Node](fieldName string) fieldInfo {
166183
if !ok {
167184
panic(fieldName)
168185
}
169-
return fieldInfo{nodePtrType, fieldName, f.Type}
186+
return fieldInfo{nodePtrType, fieldName, f.Index[0], f.Type}
170187
}
171188

172189
var fieldInfos = [...]fieldInfo{

0 commit comments

Comments
 (0)