Skip to content

Commit 73508ba

Browse files
author
Pedro Freire
committed
Update DeformConv to be more consistent w/ Conv2d
* rename some variables and arguments to match Conv2d; * add optional bias; * add weight, offset and bias as module parameters; * remove the n_parallel_imgs parameter; * Fix __repr__; * etc.. Initialization of weight and bias is the same as in Conv2d, and initialization of offsets to zero is the same as in the paper. This also includes some other small unrelated fixes/improvements.
1 parent 163c99e commit 73508ba

File tree

9 files changed

+284
-228
lines changed

9 files changed

+284
-228
lines changed

test/test_ops.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from __future__ import division
22
import math
3-
from typing import Tuple
43
import unittest
54

65
import numpy as np
76

87
import torch
8+
from torch import Tensor
99
from torch.autograd import gradcheck
10+
from torch.jit.annotations import Tuple
1011
from torch.nn.modules.utils import _pair
11-
from torch import Tensor
1212
from torchvision import ops
1313

1414

@@ -374,44 +374,45 @@ def test_nms_cuda(self):
374374

375375

376376
class DeformConvTester(OpTester, unittest.TestCase):
377-
def expected_fn(self, x, offsets, weights, *args, stride=1, pad=0, dilation=1):
377+
def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
378378
stride_h, stride_w = _pair(stride)
379-
pad_h, pad_w = _pair(pad)
379+
pad_h, pad_w = _pair(padding)
380380
dil_h, dil_w = _pair(dilation)
381-
weights_h, weights_w = weights.shape[-2:]
381+
weight_h, weight_w = weight.shape[-2:]
382382

383383
n_batches, n_in_channels, in_h, in_w = x.shape
384-
n_out_channels = weights.shape[0]
384+
n_out_channels = weight.shape[0]
385385

386-
out_h = (in_h + 2 * pad_h - (dil_h * (weights_h - 1) + 1)) // stride_h + 1
387-
out_w = (in_w + 2 * pad_w - (dil_w * (weights_w - 1) + 1)) // stride_w + 1
386+
out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
387+
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
388388

389-
n_offset_grps = offsets.shape[1] // (2 * weights_h * weights_w)
389+
n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
390390
in_c_per_offset_grp = n_in_channels // n_offset_grps
391391

392-
n_weight_grps = n_in_channels // weights.shape[1]
393-
in_c_per_weight_grp = weights.shape[1]
392+
n_weight_grps = n_in_channels // weight.shape[1]
393+
in_c_per_weight_grp = weight.shape[1]
394394
out_c_per_weight_grp = n_out_channels // n_weight_grps
395395

396396
out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
397397
for b in range(n_batches):
398398
for c_out in range(n_out_channels):
399399
for i in range(out_h):
400400
for j in range(out_w):
401-
for di in range(weights_h):
402-
for dj in range(weights_w):
401+
for di in range(weight_h):
402+
for dj in range(weight_w):
403403
for c in range(in_c_per_weight_grp):
404404
weight_grp = c_out // out_c_per_weight_grp
405405
c_in = weight_grp * in_c_per_weight_grp + c
406406

407407
offset_grp = c_in // in_c_per_offset_grp
408-
offset_idx = 2 * (offset_grp * (weights_h * weights_w) + di * weights_w + dj)
408+
offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj)
409409

410-
pi = stride_h * i - pad_h + dil_h * di + offsets[b, offset_idx, i, j]
411-
pj = stride_w * j - pad_w + dil_w * dj + offsets[b, offset_idx + 1, i, j]
410+
pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
411+
pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]
412412

413-
out[b, c_out, i, j] += (weights[c_out, c, di, dj] *
413+
out[b, c_out, i, j] += (weight[c_out, c, di, dj] *
414414
bilinear_interpolate(x[b, c_in, :, :], pi, pj))
415+
out += bias.view(1, n_out_channels, 1, 1)
415416
return out
416417

417418
def get_fn_args(self, device, contiguous):
@@ -442,36 +443,50 @@ def get_fn_args(self, device, contiguous):
442443
weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
443444
device=device, dtype=self.dtype, requires_grad=True)
444445

446+
bias = torch.randn(n_out_channels, device=device, dtype=self.dtype, requires_grad=True)
447+
445448
if not contiguous:
446449
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
447450
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
448451
weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
449452

450-
return x, offset, weight, stride, pad, dilation
453+
return x, weight, offset, bias, stride, pad, dilation
451454

452455
def _test_forward(self, device, contiguous):
453-
x, offset, weight, stride, pad, dilation = self.get_fn_args(device, contiguous)
456+
x, _, _, _, stride, padding, dilation = self.get_fn_args(device, contiguous)
457+
in_channels = 6
458+
out_channels = 2
459+
kernel_size = (3, 2)
460+
groups = 2
461+
offset_groups = 3
462+
463+
layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
464+
dilation=dilation, groups=groups, offset_groups=offset_groups)
465+
layer.offset_conv.weight.data = torch.randn_like(layer.offset_conv.weight.data)
466+
res = layer(x)
454467

455-
res = ops.DeformConv(stride=stride, pad=pad, dilation=dilation)(x, offset, weight)
456-
expected = self.expected_fn(x, offset, weight, stride=stride, pad=pad, dilation=dilation)
468+
weight = layer.weight.data.to(device=x.device, dtype=x.dtype)
469+
offset = layer.offset_conv.to(device=x.device, dtype=x.dtype)(x)
470+
bias = layer.bias.data.to(device=x.device, dtype=x.dtype)
471+
expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation)
457472

458-
self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(x, res, expected))
473+
self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected))
459474

460475
def _test_backward(self, device, contiguous):
461-
x, offset, weight, stride, pad, dilation = self.get_fn_args(device, contiguous)
476+
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous)
462477

463-
def func(x_, offset_, weight_):
464-
return ops.deform_conv(x_, offset_, weight_, stride=stride, pad=pad, dilation=dilation)
478+
def func(x_, weight_, offset_, bias_):
479+
return ops.deform_conv2d(x_, weight_, offset_, bias_, stride=stride, padding=padding, dilation=dilation)
465480

466-
gradcheck(func, (x, offset, weight), nondet_tol=1e-5)
481+
gradcheck(func, (x, weight, offset, bias), nondet_tol=1e-5)
467482

468483
@torch.jit.script
469-
def script_func(x_, offset_, weight_, stride_, pad_, dilation_):
470-
# type: (Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
471-
return ops.deform_conv(x_, offset_, weight_, stride=stride_, pad=pad_, dilation=dilation_)
484+
def script_func(x_, weight_, offset_, bias_, stride_, pad_, dilation_):
485+
# type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
486+
return ops.deform_conv2d(x_, weight_, offset_, bias_, stride=stride_, padding=pad_, dilation=dilation_)
472487

473-
gradcheck(lambda z, off, wei: script_func(z, off, wei, stride, pad, dilation),
474-
(x, offset, weight), nondet_tol=1e-5)
488+
gradcheck(lambda z, wei, off, bi: script_func(z, wei, off, bi, stride, padding, dilation),
489+
(x, weight, offset, bias), nondet_tol=1e-5)
475490

476491

477492
if __name__ == '__main__':

torchvision/csrc/DeformConv.h

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,48 @@
66
#include "cuda/vision_cuda.h"
77
#endif
88

9-
at::Tensor DCN_forward(
9+
at::Tensor DeformConv2d_forward(
1010
const Tensor& input,
11+
const Tensor& weight,
1112
const Tensor& offset,
12-
const Tensor& weights,
13+
const Tensor& bias,
1314
const std::pair<int, int>& stride,
14-
const std::pair<int, int>& pad,
15+
const std::pair<int, int>& padding,
1516
const std::pair<int, int>& dilation,
16-
const int groups,
17-
const int deformable_groups,
18-
const int n_parallel_imgs) {
17+
const int groups, const int offset_groups) {
1918
if (input.type().is_cuda()) {
2019
#ifdef WITH_CUDA
21-
return DCN_forward_cuda(input.contiguous(), offset.contiguous(), weights.contiguous(), stride, pad,
22-
dilation, groups, deformable_groups, n_parallel_imgs);
20+
return DeformConv2d_forward_cuda(input.contiguous(), weight.contiguous(), offset.contiguous(),
21+
bias.contiguous(), stride, padding, dilation, groups, offset_groups);
2322
#else
2423
AT_ERROR("Not compiled with GPU support");
2524
#endif
2625
}
27-
return DCN_forward_cpu(input.contiguous(), offset.contiguous(), weights.contiguous(), stride, pad,
28-
dilation, groups, deformable_groups, n_parallel_imgs);
26+
return DeformConv2d_forward_cpu(input.contiguous(), weight.contiguous(), offset.contiguous(),
27+
bias.contiguous(), stride, padding, dilation, groups, offset_groups);
2928
}
3029

31-
std::tuple<at::Tensor, at::Tensor, at::Tensor> DCN_backward(
30+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward(
3231
const at::Tensor& grad,
3332
const Tensor& input,
33+
const Tensor& weight,
3434
const Tensor& offset,
35-
const Tensor& weights,
35+
const Tensor& bias,
3636
const std::pair<int, int>& stride,
37-
const std::pair<int, int>& pad,
37+
const std::pair<int, int>& padding,
3838
const std::pair<int, int>& dilation,
3939
const int groups,
40-
const int deformable_groups,
41-
const int n_parallel_imgs) {
40+
const int offset_groups) {
4241
if (grad.type().is_cuda()) {
4342
#ifdef WITH_CUDA
44-
return DCN_backward_cuda(grad.contiguous(), input.contiguous(), offset.contiguous(), weights.contiguous(), stride, pad,
45-
dilation, groups, deformable_groups, n_parallel_imgs);
43+
return DeformConv2d_backward_cuda(grad.contiguous(), input.contiguous(), weight.contiguous(), offset.contiguous(),
44+
bias.contiguous(), stride, padding, dilation, groups, offset_groups);
4645
#else
4746
AT_ERROR("Not compiled with GPU support");
4847
#endif
4948
}
50-
return DCN_backward_cpu(grad.contiguous(), input.contiguous(), offset.contiguous(), weights.contiguous(), stride, pad,
51-
dilation, groups, deformable_groups, n_parallel_imgs);
49+
return DeformConv2d_backward_cpu(grad.contiguous(), input.contiguous(), weight.contiguous(), offset.contiguous(),
50+
bias.contiguous(), stride, padding, dilation, groups, offset_groups);
5251
}
5352

5453
using namespace at;
@@ -57,35 +56,35 @@ using torch::autograd::AutogradContext;
5756
using torch::autograd::Variable;
5857
using torch::autograd::variable_list;
5958

60-
class DeformConvFunction : public torch::autograd::Function<DeformConvFunction> {
59+
class DeformConv2dFunction : public torch::autograd::Function<DeformConv2dFunction> {
6160
public:
6261
static variable_list forward(
6362
AutogradContext* ctx,
6463
Variable input,
64+
Variable weight,
6565
Variable offset,
66-
Variable weights,
66+
Variable bias,
6767
int64_t stride_h, int64_t stride_w,
6868
int64_t pad_h, int64_t pad_w,
6969
int64_t dilation_h, int64_t dilation_w,
7070
int64_t groups,
71-
int64_t deformable_groups,
72-
int64_t n_parallel_imgs) {
73-
auto output = DCN_forward(input, offset, weights,
71+
int64_t offset_groups) {
72+
auto output = DeformConv2d_forward(
73+
input, weight, offset, bias,
7474
{stride_h, stride_w},
7575
{pad_h, pad_w},
7676
{dilation_h, dilation_w},
77-
groups, deformable_groups, n_parallel_imgs);
77+
groups, offset_groups);
7878

79-
ctx->save_for_backward({input, offset, weights});
79+
ctx->save_for_backward({input, weight, offset, bias});
8080
ctx->saved_data["stride_h"] = stride_h;
8181
ctx->saved_data["stride_w"] = stride_w;
8282
ctx->saved_data["pad_h"] = pad_h;
8383
ctx->saved_data["pad_w"] = pad_w;
8484
ctx->saved_data["dilation_h"] = dilation_h;
8585
ctx->saved_data["dilation_w"] = dilation_w;
8686
ctx->saved_data["groups"] = groups;
87-
ctx->saved_data["deformable_groups"] = deformable_groups;
88-
ctx->saved_data["n_parallel_imgs"] = n_parallel_imgs;
87+
ctx->saved_data["offset_groups"] = offset_groups;
8988

9089
return {output,};
9190
}
@@ -95,8 +94,9 @@ class DeformConvFunction : public torch::autograd::Function<DeformConvFunction>
9594
variable_list grad_output) {
9695
auto saved = ctx->get_saved_variables();
9796
auto input = saved[0];
98-
auto offset = saved[1];
99-
auto weight = saved[2];
97+
auto weight = saved[1];
98+
auto offset = saved[2];
99+
auto bias = saved[3];
100100

101101
auto stride_h = ctx->saved_data["stride_h"].toInt();
102102
auto stride_w = ctx->saved_data["stride_w"].toInt();
@@ -105,37 +105,36 @@ class DeformConvFunction : public torch::autograd::Function<DeformConvFunction>
105105
auto dilation_h = ctx->saved_data["dilation_h"].toInt();
106106
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
107107
auto groups = ctx->saved_data["groups"].toInt();
108-
auto deformable_groups = ctx->saved_data["deformable_groups"].toInt();
109-
auto n_parallel_imgs = ctx->saved_data["n_parallel_imgs"].toInt();
108+
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
110109

111-
auto grads = DCN_backward(grad_output[0],
112-
input, offset, weight,
110+
auto grads = DeformConv2d_backward(grad_output[0],
111+
input, weight, offset, bias,
113112
{stride_h, stride_w},
114113
{pad_h, pad_w},
115114
{dilation_h, dilation_w},
116-
groups, deformable_groups, n_parallel_imgs);
115+
groups, offset_groups);
117116
auto grad_input = std::get<0>(grads);
118-
auto grad_offset = std::get<1>(grads);
119-
auto grad_weight = std::get<2>(grads);
117+
auto grad_weight = std::get<1>(grads);
118+
auto grad_offset = std::get<2>(grads);
119+
auto grad_bias = std::get<3>(grads);
120120

121-
return {grad_input, grad_offset, grad_weight,
122-
Variable(), Variable(), Variable(),
121+
return {grad_input, grad_weight, grad_offset,
122+
grad_bias, Variable(), Variable(),
123123
Variable(), Variable(), Variable(),
124124
Variable(), Variable(), Variable(),};
125125
}
126126
};
127127

128-
Tensor deform_conv(
128+
Tensor deform_conv2d(
129129
const Tensor& input,
130+
const Tensor& weight,
130131
const Tensor& offset,
131-
const Tensor& weights,
132+
const Tensor& bias,
132133
int64_t stride_h, int64_t stride_w,
133134
int64_t pad_h, int64_t pad_w,
134135
int64_t dilation_h, int64_t dilation_w,
135-
int64_t groups,
136-
int64_t deformable_groups,
137-
int64_t n_parallel_imgs) {
138-
auto result = DeformConvFunction::apply(input, offset, weights, stride_h, stride_w, pad_h, pad_w,
139-
dilation_h, dilation_w, groups, deformable_groups, n_parallel_imgs);
136+
int64_t groups, int64_t offset_groups) {
137+
auto result = DeformConv2dFunction::apply(input, weight, offset, bias, stride_h, stride_w, pad_h, pad_w,
138+
dilation_h, dilation_w, groups, offset_groups);
140139
return result[0];
141140
}

0 commit comments

Comments
 (0)