Skip to content

Commit 701aacd

Browse files
shabbyrobephilhofer
authored andcommitted
First pass of #67 - Let the shim functions return an error (#182)
* #67 - Let the shim functions return an error
1 parent b2b6a67 commit 701aacd

File tree

11 files changed

+218
-25
lines changed

11 files changed

+218
-25
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ _generated/*_gen.go
44
_generated/*_gen_test.go
55
msgp/defgen_test.go
66
msgp/cover.out
7-
*~
7+
*~
8+
*.coverprofile

_generated/convert.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package _generated
2+
3+
import "errors"
4+
5+
//go:generate msgp
6+
7+
//msgp:shim ConvertStringVal as:string using:fromConvertStringVal/toConvertStringVal mode:convert
8+
//msgp:ignore ConvertStringVal
9+
10+
func fromConvertStringVal(v ConvertStringVal) (string, error) {
11+
return string(v), nil
12+
}
13+
14+
func toConvertStringVal(s string) (ConvertStringVal, error) {
15+
return ConvertStringVal(s), nil
16+
}
17+
18+
type ConvertStringVal string
19+
20+
type ConvertString struct {
21+
String ConvertStringVal
22+
}
23+
24+
//msgp:shim ConvertIntfVal as:interface{} using:fromConvertIntfVal/toConvertIntfVal mode:convert
25+
//msgp:ignore ConvertIntfVal
26+
27+
func fromConvertIntfVal(v ConvertIntfVal) (interface{}, error) {
28+
return v.Test, nil
29+
}
30+
31+
func toConvertIntfVal(s interface{}) (ConvertIntfVal, error) {
32+
return ConvertIntfVal{Test: s.(string)}, nil
33+
}
34+
35+
type ConvertIntfVal struct {
36+
Test string
37+
}
38+
39+
type ConvertIntf struct {
40+
Intf ConvertIntfVal
41+
}
42+
43+
//msgp:shim ConvertErrVal as:string using:fromConvertErrVal/toConvertErrVal mode:convert
44+
//msgp:ignore ConvertErrVal
45+
46+
var (
47+
errConvertFrom = errors.New("error: convert from")
48+
errConvertTo = errors.New("error: convert to")
49+
)
50+
51+
const (
52+
fromFailStr = "fromfail"
53+
toFailStr = "tofail"
54+
)
55+
56+
func fromConvertErrVal(v ConvertErrVal) (string, error) {
57+
s := string(v)
58+
if s == fromFailStr {
59+
return "", errConvertFrom
60+
}
61+
return s, nil
62+
}
63+
64+
func toConvertErrVal(s string) (ConvertErrVal, error) {
65+
if s == toFailStr {
66+
return ConvertErrVal(""), errConvertTo
67+
}
68+
return ConvertErrVal(s), nil
69+
}
70+
71+
type ConvertErrVal string
72+
73+
type ConvertErr struct {
74+
Err ConvertErrVal
75+
}

_generated/convert_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package _generated
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
7+
"github.com/tinylib/msgp/msgp"
8+
)
9+
10+
func TestConvertFromEncodeError(t *testing.T) {
11+
e := ConvertErr{ConvertErrVal(fromFailStr)}
12+
var buf bytes.Buffer
13+
w := msgp.NewWriter(&buf)
14+
err := e.EncodeMsg(w)
15+
if err != errConvertFrom {
16+
t.Fatalf("expected conversion error, found %v", err.Error())
17+
}
18+
}
19+
20+
func TestConvertToEncodeError(t *testing.T) {
21+
var in, out ConvertErr
22+
in = ConvertErr{ConvertErrVal(toFailStr)}
23+
var buf bytes.Buffer
24+
w := msgp.NewWriter(&buf)
25+
err := in.EncodeMsg(w)
26+
if err != nil {
27+
t.FailNow()
28+
}
29+
w.Flush()
30+
31+
r := msgp.NewReader(&buf)
32+
err = (&out).DecodeMsg(r)
33+
if err != errConvertTo {
34+
t.Fatalf("expected conversion error, found %v", err.Error())
35+
}
36+
}
37+
38+
func TestConvertFromMarshalError(t *testing.T) {
39+
e := ConvertErr{ConvertErrVal(fromFailStr)}
40+
var b []byte
41+
_, err := e.MarshalMsg(b)
42+
if err != errConvertFrom {
43+
t.Fatalf("expected conversion error, found %v", err.Error())
44+
}
45+
}
46+
47+
func TestConvertToMarshalError(t *testing.T) {
48+
var in, out ConvertErr
49+
in = ConvertErr{ConvertErrVal(toFailStr)}
50+
b, err := in.MarshalMsg(nil)
51+
if err != nil {
52+
t.FailNow()
53+
}
54+
55+
_, err = (&out).UnmarshalMsg(b)
56+
if err != errConvertTo {
57+
t.Fatalf("expected conversion error, found %v", err.Error())
58+
}
59+
}

_generated/def.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
package _generated
22

33
import (
4-
"github.com/tinylib/msgp/msgp"
54
"os"
65
"time"
6+
7+
"github.com/tinylib/msgp/msgp"
78
)
89

910
//go:generate msgp -o generated.go

gen/decode.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,17 @@ func (d *decodeGen) gBase(b *BaseElem) {
143143
d.p.printf("\n%s, err = dc.Read%s()", vname, bname)
144144
}
145145
}
146+
d.p.print(errcheck)
146147

147148
// close block for 'tmp'
148149
if b.Convert {
149-
d.p.printf("\n%s = %s(%s)\n}", vname, b.FromBase(), tmp)
150+
if b.ShimMode == Cast {
151+
d.p.printf("\n%s = %s(%s)\n}", vname, b.FromBase(), tmp)
152+
} else {
153+
d.p.printf("\n%s, err = %s(%s)\n}", vname, b.FromBase(), tmp)
154+
d.p.print(errcheck)
155+
}
150156
}
151-
152-
d.p.print(errcheck)
153157
}
154158

155159
func (d *decodeGen) gMap(m *Map) {

gen/elem.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,11 +406,19 @@ type StructField struct {
406406
FieldElem Elem // the field type
407407
}
408408

409+
type ShimMode int
410+
411+
const (
412+
Cast ShimMode = iota
413+
Convert
414+
)
415+
409416
// BaseElem is an element that
410417
// can be represented by a primitive
411418
// MessagePack type.
412419
type BaseElem struct {
413420
common
421+
ShimMode ShimMode // Method used to shim
414422
ShimToBase string // shim to base type, or empty
415423
ShimFromBase string // shim from base type, or empty
416424
Value Primitive // Type of element

gen/encode.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ package gen
22

33
import (
44
"fmt"
5-
"github.com/tinylib/msgp/msgp"
65
"io"
6+
7+
"github.com/tinylib/msgp/msgp"
78
)
89

910
func encode(w io.Writer) *encodeGen {
@@ -172,7 +173,14 @@ func (e *encodeGen) gBase(b *BaseElem) {
172173
e.fuseHook()
173174
vname := b.Varname()
174175
if b.Convert {
175-
vname = tobaseConvert(b)
176+
if b.ShimMode == Cast {
177+
vname = tobaseConvert(b)
178+
} else {
179+
vname = randIdent()
180+
e.p.printf("\nvar %s %s", vname, b.BaseType())
181+
e.p.printf("\n%s, err = %s", vname, tobaseConvert(b))
182+
e.p.printf(errcheck)
183+
}
176184
}
177185

178186
if b.Value == IDENT { // unknown identity

gen/marshal.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ package gen
22

33
import (
44
"fmt"
5-
"github.com/tinylib/msgp/msgp"
65
"io"
6+
7+
"github.com/tinylib/msgp/msgp"
78
)
89

910
func marshal(w io.Writer) *marshalGen {
@@ -177,7 +178,14 @@ func (m *marshalGen) gBase(b *BaseElem) {
177178
vname := b.Varname()
178179

179180
if b.Convert {
180-
vname = tobaseConvert(b)
181+
if b.ShimMode == Cast {
182+
vname = tobaseConvert(b)
183+
} else {
184+
vname = randIdent()
185+
m.p.printf("\nvar %s %s", vname, b.BaseType())
186+
m.p.printf("\n%s, err = %s", vname, tobaseConvert(b))
187+
m.p.printf(errcheck)
188+
}
181189
}
182190

183191
var echeck bool

gen/size.go

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ package gen
22

33
import (
44
"fmt"
5-
"github.com/tinylib/msgp/msgp"
65
"io"
76
"strconv"
7+
8+
"github.com/tinylib/msgp/msgp"
89
)
910

1011
type sizeState uint8
@@ -183,7 +184,20 @@ func (s *sizeGen) gBase(b *BaseElem) {
183184
if !s.p.ok() {
184185
return
185186
}
186-
s.addConstant(basesizeExpr(b))
187+
if b.Convert && b.ShimMode == Convert {
188+
s.state = add
189+
vname := randIdent()
190+
s.p.printf("\nvar %s %s", vname, b.BaseType())
191+
s.p.printf("\ns += %s", basesizeExpr(b.Value, vname, b.BaseName()))
192+
s.state = expr
193+
194+
} else {
195+
vname := b.Varname()
196+
if b.Convert {
197+
vname = tobaseConvert(b)
198+
}
199+
s.addConstant(basesizeExpr(b.Value, vname, b.BaseName()))
200+
}
187201
}
188202

189203
// returns "len(slice)"
@@ -250,12 +264,8 @@ func fixedsizeExpr(e Elem) (string, bool) {
250264
}
251265

252266
// print size expression of a variable name
253-
func basesizeExpr(b *BaseElem) string {
254-
vname := b.Varname()
255-
if b.Convert {
256-
vname = tobaseConvert(b)
257-
}
258-
switch b.Value {
267+
func basesizeExpr(value Primitive, vname, basename string) string {
268+
switch value {
259269
case Ext:
260270
return "msgp.ExtensionPrefixSize + " + stripRef(vname) + ".Len()"
261271
case Intf:
@@ -267,6 +277,6 @@ func basesizeExpr(b *BaseElem) string {
267277
case String:
268278
return "msgp.StringPrefixSize + len(" + vname + ")"
269279
default:
270-
return builtinSize(b.BaseName())
280+
return builtinSize(basename)
271281
}
272282
}

gen/unmarshal.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,18 @@ func (u *unmarshalGen) gBase(b *BaseElem) {
132132
default:
133133
u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", refname, b.BaseName())
134134
}
135+
u.p.print(errcheck)
136+
135137
if b.Convert {
136138
// close 'tmp' block
137-
u.p.printf("\n%s = %s(%s)\n}", b.Varname(), b.FromBase(), refname)
139+
if b.ShimMode == Cast {
140+
u.p.printf("\n%s = %s(%s)\n", b.Varname(), b.FromBase(), refname)
141+
} else {
142+
u.p.printf("\n%s, err = %s(%s)", b.Varname(), b.FromBase(), refname)
143+
u.p.print(errcheck)
144+
}
145+
u.p.printf("}")
138146
}
139-
140-
u.p.print(errcheck)
141147
}
142148

143149
func (u *unmarshalGen) gArray(a *Array) {

parse/directives.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ package parse
22

33
import (
44
"fmt"
5-
"github.com/tinylib/msgp/gen"
65
"go/ast"
76
"strings"
7+
8+
"github.com/tinylib/msgp/gen"
89
)
910

1011
const linePrefix = "//msgp:"
@@ -52,10 +53,10 @@ func yieldComments(c []*ast.CommentGroup) []string {
5253
return out
5354
}
5455

55-
//msgp:shim {Type} as:{Newtype} using:{toFunc/fromFunc}
56+
//msgp:shim {Type} as:{Newtype} using:{toFunc/fromFunc} mode:{Mode}
5657
func applyShim(text []string, f *FileSet) error {
57-
if len(text) != 4 {
58-
return fmt.Errorf("shim directive should have 3 arguments; found %d", len(text)-1)
58+
if len(text) < 4 || len(text) > 5 {
59+
return fmt.Errorf("shim directive should have 3 or 4 arguments; found %d", len(text)-1)
5960
}
6061

6162
name := text[1]
@@ -76,6 +77,18 @@ func applyShim(text []string, f *FileSet) error {
7677
be.ShimToBase = methods[0]
7778
be.ShimFromBase = methods[1]
7879

80+
if len(text) == 5 {
81+
modestr := strings.TrimPrefix(strings.TrimSpace(text[4]), "mode:") // parse mode::{mode}
82+
switch modestr {
83+
case "cast":
84+
be.ShimMode = gen.Cast
85+
case "convert":
86+
be.ShimMode = gen.Convert
87+
default:
88+
return fmt.Errorf("invalid shim mode; found %s, expected 'cast' or 'convert", modestr)
89+
}
90+
}
91+
7992
infof("%s -> %s\n", name, be.Value.String())
8093
f.findShim(name, be)
8194

0 commit comments

Comments
 (0)