diff --git a/definition.go b/definition.go index 2f129ccf..c1ec6d58 100644 --- a/definition.go +++ b/definition.go @@ -592,6 +592,16 @@ type ResolveInfo struct { VariableValues map[string]interface{} } +type ResolveResult struct { + Value interface{} + Error error +} + +// When returned from a resolve function, ResolvePromise indicates that the resolution will be done +// asynchronously. When used, an IdleHandler should be specified. This handler must fulfill one or +// more promises each time it is invoked. +type ResolvePromise chan *ResolveResult + type Fields map[string]*Field type Field struct { diff --git a/executor.go b/executor.go index e3245f45..7bcec8bc 100644 --- a/executor.go +++ b/executor.go @@ -9,6 +9,7 @@ import ( "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/promise" ) type ExecuteParams struct { @@ -18,6 +19,8 @@ type ExecuteParams struct { OperationName string Args map[string]interface{} + IdleHandler func() + // Context may be provided to pass application-specific per-request // information to resolve functions. Context context.Context @@ -74,6 +77,7 @@ func Execute(p ExecuteParams) (result *Result) { ExecutionContext: exeContext, Root: p.Root, Operation: exeContext.Operation, + IdleHandler: p.IdleHandler, }) select { case out <- result: @@ -164,6 +168,28 @@ type executeOperationParams struct { ExecutionContext *executionContext Root interface{} Operation ast.Definition + IdleHandler func() +} + +func waitForPromise(p executeOperationParams, promise *promise.Promise) interface{} { + var result interface{} + done := false + promise.Then(func(value interface{}) interface{} { + result = value + done = true + return nil + }) + promise.Schedule() + for !done { + if p.IdleHandler == nil { + panic(errors.New("Asynchronous resolution attempted with no idle handler defined")) + } + p.IdleHandler() + if !promise.Schedule() { + panic(errors.New("No progress was made after idle handler was invoked")) + } + } + return result } func executeOperation(p executeOperationParams) *Result { @@ -185,11 +211,22 @@ func executeOperation(p executeOperationParams) *Result { Fields: fields, } + var data interface{} + if p.Operation.GetOperation() == ast.OperationTypeMutation { - return executeFieldsSerially(executeFieldsParams) + data = executeFieldsSerially(executeFieldsParams) + } else { + data = executeFields(executeFieldsParams) } - return executeFields(executeFieldsParams) + result := &Result{} + if isPromise(data) { + result.Data = waitForPromise(p, data.(*promise.Promise)) + } else { + result.Data = data + } + result.Errors = p.ExecutionContext.Errors + return result } // Extracts the root type of the operation from the schema. @@ -247,7 +284,7 @@ type executeFieldsParams struct { } // Implements the "Evaluating selection sets" section of the spec for "write" mode. -func executeFieldsSerially(p executeFieldsParams) *Result { +func executeFieldsSerially(p executeFieldsParams) interface{} { if p.Source == nil { p.Source = map[string]interface{}{} } @@ -256,22 +293,50 @@ func executeFieldsSerially(p executeFieldsParams) *Result { } finalResults := make(map[string]interface{}, len(p.Fields)) + chain := promise.Resolve(nil) for responseName, fieldASTs := range p.Fields { - resolved, state := resolveField(p.ExecutionContext, p.ParentType, p.Source, fieldASTs) - if state.hasNoFieldDefs { - continue - } - finalResults[responseName] = resolved + fieldASTs := fieldASTs + responseName := responseName + chain = chain.Then(func(interface{}) interface{} { + resolved, state := resolveField(p.ExecutionContext, p.ParentType, p.Source, fieldASTs) + if state.hasNoFieldDefs { + return nil + } + if isPromise(resolved) { + return resolved.(*promise.Promise).Then(func(value interface{}) interface{} { + finalResults[responseName] = value + return nil + }) + } + finalResults[responseName] = resolved + return nil + }) } - return &Result{ - Data: finalResults, - Errors: p.ExecutionContext.Errors, - } + return chain.Then(func(interface{}) interface{} { + return finalResults + }) +} + +func promiseForObject(object map[string]interface{}) *promise.Promise { + keys := make([]string, 0, len(object)) + values := make([]interface{}, 0, len(object)) + for key, value := range object { + keys = append(keys, key) + values = append(values, value) + } + return promise.All(values).Then(func(values interface{}) interface{} { + list := values.([]interface{}) + result := make(map[string]interface{}, len(list)) + for i, value := range list { + result[keys[i]] = value + } + return result + }) } // Implements the "Evaluating selection sets" section of the spec for "read" mode. -func executeFields(p executeFieldsParams) *Result { +func executeFields(p executeFieldsParams) interface{} { if p.Source == nil { p.Source = map[string]interface{}{} } @@ -279,19 +344,24 @@ func executeFields(p executeFieldsParams) *Result { p.Fields = map[string][]*ast.Field{} } + containsPromise := false finalResults := make(map[string]interface{}, len(p.Fields)) for responseName, fieldASTs := range p.Fields { resolved, state := resolveField(p.ExecutionContext, p.ParentType, p.Source, fieldASTs) if state.hasNoFieldDefs { continue } + if isPromise(resolved) { + containsPromise = true + } finalResults[responseName] = resolved } - return &Result{ - Data: finalResults, - Errors: p.ExecutionContext.Errors, + if !containsPromise { + return finalResults } + + return promiseForObject(finalResults) } type collectFieldsParams struct { @@ -504,6 +574,50 @@ type resolveFieldResultState struct { hasNoFieldDefs bool } +func promiseForResolvePromise(ch ResolvePromise) *promise.Promise { + return promise.New(func(resolve func(interface{}), reject func(error)) { + select { + case r := <-ch: + if r.Error != nil { + reject(r.Error) + } else { + resolve(r.Value) + } + default: + } + }) +} + +// If the given value is a ResolvePromise or an iterable containing one or more ResolvePromises, a +// promise is returned for it. +func maybePromise(v interface{}) *promise.Promise { + if isIterable(v) { + v := reflect.ValueOf(v) + containsPromise := false + for i := 0; i < v.Len(); i++ { + if _, ok := v.Index(i).Interface().(ResolvePromise); ok { + containsPromise = true + break + } + } + if containsPromise { + elements := make([]interface{}, v.Len()) + for i := 0; i < v.Len(); i++ { + v := v.Index(i).Interface() + if rp, ok := v.(ResolvePromise); ok { + elements[i] = promiseForResolvePromise(rp) + } else { + elements[i] = v + } + } + return promise.All(elements) + } + } else if ch, ok := v.(ResolvePromise); ok { + return promiseForResolvePromise(ch) + } + return nil +} + // Resolves the field on the given source object. In particular, this // figures out the value that the field returns by calling its resolve function, // then calls completeValue to complete promises, serialize scalars, or execute @@ -570,7 +684,7 @@ func resolveField(eCtx *executionContext, parentType *Object, source interface{} var resolveFnError error - result, resolveFnError = resolveFn(ResolveParams{ + resolveFnValue, resolveFnError := resolveFn(ResolveParams{ Source: source, Args: args, Info: info, @@ -581,8 +695,23 @@ func resolveField(eCtx *executionContext, parentType *Object, source interface{} panic(gqlerrors.FormatError(resolveFnError)) } - completed := completeValueCatchingError(eCtx, returnType, fieldASTs, info, result) - return completed, resultState + // If the value is a promise, chain value completion. + if promise := maybePromise(resolveFnValue); promise != nil { + promise = promise.Then(func(value interface{}) interface{} { + return completeValueCatchingError(eCtx, returnType, fieldASTs, info, value) + }) + // If the value isn't non-null, catch errors. + if _, ok := returnType.(*NonNull); !ok { + promise = promise.Catch(func(err error) interface{} { + eCtx.Errors = append(eCtx.Errors, gqlerrors.FormatError(err)) + return nil + }) + } + return promise, resultState + } + + result = completeValueCatchingError(eCtx, returnType, fieldASTs, info, resolveFnValue) + return result, resultState } func completeValueCatchingError(eCtx *executionContext, returnType Type, fieldASTs []*ast.Field, info ResolveInfo, result interface{}) (completed interface{}) { @@ -759,10 +888,7 @@ func completeObjectValue(eCtx *executionContext, returnType *Object, fieldASTs [ Source: result, Fields: subFieldASTs, } - results := executeFields(executeFieldsParams) - - return results.Data - + return executeFields(executeFieldsParams) } // completeLeafValue complete a leaf value (Scalar / Enum) by serializing to a valid value, returning nil if serialization is not possible. @@ -792,11 +918,19 @@ func completeListValue(eCtx *executionContext, returnType *List, fieldASTs []*as itemType := returnType.OfType completedResults := make([]interface{}, 0, resultVal.Len()) + containsPromise := false for i := 0; i < resultVal.Len(); i++ { val := resultVal.Index(i).Interface() completedItem := completeValueCatchingError(eCtx, itemType, fieldASTs, info, val) + if isPromise(completedItem) { + containsPromise = true + } completedResults = append(completedResults, completedItem) } + + if containsPromise { + return promise.All(completedResults) + } return completedResults } diff --git a/executor_test.go b/executor_test.go index 954d6d30..bf35ce0f 100644 --- a/executor_test.go +++ b/executor_test.go @@ -1807,3 +1807,259 @@ func TestContextDeadline(t *testing.T) { t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expectedErrors, result.Errors)) } } + +func TestAsynchronousResolver(t *testing.T) { + stringChan := make(graphql.ResolvePromise, 1) + stringListChan := make(graphql.ResolvePromise, 1) + + subObjectType := graphql.NewObject( + graphql.ObjectConfig{ + Name: "SubObject", + Fields: graphql.Fields{ + "source": &graphql.Field{ + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return p.Source, nil + }, + }, + "s": &graphql.Field{ + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return graphql.ResolvePromise(stringChan), nil + }, + }, + }, + }, + ) + + for name, tc := range map[string]struct { + Object *graphql.Object + RequestString string + Expected *graphql.Result + }{ + "String": { + Object: graphql.NewObject( + graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "s": &graphql.Field{ + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return graphql.ResolvePromise(stringChan), nil + }, + }, + }, + }, + ), + RequestString: "{s}", + Expected: &graphql.Result{ + Data: map[string]interface{}{ + "s": "foo", + }, + }, + }, + "[String]": { + Object: graphql.NewObject( + graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "l": &graphql.Field{ + Type: graphql.NewList(graphql.String), + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return graphql.ResolvePromise(stringListChan), nil + }, + }, + }, + }, + ), + RequestString: "{l}", + Expected: &graphql.Result{ + Data: map[string]interface{}{ + "l": []interface{}{"foo"}, + }, + }, + }, + "[StringPromise]": { + Object: graphql.NewObject( + graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "l": &graphql.Field{ + Type: graphql.NewList(graphql.String), + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return []graphql.ResolvePromise{graphql.ResolvePromise(stringChan), graphql.ResolvePromise(stringChan)}, nil + }, + }, + }, + }, + ), + RequestString: "{l}", + Expected: &graphql.Result{ + Data: map[string]interface{}{ + "l": []interface{}{"foo", "foo"}, + }, + }, + }, + "[SubObject]": { + Object: graphql.NewObject( + graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "l": &graphql.Field{ + Type: graphql.NewList(subObjectType), + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return []string{"bar", "baz"}, nil + }, + }, + }, + }, + ), + RequestString: "{l{source, s}}", + Expected: &graphql.Result{ + Data: map[string]interface{}{ + "l": []interface{}{ + map[string]interface{}{ + "source": "bar", + "s": "foo", + }, + map[string]interface{}{ + "source": "baz", + "s": "foo", + }, + }, + }, + }, + }, + } { + t.Run(name, func(t *testing.T) { + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: tc.Object, + }) + if err != nil { + t.Fatalf("unexpected error, got: %v", err) + } + + result := graphql.Do(graphql.Params{ + Schema: schema, + RequestString: tc.RequestString, + IdleHandler: func() { + select { + case stringChan <- &graphql.ResolveResult{ + Value: "foo", + }: + default: + } + select { + case stringListChan <- &graphql.ResolveResult{ + Value: []string{"foo"}, + }: + default: + } + }, + }) + if !reflect.DeepEqual(tc.Expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(tc.Expected, result)) + } + }) + } +} + +func TestAsynchronousError(t *testing.T) { + ch := make(graphql.ResolvePromise, 1) + + var queryType = graphql.NewObject( + graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "hello": &graphql.Field{ + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return graphql.ResolvePromise(ch), nil + }, + }, + }, + }) + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: queryType, + }) + if err != nil { + t.Fatalf("unexpected error, got: %v", err) + } + + result := graphql.Do(graphql.Params{ + Schema: schema, + RequestString: "{hello}", + IdleHandler: func() { + ch <- &graphql.ResolveResult{ + Error: errors.New("world"), + } + }, + }) + + expected := &graphql.Result{ + Data: map[string]interface{}{ + "hello": nil, + }, + Errors: []gqlerrors.FormattedError{gqlerrors.FormatError(errors.New("world"))}, + } + + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } +} + +func TestAsynchronousError_NonNull(t *testing.T) { + ch := make(graphql.ResolvePromise, 1) + + var queryType = graphql.NewObject( + graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "hello": &graphql.Field{ + Type: graphql.NewObject( + graphql.ObjectConfig{ + Name: "SubObject", + Fields: graphql.Fields{ + "world": &graphql.Field{ + Type: graphql.NewNonNull(graphql.String), + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return graphql.ResolvePromise(ch), nil + }, + }, + }, + }, + ), + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return graphql.ResolvePromise(ch), nil + }, + }, + }, + }) + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: queryType, + }) + if err != nil { + t.Fatalf("unexpected error, got: %v", err) + } + + result := graphql.Do(graphql.Params{ + Schema: schema, + RequestString: "{hello{world}}", + IdleHandler: func() { + ch <- &graphql.ResolveResult{ + Error: errors.New("!"), + } + }, + }) + + expected := &graphql.Result{ + Data: map[string]interface{}{ + "hello": nil, + }, + Errors: []gqlerrors.FormattedError{gqlerrors.FormatError(errors.New("!"))}, + } + + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } +} diff --git a/graphql.go b/graphql.go index c9bdb168..5b5c8e29 100644 --- a/graphql.go +++ b/graphql.go @@ -28,6 +28,11 @@ type Params struct { // one operation. OperationName string + // IdleHandler is invoked when asynchronous resolution is used and no more work can be completed + // until asynchronous resolution progresses. At least some values must be resolved prior to this + // function's return. + IdleHandler func() + // Context may be provided to pass application-specific per-request // information to resolve functions. Context context.Context @@ -58,6 +63,7 @@ func Do(p Params) *Result { AST: AST, OperationName: p.OperationName, Args: p.VariableValues, + IdleHandler: p.IdleHandler, Context: p.Context, }) } diff --git a/promise/promise.go b/promise/promise.go new file mode 100644 index 00000000..a911c0d2 --- /dev/null +++ b/promise/promise.go @@ -0,0 +1,185 @@ +package promise + +import "reflect" + +type Promise struct { + isFullfilled bool + value interface{} + + isRejected bool + err error + + next []*Promise + dependencies []*Promise + + source func(resolve func(interface{}), reject func(error)) + + parent *Promise + onFulfilled func(value interface{}) interface{} + onRejected func(value error) interface{} +} + +// New returns a new Promise, which is very much like the JavaScript equivalent, but with one +// exception: the function given to New should not actually perform any work asynchronously (or if +// it does, it should be done transparently). The function will be invoked when the promise is +// scheduled. If the promise cannot be fulfilled yet, simply don't invoke resolve until the next +// time the function is called. +func New(f func(resolve func(interface{}), reject func(error))) *Promise { + return &Promise{ + source: f, + } +} + +type Thenable interface { + Then(f func(value interface{}) interface{}) *Promise +} + +// Then appends a handler to be promise, invoking it when the promise is fulfilled. If the handler +// returns a value, it'll be passed as input to the next handler in the chain. If the handler +// returns another promise, the next handler in the chain will receive that promise's value when it +// is fulfilled. +func (p *Promise) Then(onFulfilled func(value interface{}) interface{}) *Promise { + newPromise := &Promise{ + parent: p, + onFulfilled: onFulfilled, + } + p.next = append(p.next, newPromise) + return newPromise +} + +// Returns a Promise and deals with rejected cases only. The Promise returned by Catch is rejected +// if onRejected throws an error or returns a Promise which is itself rejected; otherwise, it is +// resolved. +func (p *Promise) Catch(onRejected func(err error) interface{}) *Promise { + newPromise := &Promise{ + parent: p, + onRejected: onRejected, + } + p.next = append(p.next, newPromise) + return newPromise +} + +// Schedule invokes pending functions for unfulfilled promises and returns true if any progress was +// made. +func (p *Promise) Schedule() (didProgress bool) { + for i := 0; ; i++ { + didProgress := false + for _, dependency := range p.dependencies { + if !dependency.isFullfilled { + if dependency.Schedule() { + didProgress = true + } + } + } + if !p.isFullfilled && !p.isRejected { + if p.source != nil { + p.source(func(value interface{}) { + p.isFullfilled = true + p.value = value + didProgress = true + }, func(err error) { + p.isRejected = true + p.err = err + didProgress = true + }) + } else if p.parent != nil { + if p.parent.isFullfilled || p.parent.isRejected { + if p.parent.isFullfilled { + p.isFullfilled = true + p.value = p.parent.value + if p.onFulfilled != nil { + p.value = p.onFulfilled(p.value) + } + } else { + if p.onRejected != nil { + p.isFullfilled = true + p.value = p.onRejected(p.parent.err) + } else { + p.isRejected = true + p.err = p.parent.err + } + } + didProgress = true + if promise, ok := p.value.(*Promise); ok { + for _, next := range p.next { + next.parent = promise + } + promise.next = append(promise.next, p.next...) + p.next = []*Promise{promise} + } + } else { + didProgress = p.parent.Schedule() + } + } + } + if p.isFullfilled { + for _, next := range p.next { + if next.Schedule() { + didProgress = true + } + } + } + if !didProgress { + return i > 0 + } + } +} + +// Returns a Promise object that is resolved with the given value. If the value is a Thenable (i.e. +// has a Then method), the returned promise will "follow" that thenable, adopting its eventual +// state; otherwise the returned promise will be fulfilled with the value. +func Resolve(value interface{}) *Promise { + if thenable, ok := value.(Thenable); ok { + return thenable.Then(func(value interface{}) interface{} { + return value + }) + } + return New(func(resolve func(interface{}), reject func(error)) { + resolve(value) + }) +} + +// Returns a Promise that is rejected with the given reason. +func Reject(reason error) *Promise { + return New(func(resolve func(interface{}), reject func(error)) { + reject(reason) + }) +} + +// All returns a single Promise that resolves when all of the promises in the argument have resolved +// or when the iterable argument contains no promises. It rejects with the reason of the first +// promise that rejects. +func All(iterable interface{}) *Promise { + v := reflect.ValueOf(iterable) + result := make([]interface{}, v.Len()) + var rejectReason error + remaining := 0 + all := New(func(resolve func(interface{}), reject func(error)) { + if rejectReason != nil { + reject(rejectReason) + } else if remaining == 0 { + resolve(result) + } + }) + for i := 0; i < v.Len(); i++ { + value := v.Index(i).Interface() + promise, ok := value.(*Promise) + if !ok { + result[i] = value + continue + } else if promise == nil { + continue + } + i := i + remaining++ + all.dependencies = append(all.dependencies, promise.Then(func(value interface{}) interface{} { + result[i] = value + remaining-- + return nil + }).Catch(func(err error) interface{} { + rejectReason = err + return nil + })) + } + return all +} diff --git a/promise/promise_test.go b/promise/promise_test.go new file mode 100644 index 00000000..d88064a5 --- /dev/null +++ b/promise/promise_test.go @@ -0,0 +1,144 @@ +package promise + +import ( + "errors" + "testing" +) + +func TestPromiseValueChaining(t *testing.T) { + n := 0 + Resolve(1).Then(func(v interface{}) interface{} { + n = v.(int) + if n != 1 { + t.Fatalf("expected 1, got %v", n) + } + return n + 1 + }).Then(func(v interface{}) interface{} { + n = v.(int) + if n != 2 { + t.Fatalf("expected 2, got %v", n) + } + return n + 1 + }).Then(func(v interface{}) interface{} { + n = v.(int) + if n != 3 { + t.Fatalf("expected 3, got %v", n) + } + return nil + }).Schedule() + if n != 3 { + t.Fatalf("expected 3, got %v", n) + } +} + +func TestCatch(t *testing.T) { + var val interface{} + var err error + Resolve(1).Then(func(interface{}) interface{} { + return Reject(errors.New("reject")) + }).Then(func(interface{}) interface{} { + return nil + }).Catch(func(caught error) interface{} { + err = caught + return "foo" + }).Then(func(value interface{}) interface{} { + val = value + return nil + }).Schedule() + if err == nil { + t.Fatalf("expected non-nil error") + } + if val != "foo" { + t.Fatalf("expected \"foo\", got %v", val) + } +} + +func TestAll(t *testing.T) { + p1 := Resolve(1) + var p2 *Promise + p3 := Resolve(3) + var result []interface{} + All([]interface{}{p1, p2, p3, 4}).Then(func(value interface{}) interface{} { + result = value.([]interface{}) + return nil + }).Schedule() + if len(result) != 4 { + t.Fatalf("expected 4 results, got %v", len(result)) + } + if n, _ := result[0].(int); n != 1 { + t.Fatalf("expected 1, got %v", n) + } + if result[1] != nil { + t.Fatalf("expected nil, got %v", result[1]) + } + if n, _ := result[2].(int); n != 3 { + t.Fatalf("expected 3, got %v", n) + } + if n, _ := result[3].(int); n != 4 { + t.Fatalf("expected 4, got %v", n) + } +} + +func TestAll_Reject(t *testing.T) { + p1 := Resolve(1) + p2 := Reject(errors.New("foo")) + p3 := Resolve(3) + var result []interface{} + var rejectReason error + All([]interface{}{p1, p2, p3, 4}).Then(func(value interface{}) interface{} { + result = value.([]interface{}) + return nil + }).Catch(func(err error) interface{} { + rejectReason = err + return nil + }).Schedule() + if result != nil { + t.Fatalf("expected nil result, got %v", result) + } + if rejectReason == nil { + t.Fatalf("expected non-nil reject reason") + } +} + +func TestAll_Schedule(t *testing.T) { + step := 0 + p1 := New(func(resolve func(interface{}), reject func(error)) { + if step > 1 { + resolve(1) + } + }) + p2 := New(func(resolve func(interface{}), reject func(error)) { + if step > 0 { + resolve(2) + } + }) + var result []interface{} + all := All([]interface{}{p1, p2}).Then(func(value interface{}) interface{} { + result = value.([]interface{}) + return nil + }) + + if all.Schedule() { + t.Fatalf("expected false") + } + + step++ + if !all.Schedule() { + t.Fatalf("expected true") + } + + step++ + if !all.Schedule() { + t.Fatalf("expected true") + } + + if len(result) != 2 { + t.Fatalf("expected 2 results, got %v", len(result)) + } + if n, _ := result[0].(int); n != 1 { + t.Fatalf("expected 1, got %v", n) + } + if n, _ := result[1].(int); n != 2 { + t.Fatalf("expected 2, got %v", n) + } +} diff --git a/values.go b/values.go index 8d0410b7..f07943a4 100644 --- a/values.go +++ b/values.go @@ -13,6 +13,7 @@ import ( "github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/kinds" "github.com/graphql-go/graphql/language/printer" + "github.com/graphql-go/graphql/promise" ) // Prepares an object map of variableValues of the correct type based on the @@ -339,6 +340,12 @@ func isNullish(src interface{}) bool { return false } +// Returns true if src is a *promise.Promise +func isPromise(src interface{}) bool { + _, ok := src.(*promise.Promise) + return ok +} + // Returns true if src is a slice or an array func isIterable(src interface{}) bool { if src == nil {