Skip to content

Commit 2f2255e

Browse files
committed
renamed bottom and top to follow torch conventions
1 parent 7b6d1d2 commit 2f2255e

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

torchvision/csrc/cpu/ROIPool_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(const at::Tensor &input,
3838
int roi_end_w = round(rois_a[n][3] * spatial_scale);
3939
int roi_end_h = round(rois_a[n][4] * spatial_scale);
4040

41-
// Force malformed ROIs to be 1x1 or HxW
41+
// Force malformed ROIs to be 1x1
4242
int roi_width = std::max(roi_end_w - roi_start_w + 1, 1);
4343
int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
4444
float bin_size_h = static_cast<float>(roi_height) / static_cast<float>(pooled_height);

torchvision/csrc/cuda/ROIPool_cuda.cu

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,25 @@
1010

1111

1212
template <typename T>
13-
__global__ void RoIPoolForward(const int nthreads, const T* bottom_data,
13+
__global__ void RoIPoolForward(const int nthreads, const T* input,
1414
const T spatial_scale, const int channels, const int height,
1515
const int width, const int pooled_height, const int pooled_width,
16-
const T* bottom_rois, T* top_data, int* argmax_data) {
16+
const T* rois, T* output, int* argmax_data) {
1717
CUDA_1D_KERNEL_LOOP(index, nthreads) {
1818
// (n, c, ph, pw) is an element in the pooled output
1919
int pw = index % pooled_width;
2020
int ph = (index / pooled_width) % pooled_height;
2121
int c = (index / pooled_width / pooled_height) % channels;
2222
int n = index / pooled_width / pooled_height / channels;
2323

24-
const T* offset_bottom_rois = bottom_rois + n * 5;
25-
int roi_batch_ind = offset_bottom_rois[0];
26-
int roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
27-
int roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
28-
int roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
29-
int roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
24+
const T* offset_rois = rois + n * 5;
25+
int roi_batch_ind = offset_rois[0];
26+
int roi_start_w = round(offset_rois[1] * spatial_scale);
27+
int roi_start_h = round(offset_rois[2] * spatial_scale);
28+
int roi_end_w = round(offset_rois[3] * spatial_scale);
29+
int roi_end_h = round(offset_rois[4] * spatial_scale);
3030

31-
// Force malformed ROIs to be 1x1 or HxW
31+
// Force malformed ROIs to be 1x1
3232
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
3333
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
3434
T bin_size_h = static_cast<T>(roi_height)
@@ -56,28 +56,28 @@ __global__ void RoIPoolForward(const int nthreads, const T* bottom_data,
5656
T maxval = is_empty ? 0 : -FLT_MAX;
5757
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
5858
int maxidx = -1;
59-
const T* offset_bottom_data =
60-
bottom_data + (roi_batch_ind * channels + c) * height * width;
59+
const T* offset_input =
60+
input + (roi_batch_ind * channels + c) * height * width;
6161
for (int h = hstart; h < hend; ++h) {
6262
for (int w = wstart; w < wend; ++w) {
63-
int bottom_index = h * width + w;
64-
if (offset_bottom_data[bottom_index] > maxval) {
65-
maxval = offset_bottom_data[bottom_index];
66-
maxidx = bottom_index;
63+
int input_index = h * width + w;
64+
if (offset_input[input_index] > maxval) {
65+
maxval = offset_input[input_index];
66+
maxidx = input_index;
6767
}
6868
}
6969
}
70-
top_data[index] = maxval;
70+
output[index] = maxval;
7171
argmax_data[index] = maxidx;
7272
}
7373
}
7474

7575
template <typename T>
76-
__global__ void RoIPoolBackward(const int nthreads, const T* top_grad,
76+
__global__ void RoIPoolBackward(const int nthreads, const T* grad_output,
7777
const int* argmax_data, const int num_rois, const T spatial_scale,
7878
const int channels, const int height, const int width,
79-
const int pooled_height, const int pooled_width, T* bottom_data,
80-
const T* bottom_rois,
79+
const int pooled_height, const int pooled_width, T* grad_input,
80+
const T* rois,
8181
const int n_stride, const int c_stride,
8282
const int h_stride, const int w_stride) {
8383

@@ -88,18 +88,17 @@ __global__ void RoIPoolBackward(const int nthreads, const T* top_grad,
8888
int c = (index / pooled_width / pooled_height) % channels;
8989
int n = index / pooled_width / pooled_height / channels;
9090

91-
const T* offset_bottom_rois = bottom_rois + n * 5;
92-
int roi_batch_ind = offset_bottom_rois[0];
93-
int bottom_offset = (roi_batch_ind * channels + c) * height * width;
94-
T* bottom_data_offset = bottom_data + bottom_offset;
91+
const T* offset_rois = rois + n * 5;
92+
int roi_batch_ind = offset_rois[0];
93+
T* grad_input_offset = grad_input + ((roi_batch_ind * channels + c) * height * width);
9594

96-
int top_offset = n*n_stride + c*c_stride;
95+
int output_offset = n*n_stride + c*c_stride;
9796
const int* argmax_data_offset = argmax_data + n*channels*pooled_height*pooled_width;
9897
int argmax = argmax_data_offset[c*pooled_height*pooled_width + ph*pooled_width + pw];
9998

10099
if (argmax != -1) {
101-
atomicAdd(bottom_data_offset + argmax,
102-
static_cast<T>(top_grad[top_offset + ph*h_stride + pw*w_stride]));
100+
atomicAdd(grad_input_offset + argmax,
101+
static_cast<T>(grad_output[output_offset + ph*h_stride + pw*w_stride]));
103102
}
104103
}
105104
}

0 commit comments

Comments
 (0)