Skip to content

Commit 4016b62

Browse files
authored
Merge pull request #143 from curvetips/custom-connection
Add support for using custom connections
2 parents 70df6c7 + 79eb7ce commit 4016b62

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

graphene_sqlalchemy/fields.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@
1111

1212
class UnsortedSQLAlchemyConnectionField(ConnectionField):
1313

14+
@property
15+
def type(self):
16+
from .types import SQLAlchemyObjectType
17+
_type = super(ConnectionField, self).type
18+
if issubclass(_type, Connection):
19+
return _type
20+
assert issubclass(_type, SQLAlchemyObjectType), (
21+
"SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
22+
).format(_type.__name__)
23+
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
24+
return _type._meta.connection
25+
1426
@property
1527
def model(self):
1628
return self.type._meta.node._meta.model

graphene_sqlalchemy/types.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ class SQLAlchemyObjectType(ObjectType):
9090
@classmethod
9191
def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False,
9292
only_fields=(), exclude_fields=(), connection=None,
93-
use_connection=None, interfaces=(), id=None, _meta=None, **options):
93+
connection_class=None, use_connection=None, interfaces=(),
94+
id=None, _meta=None, **options):
9495
assert is_mapped_class(model), (
9596
'You need to pass a valid SQLAlchemy Model in '
9697
'{}.Meta, received "{}".'
@@ -114,7 +115,11 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa
114115

115116
if use_connection and not connection:
116117
# We create the connection automatically
117-
connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls)
118+
if not connection_class:
119+
connection_class = Connection
120+
121+
connection = connection_class.create_type(
122+
'{}Connection'.format(cls.__name__), node=cls)
118123

119124
if connection is not None:
120125
assert issubclass(connection, Connection), (

0 commit comments

Comments
 (0)