diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index c65aff6bd..7c5bec2ca 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -12,6 +12,12 @@ from .named_members_polyfill import _named_parameters, _named_buffers from typing import Callable, List, Dict, Any, Tuple, Optional +try: + from torchdynamo import disable as disable_torchdynamo +except ImportError: + def disable_torchdynamo(x): + return x + pytree._register_pytree_node( immutable_collections.immutable_list, lambda x: (list(x), None), @@ -129,6 +135,7 @@ def create_aot_autograd_function( class CompiledFunction(torch.autograd.Function): @staticmethod + @disable_torchdynamo def forward(ctx, *flat_tensor_args): nonlocal compiled_fw, compiled_bw, num_outs if compiled_fw is None: @@ -163,6 +170,7 @@ def forward(ctx, *flat_tensor_args): return tuple(fw_outs[0:num_outs]) @staticmethod + @disable_torchdynamo def backward(ctx, *flat_args): contiguous_args = [t.contiguous() for t in flat_args] # contiguous_args = [t for t in flat_args]