@@ -81,9 +81,15 @@ def forward(self, x):
81
81
z = torch .add (y , z )
82
82
return z
83
83
84
- def _test_conv1d (self , module , inputs , conv_count , quantized = False ):
84
+ def _test_conv1d (
85
+ self , module , inputs , conv_count , quantized = False , dynamic_shape = None
86
+ ):
85
87
(
86
- (Tester (module , inputs ).quantize () if quantized else Tester (module , inputs ))
88
+ (
89
+ Tester (module , inputs , dynamic_shape ).quantize ()
90
+ if quantized
91
+ else Tester (module , inputs )
92
+ )
87
93
.export ()
88
94
.check_count ({"torch.ops.aten.convolution.default" : conv_count })
89
95
.to_edge ()
@@ -101,21 +107,41 @@ def _test_conv1d(self, module, inputs, conv_count, quantized=False):
101
107
)
102
108
103
109
def test_fp16_conv1d (self ):
104
- inputs = (torch .randn (1 , 2 , 4 ).to (torch .float16 ),)
105
- self ._test_conv1d (self .Conv1d (dtype = torch .float16 ), inputs , conv_count = 1 )
110
+ inputs = (torch .randn (2 , 2 , 4 ).to (torch .float16 ),)
111
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , min = 2 , max = 10 )},)
112
+ self ._test_conv1d (
113
+ self .Conv1d (dtype = torch .float16 ),
114
+ inputs ,
115
+ conv_count = 1 ,
116
+ dynamic_shape = dynamic_shapes ,
117
+ )
106
118
107
119
def test_fp32_conv1d (self ):
108
- inputs = (torch .randn (1 , 2 , 4 ),)
109
- self ._test_conv1d (self .Conv1d (), inputs , 1 )
120
+ inputs = (torch .randn (2 , 2 , 4 ),)
121
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , min = 2 , max = 10 )},)
122
+ self ._test_conv1d (self .Conv1d (), inputs , 1 , dynamic_shape = dynamic_shapes )
110
123
111
124
def test_fp32_conv1d_batchnorm_seq (self ):
112
- inputs = (torch .randn (1 , 2 , 4 ),)
113
- self ._test_conv1d (self .Conv1dBatchNormSequential (), inputs , 2 )
125
+ inputs = (torch .randn (2 , 2 , 4 ),)
126
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , min = 2 , max = 10 )},)
127
+ self ._test_conv1d (
128
+ self .Conv1dBatchNormSequential (), inputs , 2 , dynamic_shape = dynamic_shapes
129
+ )
114
130
115
131
def test_qs8_conv1d (self ):
116
- inputs = (torch .randn (1 , 2 , 4 ),)
117
- self ._test_conv1d (self .Conv1d (), inputs , 1 , quantized = True )
132
+ inputs = (torch .randn (2 , 2 , 4 ),)
133
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , min = 2 , max = 10 )},)
134
+ self ._test_conv1d (
135
+ self .Conv1d (), inputs , 1 , quantized = True , dynamic_shape = dynamic_shapes
136
+ )
118
137
119
138
def test_qs8_conv1d_batchnorm_seq (self ):
120
- inputs = (torch .randn (1 , 2 , 4 ),)
121
- self ._test_conv1d (self .Conv1dBatchNormSequential (), inputs , 2 , quantized = True )
139
+ inputs = (torch .randn (2 , 2 , 4 ),)
140
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , min = 2 , max = 10 )},)
141
+ self ._test_conv1d (
142
+ self .Conv1dBatchNormSequential (),
143
+ inputs ,
144
+ 2 ,
145
+ quantized = True ,
146
+ dynamic_shape = dynamic_shapes ,
147
+ )
0 commit comments