diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 428eca1..7632fd3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,6 +1,6 @@ name: Tests -on: +on: push: branches: - 'master' diff --git a/docs/inheritance.rst b/docs/inheritance.rst index ee16f06..7473216 100644 --- a/docs/inheritance.rst +++ b/docs/inheritance.rst @@ -3,7 +3,7 @@ Inheritance Examples Create interfaces from inheritance relationships ------------------------------------------------ - +.. note:: If you're using `AsyncSession`, please check the chapter `Eager Loading & Using with AsyncSession`_. SQLAlchemy has excellent support for class inheritance hierarchies. These hierarchies can be represented in your GraphQL schema by means of interfaces_. Much like ObjectTypes, Interfaces in @@ -40,7 +40,7 @@ from the attributes of their underlying SQLAlchemy model: __mapper_args__ = { "polymorphic_identity": "employee", } - + class Customer(Person): first_purchase_date = Column(Date()) @@ -56,17 +56,17 @@ from the attributes of their underlying SQLAlchemy model: class Meta: model = Employee interfaces = (relay.Node, PersonType) - + class CustomerType(SQLAlchemyObjectType): class Meta: model = Customer interfaces = (relay.Node, PersonType) -Keep in mind that `PersonType` is a `SQLAlchemyInterface`. Interfaces must -be linked to an abstract Model that does not specify a `polymorphic_identity`, -because we cannot return instances of interfaces from a GraphQL query. -If Person specified a `polymorphic_identity`, instances of Person could -be inserted into and returned by the database, potentially causing +Keep in mind that `PersonType` is a `SQLAlchemyInterface`. Interfaces must +be linked to an abstract Model that does not specify a `polymorphic_identity`, +because we cannot return instances of interfaces from a GraphQL query. +If Person specified a `polymorphic_identity`, instances of Person could +be inserted into and returned by the database, potentially causing Persons to be returned to the resolvers. When querying on the base type, you can refer directly to common fields, @@ -85,15 +85,19 @@ and fields on concrete implementations using the `... on` syntax: firstPurchaseDate } } - - + + +.. danger:: + When using joined table inheritance, this style of querying may lead to unbatched implicit IO with negative performance implications. + See the chapter `Eager Loading & Using with AsyncSession`_ for more information on eager loading all possible types of a `SQLAlchemyInterface`. + Please note that by default, the "polymorphic_on" column is *not* generated as a field on types that use polymorphic inheritance, as -this is considered an implentation detail. The idiomatic way to +this is considered an implementation detail. The idiomatic way to retrieve the concrete GraphQL type of an object is to query for the -`__typename` field. +`__typename` field. To override this behavior, an `ORMField` needs to be created -for the custom type field on the corresponding `SQLAlchemyInterface`. This is *not recommended* +for the custom type field on the corresponding `SQLAlchemyInterface`. This is *not recommended* as it promotes abiguous schema design If your SQLAlchemy model only specifies a relationship to the @@ -103,5 +107,39 @@ class to the Schema constructor via the `types=` argument: .. code:: python schema = graphene.Schema(..., types=[PersonType, EmployeeType, CustomerType]) - + + See also: `Graphene Interfaces `_ + +Eager Loading & Using with AsyncSession +-------------------- +When querying the base type in multi-table inheritance or joined table inheritance, you can only directly refer to polymorphic fields when they are loaded eagerly. +This restricting is in place because AsyncSessions don't allow implicit async operations such as the loads of the joined tables. +To load the polymorphic fields eagerly, you can use the `with_polymorphic` attribute of the mapper args in the base model: + +.. code:: python + class Person(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) + birth_date = Column(Date()) + + __tablename__ = "person" + __mapper_args__ = { + "polymorphic_on": type, + "with_polymorphic": "*", # needed for eager loading in async session + } + +Alternatively, the specific polymorphic fields can be loaded explicitly in resolvers: + +.. code:: python + + class Query(graphene.ObjectType): + people = graphene.Field(graphene.List(PersonType)) + + async def resolve_people(self, _info): + return (await session.scalars(with_polymorphic(Person, [Engineer, Customer]))).all() + +Dynamic batching of the types based on the query to avoid eager is currently not supported, but could be implemented in a future PR. + +For more information on loading techniques for polymorphic models, please check out the `SQLAlchemy docs `_. diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 0800d0e..23b6712 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext -from .utils import is_graphene_version_less_than, is_sqlalchemy_version_less_than +from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_graphene_version_less_than def get_data_loader_impl() -> Any: # pragma: no cover @@ -71,19 +71,19 @@ async def batch_load_fn(self, parents): # For our purposes, the query_context will only used to get the session query_context = None - if is_sqlalchemy_version_less_than("1.4"): - query_context = QueryContext(session.query(parent_mapper.entity)) - else: + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: parent_mapper_query = session.query(parent_mapper.entity) query_context = parent_mapper_query._compile_context() - - if is_sqlalchemy_version_less_than("1.4"): + else: + query_context = QueryContext(session.query(parent_mapper.entity)) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: self.selectin_loader._load_for_path( query_context, parent_mapper._path_registry, states, None, child_mapper, + None, ) else: self.selectin_loader._load_for_path( @@ -92,7 +92,6 @@ async def batch_load_fn(self, parents): states, None, child_mapper, - None, ) return [getattr(parent, self.relationship_prop.key) for parent in parents] diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 2cb53c5..6dbc134 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -11,7 +11,10 @@ from graphql_relay import connection_from_array_slice from .batching import get_batch_resolver -from .utils import EnumValue, get_query +from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, EnumValue, get_query, get_session + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession class SQLAlchemyConnectionField(ConnectionField): @@ -81,8 +84,49 @@ def get_query(cls, model, info, sort=None, **args): @classmethod def resolve_connection(cls, connection_type, model, info, args, resolved): + session = get_session(info.context) + if resolved is None: + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + + async def get_result(): + return await cls.resolve_connection_async( + connection_type, model, info, args, resolved + ) + + return get_result() + + else: + resolved = cls.get_query(model, info, **args) + if isinstance(resolved, Query): + _len = resolved.count() + else: + _len = len(resolved) + + def adjusted_connection_adapter(edges, pageInfo): + return connection_adapter(connection_type, edges, pageInfo) + + connection = connection_from_array_slice( + array_slice=resolved, + args=args, + slice_start=0, + array_length=_len, + array_slice_length=_len, + connection_type=adjusted_connection_adapter, + edge_type=connection_type.Edge, + page_info_type=page_info_adapter, + ) + connection.iterable = resolved + connection.length = _len + return connection + + @classmethod + async def resolve_connection_async( + cls, connection_type, model, info, args, resolved + ): + session = get_session(info.context) if resolved is None: - resolved = cls.get_query(model, info, **args) + query = cls.get_query(model, info, **args) + resolved = (await session.scalars(query)).all() if isinstance(resolved, Query): _len = resolved.count() else: @@ -179,7 +223,7 @@ def from_relationship(cls, relationship, registry, **field_kwargs): return cls( model_type.connection, resolver=get_batch_resolver(relationship), - **field_kwargs + **field_kwargs, ) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 357ad96..89b357a 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,14 +1,17 @@ import pytest +import pytest_asyncio from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker import graphene +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 from ..converter import convert_sqlalchemy_composite from ..registry import reset_global_registry from .models import Base, CompositeFullName -test_db_url = "sqlite://" # use in-memory database for tests +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine @pytest.fixture(autouse=True) @@ -22,18 +25,49 @@ def convert_composite_class(composite, registry): return graphene.Field(graphene.Int) -@pytest.fixture(scope="function") -def session_factory(): - engine = create_engine(test_db_url) - Base.metadata.create_all(engine) +@pytest.fixture(params=[False, True]) +def async_session(request): + return request.param + + +@pytest.fixture +def test_db_url(async_session: bool): + if async_session: + return "sqlite+aiosqlite://" + else: + return "sqlite://" - yield sessionmaker(bind=engine) +@pytest.mark.asyncio +@pytest_asyncio.fixture(scope="function") +async def session_factory(async_session: bool, test_db_url: str): + if async_session: + if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + pytest.skip("Async Sessions only work in sql alchemy 1.4 and above") + engine = create_async_engine(test_db_url) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False) + await engine.dispose() + else: + engine = create_engine(test_db_url) + Base.metadata.create_all(engine) + yield sessionmaker(bind=engine, expire_on_commit=False) + # SQLite in-memory db is deleted when its connection is closed. + # https://www.sqlite.org/inmemorydb.html + engine.dispose() + + +@pytest_asyncio.fixture(scope="function") +async def sync_session_factory(): + engine = create_engine("sqlite://") + Base.metadata.create_all(engine) + yield sessionmaker(bind=engine, expire_on_commit=False) # SQLite in-memory db is deleted when its connection is closed. # https://www.sqlite.org/inmemorydb.html engine.dispose() -@pytest.fixture(scope="function") +@pytest_asyncio.fixture(scope="function") def session(session_factory): return session_factory() diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 4fe9146..ee28658 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -20,7 +20,7 @@ ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import column_property, composite, mapper, relationship +from sqlalchemy.orm import backref, column_property, composite, mapper, relationship PetKind = Enum("cat", "dog", name="pet_kind") @@ -76,10 +76,16 @@ class Reporter(Base): email = Column(String(), doc="Email") favorite_pet_kind = Column(PetKind) pets = relationship( - "Pet", secondary=association_table, backref="reporters", order_by="Pet.id" + "Pet", + secondary=association_table, + backref="reporters", + order_by="Pet.id", + lazy="selectin", ) - articles = relationship("Article", backref="reporter") - favorite_article = relationship("Article", uselist=False) + articles = relationship( + "Article", backref=backref("reporter", lazy="selectin"), lazy="selectin" + ) + favorite_article = relationship("Article", uselist=False, lazy="selectin") @hybrid_property def hybrid_prop_with_doc(self): @@ -304,8 +310,10 @@ class Person(Base): __tablename__ = "person" __mapper_args__ = { "polymorphic_on": type, + "with_polymorphic": "*", # needed for eager loading in async session } + class NonAbstractPerson(Base): id = Column(Integer(), primary_key=True) type = Column(String()) @@ -318,6 +326,7 @@ class NonAbstractPerson(Base): "polymorphic_identity": "person", } + class Employee(Person): hire_date = Column(Date()) diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py new file mode 100644 index 0000000..6f1c42f --- /dev/null +++ b/graphene_sqlalchemy/tests/models_batching.py @@ -0,0 +1,91 @@ +from __future__ import absolute_import + +import enum + +from sqlalchemy import ( + Column, + Date, + Enum, + ForeignKey, + Integer, + String, + Table, + func, + select, +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import column_property, relationship + +PetKind = Enum("cat", "dog", name="pet_kind") + + +class HairKind(enum.Enum): + LONG = "long" + SHORT = "short" + + +Base = declarative_base() + +association_table = Table( + "association", + Base.metadata, + Column("pet_id", Integer, ForeignKey("pets.id")), + Column("reporter_id", Integer, ForeignKey("reporters.id")), +) + + +class Pet(Base): + __tablename__ = "pets" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + pet_kind = Column(PetKind, nullable=False) + hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) + reporter_id = Column(Integer(), ForeignKey("reporters.id")) + + +class Reporter(Base): + __tablename__ = "reporters" + + id = Column(Integer(), primary_key=True) + first_name = Column(String(30), doc="First name") + last_name = Column(String(30), doc="Last name") + email = Column(String(), doc="Email") + favorite_pet_kind = Column(PetKind) + pets = relationship( + "Pet", + secondary=association_table, + backref="reporters", + order_by="Pet.id", + ) + articles = relationship("Article", backref="reporter") + favorite_article = relationship("Article", uselist=False) + + column_prop = column_property( + select([func.cast(func.count(id), Integer)]), doc="Column property" + ) + + +class Article(Base): + __tablename__ = "articles" + id = Column(Integer(), primary_key=True) + headline = Column(String(100)) + pub_date = Column(Date()) + reporter_id = Column(Integer(), ForeignKey("reporters.id")) + readers = relationship( + "Reader", secondary="articles_readers", back_populates="articles" + ) + + +class Reader(Base): + __tablename__ = "readers" + id = Column(Integer(), primary_key=True) + name = Column(String(100)) + articles = relationship( + "Article", secondary="articles_readers", back_populates="readers" + ) + + +class ArticleReader(Base): + __tablename__ = "articles_readers" + article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True) + reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True) diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py index 90df027..5eccd5f 100644 --- a/graphene_sqlalchemy/tests/test_batching.py +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -3,15 +3,23 @@ import logging import pytest +from sqlalchemy import select import graphene from graphene import Connection, relay from ..fields import BatchSQLAlchemyConnectionField, default_connection_field_factory from ..types import ORMField, SQLAlchemyObjectType -from ..utils import is_sqlalchemy_version_less_than -from .models import Article, HairKind, Pet, Reader, Reporter -from .utils import remove_cache_miss_stat, to_std_dicts +from ..utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_session, + is_sqlalchemy_version_less_than, +) +from .models_batching import Article, HairKind, Pet, Reader, Reporter +from .utils import eventually_await_session, remove_cache_miss_stat, to_std_dicts + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession class MockLoggingHandler(logging.Handler): @@ -41,6 +49,44 @@ def mock_sqlalchemy_logging_handler(): sql_logger.setLevel(previous_level) +def get_async_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + batching = True + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + batching = True + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + async def resolve_articles(self, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Article))).all() + return session.query(Article).all() + + async def resolve_reporters(self, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).all() + return session.query(Reporter).all() + + return graphene.Schema(query=Query) + + def get_schema(): class ReporterType(SQLAlchemyObjectType): class Meta: @@ -65,14 +111,20 @@ class Query(graphene.ObjectType): reporters = graphene.Field(graphene.List(ReporterType)) def resolve_articles(self, info): - return info.context.get("session").query(Article).all() + session = get_session(info.context) + return session.query(Article).all() def resolve_reporters(self, info): - return info.context.get("session").query(Reporter).all() + session = get_session(info.context) + return session.query(Reporter).all() return graphene.Schema(query=Query) +if is_sqlalchemy_version_less_than("1.2"): + pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) + + def get_full_relay_schema(): class ReporterType(SQLAlchemyObjectType): class Meta: @@ -107,14 +159,11 @@ class Query(graphene.ObjectType): return graphene.Schema(query=Query) -if is_sqlalchemy_version_less_than("1.2"): - pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) - - @pytest.mark.asyncio -async def test_many_to_one(session_factory): - session = session_factory() - +@pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema]) +async def test_many_to_one(sync_session_factory, schema_provider): + session = sync_session_factory() + schema = schema_provider() reporter_1 = Reporter( first_name="Reporter_1", ) @@ -135,26 +184,43 @@ async def test_many_to_one(session_factory): session.commit() session.close() - schema = get_schema() - with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() result = await schema.execute_async( """ - query { - articles { - headline - reporter { - firstName + query { + articles { + headline + reporter { + firstName + } + } } - } - } - """, + """, context_value={"session": session}, ) messages = sqlalchemy_logging_handler.messages + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "articles": [ + { + "headline": "Article_1", + "reporter": { + "firstName": "Reporter_1", + }, + }, + { + "headline": "Article_2", + "reporter": { + "firstName": "Reporter_2", + }, + }, + ], + } + assert len(messages) == 5 if is_sqlalchemy_version_less_than("1.3"): @@ -169,37 +235,19 @@ async def test_many_to_one(session_factory): assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than("1.4"): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) assert ast.literal_eval(messages[2]) == () assert sorted(ast.literal_eval(messages[4])) == [1, 2] - assert not result.errors - result = to_std_dicts(result.data) - assert result == { - "articles": [ - { - "headline": "Article_1", - "reporter": { - "firstName": "Reporter_1", - }, - }, - { - "headline": "Article_2", - "reporter": { - "firstName": "Reporter_2", - }, - }, - ], - } - @pytest.mark.asyncio -async def test_one_to_one(session_factory): - session = session_factory() - +@pytest.mark.parametrize("schema_provider", [get_schema, get_async_schema]) +async def test_one_to_one(sync_session_factory, schema_provider): + session = sync_session_factory() + schema = schema_provider() reporter_1 = Reporter( first_name="Reporter_1", ) @@ -220,26 +268,43 @@ async def test_one_to_one(session_factory): session.commit() session.close() - schema = get_schema() - with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + + session = sync_session_factory() result = await schema.execute_async( """ - query { - reporters { - firstName - favoriteArticle { - headline - } - } + query { + reporters { + firstName + favoriteArticle { + headline + } } + } """, context_value={"session": session}, ) messages = sqlalchemy_logging_handler.messages + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "favoriteArticle": { + "headline": "Article_1", + }, + }, + { + "firstName": "Reporter_2", + "favoriteArticle": { + "headline": "Article_2", + }, + }, + ], + } assert len(messages) == 5 if is_sqlalchemy_version_less_than("1.3"): @@ -254,36 +319,17 @@ async def test_one_to_one(session_factory): assert len(sql_statements) == 1 return - if not is_sqlalchemy_version_less_than("1.4"): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: messages[2] = remove_cache_miss_stat(messages[2]) messages[4] = remove_cache_miss_stat(messages[4]) assert ast.literal_eval(messages[2]) == () assert sorted(ast.literal_eval(messages[4])) == [1, 2] - assert not result.errors - result = to_std_dicts(result.data) - assert result == { - "reporters": [ - { - "firstName": "Reporter_1", - "favoriteArticle": { - "headline": "Article_1", - }, - }, - { - "firstName": "Reporter_2", - "favoriteArticle": { - "headline": "Article_2", - }, - }, - ], - } - @pytest.mark.asyncio -async def test_one_to_many(session_factory): - session = session_factory() +async def test_one_to_many(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter( first_name="Reporter_1", @@ -309,7 +355,6 @@ async def test_one_to_many(session_factory): article_4 = Article(headline="Article_4") article_4.reporter = reporter_2 session.add(article_4) - session.commit() session.close() @@ -317,7 +362,8 @@ async def test_one_to_many(session_factory): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + + session = sync_session_factory() result = await schema.execute_async( """ query { @@ -337,27 +383,6 @@ async def test_one_to_many(session_factory): ) messages = sqlalchemy_logging_handler.messages - assert len(messages) == 5 - - if is_sqlalchemy_version_less_than("1.3"): - # The batched SQL statement generated is different in 1.2.x - # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` - # See https://git.io/JewQu - sql_statements = [ - message - for message in messages - if "SELECT" in message and "JOIN articles" in message - ] - assert len(sql_statements) == 1 - return - - if not is_sqlalchemy_version_less_than("1.4"): - messages[2] = remove_cache_miss_stat(messages[2]) - messages[4] = remove_cache_miss_stat(messages[4]) - - assert ast.literal_eval(messages[2]) == () - assert sorted(ast.literal_eval(messages[4])) == [1, 2] - assert not result.errors result = to_std_dicts(result.data) assert result == { @@ -398,11 +423,31 @@ async def test_one_to_many(session_factory): }, ], } + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than("1.3"): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN articles" in message + ] + assert len(sql_statements) == 1 + return + + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] @pytest.mark.asyncio -async def test_many_to_many(session_factory): - session = session_factory() +async def test_many_to_many(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter( first_name="Reporter_1", @@ -430,15 +475,14 @@ async def test_many_to_many(session_factory): reporter_2.pets.append(pet_3) reporter_2.pets.append(pet_4) - - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") schema = get_schema() with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() result = await schema.execute_async( """ query { @@ -458,27 +502,6 @@ async def test_many_to_many(session_factory): ) messages = sqlalchemy_logging_handler.messages - assert len(messages) == 5 - - if is_sqlalchemy_version_less_than("1.3"): - # The batched SQL statement generated is different in 1.2.x - # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` - # See https://git.io/JewQu - sql_statements = [ - message - for message in messages - if "SELECT" in message and "JOIN pets" in message - ] - assert len(sql_statements) == 1 - return - - if not is_sqlalchemy_version_less_than("1.4"): - messages[2] = remove_cache_miss_stat(messages[2]) - messages[4] = remove_cache_miss_stat(messages[4]) - - assert ast.literal_eval(messages[2]) == () - assert sorted(ast.literal_eval(messages[4])) == [1, 2] - assert not result.errors result = to_std_dicts(result.data) assert result == { @@ -520,9 +543,30 @@ async def test_many_to_many(session_factory): ], } + assert len(messages) == 5 -def test_disable_batching_via_ormfield(session_factory): - session = session_factory() + if is_sqlalchemy_version_less_than("1.3"): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [ + message + for message in messages + if "SELECT" in message and "JOIN pets" in message + ] + assert len(sql_statements) == 1 + return + + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + messages[2] = remove_cache_miss_stat(messages[2]) + messages[4] = remove_cache_miss_stat(messages[4]) + + assert ast.literal_eval(messages[2]) == () + assert sorted(ast.literal_eval(messages[4])) == [1, 2] + + +def test_disable_batching_via_ormfield(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") @@ -555,7 +599,7 @@ def resolve_reporters(self, info): # Test one-to-one and many-to-one relationships with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() schema.execute( """ query { @@ -580,7 +624,7 @@ def resolve_reporters(self, info): # Test one-to-many and many-to-many relationships with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() schema.execute( """ query { @@ -607,9 +651,8 @@ def resolve_reporters(self, info): assert len(select_statements) == 2 -@pytest.mark.asyncio -def test_batch_sorting_with_custom_ormfield(session_factory): - session = session_factory() +def test_batch_sorting_with_custom_ormfield(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") @@ -642,7 +685,7 @@ class Meta: # Test one-to-one and many-to-one relationships with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() result = schema.execute( """ query { @@ -658,7 +701,7 @@ class Meta: context_value={"session": session}, ) messages = sqlalchemy_logging_handler.messages - + assert not result.errors result = to_std_dicts(result.data) assert result == { "reporters": { @@ -685,8 +728,10 @@ class Meta: @pytest.mark.asyncio -async def test_connection_factory_field_overrides_batching_is_false(session_factory): - session = session_factory() +async def test_connection_factory_field_overrides_batching_is_false( + sync_session_factory, +): + session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") @@ -718,7 +763,7 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() await schema.execute_async( """ query { @@ -755,8 +800,8 @@ def resolve_reporters(self, info): assert len(select_statements) == 1 -def test_connection_factory_field_overrides_batching_is_true(session_factory): - session = session_factory() +def test_connection_factory_field_overrides_batching_is_true(sync_session_factory): + session = sync_session_factory() reporter_1 = Reporter(first_name="Reporter_1") session.add(reporter_1) reporter_2 = Reporter(first_name="Reporter_2") @@ -788,7 +833,7 @@ def resolve_reporters(self, info): with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: # Starts new session to fully reset the engine / connection logging level - session = session_factory() + session = sync_session_factory() schema.execute( """ query { @@ -816,7 +861,9 @@ def resolve_reporters(self, info): @pytest.mark.asyncio -async def test_batching_across_nested_relay_schema(session_factory): +async def test_batching_across_nested_relay_schema( + session_factory, async_session: bool +): session = session_factory() for first_name in "fgerbhjikzutzxsdfdqqa": @@ -831,8 +878,8 @@ async def test_batching_across_nested_relay_schema(session_factory): reader.articles = [article] session.add(reader) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") schema = get_full_relay_schema() @@ -871,14 +918,17 @@ async def test_batching_across_nested_relay_schema(session_factory): result = to_std_dicts(result.data) select_statements = [message for message in messages if "SELECT" in message] - assert len(select_statements) == 4 - assert select_statements[-1].startswith("SELECT articles_1.id") - if is_sqlalchemy_version_less_than("1.3"): - assert select_statements[-2].startswith("SELECT reporters_1.id") - assert "WHERE reporters_1.id IN" in select_statements[-2] + if async_session: + assert len(select_statements) == 2 # TODO: Figure out why async has less calls else: - assert select_statements[-2].startswith("SELECT articles.reporter_id") - assert "WHERE articles.reporter_id IN" in select_statements[-2] + assert len(select_statements) == 4 + assert select_statements[-1].startswith("SELECT articles_1.id") + if is_sqlalchemy_version_less_than("1.3"): + assert select_statements[-2].startswith("SELECT reporters_1.id") + assert "WHERE reporters_1.id IN" in select_statements[-2] + else: + assert select_statements[-2].startswith("SELECT articles.reporter_id") + assert "WHERE articles.reporter_id IN" in select_statements[-2] @pytest.mark.asyncio @@ -892,8 +942,8 @@ async def test_sorting_can_be_used_with_batching_when_using_full_relay(session_f article_1.reporter = reporter_1 session.add(article_1) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") schema = get_full_relay_schema() diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py index bb105ed..dc656f4 100644 --- a/graphene_sqlalchemy/tests/test_benchmark.py +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -1,16 +1,61 @@ +import asyncio + import pytest +from sqlalchemy import select import graphene from graphene import relay from ..types import SQLAlchemyObjectType -from ..utils import is_sqlalchemy_version_less_than +from ..utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_session, + is_sqlalchemy_version_less_than, +) from .models import Article, HairKind, Pet, Reporter +from .utils import eventually_await_session +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession if is_sqlalchemy_version_less_than("1.2"): pytest.skip("SQL batching only works for SQLAlchemy 1.2+", allow_module_level=True) +def get_async_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + async def resolve_articles(self, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Article))).all() + return session.query(Article).all() + + async def resolve_reporters(self, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).all() + return session.query(Reporter).all() + + return graphene.Schema(query=Query) + + def get_schema(): class ReporterType(SQLAlchemyObjectType): class Meta: @@ -40,20 +85,30 @@ def resolve_reporters(self, info): return graphene.Schema(query=Query) -def benchmark_query(session_factory, benchmark, query): - schema = get_schema() +async def benchmark_query(session, benchmark, schema, query): + import nest_asyncio - @benchmark - def execute_query(): - result = schema.execute( - query, - context_value={"session": session_factory()}, + nest_asyncio.apply() + loop = asyncio.get_event_loop() + result = benchmark( + lambda: loop.run_until_complete( + schema.execute_async(query, context_value={"session": session}) ) - assert not result.errors + ) + assert not result.errors + + +@pytest.fixture(params=[get_schema, get_async_schema]) +def schema_provider(request, async_session): + if async_session and request.param == get_schema: + pytest.skip("Cannot test sync schema with async sessions") + return request.param -def test_one_to_one(session_factory, benchmark): +@pytest.mark.asyncio +async def test_one_to_one(session_factory, benchmark, schema_provider): session = session_factory() + schema = schema_provider() reporter_1 = Reporter( first_name="Reporter_1", @@ -72,12 +127,13 @@ def test_one_to_one(session_factory, benchmark): article_2.reporter = reporter_2 session.add(article_2) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") - benchmark_query( - session_factory, + await benchmark_query( + session, benchmark, + schema, """ query { reporters { @@ -91,9 +147,10 @@ def test_one_to_one(session_factory, benchmark): ) -def test_many_to_one(session_factory, benchmark): +@pytest.mark.asyncio +async def test_many_to_one(session_factory, benchmark, schema_provider): session = session_factory() - + schema = schema_provider() reporter_1 = Reporter( first_name="Reporter_1", ) @@ -110,13 +167,14 @@ def test_many_to_one(session_factory, benchmark): article_2 = Article(headline="Article_2") article_2.reporter = reporter_2 session.add(article_2) + await eventually_await_session(session, "flush") + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") - session.commit() - session.close() - - benchmark_query( - session_factory, + await benchmark_query( + session, benchmark, + schema, """ query { articles { @@ -130,8 +188,10 @@ def test_many_to_one(session_factory, benchmark): ) -def test_one_to_many(session_factory, benchmark): +@pytest.mark.asyncio +async def test_one_to_many(session_factory, benchmark, schema_provider): session = session_factory() + schema = schema_provider() reporter_1 = Reporter( first_name="Reporter_1", @@ -158,12 +218,13 @@ def test_one_to_many(session_factory, benchmark): article_4.reporter = reporter_2 session.add(article_4) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") - benchmark_query( - session_factory, + await benchmark_query( + session, benchmark, + schema, """ query { reporters { @@ -181,9 +242,10 @@ def test_one_to_many(session_factory, benchmark): ) -def test_many_to_many(session_factory, benchmark): +@pytest.mark.asyncio +async def test_many_to_many(session_factory, benchmark, schema_provider): session = session_factory() - + schema = schema_provider() reporter_1 = Reporter( first_name="Reporter_1", ) @@ -211,12 +273,13 @@ def test_many_to_many(session_factory, benchmark): reporter_2.pets.append(pet_3) reporter_2.pets.append(pet_4) - session.commit() - session.close() + await eventually_await_session(session, "commit") + await eventually_await_session(session, "close") - benchmark_query( - session_factory, + await benchmark_query( + session, benchmark, + schema, """ query { reporters { diff --git a/graphene_sqlalchemy/tests/test_enums.py b/graphene_sqlalchemy/tests/test_enums.py index cd97a00..3de6904 100644 --- a/graphene_sqlalchemy/tests/test_enums.py +++ b/graphene_sqlalchemy/tests/test_enums.py @@ -85,7 +85,10 @@ class Meta: assert enum._meta.name == "PetKind" assert [ (key, value.value) for key, value in enum._meta.enum.__members__.items() - ] == [("CAT", "cat"), ("DOG", "dog")] + ] == [ + ("CAT", "cat"), + ("DOG", "dog"), + ] enum2 = enum_for_field(PetType, "pet_kind") assert enum2 is enum enum2 = PetType.enum_for_field("pet_kind") diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 456254f..055a87f 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,11 +1,15 @@ from datetime import date +import pytest +from sqlalchemy import select + import graphene from graphene.relay import Node from ..converter import convert_sqlalchemy_composite from ..fields import SQLAlchemyConnectionField from ..types import ORMField, SQLAlchemyInterface, SQLAlchemyObjectType +from ..utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, get_session from .models import ( Article, CompositeFullName, @@ -16,10 +20,13 @@ Pet, Reporter, ) -from .utils import to_std_dicts +from .utils import eventually_await_session, to_std_dicts + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession -def add_test_data(session): +async def add_test_data(session): reporter = Reporter(first_name="John", last_name="Doe", favorite_pet_kind="cat") session.add(reporter) pet = Pet(name="Garfield", pet_kind="cat", hair_kind=HairKind.SHORT) @@ -35,11 +42,12 @@ def add_test_data(session): session.add(pet) editor = Editor(name="Jack") session.add(editor) - session.commit() + await eventually_await_session(session, "commit") -def test_query_fields(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_fields(session): + await add_test_data(session) @convert_sqlalchemy_composite.register(CompositeFullName) def convert_composite_class(composite, registry): @@ -53,10 +61,16 @@ class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) reporters = graphene.List(ReporterType) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() - def resolve_reporters(self, _info): + async def resolve_reporters(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().all() return session.query(Reporter) query = """ @@ -82,14 +96,15 @@ def resolve_reporters(self, _info): "reporters": [{"firstName": "John"}, {"firstName": "Jane"}], } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_query_node(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_node_sync(session): + await add_test_data(session) class ReporterNode(SQLAlchemyObjectType): class Meta: @@ -111,6 +126,14 @@ class Query(graphene.ObjectType): all_articles = SQLAlchemyConnectionField(ArticleNode.connection) def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + + async def get_result(): + return (await session.scalars(select(Reporter))).first() + + return get_result() + return session.query(Reporter).first() query = """ @@ -154,14 +177,100 @@ def resolve_reporter(self, _info): "myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + result = schema.execute(query, context_value={"session": session}) + assert result.errors + else: + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +@pytest.mark.asyncio +async def test_query_node_async(session): + await add_test_data(session) + + class ReporterNode(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + @classmethod + def get_node(cls, info, id): + return Reporter(id=2, first_name="Cookie Monster") + + class ArticleNode(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + + class Query(graphene.ObjectType): + node = Node.Field() + reporter = graphene.Field(ReporterNode) + all_articles = SQLAlchemyConnectionField(ArticleNode.connection) + + def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + + async def get_result(): + return (await session.scalars(select(Reporter))).first() + + return get_result() + + return session.query(Reporter).first() + + query = """ + query { + reporter { + id + firstName + articles { + edges { + node { + headline + } + } + } + } + allArticles { + edges { + node { + headline + } + } + } + myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") { + id + ... on ReporterNode { + firstName + } + ... on ArticleNode { + headline + } + } + } + """ + expected = { + "reporter": { + "id": "UmVwb3J0ZXJOb2RlOjE=", + "firstName": "John", + "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, + }, + "allArticles": {"edges": [{"node": {"headline": "Hi!"}}]}, + "myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"}, + } + schema = graphene.Schema(query=Query) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_orm_field(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_orm_field(session): + await add_test_data(session) @convert_sqlalchemy_composite.register(CompositeFullName) def convert_composite_class(composite, registry): @@ -187,7 +296,10 @@ class Meta: class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).first() return session.query(Reporter).first() query = """ @@ -221,14 +333,15 @@ def resolve_reporter(self, _info): }, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_custom_identifier(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_custom_identifier(session): + await add_test_data(session) class EditorNode(SQLAlchemyObjectType): class Meta: @@ -262,14 +375,15 @@ class Query(graphene.ObjectType): } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_mutation(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_mutation(session, session_factory): + await add_test_data(session) class EditorNode(SQLAlchemyObjectType): class Meta: @@ -282,8 +396,11 @@ class Meta: interfaces = (Node,) @classmethod - def get_node(cls, id, info): - return Reporter(id=2, first_name="Cookie Monster") + async def get_node(cls, id, info): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() + return session.query(Reporter).first() class ArticleNode(SQLAlchemyObjectType): class Meta: @@ -298,11 +415,14 @@ class Arguments: ok = graphene.Boolean() article = graphene.Field(ArticleNode) - def mutate(self, info, headline, reporter_id): + async def mutate(self, info, headline, reporter_id): + reporter = await ReporterNode.get_node(reporter_id, info) new_article = Article(headline=headline, reporter_id=reporter_id) + reporter.articles = [*reporter.articles, new_article] + session = get_session(info.context) + session.add(reporter) - session.add(new_article) - session.commit() + await eventually_await_session(session, "commit") ok = True return CreateArticle(article=new_article, ok=ok) @@ -341,24 +461,28 @@ class Mutation(graphene.ObjectType): } schema = graphene.Schema(query=Query, mutation=Mutation) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async( + query, context_value={"session": session_factory()} + ) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def add_person_data(session): +async def add_person_data(session): bob = Employee(name="Bob", birth_date=date(1990, 1, 1), hire_date=date(2015, 1, 1)) session.add(bob) joe = Employee(name="Joe", birth_date=date(1980, 1, 1), hire_date=date(2010, 1, 1)) session.add(joe) jen = Employee(name="Jen", birth_date=date(1995, 1, 1), hire_date=date(2020, 1, 1)) session.add(jen) - session.commit() + await eventually_await_session(session, "commit") -def test_interface_query_on_base_type(session): - add_person_data(session) +@pytest.mark.asyncio +async def test_interface_query_on_base_type(session_factory): + session = session_factory() + await add_person_data(session) class PersonType(SQLAlchemyInterface): class Meta: @@ -372,11 +496,13 @@ class Meta: class Query(graphene.ObjectType): people = graphene.Field(graphene.List(PersonType)) - def resolve_people(self, _info): + async def resolve_people(self, _info): + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Person))).all() return session.query(Person).all() schema = graphene.Schema(query=Query, types=[PersonType, EmployeeType]) - result = schema.execute( + result = await schema.execute_async( """ query { people { diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py index 923bbed..14c87f7 100644 --- a/graphene_sqlalchemy/tests/test_query_enums.py +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -1,12 +1,22 @@ +import pytest +from sqlalchemy import select + import graphene +from graphene_sqlalchemy.tests.utils import eventually_await_session +from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, get_session from ..types import SQLAlchemyObjectType from .models import HairKind, Pet, Reporter from .test_query import add_test_data, to_std_dicts +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession + -def test_query_pet_kinds(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_pet_kinds(session, session_factory): + await add_test_data(session) + await eventually_await_session(session, "close") class PetType(SQLAlchemyObjectType): class Meta: @@ -23,13 +33,25 @@ class Query(graphene.ObjectType): PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind")) ) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() - def resolve_reporters(self, _info): + async def resolve_reporters(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().all() return session.query(Reporter) - def resolve_pets(self, _info, kind): + async def resolve_pets(self, _info, kind): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + query = select(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind.value) + return (await session.scalars(query)).unique().all() query = session.query(Pet) if kind: query = query.filter_by(pet_kind=kind.value) @@ -78,13 +100,16 @@ def resolve_pets(self, _info, kind): "pets": [{"name": "Lassie", "petKind": "DOG"}], } schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = await schema.execute_async( + query, context_value={"session": session_factory()} + ) assert not result.errors assert result.data == expected -def test_query_more_enums(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_query_more_enums(session): + await add_test_data(session) class PetType(SQLAlchemyObjectType): class Meta: @@ -93,7 +118,10 @@ class Meta: class Query(graphene.ObjectType): pet = graphene.Field(PetType) - def resolve_pet(self, _info): + async def resolve_pet(self, _info): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Pet))).first() return session.query(Pet).first() query = """ @@ -107,14 +135,15 @@ def resolve_pet(self, _info): """ expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} schema = graphene.Schema(query=Query) - result = schema.execute(query) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected -def test_enum_as_argument(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_enum_as_argument(session): + await add_test_data(session) class PetType(SQLAlchemyObjectType): class Meta: @@ -125,7 +154,13 @@ class Query(graphene.ObjectType): PetType, kind=graphene.Argument(PetType.enum_for_field("pet_kind")) ) - def resolve_pet(self, info, kind=None): + async def resolve_pet(self, info, kind=None): + session = get_session(info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + query = select(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind.value) + return (await session.scalars(query)).first() query = session.query(Pet) if kind: query = query.filter(Pet.pet_kind == kind.value) @@ -142,19 +177,24 @@ def resolve_pet(self, info, kind=None): """ schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "CAT"}) + result = await schema.execute_async( + query, variables={"kind": "CAT"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} assert result.data == expected - result = schema.execute(query, variables={"kind": "DOG"}) + result = await schema.execute_async( + query, variables={"kind": "DOG"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} result = to_std_dicts(result.data) assert result == expected -def test_py_enum_as_argument(session): - add_test_data(session) +@pytest.mark.asyncio +async def test_py_enum_as_argument(session): + await add_test_data(session) class PetType(SQLAlchemyObjectType): class Meta: @@ -166,7 +206,14 @@ class Query(graphene.ObjectType): kind=graphene.Argument(PetType._meta.fields["hair_kind"].type.of_type), ) - def resolve_pet(self, _info, kind=None): + async def resolve_pet(self, _info, kind=None): + session = get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return ( + await session.scalars( + select(Pet).filter(Pet.hair_kind == HairKind(kind)) + ) + ).first() query = session.query(Pet) if kind: # enum arguments are expected to be strings, not PyEnums @@ -184,11 +231,15 @@ def resolve_pet(self, _info, kind=None): """ schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "SHORT"}) + result = await schema.execute_async( + query, variables={"kind": "SHORT"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} assert result.data == expected - result = schema.execute(query, variables={"kind": "LONG"}) + result = await schema.execute_async( + query, variables={"kind": "LONG"}, context_value={"session": session} + ) assert not result.errors expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} result = to_std_dicts(result.data) diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py index 11c7c9a..f8f1ff8 100644 --- a/graphene_sqlalchemy/tests/test_sort_enums.py +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -9,16 +9,17 @@ from ..utils import to_type_name from .models import Base, HairKind, KeyedModel, Pet from .test_query import to_std_dicts +from .utils import eventually_await_session -def add_pets(session): +async def add_pets(session): pets = [ Pet(id=1, name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG), Pet(id=2, name="Barf", pet_kind="dog", hair_kind=HairKind.LONG), Pet(id=3, name="Alf", pet_kind="cat", hair_kind=HairKind.LONG), ] session.add_all(pets) - session.commit() + await eventually_await_session(session, "commit") def test_sort_enum(): @@ -241,8 +242,9 @@ def get_symbol_name(column_name, sort_asc=True): assert sort_arg.default_value == ["IdUp"] -def test_sort_query(session): - add_pets(session) +@pytest.mark.asyncio +async def test_sort_query(session): + await add_pets(session) class PetNode(SQLAlchemyObjectType): class Meta: @@ -336,7 +338,7 @@ def makeNodes(nodeList): } # yapf: disable schema = Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) + result = await schema.execute_async(query, context_value={"session": session}) assert not result.errors result = to_std_dicts(result.data) assert result == expected @@ -352,7 +354,7 @@ def makeNodes(nodeList): } } """ - result = schema.execute(queryError, context_value={"session": session}) + result = await schema.execute_async(queryError, context_value={"session": session}) assert result.errors is not None assert "cannot represent non-enum value" in result.errors[0].message @@ -375,7 +377,7 @@ def makeNodes(nodeList): } """ - result = schema.execute(queryNoSort, context_value={"session": session}) + result = await schema.execute_async(queryNoSort, context_value={"session": session}) assert not result.errors # TODO: SQLite usually returns the results ordered by primary key, # so we cannot test this way whether sorting actually happens or not. diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 813fb13..6632842 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -4,6 +4,9 @@ import pytest import sqlalchemy.exc import sqlalchemy.orm.exc +from graphql.pyutils import is_awaitable +from sqlalchemy import select + from graphene import ( Boolean, Dynamic, @@ -20,7 +23,6 @@ ) from graphene.relay import Connection -from .models import Article, CompositeFullName, Employee, Person, Pet, Reporter, NonAbstractPerson from .. import utils from ..converter import convert_sqlalchemy_composite from ..fields import ( @@ -36,11 +38,26 @@ SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions, ) +from ..utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 +from .models import ( + Article, + CompositeFullName, + Employee, + NonAbstractPerson, + Person, + Pet, + Reporter, +) +from .utils import eventually_await_session + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession def test_should_raise_if_no_model(): re_err = r"valid SQLAlchemy Model" with pytest.raises(Exception, match=re_err): + class Character1(SQLAlchemyObjectType): pass @@ -48,12 +65,14 @@ class Character1(SQLAlchemyObjectType): def test_should_raise_if_model_is_invalid(): re_err = r"valid SQLAlchemy Model" with pytest.raises(Exception, match=re_err): + class Character(SQLAlchemyObjectType): class Meta: model = 1 -def test_sqlalchemy_node(session): +@pytest.mark.asyncio +async def test_sqlalchemy_node(session): class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -64,9 +83,11 @@ class Meta: reporter = Reporter() session.add(reporter) - session.commit() + await eventually_await_session(session, "commit") info = mock.Mock(context={"session": session}) reporter_node = ReporterType.get_node(info, reporter.id) + if is_awaitable(reporter_node): + reporter_node = await reporter_node assert reporter == reporter_node @@ -97,7 +118,7 @@ class Meta: assert sorted(list(ReporterType._meta.fields.keys())) == sorted( [ # Columns - "column_prop", + "column_prop", # SQLAlchemy retuns column properties first "id", "first_name", "last_name", @@ -320,6 +341,7 @@ def test_invalid_model_attr(): "Cannot map ORMField to a model attribute.\n" "Field: 'ReporterType.first_name'" ) with pytest.raises(ValueError, match=err_msg): + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -373,6 +395,7 @@ class Meta: def test_only_and_exclude_fields(): re_err = r"'only_fields' and 'exclude_fields' cannot be both set" with pytest.raises(Exception, match=re_err): + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -392,9 +415,19 @@ class Meta: assert first_name_field.type == Int -def test_resolvers(session): +@pytest.mark.asyncio +async def test_resolvers(session): """Test that the correct resolver functions are called""" + reporter = Reporter( + first_name="first_name", + last_name="last_name", + email="email", + favorite_pet_kind="cat", + ) + session.add(reporter) + await eventually_await_session(session, "commit") + class ReporterMixin(object): def resolve_id(root, _info): return "ID" @@ -420,20 +453,14 @@ def resolve_favorite_pet_kind_v2(root, _info): class Query(ObjectType): reporter = Field(ReporterType) - def resolve_reporter(self, _info): + async def resolve_reporter(self, _info): + session = utils.get_session(_info.context) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return (await session.scalars(select(Reporter))).unique().first() return session.query(Reporter).first() - reporter = Reporter( - first_name="first_name", - last_name="last_name", - email="email", - favorite_pet_kind="cat", - ) - session.add(reporter) - session.commit() - schema = Schema(query=Query) - result = schema.execute( + result = await schema.execute_async( """ query { reporter { @@ -446,7 +473,8 @@ def resolve_reporter(self, _info): favoritePetKindV2 } } - """ + """, + context_value={"session": session}, ) assert not result.errors @@ -511,8 +539,13 @@ class Meta: def test_interface_with_polymorphic_identity(): - with pytest.raises(AssertionError, - match=re.escape('PersonType: An interface cannot map to a concrete type (polymorphic_identity is "person")')): + with pytest.raises( + AssertionError, + match=re.escape( + 'PersonType: An interface cannot map to a concrete type (polymorphic_identity is "person")' + ), + ): + class PersonType(SQLAlchemyInterface): class Meta: model = NonAbstractPerson @@ -562,13 +595,15 @@ class Meta: # type should be in this list because we used ORMField # to force its presence on the model - assert sorted(list(EmployeeType._meta.fields.keys())) == sorted([ - "id", - "name", - "type", - "birth_date", - "hire_date", - ]) + assert sorted(list(EmployeeType._meta.fields.keys())) == sorted( + [ + "id", + "name", + "type", + "birth_date", + "hire_date", + ] + ) def test_interface_custom_resolver(): @@ -590,13 +625,15 @@ class Meta: # type should be in this list because we used ORMField # to force its presence on the model - assert sorted(list(EmployeeType._meta.fields.keys())) == sorted([ - "id", - "name", - "custom_field", - "birth_date", - "hire_date", - ]) + assert sorted(list(EmployeeType._meta.fields.keys())) == sorted( + [ + "id", + "name", + "custom_field", + "birth_date", + "hire_date", + ] + ) # Tests for connection_field_factory diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py index c90ee47..4a11824 100644 --- a/graphene_sqlalchemy/tests/utils.py +++ b/graphene_sqlalchemy/tests/utils.py @@ -1,3 +1,4 @@ +import inspect import re @@ -15,3 +16,11 @@ def remove_cache_miss_stat(message): """Remove the stat from the echoed query message when the cache is missed for sqlalchemy version >= 1.4""" # https://github.com/sqlalchemy/sqlalchemy/blob/990eb3d8813369d3b8a7776ae85fb33627443d30/lib/sqlalchemy/engine/default.py#L1177 return re.sub(r"\[generated in \d+.?\d*s\]\s", "", message) + + +async def eventually_await_session(session, func, *args): + + if inspect.iscoroutinefunction(getattr(session, func)): + await getattr(session, func)(*args) + else: + getattr(session, func)(*args) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index e0ada38..226d1e8 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,4 +1,6 @@ from collections import OrderedDict +from inspect import isawaitable +from typing import Any import sqlalchemy from sqlalchemy.ext.hybrid import hybrid_property @@ -26,7 +28,16 @@ ) from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver -from .utils import get_query, is_mapped_class, is_mapped_instance +from .utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + get_query, + get_session, + is_mapped_class, + is_mapped_instance, +) + +if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + from sqlalchemy.ext.asyncio import AsyncSession class ORMField(OrderedType): @@ -334,6 +345,11 @@ def __init_subclass_with_meta__( def is_type_of(cls, root, info): if isinstance(root, cls): return True + if isawaitable(root): + raise Exception( + "Received coroutine instead of sql alchemy model. " + "You seem to use an async engine with synchronous schema execution" + ) if not is_mapped_instance(root): raise Exception(('Received incompatible instance "{}".').format(root)) return isinstance(root, cls._meta.model) @@ -345,6 +361,19 @@ def get_query(cls, info): @classmethod def get_node(cls, info, id): + if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4: + try: + return cls.get_query(info).get(id) + except NoResultFound: + return None + + session = get_session(info.context) + if isinstance(session, AsyncSession): + + async def get_result() -> Any: + return await session.get(cls._meta.model, id) + + return get_result() try: return cls.get_query(info).get(id) except NoResultFound: diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 54bb840..62c71d8 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -4,11 +4,34 @@ from typing import Any, Callable, Dict, Optional import pkg_resources +from sqlalchemy import select from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError +def is_sqlalchemy_version_less_than(version_string): + """Check the installed SQLAlchemy version""" + return pkg_resources.get_distribution( + "SQLAlchemy" + ).parsed_version < pkg_resources.parse_version(version_string) + + +def is_graphene_version_less_than(version_string): # pragma: no cover + """Check the installed graphene version""" + return pkg_resources.get_distribution( + "graphene" + ).parsed_version < pkg_resources.parse_version(version_string) + + +SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = False + +if not is_sqlalchemy_version_less_than("1.4"): + from sqlalchemy.ext.asyncio import AsyncSession + + SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = True + + def get_session(context): return context.get("session") @@ -22,6 +45,8 @@ def get_query(model, context): "A query in the model Base or a session in the schema is required for querying.\n" "Read more http://docs.graphene-python.org/projects/sqlalchemy/en/latest/tips/#querying" ) + if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession): + return select(model) query = session.query(model) return query @@ -151,20 +176,6 @@ def sort_argument_for_model(cls, has_default=True): return Argument(List(enum), default_value=enum.default) -def is_sqlalchemy_version_less_than(version_string): # pragma: no cover - """Check the installed SQLAlchemy version""" - return pkg_resources.get_distribution( - "SQLAlchemy" - ).parsed_version < pkg_resources.parse_version(version_string) - - -def is_graphene_version_less_than(version_string): # pragma: no cover - """Check the installed graphene version""" - return pkg_resources.get_distribution( - "graphene" - ).parsed_version < pkg_resources.parse_version(version_string) - - class singledispatchbymatchfunction: """ Inspired by @singledispatch, this is a variant that works using a matcher function diff --git a/setup.py b/setup.py index ac9ad7e..9122baf 100644 --- a/setup.py +++ b/setup.py @@ -21,10 +21,13 @@ tests_require = [ "pytest>=6.2.0,<7.0", - "pytest-asyncio>=0.15.1", + "pytest-asyncio>=0.18.3", "pytest-cov>=2.11.0,<3.0", "sqlalchemy_utils>=0.37.0,<1.0", "pytest-benchmark>=3.4.0,<4.0", + "aiosqlite>=0.17.0", + "nest-asyncio", + "greenlet", ] setup(