Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 2489be5

Browse files
authored
Reduce overhead of AOT Module (#660)
Adding aot_module_simplified and aot_function_simplified Fallback to aot_module original until we prevent tracing of leaf modules
1 parent 0a8647c commit 2489be5

File tree

3 files changed

+107
-0
lines changed

3 files changed

+107
-0
lines changed

functorch/_src/aot_autograd.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,5 +527,79 @@ def forward(self, *args, **kwargs):
527527
return AOTModule()
528528

529529

530+
def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
531+
"""
532+
This is the simplified or low overhead version of aot_module. For frontends
533+
like TorchDynamo, the input functions/modules to AOT are static and have
534+
unpacked inputs/outputs. This gives us an opportunity to remove the
535+
(1) pytree overhead to parse inputs/outputs,
536+
(2) AOT Autograd cache,
537+
(3) Reading of params/buffers in every forward call
538+
539+
:func:`aot_module_simplified` removes these overheads.
540+
"""
541+
#########################################################
542+
543+
params = {
544+
**dict(_named_parameters(mod, remove_duplicate=False)),
545+
**dict(_named_buffers(mod, remove_duplicate=False)),
546+
}
547+
params_flat, params_spec = pytree.tree_flatten(params)
548+
params_flat = tuple(params_flat)
549+
params_len = len(params_flat)
550+
551+
def functional_call(*args, **kwargs):
552+
with _stateless.reparametrize_module(
553+
mod, pytree.tree_unflatten(args[:params_len], params_spec)
554+
):
555+
out = mod(*args[params_len:], **kwargs)
556+
if not isinstance(out, (tuple, list)):
557+
raise RuntimeError(
558+
"Graph output must be a tuple(). This is so that we can avoid "
559+
"pytree processing of the ouputs. Please change the module to "
560+
"have tuple outputs or use aot_module instead."
561+
)
562+
return out
563+
564+
def aot_function_simplified(
565+
fn: Callable,
566+
fw_compiler: Callable,
567+
bw_compiler: Optional[Callable] = None,
568+
partition_fn: Callable = default_partition,
569+
decompositions: Dict = {},
570+
hasher_type: str = "StaticShapeHasher",
571+
static_argnums: Optional[Tuple[int]] = None,
572+
) -> Callable:
573+
assert static_argnums is None
574+
if bw_compiler is None:
575+
bw_compiler = fw_compiler
576+
compiled_fn = create_aot_autograd_function(
577+
fn,
578+
fw_compiler,
579+
bw_compiler,
580+
partition_fn,
581+
decompositions,
582+
grad_state=torch.is_grad_enabled(),
583+
).apply
584+
585+
return compiled_fn
586+
587+
compiled_f = aot_function_simplified(functional_call, *top_args, **top_kwargs)
588+
589+
class AOTModule(nn.Module):
590+
def __init__(self):
591+
super(AOTModule, self).__init__()
592+
self.orig_module = mod
593+
594+
def forward(self, *args, **kwargs):
595+
return compiled_f(
596+
*params_flat,
597+
*args,
598+
**kwargs,
599+
)
600+
601+
return AOTModule()
602+
603+
530604
compiled_function = aot_function
531605
compiled_module = aot_module

functorch/compile/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
compiled_module,
99
num_of_recompilations,
1010
clear_compile_cache,
11+
aot_module_simplified,
1112
)
1213
from .._src.compilers import (
1314
ts_compile,

test/test_pythonkey.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
grad, vjp, vmap, jacrev,
1818
make_fx
1919
)
20+
from functorch._src.aot_autograd import aot_module_simplified
2021
from functorch.compile import (
2122
nnc_jit, compiled_function, compiled_module,
2223
min_cut_rematerialization_partition, aot_function, aot_module, decomposition_table, nop,
@@ -540,6 +541,37 @@ def f(x):
540541
torch.autograd.grad(out, inp, torch.randn(3, 2))
541542

542543

544+
class TestAOTModuleSimplified(TestCase):
545+
def test_aot_module_simplified(self):
546+
class MockModule(torch.nn.Module):
547+
def __init__(self):
548+
super().__init__()
549+
self.linear = torch.nn.Linear(20, 30)
550+
551+
def forward(self, x, y):
552+
return (self.linear(x) + y, )
553+
554+
mod = MockModule()
555+
mod.zero_grad()
556+
557+
x = torch.randn(128, 20, requires_grad=True)
558+
y = torch.randn(128, 30, requires_grad=True)
559+
inputs = [x, y]
560+
cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs]
561+
562+
ref = mod(*inputs)
563+
ref[0].sum().backward()
564+
565+
aot_mod = aot_module_simplified(mod, nop)
566+
aot_mod.zero_grad()
567+
res = aot_mod(*cloned_inputs)
568+
res[0].sum().backward()
569+
570+
assert torch.allclose(ref[0], res[0])
571+
assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad)
572+
assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad)
573+
574+
543575
only_for = ("cpu")
544576
instantiate_device_type_tests(
545577
TestPythonKey,

0 commit comments

Comments
 (0)