|
| 1 | +// Copyright 2017 The go-github AUTHORS. All rights reserved. |
| 2 | +// |
| 3 | +// Use of this source code is governed by a BSD-style |
| 4 | +// license that can be found in the LICENSE file. |
| 5 | + |
| 6 | +// gen-accessors generates accessor methods for structs with pointer fields. |
| 7 | +// |
| 8 | +// It is meant to be used by the go-github authors in conjunction with the |
| 9 | +// go generate tool before sending a commit to GitHub. |
| 10 | +package main |
| 11 | + |
| 12 | +import ( |
| 13 | + "bytes" |
| 14 | + "flag" |
| 15 | + "fmt" |
| 16 | + "go/ast" |
| 17 | + "go/format" |
| 18 | + "go/parser" |
| 19 | + "go/token" |
| 20 | + "io/ioutil" |
| 21 | + "log" |
| 22 | + "os" |
| 23 | + "path/filepath" |
| 24 | + "sort" |
| 25 | + "strings" |
| 26 | + "text/template" |
| 27 | + "time" |
| 28 | +) |
| 29 | + |
| 30 | +const ( |
| 31 | + fileSuffix = "-accessors.go" |
| 32 | +) |
| 33 | + |
| 34 | +var ( |
| 35 | + verbose = flag.Bool("v", false, "Print verbose log messages") |
| 36 | + |
| 37 | + sourceTmpl = template.Must(template.New("source").Parse(source)) |
| 38 | + |
| 39 | + // blacklist lists which "struct.method" combos to not generate. |
| 40 | + blacklist = map[string]bool{ |
| 41 | + "RepositoryContent.GetContent": true, |
| 42 | + "Client.GetBaseURL": true, |
| 43 | + "Client.GetUploadURL": true, |
| 44 | + "ErrorResponse.GetResponse": true, |
| 45 | + "RateLimitError.GetResponse": true, |
| 46 | + "AbuseRateLimitError.GetResponse": true, |
| 47 | + } |
| 48 | +) |
| 49 | + |
| 50 | +func logf(fmt string, args ...interface{}) { |
| 51 | + if *verbose { |
| 52 | + log.Printf(fmt, args...) |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +func main() { |
| 57 | + flag.Parse() |
| 58 | + fset := token.NewFileSet() |
| 59 | + |
| 60 | + pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0) |
| 61 | + if err != nil { |
| 62 | + log.Fatal(err) |
| 63 | + return |
| 64 | + } |
| 65 | + |
| 66 | + for pkgName, pkg := range pkgs { |
| 67 | + t := &templateData{ |
| 68 | + Year: time.Now().Year(), |
| 69 | + Filename: pkgName + fileSuffix, |
| 70 | + Package: pkgName, |
| 71 | + Imports: map[string]string{}, |
| 72 | + } |
| 73 | + for filename, f := range pkg.Files { |
| 74 | + logf("Processing %v...", filename) |
| 75 | + t.sourceFile = filename |
| 76 | + if err := t.processAST(f); err != nil { |
| 77 | + log.Fatal(err) |
| 78 | + } |
| 79 | + } |
| 80 | + if err := t.dump(); err != nil { |
| 81 | + log.Fatal(err) |
| 82 | + } |
| 83 | + } |
| 84 | + logf("Done.") |
| 85 | +} |
| 86 | + |
| 87 | +func (t *templateData) processAST(f *ast.File) error { |
| 88 | + for _, decl := range f.Decls { |
| 89 | + gd, ok := decl.(*ast.GenDecl) |
| 90 | + if !ok { |
| 91 | + continue |
| 92 | + } |
| 93 | + for _, spec := range gd.Specs { |
| 94 | + ts, ok := spec.(*ast.TypeSpec) |
| 95 | + if !ok { |
| 96 | + continue |
| 97 | + } |
| 98 | + st, ok := ts.Type.(*ast.StructType) |
| 99 | + if !ok { |
| 100 | + continue |
| 101 | + } |
| 102 | + for _, field := range st.Fields.List { |
| 103 | + se, ok := field.Type.(*ast.StarExpr) |
| 104 | + if len(field.Names) == 0 || !ok { |
| 105 | + continue |
| 106 | + } |
| 107 | + |
| 108 | + fieldName := field.Names[0] |
| 109 | + if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); blacklist[key] { |
| 110 | + logf("Method %v blacklisted; skipping.", key) |
| 111 | + continue |
| 112 | + } |
| 113 | + |
| 114 | + switch x := se.X.(type) { |
| 115 | + case *ast.ArrayType: |
| 116 | + t.addArrayType(x, ts.Name.String(), fieldName.String()) |
| 117 | + case *ast.Ident: |
| 118 | + t.addIdent(x, ts.Name.String(), fieldName.String()) |
| 119 | + case *ast.MapType: |
| 120 | + t.addMapType(x, ts.Name.String(), fieldName.String()) |
| 121 | + case *ast.SelectorExpr: |
| 122 | + t.addSelectorExpr(x, ts.Name.String(), fieldName.String()) |
| 123 | + default: |
| 124 | + logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x) |
| 125 | + } |
| 126 | + } |
| 127 | + } |
| 128 | + } |
| 129 | + return nil |
| 130 | +} |
| 131 | + |
| 132 | +func sourceFilter(fi os.FileInfo) bool { |
| 133 | + return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix) |
| 134 | +} |
| 135 | + |
| 136 | +func (t *templateData) dump() error { |
| 137 | + if len(t.Getters) == 0 { |
| 138 | + logf("No getters for %v; skipping.", t.Filename) |
| 139 | + return nil |
| 140 | + } |
| 141 | + |
| 142 | + // Sort getters by ReceiverType.FieldName |
| 143 | + sort.Sort(byName(t.Getters)) |
| 144 | + |
| 145 | + var buf bytes.Buffer |
| 146 | + if err := sourceTmpl.Execute(&buf, t); err != nil { |
| 147 | + return err |
| 148 | + } |
| 149 | + clean, err := format.Source(buf.Bytes()) |
| 150 | + if err != nil { |
| 151 | + return err |
| 152 | + } |
| 153 | + |
| 154 | + outFile := filepath.Join(filepath.Dir(t.sourceFile), t.Filename) |
| 155 | + logf("Writing %v...", outFile) |
| 156 | + return ioutil.WriteFile(outFile, clean, 0644) |
| 157 | +} |
| 158 | + |
| 159 | +func newGetter(receiverType, fieldName, fieldType, zeroValue string) *getter { |
| 160 | + return &getter{ |
| 161 | + sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName), |
| 162 | + ReceiverVar: strings.ToLower(receiverType[:1]), |
| 163 | + ReceiverType: receiverType, |
| 164 | + FieldName: fieldName, |
| 165 | + FieldType: fieldType, |
| 166 | + ZeroValue: zeroValue, |
| 167 | + } |
| 168 | +} |
| 169 | + |
| 170 | +func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string) { |
| 171 | + var eltType string |
| 172 | + switch elt := x.Elt.(type) { |
| 173 | + case *ast.Ident: |
| 174 | + eltType = elt.String() |
| 175 | + default: |
| 176 | + logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt) |
| 177 | + return |
| 178 | + } |
| 179 | + |
| 180 | + t.Getters = append(t.Getters, newGetter(receiverType, fieldName, "[]"+eltType, "nil")) |
| 181 | +} |
| 182 | + |
| 183 | +func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) { |
| 184 | + var zeroValue string |
| 185 | + switch x.String() { |
| 186 | + case "int": |
| 187 | + zeroValue = "0" |
| 188 | + case "string": |
| 189 | + zeroValue = `""` |
| 190 | + case "bool": |
| 191 | + zeroValue = "false" |
| 192 | + case "Timestamp": |
| 193 | + zeroValue = "Timestamp{}" |
| 194 | + default: // other structs handled by their receivers directly. |
| 195 | + return |
| 196 | + } |
| 197 | + |
| 198 | + t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue)) |
| 199 | +} |
| 200 | + |
| 201 | +func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string) { |
| 202 | + var keyType string |
| 203 | + switch key := x.Key.(type) { |
| 204 | + case *ast.Ident: |
| 205 | + keyType = key.String() |
| 206 | + default: |
| 207 | + logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key) |
| 208 | + return |
| 209 | + } |
| 210 | + |
| 211 | + var valueType string |
| 212 | + switch value := x.Value.(type) { |
| 213 | + case *ast.Ident: |
| 214 | + valueType = value.String() |
| 215 | + default: |
| 216 | + logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value) |
| 217 | + return |
| 218 | + } |
| 219 | + |
| 220 | + fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType) |
| 221 | + zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType) |
| 222 | + t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue)) |
| 223 | +} |
| 224 | + |
| 225 | +func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) { |
| 226 | + if strings.ToLower(fieldName[:1]) == fieldName[:1] { // non-exported field |
| 227 | + return |
| 228 | + } |
| 229 | + |
| 230 | + var xX string |
| 231 | + if xx, ok := x.X.(*ast.Ident); ok { |
| 232 | + xX = xx.String() |
| 233 | + } |
| 234 | + |
| 235 | + switch xX { |
| 236 | + case "time", "json": |
| 237 | + if xX == "json" { |
| 238 | + t.Imports["encoding/json"] = "encoding/json" |
| 239 | + } else { |
| 240 | + t.Imports[xX] = xX |
| 241 | + } |
| 242 | + fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name) |
| 243 | + zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name) |
| 244 | + if xX == "time" && x.Sel.Name == "Duration" { |
| 245 | + zeroValue = "0" |
| 246 | + } |
| 247 | + t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue)) |
| 248 | + default: |
| 249 | + logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x) |
| 250 | + } |
| 251 | +} |
| 252 | + |
| 253 | +type templateData struct { |
| 254 | + sourceFile string |
| 255 | + Year int |
| 256 | + Filename string |
| 257 | + Package string |
| 258 | + Imports map[string]string |
| 259 | + Getters []*getter |
| 260 | +} |
| 261 | + |
| 262 | +type getter struct { |
| 263 | + sortVal string // lower-case version of "ReceiverType.FieldName" |
| 264 | + ReceiverVar string // the one-letter variable name to match the ReceiverType |
| 265 | + ReceiverType string |
| 266 | + FieldName string |
| 267 | + FieldType string |
| 268 | + ZeroValue string |
| 269 | +} |
| 270 | + |
| 271 | +type byName []*getter |
| 272 | + |
| 273 | +func (b byName) Len() int { return len(b) } |
| 274 | +func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal } |
| 275 | +func (b byName) Swap(i, j int) { b[i], b[j] = b[j], b[i] } |
| 276 | + |
| 277 | +const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved. |
| 278 | +// |
| 279 | +// Use of this source code is governed by a BSD-style |
| 280 | +// license that can be found in the LICENSE file. |
| 281 | +
|
| 282 | +// {{.Filename}} generated by gen-accessors; DO NOT EDIT |
| 283 | +
|
| 284 | +package {{.Package}} |
| 285 | +{{if .Imports}} |
| 286 | +import ( |
| 287 | + {{range .Imports}} |
| 288 | + "{{.}}"{{end}} |
| 289 | +) |
| 290 | +{{end}} |
| 291 | +{{range .Getters}} |
| 292 | +// Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise. |
| 293 | +func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} { |
| 294 | + if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil { |
| 295 | + return {{.ZeroValue}} |
| 296 | + } |
| 297 | + return *{{.ReceiverVar}}.{{.FieldName}} |
| 298 | +} |
| 299 | +{{end}} |
| 300 | +` |
0 commit comments