diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index 467d0bd84e..ab5dd7fe19 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -1,5 +1,6 @@ from typing import ( Any, + Callable, Dict, Mapping, Optional, @@ -17,15 +18,17 @@ from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession from sqlalchemy.ext.asyncio.result import _ensure_sync_result from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS +from sqlalchemy.orm import Session as _Session from sqlalchemy.orm._typing import OrmExecuteOptionsParameter from sqlalchemy.sql.base import Executable as _Executable from sqlalchemy.util.concurrency import greenlet_spawn -from typing_extensions import deprecated +from typing_extensions import Concatenate, ParamSpec, deprecated from ...orm.session import Session from ...sql.base import Executable from ...sql.expression import Select, SelectOfScalar +_P = ParamSpec("_P") _TSelectParam = TypeVar("_TSelectParam", bound=Any) @@ -148,3 +151,17 @@ async def execute( # type: ignore _parent_execute_state=_parent_execute_state, _add_event=_add_event, ) + + async def run_sync( + self, + fn: Callable[Concatenate[Session, _P], _TSelectParam], + *arg: _P.args, + **kw: _P.kwargs, + ) -> _TSelectParam: + base_fn = cast(Callable[Concatenate[_Session, _P], _TSelectParam], fn) + + return await super().run_sync( + base_fn, + *arg, + **kw, + )