Skip to content

Commit 5abeae6

Browse files
tarun292facebook-github-bot
authored andcommitted
Fix placeholder lifetime bug in memory planning
Differential Revision: D66184849
1 parent 809a1a5 commit 5abeae6

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

exir/memory_planning.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,9 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
268268
)
269269

270270

271-
def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None:
271+
def update_tensor_lifetime(
272+
node: torch.fx.Node, spec: TensorSpec, node_idx: int
273+
) -> None:
272274
r"""
273275
Update the lifetime of the tensor to cover node_idx. A tensor's lifetime
274276
are represented by the index of the first and last node referring
@@ -279,7 +281,10 @@ def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None:
279281
node_idx: extend the tensor's lifetime to cover node_idx
280282
"""
281283
start, end = spec.lifetime
282-
start = node_idx if start is None or start > node_idx else start
284+
if node.op == "placeholder":
285+
start = 0
286+
else:
287+
start = node_idx if start is None or start > node_idx else start
283288
end = node_idx if end is None or end < node_idx else end
284289
spec.lifetime = [start, end]
285290

@@ -444,7 +449,7 @@ def update_all_tensors_lifetime(
444449
do_assertion=False,
445450
ignore_dynamic_unbound_tensor=False,
446451
):
447-
update_tensor_lifetime(spec, node_idx)
452+
update_tensor_lifetime(node, spec, node_idx)
448453
specs.add(spec)
449454
return specs
450455

exir/tests/test_memory_planning.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515
import torch
1616
from executorch.exir import ExecutorchBackendConfig, to_edge
17+
from executorch.exir.dialects._ops import ops as exir_ops
1718
from executorch.exir.memory_planning import (
1819
filter_nodes,
1920
get_node_tensor_specs,
2021
greedy,
2122
naive,
2223
Verifier,
2324
)
24-
from executorch.exir.pass_base import PassResult
25+
from executorch.exir.pass_base import ExportPass, PassResult
2526
from executorch.exir.pass_manager import PassManager
2627
from executorch.exir.passes import ( # noqa
2728
MemoryPlanningPass,
@@ -591,3 +592,49 @@ def count_planned_inputs(
591592
num_placeholders,
592593
5,
593594
)
595+
596+
def test_placeholder_lifetime(self) -> None:
597+
class TestModel(torch.nn.Module):
598+
def forward(self, a, b):
599+
a = a + b
600+
b = a + b
601+
return a, b
602+
603+
model = TestModel()
604+
example_inputs = (torch.rand(1, 6000, 2), torch.rand(1, 6000, 2))
605+
exported_model = torch.export.export(model, example_inputs)
606+
edge = to_edge(exported_model)
607+
608+
class TestPass(ExportPass):
609+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
610+
permute_dims = [1, 0, 2]
611+
for node in graph_module.graph.nodes:
612+
if node.op == "placeholder":
613+
inverse_dims = [
614+
permute_dims.index(x) for x in range(len(permute_dims))
615+
]
616+
617+
with graph_module.graph.inserting_after(node):
618+
permute = graph_module.graph.call_function(
619+
exir_ops.edge.aten.permute_copy.default,
620+
args=(node, inverse_dims),
621+
)
622+
permute.meta = node.meta.copy()
623+
node.meta["val"] = node.meta["val"].permute(permute_dims)
624+
node.replace_all_uses_with(
625+
permute, lambda x, permute=permute: x is not permute
626+
)
627+
break
628+
return PassResult(graph_module, True)
629+
630+
edge = edge.transform([TestPass()])
631+
et_program = edge.to_executorch().executorch_program
632+
inputs = et_program.execution_plan[0].inputs
633+
self.assertNotEqual(
634+
et_program.execution_plan[0]
635+
.values[inputs[0]]
636+
.val.allocation_info.memory_offset_low,
637+
et_program.execution_plan[0]
638+
.values[inputs[1]]
639+
.val.allocation_info.memory_offset_low,
640+
) # pyre-ignore

0 commit comments

Comments
 (0)