Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit ab81771

Browse files
authored
Force RNN modules to be inlined (#975)
They call Tensor.set_ internally with Storage, which is no go for AOTAutograd. Inline into them so that we can graph break. Fixes pytorch/functorch#586 Test strategy: ``` ./benchmarks/torchbench.py --inductor -dcuda --no-skip -k tts_angular ``` Note that inductor is still failing, but differently, after this PR. Signed-off-by: Edward Z. Yang <[email protected]> Signed-off-by: Edward Z. Yang <[email protected]>
1 parent ea455b7 commit ab81771

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

torchdynamo/allowed_functions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,20 @@ def _allowed_function_ids():
116116

117117
def _is_allowed_module_prefix(obj):
118118
allowed_modules = ("torch", "math")
119-
disallowed_modules = "torch.optim."
119+
# torch.nn.modules.rnn is disallowed because these modules internally
120+
# flatten their parameters. This flattening process will call
121+
# Tensor.set_ with a Storage, and Storages cannot be traced with
122+
# AOTAutograd; so we need to graph-break. To ensure this, we inline
123+
# these functions, rather than keep them opaque-ly in the graph.
124+
disallowed_modules = ("torch.optim.", "torch.nn.modules.rnn.")
120125
allowed_modules_dot = tuple([x + "." for x in allowed_modules])
121126
module = inspect.getmodule(obj)
122127
if module is None:
123128
return False
124129

125130
mod_name = module.__name__
126131

127-
if mod_name.startswith(disallowed_modules):
132+
if any(mod_name.startswith(m) for m in disallowed_modules):
128133
return False
129134

130135
return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)

0 commit comments

Comments
 (0)