Skip to content

Commit c2e1a32

Browse files
committed
ROIPool: Support for all datatypes (pytorch#632)
* Use of torch7 naming scheme for ROIAlign forward and backward * use common cuda helpers in ROIAlign * use .options() in favor of .type() where applicable * Added tests for forward pass of ROIAlign, as well as more consistent naming scheme for CPU vs CUDA * working ROIAlign cuda backwards pass * working ROIAlign backwards pass for CPU * added relevant headers for ROIAlign backwards * tests for ROIAlign layer * replace .type() with .options() for tensor initialization in ROIAlign layers * support for Half types in ROIAlign * gradcheck tests for ROIAlign * updated ROIPool on CPU to work with all datatypes * updated and cleaned tests for ROI Pooling
1 parent f1079c1 commit c2e1a32

File tree

7 files changed

+749
-327
lines changed

7 files changed

+749
-327
lines changed

test/test_layers.py

Lines changed: 262 additions & 88 deletions
Large diffs are not rendered by default.

torchvision/csrc/ROIAlign.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
#endif
88

99
// Interface for Python
10-
at::Tensor ROIAlign_forward(const at::Tensor& input,
11-
const at::Tensor& rois,
12-
const float spatial_scale,
13-
const int pooled_height,
14-
const int pooled_width,
15-
const int sampling_ratio) {
10+
at::Tensor ROIAlign_forward(const at::Tensor& input, // Input feature map.
11+
const at::Tensor& rois, // List of ROIs to pool over.
12+
const float spatial_scale, // The scale of the image features. ROIs will be scaled to this.
13+
const int pooled_height, // The height of the pooled feature map.
14+
const int pooled_width, // The width of the pooled feature
15+
const int sampling_ratio) // The number of points to sample in each bin along each axis.
16+
{
1617
if (input.type().is_cuda()) {
1718
#ifdef WITH_CUDA
1819
return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
@@ -40,6 +41,6 @@ at::Tensor ROIAlign_backward(const at::Tensor& grad,
4041
AT_ERROR("Not compiled with GPU support");
4142
#endif
4243
}
43-
AT_ERROR("Not implemented on the CPU");
44+
return ROIAlign_backward_cpu(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio);
4445
}
4546

torchvision/csrc/cpu/ROIAlign_cpu.cpp

Lines changed: 215 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -110,46 +110,37 @@ void pre_calc_for_bilinear_interpolate(
110110
}
111111

112112
template <typename T>
113-
void ROIAlignForward_cpu_kernel(
113+
void ROIAlignForward(
114114
const int nthreads,
115-
const T* bottom_data,
115+
const T* input,
116116
const T& spatial_scale,
117117
const int channels,
118118
const int height,
119119
const int width,
120120
const int pooled_height,
121121
const int pooled_width,
122122
const int sampling_ratio,
123-
const T* bottom_rois,
124-
//int roi_cols,
125-
T* top_data) {
126-
//AT_ASSERT(roi_cols == 4 || roi_cols == 5);
127-
int roi_cols = 5;
128-
123+
const T* rois,
124+
T* output) {
129125
int n_rois = nthreads / channels / pooled_width / pooled_height;
130126
// (n, c, ph, pw) is an element in the pooled output
131127
// can be parallelized using omp
132128
// #pragma omp parallel for num_threads(32)
133129
for (int n = 0; n < n_rois; n++) {
134130
int index_n = n * channels * pooled_width * pooled_height;
135131

136-
// roi could have 4 or 5 columns
137-
const T* offset_bottom_rois = bottom_rois + n * roi_cols;
138-
int roi_batch_ind = 0;
139-
if (roi_cols == 5) {
140-
roi_batch_ind = offset_bottom_rois[0];
141-
offset_bottom_rois++;
142-
}
132+
const T* offset_rois = rois + n * 5;
133+
int roi_batch_ind = offset_rois[0];
143134

144135
// Do not using rounding; this implementation detail is critical
145-
T roi_start_w = offset_bottom_rois[0] * spatial_scale;
146-
T roi_start_h = offset_bottom_rois[1] * spatial_scale;
147-
T roi_end_w = offset_bottom_rois[2] * spatial_scale;
148-
T roi_end_h = offset_bottom_rois[3] * spatial_scale;
149-
// T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
150-
// T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
151-
// T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
152-
// T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
136+
T roi_start_w = offset_rois[1] * spatial_scale;
137+
T roi_start_h = offset_rois[2] * spatial_scale;
138+
T roi_end_w = offset_rois[3] * spatial_scale;
139+
T roi_end_h = offset_rois[4] * spatial_scale;
140+
// T roi_start_w = round(offset_rois[0] * spatial_scale);
141+
// T roi_start_h = round(offset_rois[1] * spatial_scale);
142+
// T roi_end_w = round(offset_rois[2] * spatial_scale);
143+
// T roi_end_h = round(offset_rois[3] * spatial_scale);
153144

154145
// Force malformed ROIs to be 1x1
155146
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
@@ -188,8 +179,8 @@ void ROIAlignForward_cpu_kernel(
188179

189180
for (int c = 0; c < channels; c++) {
190181
int index_n_c = index_n + c * pooled_width * pooled_height;
191-
const T* offset_bottom_data =
192-
bottom_data + (roi_batch_ind * channels + c) * height * width;
182+
const T* offset_input =
183+
input + (roi_batch_ind * channels + c) * height * width;
193184
int pre_calc_index = 0;
194185

195186
for (int ph = 0; ph < pooled_height; ph++) {
@@ -200,46 +191,186 @@ void ROIAlignForward_cpu_kernel(
200191
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
201192
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
202193
PreCalc<T> pc = pre_calc[pre_calc_index];
203-
output_val += pc.w1 * offset_bottom_data[pc.pos1] +
204-
pc.w2 * offset_bottom_data[pc.pos2] +
205-
pc.w3 * offset_bottom_data[pc.pos3] +
206-
pc.w4 * offset_bottom_data[pc.pos4];
194+
output_val += pc.w1 * offset_input[pc.pos1] +
195+
pc.w2 * offset_input[pc.pos2] +
196+
pc.w3 * offset_input[pc.pos3] +
197+
pc.w4 * offset_input[pc.pos4];
207198

208199
pre_calc_index += 1;
209200
}
210201
}
211202
output_val /= count;
212203

213-
top_data[index] = output_val;
204+
output[index] = output_val;
214205
} // for pw
215206
} // for ph
216207
} // for c
217208
} // for n
218209
}
219210

211+
template <typename T>
212+
void bilinear_interpolate_gradient(
213+
const int height, const int width,
214+
T y, T x,
215+
T& w1, T& w2, T& w3, T& w4,
216+
int& x_low, int& x_high, int& y_low, int& y_high,
217+
const int index /* index for debug only*/) {
218+
219+
// deal with cases that inverse elements are out of feature map boundary
220+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
221+
// empty
222+
w1 = w2 = w3 = w4 = 0.;
223+
x_low = x_high = y_low = y_high = -1;
224+
return;
225+
}
226+
227+
if (y <= 0) y = 0;
228+
if (x <= 0) x = 0;
229+
230+
y_low = (int)y;
231+
x_low = (int)x;
232+
233+
if (y_low >= height - 1) {
234+
y_high = y_low = height - 1;
235+
y = (T)y_low;
236+
} else {
237+
y_high = y_low + 1;
238+
}
239+
240+
if (x_low >= width - 1) {
241+
x_high = x_low = width - 1;
242+
x = (T)x_low;
243+
} else {
244+
x_high = x_low + 1;
245+
}
246+
247+
T ly = y - y_low;
248+
T lx = x - x_low;
249+
T hy = 1. - ly, hx = 1. - lx;
250+
251+
// reference in forward
252+
// T v1 = input[y_low * width + x_low];
253+
// T v2 = input[y_low * width + x_high];
254+
// T v3 = input[y_high * width + x_low];
255+
// T v4 = input[y_high * width + x_high];
256+
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
257+
258+
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
259+
260+
return;
261+
}
262+
263+
template <class T>
264+
inline void add(T* address, const T& val) {
265+
*address += val;
266+
}
267+
268+
template <typename T>
269+
void ROIAlignBackward(
270+
const int nthreads,
271+
const T* grad_output,
272+
const T& spatial_scale,
273+
const int channels,
274+
const int height,
275+
const int width,
276+
const int pooled_height,
277+
const int pooled_width,
278+
const int sampling_ratio,
279+
T* grad_input,
280+
const T* rois,
281+
const int n_stride, const int c_stride,
282+
const int h_stride, const int w_stride) {
283+
for (int index = 0; index < nthreads; index++) {
284+
// (n, c, ph, pw) is an element in the pooled output
285+
int pw = index % pooled_width;
286+
int ph = (index / pooled_width) % pooled_height;
287+
int c = (index / pooled_width / pooled_height) % channels;
288+
int n = index / pooled_width / pooled_height / channels;
289+
290+
const T* offset_rois = rois + n * 5;
291+
int roi_batch_ind = offset_rois[0];
292+
293+
// Do not using rounding; this implementation detail is critical
294+
T roi_start_w = offset_rois[1] * spatial_scale;
295+
T roi_start_h = offset_rois[2] * spatial_scale;
296+
T roi_end_w = offset_rois[3] * spatial_scale;
297+
T roi_end_h = offset_rois[4] * spatial_scale;
298+
299+
// Force malformed ROIs to be 1x1
300+
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
301+
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
302+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
303+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
304+
305+
T* offset_grad_input = grad_input + ((roi_batch_ind * channels + c) * height * width);
306+
307+
int output_offset = n*n_stride + c*c_stride;
308+
const T* offset_grad_output = grad_output + output_offset;
309+
const T grad_output_this_bin = offset_grad_output[ph*h_stride + pw*w_stride];
310+
311+
// We use roi_bin_grid to sample the grid and mimic integral
312+
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
313+
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
314+
315+
// We do average (integral) pooling inside a bin
316+
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
317+
318+
for (int iy = 0; iy < roi_bin_grid_h; iy++)
319+
{
320+
const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
321+
for (int ix = 0; ix < roi_bin_grid_w; ix++)
322+
{
323+
const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
324+
325+
T w1, w2, w3, w4;
326+
int x_low, x_high, y_low, y_high;
327+
328+
bilinear_interpolate_gradient(height, width, y, x,
329+
w1, w2, w3, w4,
330+
x_low, x_high, y_low, y_high,
331+
index);
332+
333+
T g1 = grad_output_this_bin * w1 / count;
334+
T g2 = grad_output_this_bin * w2 / count;
335+
T g3 = grad_output_this_bin * w3 / count;
336+
T g4 = grad_output_this_bin * w4 / count;
337+
338+
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
339+
// atomic add is not needed for now since it is single threaded
340+
add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
341+
add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
342+
add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
343+
add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
344+
} // if
345+
} // ix
346+
} // iy
347+
} // for
348+
} // ROIAlignBackward
349+
350+
220351
at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
221352
const at::Tensor& rois,
222353
const float spatial_scale,
223354
const int pooled_height,
224355
const int pooled_width,
225356
const int sampling_ratio) {
226-
AT_ASSERTM(!input.type().is_cuda(), "input must be a CPU tensor");
227-
AT_ASSERTM(!rois.type().is_cuda(), "rois must be a CPU tensor");
357+
AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
358+
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
228359

229360
auto num_rois = rois.size(0);
230361
auto channels = input.size(1);
231362
auto height = input.size(2);
232363
auto width = input.size(3);
233364

234-
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type());
365+
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
235366

236367
auto output_size = num_rois * pooled_height * pooled_width * channels;
237368

238369
if (output.numel() == 0)
239370
return output;
240371

241-
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
242-
ROIAlignForward_cpu_kernel<scalar_t>(
372+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] {
373+
ROIAlignForward<scalar_t>(
243374
output_size,
244375
input.data<scalar_t>(),
245376
spatial_scale,
@@ -254,3 +385,52 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
254385
});
255386
return output;
256387
}
388+
389+
390+
at::Tensor ROIAlign_backward_cpu(const at::Tensor& grad,
391+
const at::Tensor& rois,
392+
const float spatial_scale,
393+
const int pooled_height,
394+
const int pooled_width,
395+
const int batch_size,
396+
const int channels,
397+
const int height,
398+
const int width,
399+
const int sampling_ratio) {
400+
AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
401+
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
402+
403+
at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
404+
405+
// handle possibly empty gradients
406+
if (grad.numel() == 0)
407+
{
408+
return grad_input;
409+
}
410+
411+
// get stride values to ensure indexing into gradients is correct.
412+
int n_stride = grad.stride(0);
413+
int c_stride = grad.stride(1);
414+
int h_stride = grad.stride(2);
415+
int w_stride = grad.stride(3);
416+
417+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_forward", [&] {
418+
ROIAlignBackward<scalar_t>(
419+
grad.numel(),
420+
grad.data<scalar_t>(),
421+
spatial_scale,
422+
channels,
423+
height,
424+
width,
425+
pooled_height,
426+
pooled_width,
427+
sampling_ratio,
428+
grad_input.data<scalar_t>(),
429+
rois.data<scalar_t>(),
430+
n_stride,
431+
c_stride,
432+
h_stride,
433+
w_stride);
434+
});
435+
return grad_input;
436+
}

0 commit comments

Comments
 (0)