Skip to content

Commit 90cfb09

Browse files
authored
Split parsing, validation and execution (#43) (#53)
Instead of graphql()/graphql_sync() we now call execute() directly. This also allows adding custom validation rules and limiting the number of reported errors.
1 parent db23e62 commit 90cfb09

File tree

2 files changed

+111
-34
lines changed

2 files changed

+111
-34
lines changed

graphql_server/__init__.py

+48-34
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
import json
1010
from collections import namedtuple
1111
from collections.abc import MutableMapping
12-
from typing import Any, Callable, Dict, List, Optional, Type, Union
12+
from typing import Any, Callable, Collection, Dict, List, Optional, Type, Union
1313

14-
from graphql import ExecutionResult, GraphQLError, GraphQLSchema, OperationType
15-
from graphql import format_error as format_error_default
16-
from graphql import get_operation_ast, parse
17-
from graphql.graphql import graphql, graphql_sync
14+
from graphql.error import GraphQLError
15+
from graphql.error import format_error as format_error_default
16+
from graphql.execution import ExecutionResult, execute
17+
from graphql.language import OperationType, parse
1818
from graphql.pyutils import AwaitableOrValue
19+
from graphql.type import GraphQLSchema, validate_schema
20+
from graphql.utilities import get_operation_ast
21+
from graphql.validation import ASTValidationRule, validate
1922

2023
from .error import HttpQueryError
2124
from .version import version, version_info
@@ -223,36 +226,48 @@ def load_json_variables(variables: Optional[Union[str, Dict]]) -> Optional[Dict]
223226
return variables # type: ignore
224227

225228

229+
def assume_not_awaitable(_value: Any) -> bool:
230+
"""Replacement for isawaitable if everything is assumed to be synchronous."""
231+
return False
232+
233+
226234
def get_response(
227235
schema: GraphQLSchema,
228236
params: GraphQLParams,
229237
catch_exc: Type[BaseException],
230238
allow_only_query: bool = False,
231239
run_sync: bool = True,
240+
validation_rules: Optional[Collection[Type[ASTValidationRule]]] = None,
241+
max_errors: Optional[int] = None,
232242
**kwargs,
233243
) -> Optional[AwaitableOrValue[ExecutionResult]]:
234244
"""Get an individual execution result as response, with option to catch errors.
235245
236-
This does the same as graphql_impl() except that you can either
237-
throw an error on the ExecutionResult if allow_only_query is set to True
238-
or catch errors that belong to an exception class that you need to pass
239-
as a parameter.
246+
This will validate the schema (if the schema is used for the first time),
247+
parse the query, check if this is a query if allow_only_query is set to True,
248+
validate the query (optionally with additional validation rules and limiting
249+
the number of errors), execute the request (asynchronously if run_sync is not
250+
set to True), and return the ExecutionResult. You can also catch all errors that
251+
belong to an exception class specified by catch_exc.
240252
"""
241-
242253
# noinspection PyBroadException
243254
try:
244255
if not params.query:
245256
raise HttpQueryError(400, "Must provide query string.")
246257

258+
schema_validation_errors = validate_schema(schema)
259+
if schema_validation_errors:
260+
return ExecutionResult(data=None, errors=schema_validation_errors)
261+
262+
try:
263+
document = parse(params.query)
264+
except GraphQLError as e:
265+
return ExecutionResult(data=None, errors=[e])
266+
except Exception as e:
267+
e = GraphQLError(str(e), original_error=e)
268+
return ExecutionResult(data=None, errors=[e])
269+
247270
if allow_only_query:
248-
# Parse document to check that only query operations are used
249-
try:
250-
document = parse(params.query)
251-
except GraphQLError as e:
252-
return ExecutionResult(data=None, errors=[e])
253-
except Exception as e:
254-
e = GraphQLError(str(e), original_error=e)
255-
return ExecutionResult(data=None, errors=[e])
256271
operation_ast = get_operation_ast(document, params.operation_name)
257272
if operation_ast:
258273
operation = operation_ast.operation.value
@@ -264,22 +279,21 @@ def get_response(
264279
headers={"Allow": "POST"},
265280
)
266281

267-
if run_sync:
268-
execution_result = graphql_sync(
269-
schema=schema,
270-
source=params.query,
271-
variable_values=params.variables,
272-
operation_name=params.operation_name,
273-
**kwargs,
274-
)
275-
else:
276-
execution_result = graphql( # type: ignore
277-
schema=schema,
278-
source=params.query,
279-
variable_values=params.variables,
280-
operation_name=params.operation_name,
281-
**kwargs,
282-
)
282+
validation_errors = validate(
283+
schema, document, rules=validation_rules, max_errors=max_errors
284+
)
285+
if validation_errors:
286+
return ExecutionResult(data=None, errors=validation_errors)
287+
288+
execution_result = execute(
289+
schema,
290+
document,
291+
variable_values=params.variables,
292+
operation_name=params.operation_name,
293+
is_awaitable=assume_not_awaitable if run_sync else None,
294+
**kwargs,
295+
)
296+
283297
except catch_exc:
284298
return None
285299

tests/test_query.py

+63
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from graphql.error import GraphQLError
44
from graphql.execution import ExecutionResult
5+
from graphql.validation import ValidationRule
56
from pytest import raises
67

78
from graphql_server import (
@@ -123,6 +124,68 @@ def test_reports_validation_errors():
123124
assert response.status_code == 400
124125

125126

127+
def test_reports_custom_validation_errors():
128+
class CustomValidationRule(ValidationRule):
129+
def enter_field(self, node, *_args):
130+
self.report_error(GraphQLError("Custom validation error.", node))
131+
132+
results, params = run_http_query(
133+
schema,
134+
"get",
135+
{},
136+
query_data=dict(query="{ test }"),
137+
validation_rules=[CustomValidationRule],
138+
)
139+
140+
assert as_dicts(results) == [
141+
{
142+
"data": None,
143+
"errors": [
144+
{
145+
"message": "Custom validation error.",
146+
"locations": [{"line": 1, "column": 3}],
147+
"path": None,
148+
}
149+
],
150+
}
151+
]
152+
153+
response = encode_execution_results(results)
154+
assert response.status_code == 400
155+
156+
157+
def test_reports_max_num_of_validation_errors():
158+
results, params = run_http_query(
159+
schema,
160+
"get",
161+
{},
162+
query_data=dict(query="{ test, unknownOne, unknownTwo }"),
163+
max_errors=1,
164+
)
165+
166+
assert as_dicts(results) == [
167+
{
168+
"data": None,
169+
"errors": [
170+
{
171+
"message": "Cannot query field 'unknownOne' on type 'QueryRoot'.",
172+
"locations": [{"line": 1, "column": 9}],
173+
"path": None,
174+
},
175+
{
176+
"message": "Too many validation errors, error limit reached."
177+
" Validation aborted.",
178+
"locations": None,
179+
"path": None,
180+
},
181+
],
182+
}
183+
]
184+
185+
response = encode_execution_results(results)
186+
assert response.status_code == 400
187+
188+
126189
def test_non_dict_params_in_non_batch_query():
127190
with raises(HttpQueryError) as exc_info:
128191
# noinspection PyTypeChecker

0 commit comments

Comments
 (0)