1
1
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2
- # All rights reserved.
3
2
#
4
3
# This source code is licensed under the BSD-style license found in the
5
4
# LICENSE file in the root directory of this source tree.
@@ -36,8 +35,8 @@ def forward(self, x: torch.Tensor):
36
35
return torch .full ((2 , 2 , 3 , 3 ), 4.5 , dtype = torch .float32 ) + x
37
36
38
37
class AddVariableFull (torch .nn .Module ):
39
- sizes = [
40
- (5 ),
38
+ sizes : list [ tuple [ int , ...]] = [
39
+ (5 , ),
41
40
(5 , 5 ),
42
41
(5 , 5 , 5 ),
43
42
(1 , 5 , 5 , 5 ),
@@ -48,6 +47,21 @@ def forward(self, x: torch.Tensor, y):
48
47
# Input + a full with the shape from the input and a given value 'y'.
49
48
return x + torch .full (x .shape , y )
50
49
50
+ class FullLike (torch .nn .Module ):
51
+ """Since full_like is replaced with full, we only need to test on reference model, not FVP."""
52
+
53
+ test_parameters = [
54
+ ((torch .randn (2 , 2 , 2 , 2 ) * 50 , 3.2 ),),
55
+ ((torch .randn (2 , 2 , 2 , 2 ) * 50 , 3 ),),
56
+ (((torch .randn (2 , 2 , 2 , 2 ) * 50 ).to (torch .int32 ), 3.2 ),),
57
+ (((torch .randn (2 , 2 , 2 , 2 ) * 50 ).to (torch .int32 ), 3 ),),
58
+ ]
59
+
60
+ def forward (self , input_tensor : torch .Tensor , value ):
61
+ # Our backend can't handle tensors without users, which input_tensor doesn't have
62
+ # when the full_like is converted to a full. Therefore involve it in the output.
63
+ return input_tensor + torch .full_like (input_tensor , value )
64
+
51
65
def _test_full_tosa_MI_pipeline (
52
66
self ,
53
67
module : torch .nn .Module ,
@@ -63,9 +77,7 @@ def _test_full_tosa_MI_pipeline(
63
77
compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" ),
64
78
)
65
79
.export ()
66
- .check_count ({"torch.ops.aten.full.default" : 1 })
67
- .to_edge ()
68
- .partition ()
80
+ .to_edge_transform_and_lower ()
69
81
.check_not (["executorch_exir_dialects_edge__ops_aten_full_default" ])
70
82
.check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
71
83
.to_executorch ()
@@ -85,9 +97,7 @@ def _test_full_tosa_BI_pipeline(
85
97
)
86
98
.quantize ()
87
99
.export ()
88
- .check_count ({"torch.ops.aten.full.default" : 1 })
89
- .to_edge ()
90
- .partition ()
100
+ .to_edge_transform_and_lower ()
91
101
.check_not (["executorch_exir_dialects_edge__ops_aten_full_default" ])
92
102
.check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
93
103
.to_executorch ()
@@ -101,9 +111,7 @@ def _test_full_tosa_ethos_pipeline(
101
111
ArmTester (module , example_inputs = test_data , compile_spec = compile_spec )
102
112
.quantize ()
103
113
.export ()
104
- .check_count ({"torch.ops.aten.full.default" : 1 })
105
- .to_edge ()
106
- .partition ()
114
+ .to_edge_transform_and_lower ()
107
115
.check_not (["executorch_exir_dialects_edge__ops_aten_full_default" ])
108
116
.check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
109
117
.to_executorch ()
@@ -129,6 +137,10 @@ def test_const_full_tosa_MI(self):
129
137
_input = torch .rand ((2 , 2 , 3 , 3 )) * 10
130
138
self ._test_full_tosa_MI_pipeline (self .AddConstFull (), (_input ,))
131
139
140
+ @parameterized .expand (FullLike .test_parameters )
141
+ def test_full_like_tosa_MI (self , test_tensor : Tuple ):
142
+ self ._test_full_tosa_MI_pipeline (self .FullLike (), test_tensor )
143
+
132
144
def test_const_full_nhwc_tosa_BI (self ):
133
145
_input = torch .rand ((2 , 2 , 3 , 3 )) * 10
134
146
self ._test_full_tosa_BI_pipeline (self .AddConstFull (), (_input ,))
@@ -143,6 +155,10 @@ def test_full_tosa_MI(self, test_tensor: Tuple):
143
155
def test_full_tosa_BI (self , test_tensor : Tuple ):
144
156
self ._test_full_tosa_BI_pipeline (self .AddVariableFull (), test_tensor )
145
157
158
+ @parameterized .expand (FullLike .test_parameters )
159
+ def test_full_like_tosa_BI (self , test_tensor : Tuple ):
160
+ self ._test_full_tosa_BI_pipeline (self .FullLike (), test_tensor )
161
+
146
162
@parameterized .expand (AddVariableFull .test_parameters )
147
163
@pytest .mark .corstone_fvp
148
164
def test_full_u55_BI (self , test_tensor : Tuple ):
0 commit comments