diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 6ec740bd9f4..6f0ab2a3922 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -268,7 +268,9 @@ def _is_inplace_node(node: torch.fx.Node) -> bool: ) -def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None: +def update_tensor_lifetime( + node: torch.fx.Node, spec: TensorSpec, node_idx: int +) -> None: r""" Update the lifetime of the tensor to cover node_idx. A tensor's lifetime are represented by the index of the first and last node referring @@ -279,7 +281,10 @@ def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None: node_idx: extend the tensor's lifetime to cover node_idx """ start, end = spec.lifetime - start = node_idx if start is None or start > node_idx else start + if node.op == "placeholder": + start = 0 + else: + start = node_idx if start is None or start > node_idx else start end = node_idx if end is None or end < node_idx else end spec.lifetime = [start, end] @@ -444,7 +449,7 @@ def update_all_tensors_lifetime( do_assertion=False, ignore_dynamic_unbound_tensor=False, ): - update_tensor_lifetime(spec, node_idx) + update_tensor_lifetime(node, spec, node_idx) specs.add(spec) return specs diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index 90398035e7d..5e4573a2bab 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -14,6 +14,7 @@ import torch from executorch.exir import ExecutorchBackendConfig, to_edge +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.memory_planning import ( filter_nodes, get_node_tensor_specs, @@ -21,7 +22,7 @@ naive, Verifier, ) -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.pass_manager import PassManager from executorch.exir.passes import ( # noqa MemoryPlanningPass, @@ -593,3 +594,65 @@ def count_planned_inputs( num_placeholders, 5, ) + + def test_placeholder_lifetime(self) -> None: + class TestModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, a, b, x): + a = a + b + b = a + b + y = self.linear(x) + return a, b, y + + model = TestModel() + example_inputs = (torch.rand(1, 6, 2), torch.rand(1, 6, 2), torch.randn(5, 5)) + exported_model = torch.export.export(model, example_inputs) + edge = to_edge(exported_model) + + class TestPass(ExportPass): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + permute_dims = [1, 0, 2] + for node in graph_module.graph.nodes: + if node.op == "placeholder" and str(node) == "a": + inverse_dims = [ + permute_dims.index(x) for x in range(len(permute_dims)) + ] + + with graph_module.graph.inserting_after(node): + permute = graph_module.graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(node, inverse_dims), + ) + permute.meta = node.meta.copy() + node.meta["val"] = node.meta["val"].permute(permute_dims) + node.replace_all_uses_with( + permute, lambda x, permute=permute: x is not permute + ) + break + return PassResult(graph_module, True) + + edge = edge.transform([TestPass()]) + et = edge.to_executorch() + et_program = et.executorch_program + inputs = et_program.execution_plan[0].inputs + self.assertNotEqual( + et_program.execution_plan[0] # pyre-ignore + .values[inputs[0]] + .val.allocation_info.memory_offset_low, + et_program.execution_plan[0] # pyre-ignore + .values[inputs[1]] + .val.allocation_info.memory_offset_low, + ) + + constants = 0 + for node in et.exported_program().graph_module.graph.nodes: + if node.op == "placeholder" and node.meta.get("spec"): + meta_spec = node.meta["spec"] + if meta_spec.const is True: + constants += 1 + self.assertIsNone(node.meta["spec"].mem_offset) + self.assertIsNone(node.meta["spec"].mem_id) + self.assertEqual(constants, 2)