diff --git a/.gitignore b/.gitignore index 17f1ccdc..b77b56ee 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ _generated/*_gen.go _generated/*_gen_test.go msgp/defgen_test.go msgp/cover.out -*~ \ No newline at end of file +*~ +*.coverprofile diff --git a/_generated/convert.go b/_generated/convert.go new file mode 100644 index 00000000..b983c2bc --- /dev/null +++ b/_generated/convert.go @@ -0,0 +1,75 @@ +package _generated + +import "errors" + +//go:generate msgp + +//msgp:shim ConvertStringVal as:string using:fromConvertStringVal/toConvertStringVal mode:convert +//msgp:ignore ConvertStringVal + +func fromConvertStringVal(v ConvertStringVal) (string, error) { + return string(v), nil +} + +func toConvertStringVal(s string) (ConvertStringVal, error) { + return ConvertStringVal(s), nil +} + +type ConvertStringVal string + +type ConvertString struct { + String ConvertStringVal +} + +//msgp:shim ConvertIntfVal as:interface{} using:fromConvertIntfVal/toConvertIntfVal mode:convert +//msgp:ignore ConvertIntfVal + +func fromConvertIntfVal(v ConvertIntfVal) (interface{}, error) { + return v.Test, nil +} + +func toConvertIntfVal(s interface{}) (ConvertIntfVal, error) { + return ConvertIntfVal{Test: s.(string)}, nil +} + +type ConvertIntfVal struct { + Test string +} + +type ConvertIntf struct { + Intf ConvertIntfVal +} + +//msgp:shim ConvertErrVal as:string using:fromConvertErrVal/toConvertErrVal mode:convert +//msgp:ignore ConvertErrVal + +var ( + errConvertFrom = errors.New("error: convert from") + errConvertTo = errors.New("error: convert to") +) + +const ( + fromFailStr = "fromfail" + toFailStr = "tofail" +) + +func fromConvertErrVal(v ConvertErrVal) (string, error) { + s := string(v) + if s == fromFailStr { + return "", errConvertFrom + } + return s, nil +} + +func toConvertErrVal(s string) (ConvertErrVal, error) { + if s == toFailStr { + return ConvertErrVal(""), errConvertTo + } + return ConvertErrVal(s), nil +} + +type ConvertErrVal string + +type ConvertErr struct { + Err ConvertErrVal +} diff --git a/_generated/convert_test.go b/_generated/convert_test.go new file mode 100644 index 00000000..5d22469a --- /dev/null +++ b/_generated/convert_test.go @@ -0,0 +1,59 @@ +package _generated + +import ( + "bytes" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestConvertFromEncodeError(t *testing.T) { + e := ConvertErr{ConvertErrVal(fromFailStr)} + var buf bytes.Buffer + w := msgp.NewWriter(&buf) + err := e.EncodeMsg(w) + if err != errConvertFrom { + t.Fatalf("expected conversion error, found %v", err.Error()) + } +} + +func TestConvertToEncodeError(t *testing.T) { + var in, out ConvertErr + in = ConvertErr{ConvertErrVal(toFailStr)} + var buf bytes.Buffer + w := msgp.NewWriter(&buf) + err := in.EncodeMsg(w) + if err != nil { + t.FailNow() + } + w.Flush() + + r := msgp.NewReader(&buf) + err = (&out).DecodeMsg(r) + if err != errConvertTo { + t.Fatalf("expected conversion error, found %v", err.Error()) + } +} + +func TestConvertFromMarshalError(t *testing.T) { + e := ConvertErr{ConvertErrVal(fromFailStr)} + var b []byte + _, err := e.MarshalMsg(b) + if err != errConvertFrom { + t.Fatalf("expected conversion error, found %v", err.Error()) + } +} + +func TestConvertToMarshalError(t *testing.T) { + var in, out ConvertErr + in = ConvertErr{ConvertErrVal(toFailStr)} + b, err := in.MarshalMsg(nil) + if err != nil { + t.FailNow() + } + + _, err = (&out).UnmarshalMsg(b) + if err != errConvertTo { + t.Fatalf("expected conversion error, found %v", err.Error()) + } +} diff --git a/_generated/def.go b/_generated/def.go index 5579b256..6672c472 100644 --- a/_generated/def.go +++ b/_generated/def.go @@ -1,9 +1,10 @@ package _generated import ( - "github.com/tinylib/msgp/msgp" "os" "time" + + "github.com/tinylib/msgp/msgp" ) //go:generate msgp -o generated.go diff --git a/gen/decode.go b/gen/decode.go index f3907601..7a674bc0 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -143,13 +143,17 @@ func (d *decodeGen) gBase(b *BaseElem) { d.p.printf("\n%s, err = dc.Read%s()", vname, bname) } } + d.p.print(errcheck) // close block for 'tmp' if b.Convert { - d.p.printf("\n%s = %s(%s)\n}", vname, b.FromBase(), tmp) + if b.ShimMode == Cast { + d.p.printf("\n%s = %s(%s)\n}", vname, b.FromBase(), tmp) + } else { + d.p.printf("\n%s, err = %s(%s)\n}", vname, b.FromBase(), tmp) + d.p.print(errcheck) + } } - - d.p.print(errcheck) } func (d *decodeGen) gMap(m *Map) { diff --git a/gen/elem.go b/gen/elem.go index 719df2e8..df7461e1 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -406,11 +406,19 @@ type StructField struct { FieldElem Elem // the field type } +type ShimMode int + +const ( + Cast ShimMode = iota + Convert +) + // BaseElem is an element that // can be represented by a primitive // MessagePack type. type BaseElem struct { common + ShimMode ShimMode // Method used to shim ShimToBase string // shim to base type, or empty ShimFromBase string // shim from base type, or empty Value Primitive // Type of element diff --git a/gen/encode.go b/gen/encode.go index a224a594..6585ee0f 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -2,8 +2,9 @@ package gen import ( "fmt" - "github.com/tinylib/msgp/msgp" "io" + + "github.com/tinylib/msgp/msgp" ) func encode(w io.Writer) *encodeGen { @@ -172,7 +173,14 @@ func (e *encodeGen) gBase(b *BaseElem) { e.fuseHook() vname := b.Varname() if b.Convert { - vname = tobaseConvert(b) + if b.ShimMode == Cast { + vname = tobaseConvert(b) + } else { + vname = randIdent() + e.p.printf("\nvar %s %s", vname, b.BaseType()) + e.p.printf("\n%s, err = %s", vname, tobaseConvert(b)) + e.p.printf(errcheck) + } } if b.Value == IDENT { // unknown identity diff --git a/gen/marshal.go b/gen/marshal.go index 90eccc22..922beeae 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -2,8 +2,9 @@ package gen import ( "fmt" - "github.com/tinylib/msgp/msgp" "io" + + "github.com/tinylib/msgp/msgp" ) func marshal(w io.Writer) *marshalGen { @@ -177,7 +178,14 @@ func (m *marshalGen) gBase(b *BaseElem) { vname := b.Varname() if b.Convert { - vname = tobaseConvert(b) + if b.ShimMode == Cast { + vname = tobaseConvert(b) + } else { + vname = randIdent() + m.p.printf("\nvar %s %s", vname, b.BaseType()) + m.p.printf("\n%s, err = %s", vname, tobaseConvert(b)) + m.p.printf(errcheck) + } } var echeck bool diff --git a/gen/size.go b/gen/size.go index 3e636e47..7563d935 100644 --- a/gen/size.go +++ b/gen/size.go @@ -2,9 +2,10 @@ package gen import ( "fmt" - "github.com/tinylib/msgp/msgp" "io" "strconv" + + "github.com/tinylib/msgp/msgp" ) type sizeState uint8 @@ -183,7 +184,20 @@ func (s *sizeGen) gBase(b *BaseElem) { if !s.p.ok() { return } - s.addConstant(basesizeExpr(b)) + if b.Convert && b.ShimMode == Convert { + s.state = add + vname := randIdent() + s.p.printf("\nvar %s %s", vname, b.BaseType()) + s.p.printf("\ns += %s", basesizeExpr(b.Value, vname, b.BaseName())) + s.state = expr + + } else { + vname := b.Varname() + if b.Convert { + vname = tobaseConvert(b) + } + s.addConstant(basesizeExpr(b.Value, vname, b.BaseName())) + } } // returns "len(slice)" @@ -250,12 +264,8 @@ func fixedsizeExpr(e Elem) (string, bool) { } // print size expression of a variable name -func basesizeExpr(b *BaseElem) string { - vname := b.Varname() - if b.Convert { - vname = tobaseConvert(b) - } - switch b.Value { +func basesizeExpr(value Primitive, vname, basename string) string { + switch value { case Ext: return "msgp.ExtensionPrefixSize + " + stripRef(vname) + ".Len()" case Intf: @@ -267,6 +277,6 @@ func basesizeExpr(b *BaseElem) string { case String: return "msgp.StringPrefixSize + len(" + vname + ")" default: - return builtinSize(b.BaseName()) + return builtinSize(basename) } } diff --git a/gen/unmarshal.go b/gen/unmarshal.go index 7b950313..24d8c111 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -132,12 +132,18 @@ func (u *unmarshalGen) gBase(b *BaseElem) { default: u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", refname, b.BaseName()) } + u.p.print(errcheck) + if b.Convert { // close 'tmp' block - u.p.printf("\n%s = %s(%s)\n}", b.Varname(), b.FromBase(), refname) + if b.ShimMode == Cast { + u.p.printf("\n%s = %s(%s)\n", b.Varname(), b.FromBase(), refname) + } else { + u.p.printf("\n%s, err = %s(%s)", b.Varname(), b.FromBase(), refname) + u.p.print(errcheck) + } + u.p.printf("}") } - - u.p.print(errcheck) } func (u *unmarshalGen) gArray(a *Array) { diff --git a/parse/directives.go b/parse/directives.go index fb78974b..73e441ef 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -2,9 +2,10 @@ package parse import ( "fmt" - "github.com/tinylib/msgp/gen" "go/ast" "strings" + + "github.com/tinylib/msgp/gen" ) const linePrefix = "//msgp:" @@ -52,10 +53,10 @@ func yieldComments(c []*ast.CommentGroup) []string { return out } -//msgp:shim {Type} as:{Newtype} using:{toFunc/fromFunc} +//msgp:shim {Type} as:{Newtype} using:{toFunc/fromFunc} mode:{Mode} func applyShim(text []string, f *FileSet) error { - if len(text) != 4 { - return fmt.Errorf("shim directive should have 3 arguments; found %d", len(text)-1) + if len(text) < 4 || len(text) > 5 { + return fmt.Errorf("shim directive should have 3 or 4 arguments; found %d", len(text)-1) } name := text[1] @@ -76,6 +77,18 @@ func applyShim(text []string, f *FileSet) error { be.ShimToBase = methods[0] be.ShimFromBase = methods[1] + if len(text) == 5 { + modestr := strings.TrimPrefix(strings.TrimSpace(text[4]), "mode:") // parse mode::{mode} + switch modestr { + case "cast": + be.ShimMode = gen.Cast + case "convert": + be.ShimMode = gen.Convert + default: + return fmt.Errorf("invalid shim mode; found %s, expected 'cast' or 'convert", modestr) + } + } + infof("%s -> %s\n", name, be.Value.String()) f.findShim(name, be)