|
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 | from executorch.exir import ExecutorchBackendConfig, to_edge
|
| 17 | +from executorch.exir.dialects._ops import ops as exir_ops |
17 | 18 | from executorch.exir.memory_planning import (
|
18 | 19 | filter_nodes,
|
19 | 20 | get_node_tensor_specs,
|
20 | 21 | greedy,
|
21 | 22 | naive,
|
22 | 23 | Verifier,
|
23 | 24 | )
|
24 |
| -from executorch.exir.pass_base import PassResult |
| 25 | +from executorch.exir.pass_base import ExportPass, PassResult |
25 | 26 | from executorch.exir.pass_manager import PassManager
|
26 | 27 | from executorch.exir.passes import ( # noqa
|
27 | 28 | MemoryPlanningPass,
|
@@ -591,3 +592,49 @@ def count_planned_inputs(
|
591 | 592 | num_placeholders,
|
592 | 593 | 5,
|
593 | 594 | )
|
| 595 | + |
| 596 | + def test_placeholder_lifetime(self) -> None: |
| 597 | + class TestModel(torch.nn.Module): |
| 598 | + def forward(self, a, b): |
| 599 | + a = a + b |
| 600 | + b = a + b |
| 601 | + return a, b |
| 602 | + |
| 603 | + model = TestModel() |
| 604 | + example_inputs = (torch.rand(1, 6000, 2), torch.rand(1, 6000, 2)) |
| 605 | + exported_model = torch.export.export(model, example_inputs) |
| 606 | + edge = to_edge(exported_model) |
| 607 | + |
| 608 | + class TestPass(ExportPass): |
| 609 | + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 610 | + permute_dims = [1, 0, 2] |
| 611 | + for node in graph_module.graph.nodes: |
| 612 | + if node.op == "placeholder": |
| 613 | + inverse_dims = [ |
| 614 | + permute_dims.index(x) for x in range(len(permute_dims)) |
| 615 | + ] |
| 616 | + |
| 617 | + with graph_module.graph.inserting_after(node): |
| 618 | + permute = graph_module.graph.call_function( |
| 619 | + exir_ops.edge.aten.permute_copy.default, |
| 620 | + args=(node, inverse_dims), |
| 621 | + ) |
| 622 | + permute.meta = node.meta.copy() |
| 623 | + node.meta["val"] = node.meta["val"].permute(permute_dims) |
| 624 | + node.replace_all_uses_with( |
| 625 | + permute, lambda x, permute=permute: x is not permute |
| 626 | + ) |
| 627 | + break |
| 628 | + return PassResult(graph_module, True) |
| 629 | + |
| 630 | + edge = edge.transform([TestPass()]) |
| 631 | + et_program = edge.to_executorch().executorch_program |
| 632 | + inputs = et_program.execution_plan[0].inputs |
| 633 | + self.assertNotEqual( |
| 634 | + et_program.execution_plan[0] |
| 635 | + .values[inputs[0]] |
| 636 | + .val.allocation_info.memory_offset_low, |
| 637 | + et_program.execution_plan[0] |
| 638 | + .values[inputs[1]] |
| 639 | + .val.allocation_info.memory_offset_low, |
| 640 | + ) # pyre-ignore |
0 commit comments