Skip to content

Commit bbedc71

Browse files
ezyangpytorchmergebot
authored andcommitted
test: ensure editable cached wrapper is respected (pytorch#160943)
## Summary - add a test verifying that editing the local cache wrapper is picked up after Dynamo reset ## Testing - `lintrunner -a` *(fails: FLAKE8 failure, TEST_HAS_MAIN failure, CODESPELL failure, PYFMT failure)* - `PYTHONPATH=. python test/inductor/test_codecache.py TestPyCodeCache.test_editable_cached_wrapper -v` ------ https://chatgpt.com/codex/tasks/task_e_68a3aa3fcc9883239b17d1f4250d1e89 Pull Request resolved: pytorch#160943 Approved by: https://github.com/xmfan
1 parent e9481b6 commit bbedc71

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

test/inductor/test_codecache.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import subprocess
88
import sys
99
import tempfile
10+
import textwrap
1011
import unittest
1112
from contextlib import contextmanager
1213
from typing import Optional, Union
@@ -137,6 +138,100 @@ def test_linemaps_empty(self):
137138
stack_frames = PyCodeCache.stack_frames_for_code(path, 0)
138139
self.assertEqual(stack_frames, None)
139140

141+
def test_editable_cached_wrapper(self):
142+
with tempfile.TemporaryDirectory() as tmpdir:
143+
env = os.environ.copy()
144+
env["TORCHINDUCTOR_CACHE_DIR"] = tmpdir
145+
146+
step1 = textwrap.dedent(
147+
"""
148+
import glob
149+
import os
150+
import torch
151+
import warnings
152+
from torch._inductor import config
153+
154+
warnings.filterwarnings("ignore")
155+
config.fx_graph_cache = True
156+
config.fx_graph_remote_cache = False
157+
torch._dynamo.reset()
158+
159+
@torch.compile(backend="inductor")
160+
def f(x):
161+
return x * 2
162+
163+
f(torch.ones(2))
164+
cache_dir = os.environ["TORCHINDUCTOR_CACHE_DIR"]
165+
pyfiles = glob.glob(os.path.join(cache_dir, "**", "*.py"), recursive=True)
166+
print(pyfiles[0])
167+
"""
168+
)
169+
wrapper_path = (
170+
subprocess.check_output([sys.executable, "-c", step1], env=env)
171+
.decode()
172+
.strip()
173+
)
174+
175+
step2 = textwrap.dedent(
176+
"""
177+
import torch
178+
import warnings
179+
from torch._dynamo.utils import counters
180+
from torch._inductor import config
181+
182+
warnings.filterwarnings("ignore")
183+
config.fx_graph_cache = True
184+
config.fx_graph_remote_cache = False
185+
torch._dynamo.reset()
186+
187+
@torch.compile(backend="inductor")
188+
def f(x):
189+
return x * 2
190+
191+
f(torch.ones(2))
192+
print(counters["inductor"]["fxgraph_cache_hit"])
193+
"""
194+
)
195+
hit = (
196+
subprocess.check_output([sys.executable, "-c", step2], env=env)
197+
.decode()
198+
.strip()
199+
)
200+
self.assertEqual(hit, "1")
201+
202+
with open(wrapper_path) as f:
203+
src = f.read()
204+
with open(wrapper_path, "w") as f:
205+
f.write(
206+
src.replace(
207+
"def call(self, args):",
208+
"def call(self, args):\n print('debug')",
209+
)
210+
)
211+
212+
step3 = textwrap.dedent(
213+
"""
214+
import torch
215+
import warnings
216+
from torch._inductor import config
217+
218+
warnings.filterwarnings("ignore")
219+
config.fx_graph_cache = True
220+
config.fx_graph_remote_cache = False
221+
torch._dynamo.reset()
222+
223+
@torch.compile(backend="inductor")
224+
def f(x):
225+
return x * 2
226+
227+
f(torch.ones(2))
228+
"""
229+
)
230+
out = subprocess.check_output(
231+
[sys.executable, "-c", step3], env=env
232+
).decode()
233+
self.assertIn("debug", out)
234+
140235

141236
@instantiate_parametrized_tests
142237
class TestFxGraphCache(TestCase):

0 commit comments

Comments
 (0)