Skip to content

Torch7 naming convention for ROIAlign C++ code #629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions torchvision/csrc/cpu/ROIAlign_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,17 @@ void pre_calc_for_bilinear_interpolate(
template <typename T>
void ROIAlignForward_cpu_kernel(
const int nthreads,
const T* bottom_data,
const T* input,
const T& spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const T* bottom_rois,
const T* rois,
//int roi_cols,
T* top_data) {
T* output) {
//AT_ASSERT(roi_cols == 4 || roi_cols == 5);
int roi_cols = 5;

Expand All @@ -134,22 +134,22 @@ void ROIAlignForward_cpu_kernel(
int index_n = n * channels * pooled_width * pooled_height;

// roi could have 4 or 5 columns
const T* offset_bottom_rois = bottom_rois + n * roi_cols;
const T* offset_rois = rois + n * roi_cols;
int roi_batch_ind = 0;
if (roi_cols == 5) {
roi_batch_ind = offset_bottom_rois[0];
offset_bottom_rois++;
roi_batch_ind = offset_rois[0];
offset_rois++;
}

// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_bottom_rois[0] * spatial_scale;
T roi_start_h = offset_bottom_rois[1] * spatial_scale;
T roi_end_w = offset_bottom_rois[2] * spatial_scale;
T roi_end_h = offset_bottom_rois[3] * spatial_scale;
// T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
// T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
// T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
// T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
T roi_start_w = offset_rois[0] * spatial_scale;
T roi_start_h = offset_rois[1] * spatial_scale;
T roi_end_w = offset_rois[2] * spatial_scale;
T roi_end_h = offset_rois[3] * spatial_scale;
// T roi_start_w = round(offset_rois[0] * spatial_scale);
// T roi_start_h = round(offset_rois[1] * spatial_scale);
// T roi_end_w = round(offset_rois[2] * spatial_scale);
// T roi_end_h = round(offset_rois[3] * spatial_scale);

// Force malformed ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
Expand Down Expand Up @@ -188,8 +188,8 @@ void ROIAlignForward_cpu_kernel(

for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * pooled_width * pooled_height;
const T* offset_bottom_data =
bottom_data + (roi_batch_ind * channels + c) * height * width;
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
int pre_calc_index = 0;

for (int ph = 0; ph < pooled_height; ph++) {
Expand All @@ -200,17 +200,17 @@ void ROIAlignForward_cpu_kernel(
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
PreCalc<T> pc = pre_calc[pre_calc_index];
output_val += pc.w1 * offset_bottom_data[pc.pos1] +
pc.w2 * offset_bottom_data[pc.pos2] +
pc.w3 * offset_bottom_data[pc.pos3] +
pc.w4 * offset_bottom_data[pc.pos4];
output_val += pc.w1 * offset_input[pc.pos1] +
pc.w2 * offset_input[pc.pos2] +
pc.w3 * offset_input[pc.pos3] +
pc.w4 * offset_input[pc.pos4];

pre_calc_index += 1;
}
}
output_val /= count;

top_data[index] = output_val;
output[index] = output_val;
} // for pw
} // for ph
} // for c
Expand Down
6 changes: 3 additions & 3 deletions torchvision/csrc/cpu/ROIPool_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(const at::Tensor &input,
int input_height = input.size(2);
int input_width = input.size(3);

at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type());
at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt));
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt));

// define accessors for indexing
auto input_a = input.accessor<float, 4>();
Expand Down Expand Up @@ -107,7 +107,7 @@ at::Tensor ROIPool_backward_cpu(const at::Tensor &grad,

auto num_rois = rois.size(0);

at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type());
at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());

// handle possibly empty gradients
if (grad.numel() == 0)
Expand Down
113 changes: 57 additions & 56 deletions torchvision/csrc/cuda/ROIAlign_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

// TODO make it in a common file
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
#include "cuda_helpers.h"


template <typename T>
__device__ T bilinear_interpolate(const T* bottom_data,
__device__ T bilinear_interpolate(const T* input,
const int height, const int width,
T y, T x,
const int index /* index for debug only*/) {
Expand Down Expand Up @@ -48,11 +45,12 @@ __device__ T bilinear_interpolate(const T* bottom_data,
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;

// do bilinear interpolation
T v1 = bottom_data[y_low * width + x_low];
T v2 = bottom_data[y_low * width + x_high];
T v3 = bottom_data[y_high * width + x_low];
T v4 = bottom_data[y_high * width + x_high];
T v1 = input[y_low * width + x_low];
T v2 = input[y_low * width + x_high];
T v3 = input[y_high * width + x_low];
T v4 = input[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
Expand All @@ -61,39 +59,35 @@ __device__ T bilinear_interpolate(const T* bottom_data,
}

template <typename T>
__global__ void RoIAlignForward(const int nthreads, const T* bottom_data,
__global__ void RoIAlignForward(const int nthreads, const T* input,
const T spatial_scale, const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const int sampling_ratio,
const T* bottom_rois, T* top_data) {
const T* rois, T* output) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;

const T* offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];

// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_bottom_rois[1] * spatial_scale;
T roi_start_h = offset_bottom_rois[2] * spatial_scale;
T roi_end_w = offset_bottom_rois[3] * spatial_scale;
T roi_end_h = offset_bottom_rois[4] * spatial_scale;
// T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
// T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
// T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
// T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;

// Force malformed ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
const T* offset_input = input + (roi_batch_ind * channels + c) * height * width;

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
Expand All @@ -110,13 +104,13 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data,
{
const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);

T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index);
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
output_val += val;
}
}
output_val /= count;

top_data[index] = output_val;
output[index] = output_val;
}
}

Expand Down Expand Up @@ -162,10 +156,10 @@ __device__ void bilinear_interpolate_gradient(
T hy = 1. - ly, hx = 1. - lx;

// reference in forward
// T v1 = bottom_data[y_low * width + x_low];
// T v2 = bottom_data[y_low * width + x_high];
// T v3 = bottom_data[y_high * width + x_low];
// T v4 = bottom_data[y_high * width + x_high];
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);

w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
Expand All @@ -174,44 +168,42 @@ __device__ void bilinear_interpolate_gradient(
}

template <typename T>
__global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff,
__global__ void RoIAlignBackward(const int nthreads, const T* grad_output,
const int num_rois, const T spatial_scale,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width,
const int sampling_ratio,
T* bottom_diff,
const T* bottom_rois) {
T* grad_input,
const T* rois,
const int n_stride, const int c_stride,
const int h_stride, const int w_stride) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;

const T* offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];

// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_bottom_rois[1] * spatial_scale;
T roi_start_h = offset_bottom_rois[2] * spatial_scale;
T roi_end_w = offset_bottom_rois[3] * spatial_scale;
T roi_end_h = offset_bottom_rois[4] * spatial_scale;
// T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
// T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
// T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
// T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);

T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;

// Force malformed ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width;
T* offset_grad_input = grad_input + (roi_batch_ind * channels + c) * height * width;

int top_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_top_diff = top_diff + top_offset;
const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
int top_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_grad_output = grad_output + top_offset;
const T grad_output_this_bin = offset_grad_output[ph * pooled_width + pw];

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
Expand All @@ -235,17 +227,17 @@ __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff,
x_low, x_high, y_low, y_high,
index);

T g1 = top_diff_this_bin * w1 / count;
T g2 = top_diff_this_bin * w2 / count;
T g3 = top_diff_this_bin * w3 / count;
T g4 = top_diff_this_bin * w4 / count;
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;

if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0)
{
atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast<T>(g1));
atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast<T>(g2));
atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast<T>(g3));
atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast<T>(g4));
atomicAdd(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
atomicAdd(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
atomicAdd(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
atomicAdd(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
} // if
} // ix
} // iy
Expand Down Expand Up @@ -326,8 +318,13 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
return grad_input;
}

int n_stride = grad.stride(0);
int c_stride = grad.stride(1);
int h_stride = grad.stride(2);
int w_stride = grad.stride(3);

AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] {
RoIAlignBackwardFeature<scalar_t><<<grid, block, 0, stream>>>(
RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(),
grad.data<scalar_t>(),
num_rois,
Expand All @@ -339,7 +336,11 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
pooled_width,
sampling_ratio,
grad_input.data<scalar_t>(),
rois.data<scalar_t>());
rois.data<scalar_t>(),
n_stride,
c_stride,
h_stride,
w_stride);
});
THCudaCheck(cudaGetLastError());
return grad_input;
Expand Down
16 changes: 8 additions & 8 deletions torchvision/csrc/cuda/ROIPool_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,16 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");

auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
auto width = input.size(3);

at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type());
at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt));
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt));

auto output_size = num_rois * pooled_height * pooled_width * channels;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -159,13 +159,13 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
const int height,
const int width) {
// Check if input tensors are CUDA tensors
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
AT_ASSERTM(argmax.type().is_cuda(), "argmax must be a CUDA tensor");
AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
AT_ASSERTM(argmax.device().is_cuda(), "argmax must be a CUDA tensor");

auto num_rois = rois.size(0);

at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type());
at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

Expand Down