Skip to content

Commit b36c57e

Browse files
atombendersimaotwx
authored andcommitted
Fix coercion of typedef primitives and their pointers. For example,
with "type MyInt int", values of type MyInt and *MyInt should be treated as ints. Fixes graphql-go#488. Signed-off-by: Simao Gomes Viana <[email protected]>
1 parent 35d1c17 commit b36c57e

File tree

2 files changed

+191
-11
lines changed

2 files changed

+191
-11
lines changed

scalars.go

Lines changed: 99 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,79 @@ package graphql
33
import (
44
"fmt"
55
"math"
6+
"reflect"
67
"strconv"
78
"time"
89

910
"github.com/graphql-go/graphql/language/ast"
1011
)
1112

13+
func unwrapInt(value interface{}) (interface{}, bool) {
14+
r := reflect.Indirect(reflect.ValueOf(value))
15+
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
16+
return nil, false
17+
}
18+
19+
switch r.Kind() {
20+
case reflect.Int:
21+
return int(r.Int()), true
22+
case reflect.Int8:
23+
return int8(r.Int()), true
24+
case reflect.Int16:
25+
return int16(r.Int()), true
26+
case reflect.Int32:
27+
return int32(r.Int()), true
28+
case reflect.Int64:
29+
return r.Int(), true
30+
default:
31+
return nil, false
32+
}
33+
}
34+
35+
func unwrapFloat(value interface{}) (interface{}, bool) {
36+
r := reflect.Indirect(reflect.ValueOf(value))
37+
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
38+
return nil, false
39+
}
40+
41+
switch r.Kind() {
42+
case reflect.Float32:
43+
return float32(r.Float()), true
44+
case reflect.Float64:
45+
return r.Float(), true
46+
default:
47+
return nil, false
48+
}
49+
}
50+
51+
func unwrapBool(value interface{}) (interface{}, bool) {
52+
r := reflect.Indirect(reflect.ValueOf(value))
53+
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
54+
return nil, false
55+
}
56+
57+
switch r.Kind() {
58+
case reflect.Bool:
59+
return r.Bool(), true
60+
default:
61+
return nil, false
62+
}
63+
}
64+
65+
func unwrapString(value interface{}) (interface{}, bool) {
66+
r := reflect.Indirect(reflect.ValueOf(value))
67+
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
68+
return nil, false
69+
}
70+
71+
switch r.Kind() {
72+
case reflect.String:
73+
return r.String(), true
74+
default:
75+
return nil, false
76+
}
77+
}
78+
1279
// As per the GraphQL Spec, Integers are only treated as valid when a valid
1380
// 32-bit signed integer, providing the broadest support across platforms.
1481
//
@@ -142,11 +209,14 @@ func coerceInt(value interface{}) interface{} {
142209
return nil
143210
}
144211
return coerceInt(*value)
212+
default:
213+
if v, ok := unwrapInt(value); ok {
214+
return coerceInt(v)
215+
}
216+
// If the value cannot be transformed into an int, return nil instead of '0'
217+
// to denote 'no integer found'
218+
return nil
145219
}
146-
147-
// If the value cannot be transformed into an int, return nil instead of '0'
148-
// to denote 'no integer found'
149-
return nil
150220
}
151221

152222
// Int is the GraphQL Integer type definition.
@@ -276,6 +346,10 @@ func coerceFloat(value interface{}) interface{} {
276346
return coerceFloat(*value)
277347
}
278348

349+
if v, ok := unwrapFloat(value); ok {
350+
return coerceFloat(v)
351+
}
352+
279353
// If the value cannot be transformed into an float, return nil instead of '0.0'
280354
// to denote 'no float found'
281355
return nil
@@ -305,13 +379,23 @@ var Float = NewScalar(ScalarConfig{
305379
})
306380

307381
func coerceString(value interface{}) interface{} {
308-
if v, ok := value.(*string); ok {
309-
if v == nil {
382+
switch t := value.(type) {
383+
case *string:
384+
if t == nil {
310385
return nil
311386
}
312-
return *v
387+
return *t
388+
case string:
389+
return t
390+
default:
391+
if v, ok := unwrapString(value); ok {
392+
return coerceString(v)
393+
}
394+
if r := reflect.ValueOf(value); r.Kind() == reflect.Ptr && r.IsNil() {
395+
return nil
396+
}
397+
return fmt.Sprintf("%v", value)
313398
}
314-
return fmt.Sprintf("%v", value)
315399
}
316400

317401
// String is the GraphQL string type definition
@@ -472,6 +556,13 @@ func coerceBool(value interface{}) interface{} {
472556
}
473557
return coerceBool(*value)
474558
}
559+
560+
if v, ok := unwrapBool(value); ok {
561+
return coerceBool(v)
562+
}
563+
if r := reflect.ValueOf(value); r.Kind() == reflect.Ptr && r.IsNil() {
564+
return nil
565+
}
475566
return false
476567
}
477568

scalars_test.go

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,50 @@ import (
55
"testing"
66
)
77

8+
type (
9+
myInt int
10+
myString string
11+
myBool bool
12+
myFloat32 float32
13+
)
14+
15+
func TestCoerceString(t *testing.T) {
16+
tests := []struct {
17+
in interface{}
18+
want interface{}
19+
}{
20+
{
21+
in: "hello",
22+
want: "hello",
23+
},
24+
{
25+
in: func() interface{} { s := "hello"; return &s }(),
26+
want: "hello",
27+
},
28+
// Typedef
29+
{
30+
in: myString("hello"),
31+
want: "hello",
32+
},
33+
// Typedef with pointer
34+
{
35+
in: func() interface{} { v := myString("hello"); return &v }(),
36+
want: "hello",
37+
},
38+
// Typedef with nil pointer
39+
{
40+
in: (*myString)(nil),
41+
want: nil,
42+
},
43+
}
44+
45+
for i, tt := range tests {
46+
if got, want := coerceString(tt.in), tt.want; got != want {
47+
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
48+
}
49+
}
50+
}
51+
852
func TestCoerceInt(t *testing.T) {
953
tests := []struct {
1054
in interface{}
@@ -240,11 +284,26 @@ func TestCoerceInt(t *testing.T) {
240284
in: make(map[string]interface{}),
241285
want: nil,
242286
},
287+
// Typedef
288+
{
289+
in: myInt(42),
290+
want: int(42),
291+
},
292+
// Typedef with pointer
293+
{
294+
in: func() interface{} { v := myInt(42); return &v }(),
295+
want: int(42),
296+
},
297+
// Typedef with nil pointer
298+
{
299+
in: (*myInt)(nil),
300+
want: nil,
301+
},
243302
}
244303

245304
for i, tt := range tests {
246305
if got, want := coerceInt(tt.in), tt.want; got != want {
247-
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
306+
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
248307
}
249308
}
250309
}
@@ -438,11 +497,26 @@ func TestCoerceFloat(t *testing.T) {
438497
in: make(map[string]interface{}),
439498
want: nil,
440499
},
500+
// Typedef
501+
{
502+
in: myFloat32(3.14),
503+
want: float32(3.14),
504+
},
505+
// Typedef with pointer
506+
{
507+
in: func() interface{} { v := myFloat32(3.14); return &v }(),
508+
want: float32(3.14),
509+
},
510+
// Typedef with nil pointer
511+
{
512+
in: (*myFloat32)(nil),
513+
want: nil,
514+
},
441515
}
442516

443517
for i, tt := range tests {
444518
if got, want := coerceFloat(tt.in), tt.want; got != want {
445-
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
519+
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
446520
}
447521
}
448522
}
@@ -740,11 +814,26 @@ func TestCoerceBool(t *testing.T) {
740814
in: make(map[string]interface{}),
741815
want: false,
742816
},
817+
// Typedef
818+
{
819+
in: myBool(true),
820+
want: true,
821+
},
822+
// Typedef with pointer
823+
{
824+
in: func() interface{} { v := myBool(true); return &v }(),
825+
want: true,
826+
},
827+
// Typedef with nil pointer
828+
{
829+
in: (*myBool)(nil),
830+
want: nil,
831+
},
743832
}
744833

745834
for i, tt := range tests {
746835
if got, want := coerceBool(tt.in), tt.want; got != want {
747-
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
836+
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
748837
}
749838
}
750839
}

0 commit comments

Comments
 (0)