diff --git a/msgp/number.go b/msgp/number.go index ad07ef99..609188be 100644 --- a/msgp/number.go +++ b/msgp/number.go @@ -91,6 +91,58 @@ func (n *Number) Float() (float64, bool) { } } +// CastInt64 returns the number as an int64 and +// returns whether the cast is valid. +func (n *Number) CastInt64() (int64, bool) { + i, ok := n.Int() + if !ok { + return 0, false + } + + return i, true +} + +// CastInt32 returns the number as an int32 and +// returns whether the cast is valid. +func (n *Number) CastInt32() (int32, bool) { + i, ok := n.CastInt64() + if !ok { + return 0, false + } + + if i >= math.MinInt32 && i <= math.MaxInt32 { + return int32(i), true + } + + return 0, false +} + +// CastUint64 returns the number as an uint64 and +// returns whether the cast is valid. +func (n *Number) CastUint64() (uint64, bool) { + i, ok := n.Uint() + if i-1 == math.MaxInt64 && !ok { + return 0, false + } + + return i, true +} + +// CastUint32 returns the number as an uint32 and +// returns whether the cast is valid. +func (n *Number) CastUint32() (uint32, bool) { + i, ok := n.CastUint64() + if !ok { + return 0, false + } + + if i <= math.MaxInt32 { + return uint32(i), true + } + + return 0, false +} + // Type will return one of: // Float64Type, Float32Type, UintType, or IntType. func (n *Number) Type() Type { diff --git a/msgp/number_test.go b/msgp/number_test.go index 3490647c..0ce06256 100644 --- a/msgp/number_test.go +++ b/msgp/number_test.go @@ -2,6 +2,7 @@ package msgp import ( "bytes" + "math" "testing" ) @@ -92,3 +93,68 @@ func TestNumber(t *testing.T) { } } + +func TestNumber_CastInt64(t *testing.T) { + var n Number + n.AsUint(math.MaxUint64) + + _, ok := n.CastInt64() + if ok { + t.Error("CastInt64() failed: MaxUint64 > MaxInt64") + } +} + +func TestNumber_CastInt32(t *testing.T) { + cases := []struct { + in int64 + want bool + }{ + {math.MinInt32, true}, + {math.MinInt32 - 1, false}, + {math.MaxInt32, true}, + {math.MaxInt32 + 1, false}, + {0, true}, + } + + var n Number + for _, c := range cases { + n.AsInt(c.in) + + _, ok := n.CastInt32() + if ok != c.want { + t.Errorf("cast %v to int32 invalid: %v, got %v", c.in, c.want, ok) + } + } +} + +func TestNumber_CastUint64(t *testing.T) { + var n Number + n.AsInt(math.MinInt64) + + _, ok := n.CastUint64() + if ok { + t.Error("CastUint64() failed: MinInt64 < 0") + } +} + +func TestNumber_CastUint32(t *testing.T) { + cases := []struct { + in int64 + want bool + }{ + {math.MinInt64, false}, + {math.MaxInt32, true}, + {math.MaxInt32 + 1, false}, + {0, true}, + } + + var n Number + for _, c := range cases { + n.AsInt(c.in) + + _, ok := n.CastUint32() + if ok != c.want { + t.Errorf("cast %v to uint32 invalid: %v, got %v", c.in, c.want, ok) + } + } +}