Skip to content

Commit 7e74237

Browse files
[3.11] gh-98086: Now patch.dict can decorate async functions (GH-98095) (#99365)
gh-98086: Now ``patch.dict`` can decorate async functions (GH-98095) (cherry picked from commit 67b4d27) Co-authored-by: Nikita Sobolev <[email protected]> Co-authored-by: Nikita Sobolev <[email protected]>
1 parent 369cb3e commit 7e74237

File tree

3 files changed

+36
-0
lines changed

3 files changed

+36
-0
lines changed

Lib/unittest/mock.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,6 +1809,12 @@ def __init__(self, in_dict, values=(), clear=False, **kwargs):
18091809
def __call__(self, f):
18101810
if isinstance(f, type):
18111811
return self.decorate_class(f)
1812+
if inspect.iscoroutinefunction(f):
1813+
return self.decorate_async_callable(f)
1814+
return self.decorate_callable(f)
1815+
1816+
1817+
def decorate_callable(self, f):
18121818
@wraps(f)
18131819
def _inner(*args, **kw):
18141820
self._patch_dict()
@@ -1820,6 +1826,18 @@ def _inner(*args, **kw):
18201826
return _inner
18211827

18221828

1829+
def decorate_async_callable(self, f):
1830+
@wraps(f)
1831+
async def _inner(*args, **kw):
1832+
self._patch_dict()
1833+
try:
1834+
return await f(*args, **kw)
1835+
finally:
1836+
self._unpatch_dict()
1837+
1838+
return _inner
1839+
1840+
18231841
def decorate_class(self, klass):
18241842
for attr in dir(klass):
18251843
attr_value = getattr(klass, attr)

Lib/unittest/test/testmock/testasync.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,23 @@ async def test_async():
149149

150150
run(test_async())
151151

152+
def test_patch_dict_async_def(self):
153+
foo = {'a': 'a'}
154+
@patch.dict(foo, {'a': 'b'})
155+
async def test_async():
156+
self.assertEqual(foo['a'], 'b')
157+
158+
self.assertTrue(iscoroutinefunction(test_async))
159+
run(test_async())
160+
161+
def test_patch_dict_async_def_context(self):
162+
foo = {'a': 'a'}
163+
async def test_async():
164+
with patch.dict(foo, {'a': 'b'}):
165+
self.assertEqual(foo['a'], 'b')
166+
167+
run(test_async())
168+
152169

153170
class AsyncMockTest(unittest.TestCase):
154171
def test_iscoroutinefunction_default(self):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make sure ``patch.dict()`` can be applied on async functions.

0 commit comments

Comments
 (0)