@@ -1215,19 +1215,37 @@ def test_deform_conv2d_opcheck(dtype, device, requires_grad):
1215
1215
out_h = (height + 2 * padding [0 ] - dilation [0 ] * (kernel_size [0 ] - 1 ) - 1 ) // stride [0 ] + 1
1216
1216
out_w = (width + 2 * padding [1 ] - dilation [1 ] * (kernel_size [1 ] - 1 ) - 1 ) // stride [1 ] + 1
1217
1217
x = torch .randn (batch_size , channels_in , height , width , dtype = dtype , device = device , requires_grad = requires_grad )
1218
- offset = torch .randn (batch_size , 2 * kernel_size [0 ] * kernel_size [1 ], out_h , out_w ,
1219
- dtype = dtype , device = device , requires_grad = requires_grad )
1220
- weight = torch .randn (out_channels , channels_in // groups , kernel_size [0 ], kernel_size [1 ],
1221
- dtype = dtype , device = device , requires_grad = requires_grad )
1222
- bias = torch .randn (out_channels , dtype = dtype , device = device , requires_grad = requires_grad )
1223
- use_mask = True
1224
- mask = torch .sigmoid (torch .randn (
1218
+ offset = torch .randn (
1225
1219
batch_size ,
1226
- kernel_size [0 ] * kernel_size [1 ],
1220
+ 2 * kernel_size [0 ] * kernel_size [1 ],
1227
1221
out_h ,
1228
1222
out_w ,
1229
- dtype = dtype , device = device , requires_grad = requires_grad
1230
- ))
1223
+ dtype = dtype ,
1224
+ device = device ,
1225
+ requires_grad = requires_grad ,
1226
+ )
1227
+ weight = torch .randn (
1228
+ out_channels ,
1229
+ channels_in // groups ,
1230
+ kernel_size [0 ],
1231
+ kernel_size [1 ],
1232
+ dtype = dtype ,
1233
+ device = device ,
1234
+ requires_grad = requires_grad ,
1235
+ )
1236
+ bias = torch .randn (out_channels , dtype = dtype , device = device , requires_grad = requires_grad )
1237
+ use_mask = True
1238
+ mask = torch .sigmoid (
1239
+ torch .randn (
1240
+ batch_size ,
1241
+ kernel_size [0 ] * kernel_size [1 ],
1242
+ out_h ,
1243
+ out_w ,
1244
+ dtype = dtype ,
1245
+ device = device ,
1246
+ requires_grad = requires_grad ,
1247
+ )
1248
+ )
1231
1249
kwargs = {
1232
1250
"offset" : offset ,
1233
1251
"weight" : weight ,
@@ -1246,7 +1264,6 @@ def test_deform_conv2d_opcheck(dtype, device, requires_grad):
1246
1264
optests .opcheck (torch .ops .torchvision .deform_conv2d , args = (x ,), kwargs = kwargs )
1247
1265
1248
1266
1249
-
1250
1267
class TestFrozenBNT :
1251
1268
def test_frozenbatchnorm2d_repr (self ):
1252
1269
num_features = 32
0 commit comments