diff --git a/backends/transforms/rank_0_to_rank_1.py b/backends/transforms/rank_0_to_rank_1.py new file mode 100644 index 00000000000..81159efcc24 --- /dev/null +++ b/backends/transforms/rank_0_to_rank_1.py @@ -0,0 +1,18 @@ +import torch +from executorch.exir.pass_base import ExportPass, PassResult + + +class Rank0ToRank1Pass(ExportPass): + """ + Replace Rank-0 Tensor to Rank-1 Tensor for all the inputs. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.op == "placeholder" and node.meta["val"].shape == (): + node.meta["val"] = node.meta["val"].reshape(1, 1) + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 09ef0f59c59..c532798546d 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -187,6 +187,20 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "rank_0_to_rank_1", + srcs = [ + "rank_0_to_rank_1.py", + ], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + ], + ) + runtime.python_test( name = "test_duplicate_dynamic_quant_chain", srcs = [ @@ -200,3 +214,16 @@ def define_common_targets(): "//executorch/exir:lib", ], ) + + + runtime.python_test( + name = "test_rank_0_to_rank_1", + srcs = [ + "test/test_rank_0_to_rank_1.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + ":rank_0_to_rank_1", + ], + ) diff --git a/backends/transforms/test/test_rank_0_to_rank_1.py b/backends/transforms/test/test_rank_0_to_rank_1.py new file mode 100644 index 00000000000..50c6357fb67 --- /dev/null +++ b/backends/transforms/test/test_rank_0_to_rank_1.py @@ -0,0 +1,32 @@ +import unittest + +import torch +from executorch.backends.transforms.rank_0_to_rank_1 import Rank0ToRank1Pass +from executorch.exir import to_edge + + +class TestRank0ToRank1Pass(unittest.TestCase): + def test_pass( + self, + ): + class Model(torch.nn.Module): + def forward(self, x, y): + return x + y + + model = Model() + model.eval() + + example_inputs = (torch.tensor(1.0), torch.tensor(2.0)) + aten = torch.export.export(model, example_inputs) + + # Check that the input rank is 0 + for node in aten.graph.nodes: + if node.op == "placeholder": + self.assertTrue(node.meta["val"].shape == ()) + + edge = to_edge(aten).transform([Rank0ToRank1Pass()]) + + # Check that the input rank is 1 + for node in edge.exported_program().graph.nodes: + if node.op == "placeholder": + self.assertTrue(node.meta["val"].shape == (1, 1))