12
12
import torch
13
13
from executorch .backends .arm .test import common
14
14
from executorch .backends .arm .test .tester .arm_tester import ArmTester
15
+ from parameterized import parameterized
15
16
16
17
logger = logging .getLogger (__name__ )
17
18
logger .setLevel (logging .INFO )
@@ -126,6 +127,32 @@ def forward(self, x):
126
127
return x
127
128
128
129
130
+ class ComboConvRelu6 (torch .nn .Module ):
131
+ edge_op_list = [
132
+ "executorch_exir_dialects_edge__ops_aten_convolution_default" ,
133
+ "executorch_exir_dialects_edge__ops_aten_hardtanh_default" ,
134
+ ]
135
+
136
+ test_data = [
137
+ (20 * torch .randn (1 , 3 , 256 , 256 ),),
138
+ (5 * torch .randn (1 , 3 , 256 , 256 ),),
139
+ (torch .randn (1 , 3 , 256 , 256 ),),
140
+ (- 5 * torch .randn (1 , 3 , 256 , 256 ),),
141
+ ]
142
+
143
+ def __init__ (self ):
144
+ super ().__init__ ()
145
+ self .conv2d = torch .nn .Conv2d (
146
+ in_channels = 3 , out_channels = 3 , kernel_size = 3 , stride = 1 , groups = 1
147
+ )
148
+ self .relu6 = torch .nn .ReLU6 ()
149
+
150
+ def forward (self , x ):
151
+ x = self .conv2d (x )
152
+ x = self .relu6 (x )
153
+ return x
154
+
155
+
129
156
class TestConvCombos (unittest .TestCase ):
130
157
def _test_conv_combo_tosa_MI_pipeline (
131
158
self , module : torch .nn .Module , test_data : Tuple [torch .Tensor ]
@@ -222,15 +249,9 @@ def test_conv_batchnorm_relu_tosa_MI(self):
222
249
model = ComboConvBatchnormRelu ()
223
250
self ._test_conv_combo_tosa_MI_pipeline (model , model .get_inputs ())
224
251
225
- # TODO(MLETORCH-85): Investigate numerical issue. This diff is present in legacy
226
- # testcase as well (and also not tested). For now, just increase the
227
- # tolerance, such that we don't skip the test entirely (i.e. we maintain
228
- # functionality).
229
252
def test_conv_batchnorm_relu_tosa_BI (self ):
230
253
model = ComboConvBatchnormRelu ()
231
- self ._test_conv_combo_tosa_BI_pipeline (
232
- model , model .get_inputs (), atol = 1.0 , rtol = 1.0
233
- )
254
+ self ._test_conv_combo_tosa_BI_pipeline (model , model .get_inputs ())
234
255
235
256
@unittest .skipIf (
236
257
not common .VELA_INSTALLED ,
@@ -240,21 +261,41 @@ def test_conv_batchnorm_relu_u55_BI(self):
240
261
model = ComboConvBatchnormRelu ()
241
262
self ._test_conv_combo_u55_BI_pipeline (model , model .get_inputs ())
242
263
264
+ ##################
265
+ ## Conv + ReLU6 ##
266
+ ##################
267
+ @parameterized .expand (ComboConvRelu6 .test_data )
268
+ def test_conv_relu6_tosa_MI (self , test_data : torch .Tensor ):
269
+ model = ComboConvRelu6 ()
270
+ test_data = (test_data ,)
271
+ self ._test_conv_combo_tosa_MI_pipeline (model , test_data )
272
+
273
+ @parameterized .expand (ComboConvRelu6 .test_data )
274
+ def test_conv_relu6_tosa_BI (self , test_data : torch .Tensor ):
275
+ model = ComboConvRelu6 ()
276
+ test_data = (test_data ,)
277
+ self ._test_conv_combo_tosa_BI_pipeline (model , test_data )
278
+
279
+ @parameterized .expand (ComboConvRelu6 .test_data )
280
+ @unittest .skipIf (
281
+ not common .VELA_INSTALLED ,
282
+ "There is no point in running U55 tests if the Vela tool is not installed" ,
283
+ )
284
+ def test_conv_relu6_u55_BI (self , test_data : torch .Tensor ):
285
+ model = ComboConvRelu6 ()
286
+ test_data = (test_data ,)
287
+ self ._test_conv_combo_u55_BI_pipeline (model , test_data )
288
+
243
289
###############################
244
290
## Block bottleneck residual ##
245
291
###############################
246
292
def test_block_bottleneck_residual_tosa_MI (self ):
247
293
model = ComboBlockBottleneckResidual ()
248
294
self ._test_conv_combo_tosa_MI_pipeline (model , model .get_inputs ())
249
295
250
- # TODO(MLETORCH-85): Investigate numerical issue. This diff was present in legacy
251
- # testcase as well. For now, just increase the tolerance, such that
252
- # we don't skip the test entirely (i.e. we maintain functionality).
253
296
def test_block_bottleneck_residual_tosa_BI (self ):
254
297
model = ComboBlockBottleneckResidual ()
255
- self ._test_conv_combo_tosa_BI_pipeline (
256
- model , model .get_inputs (), atol = 1.0 , rtol = 1.0
257
- )
298
+ self ._test_conv_combo_tosa_BI_pipeline (model , model .get_inputs ())
258
299
259
300
@unittest .skipIf (
260
301
not common .VELA_INSTALLED ,
0 commit comments