Skip to content

Commit 78004a6

Browse files
committed
Support GQL interfaces for polymorphic SQLA models
1 parent 8bfa1e9 commit 78004a6

File tree

4 files changed

+55
-30
lines changed

4 files changed

+55
-30
lines changed

graphene_sqlalchemy/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from .fields import SQLAlchemyConnectionField
2-
from .types import SQLAlchemyObjectType
2+
from .types import SQLAlchemyInterface, SQLAlchemyObjectType
33
from .utils import get_query, get_session
44

55
__version__ = "3.0.0b3"
66

77
__all__ = [
88
"__version__",
9+
"SQLAlchemyInterface",
910
"SQLAlchemyObjectType",
1011
"SQLAlchemyConnectionField",
1112
"get_query",

graphene_sqlalchemy/registry.py

+6-15
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,10 @@ def __init__(self):
1818
self._registry_unions = {}
1919

2020
def register(self, obj_type):
21+
from .types import SQLAlchemyBase
2122

22-
from .types import SQLAlchemyObjectType
23-
24-
if not isinstance(obj_type, type) or not issubclass(
25-
obj_type, SQLAlchemyObjectType
26-
):
27-
raise TypeError(
28-
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
29-
)
23+
if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyBase):
24+
raise TypeError("Expected SQLAlchemyBase, but got: {!r}".format(obj_type))
3025
assert obj_type._meta.registry == self, "Registry for a Model have to match."
3126
# assert self.get_type_for_model(cls._meta.model) in [None, cls], (
3227
# 'SQLAlchemy model "{}" already associated with '
@@ -38,14 +33,10 @@ def get_type_for_model(self, model):
3833
return self._registry.get(model)
3934

4035
def register_orm_field(self, obj_type, field_name, orm_field):
41-
from .types import SQLAlchemyObjectType
36+
from .types import SQLAlchemyBase
4237

43-
if not isinstance(obj_type, type) or not issubclass(
44-
obj_type, SQLAlchemyObjectType
45-
):
46-
raise TypeError(
47-
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
48-
)
38+
if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyBase):
39+
raise TypeError("Expected SQLAlchemyBase, but got: {!r}".format(obj_type))
4940
if not field_name or not isinstance(field_name, str):
5041
raise TypeError("Expected a field name, but got: {!r}".format(field_name))
5142
self._registry_orm_fields[obj_type][field_name] = orm_field

graphene_sqlalchemy/tests/test_registry.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_register_incorrect_object_type():
2828
class Spam:
2929
pass
3030

31-
re_err = "Expected SQLAlchemyObjectType, but got: .*Spam"
31+
re_err = "Expected SQLAlchemyBase, but got: .*Spam"
3232
with pytest.raises(TypeError, match=re_err):
3333
reg.register(Spam)
3434

@@ -51,7 +51,7 @@ def test_register_orm_field_incorrect_types():
5151
class Spam:
5252
pass
5353

54-
re_err = "Expected SQLAlchemyObjectType, but got: .*Spam"
54+
re_err = "Expected SQLAlchemyBase, but got: .*Spam"
5555
with pytest.raises(TypeError, match=re_err):
5656
reg.register_orm_field(Spam, "name", Pet.name)
5757

graphene_sqlalchemy/types.py

+45-12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from graphene import Field
99
from graphene.relay import Connection, Node
10+
from graphene.types.base import BaseType
11+
from graphene.types.interface import Interface, InterfaceOptions
1012
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
1113
from graphene.types.utils import yank_fields_from_attrs
1214
from graphene.utils.orderedtype import OrderedType
@@ -211,14 +213,7 @@ def construct_fields(
211213
return fields
212214

213215

214-
class SQLAlchemyObjectTypeOptions(ObjectTypeOptions):
215-
model = None # type: sqlalchemy.Model
216-
registry = None # type: sqlalchemy.Registry
217-
connection = None # type: sqlalchemy.Type[sqlalchemy.Connection]
218-
id = None # type: str
219-
220-
221-
class SQLAlchemyObjectType(ObjectType):
216+
class SQLAlchemyBase(BaseType):
222217
@classmethod
223218
def __init_subclass_with_meta__(
224219
cls,
@@ -237,6 +232,11 @@ def __init_subclass_with_meta__(
237232
_meta=None,
238233
**options
239234
):
235+
# We always want to bypass this hook unless we're defining a concrete
236+
# `SQLAlchemyObjectType` or `SQLAlchemyInterface`.
237+
if not _meta:
238+
return
239+
240240
# Make sure model is a valid SQLAlchemy model
241241
if not is_mapped_class(model):
242242
raise ValueError(
@@ -290,9 +290,6 @@ def __init_subclass_with_meta__(
290290
"The connection must be a Connection. Received {}"
291291
).format(connection.__name__)
292292

293-
if not _meta:
294-
_meta = SQLAlchemyObjectTypeOptions(cls)
295-
296293
_meta.model = model
297294
_meta.registry = registry
298295

@@ -306,7 +303,7 @@ def __init_subclass_with_meta__(
306303

307304
cls.connection = connection # Public way to get the connection
308305

309-
super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__(
306+
super(SQLAlchemyBase, cls).__init_subclass_with_meta__(
310307
_meta=_meta, interfaces=interfaces, **options
311308
)
312309

@@ -345,3 +342,39 @@ def enum_for_field(cls, field_name):
345342
sort_enum = classmethod(sort_enum_for_object_type)
346343

347344
sort_argument = classmethod(sort_argument_for_object_type)
345+
346+
347+
class SQLAlchemyObjectTypeOptions(ObjectTypeOptions):
348+
model = None # type: sqlalchemy.Model
349+
registry = None # type: sqlalchemy.Registry
350+
connection = None # type: sqlalchemy.Type[sqlalchemy.Connection]
351+
id = None # type: str
352+
353+
354+
class SQLAlchemyObjectType(SQLAlchemyBase, ObjectType):
355+
@classmethod
356+
def __init_subclass_with_meta__(cls, _meta=None, **options):
357+
if not _meta:
358+
_meta = SQLAlchemyObjectTypeOptions(cls)
359+
360+
super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__(
361+
_meta=_meta, **options
362+
)
363+
364+
365+
class SQLAlchemyInterfaceOptions(InterfaceOptions):
366+
model = None # type: sqlalchemy.Model
367+
registry = None # type: sqlalchemy.Registry
368+
connection = None # type: sqlalchemy.Type[sqlalchemy.Connection]
369+
id = None # type: str
370+
371+
372+
class SQLAlchemyInterface(SQLAlchemyBase, Interface):
373+
@classmethod
374+
def __init_subclass_with_meta__(cls, _meta=None, **options):
375+
if not _meta:
376+
_meta = SQLAlchemyInterfaceOptions(cls)
377+
378+
super(SQLAlchemyInterface, cls).__init_subclass_with_meta__(
379+
_meta=_meta, **options
380+
)

0 commit comments

Comments
 (0)