Skip to content

Commit f484009

Browse files
committed
linter, thought I ran it :/
1 parent 7313516 commit f484009

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

test/test_ops.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,19 +1215,37 @@ def test_deform_conv2d_opcheck(dtype, device, requires_grad):
12151215
out_h = (height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1
12161216
out_w = (width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1
12171217
x = torch.randn(batch_size, channels_in, height, width, dtype=dtype, device=device, requires_grad=requires_grad)
1218-
offset = torch.randn(batch_size, 2 * kernel_size[0] * kernel_size[1], out_h, out_w,
1219-
dtype=dtype, device=device, requires_grad=requires_grad)
1220-
weight = torch.randn(out_channels, channels_in // groups, kernel_size[0], kernel_size[1],
1221-
dtype=dtype, device=device, requires_grad=requires_grad)
1222-
bias = torch.randn(out_channels, dtype=dtype, device=device, requires_grad=requires_grad)
1223-
use_mask = True
1224-
mask = torch.sigmoid(torch.randn(
1218+
offset = torch.randn(
12251219
batch_size,
1226-
kernel_size[0] * kernel_size[1],
1220+
2 * kernel_size[0] * kernel_size[1],
12271221
out_h,
12281222
out_w,
1229-
dtype=dtype, device=device, requires_grad=requires_grad
1230-
))
1223+
dtype=dtype,
1224+
device=device,
1225+
requires_grad=requires_grad,
1226+
)
1227+
weight = torch.randn(
1228+
out_channels,
1229+
channels_in // groups,
1230+
kernel_size[0],
1231+
kernel_size[1],
1232+
dtype=dtype,
1233+
device=device,
1234+
requires_grad=requires_grad,
1235+
)
1236+
bias = torch.randn(out_channels, dtype=dtype, device=device, requires_grad=requires_grad)
1237+
use_mask = True
1238+
mask = torch.sigmoid(
1239+
torch.randn(
1240+
batch_size,
1241+
kernel_size[0] * kernel_size[1],
1242+
out_h,
1243+
out_w,
1244+
dtype=dtype,
1245+
device=device,
1246+
requires_grad=requires_grad,
1247+
)
1248+
)
12311249
kwargs = {
12321250
"offset": offset,
12331251
"weight": weight,
@@ -1246,7 +1264,6 @@ def test_deform_conv2d_opcheck(dtype, device, requires_grad):
12461264
optests.opcheck(torch.ops.torchvision.deform_conv2d, args=(x,), kwargs=kwargs)
12471265

12481266

1249-
12501267
class TestFrozenBNT:
12511268
def test_frozenbatchnorm2d_repr(self):
12521269
num_features = 32

torchvision/csrc/ops/mps/deform_conv2d_kernel.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,4 +146,4 @@
146146
}
147147

148148
} // namespace ops
149-
} // namespace vision
149+
} // namespace vision

0 commit comments

Comments
 (0)