diff --git a/src/gino/schema.py b/src/gino/schema.py index cdb5c7fd..c859704d 100644 --- a/src/gino/schema.py +++ b/src/gino/schema.py @@ -398,31 +398,54 @@ async def _on_metadata_drop_async(self, target, bind, **kw): await getattr(t, "_on_metadata_drop_async")(target, bind, **kw) -async def _call_portable_instancemethod(fn, args, kw): +async def _call_portable_instancemethod(fn, target, connection, kw): m = None if hasattr(fn, "target"): m = getattr(fn.target, fn.name + "_async", None) if m is None: - return fn(*args, **kw) + return fn(target, connection, **kw) else: kw.update(fn.kwargs) - return await m(*args, **kw) + return await m(target, connection, **kw) class _Async: def __init__(self, listener): self._listener = listener - async def call(self, *args, **kw): + async def call(self, target, connection, **kw): for fn in self._listener.parent_listeners: - await _call_portable_instancemethod(fn, args, kw) + conn = _DelayedExecConn(connection) + await _call_portable_instancemethod(fn, target, conn, kw) + await conn.async_execute() for fn in self._listener.listeners: - await _call_portable_instancemethod(fn, args, kw) + conn = _DelayedExecConn(connection) + await _call_portable_instancemethod(fn, target, conn, kw) + await conn.async_execute() def __call__(self, *args, **kwargs): return self.call(*args, **kwargs) +class _DelayedExecConn: + # collect SQL statements to be executed in method :meth:`execute`, + # and to be triggered in :meth:`async_execute` + + def __init__(self, conn): + self._conn = conn + self._stmts = [] + + def __getattr__(self, item): + return getattr(self._conn, item) + + def execute(self, stmt, *args, **kwargs): + self._stmts.append((stmt, args, kwargs)) + + async def async_execute(self): + for i in self._stmts: + await self._conn.scalar(i[0], *i[1], **i[2]) + + def patch_schema(db): for st in {"Enum"}: setattr(db, st, type(st, (getattr(db, st), AsyncSchemaTypeMixin), {}))