Skip to content

Commit fa8abe3

Browse files
committed
gopls/internal: add code action "extract declarations to new file"
This code action moves selected code sections to a newly created file within the same package. The created filename is chosen as the first {function, type, const, var} name encountered. In addition, import declarations are added or removed as needed. Fixes golang/go#65707
1 parent abe5874 commit fa8abe3

File tree

8 files changed

+920
-10
lines changed

8 files changed

+920
-10
lines changed

gopls/internal/golang/codeaction.go

+11-4
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,11 @@ func fixedByImportFix(fix *imports.ImportFix, diagnostics []protocol.Diagnostic)
181181

182182
// getExtractCodeActions returns any refactor.extract code actions for the selection.
183183
func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) {
184-
if rng.Start == rng.End {
185-
return nil, nil
186-
}
187-
188184
start, end, err := pgf.RangePos(rng)
189185
if err != nil {
190186
return nil, err
191187
}
188+
192189
puri := pgf.URI
193190
var commands []protocol.Command
194191
if _, ok, methodOk, _ := CanExtractFunction(pgf.Tok, start, end, pgf.Src, pgf.File); ok {
@@ -227,6 +224,16 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti
227224
}
228225
commands = append(commands, cmd)
229226
}
227+
if canExtractToNewFile(pgf, start, end) {
228+
cmd, err := command.NewExtractToNewFileCommand(
229+
"Extract declarations to new file",
230+
command.ExtractToNewFileArgs{URI: pgf.URI, Range: rng},
231+
)
232+
if err != nil {
233+
return nil, err
234+
}
235+
commands = append(commands, cmd)
236+
}
230237
var actions []protocol.CodeAction
231238
for i := range commands {
232239
actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorExtract, &commands[i], nil, options))
+322
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package golang
6+
7+
// This file defines the code action "extract to a new file".
8+
9+
// todo: rename file to extract_to_new_file.go after code review
10+
11+
import (
12+
"context"
13+
"errors"
14+
"fmt"
15+
"go/ast"
16+
"go/format"
17+
"go/token"
18+
"go/types"
19+
"os"
20+
"path/filepath"
21+
"strings"
22+
23+
"golang.org/x/tools/gopls/internal/cache"
24+
"golang.org/x/tools/gopls/internal/cache/parsego"
25+
"golang.org/x/tools/gopls/internal/file"
26+
"golang.org/x/tools/gopls/internal/protocol"
27+
"golang.org/x/tools/gopls/internal/util/bug"
28+
"golang.org/x/tools/gopls/internal/util/typesutil"
29+
)
30+
31+
// canExtractToNewFile reports whether the code in the given range can be extracted to a new file.
32+
func canExtractToNewFile(pgf *parsego.File, start, end token.Pos) bool {
33+
_, _, _, ok := selectedToplevelDecls(pgf, start, end)
34+
return ok
35+
}
36+
37+
// findImportEdits finds imports specs that needs to be added to the new file
38+
// or deleted from the old file if the range is extracted to a new file.
39+
//
40+
// TODO: handle dot imports
41+
func findImportEdits(file *ast.File, info *types.Info, start, end token.Pos) (adds []*ast.ImportSpec, deletes []*ast.ImportSpec) {
42+
// make a map from a pkgName to its references
43+
pkgNameReferences := make(map[*types.PkgName][]*ast.Ident)
44+
for ident, use := range info.Uses {
45+
if pkgName, ok := use.(*types.PkgName); ok {
46+
pkgNameReferences[pkgName] = append(pkgNameReferences[pkgName], ident)
47+
}
48+
}
49+
50+
// PkgName referenced in the extracted selection must be
51+
// imported in the new file.
52+
// PkgName only refereced in the extracted selection must be
53+
// deleted from the original file.
54+
for _, spec := range file.Imports {
55+
pkgName, ok := typesutil.ImportedPkgName(info, spec)
56+
if !ok {
57+
continue
58+
}
59+
usedInSelection := false
60+
usedInNonSelection := false
61+
for _, ident := range pkgNameReferences[pkgName] {
62+
if contain(start, end, ident.Pos(), ident.End()) {
63+
usedInSelection = true
64+
} else {
65+
usedInNonSelection = true
66+
}
67+
}
68+
if usedInSelection {
69+
adds = append(adds, spec)
70+
}
71+
if usedInSelection && !usedInNonSelection {
72+
deletes = append(deletes, spec)
73+
}
74+
}
75+
76+
return adds, deletes
77+
}
78+
79+
// ExtractToNewFile moves selected declarations into a new file.
80+
func ExtractToNewFile(
81+
ctx context.Context,
82+
snapshot *cache.Snapshot,
83+
fh file.Handle,
84+
rng protocol.Range,
85+
) (*protocol.WorkspaceEdit, error) {
86+
errorPrefix := "ExtractToNewFile"
87+
88+
pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI())
89+
if err != nil {
90+
return nil, err
91+
}
92+
93+
start, end, err := pgf.RangePos(rng)
94+
if err != nil {
95+
return nil, fmt.Errorf("%s: %w", errorPrefix, err)
96+
}
97+
98+
start, end, filename, ok := selectedToplevelDecls(pgf, start, end)
99+
if !ok {
100+
return nil, bug.Errorf("precondition unmet")
101+
}
102+
103+
end = skipWhiteSpaces(pgf, end)
104+
105+
replaceRange, err := pgf.PosRange(start, end)
106+
if err != nil {
107+
return nil, bug.Errorf("findRangeAndFilename returned invalid range: %v", err)
108+
}
109+
110+
adds, deletes := findImportEdits(pgf.File, pkg.GetTypesInfo(), start, end)
111+
112+
var importDeletes []protocol.TextEdit
113+
// For unparenthesised declarations like `import "fmt"` we remove
114+
// the whole declaration because simply removing importSpec leaves
115+
// `import \n`, which does not compile.
116+
// For parenthesised declarations like `import ("fmt"\n "log")`
117+
// we only remove the ImportSpec, because removing the whole declaration
118+
// might remove other ImportsSpecs we don't want to touch.
119+
parenthesisFreeImports := findParenthesisFreeImports(pgf)
120+
for _, importSpec := range deletes {
121+
if decl := parenthesisFreeImports[importSpec]; decl != nil {
122+
importDeletes = append(importDeletes, removeNode(pgf, decl))
123+
} else {
124+
importDeletes = append(importDeletes, removeNode(pgf, importSpec))
125+
}
126+
}
127+
128+
importAdds := ""
129+
if len(adds) > 0 {
130+
importAdds += "import ("
131+
for _, importSpec := range adds {
132+
if importSpec.Name != nil {
133+
importAdds += importSpec.Name.Name + " " + importSpec.Path.Value + "\n"
134+
} else {
135+
importAdds += importSpec.Path.Value + "\n"
136+
}
137+
}
138+
importAdds += ")"
139+
}
140+
141+
newFileURI, err := resolveNewFileURI(ctx, snapshot, pgf.URI.Dir().Path(), filename)
142+
if err != nil {
143+
return nil, fmt.Errorf("%s: %w", errorPrefix, err)
144+
}
145+
146+
// TODO: attempt to duplicate the copyright header, if any.
147+
newFileContent, err := format.Source([]byte(
148+
"package " + pgf.File.Name.Name + "\n" +
149+
importAdds + "\n" +
150+
string(pgf.Src[start-pgf.File.FileStart:end-pgf.File.FileStart]),
151+
))
152+
if err != nil {
153+
return nil, err
154+
}
155+
156+
return &protocol.WorkspaceEdit{
157+
DocumentChanges: []protocol.DocumentChanges{
158+
// original file edits
159+
protocol.TextEditsToDocumentChanges(fh.URI(), fh.Version(), append(
160+
importDeletes,
161+
protocol.TextEdit{
162+
Range: replaceRange,
163+
NewText: "",
164+
},
165+
))[0],
166+
{
167+
CreateFile: &protocol.CreateFile{
168+
Kind: "create",
169+
URI: newFileURI,
170+
},
171+
},
172+
// created file edits
173+
protocol.TextEditsToDocumentChanges(newFileURI, 0, []protocol.TextEdit{
174+
{
175+
Range: protocol.Range{},
176+
NewText: string(newFileContent),
177+
},
178+
})[0],
179+
},
180+
}, nil
181+
}
182+
183+
// resolveNewFileURI checks that basename.go does not exists in dir, otherwise
184+
// select basename.{1,2,3,4,5}.go as filename.
185+
func resolveNewFileURI(ctx context.Context, snapshot *cache.Snapshot, dir string, basename string) (protocol.DocumentURI, error) {
186+
basename = strings.ToLower(basename)
187+
newPath := protocol.URIFromPath(filepath.Join(dir, basename+".go"))
188+
for count := 1; ; count++ {
189+
fh, err := snapshot.ReadFile(ctx, newPath)
190+
if err != nil {
191+
return "", nil
192+
}
193+
if _, err := fh.Content(); errors.Is(err, os.ErrNotExist) {
194+
break
195+
}
196+
if count >= 5 {
197+
return "", fmt.Errorf("resolveNewFileURI: exceeded retry limit")
198+
}
199+
filename := fmt.Sprintf("%s.%d.go", basename, count)
200+
newPath = protocol.URIFromPath(filepath.Join(dir, filename))
201+
}
202+
return newPath, nil
203+
}
204+
205+
// selectedToplevelDecls returns the lexical extent of the top-level
206+
// declarations enclosed by [start, end), along with the name of the
207+
// first declaration. The returned boolean reports whether the selection
208+
// should be offered code action.
209+
func selectedToplevelDecls(pgf *parsego.File, start, end token.Pos) (token.Pos, token.Pos, string, bool) {
210+
// selection cannot intersect a package declaration
211+
if intersect(start, end, pgf.File.Package, pgf.File.Name.End()) {
212+
return 0, 0, "", false
213+
}
214+
firstName := ""
215+
for _, decl := range pgf.File.Decls {
216+
if intersect(start, end, decl.Pos(), decl.End()) {
217+
var id *ast.Ident
218+
switch v := decl.(type) {
219+
case *ast.BadDecl:
220+
return 0, 0, "", false
221+
case *ast.FuncDecl:
222+
// if only selecting keyword "func" or function name, extend selection to the
223+
// whole function
224+
if contain(v.Pos(), v.Name.End(), start, end) {
225+
start, end = v.Pos(), v.End()
226+
}
227+
id = v.Name
228+
case *ast.GenDecl:
229+
// selection cannot intersect an import declaration
230+
if v.Tok == token.IMPORT {
231+
return 0, 0, "", false
232+
}
233+
// if only selecting keyword "type", "const", or "var", extend selection to the
234+
// whole declaration
235+
if v.Tok == token.TYPE && contain(v.Pos(), v.Pos()+4, start, end) ||
236+
v.Tok == token.CONST && contain(v.Pos(), v.Pos()+5, start, end) ||
237+
v.Tok == token.VAR && contain(v.Pos(), v.Pos()+3, start, end) {
238+
start, end = v.Pos(), v.End()
239+
}
240+
if len(v.Specs) > 0 {
241+
switch spec := v.Specs[0].(type) {
242+
case *ast.TypeSpec:
243+
id = spec.Name
244+
case *ast.ValueSpec:
245+
id = spec.Names[0]
246+
}
247+
}
248+
}
249+
// selection cannot partially intersect a node
250+
if !contain(start, end, decl.Pos(), decl.End()) {
251+
return 0, 0, "", false
252+
}
253+
if id != nil && firstName == "" {
254+
firstName = id.Name
255+
}
256+
// extends selection to docs comments
257+
var c *ast.CommentGroup
258+
switch decl := decl.(type) {
259+
case *ast.GenDecl:
260+
c = decl.Doc
261+
case *ast.FuncDecl:
262+
c = decl.Doc
263+
}
264+
if c != nil && c.Pos() < start {
265+
start = c.Pos()
266+
}
267+
}
268+
}
269+
for _, comment := range pgf.File.Comments {
270+
if intersect(start, end, comment.Pos(), comment.End()) {
271+
if !contain(start, end, comment.Pos(), comment.End()) {
272+
// selection cannot partially intersect a comment
273+
return 0, 0, "", false
274+
}
275+
}
276+
}
277+
if firstName == "" {
278+
return 0, 0, "", false
279+
}
280+
return start, end, firstName, true
281+
}
282+
283+
func skipWhiteSpaces(pgf *parsego.File, pos token.Pos) token.Pos {
284+
i := pos
285+
for ; i-pgf.File.FileStart < token.Pos(len(pgf.Src)); i++ {
286+
c := pgf.Src[i-pgf.File.FileStart]
287+
if !(c == ' ' || c == '\t' || c == '\n') {
288+
break
289+
}
290+
}
291+
return i
292+
}
293+
294+
func findParenthesisFreeImports(pgf *parsego.File) map[*ast.ImportSpec]*ast.GenDecl {
295+
decls := make(map[*ast.ImportSpec]*ast.GenDecl)
296+
for _, decl := range pgf.File.Decls {
297+
if g, ok := decl.(*ast.GenDecl); ok {
298+
if !g.Lparen.IsValid() && len(g.Specs) > 0 {
299+
if v, ok := g.Specs[0].(*ast.ImportSpec); ok {
300+
decls[v] = g
301+
}
302+
}
303+
}
304+
}
305+
return decls
306+
}
307+
308+
// removeNode returns a TextEdit that removes the node
309+
func removeNode(pgf *parsego.File, node ast.Node) protocol.TextEdit {
310+
rng, _ := pgf.PosRange(node.Pos(), node.End())
311+
return protocol.TextEdit{Range: rng, NewText: ""}
312+
}
313+
314+
// intersect checks if [a, b) and [c, d) intersect, assuming a <= b and c <= d
315+
func intersect(a, b, c, d token.Pos) bool {
316+
return !(b <= c || d <= a)
317+
}
318+
319+
// contain checks if [a, b) contains [c, d), assuming a <= b and c <= d
320+
func contain(a, b, c, d token.Pos) bool {
321+
return a <= c && d <= b
322+
}

0 commit comments

Comments
 (0)