Skip to content

Commit 9aa41fe

Browse files
tarun292facebook-github-bot
authored andcommitted
Fix placeholder lifetime bug in memory planning (#6971)
Summary: If we have a node that is inserted between placeholders via a pass after to_edge then there is a bug in memory planning lifetime calculations which results in two placeholders being allocated the same memory segment. Two placeholders should never be using the same memory segment and to prevent this we set the beginning lifetime of the placeholder to be always 0 in the memory planning pass. In the test case added this is the graph: ``` graph(): %a : [num_users=1] = placeholder[target=a] %aten_permute_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%a, [1, 0, 2]), kwargs = {}) %b : [num_users=2] = placeholder[target=b] %aten_add_tensor : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_permute_copy_default, %b), kwargs = {}) %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_add_tensor, %b), kwargs = {}) return (aten_add_tensor, aten_add_tensor_1) ``` Without this fix the lifetimes of placeholders a and b are: `a => [0, 2]` `b => [3, 7]` Thus the same memory segment is allocated for both of them. After this fix the lifetimes of the placeholders a and b are: `a => [0, 2]` `b => [0, 7]` Reviewed By: JacobSzwejbka Differential Revision: D66184849
1 parent c726a9b commit 9aa41fe

File tree

2 files changed

+72
-4
lines changed

2 files changed

+72
-4
lines changed

exir/memory_planning.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,9 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
268268
)
269269

270270

271-
def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None:
271+
def update_tensor_lifetime(
272+
node: torch.fx.Node, spec: TensorSpec, node_idx: int
273+
) -> None:
272274
r"""
273275
Update the lifetime of the tensor to cover node_idx. A tensor's lifetime
274276
are represented by the index of the first and last node referring
@@ -279,7 +281,10 @@ def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None:
279281
node_idx: extend the tensor's lifetime to cover node_idx
280282
"""
281283
start, end = spec.lifetime
282-
start = node_idx if start is None or start > node_idx else start
284+
if node.op == "placeholder":
285+
start = 0
286+
else:
287+
start = node_idx if start is None or start > node_idx else start
283288
end = node_idx if end is None or end < node_idx else end
284289
spec.lifetime = [start, end]
285290

@@ -444,7 +449,7 @@ def update_all_tensors_lifetime(
444449
do_assertion=False,
445450
ignore_dynamic_unbound_tensor=False,
446451
):
447-
update_tensor_lifetime(spec, node_idx)
452+
update_tensor_lifetime(node, spec, node_idx)
448453
specs.add(spec)
449454
return specs
450455

exir/tests/test_memory_planning.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515
import torch
1616
from executorch.exir import ExecutorchBackendConfig, to_edge
17+
from executorch.exir.dialects._ops import ops as exir_ops
1718
from executorch.exir.memory_planning import (
1819
filter_nodes,
1920
get_node_tensor_specs,
2021
greedy,
2122
naive,
2223
Verifier,
2324
)
24-
from executorch.exir.pass_base import PassResult
25+
from executorch.exir.pass_base import ExportPass, PassResult
2526
from executorch.exir.pass_manager import PassManager
2627
from executorch.exir.passes import ( # noqa
2728
MemoryPlanningPass,
@@ -593,3 +594,65 @@ def count_planned_inputs(
593594
num_placeholders,
594595
5,
595596
)
597+
598+
def test_placeholder_lifetime(self) -> None:
599+
class TestModel(torch.nn.Module):
600+
def __init__(self) -> None:
601+
super().__init__()
602+
self.linear = torch.nn.Linear(5, 5)
603+
604+
def forward(self, a, b, x):
605+
a = a + b
606+
b = a + b
607+
y = self.linear(x)
608+
return a, b, y
609+
610+
model = TestModel()
611+
example_inputs = (torch.rand(1, 6, 2), torch.rand(1, 6, 2), torch.randn(5, 5))
612+
exported_model = torch.export.export(model, example_inputs)
613+
edge = to_edge(exported_model)
614+
615+
class TestPass(ExportPass):
616+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
617+
permute_dims = [1, 0, 2]
618+
for node in graph_module.graph.nodes:
619+
if node.op == "placeholder" and str(node) == "a":
620+
inverse_dims = [
621+
permute_dims.index(x) for x in range(len(permute_dims))
622+
]
623+
624+
with graph_module.graph.inserting_after(node):
625+
permute = graph_module.graph.call_function(
626+
exir_ops.edge.aten.permute_copy.default,
627+
args=(node, inverse_dims),
628+
)
629+
permute.meta = node.meta.copy()
630+
node.meta["val"] = node.meta["val"].permute(permute_dims)
631+
node.replace_all_uses_with(
632+
permute, lambda x, permute=permute: x is not permute
633+
)
634+
break
635+
return PassResult(graph_module, True)
636+
637+
edge = edge.transform([TestPass()])
638+
et = edge.to_executorch()
639+
et_program = et.executorch_program
640+
inputs = et_program.execution_plan[0].inputs
641+
self.assertNotEqual(
642+
et_program.execution_plan[0] # pyre-ignore
643+
.values[inputs[0]]
644+
.val.allocation_info.memory_offset_low,
645+
et_program.execution_plan[0] # pyre-ignore
646+
.values[inputs[1]]
647+
.val.allocation_info.memory_offset_low,
648+
)
649+
650+
constants = 0
651+
for node in et.exported_program().graph_module.graph.nodes:
652+
if node.op == "placeholder" and node.meta.get("spec"):
653+
meta_spec = node.meta["spec"]
654+
if meta_spec.const is True:
655+
constants += 1
656+
self.assertIsNone(node.meta["spec"].mem_offset)
657+
self.assertIsNone(node.meta["spec"].mem_id)
658+
self.assertEqual(constants, 2)

0 commit comments

Comments
 (0)