7
7
#endif
8
8
9
9
at::Tensor DeformConv2d_forward (
10
- const Tensor& input,
11
- const Tensor& weight,
12
- const Tensor& offset,
13
- const Tensor& bias,
10
+ const at:: Tensor& input,
11
+ const at:: Tensor& weight,
12
+ const at:: Tensor& offset,
13
+ const at:: Tensor& bias,
14
14
const std::pair<int , int >& stride,
15
15
const std::pair<int , int >& padding,
16
16
const std::pair<int , int >& dilation,
17
- const int groups, const int offset_groups) {
17
+ const int groups,
18
+ const int offset_groups) {
18
19
if (input.type ().is_cuda ()) {
19
20
#ifdef WITH_CUDA
20
- return DeformConv2d_forward_cuda (input.contiguous (), weight.contiguous (), offset.contiguous (),
21
- bias.contiguous (), stride, padding, dilation, groups, offset_groups);
21
+ return DeformConv2d_forward_cuda (
22
+ input.contiguous (),
23
+ weight.contiguous (),
24
+ offset.contiguous (),
25
+ bias.contiguous (),
26
+ stride,
27
+ padding,
28
+ dilation,
29
+ groups,
30
+ offset_groups);
22
31
#else
23
32
AT_ERROR (" Not compiled with GPU support" );
24
33
#endif
25
34
}
26
- return DeformConv2d_forward_cpu (input.contiguous (), weight.contiguous (), offset.contiguous (),
27
- bias.contiguous (), stride, padding, dilation, groups, offset_groups);
35
+ return DeformConv2d_forward_cpu (
36
+ input.contiguous (),
37
+ weight.contiguous (),
38
+ offset.contiguous (),
39
+ bias.contiguous (),
40
+ stride,
41
+ padding,
42
+ dilation,
43
+ groups,
44
+ offset_groups);
28
45
}
29
46
30
47
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward (
31
48
const at::Tensor& grad,
32
- const Tensor& input,
33
- const Tensor& weight,
34
- const Tensor& offset,
35
- const Tensor& bias,
49
+ const at:: Tensor& input,
50
+ const at:: Tensor& weight,
51
+ const at:: Tensor& offset,
52
+ const at:: Tensor& bias,
36
53
const std::pair<int , int >& stride,
37
54
const std::pair<int , int >& padding,
38
55
const std::pair<int , int >& dilation,
39
56
const int groups,
40
57
const int offset_groups) {
41
58
if (grad.type ().is_cuda ()) {
42
59
#ifdef WITH_CUDA
43
- return DeformConv2d_backward_cuda (grad.contiguous (), input.contiguous (), weight.contiguous (), offset.contiguous (),
44
- bias.contiguous (), stride, padding, dilation, groups, offset_groups);
60
+ return DeformConv2d_backward_cuda (
61
+ grad.contiguous (),
62
+ input.contiguous (),
63
+ weight.contiguous (),
64
+ offset.contiguous (),
65
+ bias.contiguous (),
66
+ stride,
67
+ padding,
68
+ dilation,
69
+ groups,
70
+ offset_groups);
45
71
#else
46
72
AT_ERROR (" Not compiled with GPU support" );
47
73
#endif
48
74
}
49
- return DeformConv2d_backward_cpu (grad.contiguous (), input.contiguous (), weight.contiguous (), offset.contiguous (),
50
- bias.contiguous (), stride, padding, dilation, groups, offset_groups);
75
+ return DeformConv2d_backward_cpu (
76
+ grad.contiguous (),
77
+ input.contiguous (),
78
+ weight.contiguous (),
79
+ offset.contiguous (),
80
+ bias.contiguous (),
81
+ stride,
82
+ padding,
83
+ dilation,
84
+ groups,
85
+ offset_groups);
51
86
}
52
87
53
88
using namespace at ;
@@ -56,25 +91,33 @@ using torch::autograd::AutogradContext;
56
91
using torch::autograd::Variable;
57
92
using torch::autograd::variable_list;
58
93
59
- class DeformConv2dFunction : public torch ::autograd::Function<DeformConv2dFunction> {
94
+ class DeformConv2dFunction
95
+ : public torch::autograd::Function<DeformConv2dFunction> {
60
96
public:
61
97
static variable_list forward (
62
98
AutogradContext* ctx,
63
99
Variable input,
64
100
Variable weight,
65
101
Variable offset,
66
102
Variable bias,
67
- int64_t stride_h, int64_t stride_w,
68
- int64_t pad_h, int64_t pad_w,
69
- int64_t dilation_h, int64_t dilation_w,
103
+ int64_t stride_h,
104
+ int64_t stride_w,
105
+ int64_t pad_h,
106
+ int64_t pad_w,
107
+ int64_t dilation_h,
108
+ int64_t dilation_w,
70
109
int64_t groups,
71
110
int64_t offset_groups) {
72
111
auto output = DeformConv2d_forward (
73
- input, weight, offset, bias,
112
+ input,
113
+ weight,
114
+ offset,
115
+ bias,
74
116
{stride_h, stride_w},
75
117
{pad_h, pad_w},
76
118
{dilation_h, dilation_w},
77
- groups, offset_groups);
119
+ groups,
120
+ offset_groups);
78
121
79
122
ctx->save_for_backward ({input, weight, offset, bias});
80
123
ctx->saved_data [" stride_h" ] = stride_h;
@@ -86,7 +129,9 @@ class DeformConv2dFunction : public torch::autograd::Function<DeformConv2dFuncti
86
129
ctx->saved_data [" groups" ] = groups;
87
130
ctx->saved_data [" offset_groups" ] = offset_groups;
88
131
89
- return {output,};
132
+ return {
133
+ output,
134
+ };
90
135
}
91
136
92
137
static variable_list backward (
@@ -107,34 +152,64 @@ class DeformConv2dFunction : public torch::autograd::Function<DeformConv2dFuncti
107
152
auto groups = ctx->saved_data [" groups" ].toInt ();
108
153
auto offset_groups = ctx->saved_data [" offset_groups" ].toInt ();
109
154
110
- auto grads = DeformConv2d_backward (grad_output[0 ],
111
- input, weight, offset, bias,
155
+ auto grads = DeformConv2d_backward (
156
+ grad_output[0 ],
157
+ input,
158
+ weight,
159
+ offset,
160
+ bias,
112
161
{stride_h, stride_w},
113
162
{pad_h, pad_w},
114
163
{dilation_h, dilation_w},
115
- groups, offset_groups);
164
+ groups,
165
+ offset_groups);
116
166
auto grad_input = std::get<0 >(grads);
117
167
auto grad_weight = std::get<1 >(grads);
118
168
auto grad_offset = std::get<2 >(grads);
119
169
auto grad_bias = std::get<3 >(grads);
120
170
121
- return {grad_input, grad_weight, grad_offset,
122
- grad_bias, Variable (), Variable (),
123
- Variable (), Variable (), Variable (),
124
- Variable (), Variable (), Variable (),};
171
+ return {
172
+ grad_input,
173
+ grad_weight,
174
+ grad_offset,
175
+ grad_bias,
176
+ Variable (),
177
+ Variable (),
178
+ Variable (),
179
+ Variable (),
180
+ Variable (),
181
+ Variable (),
182
+ Variable (),
183
+ Variable (),
184
+ };
125
185
}
126
186
};
127
187
128
- Tensor deform_conv2d (
129
- const Tensor& input,
130
- const Tensor& weight,
131
- const Tensor& offset,
132
- const Tensor& bias,
133
- int64_t stride_h, int64_t stride_w,
134
- int64_t pad_h, int64_t pad_w,
135
- int64_t dilation_h, int64_t dilation_w,
136
- int64_t groups, int64_t offset_groups) {
137
- auto result = DeformConv2dFunction::apply (input, weight, offset, bias, stride_h, stride_w, pad_h, pad_w,
138
- dilation_h, dilation_w, groups, offset_groups);
188
+ at::Tensor deform_conv2d (
189
+ const at::Tensor& input,
190
+ const at::Tensor& weight,
191
+ const at::Tensor& offset,
192
+ const at::Tensor& bias,
193
+ int64_t stride_h,
194
+ int64_t stride_w,
195
+ int64_t pad_h,
196
+ int64_t pad_w,
197
+ int64_t dilation_h,
198
+ int64_t dilation_w,
199
+ int64_t groups,
200
+ int64_t offset_groups) {
201
+ auto result = DeformConv2dFunction::apply (
202
+ input,
203
+ weight,
204
+ offset,
205
+ bias,
206
+ stride_h,
207
+ stride_w,
208
+ pad_h,
209
+ pad_w,
210
+ dilation_h,
211
+ dilation_w,
212
+ groups,
213
+ offset_groups);
139
214
return result[0 ];
140
215
}
0 commit comments