Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,9 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
)


def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None:
def update_tensor_lifetime(
node: torch.fx.Node, spec: TensorSpec, node_idx: int
) -> None:
r"""
Update the lifetime of the tensor to cover node_idx. A tensor's lifetime
are represented by the index of the first and last node referring
Expand All @@ -279,7 +281,10 @@ def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None:
node_idx: extend the tensor's lifetime to cover node_idx
"""
start, end = spec.lifetime
start = node_idx if start is None or start > node_idx else start
if node.op == "placeholder":
start = 0
else:
start = node_idx if start is None or start > node_idx else start
end = node_idx if end is None or end < node_idx else end
spec.lifetime = [start, end]

Expand Down Expand Up @@ -444,7 +449,7 @@ def update_all_tensors_lifetime(
do_assertion=False,
ignore_dynamic_unbound_tensor=False,
):
update_tensor_lifetime(spec, node_idx)
update_tensor_lifetime(node, spec, node_idx)
specs.add(spec)
return specs

Expand Down
65 changes: 64 additions & 1 deletion exir/tests/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

import torch
from executorch.exir import ExecutorchBackendConfig, to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.memory_planning import (
filter_nodes,
get_node_tensor_specs,
greedy,
naive,
Verifier,
)
from executorch.exir.pass_base import PassResult
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_manager import PassManager
from executorch.exir.passes import ( # noqa
MemoryPlanningPass,
Expand Down Expand Up @@ -593,3 +594,65 @@ def count_planned_inputs(
num_placeholders,
5,
)

def test_placeholder_lifetime(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 5)

def forward(self, a, b, x):
a = a + b
b = a + b
y = self.linear(x)
return a, b, y

model = TestModel()
example_inputs = (torch.rand(1, 6, 2), torch.rand(1, 6, 2), torch.randn(5, 5))
exported_model = torch.export.export(model, example_inputs)
edge = to_edge(exported_model)

class TestPass(ExportPass):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
permute_dims = [1, 0, 2]
for node in graph_module.graph.nodes:
if node.op == "placeholder" and str(node) == "a":
inverse_dims = [
permute_dims.index(x) for x in range(len(permute_dims))
]

with graph_module.graph.inserting_after(node):
permute = graph_module.graph.call_function(
exir_ops.edge.aten.permute_copy.default,
args=(node, inverse_dims),
)
permute.meta = node.meta.copy()
node.meta["val"] = node.meta["val"].permute(permute_dims)
node.replace_all_uses_with(
permute, lambda x, permute=permute: x is not permute
)
break
return PassResult(graph_module, True)

edge = edge.transform([TestPass()])
et = edge.to_executorch()
et_program = et.executorch_program
inputs = et_program.execution_plan[0].inputs
self.assertNotEqual(
et_program.execution_plan[0] # pyre-ignore
.values[inputs[0]]
.val.allocation_info.memory_offset_low,
et_program.execution_plan[0] # pyre-ignore
.values[inputs[1]]
.val.allocation_info.memory_offset_low,
)

constants = 0
for node in et.exported_program().graph_module.graph.nodes:
if node.op == "placeholder" and node.meta.get("spec"):
meta_spec = node.meta["spec"]
if meta_spec.const is True:
constants += 1
self.assertIsNone(node.meta["spec"].mem_offset)
self.assertIsNone(node.meta["spec"].mem_id)
self.assertEqual(constants, 2)