Skip to content

Commit ee1c244

Browse files
williamwen42pytorchmergebot
authored andcommitted
[dynamo] delete dynamo cache entry when guard function is invalidated [attempt 2] (pytorch#119107)
Attempt #2 for pytorch#117875 to fix pytorch#112090. Summary of changes: - ~Changed CacheEntry linked list into a doubly-linked list structure to support deletion.~ (done by C++ refactor) - Added CacheEntry and ExtraState borrowed references to GuardFn so that GuardFn can tell ExtraState to delete CacheEntry when the GuardFn is invalidated. - ~Added ExtraState raw reference to CacheEntry so that we can get ExtraState to correctly point to the first CacheEntry if it gets deleted.~ (done by C++ refactor) - CacheEntry destructor needs to reset GuardFn refs to ExtraState/CacheEntry in order to prevent use-after-free. - code_context values that are nn.GraphModules need to be weakrefs in order to prevent circular references. - Added tests that check for memory leaks and cache deletion operations. Pull Request resolved: pytorch#119107 Approved by: https://github.com/jansel
1 parent fcc36de commit ee1c244

13 files changed

+169
-29
lines changed

test/dynamo/test_misc.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9412,6 +9412,44 @@ def fn(x):
94129412
c2 = _debug_get_cache_entry_list(fn.__code__)
94139413
self.assertEqual(len(c2), 0)
94149414

9415+
@unittest.skipIf(not TEST_CUDA, "requires cuda")
9416+
def test_module_free(self):
9417+
"""Test that CUDA memory is freed when a model goes out of scope"""
9418+
9419+
class Mod(torch.nn.Module):
9420+
def __init__(self):
9421+
super(Mod, self).__init__()
9422+
self.fc = torch.nn.Linear(10000, 10000)
9423+
9424+
def forward(self, out):
9425+
return self.fc(out)
9426+
9427+
def run(compile):
9428+
mod = Mod().cuda()
9429+
if compile:
9430+
mod = torch.compile(mod, backend="eager")
9431+
inp = torch.rand(10000, 10000).cuda()
9432+
mod(inp)
9433+
9434+
def clean_and_report_memory():
9435+
import gc
9436+
9437+
gc.collect()
9438+
return torch.cuda.memory_allocated()
9439+
9440+
run(False)
9441+
# mem1 = clean_and_report_memory()
9442+
run(True)
9443+
mem2 = clean_and_report_memory()
9444+
torch._dynamo.reset_code_caches()
9445+
mem3 = clean_and_report_memory()
9446+
9447+
# it's possible for dynamo to hold on to more memory
9448+
# even after a _dynamo.reset[_code_caches], so we omit the following check.
9449+
# self.assertEqual(mem1, mem2)
9450+
9451+
self.assertEqual(mem2, mem3)
9452+
94159453
def test_dynamo_cache_move_to_front(self):
94169454
class Mod(torch.nn.Module):
94179455
def __init__(self):
@@ -9445,6 +9483,56 @@ def fn(x, mod):
94459483
c2 = _debug_get_cache_entry_list(fn.__code__)
94469484
self.assertIs(c1[1], c2[0])
94479485

9486+
def test_dynamo_cache_invalidate(self):
9487+
class Mod(torch.nn.Module):
9488+
def __init__(self):
9489+
super(Mod, self).__init__()
9490+
self.fc = torch.nn.Linear(3, 3)
9491+
9492+
def forward(self, out):
9493+
return self.fc(out)
9494+
9495+
def fn(x, mod):
9496+
return mod(x)
9497+
9498+
opt_fn = torch.compile(fn, backend="eager")
9499+
9500+
m1 = Mod()
9501+
m2 = Mod()
9502+
m3 = Mod()
9503+
inp = torch.randn(3, 3)
9504+
9505+
# NOTE: assumes that each cache entry is guarded
9506+
# on unique Mod instance
9507+
opt_fn(inp, m1)
9508+
opt_fn(inp, m2)
9509+
opt_fn(inp, m3)
9510+
9511+
c1 = _debug_get_cache_entry_list(fn.__code__)
9512+
self.assertEqual(len(c1), 3)
9513+
9514+
# move cache entry to front
9515+
opt_fn(inp, m2)
9516+
c2 = _debug_get_cache_entry_list(fn.__code__)
9517+
self.assertIs(c1[1], c2[0])
9518+
9519+
# delete center of cache
9520+
del m3
9521+
c3 = _debug_get_cache_entry_list(fn.__code__)
9522+
self.assertEqual(len(c3), 2)
9523+
self.assertIs(c3[0], c2[0])
9524+
self.assertIs(c3[1], c2[2])
9525+
9526+
# delete end of cache
9527+
del m1
9528+
c4 = _debug_get_cache_entry_list(fn.__code__)
9529+
self.assertEqual(len(c4), 1)
9530+
self.assertIs(c4[0], c3[0])
9531+
9532+
del m2
9533+
c5 = _debug_get_cache_entry_list(fn.__code__)
9534+
self.assertEqual(len(c5), 0)
9535+
94489536

94499537
class TestTracer(JitTestCase):
94509538
def test_jit_save(self):

torch/_C/_dynamo/eval_frame.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,7 @@ class _CacheEntry:
1818
code: types.CodeType
1919
next: Optional[_CacheEntry]
2020

21+
class _ExtraState:
22+
def invalidate(self, cache_entry: _CacheEntry): ...
23+
2124
def _debug_get_cache_entry_list(code: types.CodeType) -> List[_CacheEntry]: ...

torch/_dynamo/__init__.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,7 @@
6767
def reset() -> None:
6868
"""Clear all compile caches and restore initial state"""
6969
with convert_frame.compile_lock:
70-
for weak_code in (
71-
convert_frame.input_codes.seen + convert_frame.output_codes.seen
72-
):
73-
code = weak_code()
74-
if code:
75-
reset_code(code)
70+
reset_code_caches()
7671
convert_frame.input_codes.clear()
7772
convert_frame.output_codes.clear()
7873
orig_code_map.clear()
@@ -82,4 +77,15 @@ def reset() -> None:
8277
_reset_guarded_backend_cache()
8378
reset_frame_count()
8479
torch._C._dynamo.compiled_autograd.clear_cache()
80+
81+
82+
def reset_code_caches() -> None:
83+
"""Clear compile caches that are keyed by code objects"""
84+
with convert_frame.compile_lock:
85+
for weak_code in (
86+
convert_frame.input_codes.seen + convert_frame.output_codes.seen
87+
):
88+
code = weak_code()
89+
if code:
90+
reset_code(code)
8591
code_context.clear()

torch/_dynamo/eval_frame.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import traceback
2121
import types
2222
import warnings
23+
import weakref
2324
from enum import Enum
2425
from os.path import dirname, join
2526
from typing import (
@@ -384,7 +385,9 @@ def get_compiler_config():
384385
# Assume that the underlying node metadata of `fn`,
385386
# a GraphModule instance, accurately represents
386387
# all instances of type(fn).
387-
code_context.get_context(fn.forward.__code__)["orig_graphmodule"] = fn
388+
code_context.get_context(fn.forward.__code__)[
389+
"orig_graphmodule"
390+
] = weakref.ref(fn)
388391

389392
# Optimize the forward method of torch.nn.Module object
390393
if isinstance(fn, torch.nn.Module):

torch/_dynamo/guards.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from . import config, convert_frame, exc, mutation_guard
5656
from .eval_frame import set_guard_error_hook
5757
from .source import DefaultsSource, LocalSource, TypeSource
58-
from .types import GuardedCode, GuardFail, GuardFn # noqa: F401
58+
from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401
5959
from .utils import (
6060
common_constant_types,
6161
dict_keys_repr,
@@ -931,24 +931,22 @@ def must_add_nn_module_guards(guard):
931931
)
932932

933933

934+
class DeletedGuardFn:
935+
pass
936+
937+
934938
# NB: Naively, you'd expect this to only be a function that produces
935939
# the callable that constitutes the guard. However, there is some
936940
# delicate handling for invalidating this check function when the
937941
# locals/globals get invalidated, so there's some extra state
938942
# we have to hold in this manager class.
939-
#
940-
# TODO: this object has reference cycle with itself, via check_fn which
941-
# references back to CheckFunction via ___guarded_code in closure_vars.
942-
# Ideally, there shouldn't be any ref cycle so that guards are
943-
# promptly disposed of.
944943
class CheckFunctionManager:
945944
def __init__(
946945
self,
947946
output_graph=None,
948947
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
949948
):
950949
guards = output_graph.guards if output_graph else None
951-
self.valid = True
952950
self._weakrefs: Dict[int, ReferenceType[object]] = {}
953951
self.output_graph = output_graph
954952

@@ -1025,7 +1023,7 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn):
10251023
guards_log.debug("GUARDS:")
10261024

10271025
# Don't report this guard, it's always the same, useless!
1028-
code_parts = ["___guarded_code.valid", "___check_global_state()"]
1026+
code_parts = ["___check_global_state()"]
10291027
verbose_code_parts = code_parts[:]
10301028

10311029
def add_code_part(code, guard, log_only=False):
@@ -1157,7 +1155,6 @@ def convert(size_or_stride):
11571155
# we should only hit this case in NopTests()
11581156
global_state = convert_frame.GlobalStateGuard()
11591157
closure_vars = {
1160-
"___guarded_code": self,
11611158
"___check_tensors": check_tensors_fn,
11621159
"___check_tensors_verbose": check_tensors_verbose_fn,
11631160
"___check_global_state": global_state.check,
@@ -1194,14 +1191,28 @@ def convert(size_or_stride):
11941191
# Grab only G, but preserve "G" because guards access it as "G"
11951192
guard_fn.global_scope = globals_for_guard_fn
11961193
guard_fn.guard_fail_fn = guard_fail_fn
1194+
# will be populated by a non-owning reference to CacheEntry/ExtraState
1195+
# when the CacheEntry is constructed
1196+
guard_fn.cache_entry = None
1197+
guard_fn.extra_state = None
11971198
return guard_fn
11981199

11991200
def invalidate(self):
1200-
# A weakref is no longer valid, self.check_fn should return false
1201-
# TODO(janimesh) - Free up cache entry after the cache entry formation
1202-
# is in python, and the underlying data structure is a doubly linked
1203-
# list.
1204-
self.valid = False
1201+
# Some tests reveal that CheckFunctionManager has no attribute
1202+
# check_fn, but this case should not be of any concern.
1203+
# This case doesn't seem easy to repro.
1204+
if (
1205+
hasattr(self, "check_fn")
1206+
and self.check_fn is not DeletedGuardFn
1207+
and (cache_entry := self.check_fn.cache_entry) is not None
1208+
and (extra_state := self.check_fn.extra_state) is not None
1209+
):
1210+
assert isinstance(cache_entry, CacheEntry)
1211+
assert isinstance(extra_state, ExtraState)
1212+
extra_state.invalidate(cache_entry)
1213+
self.check_fn.cache_entry = None
1214+
self.check_fn.extra_state = None
1215+
self.check_fn = DeletedGuardFn
12051216

12061217
def id_ref(self, obj):
12071218
"""add a weakref, return the id"""

torch/_dynamo/output_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,8 +1669,8 @@ def get_trace_call_log_str():
16691669
is_retracing = False
16701670
if tx.f_code is not self._cur_code:
16711671
orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
1672-
"orig_graphmodule", None
1673-
)
1672+
"orig_graphmodule", lambda: None
1673+
)()
16741674
if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
16751675
is_retracing = True
16761676
self._orig_gm_meta = [

torch/_dynamo/symbolic_convert.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2194,12 +2194,12 @@ def create_call_resume_at(self, inst):
21942194
# Add original GraphModule context to the resume function to handle
21952195
# the case of a graph break while tracing a GraphModule
21962196
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
2197-
"orig_graphmodule", None
2198-
)
2197+
"orig_graphmodule", lambda: None
2198+
)()
21992199
if orig_graphmodule_maybe is not None:
2200-
code_context.get_context(new_code)[
2201-
"orig_graphmodule"
2202-
] = orig_graphmodule_maybe
2200+
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
2201+
orig_graphmodule_maybe
2202+
)
22032203

22042204
if new_code.co_freevars:
22052205
cg.make_function_with_closure(name, new_code, True, stack_len)
@@ -2347,7 +2347,7 @@ def get_trace_call_log_str():
23472347
# but it is enough to add a context for `forward` in case it is called.
23482348
code_context.get_context(module.forward.__code__)[
23492349
"orig_graphmodule"
2350-
] = module
2350+
] = weakref.ref(module)
23512351

23522352
tracer: InliningInstructionTranslator
23532353
if is_generator(code):

torch/_dynamo/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
# and a `code` field for the code object.
2020
CacheEntry = torch._C._dynamo.eval_frame._CacheEntry
2121

22+
ExtraState = torch._C._dynamo.eval_frame._ExtraState
23+
2224
# We use a dict to store additional data per frame.
2325
FrameState = Dict[Any, Any]
2426

@@ -37,6 +39,8 @@ class GuardFn(Protocol):
3739
verbose_code_parts: List[str]
3840
global_scope: Dict[str, object]
3941
guard_fail_fn: Optional[Callable[[GuardFail], None]]
42+
cache_entry: Optional[CacheEntry]
43+
extra_state: Optional[ExtraState]
4044

4145
# maps locals of user function to bool
4246
def __call__(self, f_locals: Dict[str, object]) -> bool:

torch/csrc/dynamo/cache_entry.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ CacheEntry::CacheEntry(const py::handle& guarded_code) {
88
this->code = guarded_code.attr("code");
99
}
1010

11+
CacheEntry::~CacheEntry() {
12+
// prevent check_fn from use-after-free when invalidating
13+
this->check_fn.attr("cache_entry") = py::none();
14+
this->check_fn.attr("extra_state") = py::none();
15+
}
16+
1117
py::object CacheEntry::next() {
1218
NULL_CHECK(this->_owner);
1319
auto it = this->_owner_loc;

torch/csrc/dynamo/cache_entry.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ typedef struct VISIBILITY_HIDDEN CacheEntry {
4949
std::list<CacheEntry>::iterator _owner_loc;
5050

5151
CacheEntry(const py::handle& guarded_code);
52+
~CacheEntry();
5253

5354
// Warning: returns a reference whose lifetime is controlled by C++
5455
py::object next();

torch/csrc/dynamo/extra_state.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ void ExtraState::move_to_front(CacheEntry* cache_entry) {
2222
cache_entry->_owner_loc);
2323
}
2424

25+
void ExtraState::invalidate(CacheEntry* cache_entry) {
26+
CHECK(cache_entry->_owner == this);
27+
CHECK(!this->cache_entry_list.empty());
28+
CHECK(cache_entry == &*cache_entry->_owner_loc);
29+
this->cache_entry_list.erase(cache_entry->_owner_loc);
30+
}
31+
2532
CacheEntry* extract_cache_entry(ExtraState* extra_state) {
2633
if (extra_state == NULL || extra_state == SKIP_CODE) {
2734
return NULL;
@@ -110,6 +117,13 @@ CacheEntry* create_cache_entry(
110117
auto new_iter = extra_state->cache_entry_list.begin();
111118
new_iter->_owner = extra_state;
112119
new_iter->_owner_loc = new_iter;
120+
// Set check_fn references to extra_state and CacheEntry
121+
// Warning: lifetime is controlled by C++!
122+
py::handle check_fn = py::handle(guarded_code).attr("check_fn");
123+
check_fn.attr("cache_entry") =
124+
py::cast(*new_iter, py::return_value_policy::reference);
125+
check_fn.attr("extra_state") =
126+
py::cast(extra_state, py::return_value_policy::reference);
113127
return &*new_iter;
114128
}
115129

torch/csrc/dynamo/extra_state.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ typedef struct VISIBILITY_HIDDEN ExtraState {
4141

4242
CacheEntry* get_first_entry();
4343
void move_to_front(CacheEntry* cache_entry);
44+
void invalidate(CacheEntry* cache_entry);
4445
} ExtraState;
4546

4647
#else

torch/csrc/dynamo/init.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ void initDynamoBindings(PyObject* torch) {
4444
.def_readonly("code", &CacheEntry::code)
4545
.def_property_readonly("next", &CacheEntry::next);
4646

47+
py::class_<ExtraState>(m, "_ExtraState")
48+
.def("invalidate", &ExtraState::invalidate);
49+
4750
m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list);
4851
}
4952

0 commit comments

Comments
 (0)