diff --git a/docs/types/enums.rst b/docs/types/enums.rst index 02cc267c6..a3215cada 100644 --- a/docs/types/enums.rst +++ b/docs/types/enums.rst @@ -61,7 +61,8 @@ you can add description etc. to your enum without changing the original: graphene.Enum.from_enum( AlreadyExistingPyEnum, - description=lambda v: return 'foo' if v == AlreadyExistingPyEnum.Foo else 'bar') + description=lambda v: return 'foo' if v == AlreadyExistingPyEnum.Foo else 'bar' + ) Notes @@ -76,6 +77,7 @@ In the Python ``Enum`` implementation you can access a member by initing the Enu .. code:: python from enum import Enum + class Color(Enum): RED = 1 GREEN = 2 @@ -89,6 +91,7 @@ However, in Graphene ``Enum`` you need to call get to have the same effect: .. code:: python from graphene import Enum + class Color(Enum): RED = 1 GREEN = 2 diff --git a/graphene/relay/tests/test_node.py b/graphene/relay/tests/test_node.py index 92d851054..d46838acd 100644 --- a/graphene/relay/tests/test_node.py +++ b/graphene/relay/tests/test_node.py @@ -1,7 +1,7 @@ import re from graphql_relay import to_global_id -from graphql.pyutils import dedent +from graphene.tests.utils import dedent from ...types import ObjectType, Schema, String from ..node import Node, is_node diff --git a/graphene/relay/tests/test_node_custom.py b/graphene/relay/tests/test_node_custom.py index 30d62e7ba..76a2cad36 100644 --- a/graphene/relay/tests/test_node_custom.py +++ b/graphene/relay/tests/test_node_custom.py @@ -1,5 +1,6 @@ from graphql import graphql_sync -from graphql.pyutils import dedent + +from graphene.tests.utils import dedent from ...types import Interface, ObjectType, Schema from ...types.scalars import Int, String diff --git a/graphene/tests/utils.py b/graphene/tests/utils.py new file mode 100644 index 000000000..b9804d9be --- /dev/null +++ b/graphene/tests/utils.py @@ -0,0 +1,9 @@ +from textwrap import dedent as _dedent + + +def dedent(text: str) -> str: + """Fix indentation of given text by removing leading spaces and tabs. + Also removes leading newlines and trailing spaces and tabs, but keeps trailing + newlines. + """ + return _dedent(text.lstrip("\n").rstrip(" \t")) diff --git a/graphene/types/definitions.py b/graphene/types/definitions.py index 009169201..908cc7c86 100644 --- a/graphene/types/definitions.py +++ b/graphene/types/definitions.py @@ -1,3 +1,5 @@ +from enum import Enum as PyEnum + from graphql import ( GraphQLEnumType, GraphQLInputObjectType, @@ -5,6 +7,7 @@ GraphQLObjectType, GraphQLScalarType, GraphQLUnionType, + Undefined, ) @@ -36,7 +39,19 @@ class GrapheneScalarType(GrapheneGraphQLType, GraphQLScalarType): class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType): - pass + def serialize(self, value): + if not isinstance(value, PyEnum): + enum = self.graphene_type._meta.enum + try: + # Try and get enum by value + value = enum(value) + except ValueError: + # Try and get enum by name + try: + value = enum[value] + except KeyError: + return Undefined + return super(GrapheneEnumType, self).serialize(value) class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType): diff --git a/graphene/types/schema.py b/graphene/types/schema.py index 29ead4a70..ce0c74398 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -172,7 +172,7 @@ def create_enum(graphene_type): deprecation_reason = graphene_type._meta.deprecation_reason(value) values[name] = GraphQLEnumValue( - value=value.value, + value=value, description=description, deprecation_reason=deprecation_reason, ) diff --git a/graphene/types/tests/test_enum.py b/graphene/types/tests/test_enum.py index 1b6181208..8d5e87af4 100644 --- a/graphene/types/tests/test_enum.py +++ b/graphene/types/tests/test_enum.py @@ -1,7 +1,12 @@ +from textwrap import dedent + from ..argument import Argument from ..enum import Enum, PyEnum from ..field import Field from ..inputfield import InputField +from ..inputobjecttype import InputObjectType +from ..mutation import Mutation +from ..scalars import String from ..schema import ObjectType, Schema @@ -224,3 +229,245 @@ class Meta: "GREEN": RGB1.GREEN, "BLUE": RGB1.BLUE, } + + +def test_enum_types(): + from enum import Enum as PyEnum + + class Color(PyEnum): + """Primary colors""" + + RED = 1 + YELLOW = 2 + BLUE = 3 + + GColor = Enum.from_enum(Color) + + class Query(ObjectType): + color = GColor(required=True) + + def resolve_color(_, info): + return Color.RED + + schema = Schema(query=Query) + + assert str(schema) == dedent( + '''\ + type Query { + color: Color! + } + + """Primary colors""" + enum Color { + RED + YELLOW + BLUE + } + ''' + ) + + +def test_enum_resolver(): + from enum import Enum as PyEnum + + class Color(PyEnum): + RED = 1 + GREEN = 2 + BLUE = 3 + + GColor = Enum.from_enum(Color) + + class Query(ObjectType): + color = GColor(required=True) + + def resolve_color(_, info): + return Color.RED + + schema = Schema(query=Query) + + results = schema.execute("query { color }") + assert not results.errors + + assert results.data["color"] == Color.RED.name + + +def test_enum_resolver_compat(): + from enum import Enum as PyEnum + + class Color(PyEnum): + RED = 1 + GREEN = 2 + BLUE = 3 + + GColor = Enum.from_enum(Color) + + class Query(ObjectType): + color = GColor(required=True) + color_by_name = GColor(required=True) + + def resolve_color(_, info): + return Color.RED.value + + def resolve_color_by_name(_, info): + return Color.RED.name + + schema = Schema(query=Query) + + results = schema.execute( + """query { + color + colorByName + }""" + ) + assert not results.errors + + assert results.data["color"] == Color.RED.name + assert results.data["colorByName"] == Color.RED.name + + +def test_enum_resolver_invalid(): + from enum import Enum as PyEnum + + class Color(PyEnum): + RED = 1 + GREEN = 2 + BLUE = 3 + + GColor = Enum.from_enum(Color) + + class Query(ObjectType): + color = GColor(required=True) + + def resolve_color(_, info): + return "BLACK" + + schema = Schema(query=Query) + + results = schema.execute("query { color }") + assert results.errors + assert ( + results.errors[0].message + == "Expected a value of type 'Color' but received: 'BLACK'" + ) + + +def test_field_enum_argument(): + class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + class Brick(ObjectType): + color = Color(required=True) + + color_filter = None + + class Query(ObjectType): + bricks_by_color = Field(Brick, color=Color(required=True)) + + def resolve_bricks_by_color(_, info, color): + nonlocal color_filter + color_filter = color + return Brick(color=color) + + schema = Schema(query=Query) + + results = schema.execute( + """ + query { + bricksByColor(color: RED) { + color + } + } + """ + ) + assert not results.errors + assert results.data == {"bricksByColor": {"color": "RED"}} + assert color_filter == Color.RED + + +def test_mutation_enum_input(): + class RGB(Enum): + """Available colors""" + + RED = 1 + GREEN = 2 + BLUE = 3 + + color_input = None + + class CreatePaint(Mutation): + class Arguments: + color = RGB(required=True) + + color = RGB(required=True) + + def mutate(_, info, color): + nonlocal color_input + color_input = color + return CreatePaint(color=color) + + class MyMutation(ObjectType): + create_paint = CreatePaint.Field() + + class Query(ObjectType): + a = String() + + schema = Schema(query=Query, mutation=MyMutation) + result = schema.execute( + """ mutation MyMutation { + createPaint(color: RED) { + color + } + } + """ + ) + assert not result.errors + assert result.data == {"createPaint": {"color": "RED"}} + + assert color_input == RGB.RED + + +def test_mutation_enum_input_type(): + class RGB(Enum): + """Available colors""" + + RED = 1 + GREEN = 2 + BLUE = 3 + + class ColorInput(InputObjectType): + color = RGB(required=True) + + color_input_value = None + + class CreatePaint(Mutation): + class Arguments: + color_input = ColorInput(required=True) + + color = RGB(required=True) + + def mutate(_, info, color_input): + nonlocal color_input_value + color_input_value = color_input.color + return CreatePaint(color=color_input.color) + + class MyMutation(ObjectType): + create_paint = CreatePaint.Field() + + class Query(ObjectType): + a = String() + + schema = Schema(query=Query, mutation=MyMutation) + result = schema.execute( + """ mutation MyMutation { + createPaint(colorInput: { color: RED }) { + color + } + } + """, + ) + assert not result.errors + assert result.data == {"createPaint": {"color": "RED"}} + + assert color_input_value == RGB.RED diff --git a/graphene/types/tests/test_schema.py b/graphene/types/tests/test_schema.py index 0c85e1708..fe4739c98 100644 --- a/graphene/types/tests/test_schema.py +++ b/graphene/types/tests/test_schema.py @@ -1,7 +1,7 @@ +from graphql.type import GraphQLObjectType, GraphQLSchema from pytest import raises -from graphql.type import GraphQLObjectType, GraphQLSchema -from graphql.pyutils import dedent +from graphene.tests.utils import dedent from ..field import Field from ..objecttype import ObjectType diff --git a/graphene/types/utils.py b/graphene/types/utils.py index 3b195d692..1976448aa 100644 --- a/graphene/types/utils.py +++ b/graphene/types/utils.py @@ -41,3 +41,10 @@ def get_type(_type): if inspect.isfunction(_type) or isinstance(_type, partial): return _type() return _type + + +def get_underlying_type(_type): + """Get the underlying type even if it is wrapped in structures like NonNull""" + while hasattr(_type, "of_type"): + _type = _type.of_type + return _type diff --git a/setup.py b/setup.py index 24bddcf90..48d7d285d 100644 --- a/setup.py +++ b/setup.py @@ -82,7 +82,7 @@ def run_tests(self): keywords="api graphql protocol rest relay graphene", packages=find_packages(exclude=["examples*"]), install_requires=[ - "graphql-core>=3.1.1,<4", + "graphql-core>=3.1.2,<4", "graphql-relay>=3.0,<4", "aniso8601>=8,<9", ],