diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 76e994abdbf..a90abada897 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -109,6 +109,7 @@ def test_op_dim_order_update(self) -> None: ) def test_op_dim_order_propagation(self) -> None: + print("test_op_dim_order_propagation: unambiguous path") MemoryFormatOpsPassTestUtils.memory_format_test_runner( self, MemoryFormatTestSet( @@ -126,6 +127,24 @@ def test_op_dim_order_propagation(self) -> None: ), ) + print("test_op_dim_order_propagation: ambiguous path") + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=PropagateToCopyChannalsLastModule().eval(), + op=torch.ops.aten._to_copy.default, + sample_input=( + torch.rand_like( + torch.zeros([2, 1, 2, 2]), + dtype=torch.float32, + memory_format=torch.contiguous_format, + ), + ), + target_memory_format=torch.channels_last, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + # Only test dim order replacement result in lean mode test. # This test is irrelevant with operator mode. def test_dim_order_replacement(self) -> None: diff --git a/exir/tests/test_memory_format_ops_pass_utils.py b/exir/tests/test_memory_format_ops_pass_utils.py index 8bf810e847e..10b8e32e7dc 100644 --- a/exir/tests/test_memory_format_ops_pass_utils.py +++ b/exir/tests/test_memory_format_ops_pass_utils.py @@ -20,8 +20,13 @@ is_channel_last_dim_order, is_contiguous_dim_order, ) +from executorch.exir.pass_base import ExportPass + +from exir.passes.memory_format_ops_pass import MemoryFormatOpsPass from torch.export import export + +from torch.fx.passes.infra.pass_manager import PassManager from torch.testing import FileCheck from torch.utils._pytree import tree_flatten @@ -99,6 +104,50 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return t1 * t2 +def assert_unambiguous_dim_order(gm): + # This is just an example, you can add your own pass or passes. + class ExampleNOPPass(ExportPass): + """ + Does nothing! + """ + + def call_operator(self, op, args, kwargs, meta): + return super().call_operator( + op, + args, + kwargs, + meta, + ) + + # This is an example of how one can detect ambiguous dim_order anywhere in the graph. + # You can be surgical and only detect it in the nodes you are interested in or something else. + def detect_ambiguity(gm): + """ + Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats. + """ + for node in gm.graph.nodes: + if node.op == "call_function": + tensor = node.meta["val"] + # Let's make sure dim_order is not ambiguous, raise otherwise. + # This is raising because we can't do anything about it. + # The right course of follow up action is to ask user to try with a different example input. + print(f"node: {node}, shape: {tensor.shape}, ", end="") + + try: + dim_order = tensor.dim_order( + ambiguity_check=[torch.contiguous_format, torch.channels_last] + ) + print(f"dim_order: {dim_order}") + except Exception as e: + print("") + raise RuntimeError(e) + + # any pass or passes, just using MemoryFormatOpsPass as an example + dim_order_pass_manager = PassManager(passes=[ExampleNOPPass()]) + dim_order_pass_manager.add_checks(detect_ambiguity) + dim_order_pass_manager(gm) + + class MemoryFormatOpsPassTestUtils: @staticmethod def memory_format_test_runner( @@ -121,6 +170,9 @@ def memory_format_test_runner( before, compile_config=EdgeCompileConfig(_skip_dim_order=False) ) + # Just as an example + assert_unambiguous_dim_order(epm.exported_program().graph_module) + # check memory format ops, if needed if test_set.op_level_check: aten_op_str, edge_op_str = MemoryFormatOps2Str[test_set.op]