diff --git a/torchdynamo/allowed_functions.py b/torchdynamo/allowed_functions.py index 7b17461b03..ffc04f8885 100644 --- a/torchdynamo/allowed_functions.py +++ b/torchdynamo/allowed_functions.py @@ -116,7 +116,12 @@ def _allowed_function_ids(): def _is_allowed_module_prefix(obj): allowed_modules = ("torch", "math") - disallowed_modules = "torch.optim." + # torch.nn.modules.rnn is disallowed because these modules internally + # flatten their parameters. This flattening process will call + # Tensor.set_ with a Storage, and Storages cannot be traced with + # AOTAutograd; so we need to graph-break. To ensure this, we inline + # these functions, rather than keep them opaque-ly in the graph. + disallowed_modules = ("torch.optim.", "torch.nn.modules.rnn.") allowed_modules_dot = tuple([x + "." for x in allowed_modules]) module = inspect.getmodule(obj) if module is None: @@ -124,7 +129,7 @@ def _is_allowed_module_prefix(obj): mod_name = module.__name__ - if mod_name.startswith(disallowed_modules): + if any(mod_name.startswith(m) for m in disallowed_modules): return False return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)