Skip to content

Commit b797704

Browse files
thepuddstimothy-king
authored andcommitted
go/analysis/passes/loopclosure: recursively check last statements in statements like if, switch, and for
In golang/go#16520, there was a suggestion to extend the current loopclosure check to check more statements. The current loopclosure flags patterns like: for k, v := range seq { go/defer func() { ... k, v ... }() } For this CL, the motivating example from golang/go#16520 is: var wg sync.WaitGroup for i := 0; i < 10; i++ { for j := 0; j < 1; j++ { wg.Add(1) go func() { fmt.Printf("%d ", i) wg.Done() }() } } wg.Wait() The current loopclosure check does not flag this because of the inner for loop, and the checker looks only at the last statement in the outer loop body. The suggestion is we redefine "last" recursively. For example, if the last statement is an if, then we examine the last statements in both of its branches. Or if the last statement is a nested loop, then we examine the last statement of that loop's body, and so on. A few years ago, Alan Donovan sent a sketch in CL 184537. This CL attempts to complete Alan's sketch, as well as integrates with the ensuing changes from golang/go#55972 to check errgroup.Group.Go, which with this CL can now be recursively "last". Updates golang/go#16520 Updates golang/go#55972 Fixes golang/go#30649 Fixes golang/go#32876 Change-Id: If66c6707025c20f32a2a781f6d11c4901f15742a GitHub-Last-Rev: 04980e0 GitHub-Pull-Request: #415 Reviewed-on: https://go-review.googlesource.com/c/tools/+/452155 Reviewed-by: Tim King <[email protected]> Run-TryBot: Tim King <[email protected]> Reviewed-by: Alan Donovan <[email protected]> gopls-CI: kokoro <[email protected]> Reviewed-by: Robert Findley <[email protected]> Run-TryBot: Alan Donovan <[email protected]> TryBot-Result: Gopher Robot <[email protected]>
1 parent 3b9d20c commit b797704

File tree

4 files changed

+317
-77
lines changed

4 files changed

+317
-77
lines changed

go/analysis/passes/loopclosure/loopclosure.go

Lines changed: 155 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,60 @@ import (
1818

1919
const Doc = `check references to loop variables from within nested functions
2020
21-
This analyzer checks for references to loop variables from within a function
22-
literal inside the loop body. It checks for patterns where access to a loop
23-
variable is known to escape the current loop iteration:
24-
1. a call to go or defer at the end of the loop body
25-
2. a call to golang.org/x/sync/errgroup.Group.Go at the end of the loop body
26-
3. a call testing.T.Run where the subtest body invokes t.Parallel()
27-
28-
In the case of (1) and (2), the analyzer only considers references in the last
29-
statement of the loop body as it is not deep enough to understand the effects
30-
of subsequent statements which might render the reference benign.
31-
32-
For example:
33-
34-
for i, v := range s {
35-
go func() {
36-
println(i, v) // not what you might expect
37-
}()
38-
}
21+
This analyzer reports places where a function literal references the
22+
iteration variable of an enclosing loop, and the loop calls the function
23+
in such a way (e.g. with go or defer) that it may outlive the loop
24+
iteration and possibly observe the wrong value of the variable.
25+
26+
In this example, all the deferred functions run after the loop has
27+
completed, so all observe the final value of v.
28+
29+
for _, v := range list {
30+
defer func() {
31+
use(v) // incorrect
32+
}()
33+
}
34+
35+
One fix is to create a new variable for each iteration of the loop:
36+
37+
for _, v := range list {
38+
v := v // new var per iteration
39+
defer func() {
40+
use(v) // ok
41+
}()
42+
}
43+
44+
The next example uses a go statement and has a similar problem.
45+
In addition, it has a data race because the loop updates v
46+
concurrent with the goroutines accessing it.
47+
48+
for _, v := range elem {
49+
go func() {
50+
use(v) // incorrect, and a data race
51+
}()
52+
}
53+
54+
A fix is the same as before. The checker also reports problems
55+
in goroutines started by golang.org/x/sync/errgroup.Group.
56+
A hard-to-spot variant of this form is common in parallel tests:
57+
58+
func Test(t *testing.T) {
59+
for _, test := range tests {
60+
t.Run(test.name, func(t *testing.T) {
61+
t.Parallel()
62+
use(test) // incorrect, and a data race
63+
})
64+
}
65+
}
66+
67+
The t.Parallel() call causes the rest of the function to execute
68+
concurrent with the loop.
69+
70+
The analyzer reports references only in the last statement,
71+
as it is not deep enough to understand the effects of subsequent
72+
statements that might render the reference benign.
73+
("Last statement" is defined recursively in compound
74+
statements such as if, switch, and select.)
3975
4076
See: https://golang.org/doc/go_faq.html#closures_and_goroutines`
4177

@@ -91,59 +127,121 @@ func run(pass *analysis.Pass) (interface{}, error) {
91127
//
92128
// For go, defer, and errgroup.Group.Go, we ignore all but the last
93129
// statement, because it's hard to prove go isn't followed by wait, or
94-
// defer by return.
130+
// defer by return. "Last" is defined recursively.
95131
//
132+
// TODO: consider allowing the "last" go/defer/Go statement to be followed by
133+
// N "trivial" statements, possibly under a recursive definition of "trivial"
134+
// so that that checker could, for example, conclude that a go statement is
135+
// followed by an if statement made of only trivial statements and trivial expressions,
136+
// and hence the go statement could still be checked.
137+
forEachLastStmt(body.List, func(last ast.Stmt) {
138+
var stmts []ast.Stmt
139+
switch s := last.(type) {
140+
case *ast.GoStmt:
141+
stmts = litStmts(s.Call.Fun)
142+
case *ast.DeferStmt:
143+
stmts = litStmts(s.Call.Fun)
144+
case *ast.ExprStmt: // check for errgroup.Group.Go
145+
if call, ok := s.X.(*ast.CallExpr); ok {
146+
stmts = litStmts(goInvoke(pass.TypesInfo, call))
147+
}
148+
}
149+
for _, stmt := range stmts {
150+
reportCaptured(pass, vars, stmt)
151+
}
152+
})
153+
154+
// Also check for testing.T.Run (with T.Parallel).
96155
// We consider every t.Run statement in the loop body, because there is
97-
// no such commonly used mechanism for synchronizing parallel subtests.
156+
// no commonly used mechanism for synchronizing parallel subtests.
98157
// It is of course theoretically possible to synchronize parallel subtests,
99158
// though such a pattern is likely to be exceedingly rare as it would be
100159
// fighting against the test runner.
101-
lastStmt := len(body.List) - 1
102-
for i, s := range body.List {
103-
var stmts []ast.Stmt // statements that must be checked for escaping references
160+
for _, s := range body.List {
104161
switch s := s.(type) {
105-
case *ast.GoStmt:
106-
if i == lastStmt {
107-
stmts = litStmts(s.Call.Fun)
108-
}
109-
110-
case *ast.DeferStmt:
111-
if i == lastStmt {
112-
stmts = litStmts(s.Call.Fun)
113-
}
114-
115-
case *ast.ExprStmt: // check for errgroup.Group.Go and testing.T.Run (with T.Parallel)
162+
case *ast.ExprStmt:
116163
if call, ok := s.X.(*ast.CallExpr); ok {
117-
if i == lastStmt {
118-
stmts = litStmts(goInvoke(pass.TypesInfo, call))
119-
}
120-
if stmts == nil {
121-
stmts = parallelSubtest(pass.TypesInfo, call)
164+
for _, stmt := range parallelSubtest(pass.TypesInfo, call) {
165+
reportCaptured(pass, vars, stmt)
122166
}
167+
123168
}
124169
}
170+
}
171+
})
172+
return nil, nil
173+
}
125174

126-
for _, stmt := range stmts {
127-
ast.Inspect(stmt, func(n ast.Node) bool {
128-
id, ok := n.(*ast.Ident)
129-
if !ok {
130-
return true
131-
}
132-
obj := pass.TypesInfo.Uses[id]
133-
if obj == nil {
134-
return true
135-
}
136-
for _, v := range vars {
137-
if v == obj {
138-
pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name)
139-
}
140-
}
141-
return true
142-
})
175+
// reportCaptured reports a diagnostic stating a loop variable
176+
// has been captured by a func literal if checkStmt has escaping
177+
// references to vars. vars is expected to be variables updated by a loop statement,
178+
// and checkStmt is expected to be a statements from the body of a func literal in the loop.
179+
func reportCaptured(pass *analysis.Pass, vars []types.Object, checkStmt ast.Stmt) {
180+
ast.Inspect(checkStmt, func(n ast.Node) bool {
181+
id, ok := n.(*ast.Ident)
182+
if !ok {
183+
return true
184+
}
185+
obj := pass.TypesInfo.Uses[id]
186+
if obj == nil {
187+
return true
188+
}
189+
for _, v := range vars {
190+
if v == obj {
191+
pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name)
143192
}
144193
}
194+
return true
145195
})
146-
return nil, nil
196+
}
197+
198+
// forEachLastStmt calls onLast on each "last" statement in a list of statements.
199+
// "Last" is defined recursively so, for example, if the last statement is
200+
// a switch statement, then each switch case is also visited to examine
201+
// its last statements.
202+
func forEachLastStmt(stmts []ast.Stmt, onLast func(last ast.Stmt)) {
203+
if len(stmts) == 0 {
204+
return
205+
}
206+
207+
s := stmts[len(stmts)-1]
208+
switch s := s.(type) {
209+
case *ast.IfStmt:
210+
loop:
211+
for {
212+
forEachLastStmt(s.Body.List, onLast)
213+
switch e := s.Else.(type) {
214+
case *ast.BlockStmt:
215+
forEachLastStmt(e.List, onLast)
216+
break loop
217+
case *ast.IfStmt:
218+
s = e
219+
case nil:
220+
break loop
221+
}
222+
}
223+
case *ast.ForStmt:
224+
forEachLastStmt(s.Body.List, onLast)
225+
case *ast.RangeStmt:
226+
forEachLastStmt(s.Body.List, onLast)
227+
case *ast.SwitchStmt:
228+
for _, c := range s.Body.List {
229+
cc := c.(*ast.CaseClause)
230+
forEachLastStmt(cc.Body, onLast)
231+
}
232+
case *ast.TypeSwitchStmt:
233+
for _, c := range s.Body.List {
234+
cc := c.(*ast.CaseClause)
235+
forEachLastStmt(cc.Body, onLast)
236+
}
237+
case *ast.SelectStmt:
238+
for _, c := range s.Body.List {
239+
cc := c.(*ast.CommClause)
240+
forEachLastStmt(cc.Body, onLast)
241+
}
242+
default:
243+
onLast(s)
244+
}
147245
}
148246

149247
// litStmts returns all statements from the function body of a function

go/analysis/passes/loopclosure/testdata/src/a/a.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
package testdata
88

99
import (
10+
"sync"
11+
1012
"golang.org/x/sync/errgroup"
1113
)
1214

@@ -108,6 +110,70 @@ func _() {
108110
}
109111
}
110112

113+
// Cases that rely on recursively checking for last statements.
114+
func _() {
115+
116+
for i := range "outer" {
117+
for j := range "inner" {
118+
if j < 1 {
119+
defer func() {
120+
print(i) // want "loop variable i captured by func literal"
121+
}()
122+
} else if j < 2 {
123+
go func() {
124+
print(i) // want "loop variable i captured by func literal"
125+
}()
126+
} else {
127+
go func() {
128+
print(i)
129+
}()
130+
println("we don't catch the error above because of this statement")
131+
}
132+
}
133+
}
134+
135+
for i := 0; i < 10; i++ {
136+
for j := 0; j < 10; j++ {
137+
if j < 1 {
138+
switch j {
139+
case 0:
140+
defer func() {
141+
print(i) // want "loop variable i captured by func literal"
142+
}()
143+
default:
144+
go func() {
145+
print(i) // want "loop variable i captured by func literal"
146+
}()
147+
}
148+
} else if j < 2 {
149+
var a interface{} = j
150+
switch a.(type) {
151+
case int:
152+
defer func() {
153+
print(i) // want "loop variable i captured by func literal"
154+
}()
155+
default:
156+
go func() {
157+
print(i) // want "loop variable i captured by func literal"
158+
}()
159+
}
160+
} else {
161+
ch := make(chan string)
162+
select {
163+
case <-ch:
164+
defer func() {
165+
print(i) // want "loop variable i captured by func literal"
166+
}()
167+
default:
168+
go func() {
169+
print(i) // want "loop variable i captured by func literal"
170+
}()
171+
}
172+
}
173+
}
174+
}
175+
}
176+
111177
// Group is used to test that loopclosure only matches Group.Go when Group is
112178
// from the golang.org/x/sync/errgroup package.
113179
type Group struct{}
@@ -125,6 +191,21 @@ func _() {
125191
return nil
126192
})
127193
}
194+
195+
for i, v := range s {
196+
if i > 0 {
197+
g.Go(func() error {
198+
print(i) // want "loop variable i captured by func literal"
199+
return nil
200+
})
201+
} else {
202+
g.Go(func() error {
203+
print(v) // want "loop variable v captured by func literal"
204+
return nil
205+
})
206+
}
207+
}
208+
128209
// Do not match other Group.Go cases
129210
g1 := new(Group)
130211
for i, v := range s {
@@ -135,3 +216,28 @@ func _() {
135216
})
136217
}
137218
}
219+
220+
// Real-world example from #16520, slightly simplified
221+
func _() {
222+
var nodes []interface{}
223+
224+
critical := new(errgroup.Group)
225+
others := sync.WaitGroup{}
226+
227+
isCritical := func(node interface{}) bool { return false }
228+
run := func(node interface{}) error { return nil }
229+
230+
for _, node := range nodes {
231+
if isCritical(node) {
232+
critical.Go(func() error {
233+
return run(node) // want "loop variable node captured by func literal"
234+
})
235+
} else {
236+
others.Add(1)
237+
go func() {
238+
_ = run(node) // want "loop variable node captured by func literal"
239+
others.Done()
240+
}()
241+
}
242+
}
243+
}

0 commit comments

Comments
 (0)