Skip to content

Commit 4ff26d9

Browse files
authored
Add pointer receiver directive (#357)
Adds `//msgp:pointer` file directive. This will generate all functions with pointer receivers. Tested with base types. Not tested with various other directives, so there may be quirks. Fixes #332
1 parent f80292a commit 4ff26d9

File tree

8 files changed

+255
-8
lines changed

8 files changed

+255
-8
lines changed

_generated/pointer.go

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
package _generated
2+
3+
import (
4+
"fmt"
5+
"time"
6+
7+
"github.com/tinylib/msgp/msgp"
8+
)
9+
10+
//go:generate msgp $GOFILE$
11+
12+
// Generate only pointer receivers:
13+
14+
//msgp:pointer
15+
16+
var mustNoInterf = []interface{}{
17+
Pointer0{},
18+
NamedBoolPointer(true),
19+
NamedIntPointer(0),
20+
NamedFloat64Pointer(0),
21+
NamedStringPointer(""),
22+
NamedMapStructPointer(nil),
23+
NamedMapStructPointer2(nil),
24+
NamedMapStringPointer(nil),
25+
NamedMapStringPointer2(nil),
26+
EmbeddableStructPointer{},
27+
EmbeddableStruct2Pointer{},
28+
PointerHalfFull{},
29+
PointerNoName{},
30+
}
31+
32+
var mustHaveInterf = []interface{}{
33+
&Pointer0{},
34+
mustPtr(NamedBoolPointer(true)),
35+
mustPtr(NamedIntPointer(0)),
36+
mustPtr(NamedFloat64Pointer(0)),
37+
mustPtr(NamedStringPointer("")),
38+
mustPtr(NamedMapStructPointer(nil)),
39+
mustPtr(NamedMapStructPointer2(nil)),
40+
mustPtr(NamedMapStringPointer(nil)),
41+
mustPtr(NamedMapStringPointer2(nil)),
42+
&EmbeddableStructPointer{},
43+
&EmbeddableStruct2Pointer{},
44+
&PointerHalfFull{},
45+
&PointerNoName{},
46+
}
47+
48+
func mustPtr[T any](v T) *T {
49+
return &v
50+
}
51+
52+
func init() {
53+
for _, v := range mustNoInterf {
54+
if _, ok := v.(msgp.Marshaler); ok {
55+
panic(fmt.Sprintf("type %T supports interface", v))
56+
}
57+
if _, ok := v.(msgp.Encodable); ok {
58+
panic(fmt.Sprintf("type %T supports interface", v))
59+
}
60+
}
61+
for _, v := range mustHaveInterf {
62+
if _, ok := v.(msgp.Marshaler); !ok {
63+
panic(fmt.Sprintf("type %T does not support interface", v))
64+
}
65+
if _, ok := v.(msgp.Encodable); !ok {
66+
panic(fmt.Sprintf("type %T does not support interface", v))
67+
}
68+
}
69+
}
70+
71+
type Pointer0 struct {
72+
ABool bool `msg:"abool"`
73+
AInt int `msg:"aint"`
74+
AInt8 int8 `msg:"aint8"`
75+
AInt16 int16 `msg:"aint16"`
76+
AInt32 int32 `msg:"aint32"`
77+
AInt64 int64 `msg:"aint64"`
78+
AUint uint `msg:"auint"`
79+
AUint8 uint8 `msg:"auint8"`
80+
AUint16 uint16 `msg:"auint16"`
81+
AUint32 uint32 `msg:"auint32"`
82+
AUint64 uint64 `msg:"auint64"`
83+
AFloat32 float32 `msg:"afloat32"`
84+
AFloat64 float64 `msg:"afloat64"`
85+
AComplex64 complex64 `msg:"acomplex64"`
86+
AComplex128 complex128 `msg:"acomplex128"`
87+
88+
ANamedBool bool `msg:"anamedbool"`
89+
ANamedInt int `msg:"anamedint"`
90+
ANamedFloat64 float64 `msg:"anamedfloat64"`
91+
92+
AMapStrStr map[string]string `msg:"amapstrstr"`
93+
94+
APtrNamedStr *NamedString `msg:"aptrnamedstr"`
95+
96+
AString string `msg:"astring"`
97+
ANamedString string `msg:"anamedstring"`
98+
AByteSlice []byte `msg:"abyteslice"`
99+
100+
ASliceString []string `msg:"aslicestring"`
101+
ASliceNamedString []NamedString `msg:"aslicenamedstring"`
102+
103+
ANamedStruct NamedStruct `msg:"anamedstruct"`
104+
APtrNamedStruct *NamedStruct `msg:"aptrnamedstruct"`
105+
106+
AUnnamedStruct struct {
107+
A string `msg:"a"`
108+
} `msg:"aunnamedstruct"` // omitempty not supported on unnamed struct
109+
110+
EmbeddableStruct `msg:",flatten"` // embed flat
111+
112+
EmbeddableStruct2 `msg:"embeddablestruct2"` // embed non-flat
113+
114+
AArrayInt [5]int `msg:"aarrayint"` // not supported
115+
116+
ATime time.Time `msg:"atime"`
117+
}
118+
119+
type (
120+
NamedBoolPointer bool
121+
NamedIntPointer int
122+
NamedFloat64Pointer float64
123+
NamedStringPointer string
124+
NamedMapStructPointer map[string]Pointer0
125+
NamedMapStructPointer2 map[string]*Pointer0
126+
NamedMapStringPointer map[string]NamedStringPointer
127+
NamedMapStringPointer2 map[string]*NamedStringPointer
128+
)
129+
130+
type EmbeddableStructPointer struct {
131+
SomeEmbed string `msg:"someembed"`
132+
}
133+
134+
type EmbeddableStruct2Pointer struct {
135+
SomeEmbed2 string `msg:"someembed2"`
136+
}
137+
138+
type NamedStructPointer struct {
139+
A string `msg:"a"`
140+
B string `msg:"b"`
141+
}
142+
143+
type PointerHalfFull struct {
144+
Field00 string `msg:"field00"`
145+
Field01 string `msg:"field01"`
146+
Field02 string `msg:"field02"`
147+
Field03 string `msg:"field03"`
148+
}
149+
150+
type PointerNoName struct {
151+
ABool bool `msg:""`
152+
AInt int `msg:""`
153+
AInt8 int8 `msg:""`
154+
AInt16 int16 `msg:""`
155+
AInt32 int32 `msg:""`
156+
AInt64 int64 `msg:""`
157+
AUint uint `msg:""`
158+
AUint8 uint8 `msg:""`
159+
AUint16 uint16 `msg:""`
160+
AUint32 uint32 `msg:""`
161+
AUint64 uint64 `msg:""`
162+
AFloat32 float32 `msg:""`
163+
AFloat64 float64 `msg:""`
164+
AComplex64 complex64 `msg:""`
165+
AComplex128 complex128 `msg:""`
166+
167+
ANamedBool bool `msg:""`
168+
ANamedInt int `msg:""`
169+
ANamedFloat64 float64 `msg:""`
170+
171+
AMapStrF map[string]NamedFloat64Pointer `msg:""`
172+
AMapStrStruct map[string]PointerHalfFull `msg:""`
173+
AMapStrStruct2 map[string]*PointerHalfFull `msg:""`
174+
175+
APtrNamedStr *NamedStringPointer `msg:""`
176+
177+
AString string `msg:""`
178+
AByteSlice []byte `msg:""`
179+
180+
ASliceString []string `msg:""`
181+
ASliceNamedString []NamedStringPointer `msg:""`
182+
183+
ANamedStruct NamedStructPointer `msg:""`
184+
APtrNamedStruct *NamedStructPointer `msg:""`
185+
186+
AUnnamedStruct struct {
187+
A string `msg:""`
188+
} `msg:""` // omitempty not supported on unnamed struct
189+
190+
EmbeddableStructPointer `msg:",flatten"` // embed flat
191+
192+
EmbeddableStruct2Pointer `msg:""` // embed non-flat
193+
194+
AArrayInt [5]int `msg:""` // not supported
195+
196+
ATime time.Time `msg:""`
197+
ADur time.Duration `msg:""`
198+
}

gen/elem.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,22 @@ var builtins = map[string]struct{}{
134134
}
135135

136136
// common data/methods for every Elem
137-
type common struct{ vname, alias string }
137+
type common struct {
138+
vname, alias string
139+
ptrRcv bool
140+
}
138141

139142
func (c *common) SetVarname(s string) { c.vname = s }
140143
func (c *common) Varname() string { return c.vname }
141144
func (c *common) Alias(typ string) { c.alias = typ }
142145
func (c *common) hidden() {}
143146
func (c *common) AllowNil() bool { return false }
147+
func (c *common) AlwaysPtr(set *bool) bool {
148+
if c != nil && set != nil {
149+
c.ptrRcv = *set
150+
}
151+
return c.ptrRcv
152+
}
144153

145154
func IsPrintable(e Elem) bool {
146155
if be, ok := e.(*BaseElem); ok && !be.Printable() {
@@ -191,6 +200,9 @@ type Elem interface {
191200
// This is true for slices and maps.
192201
AllowNil() bool
193202

203+
// AlwaysPtr will return true if receiver should always be a pointer.
204+
AlwaysPtr(set *bool) bool
205+
194206
// IfZeroExpr returns the expression to compare to an empty value
195207
// for this type, per the rules of the `omitempty` feature.
196208
// It is meant to be used in an if statement

gen/encode.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,16 @@ func (e *encodeGen) Execute(p Elem) error {
6262
e.ctx = &Context{}
6363

6464
e.p.comment("EncodeMsg implements msgp.Encodable")
65-
66-
e.p.printf("\nfunc (%s %s) EncodeMsg(en *msgp.Writer) (err error) {", p.Varname(), imutMethodReceiver(p))
65+
rcv := imutMethodReceiver(p)
66+
ogVar := p.Varname()
67+
if p.AlwaysPtr(nil) {
68+
rcv = methodReceiver(p)
69+
}
70+
e.p.printf("\nfunc (%s %s) EncodeMsg(en *msgp.Writer) (err error) {", ogVar, rcv)
6771
next(e, p)
72+
if p.AlwaysPtr(nil) {
73+
p.SetVarname(ogVar)
74+
}
6875
e.p.nakedReturn()
6976
return e.p.err
7077
}

gen/marshal.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,18 @@ func (m *marshalGen) Execute(p Elem) error {
4747
// calling methodReceiver so
4848
// that z.Msgsize() is printed correctly
4949
c := p.Varname()
50-
51-
m.p.printf("\nfunc (%s %s) MarshalMsg(b []byte) (o []byte, err error) {", p.Varname(), imutMethodReceiver(p))
50+
rcv := imutMethodReceiver(p)
51+
ogVar := p.Varname()
52+
if p.AlwaysPtr(nil) {
53+
rcv = methodReceiver(p)
54+
}
55+
m.p.printf("\nfunc (%s %s) MarshalMsg(b []byte) (o []byte, err error) {", ogVar, rcv)
5256
m.p.printf("\no = msgp.Require(b, %s.Msgsize())", c)
5357
next(m, p)
58+
if p.AlwaysPtr(nil) {
59+
p.SetVarname(ogVar)
60+
}
61+
5462
m.p.nakedReturn()
5563
return m.p.err
5664
}
@@ -280,7 +288,6 @@ func (m *marshalGen) gBase(b *BaseElem) {
280288
}
281289
m.fuseHook()
282290
vname := b.Varname()
283-
284291
if b.Convert {
285292
if b.ShimMode == Cast {
286293
vname = tobaseConvert(b)

gen/size.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,17 @@ func (s *sizeGen) Execute(p Elem) error {
8686

8787
s.p.comment("Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message")
8888

89-
s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", p.Varname(), imutMethodReceiver(p))
89+
rcv := imutMethodReceiver(p)
90+
ogVar := p.Varname()
91+
if p.AlwaysPtr(nil) {
92+
rcv = methodReceiver(p)
93+
}
94+
s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", ogVar, rcv)
9095
s.state = assign
9196
next(s, p)
97+
if p.AlwaysPtr(nil) {
98+
p.SetVarname(ogVar)
99+
}
92100
s.p.nakedReturn()
93101
return s.p.err
94102
}

parse/directives.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ var directives = map[string]directive{
3434
// to add an early directive, define a func([]string, *FileSet) error
3535
// and then add it to this list.
3636
var earlyDirectives = map[string]directive{
37-
"tag": tag,
37+
"tag": tag,
38+
"pointer": pointer,
3839
}
3940

4041
var passDirectives = map[string]passDirective{
@@ -120,6 +121,7 @@ func replace(text []string, f *FileSet) error {
120121
return err
121122
}
122123
e := f.parseExpr(expr)
124+
e.AlwaysPtr(&f.pointerRcv)
123125

124126
if be, ok := e.(*gen.BaseElem); ok {
125127
be.Convert = true
@@ -178,3 +180,9 @@ func tag(text []string, f *FileSet) error {
178180
f.tagName = strings.TrimSpace(text[1])
179181
return nil
180182
}
183+
184+
//msgp:pointer
185+
func pointer(text []string, f *FileSet) error {
186+
f.pointerRcv = true
187+
return nil
188+
}

parse/getast.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ type FileSet struct {
2222
Directives []string // raw preprocessor directives
2323
Imports []*ast.ImportSpec // imports
2424
tagName string // tag to read field names from
25+
pointerRcv bool // generate with pointer receivers.
2526
}
2627

2728
// File parses a file at the relative path
@@ -199,6 +200,7 @@ parse:
199200
popstate()
200201
continue parse
201202
}
203+
el.AlwaysPtr(&f.pointerRcv)
202204
// push unresolved identities into
203205
// the graph of links and resolve after
204206
// we've handled every possible named type.

printer/print.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ func PrintFile(file string, f *parse.FileSet, mode gen.Method) error {
4242
}
4343
err = <-res
4444
if err != nil {
45+
os.WriteFile(file+".broken", out.Bytes(), os.ModePerm)
46+
if Logf != nil {
47+
Logf("Error: %s. Wrote broken output to %s\n", err, file+".broken")
48+
}
49+
4550
return err
4651
}
4752
return nil

0 commit comments

Comments
 (0)