diff --git a/extensions.go b/extensions.go index 1c448fbf..01376541 100644 --- a/extensions.go +++ b/extensions.go @@ -27,6 +27,9 @@ type ( ResolveFieldFinishFunc func(interface{}, error) // resolveFieldFinishFuncHandler calls the resolveFieldFinishFns for all the extensions resolveFieldFinishFuncHandler func(interface{}, error) []gqlerrors.FormattedError + + // ExecuteFunc computes a GraphQL response. + ExecuteFunc func(p ExecuteParams) (result *Result) ) // Extension is an interface for extensions in graphql @@ -49,6 +52,11 @@ type Extension interface { // ResolveFieldDidStart notifies about the start of the resolving of a field ResolveFieldDidStart(context.Context, *ResolveInfo) (context.Context, ResolveFieldFinishFunc) + // ExecuteMiddleware allows the extension to wrap the execution using a chain-of-responsibility pattern, + // modeled after http.Handler. next must be called somewhere within the implementation to preserve the chain. The + // final call and only the final call should trigger ExecutionDidStart and ExecutionFinish hooks. + ExecuteMiddleware(next ExecuteFunc) ExecuteFunc + // HasResult returns if the extension wants to add data to the result HasResult() bool diff --git a/extensions_test.go b/extensions_test.go index ea23f752..731d2dba 100644 --- a/extensions_test.go +++ b/extensions_test.go @@ -247,6 +247,40 @@ func TestExtensionExecutionFinishFuncPanic(t *testing.T) { } } +func TestExecutionMiddlewareWasCalled(t *testing.T) { + ext := newtestExt("testExt") + var called bool + ext.executionFunc = func(next graphql.ExecuteFunc) graphql.ExecuteFunc { + return func(p graphql.ExecuteParams) *graphql.Result { + called = true + return next(p) + } + } + + schema := tinit(t) + query := `query Example { a }` + schema.AddExtensions(ext) + + result := graphql.Do(graphql.Params{ + Schema: schema, + RequestString: query, + }) + + expected := &graphql.Result{ + Data: map[string]interface{}{ + "a": "foo", + }, + } + + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } + + if !reflect.DeepEqual(called, true) { + t.Fatalf("Middleware was not called") + } +} + func TestExtensionResolveFieldDidStartPanic(t *testing.T) { ext := newtestExt("testExt") ext.resolveFieldDidStartFn = func(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) { @@ -424,6 +458,11 @@ func newtestExt(name string) *testExt { return nil } } + if ext.executionFunc == nil { + ext.executionFunc = func(next graphql.ExecuteFunc) graphql.ExecuteFunc { + return next + } + } return ext } @@ -436,6 +475,7 @@ type testExt struct { validationDidStartFn func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) executionDidStartFn func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) resolveFieldDidStartFn func(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) + executionFunc func(next graphql.ExecuteFunc) graphql.ExecuteFunc } func (t *testExt) Init(ctx context.Context, p *graphql.Params) context.Context { @@ -469,3 +509,7 @@ func (t *testExt) ExecutionDidStart(ctx context.Context) (context.Context, graph func (t *testExt) ResolveFieldDidStart(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) { return t.resolveFieldDidStartFn(ctx, i) } + +func (t *testExt) ExecuteMiddleware(next graphql.ExecuteFunc) graphql.ExecuteFunc { + return t.executionFunc(next) +} diff --git a/graphql.go b/graphql.go index 2b1f6a29..5c620a54 100644 --- a/graphql.go +++ b/graphql.go @@ -105,7 +105,13 @@ func Do(p Params) *Result { } } - return Execute(ExecuteParams{ + var exe ExecuteFunc + exe = Execute + for _, e := range p.Schema.extensions { + exe = e.ExecuteMiddleware(exe) + } + + return exe(ExecuteParams{ Schema: p.Schema, Root: p.RootObject, AST: AST,