|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "fmt" |
| 6 | + "go/format" |
| 7 | + "go/token" |
| 8 | + "log" |
| 9 | + "os" |
| 10 | + "path" |
| 11 | + "path/filepath" |
| 12 | + "sort" |
| 13 | + "strconv" |
| 14 | + "strings" |
| 15 | + |
| 16 | + "github.com/ssoor/implgen/model" |
| 17 | +) |
| 18 | + |
| 19 | +type generator struct { |
| 20 | + buf bytes.Buffer |
| 21 | + head bool |
| 22 | + dstFileName string |
| 23 | + indent string |
| 24 | + mockNames map[string]string // may be empty |
| 25 | + filename string // may be empty |
| 26 | + srcPackage, srcInterfaces string // may be empty |
| 27 | + copyrightHeader string |
| 28 | + |
| 29 | + packageMap map[string]string // map from import path to package name |
| 30 | +} |
| 31 | + |
| 32 | +func (g *generator) p(format string, args ...interface{}) { |
| 33 | + fmt.Fprintf(&g.buf, g.indent+format+"\n", args...) |
| 34 | +} |
| 35 | +func (g *generator) pf(format string, args ...interface{}) { |
| 36 | + fmt.Fprintf(&g.buf, g.indent+format, args...) |
| 37 | +} |
| 38 | + |
| 39 | +func (g *generator) in() { |
| 40 | + g.indent += "\t" |
| 41 | +} |
| 42 | + |
| 43 | +func (g *generator) out() { |
| 44 | + if len(g.indent) > 0 { |
| 45 | + g.indent = g.indent[0 : len(g.indent)-1] |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error { |
| 50 | + dstPkg, err := sourceMode(g.dstFileName) |
| 51 | + |
| 52 | + if err != nil { |
| 53 | + g.head = true |
| 54 | + g.generateHead(pkg, outputPkgName, outputPackagePath) |
| 55 | + } else { |
| 56 | + namesMap := make(map[string]*model.Struct) |
| 57 | + for _, sn := range dstPkg.StructNames { |
| 58 | + namesMap[sn.Name] = sn |
| 59 | + } |
| 60 | + |
| 61 | + newInterfaces := make([]*model.Interface, 0) |
| 62 | + for _, intf := range pkg.Interfaces { |
| 63 | + sn, exist := namesMap[g.mockName(intf.Name)] |
| 64 | + if exist { |
| 65 | + newMethods := make([]*model.Method, 0) |
| 66 | + for _, m := range intf.Methods { |
| 67 | + if _, exist = sn.Methods[m.Name]; exist { |
| 68 | + continue |
| 69 | + } |
| 70 | + newMethods = append(newMethods, m) |
| 71 | + } |
| 72 | + |
| 73 | + if 0 != len(newMethods) { |
| 74 | + intf.Methods = newMethods |
| 75 | + mockType := g.mockName(intf.Name) |
| 76 | + g.GenerateMockMethods(mockType, intf, outputPackagePath) |
| 77 | + } |
| 78 | + } else { |
| 79 | + newInterfaces = append(newInterfaces, intf) |
| 80 | + } |
| 81 | + } |
| 82 | + |
| 83 | + pkg.Interfaces = newInterfaces |
| 84 | + fmt.Printf("%+v-%+v-%+v\n", dstPkg.Interfaces, namesMap, pkg.Interfaces) |
| 85 | + } |
| 86 | + |
| 87 | + return g.generate(pkg, outputPkgName, outputPackagePath) |
| 88 | +} |
| 89 | + |
| 90 | +func (g *generator) generateHead(pkg *model.Package, outputPkgName string, outputPackagePath string) { |
| 91 | + if outputPkgName != pkg.Name && *selfPackage == "" { |
| 92 | + // reset outputPackagePath if it's not passed in through -self_package |
| 93 | + outputPackagePath = "" |
| 94 | + } |
| 95 | + |
| 96 | + if g.copyrightHeader != "" { |
| 97 | + lines := strings.Split(g.copyrightHeader, "\n") |
| 98 | + for _, line := range lines { |
| 99 | + g.p("// %s", line) |
| 100 | + } |
| 101 | + g.p("") |
| 102 | + } |
| 103 | + |
| 104 | + g.p("// Code generated by ImplGen.") |
| 105 | + if g.filename != "" { |
| 106 | + g.p("// Source: %v", g.filename) |
| 107 | + } else { |
| 108 | + g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces) |
| 109 | + } |
| 110 | + g.p("") |
| 111 | + |
| 112 | + // Get all required imports, and generate unique names for them all. |
| 113 | + im := pkg.Imports() |
| 114 | + |
| 115 | + // Only import reflect if it's used. We only use reflect in mocked methods |
| 116 | + // so only import if any of the mocked interfaces have methods. |
| 117 | + for _, intf := range pkg.Interfaces { |
| 118 | + if len(intf.Methods) > 0 { |
| 119 | + break |
| 120 | + } |
| 121 | + } |
| 122 | + |
| 123 | + // Sort keys to make import alias generation predictable |
| 124 | + sortedPaths := make([]string, len(im)) |
| 125 | + x := 0 |
| 126 | + for pth := range im { |
| 127 | + sortedPaths[x] = pth |
| 128 | + x++ |
| 129 | + } |
| 130 | + sort.Strings(sortedPaths) |
| 131 | + |
| 132 | + packagesName := createPackageMap(sortedPaths) |
| 133 | + |
| 134 | + g.packageMap = make(map[string]string, len(im)) |
| 135 | + localNames := make(map[string]bool, len(im)) |
| 136 | + for _, pth := range sortedPaths { |
| 137 | + base, ok := packagesName[pth] |
| 138 | + if !ok { |
| 139 | + base = sanitize(path.Base(pth)) |
| 140 | + } |
| 141 | + |
| 142 | + // Local names for an imported package can usually be the basename of the import path. |
| 143 | + // A couple of situations don't permit that, such as duplicate local names |
| 144 | + // (e.g. importing "html/template" and "text/template"), or where the basename is |
| 145 | + // a keyword (e.g. "foo/case"). |
| 146 | + // try base0, base1, ... |
| 147 | + pkgName := base |
| 148 | + i := 0 |
| 149 | + for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() { |
| 150 | + pkgName = base + strconv.Itoa(i) |
| 151 | + i++ |
| 152 | + } |
| 153 | + |
| 154 | + // Avoid importing package if source pkg == output pkg |
| 155 | + if pth == pkg.PkgPath && outputPkgName == pkg.Name { |
| 156 | + continue |
| 157 | + } |
| 158 | + |
| 159 | + g.packageMap[pth] = pkgName |
| 160 | + localNames[pkgName] = true |
| 161 | + } |
| 162 | + |
| 163 | + if *writePkgComment { |
| 164 | + g.p("// Package %v is a generated ImplGen package.", outputPkgName) |
| 165 | + } |
| 166 | + g.p("package %v", outputPkgName) |
| 167 | + g.p("") |
| 168 | + g.p("import (") |
| 169 | + g.in() |
| 170 | + for pkgPath, pkgName := range g.packageMap { |
| 171 | + if pkgPath == outputPackagePath { |
| 172 | + continue |
| 173 | + } |
| 174 | + g.p("%v %q", pkgName, pkgPath) |
| 175 | + } |
| 176 | + for _, pkgPath := range pkg.DotImports { |
| 177 | + g.p(". %q", pkgPath) |
| 178 | + } |
| 179 | + g.out() |
| 180 | + g.p(")") |
| 181 | +} |
| 182 | + |
| 183 | +func (g *generator) generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error { |
| 184 | + for _, intf := range pkg.Interfaces { |
| 185 | + if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil { |
| 186 | + return err |
| 187 | + } |
| 188 | + } |
| 189 | + |
| 190 | + return nil |
| 191 | +} |
| 192 | + |
| 193 | +// The name of the mock type to use for the given interface identifier. |
| 194 | +func (g *generator) mockName(typeName string) string { |
| 195 | + if mockName, ok := g.mockNames[typeName]; ok { |
| 196 | + return mockName |
| 197 | + } |
| 198 | + |
| 199 | + suffix := "Interface" |
| 200 | + if suffix == typeName[len(typeName)-len(suffix):] { |
| 201 | + return typeName[:len(typeName)-len(suffix)] |
| 202 | + } |
| 203 | + |
| 204 | + return typeName |
| 205 | +} |
| 206 | + |
| 207 | +func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error { |
| 208 | + mockType := g.mockName(intf.Name) |
| 209 | + |
| 210 | + g.p("") |
| 211 | + for _, doc := range intf.Doc { |
| 212 | + g.p("%v", doc) |
| 213 | + } |
| 214 | + if 0 == len(intf.Comment) { |
| 215 | + g.p("type %v struct {", mockType) |
| 216 | + } else { |
| 217 | + g.p("type %v struct { // %v", mockType, intf.Comment) |
| 218 | + } |
| 219 | + g.in() |
| 220 | + g.out() |
| 221 | + g.p("}") |
| 222 | + g.p("") |
| 223 | + |
| 224 | + // TODO: Re-enable this if we can import the interface reliably. |
| 225 | + // g.p("// Verify that the mock satisfies the interface at compile time.") |
| 226 | + // g.p("var _ %v = (*%v)(nil)", typeName, mockType) |
| 227 | + // g.p("") |
| 228 | + |
| 229 | + for _, doc := range intf.Doc { |
| 230 | + g.p("%v", doc) |
| 231 | + } |
| 232 | + if 0 == len(intf.Comment) { |
| 233 | + g.p("func New%v() *%v {", mockType, mockType) |
| 234 | + } else { |
| 235 | + g.p("func New%v() *%v { // %v", mockType, mockType, intf.Comment) |
| 236 | + } |
| 237 | + |
| 238 | + g.in() |
| 239 | + g.p("interfaceImpl := &%v{}", mockType) |
| 240 | + g.p("") |
| 241 | + g.p("// TODO: ...") |
| 242 | + g.p("") |
| 243 | + g.p("return interfaceImpl") |
| 244 | + g.out() |
| 245 | + g.p("}") |
| 246 | + g.p("") |
| 247 | + |
| 248 | + g.GenerateMockMethods(mockType, intf, outputPackagePath) |
| 249 | + |
| 250 | + return nil |
| 251 | +} |
| 252 | + |
| 253 | +func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) { |
| 254 | + for _, m := range intf.Methods { |
| 255 | + g.p("") |
| 256 | + _ = g.GenerateMockMethod(mockType, m, pkgOverride) |
| 257 | + } |
| 258 | +} |
| 259 | + |
| 260 | +// GenerateMockMethod generates a mock method implementation. |
| 261 | +// If non-empty, pkgOverride is the package in which unqualified types reside. |
| 262 | +func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error { |
| 263 | + argNames := g.getArgNames(m) |
| 264 | + argTypes := g.getArgTypes(m, pkgOverride) |
| 265 | + argString := makeArgString(argNames, argTypes) |
| 266 | + |
| 267 | + rets := make([]string, len(m.Out)) |
| 268 | + for i, p := range m.Out { |
| 269 | + rets[i] = p.Type.String(g.packageMap, pkgOverride) |
| 270 | + } |
| 271 | + retString := strings.Join(rets, ", ") |
| 272 | + if len(rets) > 1 { |
| 273 | + retString = "(" + retString + ")" |
| 274 | + } |
| 275 | + if retString != "" { |
| 276 | + retString = " " + retString |
| 277 | + } |
| 278 | + |
| 279 | + ia := newIdentifierAllocator(argNames) |
| 280 | + idRecv := ia.allocateIdentifier("m") |
| 281 | + |
| 282 | + for _, doc := range m.Doc { |
| 283 | + g.p("%v", doc) |
| 284 | + } |
| 285 | + if 0 == len(m.Comment) { |
| 286 | + g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString) |
| 287 | + } else { |
| 288 | + g.pf("func (%v *%v) %v(%v)%v { // %v", idRecv, mockType, m.Name, argString, retString, m.Comment) |
| 289 | + } |
| 290 | + |
| 291 | + g.in() |
| 292 | + |
| 293 | + g.p("panic(\"*%v.%v(%v)%v Not implemented\")", mockType, m.Name, argString, retString) |
| 294 | + g.out() |
| 295 | + g.p("}") |
| 296 | + return nil |
| 297 | +} |
| 298 | + |
| 299 | +func (g *generator) getArgNames(m *model.Method) []string { |
| 300 | + argNames := make([]string, len(m.In)) |
| 301 | + for i, p := range m.In { |
| 302 | + name := p.Name |
| 303 | + if name == "" || name == "_" { |
| 304 | + name = fmt.Sprintf("arg%d", i) |
| 305 | + } |
| 306 | + argNames[i] = name |
| 307 | + } |
| 308 | + if m.Variadic != nil { |
| 309 | + name := m.Variadic.Name |
| 310 | + if name == "" { |
| 311 | + name = fmt.Sprintf("arg%d", len(m.In)) |
| 312 | + } |
| 313 | + argNames = append(argNames, name) |
| 314 | + } |
| 315 | + return argNames |
| 316 | +} |
| 317 | + |
| 318 | +func (g *generator) getArgTypes(m *model.Method, pkgOverride string) []string { |
| 319 | + argTypes := make([]string, len(m.In)) |
| 320 | + for i, p := range m.In { |
| 321 | + argTypes[i] = p.Type.String(g.packageMap, pkgOverride) |
| 322 | + } |
| 323 | + if m.Variadic != nil { |
| 324 | + argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride)) |
| 325 | + } |
| 326 | + return argTypes |
| 327 | +} |
| 328 | + |
| 329 | +// Output returns the generator's output, formatted in the standard Go style. |
| 330 | +func (g *generator) Output() (n int, err error) { |
| 331 | + src, err := format.Source(g.buf.Bytes()) |
| 332 | + if err != nil { |
| 333 | + log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String()) |
| 334 | + } |
| 335 | + |
| 336 | + dst := os.Stdout |
| 337 | + if len(g.dstFileName) > 0 { |
| 338 | + if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil { |
| 339 | + log.Fatalf("Unable to create directory: %v", err) |
| 340 | + } |
| 341 | + var f *os.File |
| 342 | + var err error |
| 343 | + if g.head { |
| 344 | + f, err = os.Create(*destination) |
| 345 | + } else { |
| 346 | + f, err = os.OpenFile(*destination, os.O_RDWR|os.O_APPEND, 0666) |
| 347 | + } |
| 348 | + |
| 349 | + if err != nil { |
| 350 | + log.Fatalf("Failed opening destination file: %v", err) |
| 351 | + } |
| 352 | + defer dst.Close() |
| 353 | + dst = f |
| 354 | + } |
| 355 | + |
| 356 | + return dst.Write(src) |
| 357 | +} |
0 commit comments