diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 428eca1..7632fd3 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -1,6 +1,6 @@
name: Tests
-on:
+on:
push:
branches:
- 'master'
diff --git a/docs/inheritance.rst b/docs/inheritance.rst
index ee16f06..7473216 100644
--- a/docs/inheritance.rst
+++ b/docs/inheritance.rst
@@ -3,7 +3,7 @@ Inheritance Examples
Create interfaces from inheritance relationships
------------------------------------------------
-
+.. note:: If you're using `AsyncSession`, please check the chapter `Eager Loading & Using with AsyncSession`_.
SQLAlchemy has excellent support for class inheritance hierarchies.
These hierarchies can be represented in your GraphQL schema by means
of interfaces_. Much like ObjectTypes, Interfaces in
@@ -40,7 +40,7 @@ from the attributes of their underlying SQLAlchemy model:
__mapper_args__ = {
"polymorphic_identity": "employee",
}
-
+
class Customer(Person):
first_purchase_date = Column(Date())
@@ -56,17 +56,17 @@ from the attributes of their underlying SQLAlchemy model:
class Meta:
model = Employee
interfaces = (relay.Node, PersonType)
-
+
class CustomerType(SQLAlchemyObjectType):
class Meta:
model = Customer
interfaces = (relay.Node, PersonType)
-Keep in mind that `PersonType` is a `SQLAlchemyInterface`. Interfaces must
-be linked to an abstract Model that does not specify a `polymorphic_identity`,
-because we cannot return instances of interfaces from a GraphQL query.
-If Person specified a `polymorphic_identity`, instances of Person could
-be inserted into and returned by the database, potentially causing
+Keep in mind that `PersonType` is a `SQLAlchemyInterface`. Interfaces must
+be linked to an abstract Model that does not specify a `polymorphic_identity`,
+because we cannot return instances of interfaces from a GraphQL query.
+If Person specified a `polymorphic_identity`, instances of Person could
+be inserted into and returned by the database, potentially causing
Persons to be returned to the resolvers.
When querying on the base type, you can refer directly to common fields,
@@ -85,15 +85,19 @@ and fields on concrete implementations using the `... on` syntax:
firstPurchaseDate
}
}
-
-
+
+
+.. danger::
+ When using joined table inheritance, this style of querying may lead to unbatched implicit IO with negative performance implications.
+ See the chapter `Eager Loading & Using with AsyncSession`_ for more information on eager loading all possible types of a `SQLAlchemyInterface`.
+
Please note that by default, the "polymorphic_on" column is *not*
generated as a field on types that use polymorphic inheritance, as
-this is considered an implentation detail. The idiomatic way to
+this is considered an implementation detail. The idiomatic way to
retrieve the concrete GraphQL type of an object is to query for the
-`__typename` field.
+`__typename` field.
To override this behavior, an `ORMField` needs to be created
-for the custom type field on the corresponding `SQLAlchemyInterface`. This is *not recommended*
+for the custom type field on the corresponding `SQLAlchemyInterface`. This is *not recommended*
as it promotes abiguous schema design
If your SQLAlchemy model only specifies a relationship to the
@@ -103,5 +107,39 @@ class to the Schema constructor via the `types=` argument:
.. code:: python
schema = graphene.Schema(..., types=[PersonType, EmployeeType, CustomerType])
-
+
+
See also: `Graphene Interfaces `_
+
+Eager Loading & Using with AsyncSession
+--------------------
+When querying the base type in multi-table inheritance or joined table inheritance, you can only directly refer to polymorphic fields when they are loaded eagerly.
+This restricting is in place because AsyncSessions don't allow implicit async operations such as the loads of the joined tables.
+To load the polymorphic fields eagerly, you can use the `with_polymorphic` attribute of the mapper args in the base model:
+
+.. code:: python
+ class Person(Base):
+ id = Column(Integer(), primary_key=True)
+ type = Column(String())
+ name = Column(String())
+ birth_date = Column(Date())
+
+ __tablename__ = "person"
+ __mapper_args__ = {
+ "polymorphic_on": type,
+ "with_polymorphic": "*", # needed for eager loading in async session
+ }
+
+Alternatively, the specific polymorphic fields can be loaded explicitly in resolvers:
+
+.. code:: python
+
+ class Query(graphene.ObjectType):
+ people = graphene.Field(graphene.List(PersonType))
+
+ async def resolve_people(self, _info):
+ return (await session.scalars(with_polymorphic(Person, [Engineer, Customer]))).all()
+
+Dynamic batching of the types based on the query to avoid eager is currently not supported, but could be implemented in a future PR.
+
+For more information on loading techniques for polymorphic models, please check out the `SQLAlchemy docs `_.
diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py
index 0800d0e..23b6712 100644
--- a/graphene_sqlalchemy/batching.py
+++ b/graphene_sqlalchemy/batching.py
@@ -6,7 +6,7 @@
from sqlalchemy.orm import Session, strategies
from sqlalchemy.orm.query import QueryContext
-from .utils import is_graphene_version_less_than, is_sqlalchemy_version_less_than
+from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_graphene_version_less_than
def get_data_loader_impl() -> Any: # pragma: no cover
@@ -71,19 +71,19 @@ async def batch_load_fn(self, parents):
# For our purposes, the query_context will only used to get the session
query_context = None
- if is_sqlalchemy_version_less_than("1.4"):
- query_context = QueryContext(session.query(parent_mapper.entity))
- else:
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
parent_mapper_query = session.query(parent_mapper.entity)
query_context = parent_mapper_query._compile_context()
-
- if is_sqlalchemy_version_less_than("1.4"):
+ else:
+ query_context = QueryContext(session.query(parent_mapper.entity))
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
+ None,
)
else:
self.selectin_loader._load_for_path(
@@ -92,7 +92,6 @@ async def batch_load_fn(self, parents):
states,
None,
child_mapper,
- None,
)
return [getattr(parent, self.relationship_prop.key) for parent in parents]
diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py
index 2cb53c5..6dbc134 100644
--- a/graphene_sqlalchemy/fields.py
+++ b/graphene_sqlalchemy/fields.py
@@ -11,7 +11,10 @@
from graphql_relay import connection_from_array_slice
from .batching import get_batch_resolver
-from .utils import EnumValue, get_query
+from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, EnumValue, get_query, get_session
+
+if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
+ from sqlalchemy.ext.asyncio import AsyncSession
class SQLAlchemyConnectionField(ConnectionField):
@@ -81,8 +84,49 @@ def get_query(cls, model, info, sort=None, **args):
@classmethod
def resolve_connection(cls, connection_type, model, info, args, resolved):
+ session = get_session(info.context)
+ if resolved is None:
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+
+ async def get_result():
+ return await cls.resolve_connection_async(
+ connection_type, model, info, args, resolved
+ )
+
+ return get_result()
+
+ else:
+ resolved = cls.get_query(model, info, **args)
+ if isinstance(resolved, Query):
+ _len = resolved.count()
+ else:
+ _len = len(resolved)
+
+ def adjusted_connection_adapter(edges, pageInfo):
+ return connection_adapter(connection_type, edges, pageInfo)
+
+ connection = connection_from_array_slice(
+ array_slice=resolved,
+ args=args,
+ slice_start=0,
+ array_length=_len,
+ array_slice_length=_len,
+ connection_type=adjusted_connection_adapter,
+ edge_type=connection_type.Edge,
+ page_info_type=page_info_adapter,
+ )
+ connection.iterable = resolved
+ connection.length = _len
+ return connection
+
+ @classmethod
+ async def resolve_connection_async(
+ cls, connection_type, model, info, args, resolved
+ ):
+ session = get_session(info.context)
if resolved is None:
- resolved = cls.get_query(model, info, **args)
+ query = cls.get_query(model, info, **args)
+ resolved = (await session.scalars(query)).all()
if isinstance(resolved, Query):
_len = resolved.count()
else:
@@ -179,7 +223,7 @@ def from_relationship(cls, relationship, registry, **field_kwargs):
return cls(
model_type.connection,
resolver=get_batch_resolver(relationship),
- **field_kwargs
+ **field_kwargs,
)
diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py
index 357ad96..89b357a 100644
--- a/graphene_sqlalchemy/tests/conftest.py
+++ b/graphene_sqlalchemy/tests/conftest.py
@@ -1,14 +1,17 @@
import pytest
+import pytest_asyncio
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
import graphene
+from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4
from ..converter import convert_sqlalchemy_composite
from ..registry import reset_global_registry
from .models import Base, CompositeFullName
-test_db_url = "sqlite://" # use in-memory database for tests
+if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
+ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
@pytest.fixture(autouse=True)
@@ -22,18 +25,49 @@ def convert_composite_class(composite, registry):
return graphene.Field(graphene.Int)
-@pytest.fixture(scope="function")
-def session_factory():
- engine = create_engine(test_db_url)
- Base.metadata.create_all(engine)
+@pytest.fixture(params=[False, True])
+def async_session(request):
+ return request.param
+
+
+@pytest.fixture
+def test_db_url(async_session: bool):
+ if async_session:
+ return "sqlite+aiosqlite://"
+ else:
+ return "sqlite://"
- yield sessionmaker(bind=engine)
+@pytest.mark.asyncio
+@pytest_asyncio.fixture(scope="function")
+async def session_factory(async_session: bool, test_db_url: str):
+ if async_session:
+ if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
+ pytest.skip("Async Sessions only work in sql alchemy 1.4 and above")
+ engine = create_async_engine(test_db_url)
+ async with engine.begin() as conn:
+ await conn.run_sync(Base.metadata.create_all)
+ yield sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
+ await engine.dispose()
+ else:
+ engine = create_engine(test_db_url)
+ Base.metadata.create_all(engine)
+ yield sessionmaker(bind=engine, expire_on_commit=False)
+ # SQLite in-memory db is deleted when its connection is closed.
+ # https://www.sqlite.org/inmemorydb.html
+ engine.dispose()
+
+
+@pytest_asyncio.fixture(scope="function")
+async def sync_session_factory():
+ engine = create_engine("sqlite://")
+ Base.metadata.create_all(engine)
+ yield sessionmaker(bind=engine, expire_on_commit=False)
# SQLite in-memory db is deleted when its connection is closed.
# https://www.sqlite.org/inmemorydb.html
engine.dispose()
-@pytest.fixture(scope="function")
+@pytest_asyncio.fixture(scope="function")
def session(session_factory):
return session_factory()
diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py
index 4fe9146..ee28658 100644
--- a/graphene_sqlalchemy/tests/models.py
+++ b/graphene_sqlalchemy/tests/models.py
@@ -20,7 +20,7 @@
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import column_property, composite, mapper, relationship
+from sqlalchemy.orm import backref, column_property, composite, mapper, relationship
PetKind = Enum("cat", "dog", name="pet_kind")
@@ -76,10 +76,16 @@ class Reporter(Base):
email = Column(String(), doc="Email")
favorite_pet_kind = Column(PetKind)
pets = relationship(
- "Pet", secondary=association_table, backref="reporters", order_by="Pet.id"
+ "Pet",
+ secondary=association_table,
+ backref="reporters",
+ order_by="Pet.id",
+ lazy="selectin",
)
- articles = relationship("Article", backref="reporter")
- favorite_article = relationship("Article", uselist=False)
+ articles = relationship(
+ "Article", backref=backref("reporter", lazy="selectin"), lazy="selectin"
+ )
+ favorite_article = relationship("Article", uselist=False, lazy="selectin")
@hybrid_property
def hybrid_prop_with_doc(self):
@@ -304,8 +310,10 @@ class Person(Base):
__tablename__ = "person"
__mapper_args__ = {
"polymorphic_on": type,
+ "with_polymorphic": "*", # needed for eager loading in async session
}
+
class NonAbstractPerson(Base):
id = Column(Integer(), primary_key=True)
type = Column(String())
@@ -318,6 +326,7 @@ class NonAbstractPerson(Base):
"polymorphic_identity": "person",
}
+
class Employee(Person):
hire_date = Column(Date())
diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py
new file mode 100644
index 0000000..6f1c42f
--- /dev/null
+++ b/graphene_sqlalchemy/tests/models_batching.py
@@ -0,0 +1,91 @@
+from __future__ import absolute_import
+
+import enum
+
+from sqlalchemy import (
+ Column,
+ Date,
+ Enum,
+ ForeignKey,
+ Integer,
+ String,
+ Table,
+ func,
+ select,
+)
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import column_property, relationship
+
+PetKind = Enum("cat", "dog", name="pet_kind")
+
+
+class HairKind(enum.Enum):
+ LONG = "long"
+ SHORT = "short"
+
+
+Base = declarative_base()
+
+association_table = Table(
+ "association",
+ Base.metadata,
+ Column("pet_id", Integer, ForeignKey("pets.id")),
+ Column("reporter_id", Integer, ForeignKey("reporters.id")),
+)
+
+
+class Pet(Base):
+ __tablename__ = "pets"
+ id = Column(Integer(), primary_key=True)
+ name = Column(String(30))
+ pet_kind = Column(PetKind, nullable=False)
+ hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False)
+ reporter_id = Column(Integer(), ForeignKey("reporters.id"))
+
+
+class Reporter(Base):
+ __tablename__ = "reporters"
+
+ id = Column(Integer(), primary_key=True)
+ first_name = Column(String(30), doc="First name")
+ last_name = Column(String(30), doc="Last name")
+ email = Column(String(), doc="Email")
+ favorite_pet_kind = Column(PetKind)
+ pets = relationship(
+ "Pet",
+ secondary=association_table,
+ backref="reporters",
+ order_by="Pet.id",
+ )
+ articles = relationship("Article", backref="reporter")
+ favorite_article = relationship("Article", uselist=False)
+
+ column_prop = column_property(
+ select([func.cast(func.count(id), Integer)]), doc="Column property"
+ )
+
+
+class Article(Base):
+ __tablename__ = "articles"
+ id = Column(Integer(), primary_key=True)
+ headline = Column(String(100))
+ pub_date = Column(Date())
+ reporter_id = Column(Integer(), ForeignKey("reporters.id"))
+ readers = relationship(
+ "Reader", secondary="articles_readers", back_populates="articles"
+ )
+
+
+class Reader(Base):
+ __tablename__ = "readers"
+ id = Column(Integer(), primary_key=True)
+ name = Column(String(100))
+ articles = relationship(
+ "Article", secondary="articles_readers", back_populates="readers"
+ )
+
+
+class ArticleReader(Base):
+ __tablename__ = "articles_readers"
+ article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True)
+ reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True)
diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py
index 90df027..5eccd5f 100644
--- a/graphene_sqlalchemy/tests/test_batching.py
+++ b/graphene_sqlalchemy/tests/test_batching.py
@@ -3,15 +3,23 @@
import logging
import pytest
+from sqlalchemy import select
import graphene
from graphene import Connection, relay
from ..fields import BatchSQLAlchemyConnectionField, default_connection_field_factory
from ..types import ORMField, SQLAlchemyObjectType
-from ..utils import is_sqlalchemy_version_less_than
-from .models import Article, HairKind, Pet, Reader, Reporter
-from .utils import remove_cache_miss_stat, to_std_dicts
+from ..utils import (
+ SQL_VERSION_HIGHER_EQUAL_THAN_1_4,
+ get_session,
+ is_sqlalchemy_version_less_than,
+)
+from .models_batching import Article, HairKind, Pet, Reader, Reporter
+from .utils import eventually_await_session, remove_cache_miss_stat, to_std_dicts
+
+if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
+ from sqlalchemy.ext.asyncio import AsyncSession
class MockLoggingHandler(logging.Handler):
@@ -41,6 +49,44 @@ def mock_sqlalchemy_logging_handler():
sql_logger.setLevel(previous_level)
+def get_async_schema():
+ class ReporterType(SQLAlchemyObjectType):
+ class Meta:
+ model = Reporter
+ interfaces = (relay.Node,)
+ batching = True
+
+ class ArticleType(SQLAlchemyObjectType):
+ class Meta:
+ model = Article
+ interfaces = (relay.Node,)
+ batching = True
+
+ class PetType(SQLAlchemyObjectType):
+ class Meta:
+ model = Pet
+ interfaces = (relay.Node,)
+ batching = True
+
+ class Query(graphene.ObjectType):
+ articles = graphene.Field(graphene.List(ArticleType))
+ reporters = graphene.Field(graphene.List(ReporterType))
+
+ async def resolve_articles(self, info):
+ session = get_session(info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Article))).all()
+ return session.query(Article).all()
+
+ async def resolve_reporters(self, info):
+ session = get_session(info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Reporter))).all()
+ return session.query(Reporter).all()
+
+ return graphene.Schema(query=Query)
+
+
def get_schema():
class ReporterType(SQLAlchemyObjectType):
class Meta:
@@ -65,14 +111,20 @@ class Query(graphene.ObjectType):
reporters = graphene.Field(graphene.List(ReporterType))
def resolve_articles(self, info):
- return info.context.get("session").query(Article).all()
+ session = get_session(info.context)
+ return session.query(Article).all()
def resolve_reporters(self, info):
- return info.context.get("session").query(Reporter).all()
+ session = get_session(info.context)
+ return session.query(Reporter).all()
return graphene.Schema(query=Query)
+if is_sqlalchemy_version_less_than("1.2"):
+ pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True)
+
+
def get_full_relay_schema():
class ReporterType(SQLAlchemyObjectType):
class Meta:
@@ -107,14 +159,11 @@ class Query(graphene.ObjectType):
return graphene.Schema(query=Query)
-if is_sqlalchemy_version_less_than("1.2"):
- pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True)
-
-
@pytest.mark.asyncio
-async def test_many_to_one(session_factory):
- session = session_factory()
-
+@pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema])
+async def test_many_to_one(sync_session_factory, schema_provider):
+ session = sync_session_factory()
+ schema = schema_provider()
reporter_1 = Reporter(
first_name="Reporter_1",
)
@@ -135,26 +184,43 @@ async def test_many_to_one(session_factory):
session.commit()
session.close()
- schema = get_schema()
-
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
- session = session_factory()
+ session = sync_session_factory()
result = await schema.execute_async(
"""
- query {
- articles {
- headline
- reporter {
- firstName
+ query {
+ articles {
+ headline
+ reporter {
+ firstName
+ }
+ }
}
- }
- }
- """,
+ """,
context_value={"session": session},
)
messages = sqlalchemy_logging_handler.messages
+ assert not result.errors
+ result = to_std_dicts(result.data)
+ assert result == {
+ "articles": [
+ {
+ "headline": "Article_1",
+ "reporter": {
+ "firstName": "Reporter_1",
+ },
+ },
+ {
+ "headline": "Article_2",
+ "reporter": {
+ "firstName": "Reporter_2",
+ },
+ },
+ ],
+ }
+
assert len(messages) == 5
if is_sqlalchemy_version_less_than("1.3"):
@@ -169,37 +235,19 @@ async def test_many_to_one(session_factory):
assert len(sql_statements) == 1
return
- if not is_sqlalchemy_version_less_than("1.4"):
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
messages[2] = remove_cache_miss_stat(messages[2])
messages[4] = remove_cache_miss_stat(messages[4])
assert ast.literal_eval(messages[2]) == ()
assert sorted(ast.literal_eval(messages[4])) == [1, 2]
- assert not result.errors
- result = to_std_dicts(result.data)
- assert result == {
- "articles": [
- {
- "headline": "Article_1",
- "reporter": {
- "firstName": "Reporter_1",
- },
- },
- {
- "headline": "Article_2",
- "reporter": {
- "firstName": "Reporter_2",
- },
- },
- ],
- }
-
@pytest.mark.asyncio
-async def test_one_to_one(session_factory):
- session = session_factory()
-
+@pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema])
+async def test_one_to_one(sync_session_factory, schema_provider):
+ session = sync_session_factory()
+ schema = schema_provider()
reporter_1 = Reporter(
first_name="Reporter_1",
)
@@ -220,26 +268,43 @@ async def test_one_to_one(session_factory):
session.commit()
session.close()
- schema = get_schema()
-
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
- session = session_factory()
+
+ session = sync_session_factory()
result = await schema.execute_async(
"""
- query {
- reporters {
- firstName
- favoriteArticle {
- headline
- }
- }
+ query {
+ reporters {
+ firstName
+ favoriteArticle {
+ headline
+ }
}
+ }
""",
context_value={"session": session},
)
messages = sqlalchemy_logging_handler.messages
+ assert not result.errors
+ result = to_std_dicts(result.data)
+ assert result == {
+ "reporters": [
+ {
+ "firstName": "Reporter_1",
+ "favoriteArticle": {
+ "headline": "Article_1",
+ },
+ },
+ {
+ "firstName": "Reporter_2",
+ "favoriteArticle": {
+ "headline": "Article_2",
+ },
+ },
+ ],
+ }
assert len(messages) == 5
if is_sqlalchemy_version_less_than("1.3"):
@@ -254,36 +319,17 @@ async def test_one_to_one(session_factory):
assert len(sql_statements) == 1
return
- if not is_sqlalchemy_version_less_than("1.4"):
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
messages[2] = remove_cache_miss_stat(messages[2])
messages[4] = remove_cache_miss_stat(messages[4])
assert ast.literal_eval(messages[2]) == ()
assert sorted(ast.literal_eval(messages[4])) == [1, 2]
- assert not result.errors
- result = to_std_dicts(result.data)
- assert result == {
- "reporters": [
- {
- "firstName": "Reporter_1",
- "favoriteArticle": {
- "headline": "Article_1",
- },
- },
- {
- "firstName": "Reporter_2",
- "favoriteArticle": {
- "headline": "Article_2",
- },
- },
- ],
- }
-
@pytest.mark.asyncio
-async def test_one_to_many(session_factory):
- session = session_factory()
+async def test_one_to_many(sync_session_factory):
+ session = sync_session_factory()
reporter_1 = Reporter(
first_name="Reporter_1",
@@ -309,7 +355,6 @@ async def test_one_to_many(session_factory):
article_4 = Article(headline="Article_4")
article_4.reporter = reporter_2
session.add(article_4)
-
session.commit()
session.close()
@@ -317,7 +362,8 @@ async def test_one_to_many(session_factory):
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
- session = session_factory()
+
+ session = sync_session_factory()
result = await schema.execute_async(
"""
query {
@@ -337,27 +383,6 @@ async def test_one_to_many(session_factory):
)
messages = sqlalchemy_logging_handler.messages
- assert len(messages) == 5
-
- if is_sqlalchemy_version_less_than("1.3"):
- # The batched SQL statement generated is different in 1.2.x
- # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
- # See https://git.io/JewQu
- sql_statements = [
- message
- for message in messages
- if "SELECT" in message and "JOIN articles" in message
- ]
- assert len(sql_statements) == 1
- return
-
- if not is_sqlalchemy_version_less_than("1.4"):
- messages[2] = remove_cache_miss_stat(messages[2])
- messages[4] = remove_cache_miss_stat(messages[4])
-
- assert ast.literal_eval(messages[2]) == ()
- assert sorted(ast.literal_eval(messages[4])) == [1, 2]
-
assert not result.errors
result = to_std_dicts(result.data)
assert result == {
@@ -398,11 +423,31 @@ async def test_one_to_many(session_factory):
},
],
}
+ assert len(messages) == 5
+
+ if is_sqlalchemy_version_less_than("1.3"):
+ # The batched SQL statement generated is different in 1.2.x
+ # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
+ # See https://git.io/JewQu
+ sql_statements = [
+ message
+ for message in messages
+ if "SELECT" in message and "JOIN articles" in message
+ ]
+ assert len(sql_statements) == 1
+ return
+
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
+ messages[2] = remove_cache_miss_stat(messages[2])
+ messages[4] = remove_cache_miss_stat(messages[4])
+
+ assert ast.literal_eval(messages[2]) == ()
+ assert sorted(ast.literal_eval(messages[4])) == [1, 2]
@pytest.mark.asyncio
-async def test_many_to_many(session_factory):
- session = session_factory()
+async def test_many_to_many(sync_session_factory):
+ session = sync_session_factory()
reporter_1 = Reporter(
first_name="Reporter_1",
@@ -430,15 +475,14 @@ async def test_many_to_many(session_factory):
reporter_2.pets.append(pet_3)
reporter_2.pets.append(pet_4)
-
- session.commit()
- session.close()
+ await eventually_await_session(session, "commit")
+ await eventually_await_session(session, "close")
schema = get_schema()
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
- session = session_factory()
+ session = sync_session_factory()
result = await schema.execute_async(
"""
query {
@@ -458,27 +502,6 @@ async def test_many_to_many(session_factory):
)
messages = sqlalchemy_logging_handler.messages
- assert len(messages) == 5
-
- if is_sqlalchemy_version_less_than("1.3"):
- # The batched SQL statement generated is different in 1.2.x
- # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
- # See https://git.io/JewQu
- sql_statements = [
- message
- for message in messages
- if "SELECT" in message and "JOIN pets" in message
- ]
- assert len(sql_statements) == 1
- return
-
- if not is_sqlalchemy_version_less_than("1.4"):
- messages[2] = remove_cache_miss_stat(messages[2])
- messages[4] = remove_cache_miss_stat(messages[4])
-
- assert ast.literal_eval(messages[2]) == ()
- assert sorted(ast.literal_eval(messages[4])) == [1, 2]
-
assert not result.errors
result = to_std_dicts(result.data)
assert result == {
@@ -520,9 +543,30 @@ async def test_many_to_many(session_factory):
],
}
+ assert len(messages) == 5
-def test_disable_batching_via_ormfield(session_factory):
- session = session_factory()
+ if is_sqlalchemy_version_less_than("1.3"):
+ # The batched SQL statement generated is different in 1.2.x
+ # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
+ # See https://git.io/JewQu
+ sql_statements = [
+ message
+ for message in messages
+ if "SELECT" in message and "JOIN pets" in message
+ ]
+ assert len(sql_statements) == 1
+ return
+
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
+ messages[2] = remove_cache_miss_stat(messages[2])
+ messages[4] = remove_cache_miss_stat(messages[4])
+
+ assert ast.literal_eval(messages[2]) == ()
+ assert sorted(ast.literal_eval(messages[4])) == [1, 2]
+
+
+def test_disable_batching_via_ormfield(sync_session_factory):
+ session = sync_session_factory()
reporter_1 = Reporter(first_name="Reporter_1")
session.add(reporter_1)
reporter_2 = Reporter(first_name="Reporter_2")
@@ -555,7 +599,7 @@ def resolve_reporters(self, info):
# Test one-to-one and many-to-one relationships
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
- session = session_factory()
+ session = sync_session_factory()
schema.execute(
"""
query {
@@ -580,7 +624,7 @@ def resolve_reporters(self, info):
# Test one-to-many and many-to-many relationships
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
- session = session_factory()
+ session = sync_session_factory()
schema.execute(
"""
query {
@@ -607,9 +651,8 @@ def resolve_reporters(self, info):
assert len(select_statements) == 2
-@pytest.mark.asyncio
-def test_batch_sorting_with_custom_ormfield(session_factory):
- session = session_factory()
+def test_batch_sorting_with_custom_ormfield(sync_session_factory):
+ session = sync_session_factory()
reporter_1 = Reporter(first_name="Reporter_1")
session.add(reporter_1)
reporter_2 = Reporter(first_name="Reporter_2")
@@ -642,7 +685,7 @@ class Meta:
# Test one-to-one and many-to-one relationships
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
- session = session_factory()
+ session = sync_session_factory()
result = schema.execute(
"""
query {
@@ -658,7 +701,7 @@ class Meta:
context_value={"session": session},
)
messages = sqlalchemy_logging_handler.messages
-
+ assert not result.errors
result = to_std_dicts(result.data)
assert result == {
"reporters": {
@@ -685,8 +728,10 @@ class Meta:
@pytest.mark.asyncio
-async def test_connection_factory_field_overrides_batching_is_false(session_factory):
- session = session_factory()
+async def test_connection_factory_field_overrides_batching_is_false(
+ sync_session_factory,
+):
+ session = sync_session_factory()
reporter_1 = Reporter(first_name="Reporter_1")
session.add(reporter_1)
reporter_2 = Reporter(first_name="Reporter_2")
@@ -718,7 +763,7 @@ def resolve_reporters(self, info):
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
- session = session_factory()
+ session = sync_session_factory()
await schema.execute_async(
"""
query {
@@ -755,8 +800,8 @@ def resolve_reporters(self, info):
assert len(select_statements) == 1
-def test_connection_factory_field_overrides_batching_is_true(session_factory):
- session = session_factory()
+def test_connection_factory_field_overrides_batching_is_true(sync_session_factory):
+ session = sync_session_factory()
reporter_1 = Reporter(first_name="Reporter_1")
session.add(reporter_1)
reporter_2 = Reporter(first_name="Reporter_2")
@@ -788,7 +833,7 @@ def resolve_reporters(self, info):
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
- session = session_factory()
+ session = sync_session_factory()
schema.execute(
"""
query {
@@ -816,7 +861,9 @@ def resolve_reporters(self, info):
@pytest.mark.asyncio
-async def test_batching_across_nested_relay_schema(session_factory):
+async def test_batching_across_nested_relay_schema(
+ session_factory, async_session: bool
+):
session = session_factory()
for first_name in "fgerbhjikzutzxsdfdqqa":
@@ -831,8 +878,8 @@ async def test_batching_across_nested_relay_schema(session_factory):
reader.articles = [article]
session.add(reader)
- session.commit()
- session.close()
+ await eventually_await_session(session, "commit")
+ await eventually_await_session(session, "close")
schema = get_full_relay_schema()
@@ -871,14 +918,17 @@ async def test_batching_across_nested_relay_schema(session_factory):
result = to_std_dicts(result.data)
select_statements = [message for message in messages if "SELECT" in message]
- assert len(select_statements) == 4
- assert select_statements[-1].startswith("SELECT articles_1.id")
- if is_sqlalchemy_version_less_than("1.3"):
- assert select_statements[-2].startswith("SELECT reporters_1.id")
- assert "WHERE reporters_1.id IN" in select_statements[-2]
+ if async_session:
+ assert len(select_statements) == 2 # TODO: Figure out why async has less calls
else:
- assert select_statements[-2].startswith("SELECT articles.reporter_id")
- assert "WHERE articles.reporter_id IN" in select_statements[-2]
+ assert len(select_statements) == 4
+ assert select_statements[-1].startswith("SELECT articles_1.id")
+ if is_sqlalchemy_version_less_than("1.3"):
+ assert select_statements[-2].startswith("SELECT reporters_1.id")
+ assert "WHERE reporters_1.id IN" in select_statements[-2]
+ else:
+ assert select_statements[-2].startswith("SELECT articles.reporter_id")
+ assert "WHERE articles.reporter_id IN" in select_statements[-2]
@pytest.mark.asyncio
@@ -892,8 +942,8 @@ async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_f
article_1.reporter = reporter_1
session.add(article_1)
- session.commit()
- session.close()
+ await eventually_await_session(session, "commit")
+ await eventually_await_session(session, "close")
schema = get_full_relay_schema()
diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py
index bb105ed..dc656f4 100644
--- a/graphene_sqlalchemy/tests/test_benchmark.py
+++ b/graphene_sqlalchemy/tests/test_benchmark.py
@@ -1,16 +1,61 @@
+import asyncio
+
import pytest
+from sqlalchemy import select
import graphene
from graphene import relay
from ..types import SQLAlchemyObjectType
-from ..utils import is_sqlalchemy_version_less_than
+from ..utils import (
+ SQL_VERSION_HIGHER_EQUAL_THAN_1_4,
+ get_session,
+ is_sqlalchemy_version_less_than,
+)
from .models import Article, HairKind, Pet, Reporter
+from .utils import eventually_await_session
+if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
+ from sqlalchemy.ext.asyncio import AsyncSession
if is_sqlalchemy_version_less_than("1.2"):
pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True)
+def get_async_schema():
+ class ReporterType(SQLAlchemyObjectType):
+ class Meta:
+ model = Reporter
+ interfaces = (relay.Node,)
+
+ class ArticleType(SQLAlchemyObjectType):
+ class Meta:
+ model = Article
+ interfaces = (relay.Node,)
+
+ class PetType(SQLAlchemyObjectType):
+ class Meta:
+ model = Pet
+ interfaces = (relay.Node,)
+
+ class Query(graphene.ObjectType):
+ articles = graphene.Field(graphene.List(ArticleType))
+ reporters = graphene.Field(graphene.List(ReporterType))
+
+ async def resolve_articles(self, info):
+ session = get_session(info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Article))).all()
+ return session.query(Article).all()
+
+ async def resolve_reporters(self, info):
+ session = get_session(info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Reporter))).all()
+ return session.query(Reporter).all()
+
+ return graphene.Schema(query=Query)
+
+
def get_schema():
class ReporterType(SQLAlchemyObjectType):
class Meta:
@@ -40,20 +85,30 @@ def resolve_reporters(self, info):
return graphene.Schema(query=Query)
-def benchmark_query(session_factory, benchmark, query):
- schema = get_schema()
+async def benchmark_query(session, benchmark, schema, query):
+ import nest_asyncio
- @benchmark
- def execute_query():
- result = schema.execute(
- query,
- context_value={"session": session_factory()},
+ nest_asyncio.apply()
+ loop = asyncio.get_event_loop()
+ result = benchmark(
+ lambda: loop.run_until_complete(
+ schema.execute_async(query, context_value={"session": session})
)
- assert not result.errors
+ )
+ assert not result.errors
+
+
+@pytest.fixture(params=[get_schema, get_async_schema])
+def schema_provider(request, async_session):
+ if async_session and request.param == get_schema:
+ pytest.skip("Cannot test sync schema with async sessions")
+ return request.param
-def test_one_to_one(session_factory, benchmark):
+@pytest.mark.asyncio
+async def test_one_to_one(session_factory, benchmark, schema_provider):
session = session_factory()
+ schema = schema_provider()
reporter_1 = Reporter(
first_name="Reporter_1",
@@ -72,12 +127,13 @@ def test_one_to_one(session_factory, benchmark):
article_2.reporter = reporter_2
session.add(article_2)
- session.commit()
- session.close()
+ await eventually_await_session(session, "commit")
+ await eventually_await_session(session, "close")
- benchmark_query(
- session_factory,
+ await benchmark_query(
+ session,
benchmark,
+ schema,
"""
query {
reporters {
@@ -91,9 +147,10 @@ def test_one_to_one(session_factory, benchmark):
)
-def test_many_to_one(session_factory, benchmark):
+@pytest.mark.asyncio
+async def test_many_to_one(session_factory, benchmark, schema_provider):
session = session_factory()
-
+ schema = schema_provider()
reporter_1 = Reporter(
first_name="Reporter_1",
)
@@ -110,13 +167,14 @@ def test_many_to_one(session_factory, benchmark):
article_2 = Article(headline="Article_2")
article_2.reporter = reporter_2
session.add(article_2)
+ await eventually_await_session(session, "flush")
+ await eventually_await_session(session, "commit")
+ await eventually_await_session(session, "close")
- session.commit()
- session.close()
-
- benchmark_query(
- session_factory,
+ await benchmark_query(
+ session,
benchmark,
+ schema,
"""
query {
articles {
@@ -130,8 +188,10 @@ def test_many_to_one(session_factory, benchmark):
)
-def test_one_to_many(session_factory, benchmark):
+@pytest.mark.asyncio
+async def test_one_to_many(session_factory, benchmark, schema_provider):
session = session_factory()
+ schema = schema_provider()
reporter_1 = Reporter(
first_name="Reporter_1",
@@ -158,12 +218,13 @@ def test_one_to_many(session_factory, benchmark):
article_4.reporter = reporter_2
session.add(article_4)
- session.commit()
- session.close()
+ await eventually_await_session(session, "commit")
+ await eventually_await_session(session, "close")
- benchmark_query(
- session_factory,
+ await benchmark_query(
+ session,
benchmark,
+ schema,
"""
query {
reporters {
@@ -181,9 +242,10 @@ def test_one_to_many(session_factory, benchmark):
)
-def test_many_to_many(session_factory, benchmark):
+@pytest.mark.asyncio
+async def test_many_to_many(session_factory, benchmark, schema_provider):
session = session_factory()
-
+ schema = schema_provider()
reporter_1 = Reporter(
first_name="Reporter_1",
)
@@ -211,12 +273,13 @@ def test_many_to_many(session_factory, benchmark):
reporter_2.pets.append(pet_3)
reporter_2.pets.append(pet_4)
- session.commit()
- session.close()
+ await eventually_await_session(session, "commit")
+ await eventually_await_session(session, "close")
- benchmark_query(
- session_factory,
+ await benchmark_query(
+ session,
benchmark,
+ schema,
"""
query {
reporters {
diff --git a/graphene_sqlalchemy/tests/test_enums.py b/graphene_sqlalchemy/tests/test_enums.py
index cd97a00..3de6904 100644
--- a/graphene_sqlalchemy/tests/test_enums.py
+++ b/graphene_sqlalchemy/tests/test_enums.py
@@ -85,7 +85,10 @@ class Meta:
assert enum._meta.name == "PetKind"
assert [
(key, value.value) for key, value in enum._meta.enum.__members__.items()
- ] == [("CAT", "cat"), ("DOG", "dog")]
+ ] == [
+ ("CAT", "cat"),
+ ("DOG", "dog"),
+ ]
enum2 = enum_for_field(PetType, "pet_kind")
assert enum2 is enum
enum2 = PetType.enum_for_field("pet_kind")
diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py
index 456254f..055a87f 100644
--- a/graphene_sqlalchemy/tests/test_query.py
+++ b/graphene_sqlalchemy/tests/test_query.py
@@ -1,11 +1,15 @@
from datetime import date
+import pytest
+from sqlalchemy import select
+
import graphene
from graphene.relay import Node
from ..converter import convert_sqlalchemy_composite
from ..fields import SQLAlchemyConnectionField
from ..types import ORMField, SQLAlchemyInterface, SQLAlchemyObjectType
+from ..utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, get_session
from .models import (
Article,
CompositeFullName,
@@ -16,10 +20,13 @@
Pet,
Reporter,
)
-from .utils import to_std_dicts
+from .utils import eventually_await_session, to_std_dicts
+
+if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
+ from sqlalchemy.ext.asyncio import AsyncSession
-def add_test_data(session):
+async def add_test_data(session):
reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat")
session.add(reporter)
pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT)
@@ -35,11 +42,12 @@ def add_test_data(session):
session.add(pet)
editor = Editor(name="Jack")
session.add(editor)
- session.commit()
+ await eventually_await_session(session, "commit")
-def test_query_fields(session):
- add_test_data(session)
+@pytest.mark.asyncio
+async def test_query_fields(session):
+ await add_test_data(session)
@convert_sqlalchemy_composite.register(CompositeFullName)
def convert_composite_class(composite, registry):
@@ -53,10 +61,16 @@ class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
reporters = graphene.List(ReporterType)
- def resolve_reporter(self, _info):
+ async def resolve_reporter(self, _info):
+ session = get_session(_info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Reporter))).unique().first()
return session.query(Reporter).first()
- def resolve_reporters(self, _info):
+ async def resolve_reporters(self, _info):
+ session = get_session(_info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Reporter))).unique().all()
return session.query(Reporter)
query = """
@@ -82,14 +96,15 @@ def resolve_reporters(self, _info):
"reporters": [{"firstName": "John"}, {"firstName": "Jane"}],
}
schema = graphene.Schema(query=Query)
- result = schema.execute(query)
+ result = await schema.execute_async(query, context_value={"session": session})
assert not result.errors
result = to_std_dicts(result.data)
assert result == expected
-def test_query_node(session):
- add_test_data(session)
+@pytest.mark.asyncio
+async def test_query_node_sync(session):
+ await add_test_data(session)
class ReporterNode(SQLAlchemyObjectType):
class Meta:
@@ -111,6 +126,14 @@ class Query(graphene.ObjectType):
all_articles = SQLAlchemyConnectionField(ArticleNode.connection)
def resolve_reporter(self, _info):
+ session = get_session(_info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+
+ async def get_result():
+ return (await session.scalars(select(Reporter))).first()
+
+ return get_result()
+
return session.query(Reporter).first()
query = """
@@ -154,14 +177,100 @@ def resolve_reporter(self, _info):
"myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"},
}
schema = graphene.Schema(query=Query)
- result = schema.execute(query, context_value={"session": session})
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ result = schema.execute(query, context_value={"session": session})
+ assert result.errors
+ else:
+ result = schema.execute(query, context_value={"session": session})
+ assert not result.errors
+ result = to_std_dicts(result.data)
+ assert result == expected
+
+
+@pytest.mark.asyncio
+async def test_query_node_async(session):
+ await add_test_data(session)
+
+ class ReporterNode(SQLAlchemyObjectType):
+ class Meta:
+ model = Reporter
+ interfaces = (Node,)
+
+ @classmethod
+ def get_node(cls, info, id):
+ return Reporter(id=2, first_name="Cookie Monster")
+
+ class ArticleNode(SQLAlchemyObjectType):
+ class Meta:
+ model = Article
+ interfaces = (Node,)
+
+ class Query(graphene.ObjectType):
+ node = Node.Field()
+ reporter = graphene.Field(ReporterNode)
+ all_articles = SQLAlchemyConnectionField(ArticleNode.connection)
+
+ def resolve_reporter(self, _info):
+ session = get_session(_info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+
+ async def get_result():
+ return (await session.scalars(select(Reporter))).first()
+
+ return get_result()
+
+ return session.query(Reporter).first()
+
+ query = """
+ query {
+ reporter {
+ id
+ firstName
+ articles {
+ edges {
+ node {
+ headline
+ }
+ }
+ }
+ }
+ allArticles {
+ edges {
+ node {
+ headline
+ }
+ }
+ }
+ myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") {
+ id
+ ... on ReporterNode {
+ firstName
+ }
+ ... on ArticleNode {
+ headline
+ }
+ }
+ }
+ """
+ expected = {
+ "reporter": {
+ "id": "UmVwb3J0ZXJOb2RlOjE=",
+ "firstName": "John",
+ "articles": {"edges": [{"node": {"headline": "Hi!"}}]},
+ },
+ "allArticles": {"edges": [{"node": {"headline": "Hi!"}}]},
+ "myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"},
+ }
+ schema = graphene.Schema(query=Query)
+ result = await schema.execute_async(query, context_value={"session": session})
assert not result.errors
result = to_std_dicts(result.data)
assert result == expected
-def test_orm_field(session):
- add_test_data(session)
+@pytest.mark.asyncio
+async def test_orm_field(session):
+ await add_test_data(session)
@convert_sqlalchemy_composite.register(CompositeFullName)
def convert_composite_class(composite, registry):
@@ -187,7 +296,10 @@ class Meta:
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
- def resolve_reporter(self, _info):
+ async def resolve_reporter(self, _info):
+ session = get_session(_info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Reporter))).first()
return session.query(Reporter).first()
query = """
@@ -221,14 +333,15 @@ def resolve_reporter(self, _info):
},
}
schema = graphene.Schema(query=Query)
- result = schema.execute(query, context_value={"session": session})
+ result = await schema.execute_async(query, context_value={"session": session})
assert not result.errors
result = to_std_dicts(result.data)
assert result == expected
-def test_custom_identifier(session):
- add_test_data(session)
+@pytest.mark.asyncio
+async def test_custom_identifier(session):
+ await add_test_data(session)
class EditorNode(SQLAlchemyObjectType):
class Meta:
@@ -262,14 +375,15 @@ class Query(graphene.ObjectType):
}
schema = graphene.Schema(query=Query)
- result = schema.execute(query, context_value={"session": session})
+ result = await schema.execute_async(query, context_value={"session": session})
assert not result.errors
result = to_std_dicts(result.data)
assert result == expected
-def test_mutation(session):
- add_test_data(session)
+@pytest.mark.asyncio
+async def test_mutation(session, session_factory):
+ await add_test_data(session)
class EditorNode(SQLAlchemyObjectType):
class Meta:
@@ -282,8 +396,11 @@ class Meta:
interfaces = (Node,)
@classmethod
- def get_node(cls, id, info):
- return Reporter(id=2, first_name="Cookie Monster")
+ async def get_node(cls, id, info):
+ session = get_session(info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Reporter))).unique().first()
+ return session.query(Reporter).first()
class ArticleNode(SQLAlchemyObjectType):
class Meta:
@@ -298,11 +415,14 @@ class Arguments:
ok = graphene.Boolean()
article = graphene.Field(ArticleNode)
- def mutate(self, info, headline, reporter_id):
+ async def mutate(self, info, headline, reporter_id):
+ reporter = await ReporterNode.get_node(reporter_id, info)
new_article = Article(headline=headline, reporter_id=reporter_id)
+ reporter.articles = [*reporter.articles, new_article]
+ session = get_session(info.context)
+ session.add(reporter)
- session.add(new_article)
- session.commit()
+ await eventually_await_session(session, "commit")
ok = True
return CreateArticle(article=new_article, ok=ok)
@@ -341,24 +461,28 @@ class Mutation(graphene.ObjectType):
}
schema = graphene.Schema(query=Query, mutation=Mutation)
- result = schema.execute(query, context_value={"session": session})
+ result = await schema.execute_async(
+ query, context_value={"session": session_factory()}
+ )
assert not result.errors
result = to_std_dicts(result.data)
assert result == expected
-def add_person_data(session):
+async def add_person_data(session):
bob = Employee(name="Bob", birth_date=date(1990, 1, 1), hire_date=date(2015, 1, 1))
session.add(bob)
joe = Employee(name="Joe", birth_date=date(1980, 1, 1), hire_date=date(2010, 1, 1))
session.add(joe)
jen = Employee(name="Jen", birth_date=date(1995, 1, 1), hire_date=date(2020, 1, 1))
session.add(jen)
- session.commit()
+ await eventually_await_session(session, "commit")
-def test_interface_query_on_base_type(session):
- add_person_data(session)
+@pytest.mark.asyncio
+async def test_interface_query_on_base_type(session_factory):
+ session = session_factory()
+ await add_person_data(session)
class PersonType(SQLAlchemyInterface):
class Meta:
@@ -372,11 +496,13 @@ class Meta:
class Query(graphene.ObjectType):
people = graphene.Field(graphene.List(PersonType))
- def resolve_people(self, _info):
+ async def resolve_people(self, _info):
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Person))).all()
return session.query(Person).all()
schema = graphene.Schema(query=Query, types=[PersonType, EmployeeType])
- result = schema.execute(
+ result = await schema.execute_async(
"""
query {
people {
diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py
index 923bbed..14c87f7 100644
--- a/graphene_sqlalchemy/tests/test_query_enums.py
+++ b/graphene_sqlalchemy/tests/test_query_enums.py
@@ -1,12 +1,22 @@
+import pytest
+from sqlalchemy import select
+
import graphene
+from graphene_sqlalchemy.tests.utils import eventually_await_session
+from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, get_session
from ..types import SQLAlchemyObjectType
from .models import HairKind, Pet, Reporter
from .test_query import add_test_data, to_std_dicts
+if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
+ from sqlalchemy.ext.asyncio import AsyncSession
+
-def test_query_pet_kinds(session):
- add_test_data(session)
+@pytest.mark.asyncio
+async def test_query_pet_kinds(session, session_factory):
+ await add_test_data(session)
+ await eventually_await_session(session, "close")
class PetType(SQLAlchemyObjectType):
class Meta:
@@ -23,13 +33,25 @@ class Query(graphene.ObjectType):
PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind"))
)
- def resolve_reporter(self, _info):
+ async def resolve_reporter(self, _info):
+ session = get_session(_info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Reporter))).unique().first()
return session.query(Reporter).first()
- def resolve_reporters(self, _info):
+ async def resolve_reporters(self, _info):
+ session = get_session(_info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Reporter))).unique().all()
return session.query(Reporter)
- def resolve_pets(self, _info, kind):
+ async def resolve_pets(self, _info, kind):
+ session = get_session(_info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ query = select(Pet)
+ if kind:
+ query = query.filter(Pet.pet_kind == kind.value)
+ return (await session.scalars(query)).unique().all()
query = session.query(Pet)
if kind:
query = query.filter_by(pet_kind=kind.value)
@@ -78,13 +100,16 @@ def resolve_pets(self, _info, kind):
"pets": [{"name": "Lassie", "petKind": "DOG"}],
}
schema = graphene.Schema(query=Query)
- result = schema.execute(query)
+ result = await schema.execute_async(
+ query, context_value={"session": session_factory()}
+ )
assert not result.errors
assert result.data == expected
-def test_query_more_enums(session):
- add_test_data(session)
+@pytest.mark.asyncio
+async def test_query_more_enums(session):
+ await add_test_data(session)
class PetType(SQLAlchemyObjectType):
class Meta:
@@ -93,7 +118,10 @@ class Meta:
class Query(graphene.ObjectType):
pet = graphene.Field(PetType)
- def resolve_pet(self, _info):
+ async def resolve_pet(self, _info):
+ session = get_session(_info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Pet))).first()
return session.query(Pet).first()
query = """
@@ -107,14 +135,15 @@ def resolve_pet(self, _info):
"""
expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}}
schema = graphene.Schema(query=Query)
- result = schema.execute(query)
+ result = await schema.execute_async(query, context_value={"session": session})
assert not result.errors
result = to_std_dicts(result.data)
assert result == expected
-def test_enum_as_argument(session):
- add_test_data(session)
+@pytest.mark.asyncio
+async def test_enum_as_argument(session):
+ await add_test_data(session)
class PetType(SQLAlchemyObjectType):
class Meta:
@@ -125,7 +154,13 @@ class Query(graphene.ObjectType):
PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind"))
)
- def resolve_pet(self, info, kind=None):
+ async def resolve_pet(self, info, kind=None):
+ session = get_session(info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ query = select(Pet)
+ if kind:
+ query = query.filter(Pet.pet_kind == kind.value)
+ return (await session.scalars(query)).first()
query = session.query(Pet)
if kind:
query = query.filter(Pet.pet_kind == kind.value)
@@ -142,19 +177,24 @@ def resolve_pet(self, info, kind=None):
"""
schema = graphene.Schema(query=Query)
- result = schema.execute(query, variables={"kind": "CAT"})
+ result = await schema.execute_async(
+ query, variables={"kind": "CAT"}, context_value={"session": session}
+ )
assert not result.errors
expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}}
assert result.data == expected
- result = schema.execute(query, variables={"kind": "DOG"})
+ result = await schema.execute_async(
+ query, variables={"kind": "DOG"}, context_value={"session": session}
+ )
assert not result.errors
expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}}
result = to_std_dicts(result.data)
assert result == expected
-def test_py_enum_as_argument(session):
- add_test_data(session)
+@pytest.mark.asyncio
+async def test_py_enum_as_argument(session):
+ await add_test_data(session)
class PetType(SQLAlchemyObjectType):
class Meta:
@@ -166,7 +206,14 @@ class Query(graphene.ObjectType):
kind=graphene.Argument(PetType._meta.fields["hair_kind"].type.of_type),
)
- def resolve_pet(self, _info, kind=None):
+ async def resolve_pet(self, _info, kind=None):
+ session = get_session(_info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (
+ await session.scalars(
+ select(Pet).filter(Pet.hair_kind == HairKind(kind))
+ )
+ ).first()
query = session.query(Pet)
if kind:
# enum arguments are expected to be strings, not PyEnums
@@ -184,11 +231,15 @@ def resolve_pet(self, _info, kind=None):
"""
schema = graphene.Schema(query=Query)
- result = schema.execute(query, variables={"kind": "SHORT"})
+ result = await schema.execute_async(
+ query, variables={"kind": "SHORT"}, context_value={"session": session}
+ )
assert not result.errors
expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}}
assert result.data == expected
- result = schema.execute(query, variables={"kind": "LONG"})
+ result = await schema.execute_async(
+ query, variables={"kind": "LONG"}, context_value={"session": session}
+ )
assert not result.errors
expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}}
result = to_std_dicts(result.data)
diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py
index 11c7c9a..f8f1ff8 100644
--- a/graphene_sqlalchemy/tests/test_sort_enums.py
+++ b/graphene_sqlalchemy/tests/test_sort_enums.py
@@ -9,16 +9,17 @@
from ..utils import to_type_name
from .models import Base, HairKind, KeyedModel, Pet
from .test_query import to_std_dicts
+from .utils import eventually_await_session
-def add_pets(session):
+async def add_pets(session):
pets = [
Pet(id=1, name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG),
Pet(id=2, name="Barf", pet_kind="dog", hair_kind=HairKind.LONG),
Pet(id=3, name="Alf", pet_kind="cat", hair_kind=HairKind.LONG),
]
session.add_all(pets)
- session.commit()
+ await eventually_await_session(session, "commit")
def test_sort_enum():
@@ -241,8 +242,9 @@ def get_symbol_name(column_name, sort_asc=True):
assert sort_arg.default_value == ["IdUp"]
-def test_sort_query(session):
- add_pets(session)
+@pytest.mark.asyncio
+async def test_sort_query(session):
+ await add_pets(session)
class PetNode(SQLAlchemyObjectType):
class Meta:
@@ -336,7 +338,7 @@ def makeNodes(nodeList):
} # yapf: disable
schema = Schema(query=Query)
- result = schema.execute(query, context_value={"session": session})
+ result = await schema.execute_async(query, context_value={"session": session})
assert not result.errors
result = to_std_dicts(result.data)
assert result == expected
@@ -352,7 +354,7 @@ def makeNodes(nodeList):
}
}
"""
- result = schema.execute(queryError, context_value={"session": session})
+ result = await schema.execute_async(queryError, context_value={"session": session})
assert result.errors is not None
assert "cannot represent non-enum value" in result.errors[0].message
@@ -375,7 +377,7 @@ def makeNodes(nodeList):
}
"""
- result = schema.execute(queryNoSort, context_value={"session": session})
+ result = await schema.execute_async(queryNoSort, context_value={"session": session})
assert not result.errors
# TODO: SQLite usually returns the results ordered by primary key,
# so we cannot test this way whether sorting actually happens or not.
diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py
index 813fb13..6632842 100644
--- a/graphene_sqlalchemy/tests/test_types.py
+++ b/graphene_sqlalchemy/tests/test_types.py
@@ -4,6 +4,9 @@
import pytest
import sqlalchemy.exc
import sqlalchemy.orm.exc
+from graphql.pyutils import is_awaitable
+from sqlalchemy import select
+
from graphene import (
Boolean,
Dynamic,
@@ -20,7 +23,6 @@
)
from graphene.relay import Connection
-from .models import Article, CompositeFullName, Employee, Person, Pet, Reporter, NonAbstractPerson
from .. import utils
from ..converter import convert_sqlalchemy_composite
from ..fields import (
@@ -36,11 +38,26 @@
SQLAlchemyObjectType,
SQLAlchemyObjectTypeOptions,
)
+from ..utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4
+from .models import (
+ Article,
+ CompositeFullName,
+ Employee,
+ NonAbstractPerson,
+ Person,
+ Pet,
+ Reporter,
+)
+from .utils import eventually_await_session
+
+if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
+ from sqlalchemy.ext.asyncio import AsyncSession
def test_should_raise_if_no_model():
re_err = r"valid SQLAlchemy Model"
with pytest.raises(Exception, match=re_err):
+
class Character1(SQLAlchemyObjectType):
pass
@@ -48,12 +65,14 @@ class Character1(SQLAlchemyObjectType):
def test_should_raise_if_model_is_invalid():
re_err = r"valid SQLAlchemy Model"
with pytest.raises(Exception, match=re_err):
+
class Character(SQLAlchemyObjectType):
class Meta:
model = 1
-def test_sqlalchemy_node(session):
+@pytest.mark.asyncio
+async def test_sqlalchemy_node(session):
class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
@@ -64,9 +83,11 @@ class Meta:
reporter = Reporter()
session.add(reporter)
- session.commit()
+ await eventually_await_session(session, "commit")
info = mock.Mock(context={"session": session})
reporter_node = ReporterType.get_node(info, reporter.id)
+ if is_awaitable(reporter_node):
+ reporter_node = await reporter_node
assert reporter == reporter_node
@@ -97,7 +118,7 @@ class Meta:
assert sorted(list(ReporterType._meta.fields.keys())) == sorted(
[
# Columns
- "column_prop",
+ "column_prop", # SQLAlchemy retuns column properties first
"id",
"first_name",
"last_name",
@@ -320,6 +341,7 @@ def test_invalid_model_attr():
"Cannot map ORMField to a model attribute.\n" "Field: 'ReporterType.first_name'"
)
with pytest.raises(ValueError, match=err_msg):
+
class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
@@ -373,6 +395,7 @@ class Meta:
def test_only_and_exclude_fields():
re_err = r"'only_fields' and 'exclude_fields' cannot be both set"
with pytest.raises(Exception, match=re_err):
+
class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
@@ -392,9 +415,19 @@ class Meta:
assert first_name_field.type == Int
-def test_resolvers(session):
+@pytest.mark.asyncio
+async def test_resolvers(session):
"""Test that the correct resolver functions are called"""
+ reporter = Reporter(
+ first_name="first_name",
+ last_name="last_name",
+ email="email",
+ favorite_pet_kind="cat",
+ )
+ session.add(reporter)
+ await eventually_await_session(session, "commit")
+
class ReporterMixin(object):
def resolve_id(root, _info):
return "ID"
@@ -420,20 +453,14 @@ def resolve_favorite_pet_kind_v2(root, _info):
class Query(ObjectType):
reporter = Field(ReporterType)
- def resolve_reporter(self, _info):
+ async def resolve_reporter(self, _info):
+ session = utils.get_session(_info.context)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return (await session.scalars(select(Reporter))).unique().first()
return session.query(Reporter).first()
- reporter = Reporter(
- first_name="first_name",
- last_name="last_name",
- email="email",
- favorite_pet_kind="cat",
- )
- session.add(reporter)
- session.commit()
-
schema = Schema(query=Query)
- result = schema.execute(
+ result = await schema.execute_async(
"""
query {
reporter {
@@ -446,7 +473,8 @@ def resolve_reporter(self, _info):
favoritePetKindV2
}
}
- """
+ """,
+ context_value={"session": session},
)
assert not result.errors
@@ -511,8 +539,13 @@ class Meta:
def test_interface_with_polymorphic_identity():
- with pytest.raises(AssertionError,
- match=re.escape('PersonType: An interface cannot map to a concrete type (polymorphic_identity is "person")')):
+ with pytest.raises(
+ AssertionError,
+ match=re.escape(
+ 'PersonType: An interface cannot map to a concrete type (polymorphic_identity is "person")'
+ ),
+ ):
+
class PersonType(SQLAlchemyInterface):
class Meta:
model = NonAbstractPerson
@@ -562,13 +595,15 @@ class Meta:
# type should be in this list because we used ORMField
# to force its presence on the model
- assert sorted(list(EmployeeType._meta.fields.keys())) == sorted([
- "id",
- "name",
- "type",
- "birth_date",
- "hire_date",
- ])
+ assert sorted(list(EmployeeType._meta.fields.keys())) == sorted(
+ [
+ "id",
+ "name",
+ "type",
+ "birth_date",
+ "hire_date",
+ ]
+ )
def test_interface_custom_resolver():
@@ -590,13 +625,15 @@ class Meta:
# type should be in this list because we used ORMField
# to force its presence on the model
- assert sorted(list(EmployeeType._meta.fields.keys())) == sorted([
- "id",
- "name",
- "custom_field",
- "birth_date",
- "hire_date",
- ])
+ assert sorted(list(EmployeeType._meta.fields.keys())) == sorted(
+ [
+ "id",
+ "name",
+ "custom_field",
+ "birth_date",
+ "hire_date",
+ ]
+ )
# Tests for connection_field_factory
diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py
index c90ee47..4a11824 100644
--- a/graphene_sqlalchemy/tests/utils.py
+++ b/graphene_sqlalchemy/tests/utils.py
@@ -1,3 +1,4 @@
+import inspect
import re
@@ -15,3 +16,11 @@ def remove_cache_miss_stat(message):
"""Remove the stat from the echoed query message when the cache is missed for sqlalchemy version >= 1.4"""
# https://github.com/sqlalchemy/sqlalchemy/blob/990eb3d8813369d3b8a7776ae85fb33627443d30/lib/sqlalchemy/engine/default.py#L1177
return re.sub(r"\[generated in \d+.?\d*s\]\s", "", message)
+
+
+async def eventually_await_session(session, func, *args):
+
+ if inspect.iscoroutinefunction(getattr(session, func)):
+ await getattr(session, func)(*args)
+ else:
+ getattr(session, func)(*args)
diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py
index e0ada38..226d1e8 100644
--- a/graphene_sqlalchemy/types.py
+++ b/graphene_sqlalchemy/types.py
@@ -1,4 +1,6 @@
from collections import OrderedDict
+from inspect import isawaitable
+from typing import Any
import sqlalchemy
from sqlalchemy.ext.hybrid import hybrid_property
@@ -26,7 +28,16 @@
)
from .registry import Registry, get_global_registry
from .resolvers import get_attr_resolver, get_custom_resolver
-from .utils import get_query, is_mapped_class, is_mapped_instance
+from .utils import (
+ SQL_VERSION_HIGHER_EQUAL_THAN_1_4,
+ get_query,
+ get_session,
+ is_mapped_class,
+ is_mapped_instance,
+)
+
+if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
+ from sqlalchemy.ext.asyncio import AsyncSession
class ORMField(OrderedType):
@@ -334,6 +345,11 @@ def __init_subclass_with_meta__(
def is_type_of(cls, root, info):
if isinstance(root, cls):
return True
+ if isawaitable(root):
+ raise Exception(
+ "Received coroutine instead of sql alchemy model. "
+ "You seem to use an async engine with synchronous schema execution"
+ )
if not is_mapped_instance(root):
raise Exception(('Received incompatible instance "{}".').format(root))
return isinstance(root, cls._meta.model)
@@ -345,6 +361,19 @@ def get_query(cls, info):
@classmethod
def get_node(cls, info, id):
+ if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
+ try:
+ return cls.get_query(info).get(id)
+ except NoResultFound:
+ return None
+
+ session = get_session(info.context)
+ if isinstance(session, AsyncSession):
+
+ async def get_result() -> Any:
+ return await session.get(cls._meta.model, id)
+
+ return get_result()
try:
return cls.get_query(info).get(id)
except NoResultFound:
diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py
index 54bb840..62c71d8 100644
--- a/graphene_sqlalchemy/utils.py
+++ b/graphene_sqlalchemy/utils.py
@@ -4,11 +4,34 @@
from typing import Any, Callable, Dict, Optional
import pkg_resources
+from sqlalchemy import select
from sqlalchemy.exc import ArgumentError
from sqlalchemy.orm import class_mapper, object_mapper
from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError
+def is_sqlalchemy_version_less_than(version_string):
+ """Check the installed SQLAlchemy version"""
+ return pkg_resources.get_distribution(
+ "SQLAlchemy"
+ ).parsed_version < pkg_resources.parse_version(version_string)
+
+
+def is_graphene_version_less_than(version_string): # pragma: no cover
+ """Check the installed graphene version"""
+ return pkg_resources.get_distribution(
+ "graphene"
+ ).parsed_version < pkg_resources.parse_version(version_string)
+
+
+SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = False
+
+if not is_sqlalchemy_version_less_than("1.4"):
+ from sqlalchemy.ext.asyncio import AsyncSession
+
+ SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = True
+
+
def get_session(context):
return context.get("session")
@@ -22,6 +45,8 @@ def get_query(model, context):
"A query in the model Base or a session in the schema is required for querying.\n"
"Read more http://docs.graphene-python.org/projects/sqlalchemy/en/latest/tips/#querying"
)
+ if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):
+ return select(model)
query = session.query(model)
return query
@@ -151,20 +176,6 @@ def sort_argument_for_model(cls, has_default=True):
return Argument(List(enum), default_value=enum.default)
-def is_sqlalchemy_version_less_than(version_string): # pragma: no cover
- """Check the installed SQLAlchemy version"""
- return pkg_resources.get_distribution(
- "SQLAlchemy"
- ).parsed_version < pkg_resources.parse_version(version_string)
-
-
-def is_graphene_version_less_than(version_string): # pragma: no cover
- """Check the installed graphene version"""
- return pkg_resources.get_distribution(
- "graphene"
- ).parsed_version < pkg_resources.parse_version(version_string)
-
-
class singledispatchbymatchfunction:
"""
Inspired by @singledispatch, this is a variant that works using a matcher function
diff --git a/setup.py b/setup.py
index ac9ad7e..9122baf 100644
--- a/setup.py
+++ b/setup.py
@@ -21,10 +21,13 @@
tests_require = [
"pytest>=6.2.0,<7.0",
- "pytest-asyncio>=0.15.1",
+ "pytest-asyncio>=0.18.3",
"pytest-cov>=2.11.0,<3.0",
"sqlalchemy_utils>=0.37.0,<1.0",
"pytest-benchmark>=3.4.0,<4.0",
+ "aiosqlite>=0.17.0",
+ "nest-asyncio",
+ "greenlet",
]
setup(