Skip to content

Commit 7fc0958

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add a pass to convert rank-0 tensor to rank-1 tensor
Differential Revision: D69281867
1 parent b1d76c9 commit 7fc0958

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
from executorch.exir.pass_base import ExportPass, PassResult
3+
4+
class Rank0ToRank1Pass(ExportPass):
5+
"""
6+
Replace Rank-0 Tensor to Rank-1 Tensor for all the inputs.
7+
"""
8+
9+
def __init__(self) -> None:
10+
super().__init__()
11+
12+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
13+
for node in graph_module.graph.nodes:
14+
if node.op == "placeholder":
15+
node.meta["val"] = node.meta["val"].reshape(1, 1)
16+
graph_module.recompile()
17+
return PassResult(graph_module, True)

backends/transforms/targets.bzl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,20 @@ def define_common_targets():
187187
],
188188
)
189189

190+
runtime.python_library(
191+
name = "rank_0_to_rank_1",
192+
srcs = [
193+
"rank_0_to_rank_1.py",
194+
],
195+
visibility = [
196+
"//executorch/backends/...",
197+
],
198+
deps = [
199+
"//caffe2:torch",
200+
"//executorch/exir:pass_base",
201+
],
202+
)
203+
190204
runtime.python_test(
191205
name = "test_duplicate_dynamic_quant_chain",
192206
srcs = [

0 commit comments

Comments
 (0)