Skip to content

Commit 6c3e542

Browse files
xieyuschengopherbot
authored andcommitted
gopls/internal/analysis/modernize: preserves comments in minmax
This CL changes the original deletion (from after rh0 to the end of the if stmt) into deletion from the start of assignment to the end of if stmt), add all comments between them and create a new assigment as-is. This change preserves all comments inside if stmt and the comments after the line of assignment and before if stmt, causing comments B,C,D to be preserved and put on the top of min/max function call after fix. - source: lhs0 = rhs0 // A // B if rhs0 < b { // C lhs0 = b // D } - fixed: // A // B // C // D lhs0 = max(rhs0,b) Fixes golang/go#72727 Change-Id: I7c193711aac5834ebb0d5e8ae22c26ae7990c34f Reviewed-on: https://go-review.googlesource.com/c/tools/+/656655 Auto-Submit: Alan Donovan <[email protected]> Reviewed-by: Alan Donovan <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Robert Findley <[email protected]>
1 parent dcc4b8a commit 6c3e542

File tree

5 files changed

+148
-16
lines changed

5 files changed

+148
-16
lines changed

gopls/internal/analysis/modernize/minmax.go

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"go/ast"
1010
"go/token"
1111
"go/types"
12+
"strings"
1213

1314
"golang.org/x/tools/go/analysis"
1415
"golang.org/x/tools/go/analysis/passes/inspect"
@@ -32,7 +33,7 @@ func minmax(pass *analysis.Pass) {
3233

3334
// check is called for all statements of this form:
3435
// if a < b { lhs = rhs }
35-
check := func(curIfStmt cursor.Cursor, compare *ast.BinaryExpr) {
36+
check := func(file *ast.File, curIfStmt cursor.Cursor, compare *ast.BinaryExpr) {
3637
var (
3738
ifStmt = curIfStmt.Node().(*ast.IfStmt)
3839
tassign = ifStmt.Body.List[0].(*ast.AssignStmt)
@@ -44,6 +45,14 @@ func minmax(pass *analysis.Pass) {
4445
sign = isInequality(compare.Op)
4546
)
4647

48+
allComments := func(file *ast.File, start, end token.Pos) string {
49+
var buf strings.Builder
50+
for co := range analysisinternal.Comments(file, start, end) {
51+
_, _ = fmt.Fprintf(&buf, "%s\n", co.Text)
52+
}
53+
return buf.String()
54+
}
55+
4756
if fblock, ok := ifStmt.Else.(*ast.BlockStmt); ok && isAssignBlock(fblock) {
4857
fassign := fblock.List[0].(*ast.AssignStmt)
4958

@@ -85,7 +94,8 @@ func minmax(pass *analysis.Pass) {
8594
// Replace IfStmt with lhs = min(a, b).
8695
Pos: ifStmt.Pos(),
8796
End: ifStmt.End(),
88-
NewText: fmt.Appendf(nil, "%s = %s(%s, %s)",
97+
NewText: fmt.Appendf(nil, "%s%s = %s(%s, %s)",
98+
allComments(file, ifStmt.Pos(), ifStmt.End()),
8999
analysisinternal.Format(pass.Fset, lhs),
90100
sym,
91101
analysisinternal.Format(pass.Fset, a),
@@ -144,10 +154,13 @@ func minmax(pass *analysis.Pass) {
144154
SuggestedFixes: []analysis.SuggestedFix{{
145155
Message: fmt.Sprintf("Replace if/else with %s", sym),
146156
TextEdits: []analysis.TextEdit{{
147-
// Replace rhs0 and IfStmt with min(a, b)
148-
Pos: rhs0.Pos(),
157+
Pos: fassign.Pos(),
149158
End: ifStmt.End(),
150-
NewText: fmt.Appendf(nil, "%s(%s, %s)",
159+
// Replace "x := a; if ... {}" with "x = min(...)", preserving comments.
160+
NewText: fmt.Appendf(nil, "%s %s %s %s(%s, %s)",
161+
allComments(file, fassign.Pos(), ifStmt.End()),
162+
analysisinternal.Format(pass.Fset, lhs),
163+
fassign.Tok.String(),
151164
sym,
152165
analysisinternal.Format(pass.Fset, a),
153166
analysisinternal.Format(pass.Fset, b)),
@@ -161,16 +174,16 @@ func minmax(pass *analysis.Pass) {
161174
// Find all "if a < b { lhs = rhs }" statements.
162175
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
163176
for curFile := range filesUsing(inspect, pass.TypesInfo, "go1.21") {
177+
astFile := curFile.Node().(*ast.File)
164178
for curIfStmt := range curFile.Preorder((*ast.IfStmt)(nil)) {
165179
ifStmt := curIfStmt.Node().(*ast.IfStmt)
166-
167180
if compare, ok := ifStmt.Cond.(*ast.BinaryExpr); ok &&
168181
ifStmt.Init == nil &&
169182
isInequality(compare.Op) != 0 &&
170183
isAssignBlock(ifStmt.Body) {
171184

172185
// Have: if a < b { lhs = rhs }
173-
check(curIfStmt, compare)
186+
check(astFile, curIfStmt, compare)
174187
}
175188
}
176189
}

gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package minmax
22

33
func ifmin(a, b int) {
4-
x := a
4+
x := a // A
5+
// B
56
if a < b { // want "if statement can be modernized using max"
6-
x = b
7+
// C
8+
x = b // D
9+
// E
710
}
811
print(x)
912
}
@@ -33,20 +36,30 @@ func ifmaxvariant(a, b int) {
3336
}
3437

3538
func ifelsemin(a, b int) {
36-
var x int
39+
var x int // A
40+
// B
3741
if a <= b { // want "if/else statement can be modernized using min"
38-
x = a
42+
// C
43+
x = a // D
44+
// E
3945
} else {
40-
x = b
46+
// F
47+
x = b // G
48+
// H
4149
}
4250
print(x)
4351
}
4452

4553
func ifelsemax(a, b int) {
46-
var x int
54+
// A
55+
var x int // B
56+
// C
4757
if a >= b { // want "if/else statement can be modernized using max"
48-
x = a
58+
// D
59+
x = a // E
60+
// F
4961
} else {
62+
// G
5063
x = b
5164
}
5265
print(x)
@@ -115,3 +128,11 @@ func nopeHasElseBlock(x int) int {
115128
}
116129
return y
117130
}
131+
132+
func fix72727(a, b int) {
133+
o := a - 42
134+
// some important comment. DO NOT REMOVE.
135+
if o < b { // want "if statement can be modernized using max"
136+
o = b
137+
}
138+
}

gopls/internal/analysis/modernize/testdata/src/minmax/minmax.go.golden

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,57 @@
11
package minmax
22

33
func ifmin(a, b int) {
4+
// A
5+
// B
6+
// want "if statement can be modernized using max"
7+
// C
8+
// D
9+
// E
410
x := max(a, b)
511
print(x)
612
}
713

814
func ifmax(a, b int) {
15+
// want "if statement can be modernized using min"
916
x := min(a, b)
1017
print(x)
1118
}
1219

1320
func ifminvariant(a, b int) {
21+
// want "if statement can be modernized using min"
1422
x := min(a, b)
1523
print(x)
1624
}
1725

1826
func ifmaxvariant(a, b int) {
27+
// want "if statement can be modernized using min"
1928
x := min(a, b)
2029
print(x)
2130
}
2231

2332
func ifelsemin(a, b int) {
24-
var x int
33+
var x int // A
34+
// B
35+
// want "if/else statement can be modernized using min"
36+
// C
37+
// D
38+
// E
39+
// F
40+
// G
41+
// H
2542
x = min(a, b)
2643
print(x)
2744
}
2845

2946
func ifelsemax(a, b int) {
30-
var x int
47+
// A
48+
var x int // B
49+
// C
50+
// want "if/else statement can be modernized using max"
51+
// D
52+
// E
53+
// F
54+
// G
3155
x = max(a, b)
3256
print(x)
3357
}
@@ -55,6 +79,7 @@ func nopeIfStmtHasInitStmt() {
5579
// Regression test for a bug: fix was "y := max(x, y)".
5680
func oops() {
5781
x := 1
82+
// want "if statement can be modernized using max"
5883
y := max(x, 2)
5984
print(y)
6085
}
@@ -92,3 +117,9 @@ func nopeHasElseBlock(x int) int {
92117
}
93118
return y
94119
}
120+
121+
func fix72727(a, b int) {
122+
// some important comment. DO NOT REMOVE.
123+
// want "if statement can be modernized using max"
124+
o := max(a-42, b)
125+
}

internal/analysisinternal/analysis.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"go/scanner"
1616
"go/token"
1717
"go/types"
18+
"iter"
1819
pathpkg "path"
1920
"slices"
2021
"strings"
@@ -608,3 +609,24 @@ Outer:
608609
}
609610
return []analysis.TextEdit{edit}
610611
}
612+
613+
// Comments returns an iterator over the comments overlapping the specified interval.
614+
func Comments(file *ast.File, start, end token.Pos) iter.Seq[*ast.Comment] {
615+
// TODO(adonovan): optimize use binary O(log n) instead of linear O(n) search.
616+
return func(yield func(*ast.Comment) bool) {
617+
for _, cg := range file.Comments {
618+
for _, co := range cg.List {
619+
if co.Pos() > end {
620+
return
621+
}
622+
if co.End() < start {
623+
continue
624+
}
625+
626+
if !yield(co) {
627+
return
628+
}
629+
}
630+
}
631+
}
632+
}

internal/analysisinternal/analysis_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"go/ast"
99
"go/parser"
1010
"go/token"
11+
"slices"
1112
"testing"
1213

1314
"golang.org/x/tools/go/ast/inspector"
@@ -253,3 +254,47 @@ func TestDeleteStmt(t *testing.T) {
253254

254255
}
255256
}
257+
258+
func TestComments(t *testing.T) {
259+
src := `
260+
package main
261+
262+
// A
263+
func fn() { }`
264+
var fset token.FileSet
265+
f, err := parser.ParseFile(&fset, "", []byte(src), parser.ParseComments|parser.AllErrors)
266+
if err != nil {
267+
t.Fatal(err)
268+
}
269+
270+
commentA := f.Comments[0].List[0]
271+
commentAMidPos := (commentA.Pos() + commentA.End()) / 2
272+
273+
want := []*ast.Comment{commentA}
274+
testCases := []struct {
275+
name string
276+
start, end token.Pos
277+
want []*ast.Comment
278+
}{
279+
{name: "comment totally overlaps with given interval", start: f.Pos(), end: f.End(), want: want},
280+
{name: "interval from file start to mid of comment A", start: f.Pos(), end: commentAMidPos, want: want},
281+
{name: "interval from mid of comment A to file end", start: commentAMidPos, end: commentA.End(), want: want},
282+
{name: "interval from start of comment A to mid of comment A", start: commentA.Pos(), end: commentAMidPos, want: want},
283+
{name: "interval from mid of comment A to comment A end", start: commentAMidPos, end: commentA.End(), want: want},
284+
{name: "interval at the start of comment A", start: commentA.Pos(), end: commentA.Pos(), want: want},
285+
{name: "interval at the end of comment A", start: commentA.End(), end: commentA.End(), want: want},
286+
{name: "interval from file start to the front of comment A start", start: f.Pos(), end: commentA.Pos() - 1, want: nil},
287+
{name: "interval from the position after end of comment A to file end", start: commentA.End() + 1, end: f.End(), want: nil},
288+
}
289+
for _, tc := range testCases {
290+
t.Run(tc.name, func(t *testing.T) {
291+
var got []*ast.Comment
292+
for co := range Comments(f, tc.start, tc.end) {
293+
got = append(got, co)
294+
}
295+
if !slices.Equal(got, tc.want) {
296+
t.Errorf("%s: got %v, want %v", tc.name, got, tc.want)
297+
}
298+
})
299+
}
300+
}

0 commit comments

Comments
 (0)