diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 3f017aae..2766c2ab 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,10 +1,10 @@ - +from collections import OrderedDict from graphene import Field, Int, Interface, ObjectType from graphene.relay import Node, is_node import six from ..registry import Registry -from ..types import SQLAlchemyObjectType +from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions from .models import Article, Reporter registry = Registry() @@ -116,3 +116,45 @@ def test_custom_objecttype_registered(): 'pets', 'articles', 'favorite_article'] + + +# Test Custom SQLAlchemyObjectType with Custom Options +class CustomOptions(SQLAlchemyObjectTypeOptions): + custom_option = None + custom_fields = None + + +class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType): + class Meta: + abstract = True + + @classmethod + def __init_subclass_with_meta__(cls, custom_option=None, custom_fields=None, **options): + _meta = CustomOptions(cls) + _meta.custom_option = custom_option + _meta.fields = custom_fields + super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__(_meta=_meta, **options) + + +class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): + class Meta: + model = Reporter + custom_option = 'custom_option' + custom_fields = OrderedDict([('custom_field', Field(Int()))]) + + +def test_objecttype_with_custom_options(): + assert issubclass(ReporterWithCustomOptions, ObjectType) + assert ReporterWithCustomOptions._meta.model == Reporter + assert list( + ReporterWithCustomOptions._meta.fields.keys()) == [ + 'custom_field', + 'id', + 'first_name', + 'last_name', + 'email', + 'pets', + 'articles', + 'favorite_article'] + assert ReporterWithCustomOptions._meta.custom_option == 'custom_option' + assert isinstance(ReporterWithCustomOptions._meta.fields['custom_field'].type, Int) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 04d1a8a6..69bf310c 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -90,7 +90,7 @@ class SQLAlchemyObjectType(ObjectType): @classmethod def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False, only_fields=(), exclude_fields=(), connection=None, - use_connection=None, interfaces=(), id=None, **options): + use_connection=None, interfaces=(), id=None, _meta=None, **options): assert is_mapped_class(model), ( 'You need to pass a valid SQLAlchemy Model in ' '{}.Meta, received "{}".' @@ -121,10 +121,17 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa "The connection must be a Connection. Received {}" ).format(connection.__name__) - _meta = SQLAlchemyObjectTypeOptions(cls) + if not _meta: + _meta = SQLAlchemyObjectTypeOptions(cls) + _meta.model = model _meta.registry = registry - _meta.fields = sqla_fields + + if _meta.fields: + _meta.fields.update(sqla_fields) + else: + _meta.fields = sqla_fields + _meta.connection = connection _meta.id = id or 'id'