Skip to content

Commit 6c6979c

Browse files
committed
Generate accessors for structs with pointer fields
Fixes #45. Change-Id: Ib2b6cc5d713e2eb833ee3c7fcfbd804bfe8fa313
1 parent 2ec691a commit 6c6979c

File tree

3 files changed

+7523
-0
lines changed

3 files changed

+7523
-0
lines changed

cmd/tools/gen-accessors.go

+300
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
// Copyright 2017 The go-github AUTHORS. All rights reserved.
2+
//
3+
// Use of this source code is governed by a BSD-style
4+
// license that can be found in the LICENSE file.
5+
6+
// gen-accessors generates accessor methods for structs with pointer fields.
7+
//
8+
// It is meant to be used by the go-github authors in conjunction with the
9+
// go generate tool before sending a commit to GitHub.
10+
package main
11+
12+
import (
13+
"bytes"
14+
"flag"
15+
"fmt"
16+
"go/ast"
17+
"go/format"
18+
"go/parser"
19+
"go/token"
20+
"io/ioutil"
21+
"log"
22+
"os"
23+
"path/filepath"
24+
"sort"
25+
"strings"
26+
"text/template"
27+
"time"
28+
)
29+
30+
const (
31+
fileSuffix = "-accessors.go"
32+
)
33+
34+
var (
35+
verbose = flag.Bool("v", false, "Print verbose log messages")
36+
37+
sourceTmpl = template.Must(template.New("source").Parse(source))
38+
39+
// blacklist lists which "struct.method" combos to not generate.
40+
blacklist = map[string]bool{
41+
"RepositoryContent.GetContent": true,
42+
"Client.GetBaseURL": true,
43+
"Client.GetUploadURL": true,
44+
"ErrorResponse.GetResponse": true,
45+
"RateLimitError.GetResponse": true,
46+
"AbuseRateLimitError.GetResponse": true,
47+
}
48+
)
49+
50+
func logf(fmt string, args ...interface{}) {
51+
if *verbose {
52+
log.Printf(fmt, args...)
53+
}
54+
}
55+
56+
func main() {
57+
flag.Parse()
58+
fset := token.NewFileSet()
59+
60+
pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
61+
if err != nil {
62+
log.Fatal(err)
63+
return
64+
}
65+
66+
for pkgName, pkg := range pkgs {
67+
t := &templateData{
68+
Year: time.Now().Year(),
69+
Filename: pkgName + fileSuffix,
70+
Package: pkgName,
71+
Imports: map[string]string{},
72+
}
73+
for filename, f := range pkg.Files {
74+
logf("Processing %v...", filename)
75+
t.sourceFile = filename
76+
if err := t.processAST(f); err != nil {
77+
log.Fatal(err)
78+
}
79+
}
80+
if err := t.dump(); err != nil {
81+
log.Fatal(err)
82+
}
83+
}
84+
logf("Done.")
85+
}
86+
87+
func (t *templateData) processAST(f *ast.File) error {
88+
for _, decl := range f.Decls {
89+
gd, ok := decl.(*ast.GenDecl)
90+
if !ok {
91+
continue
92+
}
93+
for _, spec := range gd.Specs {
94+
ts, ok := spec.(*ast.TypeSpec)
95+
if !ok {
96+
continue
97+
}
98+
st, ok := ts.Type.(*ast.StructType)
99+
if !ok {
100+
continue
101+
}
102+
for _, field := range st.Fields.List {
103+
se, ok := field.Type.(*ast.StarExpr)
104+
if len(field.Names) == 0 || !ok {
105+
continue
106+
}
107+
108+
fieldName := field.Names[0]
109+
if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); blacklist[key] {
110+
logf("Method %v blacklisted; skipping.", key)
111+
continue
112+
}
113+
114+
switch x := se.X.(type) {
115+
case *ast.ArrayType:
116+
t.addArrayType(x, ts.Name.String(), fieldName.String())
117+
case *ast.Ident:
118+
t.addIdent(x, ts.Name.String(), fieldName.String())
119+
case *ast.MapType:
120+
t.addMapType(x, ts.Name.String(), fieldName.String())
121+
case *ast.SelectorExpr:
122+
t.addSelectorExpr(x, ts.Name.String(), fieldName.String())
123+
default:
124+
logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
125+
}
126+
}
127+
}
128+
}
129+
return nil
130+
}
131+
132+
func sourceFilter(fi os.FileInfo) bool {
133+
return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix)
134+
}
135+
136+
func (t *templateData) dump() error {
137+
if len(t.Getters) == 0 {
138+
logf("No getters for %v; skipping.", t.Filename)
139+
return nil
140+
}
141+
142+
// Sort getters by ReceiverType.FieldName
143+
sort.Sort(byName(t.Getters))
144+
145+
var buf bytes.Buffer
146+
if err := sourceTmpl.Execute(&buf, t); err != nil {
147+
return err
148+
}
149+
clean, err := format.Source(buf.Bytes())
150+
if err != nil {
151+
return err
152+
}
153+
154+
outFile := filepath.Join(filepath.Dir(t.sourceFile), t.Filename)
155+
logf("Writing %v...", outFile)
156+
return ioutil.WriteFile(outFile, clean, 0644)
157+
}
158+
159+
func newGetter(receiverType, fieldName, fieldType, zeroValue string) *getter {
160+
return &getter{
161+
sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
162+
ReceiverVar: strings.ToLower(receiverType[:1]),
163+
ReceiverType: receiverType,
164+
FieldName: fieldName,
165+
FieldType: fieldType,
166+
ZeroValue: zeroValue,
167+
}
168+
}
169+
170+
func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string) {
171+
var eltType string
172+
switch elt := x.Elt.(type) {
173+
case *ast.Ident:
174+
eltType = elt.String()
175+
default:
176+
logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt)
177+
return
178+
}
179+
180+
t.Getters = append(t.Getters, newGetter(receiverType, fieldName, "[]"+eltType, "nil"))
181+
}
182+
183+
func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
184+
var zeroValue string
185+
switch x.String() {
186+
case "int":
187+
zeroValue = "0"
188+
case "string":
189+
zeroValue = `""`
190+
case "bool":
191+
zeroValue = "false"
192+
case "Timestamp":
193+
zeroValue = "Timestamp{}"
194+
default: // other structs handled by their receivers directly.
195+
return
196+
}
197+
198+
t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue))
199+
}
200+
201+
func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string) {
202+
var keyType string
203+
switch key := x.Key.(type) {
204+
case *ast.Ident:
205+
keyType = key.String()
206+
default:
207+
logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key)
208+
return
209+
}
210+
211+
var valueType string
212+
switch value := x.Value.(type) {
213+
case *ast.Ident:
214+
valueType = value.String()
215+
default:
216+
logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value)
217+
return
218+
}
219+
220+
fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType)
221+
zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType)
222+
t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue))
223+
}
224+
225+
func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
226+
if strings.ToLower(fieldName[:1]) == fieldName[:1] { // non-exported field
227+
return
228+
}
229+
230+
var xX string
231+
if xx, ok := x.X.(*ast.Ident); ok {
232+
xX = xx.String()
233+
}
234+
235+
switch xX {
236+
case "time", "json":
237+
if xX == "json" {
238+
t.Imports["encoding/json"] = "encoding/json"
239+
} else {
240+
t.Imports[xX] = xX
241+
}
242+
fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name)
243+
zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name)
244+
if xX == "time" && x.Sel.Name == "Duration" {
245+
zeroValue = "0"
246+
}
247+
t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue))
248+
default:
249+
logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
250+
}
251+
}
252+
253+
type templateData struct {
254+
sourceFile string
255+
Year int
256+
Filename string
257+
Package string
258+
Imports map[string]string
259+
Getters []*getter
260+
}
261+
262+
type getter struct {
263+
sortVal string // lower-case version of "ReceiverType.FieldName"
264+
ReceiverVar string // the one-letter variable name to match the ReceiverType
265+
ReceiverType string
266+
FieldName string
267+
FieldType string
268+
ZeroValue string
269+
}
270+
271+
type byName []*getter
272+
273+
func (b byName) Len() int { return len(b) }
274+
func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal }
275+
func (b byName) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
276+
277+
const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
278+
//
279+
// Use of this source code is governed by a BSD-style
280+
// license that can be found in the LICENSE file.
281+
282+
// {{.Filename}} generated by gen-accessors; DO NOT EDIT
283+
284+
package {{.Package}}
285+
{{if .Imports}}
286+
import (
287+
{{range .Imports}}
288+
"{{.}}"{{end}}
289+
)
290+
{{end}}
291+
{{range .Getters}}
292+
// Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise.
293+
func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
294+
if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
295+
return {{.ZeroValue}}
296+
}
297+
return *{{.ReceiverVar}}.{{.FieldName}}
298+
}
299+
{{end}}
300+
`

0 commit comments

Comments
 (0)