Skip to content

Commit 1f93e5b

Browse files
committed
Python: Relax restriction of flow through async with
1 parent 43af8d7 commit 1f93e5b

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

python/ql/lib/semmle/python/dataflow/new/internal/DataFlowPrivate.qll

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,12 @@ module EssaFlow {
304304
// see `with_flow` in `python/ql/src/semmle/python/dataflow/Implementation.qll`
305305
with.getContextExpr() = contextManager.getNode() and
306306
with.getOptionalVars() = var.getNode() and
307-
not with.isAsync() and
308307
contextManager.strictlyDominates(var)
308+
// note: we allow this for both `with` and `async with`, since some
309+
// implementations do `async def __aenter__(self): return self`, so you can do
310+
// both:
311+
// * `foo = x.foo(); await foo.async_methoid(); foo.close()` and
312+
// * `async with x.foo() as foo: await foo.async_method()`.
309313
)
310314
or
311315
// Async with var definition
@@ -314,6 +318,12 @@ module EssaFlow {
314318
// nodeTo is `x`, essa var
315319
//
316320
// This makes the cfg node the local source of the awaited value.
321+
//
322+
// We have this step in addition to the step above, to handle cases where the QL
323+
// modeling of `f(42)` requires a `.getAwaited()` step (in API graphs) when not
324+
// using `async with`, so you can do both:
325+
// * `foo = await x.foo(); await foo.async_methoid(); foo.close()` and
326+
// * `async with x.foo() as foo: await foo.async_method()`.
317327
exists(With with, ControlFlowNode var |
318328
nodeFrom.(CfgNode).getNode() = var and
319329
nodeTo.(EssaNode).getVar().getDefinition().(WithDefinition).getDefiningNode() = var and

python/ql/test/library-tests/frameworks/aiohttp/client_request.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ async def test():
77
resp = await s.request("method", url="url") # $ clientRequestUrlPart="url"
88

99
async with aiohttp.ClientSession() as session:
10-
resp = await session.get("url") # $ MISSING: clientRequestUrlPart="url"
11-
resp = await session.request(method="GET", url="url") # $ MISSING: clientRequestUrlPart="url"
10+
resp = await session.get("url") # $ clientRequestUrlPart="url"
11+
resp = await session.request(method="GET", url="url") # $ clientRequestUrlPart="url"
1212

1313
# other methods than GET
1414
s = aiohttp.ClientSession()

0 commit comments

Comments
 (0)