Skip to content

First pass of #67 - Let the shim functions return an error #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ _generated/*_gen.go
_generated/*_gen_test.go
msgp/defgen_test.go
msgp/cover.out
*~
*~
*.coverprofile
75 changes: 75 additions & 0 deletions _generated/convert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package _generated

import "errors"

//go:generate msgp

//msgp:shim ConvertStringVal as:string using:fromConvertStringVal/toConvertStringVal mode:convert
//msgp:ignore ConvertStringVal

func fromConvertStringVal(v ConvertStringVal) (string, error) {
return string(v), nil
}

func toConvertStringVal(s string) (ConvertStringVal, error) {
return ConvertStringVal(s), nil
}

type ConvertStringVal string

type ConvertString struct {
String ConvertStringVal
}

//msgp:shim ConvertIntfVal as:interface{} using:fromConvertIntfVal/toConvertIntfVal mode:convert
//msgp:ignore ConvertIntfVal

func fromConvertIntfVal(v ConvertIntfVal) (interface{}, error) {
return v.Test, nil
}

func toConvertIntfVal(s interface{}) (ConvertIntfVal, error) {
return ConvertIntfVal{Test: s.(string)}, nil
}

type ConvertIntfVal struct {
Test string
}

type ConvertIntf struct {
Intf ConvertIntfVal
}

//msgp:shim ConvertErrVal as:string using:fromConvertErrVal/toConvertErrVal mode:convert
//msgp:ignore ConvertErrVal

var (
errConvertFrom = errors.New("error: convert from")
errConvertTo = errors.New("error: convert to")
)

const (
fromFailStr = "fromfail"
toFailStr = "tofail"
)

func fromConvertErrVal(v ConvertErrVal) (string, error) {
s := string(v)
if s == fromFailStr {
return "", errConvertFrom
}
return s, nil
}

func toConvertErrVal(s string) (ConvertErrVal, error) {
if s == toFailStr {
return ConvertErrVal(""), errConvertTo
}
return ConvertErrVal(s), nil
}

type ConvertErrVal string

type ConvertErr struct {
Err ConvertErrVal
}
59 changes: 59 additions & 0 deletions _generated/convert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package _generated

import (
"bytes"
"testing"

"github.com/tinylib/msgp/msgp"
)

func TestConvertFromEncodeError(t *testing.T) {
e := ConvertErr{ConvertErrVal(fromFailStr)}
var buf bytes.Buffer
w := msgp.NewWriter(&buf)
err := e.EncodeMsg(w)
if err != errConvertFrom {
t.Fatalf("expected conversion error, found %v", err.Error())
}
}

func TestConvertToEncodeError(t *testing.T) {
var in, out ConvertErr
in = ConvertErr{ConvertErrVal(toFailStr)}
var buf bytes.Buffer
w := msgp.NewWriter(&buf)
err := in.EncodeMsg(w)
if err != nil {
t.FailNow()
}
w.Flush()

r := msgp.NewReader(&buf)
err = (&out).DecodeMsg(r)
if err != errConvertTo {
t.Fatalf("expected conversion error, found %v", err.Error())
}
}

func TestConvertFromMarshalError(t *testing.T) {
e := ConvertErr{ConvertErrVal(fromFailStr)}
var b []byte
_, err := e.MarshalMsg(b)
if err != errConvertFrom {
t.Fatalf("expected conversion error, found %v", err.Error())
}
}

func TestConvertToMarshalError(t *testing.T) {
var in, out ConvertErr
in = ConvertErr{ConvertErrVal(toFailStr)}
b, err := in.MarshalMsg(nil)
if err != nil {
t.FailNow()
}

_, err = (&out).UnmarshalMsg(b)
if err != errConvertTo {
t.Fatalf("expected conversion error, found %v", err.Error())
}
}
3 changes: 2 additions & 1 deletion _generated/def.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package _generated

import (
"github.com/tinylib/msgp/msgp"
"os"
"time"

"github.com/tinylib/msgp/msgp"
)

//go:generate msgp -o generated.go
Expand Down
10 changes: 7 additions & 3 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,17 @@ func (d *decodeGen) gBase(b *BaseElem) {
d.p.printf("\n%s, err = dc.Read%s()", vname, bname)
}
}
d.p.print(errcheck)

// close block for 'tmp'
if b.Convert {
d.p.printf("\n%s = %s(%s)\n}", vname, b.FromBase(), tmp)
if b.ShimMode == Cast {
d.p.printf("\n%s = %s(%s)\n}", vname, b.FromBase(), tmp)
} else {
d.p.printf("\n%s, err = %s(%s)\n}", vname, b.FromBase(), tmp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you're generating code that shadows err without checking it first.

d.p.print(errcheck)
}
}

d.p.print(errcheck)
}

func (d *decodeGen) gMap(m *Map) {
Expand Down
8 changes: 8 additions & 0 deletions gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,19 @@ type StructField struct {
FieldElem Elem // the field type
}

type ShimMode int

const (
Cast ShimMode = iota
Convert
)

// BaseElem is an element that
// can be represented by a primitive
// MessagePack type.
type BaseElem struct {
common
ShimMode ShimMode // Method used to shim
ShimToBase string // shim to base type, or empty
ShimFromBase string // shim from base type, or empty
Value Primitive // Type of element
Expand Down
12 changes: 10 additions & 2 deletions gen/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package gen

import (
"fmt"
"github.com/tinylib/msgp/msgp"
"io"

"github.com/tinylib/msgp/msgp"
)

func encode(w io.Writer) *encodeGen {
Expand Down Expand Up @@ -172,7 +173,14 @@ func (e *encodeGen) gBase(b *BaseElem) {
e.fuseHook()
vname := b.Varname()
if b.Convert {
vname = tobaseConvert(b)
if b.ShimMode == Cast {
vname = tobaseConvert(b)
} else {
vname = randIdent()
e.p.printf("\nvar %s %s", vname, b.BaseType())
e.p.printf("\n%s, err = %s", vname, tobaseConvert(b))
e.p.printf(errcheck)
}
}

if b.Value == IDENT { // unknown identity
Expand Down
12 changes: 10 additions & 2 deletions gen/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package gen

import (
"fmt"
"github.com/tinylib/msgp/msgp"
"io"

"github.com/tinylib/msgp/msgp"
)

func marshal(w io.Writer) *marshalGen {
Expand Down Expand Up @@ -177,7 +178,14 @@ func (m *marshalGen) gBase(b *BaseElem) {
vname := b.Varname()

if b.Convert {
vname = tobaseConvert(b)
if b.ShimMode == Cast {
vname = tobaseConvert(b)
} else {
vname = randIdent()
m.p.printf("\nvar %s %s", vname, b.BaseType())
m.p.printf("\n%s, err = %s", vname, tobaseConvert(b))
m.p.printf(errcheck)
}
}

var echeck bool
Expand Down
28 changes: 19 additions & 9 deletions gen/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package gen

import (
"fmt"
"github.com/tinylib/msgp/msgp"
"io"
"strconv"

"github.com/tinylib/msgp/msgp"
)

type sizeState uint8
Expand Down Expand Up @@ -183,7 +184,20 @@ func (s *sizeGen) gBase(b *BaseElem) {
if !s.p.ok() {
return
}
s.addConstant(basesizeExpr(b))
if b.Convert && b.ShimMode == Convert {
s.state = add
vname := randIdent()
s.p.printf("\nvar %s %s", vname, b.BaseType())
s.p.printf("\ns += %s", basesizeExpr(b.Value, vname, b.BaseName()))
s.state = expr

} else {
vname := b.Varname()
if b.Convert {
vname = tobaseConvert(b)
}
s.addConstant(basesizeExpr(b.Value, vname, b.BaseName()))
}
}

// returns "len(slice)"
Expand Down Expand Up @@ -250,12 +264,8 @@ func fixedsizeExpr(e Elem) (string, bool) {
}

// print size expression of a variable name
func basesizeExpr(b *BaseElem) string {
vname := b.Varname()
if b.Convert {
vname = tobaseConvert(b)
}
switch b.Value {
func basesizeExpr(value Primitive, vname, basename string) string {
switch value {
case Ext:
return "msgp.ExtensionPrefixSize + " + stripRef(vname) + ".Len()"
case Intf:
Expand All @@ -267,6 +277,6 @@ func basesizeExpr(b *BaseElem) string {
case String:
return "msgp.StringPrefixSize + len(" + vname + ")"
default:
return builtinSize(b.BaseName())
return builtinSize(basename)
}
}
12 changes: 9 additions & 3 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,18 @@ func (u *unmarshalGen) gBase(b *BaseElem) {
default:
u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", refname, b.BaseName())
}
u.p.print(errcheck)

if b.Convert {
// close 'tmp' block
u.p.printf("\n%s = %s(%s)\n}", b.Varname(), b.FromBase(), refname)
if b.ShimMode == Cast {
u.p.printf("\n%s = %s(%s)\n", b.Varname(), b.FromBase(), refname)
} else {
u.p.printf("\n%s, err = %s(%s)", b.Varname(), b.FromBase(), refname)
u.p.print(errcheck)
}
u.p.printf("}")
}

u.p.print(errcheck)
}

func (u *unmarshalGen) gArray(a *Array) {
Expand Down
21 changes: 17 additions & 4 deletions parse/directives.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package parse

import (
"fmt"
"github.com/tinylib/msgp/gen"
"go/ast"
"strings"

"github.com/tinylib/msgp/gen"
)

const linePrefix = "//msgp:"
Expand Down Expand Up @@ -52,10 +53,10 @@ func yieldComments(c []*ast.CommentGroup) []string {
return out
}

//msgp:shim {Type} as:{Newtype} using:{toFunc/fromFunc}
//msgp:shim {Type} as:{Newtype} using:{toFunc/fromFunc} mode:{Mode}
func applyShim(text []string, f *FileSet) error {
if len(text) != 4 {
return fmt.Errorf("shim directive should have 3 arguments; found %d", len(text)-1)
if len(text) < 4 || len(text) > 5 {
return fmt.Errorf("shim directive should have 3 or 4 arguments; found %d", len(text)-1)
}

name := text[1]
Expand All @@ -76,6 +77,18 @@ func applyShim(text []string, f *FileSet) error {
be.ShimToBase = methods[0]
be.ShimFromBase = methods[1]

if len(text) == 5 {
modestr := strings.TrimPrefix(strings.TrimSpace(text[4]), "mode:") // parse mode::{mode}
switch modestr {
case "cast":
be.ShimMode = gen.Cast
case "convert":
be.ShimMode = gen.Convert
default:
return fmt.Errorf("invalid shim mode; found %s, expected 'cast' or 'convert", modestr)
}
}

infof("%s -> %s\n", name, be.Value.String())
f.findShim(name, be)

Expand Down