diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 60e14ddd..1e7846eb 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -1,16 +1,16 @@ import datetime +import sys import typing import warnings from decimal import Decimal from functools import singledispatch -from typing import Any +from typing import Any, cast -from sqlalchemy import types +from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql from sqlalchemy.orm import interfaces, strategies -from graphene import (ID, Boolean, Date, DateTime, Dynamic, Enum, Field, Float, - Int, List, String, Time) +import graphene from graphene.types.json import JSONString from .batching import get_batch_resolver @@ -19,8 +19,9 @@ default_connection_field_factory) from .registry import get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver -from .utils import (registry_sqlalchemy_model_from_str, safe_isinstance, - singledispatchbymatchfunction, value_equals) +from .utils import (DummyImport, registry_sqlalchemy_model_from_str, + safe_isinstance, singledispatchbymatchfunction, + value_equals) try: from typing import ForwardRef @@ -29,15 +30,14 @@ from typing import _ForwardRef as ForwardRef try: - from sqlalchemy_utils import (ChoiceType, JSONType, ScalarListType, - TSVectorType) + from sqlalchemy_utils.types.choice import EnumTypeImpl except ImportError: - ChoiceType = JSONType = ScalarListType = TSVectorType = object + EnumTypeImpl = object try: - from sqlalchemy_utils.types.choice import EnumTypeImpl + import sqlalchemy_utils as sqa_utils except ImportError: - EnumTypeImpl = object + sqa_utils = DummyImport() is_selectin_available = getattr(strategies, 'SelectInLoader', None) @@ -79,7 +79,7 @@ def dynamic_type(): return _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching_, connection_field_factory, **field_kwargs) - return Dynamic(dynamic_type) + return graphene.Dynamic(dynamic_type) def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_field_name, **field_kwargs): @@ -100,7 +100,7 @@ def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_ resolver = get_batch_resolver(relationship_prop) if batching else \ get_attr_resolver(obj_type, relationship_prop.key) - return Field(child_type, resolver=resolver, **field_kwargs) + return graphene.Field(child_type, resolver=resolver, **field_kwargs) def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs): @@ -117,7 +117,7 @@ def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, conn child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) if not child_type._meta.connection: - return Field(List(child_type), **field_kwargs) + return graphene.Field(graphene.List(child_type), **field_kwargs) # TODO Allow override of connection_field_factory and resolver via ORMField if connection_field_factory is None: @@ -134,7 +134,7 @@ def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): if 'description' not in field_kwargs: field_kwargs['description'] = getattr(hybrid_prop, "__doc__", None) - return Field( + return graphene.Field( resolver=resolver, **field_kwargs ) @@ -181,7 +181,7 @@ def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): field_kwargs.setdefault('required', not is_column_nullable(column)) field_kwargs.setdefault('description', get_column_doc(column)) - return Field( + return graphene.Field( resolver=resolver, **field_kwargs ) @@ -195,75 +195,90 @@ def convert_sqlalchemy_type(type, column, registry=None): ) -@convert_sqlalchemy_type.register(types.Date) -@convert_sqlalchemy_type.register(types.Time) -@convert_sqlalchemy_type.register(types.String) -@convert_sqlalchemy_type.register(types.Text) -@convert_sqlalchemy_type.register(types.Unicode) -@convert_sqlalchemy_type.register(types.UnicodeText) -@convert_sqlalchemy_type.register(postgresql.UUID) +@convert_sqlalchemy_type.register(sqa_types.String) +@convert_sqlalchemy_type.register(sqa_types.Text) +@convert_sqlalchemy_type.register(sqa_types.Unicode) +@convert_sqlalchemy_type.register(sqa_types.UnicodeText) @convert_sqlalchemy_type.register(postgresql.INET) @convert_sqlalchemy_type.register(postgresql.CIDR) -@convert_sqlalchemy_type.register(TSVectorType) +@convert_sqlalchemy_type.register(sqa_utils.TSVectorType) +@convert_sqlalchemy_type.register(sqa_utils.EmailType) +@convert_sqlalchemy_type.register(sqa_utils.URLType) +@convert_sqlalchemy_type.register(sqa_utils.IPAddressType) def convert_column_to_string(type, column, registry=None): - return String + return graphene.String + + +@convert_sqlalchemy_type.register(postgresql.UUID) +@convert_sqlalchemy_type.register(sqa_utils.UUIDType) +def convert_column_to_uuid(type, column, registry=None): + return graphene.UUID -@convert_sqlalchemy_type.register(types.DateTime) +@convert_sqlalchemy_type.register(sqa_types.DateTime) def convert_column_to_datetime(type, column, registry=None): - from graphene.types.datetime import DateTime - return DateTime + return graphene.DateTime -@convert_sqlalchemy_type.register(types.SmallInteger) -@convert_sqlalchemy_type.register(types.Integer) +@convert_sqlalchemy_type.register(sqa_types.Time) +def convert_column_to_time(type, column, registry=None): + return graphene.Time + + +@convert_sqlalchemy_type.register(sqa_types.Date) +def convert_column_to_date(type, column, registry=None): + return graphene.Date + + +@convert_sqlalchemy_type.register(sqa_types.SmallInteger) +@convert_sqlalchemy_type.register(sqa_types.Integer) def convert_column_to_int_or_id(type, column, registry=None): - return ID if column.primary_key else Int + return graphene.ID if column.primary_key else graphene.Int -@convert_sqlalchemy_type.register(types.Boolean) +@convert_sqlalchemy_type.register(sqa_types.Boolean) def convert_column_to_boolean(type, column, registry=None): - return Boolean + return graphene.Boolean -@convert_sqlalchemy_type.register(types.Float) -@convert_sqlalchemy_type.register(types.Numeric) -@convert_sqlalchemy_type.register(types.BigInteger) +@convert_sqlalchemy_type.register(sqa_types.Float) +@convert_sqlalchemy_type.register(sqa_types.Numeric) +@convert_sqlalchemy_type.register(sqa_types.BigInteger) def convert_column_to_float(type, column, registry=None): - return Float + return graphene.Float -@convert_sqlalchemy_type.register(types.Enum) +@convert_sqlalchemy_type.register(sqa_types.Enum) def convert_enum_to_enum(type, column, registry=None): return lambda: enum_for_sa_enum(type, registry or get_global_registry()) # TODO Make ChoiceType conversion consistent with other enums -@convert_sqlalchemy_type.register(ChoiceType) +@convert_sqlalchemy_type.register(sqa_utils.ChoiceType) def convert_choice_to_enum(type, column, registry=None): - name = "{}_{}".format(column.table.name, column.name).upper() + name = "{}_{}".format(column.table.name, column.key).upper() if isinstance(type.type_impl, EnumTypeImpl): # type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta # do not use from_enum here because we can have more than one enum column in table - return Enum(name, list((v.name, v.value) for v in type.choices)) + return graphene.Enum(name, list((v.name, v.value) for v in type.choices)) else: - return Enum(name, type.choices) + return graphene.Enum(name, type.choices) -@convert_sqlalchemy_type.register(ScalarListType) +@convert_sqlalchemy_type.register(sqa_utils.ScalarListType) def convert_scalar_list_to_list(type, column, registry=None): - return List(String) + return graphene.List(graphene.String) def init_array_list_recursive(inner_type, n): - return inner_type if n == 0 else List(init_array_list_recursive(inner_type, n - 1)) + return inner_type if n == 0 else graphene.List(init_array_list_recursive(inner_type, n - 1)) -@convert_sqlalchemy_type.register(types.ARRAY) +@convert_sqlalchemy_type.register(sqa_types.ARRAY) @convert_sqlalchemy_type.register(postgresql.ARRAY) def convert_array_to_list(_type, column, registry=None): inner_type = convert_sqlalchemy_type(column.type.item_type, column) - return List(init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1)) + return graphene.List(init_array_list_recursive(inner_type, (column.type.dimensions or 1) - 1)) @convert_sqlalchemy_type.register(postgresql.HSTORE) @@ -273,38 +288,50 @@ def convert_json_to_string(type, column, registry=None): return JSONString -@convert_sqlalchemy_type.register(JSONType) +@convert_sqlalchemy_type.register(sqa_utils.JSONType) +@convert_sqlalchemy_type.register(sqa_types.JSON) def convert_json_type_to_string(type, column, registry=None): return JSONString +@convert_sqlalchemy_type.register(sqa_types.Variant) +def convert_variant_to_impl_type(type, column, registry=None): + return convert_sqlalchemy_type(type.impl, column, registry=registry) + + @singledispatchbymatchfunction def convert_sqlalchemy_hybrid_property_type(arg: Any): existing_graphql_type = get_global_registry().get_type_for_model(arg) if existing_graphql_type: return existing_graphql_type + if isinstance(arg, type(graphene.ObjectType)): + return arg + + if isinstance(arg, type(graphene.Scalar)): + return arg + # No valid type found, warn and fall back to graphene.String warnings.warn( (f"I don't know how to generate a GraphQL type out of a \"{arg}\" type." "Falling back to \"graphene.String\"") ) - return String + return graphene.String @convert_sqlalchemy_hybrid_property_type.register(value_equals(str)) def convert_sqlalchemy_hybrid_property_type_str(arg): - return String + return graphene.String @convert_sqlalchemy_hybrid_property_type.register(value_equals(int)) def convert_sqlalchemy_hybrid_property_type_int(arg): - return Int + return graphene.Int @convert_sqlalchemy_hybrid_property_type.register(value_equals(float)) def convert_sqlalchemy_hybrid_property_type_float(arg): - return Float + return graphene.Float @convert_sqlalchemy_hybrid_property_type.register(value_equals(Decimal)) @@ -312,39 +339,85 @@ def convert_sqlalchemy_hybrid_property_type_decimal(arg): # The reason Decimal should be serialized as a String is because this is a # base10 type used in things like money, and string allows it to not # lose precision (which would happen if we downcasted to a Float, for example) - return String + return graphene.String @convert_sqlalchemy_hybrid_property_type.register(value_equals(bool)) def convert_sqlalchemy_hybrid_property_type_bool(arg): - return Boolean + return graphene.Boolean @convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.datetime)) def convert_sqlalchemy_hybrid_property_type_datetime(arg): - return DateTime + return graphene.DateTime @convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.date)) def convert_sqlalchemy_hybrid_property_type_date(arg): - return Date + return graphene.Date @convert_sqlalchemy_hybrid_property_type.register(value_equals(datetime.time)) def convert_sqlalchemy_hybrid_property_type_time(arg): - return Time + return graphene.Time -@convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) == typing.Union) -def convert_sqlalchemy_hybrid_property_type_option_t(arg): - # Option is actually Union[T, ] +def is_union(arg) -> bool: + if sys.version_info >= (3, 10): + from types import UnionType + + if isinstance(arg, UnionType): + return True + return getattr(arg, '__origin__', None) == typing.Union + + +def graphene_union_for_py_union(obj_types: typing.List[graphene.ObjectType], registry) -> graphene.Union: + union_type = registry.get_union_for_object_types(obj_types) + + if union_type is None: + # Union Name is name of the three + union_name = ''.join(sorted([obj_type._meta.name for obj_type in obj_types])) + union_type = graphene.Union(union_name, obj_types) + registry.register_union_type(union_type, obj_types) + + return union_type + + +@convert_sqlalchemy_hybrid_property_type.register(is_union) +def convert_sqlalchemy_hybrid_property_union(arg): + """ + Converts Unions (Union[X,Y], or X | Y for python > 3.10) to the corresponding graphene schema object. + Since Optionals are internally represented as Union[T, ], they are handled here as well. + + The GQL Spec currently only allows for ObjectType unions: + GraphQL Unions represent an object that could be one of a list of GraphQL Object types, but provides for no + guaranteed fields between those types. + That's why we have to check for the nested types to be instances of graphene.ObjectType, except for the union case. + + type(x) == _types.UnionType is necessary to support X | Y notation, but might break in future python releases. + """ + from .registry import get_global_registry + # Option is actually Union[T, ] # Just get the T out of the list of arguments by filtering out the NoneType - internal_type = next(filter(lambda x: not type(None) == x, arg.__args__)) + nested_types = list(filter(lambda x: not type(None) == x, arg.__args__)) - graphql_internal_type = convert_sqlalchemy_hybrid_property_type(internal_type) + # Map the graphene types to the nested types. + # We use convert_sqlalchemy_hybrid_property_type instead of the registry to account for ForwardRefs, Lists,... + graphene_types = list(map(convert_sqlalchemy_hybrid_property_type, nested_types)) + + # If only one type is left after filtering out NoneType, the Union was an Optional + if len(graphene_types) == 1: + return graphene_types[0] + + # Now check if every type is instance of an ObjectType + if not all(isinstance(graphene_type, type(graphene.ObjectType)) for graphene_type in graphene_types): + raise ValueError("Cannot convert hybrid_property Union to graphene.Union: the Union contains scalars. " + "Please add the corresponding hybrid_property to the excluded fields in the ObjectType, " + "or use an ORMField to override this behaviour.") - return graphql_internal_type + return graphene_union_for_py_union(cast(typing.List[graphene.ObjectType], list(graphene_types)), + get_global_registry()) @convert_sqlalchemy_hybrid_property_type.register(lambda x: getattr(x, '__origin__', None) in [list, typing.List]) @@ -354,7 +427,7 @@ def convert_sqlalchemy_hybrid_property_type_list_t(arg): graphql_internal_type = convert_sqlalchemy_hybrid_property_type(internal_type) - return List(graphql_internal_type) + return graphene.List(graphql_internal_type) @convert_sqlalchemy_hybrid_property_type.register(safe_isinstance(ForwardRef)) @@ -363,11 +436,12 @@ def convert_sqlalchemy_hybrid_property_forwardref(arg): Generate a lambda that will resolve the type at runtime This takes care of self-references """ + from .registry import get_global_registry def forward_reference_solver(): model = registry_sqlalchemy_model_from_str(arg.__forward_arg__) if not model: - return String + return graphene.String # Always fall back to string if no ForwardRef type found. return get_global_registry().get_type_for_model(model) diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index f100be19..a2ed17ad 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -144,9 +144,9 @@ def sort_enum_for_object_type( column = orm_field.columns[0] if only_indexed and not (column.primary_key or column.index): continue - asc_name = get_name(column.name, True) + asc_name = get_name(column.key, True) asc_value = EnumValue(asc_name, column.asc()) - desc_name = get_name(column.name, False) + desc_name = get_name(column.key, False) desc_value = EnumValue(desc_name, column.desc()) if column.primary_key: default.append(asc_value) diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index acfa744b..80470d9b 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,7 +1,9 @@ from collections import defaultdict +from typing import List, Type from sqlalchemy.types import Enum as SQLAlchemyEnumType +import graphene from graphene import Enum @@ -13,12 +15,13 @@ def __init__(self): self._registry_composites = {} self._registry_enums = {} self._registry_sort_enums = {} + self._registry_unions = {} def register(self, obj_type): - from .types import SQLAlchemyObjectType + from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -37,7 +40,7 @@ def register_orm_field(self, obj_type, field_name, orm_field): from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -55,7 +58,7 @@ def register_composite_converter(self, composite, converter): def get_converter_for_composite(self, composite): return self._registry_composites.get(composite) - def register_enum(self, sa_enum, graphene_enum): + def register_enum(self, sa_enum: SQLAlchemyEnumType, graphene_enum: Enum): if not isinstance(sa_enum, SQLAlchemyEnumType): raise TypeError( "Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum) @@ -67,14 +70,14 @@ def register_enum(self, sa_enum, graphene_enum): self._registry_enums[sa_enum] = graphene_enum - def get_graphene_enum_for_sa_enum(self, sa_enum): + def get_graphene_enum_for_sa_enum(self, sa_enum: SQLAlchemyEnumType): return self._registry_enums.get(sa_enum) - def register_sort_enum(self, obj_type, sort_enum): - from .types import SQLAlchemyObjectType + def register_sort_enum(self, obj_type, sort_enum: Enum): + from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -83,9 +86,26 @@ def register_sort_enum(self, obj_type, sort_enum): raise TypeError("Expected Graphene Enum, but got: {!r}".format(sort_enum)) self._registry_sort_enums[obj_type] = sort_enum - def get_sort_enum_for_object_type(self, obj_type): + def get_sort_enum_for_object_type(self, obj_type: graphene.ObjectType): return self._registry_sort_enums.get(obj_type) + def register_union_type(self, union: graphene.Union, obj_types: List[Type[graphene.ObjectType]]): + if not isinstance(union, graphene.Union): + raise TypeError( + "Expected graphene.Union, but got: {!r}".format(union) + ) + + for obj_type in obj_types: + if not isinstance(obj_type, type(graphene.ObjectType)): + raise TypeError( + "Expected Graphene ObjectType, but got: {!r}".format(obj_type) + ) + + self._registry_unions[frozenset(obj_types)] = union + + def get_union_for_object_types(self, obj_types : List[Type[graphene.ObjectType]]): + return self._registry_unions.get(frozenset(obj_types)) + registry = None diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index e41adb51..dc399ee0 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -5,8 +5,8 @@ from decimal import Decimal from typing import List, Optional, Tuple -from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table, - func, select) +from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, Numeric, + String, Table, func, select) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import column_property, composite, mapper, relationship @@ -228,3 +228,9 @@ def hybrid_prop_self_referential_list(self) -> List['ShoppingCart']: @hybrid_property def hybrid_prop_optional_self_referential(self) -> Optional['ShoppingCart']: return None + + +class KeyedModel(Base): + __tablename__ = "test330" + id = Column(Integer(), primary_key=True) + reporter_number = Column("% reporter_number", Numeric, key="reporter_number") diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 70e11713..a6c2b1bf 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,28 +1,28 @@ import enum +import sys from typing import Dict, Union import pytest +import sqlalchemy_utils as sqa_utils from sqlalchemy import Column, func, select, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite -from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType import graphene -from graphene import Boolean, Float, Int, Scalar, String from graphene.relay import Node -from graphene.types.datetime import Date, DateTime, Time -from graphene.types.json import JSONString -from graphene.types.structures import List, Structure +from graphene.types.structures import Structure from ..converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, + convert_sqlalchemy_hybrid_method, convert_sqlalchemy_relationship) from ..fields import (UnsortedSQLAlchemyConnectionField, default_connection_field_factory) from ..registry import Registry, get_global_registry -from ..types import SQLAlchemyObjectType +from ..types import ORMField, SQLAlchemyObjectType from .models import (Article, CompositeFullName, Pet, Reporter, ShoppingCart, ShoppingCartItem) @@ -51,23 +51,117 @@ class Model(declarative_base()): return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) -def test_should_unknown_sqlalchemy_field_raise_exception(): - re_err = "Don't know how to convert the SQLAlchemy field" - with pytest.raises(Exception, match=re_err): - # support legacy Binary type and subsequent LargeBinary - get_field(getattr(types, 'LargeBinary', types.BINARY)()) +def get_hybrid_property_type(prop_method): + class Model(declarative_base()): + __tablename__ = 'model' + id_ = Column(types.Integer, primary_key=True) + prop = prop_method + + column_prop = inspect(Model).all_orm_descriptors['prop'] + return convert_sqlalchemy_hybrid_method(column_prop, mock_resolver(), **ORMField().kwargs) + + +def test_hybrid_prop_int(): + @hybrid_property + def prop_method() -> int: + return 42 + + assert get_hybrid_property_type(prop_method).type == graphene.Int + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +def test_hybrid_prop_scalar_union_310(): + @hybrid_property + def prop_method() -> int | str: + return "not allowed in gql schema" + + with pytest.raises(ValueError, + match=r"Cannot convert hybrid_property Union to " + r"graphene.Union: the Union contains scalars. \.*"): + get_hybrid_property_type(prop_method) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +def test_hybrid_prop_scalar_union_and_optional_310(): + """Checks if the use of Optionals does not interfere with non-conform scalar return types""" + + @hybrid_property + def prop_method() -> int | None: + return 42 + + assert get_hybrid_property_type(prop_method).type == graphene.Int + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +def test_should_union_work_310(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + registry = reg + + @hybrid_property + def prop_method() -> Union[PetType, ShoppingCartType]: + return None + + @hybrid_property + def prop_method_2() -> Union[ShoppingCartType, PetType]: + return None + + field_type_1 = get_hybrid_property_type(prop_method).type + field_type_2 = get_hybrid_property_type(prop_method_2).type + + assert isinstance(field_type_1, graphene.Union) + assert field_type_1 is field_type_2 + + # TODO verify types of the union + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10") +def test_should_union_work_310(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + class ShoppingCartType(SQLAlchemyObjectType): + class Meta: + model = ShoppingCartItem + registry = reg + + @hybrid_property + def prop_method() -> PetType | ShoppingCartType: + return None + @hybrid_property + def prop_method_2() -> ShoppingCartType | PetType: + return None -def test_should_date_convert_string(): - assert get_field(types.Date()).type == graphene.String + field_type_1 = get_hybrid_property_type(prop_method).type + field_type_2 = get_hybrid_property_type(prop_method_2).type + + assert isinstance(field_type_1, graphene.Union) + assert field_type_1 is field_type_2 def test_should_datetime_convert_datetime(): - assert get_field(types.DateTime()).type == DateTime + assert get_field(types.DateTime()).type == graphene.DateTime + +def test_should_time_convert_time(): + assert get_field(types.Time()).type == graphene.Time -def test_should_time_convert_string(): - assert get_field(types.Time()).type == graphene.String + +def test_should_date_convert_date(): + assert get_field(types.Date()).type == graphene.Date def test_should_string_convert_string(): @@ -86,6 +180,30 @@ def test_should_unicodetext_convert_string(): assert get_field(types.UnicodeText()).type == graphene.String +def test_should_tsvector_convert_string(): + assert get_field(sqa_utils.TSVectorType()).type == graphene.String + + +def test_should_email_convert_string(): + assert get_field(sqa_utils.EmailType()).type == graphene.String + + +def test_should_URL_convert_string(): + assert get_field(sqa_utils.URLType()).type == graphene.String + + +def test_should_IPaddress_convert_string(): + assert get_field(sqa_utils.IPAddressType()).type == graphene.String + + +def test_should_inet_convert_string(): + assert get_field(postgresql.INET()).type == graphene.String + + +def test_should_cidr_convert_string(): + assert get_field(postgresql.CIDR()).type == graphene.String + + def test_should_enum_convert_enum(): field = get_field(types.Enum(enum.Enum("TwoNumbers", ("one", "two")))) field_type = field.type() @@ -142,7 +260,7 @@ def test_should_numeric_convert_float(): def test_should_choice_convert_enum(): - field = get_field(ChoiceType([(u"es", u"Spanish"), (u"en", u"English")])) + field = get_field(sqa_utils.ChoiceType([(u"es", u"Spanish"), (u"en", u"English")])) graphene_type = field.type assert issubclass(graphene_type, graphene.Enum) assert graphene_type._meta.name == "MODEL_COLUMN" @@ -155,7 +273,7 @@ class TestEnum(enum.Enum): es = u"Spanish" en = u"English" - field = get_field(ChoiceType(TestEnum, impl=types.String())) + field = get_field(sqa_utils.ChoiceType(TestEnum, impl=types.String())) graphene_type = field.type assert issubclass(graphene_type, graphene.Enum) assert graphene_type._meta.name == "MODEL_COLUMN" @@ -163,12 +281,32 @@ class TestEnum(enum.Enum): assert graphene_type._meta.enum.__members__["en"].value == "English" +def test_choice_enum_column_key_name_issue_301(): + """ + Verifies that the sort enum name is generated from the column key instead of the name, + in case the column has an invalid enum name. See #330 + """ + + class TestEnum(enum.Enum): + es = u"Spanish" + en = u"English" + + testChoice = Column("% descuento1", sqa_utils.ChoiceType(TestEnum, impl=types.String()), key="descuento1") + field = get_field_from_column(testChoice) + + graphene_type = field.type + assert issubclass(graphene_type, graphene.Enum) + assert graphene_type._meta.name == "MODEL_DESCUENTO1" + assert graphene_type._meta.enum.__members__["es"].value == "Spanish" + assert graphene_type._meta.enum.__members__["en"].value == "English" + + def test_should_intenum_choice_convert_enum(): class TestEnum(enum.IntEnum): one = 1 two = 2 - field = get_field(ChoiceType(TestEnum, impl=types.String())) + field = get_field(sqa_utils.ChoiceType(TestEnum, impl=types.String())) graphene_type = field.type assert issubclass(graphene_type, graphene.Enum) assert graphene_type._meta.name == "MODEL_COLUMN" @@ -185,13 +323,22 @@ def test_should_columproperty_convert(): def test_should_scalar_list_convert_list(): - field = get_field(ScalarListType()) + field = get_field(sqa_utils.ScalarListType()) assert isinstance(field.type, graphene.List) assert field.type.of_type == graphene.String def test_should_jsontype_convert_jsonstring(): - assert get_field(JSONType()).type == JSONString + assert get_field(sqa_utils.JSONType()).type == graphene.JSONString + assert get_field(types.JSON).type == graphene.JSONString + + +def test_should_variant_int_convert_int(): + assert get_field(types.Variant(types.Integer(), {})).type == graphene.Int + + +def test_should_variant_string_convert_string(): + assert get_field(types.Variant(types.String(), {})).type == graphene.String def test_should_manytomany_convert_connectionorlist(): @@ -291,7 +438,11 @@ class Meta: def test_should_postgresql_uuid_convert(): - assert get_field(postgresql.UUID()).type == graphene.String + assert get_field(postgresql.UUID()).type == graphene.UUID + + +def test_should_sqlalchemy_utils_uuid_convert(): + assert get_field(sqa_utils.UUIDType()).type == graphene.UUID def test_should_postgresql_enum_convert(): @@ -405,8 +556,8 @@ class Meta: # Check ShoppingCartItem's Properties and Return Types ####################################################### - shopping_cart_item_expected_types: Dict[str, Union[Scalar, Structure]] = { - 'hybrid_prop_shopping_cart': List(ShoppingCartType) + shopping_cart_item_expected_types: Dict[str, Union[graphene.Scalar, Structure]] = { + 'hybrid_prop_shopping_cart': graphene.List(ShoppingCartType) } assert sorted(list(ShoppingCartItemType._meta.fields.keys())) == sorted([ @@ -421,9 +572,9 @@ class Meta: # this is a simple way of showing the failed property name # instead of having to unroll the loop. - assert ( - (hybrid_prop_name, str(hybrid_prop_field.type)) == - (hybrid_prop_name, str(hybrid_prop_expected_return_type)) + assert (hybrid_prop_name, str(hybrid_prop_field.type)) == ( + hybrid_prop_name, + str(hybrid_prop_expected_return_type), ) assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property @@ -431,27 +582,27 @@ class Meta: # Check ShoppingCart's Properties and Return Types ################################################### - shopping_cart_expected_types: Dict[str, Union[Scalar, Structure]] = { + shopping_cart_expected_types: Dict[str, Union[graphene.Scalar, Structure]] = { # Basic types - "hybrid_prop_str": String, - "hybrid_prop_int": Int, - "hybrid_prop_float": Float, - "hybrid_prop_bool": Boolean, - "hybrid_prop_decimal": String, # Decimals should be serialized Strings - "hybrid_prop_date": Date, - "hybrid_prop_time": Time, - "hybrid_prop_datetime": DateTime, + "hybrid_prop_str": graphene.String, + "hybrid_prop_int": graphene.Int, + "hybrid_prop_float": graphene.Float, + "hybrid_prop_bool": graphene.Boolean, + "hybrid_prop_decimal": graphene.String, # Decimals should be serialized Strings + "hybrid_prop_date": graphene.Date, + "hybrid_prop_time": graphene.Time, + "hybrid_prop_datetime": graphene.DateTime, # Lists and Nested Lists - "hybrid_prop_list_int": List(Int), - "hybrid_prop_list_date": List(Date), - "hybrid_prop_nested_list_int": List(List(Int)), - "hybrid_prop_deeply_nested_list_int": List(List(List(Int))), + "hybrid_prop_list_int": graphene.List(graphene.Int), + "hybrid_prop_list_date": graphene.List(graphene.Date), + "hybrid_prop_nested_list_int": graphene.List(graphene.List(graphene.Int)), + "hybrid_prop_deeply_nested_list_int": graphene.List(graphene.List(graphene.List(graphene.Int))), "hybrid_prop_first_shopping_cart_item": ShoppingCartItemType, - "hybrid_prop_shopping_cart_item_list": List(ShoppingCartItemType), - "hybrid_prop_unsupported_type_tuple": String, + "hybrid_prop_shopping_cart_item_list": graphene.List(ShoppingCartItemType), + "hybrid_prop_unsupported_type_tuple": graphene.String, # Self Referential List "hybrid_prop_self_referential": ShoppingCartType, - "hybrid_prop_self_referential_list": List(ShoppingCartType), + "hybrid_prop_self_referential_list": graphene.List(ShoppingCartType), # Optionals "hybrid_prop_optional_self_referential": ShoppingCartType, } @@ -468,8 +619,8 @@ class Meta: # this is a simple way of showing the failed property name # instead of having to unroll the loop. - assert ( - (hybrid_prop_name, str(hybrid_prop_field.type)) == - (hybrid_prop_name, str(hybrid_prop_expected_return_type)) + assert (hybrid_prop_name, str(hybrid_prop_field.type)) == ( + hybrid_prop_name, + str(hybrid_prop_expected_return_type), ) assert hybrid_prop_field.description is None # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index 0403c4f0..f451f355 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -1,12 +1,13 @@ import pytest from sqlalchemy.types import Enum as SQLAlchemyEnum +import graphene from graphene import Enum as GrapheneEnum from ..registry import Registry from ..types import SQLAlchemyObjectType from ..utils import EnumValue -from .models import Pet +from .models import Pet, Reporter def test_register_object_type(): @@ -126,3 +127,56 @@ class Meta: re_err = r"Expected Graphene Enum, but got: .*PetType.*" with pytest.raises(TypeError, match=re_err): reg.register_sort_enum(PetType, PetType) + + +def test_register_union(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + union_types = [PetType, ReporterType] + union = graphene.Union('ReporterPet', tuple(union_types)) + + reg.register_union_type(union, union_types) + + assert reg.get_union_for_object_types(union_types) == union + # Order should not matter + assert reg.get_union_for_object_types([ReporterType, PetType]) == union + + +def test_register_union_scalar(): + reg = Registry() + + union_types = [graphene.String, graphene.Int] + union = graphene.Union('StringInt', tuple(union_types)) + + re_err = r"Expected Graphene ObjectType, but got: .*String.*" + with pytest.raises(TypeError, match=re_err): + reg.register_union_type(union, union_types) + + +def test_register_union_incorrect_types(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + union_types = [PetType, ReporterType] + union = PetType + + re_err = r"Expected graphene.Union, but got: .*PetType.*" + with pytest.raises(TypeError, match=re_err): + reg.register_union_type(union, union_types) diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index 6291d4f8..e2510abc 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -7,7 +7,7 @@ from ..fields import SQLAlchemyConnectionField from ..types import SQLAlchemyObjectType from ..utils import to_type_name -from .models import Base, HairKind, Pet +from .models import Base, HairKind, KeyedModel, Pet from .test_query import to_std_dicts @@ -383,3 +383,26 @@ def makeNodes(nodeList): assert [node["node"]["name"] for node in result.data["noSort"]["edges"]] == [ node["node"]["name"] for node in result.data["noDefaultSort"]["edges"] ] + + +def test_sort_enum_from_key_issue_330(): + """ + Verifies that the sort enum name is generated from the column key instead of the name, + in case the column has an invalid enum name. See #330 + """ + + class KeyedType(SQLAlchemyObjectType): + class Meta: + model = KeyedModel + + sort_enum = KeyedType.sort_enum() + assert isinstance(sort_enum, type(Enum)) + assert sort_enum._meta.name == "KeyedTypeSortEnum" + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "REPORTER_NUMBER_ASC", + "REPORTER_NUMBER_DESC", + ] + assert str(sort_enum.REPORTER_NUMBER_ASC.value.value) == 'test330."% reporter_number" ASC' + assert str(sort_enum.REPORTER_NUMBER_DESC.value.value) == 'test330."% reporter_number" DESC' diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 9a2e992d..00e8b3af 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -76,7 +76,7 @@ class Meta: assert sorted(list(ReporterType._meta.fields.keys())) == sorted([ # Columns - "column_prop", # SQLAlchemy retuns column properties first + "column_prop", "id", "first_name", "last_name", diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index e13d919c..de359e05 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -3,8 +3,8 @@ from graphene import Enum, List, ObjectType, Schema, String -from ..utils import (get_session, sort_argument_for_model, sort_enum_for_model, - to_enum_value_name, to_type_name) +from ..utils import (DummyImport, get_session, sort_argument_for_model, + sort_enum_for_model, to_enum_value_name, to_type_name) from .models import Base, Editor, Pet @@ -99,3 +99,7 @@ class MultiplePK(Base): assert set(arg.default_value) == set( (MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc") ) + +def test_dummy_import(): + dummy_module = DummyImport() + assert dummy_module.foo == object diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 084f9b86..f6ee9b62 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -8,8 +8,6 @@ from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError -from graphene_sqlalchemy.registry import get_global_registry - def get_session(context): return context.get("session") @@ -203,7 +201,14 @@ def safe_isinstance_checker(arg): def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: + from graphene_sqlalchemy.registry import get_global_registry try: return next(filter(lambda x: x.__name__ == model_name, list(get_global_registry()._registry.keys()))) except StopIteration: pass + + +class DummyImport: + """The dummy module returns 'object' for a query for any member""" + def __getattr__(self, name): + return object