File tree Expand file tree Collapse file tree 2 files changed +31
-0
lines changed Expand file tree Collapse file tree 2 files changed +31
-0
lines changed Original file line number Diff line number Diff line change
1
+ import torch
2
+ from executorch .exir .pass_base import ExportPass , PassResult
3
+
4
+ class Rank0ToRank1Pass (ExportPass ):
5
+ """
6
+ Replace Rank-0 Tensor to Rank-1 Tensor for all the inputs.
7
+ """
8
+
9
+ def __init__ (self ) -> None :
10
+ super ().__init__ ()
11
+
12
+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
13
+ for node in graph_module .graph .nodes :
14
+ if node .op == "placeholder" :
15
+ node .meta ["val" ] = node .meta ["val" ].reshape (1 , 1 )
16
+ graph_module .recompile ()
17
+ return PassResult (graph_module , True )
Original file line number Diff line number Diff line change @@ -187,6 +187,20 @@ def define_common_targets():
187
187
],
188
188
)
189
189
190
+ runtime .python_library (
191
+ name = "rank_0_to_rank_1" ,
192
+ srcs = [
193
+ "rank_0_to_rank_1.py" ,
194
+ ],
195
+ visibility = [
196
+ "//executorch/backends/..." ,
197
+ ],
198
+ deps = [
199
+ "//caffe2:torch" ,
200
+ "//executorch/exir:pass_base" ,
201
+ ],
202
+ )
203
+
190
204
runtime .python_test (
191
205
name = "test_duplicate_dynamic_quant_chain" ,
192
206
srcs = [
You can’t perform that action at this time.
0 commit comments