Skip to content

Commit e00c129

Browse files
committed
Add subscribe_field_resolver to execution context
Replicates graphql/graphql-js@6aee19b
1 parent 5bfb52a commit e00c129

File tree

4 files changed

+41
-3
lines changed

4 files changed

+41
-3
lines changed

src/graphql/execution/execute.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class ExecutionContext:
187187
variable_values: Dict[str, Any]
188188
field_resolver: GraphQLFieldResolver
189189
type_resolver: GraphQLTypeResolver
190+
subscribe_field_resolver: GraphQLFieldResolver
190191
errors: List[GraphQLError]
191192
middleware_manager: Optional[MiddlewareManager]
192193

@@ -202,6 +203,7 @@ def __init__(
202203
variable_values: Dict[str, Any],
203204
field_resolver: GraphQLFieldResolver,
204205
type_resolver: GraphQLTypeResolver,
206+
subscribe_field_resolver: GraphQLFieldResolver,
205207
errors: List[GraphQLError],
206208
middleware_manager: Optional[MiddlewareManager],
207209
is_awaitable: Optional[Callable[[Any], bool]],
@@ -214,6 +216,7 @@ def __init__(
214216
self.variable_values = variable_values
215217
self.field_resolver = field_resolver # type: ignore
216218
self.type_resolver = type_resolver # type: ignore
219+
self.subscribe_field_resolver = subscribe_field_resolver # type: ignore
217220
self.errors = errors
218221
self.middleware_manager = middleware_manager
219222
if is_awaitable:
@@ -231,6 +234,7 @@ def build(
231234
operation_name: Optional[str] = None,
232235
field_resolver: Optional[GraphQLFieldResolver] = None,
233236
type_resolver: Optional[GraphQLTypeResolver] = None,
237+
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
234238
middleware: Optional[Middleware] = None,
235239
is_awaitable: Optional[Callable[[Any], bool]] = None,
236240
) -> Union[List[GraphQLError], "ExecutionContext"]:
@@ -298,6 +302,7 @@ def build(
298302
coerced_variable_values, # coerced values
299303
field_resolver or default_field_resolver,
300304
type_resolver or default_type_resolver,
305+
subscribe_field_resolver or default_field_resolver,
301306
[],
302307
middleware_manager,
303308
is_awaitable,
@@ -978,6 +983,7 @@ def execute(
978983
operation_name: Optional[str] = None,
979984
field_resolver: Optional[GraphQLFieldResolver] = None,
980985
type_resolver: Optional[GraphQLTypeResolver] = None,
986+
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
981987
middleware: Optional[Middleware] = None,
982988
execution_context_class: Optional[Type["ExecutionContext"]] = None,
983989
is_awaitable: Optional[Callable[[Any], bool]] = None,
@@ -1009,6 +1015,7 @@ def execute(
10091015
operation_name,
10101016
field_resolver,
10111017
type_resolver,
1018+
subscribe_field_resolver,
10121019
middleware,
10131020
is_awaitable,
10141021
)
@@ -1071,6 +1078,7 @@ def execute_sync(
10711078
operation_name,
10721079
field_resolver,
10731080
type_resolver,
1081+
None,
10741082
middleware,
10751083
execution_context_class,
10761084
is_awaitable,

src/graphql/execution/subscribe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ async def create_source_event_stream(
100100
context_value: Any = None,
101101
variable_values: Optional[Dict[str, Any]] = None,
102102
operation_name: Optional[str] = None,
103-
field_resolver: Optional[GraphQLFieldResolver] = None,
103+
subscribe_field_resolver: Optional[GraphQLFieldResolver] = None,
104104
) -> Union[AsyncIterable[Any], ExecutionResult]:
105105
"""Create source event stream
106106
@@ -138,7 +138,7 @@ async def create_source_event_stream(
138138
context_value,
139139
variable_values,
140140
operation_name,
141-
field_resolver,
141+
subscribe_field_resolver=subscribe_field_resolver,
142142
)
143143

144144
# Return early errors if execution context failed.
@@ -193,7 +193,7 @@ async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
193193

194194
# Call the `subscribe()` resolver or the default resolver to produce an
195195
# AsyncIterable yielding raw payloads.
196-
resolve_fn = field_def.subscribe or context.field_resolver
196+
resolve_fn = field_def.subscribe or context.subscribe_field_resolver
197197

198198
event_stream = resolve_fn(context.root_value, info, **args)
199199
if context.is_awaitable(event_stream):

src/graphql/graphql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def graphql_impl(
191191
operation_name,
192192
field_resolver,
193193
type_resolver,
194+
None,
194195
middleware,
195196
execution_context_class,
196197
is_awaitable,

tests/execution/test_subscribe.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,35 @@ async def foo_generator(_obj, _info):
209209

210210
await subscription.aclose()
211211

212+
@mark.asyncio
213+
async def uses_a_custom_default_subscribe_field_resolver():
214+
schema = GraphQLSchema(
215+
query=DummyQueryType,
216+
subscription=GraphQLObjectType(
217+
"Subscription", {"foo": GraphQLField(GraphQLString)}
218+
),
219+
)
220+
221+
class Root:
222+
@staticmethod
223+
async def custom_foo():
224+
yield {"foo": "FooValue"}
225+
226+
subscription = await subscribe(
227+
schema,
228+
document=parse("subscription { foo }"),
229+
root_value=Root(),
230+
subscribe_field_resolver=lambda root, _info: root.custom_foo(),
231+
)
232+
assert isinstance(subscription, MapAsyncIterator)
233+
234+
assert await anext(subscription) == (
235+
{"foo": "FooValue"},
236+
None,
237+
)
238+
239+
await subscription.aclose()
240+
212241
@mark.asyncio
213242
async def should_only_resolve_the_first_field_of_invalid_multi_field():
214243
did_resolve = {"foo": False, "bar": False}

0 commit comments

Comments
 (0)