diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index f4a0f5a0e..c4d969b49 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -1,3 +1,5 @@ +import inspect + from .base import BaseOptions, BaseType, BaseTypeMeta from .field import Field from .interface import Interface @@ -137,7 +139,7 @@ def __init_subclass_with_meta__( fields = {} for interface in interfaces: - assert issubclass( + assert inspect.isclass(interface) and issubclass( interface, Interface ), f'All interfaces of {cls.__name__} must be a subclass of Interface. Received "{interface}".' fields.update(interface._meta.fields) diff --git a/graphene/types/schema.py b/graphene/types/schema.py index ce0c74398..f40975b6e 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -23,6 +23,7 @@ GraphQLObjectType, GraphQLSchema, GraphQLString, + GraphQLType, Undefined, ) @@ -106,6 +107,11 @@ def __init__( def add_type(self, graphene_type): if inspect.isfunction(graphene_type): graphene_type = graphene_type() + + # If type is a GraphQLType from graphql-core then return it immediately + if isinstance(graphene_type, GraphQLType): + return graphene_type + if isinstance(graphene_type, List): return GraphQLList(self.add_type(graphene_type.of_type)) if isinstance(graphene_type, NonNull): @@ -252,7 +258,8 @@ def types(): union_types = [] for graphene_objecttype in graphene_type._meta.types: object_type = create_graphql_type(graphene_objecttype) - assert object_type.graphene_type == graphene_objecttype + if hasattr(object_type, "graphene_type"): + assert object_type.graphene_type == graphene_objecttype union_types.append(object_type) return union_types diff --git a/graphene/types/tests/test_type_map.py b/graphene/types/tests/test_type_map.py index 334eb2415..c26c5c0af 100644 --- a/graphene/types/tests/test_type_map.py +++ b/graphene/types/tests/test_type_map.py @@ -1,3 +1,7 @@ +from textwrap import dedent + +import pytest +from graphql import parse, build_ast_schema from graphql.type import ( GraphQLArgument, GraphQLEnumType, @@ -20,6 +24,7 @@ from ..scalars import Int, String from ..structures import List, NonNull from ..schema import Schema +from ..union import Union def create_type_map(types, auto_camelcase=True): @@ -270,3 +275,157 @@ class Meta: assert graphql_type.is_type_of assert graphql_type.is_type_of({}, None) is True assert graphql_type.is_type_of(MyObjectType(), None) is False + + +def test_graphql_type(): + """Type map should allow direct GraphQL types""" + MyGraphQLType = GraphQLObjectType( + name="MyGraphQLType", + fields={ + "hello": GraphQLField(GraphQLString, resolve=lambda obj, info: "world") + }, + ) + + class Query(ObjectType): + graphql_type = Field(MyGraphQLType) + + def resolve_graphql_type(root, info): + return {} + + schema = Schema(query=Query) + assert str(schema) == dedent( + """\ + type Query { + graphqlType: MyGraphQLType + } + + type MyGraphQLType { + hello: String + } + """ + ) + + results = schema.execute( + """ + query { + graphqlType { + hello + } + } + """ + ) + assert not results.errors + assert results.data == {"graphqlType": {"hello": "world"}} + + +def test_graphql_type_interface(): + MyGraphQLInterface = GraphQLInterfaceType( + name="MyGraphQLType", + fields={ + "hello": GraphQLField(GraphQLString, resolve=lambda obj, info: "world") + }, + ) + + with pytest.raises(AssertionError) as error: + + class MyGrapheneType(ObjectType): + class Meta: + interfaces = (MyGraphQLInterface,) + + assert str(error.value) == ( + "All interfaces of MyGrapheneType must be a subclass of Interface. " + 'Received "MyGraphQLType".' + ) + + +def test_graphql_type_union(): + MyGraphQLType = GraphQLObjectType( + name="MyGraphQLType", + fields={ + "hello": GraphQLField(GraphQLString, resolve=lambda obj, info: "world") + }, + ) + + class MyGrapheneType(ObjectType): + hi = String(default_value="world") + + class MyUnion(Union): + class Meta: + types = (MyGraphQLType, MyGrapheneType) + + @classmethod + def resolve_type(cls, instance, info): + return MyGraphQLType + + class Query(ObjectType): + my_union = Field(MyUnion) + + def resolve_my_union(root, info): + return {} + + schema = Schema(query=Query) + assert str(schema) == dedent( + """\ + type Query { + myUnion: MyUnion + } + + union MyUnion = MyGraphQLType | MyGrapheneType + + type MyGraphQLType { + hello: String + } + + type MyGrapheneType { + hi: String + } + """ + ) + + results = schema.execute( + """ + query { + myUnion { + __typename + } + } + """ + ) + assert not results.errors + assert results.data == {"myUnion": {"__typename": "MyGraphQLType"}} + + +def test_graphql_type_from_sdl(): + types = """ + type Pet { + name: String! + } + + type User { + name: String! + pets: [Pet!]! + } + """ + ast_document = parse(types) + sdl_schema = build_ast_schema(ast_document) + + class Query(ObjectType): + my_user = Field(sdl_schema.get_type("User")) + + schema = Schema(query=Query) + assert str(schema) == dedent( + """\ + type Query { + myUser: User + } + + type User { + name: String! + pets: [Pet!]! + } + + type Pet { + name: String! + } + """ + )