diff --git a/definition.go b/definition.go index 1d226fcc..19d6bcf8 100644 --- a/definition.go +++ b/definition.go @@ -7,6 +7,7 @@ import ( "regexp" "github.com/graphql-go/graphql/language/ast" + "golang.org/x/net/context" ) // These are all of the possible kinds of @@ -552,6 +553,11 @@ type ResolveInfo struct { RootValue interface{} Operation ast.Definition VariableValues map[string]interface{} + + // Context is passed through to resolve functions from either Params.Context + // or ExecuteParams.Context. This can be used to provide per-request state + // from the application. + Context context.Context } type Fields map[string]*Field diff --git a/executor.go b/executor.go index 9be274c1..bc3bc552 100644 --- a/executor.go +++ b/executor.go @@ -8,6 +8,7 @@ import ( "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/ast" + "golang.org/x/net/context" ) type ExecuteParams struct { @@ -16,6 +17,10 @@ type ExecuteParams struct { AST *ast.Document OperationName string Args map[string]interface{} + + // Context may be provided to pass application-specific per-request + // information to resolve functions. + Context context.Context } func Execute(p ExecuteParams) (result *Result) { @@ -29,6 +34,7 @@ func Execute(p ExecuteParams) (result *Result) { Args: p.Args, Errors: nil, Result: result, + Context: p.Context, }) if err != nil { @@ -62,6 +68,7 @@ type BuildExecutionCtxParams struct { Args map[string]interface{} Errors []gqlerrors.FormattedError Result *Result + Context context.Context } type ExecutionContext struct { Schema Schema @@ -70,6 +77,7 @@ type ExecutionContext struct { Operation ast.Definition VariableValues map[string]interface{} Errors []gqlerrors.FormattedError + Context context.Context } func buildExecutionContext(p BuildExecutionCtxParams) (*ExecutionContext, error) { @@ -124,6 +132,7 @@ func buildExecutionContext(p BuildExecutionCtxParams) (*ExecutionContext, error) eCtx.Operation = operation eCtx.VariableValues = variableValues eCtx.Errors = p.Errors + eCtx.Context = p.Context return eCtx, nil } @@ -492,6 +501,7 @@ func resolveField(eCtx *ExecutionContext, parentType *Object, source interface{} RootValue: eCtx.Root, Operation: eCtx.Operation, VariableValues: eCtx.VariableValues, + Context: eCtx.Context, } // TODO: If an error occurs while calling the field `resolve` function, ensure that diff --git a/executor_test.go b/executor_test.go index e643f31e..3e7779f5 100644 --- a/executor_test.go +++ b/executor_test.go @@ -11,6 +11,7 @@ import ( "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/location" "github.com/graphql-go/graphql/testutil" + "golang.org/x/net/context" ) func TestExecutesArbitraryCode(t *testing.T) { @@ -295,17 +296,17 @@ func TestMergesParallelFragments(t *testing.T) { } } -func TestThreadsContextCorrectly(t *testing.T) { +func TestThreadsSourceCorrectly(t *testing.T) { query := ` query Example { a } ` data := map[string]interface{}{ - "contextThing": "thing", + "key": "value", } - var resolvedContext map[string]interface{} + var resolvedSource map[string]interface{} schema, err := graphql.NewSchema(graphql.SchemaConfig{ Query: graphql.NewObject(graphql.ObjectConfig{ @@ -314,8 +315,8 @@ func TestThreadsContextCorrectly(t *testing.T) { "a": &graphql.Field{ Type: graphql.String, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - resolvedContext = p.Source.(map[string]interface{}) - return resolvedContext, nil + resolvedSource = p.Source.(map[string]interface{}) + return resolvedSource, nil }, }, }, @@ -339,9 +340,9 @@ func TestThreadsContextCorrectly(t *testing.T) { t.Fatalf("wrong result, unexpected errors: %v", result.Errors) } - expected := "thing" - if resolvedContext["contextThing"] != expected { - t.Fatalf("Expected context.contextThing to equal %v, got %v", expected, resolvedContext["contextThing"]) + expected := "value" + if resolvedSource["key"] != expected { + t.Fatalf("Expected context.key to equal %v, got %v", expected, resolvedSource["key"]) } } @@ -404,6 +405,53 @@ func TestCorrectlyThreadsArguments(t *testing.T) { } } +func TestThreadsContextCorrectly(t *testing.T) { + + query := ` + query Example { a } + ` + + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: graphql.NewObject(graphql.ObjectConfig{ + Name: "Type", + Fields: graphql.Fields{ + "a": &graphql.Field{ + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return p.Info.Context.Value("foo"), nil + }, + }, + }, + }), + }) + if err != nil { + t.Fatalf("Error in schema %v", err.Error()) + } + + // parse query + ast := testutil.TestParse(t, query) + + // execute + ep := graphql.ExecuteParams{ + Schema: schema, + AST: ast, + Context: context.WithValue(context.Background(), "foo", "bar"), + } + result := testutil.TestExecute(t, ep) + if len(result.Errors) > 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + + expected := &graphql.Result{ + Data: map[string]interface{}{ + "a": "bar", + }, + } + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } +} + func TestNullsOutErrorSubtrees(t *testing.T) { // TODO: TestNullsOutErrorSubtrees test for go-routines if implemented diff --git a/graphql.go b/graphql.go index 0621b6b5..db6b86ab 100644 --- a/graphql.go +++ b/graphql.go @@ -4,6 +4,7 @@ import ( "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/parser" "github.com/graphql-go/graphql/language/source" + "golang.org/x/net/context" ) type Params struct { @@ -12,6 +13,10 @@ type Params struct { RootObject map[string]interface{} VariableValues map[string]interface{} OperationName string + + // Context may be provided to pass application-specific per-request + // information to resolve functions. + Context context.Context } func Do(p Params) *Result { @@ -39,5 +44,6 @@ func Do(p Params) *Result { AST: AST, OperationName: p.OperationName, Args: p.VariableValues, + Context: p.Context, }) } diff --git a/graphql_test.go b/graphql_test.go index 07c3e508..85a4860e 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -6,6 +6,7 @@ import ( "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/testutil" + "golang.org/x/net/context" ) type T struct { @@ -131,3 +132,42 @@ func TestBasicGraphQLExample(t *testing.T) { } } + +func TestThreadsContextFromParamsThrough(t *testing.T) { + extractFieldFromContextFn := func(p graphql.ResolveParams) (interface{}, error) { + return p.Info.Context.Value(p.Args["key"]), nil + } + + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "value": &graphql.Field{ + Type: graphql.String, + Args: graphql.FieldConfigArgument{ + "key": &graphql.ArgumentConfig{Type: graphql.String}, + }, + Resolve: extractFieldFromContextFn, + }, + }, + }), + }) + if err != nil { + t.Fatalf("wrong result, unexpected errors: %v", err.Error()) + } + query := `{ value(key:"a") }` + + result := graphql.Do(graphql.Params{ + Schema: schema, + RequestString: query, + Context: context.WithValue(context.TODO(), "a", "xyz"), + }) + if len(result.Errors) > 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + expected := map[string]interface{}{"value": "xyz"} + if !reflect.DeepEqual(result.Data, expected) { + t.Fatalf("wrong result, query: %v, graphql result diff: %v", query, testutil.Diff(expected, result)) + } + +}