Skip to content

Commit 070d6fe

Browse files
committed
Fix coercion of typedef primitives and their pointers. For example,
with "type MyInt int", values of type MyInt and *MyInt should be treated as ints. Fixes graphql-go#488.
1 parent 66aaed7 commit 070d6fe

File tree

2 files changed

+162
-7
lines changed

2 files changed

+162
-7
lines changed

scalars.go

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,53 @@ package graphql
33
import (
44
"fmt"
55
"math"
6+
"reflect"
67
"strconv"
78
"time"
89

910
"github.com/graphql-go/graphql/language/ast"
1011
)
1112

13+
func unwrapPrimitive(value interface{}) (interface{}, bool) {
14+
r := reflect.Indirect(reflect.ValueOf(value))
15+
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
16+
return nil, false
17+
}
18+
19+
switch r.Kind() {
20+
case reflect.Int:
21+
return int(r.Int()), true
22+
case reflect.Int8:
23+
return int8(r.Int()), true
24+
case reflect.Int16:
25+
return int16(r.Int()), true
26+
case reflect.Int32:
27+
return int32(r.Int()), true
28+
case reflect.Int64:
29+
return r.Int(), true
30+
case reflect.Uint:
31+
return uint(r.Uint()), true
32+
case reflect.Uint8:
33+
return uint8(r.Uint()), true
34+
case reflect.Uint16:
35+
return uint16(r.Uint()), true
36+
case reflect.Uint32:
37+
return uint32(r.Uint()), true
38+
case reflect.Uint64:
39+
return r.Uint(), true
40+
case reflect.Float32:
41+
return float32(r.Float()), true
42+
case reflect.Float64:
43+
return r.Float(), true
44+
case reflect.Bool:
45+
return r.Bool(), true
46+
case reflect.String:
47+
return r.String(), true
48+
default:
49+
return nil, false
50+
}
51+
}
52+
1253
// As per the GraphQL Spec, Integers are only treated as valid when a valid
1354
// 32-bit signed integer, providing the broadest support across platforms.
1455
//
@@ -144,6 +185,10 @@ func coerceInt(value interface{}) interface{} {
144185
return coerceInt(*value)
145186
}
146187

188+
if v, ok := unwrapPrimitive(value); ok {
189+
return coerceInt(v)
190+
}
191+
147192
// If the value cannot be transformed into an int, return nil instead of '0'
148193
// to denote 'no integer found'
149194
return nil
@@ -276,6 +321,10 @@ func coerceFloat(value interface{}) interface{} {
276321
return coerceFloat(*value)
277322
}
278323

324+
if v, ok := unwrapPrimitive(value); ok {
325+
return coerceFloat(v)
326+
}
327+
279328
// If the value cannot be transformed into an float, return nil instead of '0.0'
280329
// to denote 'no float found'
281330
return nil
@@ -305,13 +354,23 @@ var Float = NewScalar(ScalarConfig{
305354
})
306355

307356
func coerceString(value interface{}) interface{} {
308-
if v, ok := value.(*string); ok {
309-
if v == nil {
357+
switch t := value.(type) {
358+
case *string:
359+
if t == nil {
310360
return nil
311361
}
312-
return *v
362+
return *t
363+
case string:
364+
return t
365+
default:
366+
if v, ok := unwrapPrimitive(value); ok {
367+
return coerceString(v)
368+
}
369+
if r := reflect.ValueOf(value); r.Kind() == reflect.Ptr && r.IsNil() {
370+
return nil
371+
}
372+
return fmt.Sprintf("%v", value)
313373
}
314-
return fmt.Sprintf("%v", value)
315374
}
316375

317376
// String is the GraphQL string type definition
@@ -472,6 +531,13 @@ func coerceBool(value interface{}) interface{} {
472531
}
473532
return coerceBool(*value)
474533
}
534+
535+
if v, ok := unwrapPrimitive(value); ok {
536+
return coerceBool(v)
537+
}
538+
if r := reflect.ValueOf(value); r.Kind() == reflect.Ptr && r.IsNil() {
539+
return nil
540+
}
475541
return false
476542
}
477543

scalars_test.go

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,50 @@ import (
55
"testing"
66
)
77

8+
type (
9+
myInt int
10+
myString string
11+
myBool bool
12+
myFloat32 float32
13+
)
14+
15+
func TestCoerceString(t *testing.T) {
16+
tests := []struct {
17+
in interface{}
18+
want interface{}
19+
}{
20+
{
21+
in: "hello",
22+
want: "hello",
23+
},
24+
{
25+
in: func() interface{} { s := "hello"; return &s }(),
26+
want: "hello",
27+
},
28+
// Typedef
29+
{
30+
in: myString("hello"),
31+
want: "hello",
32+
},
33+
// Typedef with pointer
34+
{
35+
in: func() interface{} { v := myString("hello"); return &v }(),
36+
want: "hello",
37+
},
38+
// Typedef with nil pointer
39+
{
40+
in: (*myString)(nil),
41+
want: nil,
42+
},
43+
}
44+
45+
for i, tt := range tests {
46+
if got, want := coerceString(tt.in), tt.want; got != want {
47+
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
48+
}
49+
}
50+
}
51+
852
func TestCoerceInt(t *testing.T) {
953
tests := []struct {
1054
in interface{}
@@ -240,11 +284,26 @@ func TestCoerceInt(t *testing.T) {
240284
in: make(map[string]interface{}),
241285
want: nil,
242286
},
287+
// Typedef
288+
{
289+
in: myInt(42),
290+
want: int(42),
291+
},
292+
// Typedef with pointer
293+
{
294+
in: func() interface{} { v := myInt(42); return &v }(),
295+
want: int(42),
296+
},
297+
// Typedef with nil pointer
298+
{
299+
in: (*myInt)(nil),
300+
want: nil,
301+
},
243302
}
244303

245304
for i, tt := range tests {
246305
if got, want := coerceInt(tt.in), tt.want; got != want {
247-
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
306+
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
248307
}
249308
}
250309
}
@@ -438,11 +497,26 @@ func TestCoerceFloat(t *testing.T) {
438497
in: make(map[string]interface{}),
439498
want: nil,
440499
},
500+
// Typedef
501+
{
502+
in: myFloat32(3.14),
503+
want: float32(3.14),
504+
},
505+
// Typedef with pointer
506+
{
507+
in: func() interface{} { v := myFloat32(3.14); return &v }(),
508+
want: float32(3.14),
509+
},
510+
// Typedef with nil pointer
511+
{
512+
in: (*myFloat32)(nil),
513+
want: nil,
514+
},
441515
}
442516

443517
for i, tt := range tests {
444518
if got, want := coerceFloat(tt.in), tt.want; got != want {
445-
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
519+
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
446520
}
447521
}
448522
}
@@ -740,11 +814,26 @@ func TestCoerceBool(t *testing.T) {
740814
in: make(map[string]interface{}),
741815
want: false,
742816
},
817+
// Typedef
818+
{
819+
in: myBool(true),
820+
want: true,
821+
},
822+
// Typedef with pointer
823+
{
824+
in: func() interface{} { v := myBool(true); return &v }(),
825+
want: true,
826+
},
827+
// Typedef with nil pointer
828+
{
829+
in: (*myBool)(nil),
830+
want: nil,
831+
},
743832
}
744833

745834
for i, tt := range tests {
746835
if got, want := coerceBool(tt.in), tt.want; got != want {
747-
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
836+
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
748837
}
749838
}
750839
}

0 commit comments

Comments
 (0)