Skip to content

Commit f6057e6

Browse files
author
nieml
committed
feature: 如果生成时文件已存在,则只生成新加的部分
1 parent 1e73a68 commit f6057e6

File tree

7 files changed

+654
-323
lines changed

7 files changed

+654
-323
lines changed

generator.go

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"go/format"
7+
"go/token"
8+
"log"
9+
"os"
10+
"path"
11+
"path/filepath"
12+
"sort"
13+
"strconv"
14+
"strings"
15+
16+
"github.com/ssoor/implgen/model"
17+
)
18+
19+
type generator struct {
20+
buf bytes.Buffer
21+
head bool
22+
dstFileName string
23+
indent string
24+
mockNames map[string]string // may be empty
25+
filename string // may be empty
26+
srcPackage, srcInterfaces string // may be empty
27+
copyrightHeader string
28+
29+
packageMap map[string]string // map from import path to package name
30+
}
31+
32+
func (g *generator) p(format string, args ...interface{}) {
33+
fmt.Fprintf(&g.buf, g.indent+format+"\n", args...)
34+
}
35+
func (g *generator) pf(format string, args ...interface{}) {
36+
fmt.Fprintf(&g.buf, g.indent+format, args...)
37+
}
38+
39+
func (g *generator) in() {
40+
g.indent += "\t"
41+
}
42+
43+
func (g *generator) out() {
44+
if len(g.indent) > 0 {
45+
g.indent = g.indent[0 : len(g.indent)-1]
46+
}
47+
}
48+
49+
func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error {
50+
dstPkg, err := sourceMode(g.dstFileName)
51+
52+
if err != nil {
53+
g.head = true
54+
g.generateHead(pkg, outputPkgName, outputPackagePath)
55+
} else {
56+
namesMap := make(map[string]*model.Struct)
57+
for _, sn := range dstPkg.StructNames {
58+
namesMap[sn.Name] = sn
59+
}
60+
61+
newInterfaces := make([]*model.Interface, 0)
62+
for _, intf := range pkg.Interfaces {
63+
sn, exist := namesMap[g.mockName(intf.Name)]
64+
if exist {
65+
newMethods := make([]*model.Method, 0)
66+
for _, m := range intf.Methods {
67+
if _, exist = sn.Methods[m.Name]; exist {
68+
continue
69+
}
70+
newMethods = append(newMethods, m)
71+
}
72+
73+
if 0 != len(newMethods) {
74+
intf.Methods = newMethods
75+
mockType := g.mockName(intf.Name)
76+
g.GenerateMockMethods(mockType, intf, outputPackagePath)
77+
}
78+
} else {
79+
newInterfaces = append(newInterfaces, intf)
80+
}
81+
}
82+
83+
pkg.Interfaces = newInterfaces
84+
fmt.Printf("%+v-%+v-%+v\n", dstPkg.Interfaces, namesMap, pkg.Interfaces)
85+
}
86+
87+
return g.generate(pkg, outputPkgName, outputPackagePath)
88+
}
89+
90+
func (g *generator) generateHead(pkg *model.Package, outputPkgName string, outputPackagePath string) {
91+
if outputPkgName != pkg.Name && *selfPackage == "" {
92+
// reset outputPackagePath if it's not passed in through -self_package
93+
outputPackagePath = ""
94+
}
95+
96+
if g.copyrightHeader != "" {
97+
lines := strings.Split(g.copyrightHeader, "\n")
98+
for _, line := range lines {
99+
g.p("// %s", line)
100+
}
101+
g.p("")
102+
}
103+
104+
g.p("// Code generated by ImplGen.")
105+
if g.filename != "" {
106+
g.p("// Source: %v", g.filename)
107+
} else {
108+
g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces)
109+
}
110+
g.p("")
111+
112+
// Get all required imports, and generate unique names for them all.
113+
im := pkg.Imports()
114+
115+
// Only import reflect if it's used. We only use reflect in mocked methods
116+
// so only import if any of the mocked interfaces have methods.
117+
for _, intf := range pkg.Interfaces {
118+
if len(intf.Methods) > 0 {
119+
break
120+
}
121+
}
122+
123+
// Sort keys to make import alias generation predictable
124+
sortedPaths := make([]string, len(im))
125+
x := 0
126+
for pth := range im {
127+
sortedPaths[x] = pth
128+
x++
129+
}
130+
sort.Strings(sortedPaths)
131+
132+
packagesName := createPackageMap(sortedPaths)
133+
134+
g.packageMap = make(map[string]string, len(im))
135+
localNames := make(map[string]bool, len(im))
136+
for _, pth := range sortedPaths {
137+
base, ok := packagesName[pth]
138+
if !ok {
139+
base = sanitize(path.Base(pth))
140+
}
141+
142+
// Local names for an imported package can usually be the basename of the import path.
143+
// A couple of situations don't permit that, such as duplicate local names
144+
// (e.g. importing "html/template" and "text/template"), or where the basename is
145+
// a keyword (e.g. "foo/case").
146+
// try base0, base1, ...
147+
pkgName := base
148+
i := 0
149+
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
150+
pkgName = base + strconv.Itoa(i)
151+
i++
152+
}
153+
154+
// Avoid importing package if source pkg == output pkg
155+
if pth == pkg.PkgPath && outputPkgName == pkg.Name {
156+
continue
157+
}
158+
159+
g.packageMap[pth] = pkgName
160+
localNames[pkgName] = true
161+
}
162+
163+
if *writePkgComment {
164+
g.p("// Package %v is a generated ImplGen package.", outputPkgName)
165+
}
166+
g.p("package %v", outputPkgName)
167+
g.p("")
168+
g.p("import (")
169+
g.in()
170+
for pkgPath, pkgName := range g.packageMap {
171+
if pkgPath == outputPackagePath {
172+
continue
173+
}
174+
g.p("%v %q", pkgName, pkgPath)
175+
}
176+
for _, pkgPath := range pkg.DotImports {
177+
g.p(". %q", pkgPath)
178+
}
179+
g.out()
180+
g.p(")")
181+
}
182+
183+
func (g *generator) generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error {
184+
for _, intf := range pkg.Interfaces {
185+
if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil {
186+
return err
187+
}
188+
}
189+
190+
return nil
191+
}
192+
193+
// The name of the mock type to use for the given interface identifier.
194+
func (g *generator) mockName(typeName string) string {
195+
if mockName, ok := g.mockNames[typeName]; ok {
196+
return mockName
197+
}
198+
199+
suffix := "Interface"
200+
if suffix == typeName[len(typeName)-len(suffix):] {
201+
return typeName[:len(typeName)-len(suffix)]
202+
}
203+
204+
return typeName
205+
}
206+
207+
func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error {
208+
mockType := g.mockName(intf.Name)
209+
210+
g.p("")
211+
for _, doc := range intf.Doc {
212+
g.p("%v", doc)
213+
}
214+
if 0 == len(intf.Comment) {
215+
g.p("type %v struct {", mockType)
216+
} else {
217+
g.p("type %v struct { // %v", mockType, intf.Comment)
218+
}
219+
g.in()
220+
g.out()
221+
g.p("}")
222+
g.p("")
223+
224+
// TODO: Re-enable this if we can import the interface reliably.
225+
// g.p("// Verify that the mock satisfies the interface at compile time.")
226+
// g.p("var _ %v = (*%v)(nil)", typeName, mockType)
227+
// g.p("")
228+
229+
for _, doc := range intf.Doc {
230+
g.p("%v", doc)
231+
}
232+
if 0 == len(intf.Comment) {
233+
g.p("func New%v() *%v {", mockType, mockType)
234+
} else {
235+
g.p("func New%v() *%v { // %v", mockType, mockType, intf.Comment)
236+
}
237+
238+
g.in()
239+
g.p("interfaceImpl := &%v{}", mockType)
240+
g.p("")
241+
g.p("// TODO: ...")
242+
g.p("")
243+
g.p("return interfaceImpl")
244+
g.out()
245+
g.p("}")
246+
g.p("")
247+
248+
g.GenerateMockMethods(mockType, intf, outputPackagePath)
249+
250+
return nil
251+
}
252+
253+
func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) {
254+
for _, m := range intf.Methods {
255+
g.p("")
256+
_ = g.GenerateMockMethod(mockType, m, pkgOverride)
257+
}
258+
}
259+
260+
// GenerateMockMethod generates a mock method implementation.
261+
// If non-empty, pkgOverride is the package in which unqualified types reside.
262+
func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error {
263+
argNames := g.getArgNames(m)
264+
argTypes := g.getArgTypes(m, pkgOverride)
265+
argString := makeArgString(argNames, argTypes)
266+
267+
rets := make([]string, len(m.Out))
268+
for i, p := range m.Out {
269+
rets[i] = p.Type.String(g.packageMap, pkgOverride)
270+
}
271+
retString := strings.Join(rets, ", ")
272+
if len(rets) > 1 {
273+
retString = "(" + retString + ")"
274+
}
275+
if retString != "" {
276+
retString = " " + retString
277+
}
278+
279+
ia := newIdentifierAllocator(argNames)
280+
idRecv := ia.allocateIdentifier("m")
281+
282+
for _, doc := range m.Doc {
283+
g.p("%v", doc)
284+
}
285+
if 0 == len(m.Comment) {
286+
g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString)
287+
} else {
288+
g.pf("func (%v *%v) %v(%v)%v { // %v", idRecv, mockType, m.Name, argString, retString, m.Comment)
289+
}
290+
291+
g.in()
292+
293+
g.p("panic(\"*%v.%v(%v)%v Not implemented\")", mockType, m.Name, argString, retString)
294+
g.out()
295+
g.p("}")
296+
return nil
297+
}
298+
299+
func (g *generator) getArgNames(m *model.Method) []string {
300+
argNames := make([]string, len(m.In))
301+
for i, p := range m.In {
302+
name := p.Name
303+
if name == "" || name == "_" {
304+
name = fmt.Sprintf("arg%d", i)
305+
}
306+
argNames[i] = name
307+
}
308+
if m.Variadic != nil {
309+
name := m.Variadic.Name
310+
if name == "" {
311+
name = fmt.Sprintf("arg%d", len(m.In))
312+
}
313+
argNames = append(argNames, name)
314+
}
315+
return argNames
316+
}
317+
318+
func (g *generator) getArgTypes(m *model.Method, pkgOverride string) []string {
319+
argTypes := make([]string, len(m.In))
320+
for i, p := range m.In {
321+
argTypes[i] = p.Type.String(g.packageMap, pkgOverride)
322+
}
323+
if m.Variadic != nil {
324+
argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride))
325+
}
326+
return argTypes
327+
}
328+
329+
// Output returns the generator's output, formatted in the standard Go style.
330+
func (g *generator) Output() (n int, err error) {
331+
src, err := format.Source(g.buf.Bytes())
332+
if err != nil {
333+
log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String())
334+
}
335+
336+
dst := os.Stdout
337+
if len(g.dstFileName) > 0 {
338+
if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil {
339+
log.Fatalf("Unable to create directory: %v", err)
340+
}
341+
var f *os.File
342+
var err error
343+
if g.head {
344+
f, err = os.Create(*destination)
345+
} else {
346+
f, err = os.OpenFile(*destination, os.O_RDWR|os.O_APPEND, 0666)
347+
}
348+
349+
if err != nil {
350+
log.Fatalf("Failed opening destination file: %v", err)
351+
}
352+
defer dst.Close()
353+
dst = f
354+
}
355+
356+
return dst.Write(src)
357+
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module github.com/ssoor/implgen
33
go 1.13
44

55
require (
6+
github.com/gobuffalo/packr/v2 v2.8.0
67
github.com/golang/mock v1.4.3
78
golang.org/x/mod v0.3.0
89
golang.org/x/tools v0.0.0-20200612220849-54c614fe050c

0 commit comments

Comments
 (0)