From 06c4018463f27dde435dc3f3f65bddbbf84906da Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 12 Oct 2018 19:14:43 -0400 Subject: [PATCH 01/11] Use of torch7 naming scheme for ROIAlign forward and backward --- torchvision/csrc/cpu/ROIAlign_cpu.cpp | 42 +++++----- torchvision/csrc/cuda/ROIAlign_cuda.cu | 106 +++++++++++++------------ 2 files changed, 76 insertions(+), 72 deletions(-) diff --git a/torchvision/csrc/cpu/ROIAlign_cpu.cpp b/torchvision/csrc/cpu/ROIAlign_cpu.cpp index 3850b2833ab..7ed01733d9d 100644 --- a/torchvision/csrc/cpu/ROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/ROIAlign_cpu.cpp @@ -112,7 +112,7 @@ void pre_calc_for_bilinear_interpolate( template void ROIAlignForward_cpu_kernel( const int nthreads, - const T* bottom_data, + const T* input, const T& spatial_scale, const int channels, const int height, @@ -120,9 +120,9 @@ void ROIAlignForward_cpu_kernel( const int pooled_height, const int pooled_width, const int sampling_ratio, - const T* bottom_rois, + const T* rois, //int roi_cols, - T* top_data) { + T* output) { //AT_ASSERT(roi_cols == 4 || roi_cols == 5); int roi_cols = 5; @@ -134,22 +134,22 @@ void ROIAlignForward_cpu_kernel( int index_n = n * channels * pooled_width * pooled_height; // roi could have 4 or 5 columns - const T* offset_bottom_rois = bottom_rois + n * roi_cols; + const T* offset_rois = rois + n * roi_cols; int roi_batch_ind = 0; if (roi_cols == 5) { - roi_batch_ind = offset_bottom_rois[0]; - offset_bottom_rois++; + roi_batch_ind = offset_rois[0]; + offset_rois++; } // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[0] * spatial_scale; - T roi_start_h = offset_bottom_rois[1] * spatial_scale; - T roi_end_w = offset_bottom_rois[2] * spatial_scale; - T roi_end_h = offset_bottom_rois[3] * spatial_scale; - // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale); - // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale); - // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale); - // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale); + T roi_start_w = offset_rois[0] * spatial_scale; + T roi_start_h = offset_rois[1] * spatial_scale; + T roi_end_w = offset_rois[2] * spatial_scale; + T roi_end_h = offset_rois[3] * spatial_scale; + // T roi_start_w = round(offset_rois[0] * spatial_scale); + // T roi_start_h = round(offset_rois[1] * spatial_scale); + // T roi_end_w = round(offset_rois[2] * spatial_scale); + // T roi_end_h = round(offset_rois[3] * spatial_scale); // Force malformed ROIs to be 1x1 T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); @@ -188,8 +188,8 @@ void ROIAlignForward_cpu_kernel( for (int c = 0; c < channels; c++) { int index_n_c = index_n + c * pooled_width * pooled_height; - const T* offset_bottom_data = - bottom_data + (roi_batch_ind * channels + c) * height * width; + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; int pre_calc_index = 0; for (int ph = 0; ph < pooled_height; ph++) { @@ -200,17 +200,17 @@ void ROIAlignForward_cpu_kernel( for (int iy = 0; iy < roi_bin_grid_h; iy++) { for (int ix = 0; ix < roi_bin_grid_w; ix++) { PreCalc pc = pre_calc[pre_calc_index]; - output_val += pc.w1 * offset_bottom_data[pc.pos1] + - pc.w2 * offset_bottom_data[pc.pos2] + - pc.w3 * offset_bottom_data[pc.pos3] + - pc.w4 * offset_bottom_data[pc.pos4]; + output_val += pc.w1 * offset_input[pc.pos1] + + pc.w2 * offset_input[pc.pos2] + + pc.w3 * offset_input[pc.pos3] + + pc.w4 * offset_input[pc.pos4]; pre_calc_index += 1; } } output_val /= count; - top_data[index] = output_val; + output[index] = output_val; } // for pw } // for ph } // for c diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/ROIAlign_cuda.cu index 9cc5ae28934..535911ef443 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ROIAlign_cuda.cu @@ -12,7 +12,7 @@ template -__device__ T bilinear_interpolate(const T* bottom_data, +__device__ T bilinear_interpolate(const T* input, const int height, const int width, T y, T x, const int index /* index for debug only*/) { @@ -48,11 +48,12 @@ __device__ T bilinear_interpolate(const T* bottom_data, T ly = y - y_low; T lx = x - x_low; T hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation - T v1 = bottom_data[y_low * width + x_low]; - T v2 = bottom_data[y_low * width + x_high]; - T v3 = bottom_data[y_high * width + x_low]; - T v4 = bottom_data[y_high * width + x_high]; + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); @@ -61,12 +62,12 @@ __device__ T bilinear_interpolate(const T* bottom_data, } template -__global__ void RoIAlignForward(const int nthreads, const T* bottom_data, +__global__ void RoIAlignForward(const int nthreads, const T* input, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int sampling_ratio, - const T* bottom_rois, T* top_data) { + const T* rois, T* output) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -74,18 +75,14 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data, int c = (index / pooled_width / pooled_height) % channels; int n = index / pooled_width / pooled_height / channels; - const T* offset_bottom_rois = bottom_rois + n * 5; - int roi_batch_ind = offset_bottom_rois[0]; + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[1] * spatial_scale; - T roi_start_h = offset_bottom_rois[2] * spatial_scale; - T roi_end_w = offset_bottom_rois[3] * spatial_scale; - T roi_end_h = offset_bottom_rois[4] * spatial_scale; - // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); - // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); - // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); - // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; // Force malformed ROIs to be 1x1 T roi_width = max(roi_end_w - roi_start_w, (T)1.); @@ -93,7 +90,7 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data, T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; + const T* offset_input = input + (roi_batch_ind * channels + c) * height * width; // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 @@ -110,13 +107,13 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data, { const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); - T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index); + T val = bilinear_interpolate(offset_input, height, width, y, x, index); output_val += val; } } output_val /= count; - top_data[index] = output_val; + output[index] = output_val; } } @@ -162,10 +159,10 @@ __device__ void bilinear_interpolate_gradient( T hy = 1. - ly, hx = 1. - lx; // reference in forward - // T v1 = bottom_data[y_low * width + x_low]; - // T v2 = bottom_data[y_low * width + x_high]; - // T v3 = bottom_data[y_high * width + x_low]; - // T v4 = bottom_data[y_high * width + x_high]; + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; @@ -174,13 +171,15 @@ __device__ void bilinear_interpolate_gradient( } template -__global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, +__global__ void RoIAlignBackwardFeature(const int nthreads, const T* grad_output, const int num_rois, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int sampling_ratio, - T* bottom_diff, - const T* bottom_rois) { + T* grad_input, + const T* rois, + const int n_stride, const int c_stride, + const int h_stride, const int w_stride) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; @@ -188,30 +187,26 @@ __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, int c = (index / pooled_width / pooled_height) % channels; int n = index / pooled_width / pooled_height / channels; - const T* offset_bottom_rois = bottom_rois + n * 5; - int roi_batch_ind = offset_bottom_rois[0]; + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[1] * spatial_scale; - T roi_start_h = offset_bottom_rois[2] * spatial_scale; - T roi_end_w = offset_bottom_rois[3] * spatial_scale; - T roi_end_h = offset_bottom_rois[4] * spatial_scale; - // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); - // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); - // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); - // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); - + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; + // Force malformed ROIs to be 1x1 T roi_width = max(roi_end_w - roi_start_w, (T)1.); T roi_height = max(roi_end_h - roi_start_h, (T)1.); T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width; + T* offset_grad_input = grad_input + (roi_batch_ind * channels + c) * height * width; - int top_offset = (n * channels + c) * pooled_height * pooled_width; - const T* offset_top_diff = top_diff + top_offset; - const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_grad_output = grad_output + top_offset; + const T grad_output_this_bin = offset_grad_output[ph * pooled_width + pw]; // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 @@ -235,17 +230,17 @@ __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, x_low, x_high, y_low, y_high, index); - T g1 = top_diff_this_bin * w1 / count; - T g2 = top_diff_this_bin * w2 / count; - T g3 = top_diff_this_bin * w3 / count; - T g4 = top_diff_this_bin * w4 / count; + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast(g1)); - atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast(g2)); - atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast(g3)); - atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast(g4)); + atomicAdd(offset_grad_input + y_low * width + x_low, static_cast(g1)); + atomicAdd(offset_grad_input + y_low * width + x_high, static_cast(g2)); + atomicAdd(offset_grad_input + y_high * width + x_low, static_cast(g3)); + atomicAdd(offset_grad_input + y_high * width + x_high, static_cast(g4)); } // if } // ix } // iy @@ -326,6 +321,11 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, return grad_input; } + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] { RoIAlignBackwardFeature<<>>( grad.numel(), @@ -339,7 +339,11 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, pooled_width, sampling_ratio, grad_input.data(), - rois.data()); + rois.data(), + n_stride, + c_stride, + h_stride, + w_stride); }); THCudaCheck(cudaGetLastError()); return grad_input; From 8c0cdf7c55f9075e3c971c4fb4f07b636e184eae Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 16 Oct 2018 14:23:31 -0400 Subject: [PATCH 02/11] use common cuda helpers in ROIAlign --- torchvision/csrc/cuda/ROIAlign_cuda.cu | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/ROIAlign_cuda.cu index 535911ef443..adbec2fa1bb 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ROIAlign_cuda.cu @@ -5,10 +5,7 @@ #include #include -// TODO make it in a common file -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) +#include "cuda_helpers.h" template @@ -48,7 +45,7 @@ __device__ T bilinear_interpolate(const T* input, T ly = y - y_low; T lx = x - x_low; T hy = 1. - ly, hx = 1. - lx; - + // do bilinear interpolation T v1 = input[y_low * width + x_low]; T v2 = input[y_low * width + x_high]; @@ -171,7 +168,7 @@ __device__ void bilinear_interpolate_gradient( } template -__global__ void RoIAlignBackwardFeature(const int nthreads, const T* grad_output, +__global__ void RoIAlignBackward(const int nthreads, const T* grad_output, const int num_rois, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, @@ -327,7 +324,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, int w_stride = grad.stride(3); AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] { - RoIAlignBackwardFeature<<>>( + RoIAlignBackward<<>>( grad.numel(), grad.data(), num_rois, From c57ef086fcec0914f838f4d235fa4da228c6de1e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 17 Oct 2018 13:13:47 -0400 Subject: [PATCH 03/11] use .options() in favor of .type() where applicable --- torchvision/csrc/cpu/ROIPool_cpu.cpp | 6 +++--- torchvision/csrc/cuda/ROIPool_cuda.cu | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/torchvision/csrc/cpu/ROIPool_cpu.cpp b/torchvision/csrc/cpu/ROIPool_cpu.cpp index 8ae35930533..eba66c18a14 100644 --- a/torchvision/csrc/cpu/ROIPool_cpu.cpp +++ b/torchvision/csrc/cpu/ROIPool_cpu.cpp @@ -16,8 +16,8 @@ std::tuple ROIPool_forward_cpu(const at::Tensor &input, int input_height = input.size(2); int input_width = input.size(3); - at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type()); - at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt)); + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt)); // define accessors for indexing auto input_a = input.accessor(); @@ -107,7 +107,7 @@ at::Tensor ROIPool_backward_cpu(const at::Tensor &grad, auto num_rois = rois.size(0); - at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type()); + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); // handle possibly empty gradients if (grad.numel() == 0) diff --git a/torchvision/csrc/cuda/ROIPool_cuda.cu b/torchvision/csrc/cuda/ROIPool_cuda.cu index 2ba8dc33e25..514400b0267 100644 --- a/torchvision/csrc/cuda/ROIPool_cuda.cu +++ b/torchvision/csrc/cuda/ROIPool_cuda.cu @@ -108,16 +108,16 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, const float spatial_scale, const int pooled_height, const int pooled_width) { - AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); auto num_rois = rois.size(0); auto channels = input.size(1); auto height = input.size(2); auto width = input.size(3); - at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type()); - at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt)); + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt)); auto output_size = num_rois * pooled_height * pooled_width * channels; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -159,13 +159,13 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, const int height, const int width) { // Check if input tensors are CUDA tensors - AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); - AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); - AT_ASSERTM(argmax.type().is_cuda(), "argmax must be a CUDA tensor"); + AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(argmax.device().is_cuda(), "argmax must be a CUDA tensor"); auto num_rois = rois.size(0); - at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type()); + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); From 20d080309cec6fdd6f94289d61e56c1c90fdb40c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 16 Oct 2018 14:28:18 -0400 Subject: [PATCH 04/11] Added tests for forward pass of ROIAlign, as well as more consistent naming scheme for CPU vs CUDA --- test/test_layers.py | 92 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/test/test_layers.py b/test/test_layers.py index d508393c64a..5f6de3fb07d 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -95,7 +95,7 @@ def test_roi_pool_gradient_cpu(self): assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool' @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") - def test_roi_pool_basic_gpu(self): + def test_roi_pool_basic_cuda(self): dtype = torch.float32 device = torch.device('cuda') x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device) @@ -120,7 +120,7 @@ def test_roi_pool_basic_gpu(self): assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect' @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") - def test_roi_pool_gpu(self): + def test_roi_pool_cuda(self): dtype = torch.float32 device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device) @@ -153,7 +153,7 @@ def test_roi_pool_gpu(self): assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect' @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") - def test_roi_pool_gradient_gpu(self): + def test_roi_pool_gradient_cuda(self): dtype = torch.float32 device = torch.device('cuda') layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device) @@ -186,5 +186,91 @@ def func(input): assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool' +class ROIAlignTester(unittest.TestCase): + @classmethod + def setup_class(cls): + torch.manual_seed(123) + cls.dtype = torch.float32 + cls.x = torch.rand(1, 1, 10, 10, dtype=cls.dtype) + cls.single_roi = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy) + dtype=cls.dtype) + cls.rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9]], + dtype=torch.float32) + + cls.gt_y_single = torch.tensor([[[[0.41617328, 0.5040753, 0.25266218, 0.4296828, 0.29928464], + [0.5210769, 0.57222337, 0.2524979, 0.32063985, 0.32635176], + [0.73108256, 0.6114335, 0.62033176, 0.8188273, 0.5562218], + [0.83115816, 0.70803946, 0.7084047, 0.74928707, 0.7769296], + [0.54266506, 0.45964524, 0.5780159, 0.80522037, 0.7321807]]]]) + + cls.gt_y_multiple = torch.tensor([[[[0.49311584, 0.35972416, 0.40843594, 0.3638034, 0.49751836], + [0.70881474, 0.75481665, 0.5826779, 0.34767765, 0.46865487], + [0.4740328, 0.69306874, 0.3617804, 0.47145438, 0.66130304], + [0.6861706, 0.17634538, 0.47194335, 0.42473823, 0.37930614], + [0.62666404, 0.49973848, 0.37911576, 0.5842756, 0.7176864]]], + [[[0.67499936, 0.6607055, 0.42656037, 0.46134934, 0.42144877], + [0.7471722, 0.7235433, 0.14512213, 0.13031253, 0.289369], + [0.8443615, 0.6659734, 0.23614208, 0.14719573, 0.4268827], + [0.69429564, 0.5621515, 0.5019923, 0.40678093, 0.34556213], + [0.51315194, 0.7177093, 0.6494485, 0.6775592, 0.43865064]]], + [[[0.24465509, 0.36108392, 0.64635646, 0.4051828, 0.33956185], + [0.49006107, 0.42982674, 0.34184104, 0.15493104, 0.49633422], + [0.54400194, 0.5265246, 0.22381854, 0.3929715, 0.6757667], + [0.32961223, 0.38482672, 0.68877804, 0.71822757, 0.711909], + [0.561259, 0.71047884, 0.84651315, 0.8541089, 0.644432]]]]) + + def test_roi_align_basic_cpu(self): + device = torch.device('cpu') + self.x = self.x.to(device) + self.single_roi = self.single_roi.to(device) + self.gt_y_multiple = self.gt_y_multiple.to(device) + + pool_h, pool_w = (5, 5) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2.0) + y = roi_align(self.x, self.single_roi) + + assert torch.equal(self.gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CPU' + + def test_roi_align_cpu(self): + device = torch.device('cpu') + self.x = self.x.to(device) + self.rois = self.rois.to(device) + self.gt_y_multiple = self.gt_y_multiple.to(device) + + pool_h, pool_w = (5, 5) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2.0) + y = roi_align(self.x, self.rois) + + assert torch.equal(self.gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CPU' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_align_basic_cuda(self): + device = torch.device('cuda') + self.x = self.x.to(device) + self.single_roi = self.single_roi.to(device) + self.gt_y_single = self.gt_y_single.to(device) + + pool_h, pool_w = (5, 5) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2.0) + y = roi_align(self.x, self.single_roi) + + assert torch.allclose(self.gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CUDA' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_align_cuda(self): + device = torch.device('cuda') + self.x = self.x.to(device) + self.rois = self.rois.to(device) + self.gt_y_multiple = self.gt_y_multiple.to(device) + + pool_h, pool_w = (5, 5) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2.0) + y = roi_align(self.x, self.rois) + + assert torch.allclose(self.gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CUDA' + + if __name__ == '__main__': unittest.main() From 13828e54fc5651ff9fc497827df60c5dc5fcab3d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 16 Oct 2018 17:23:58 -0400 Subject: [PATCH 05/11] working ROIAlign cuda backwards pass --- torchvision/csrc/cuda/ROIAlign_cuda.cu | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/ROIAlign_cuda.cu index adbec2fa1bb..d4d9154a621 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ROIAlign_cuda.cu @@ -169,7 +169,7 @@ __device__ void bilinear_interpolate_gradient( template __global__ void RoIAlignBackward(const int nthreads, const T* grad_output, - const int num_rois, const T spatial_scale, + const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int sampling_ratio, @@ -199,11 +199,12 @@ __global__ void RoIAlignBackward(const int nthreads, const T* grad_output, T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - T* offset_grad_input = grad_input + (roi_batch_ind * channels + c) * height * width; + T* offset_grad_input = grad_input + ((roi_batch_ind * channels + c) * height * width); - int top_offset = (n * channels + c) * pooled_height * pooled_width; - const T* offset_grad_output = grad_output + top_offset; - const T grad_output_this_bin = offset_grad_output[ph * pooled_width + pw]; + // We need to index the gradient using the tensor strides to access the correct values. + int output_offset = n*n_stride + c*c_stride; + const T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = offset_grad_output[ph*h_stride + pw*w_stride]; // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 @@ -251,8 +252,8 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const int pooled_height, const int pooled_width, const int sampling_ratio) { - AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); auto num_rois = rois.size(0); auto channels = input.size(1); @@ -290,7 +291,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, return output; } -// TODO remove the dependency on input and use instead its sizes -> save memory + at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale, @@ -301,10 +302,9 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, const int height, const int width, const int sampling_ratio) { - AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); - AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); - auto num_rois = rois.size(0); at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -322,12 +322,11 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, int c_stride = grad.stride(1); int h_stride = grad.stride(2); int w_stride = grad.stride(3); - + AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] { RoIAlignBackward<<>>( grad.numel(), grad.data(), - num_rois, spatial_scale, channels, height, From 3722e66214911121920ee0a8cfc3fb6b6682a347 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 16 Oct 2018 17:58:05 -0400 Subject: [PATCH 06/11] working ROIAlign backwards pass for CPU --- torchvision/csrc/cpu/ROIAlign_cpu.cpp | 218 +++++++++++++++++++++++--- 1 file changed, 199 insertions(+), 19 deletions(-) diff --git a/torchvision/csrc/cpu/ROIAlign_cpu.cpp b/torchvision/csrc/cpu/ROIAlign_cpu.cpp index 7ed01733d9d..14136f06d45 100644 --- a/torchvision/csrc/cpu/ROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/ROIAlign_cpu.cpp @@ -110,7 +110,7 @@ void pre_calc_for_bilinear_interpolate( } template -void ROIAlignForward_cpu_kernel( +void ROIAlignForward( const int nthreads, const T* input, const T& spatial_scale, @@ -121,11 +121,7 @@ void ROIAlignForward_cpu_kernel( const int pooled_width, const int sampling_ratio, const T* rois, - //int roi_cols, T* output) { - //AT_ASSERT(roi_cols == 4 || roi_cols == 5); - int roi_cols = 5; - int n_rois = nthreads / channels / pooled_width / pooled_height; // (n, c, ph, pw) is an element in the pooled output // can be parallelized using omp @@ -133,19 +129,14 @@ void ROIAlignForward_cpu_kernel( for (int n = 0; n < n_rois; n++) { int index_n = n * channels * pooled_width * pooled_height; - // roi could have 4 or 5 columns - const T* offset_rois = rois + n * roi_cols; - int roi_batch_ind = 0; - if (roi_cols == 5) { - roi_batch_ind = offset_rois[0]; - offset_rois++; - } + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_rois[0] * spatial_scale; - T roi_start_h = offset_rois[1] * spatial_scale; - T roi_end_w = offset_rois[2] * spatial_scale; - T roi_end_h = offset_rois[3] * spatial_scale; + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; // T roi_start_w = round(offset_rois[0] * spatial_scale); // T roi_start_h = round(offset_rois[1] * spatial_scale); // T roi_end_w = round(offset_rois[2] * spatial_scale); @@ -217,14 +208,154 @@ void ROIAlignForward_cpu_kernel( } // for n } +template +void bilinear_interpolate_gradient( + const int height, const int width, + T y, T x, + T& w1, T& w2, T& w3, T& w4, + int& x_low, int& x_high, int& y_low, int& y_high, + const int index /* index for debug only*/) { + + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void ROIAlignBackward( + const int nthreads, + const T* grad_output, + const T& spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + T* grad_input, + const T* rois, + const int n_stride, const int c_stride, + const int h_stride, const int w_stride) { + for (int index = 0; index < nthreads; index++) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; + + // Force malformed ROIs to be 1x1 + T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); + T roi_height = std::max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_grad_input = grad_input + ((roi_batch_ind * channels + c) * height * width); + + int output_offset = n*n_stride + c*c_stride; + const T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = offset_grad_output[ph*h_stride + pw*w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) + { + const T y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) + { + const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient(height, width, y, x, + w1, w2, w3, w4, + x_low, x_high, y_low, y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + // atomic add is not needed for now since it is single threaded + add(offset_grad_input + y_low * width + x_low, static_cast(g1)); + add(offset_grad_input + y_low * width + x_high, static_cast(g2)); + add(offset_grad_input + y_high * width + x_low, static_cast(g3)); + add(offset_grad_input + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // for +} // ROIAlignBackward + + at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale, const int pooled_height, const int pooled_width, const int sampling_ratio) { - AT_ASSERTM(!input.type().is_cuda(), "input must be a CPU tensor"); - AT_ASSERTM(!rois.type().is_cuda(), "rois must be a CPU tensor"); + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); auto num_rois = rois.size(0); auto channels = input.size(1); @@ -239,7 +370,7 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, return output; AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { - ROIAlignForward_cpu_kernel( + ROIAlignForward( output_size, input.data(), spatial_scale, @@ -254,3 +385,52 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, }); return output; } + + +at::Tensor ROIAlign_backward_cpu(const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio) { + AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) + { + return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_forward", [&] { + ROIAlignBackward( + grad.numel(), + grad.data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + grad_input.data(), + rois.data(), + n_stride, + c_stride, + h_stride, + w_stride); + }); + return grad_input; +} From 64735ab95ab80f5393d2f3ca327a4c23cc745bca Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 16 Oct 2018 17:58:31 -0400 Subject: [PATCH 07/11] added relevant headers for ROIAlign backwards --- torchvision/csrc/ROIAlign.h | 15 ++++++++------- torchvision/csrc/cpu/vision.h | 11 +++++++++++ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/torchvision/csrc/ROIAlign.h b/torchvision/csrc/ROIAlign.h index 94348abec09..c2e6090857f 100644 --- a/torchvision/csrc/ROIAlign.h +++ b/torchvision/csrc/ROIAlign.h @@ -7,12 +7,13 @@ #endif // Interface for Python -at::Tensor ROIAlign_forward(const at::Tensor& input, - const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio) { +at::Tensor ROIAlign_forward(const at::Tensor& input, // Input feature map. + const at::Tensor& rois, // List of ROIs to pool over. + const float spatial_scale, // The scale of the image features. ROIs will be scaled to this. + const int pooled_height, // The height of the pooled feature map. + const int pooled_width, // The width of the pooled feature + const int sampling_ratio) // The number of points to sample in each bin along each axis. +{ if (input.type().is_cuda()) { #ifdef WITH_CUDA return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); @@ -40,6 +41,6 @@ at::Tensor ROIAlign_backward(const at::Tensor& grad, AT_ERROR("Not compiled with GPU support"); #endif } - AT_ERROR("Not implemented on the CPU"); + return ROIAlign_backward_cpu(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio); } diff --git a/torchvision/csrc/cpu/vision.h b/torchvision/csrc/cpu/vision.h index eebf4b95ad5..e9b32559fda 100644 --- a/torchvision/csrc/cpu/vision.h +++ b/torchvision/csrc/cpu/vision.h @@ -25,6 +25,17 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor &input, const int pooled_width, const int sampling_ratio); +at::Tensor ROIAlign_backward_cpu(const at::Tensor &grad, + const at::Tensor &rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio); + at::Tensor nms_cpu(const at::Tensor &dets, const at::Tensor &scores, const float threshold); From d3dc4a139c365db90a2fd12206c2083293513800 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 16 Oct 2018 17:59:11 -0400 Subject: [PATCH 08/11] tests for ROIAlign layer --- test/test_layers.py | 118 +++++++++++++++++++++++++++++++++----------- 1 file changed, 89 insertions(+), 29 deletions(-) diff --git a/test/test_layers.py b/test/test_layers.py index 5f6de3fb07d..390f5701437 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -197,13 +197,13 @@ def setup_class(cls): cls.rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) [0, 0, 5, 4, 9], [0, 5, 5, 9, 9]], - dtype=torch.float32) + dtype=cls.dtype) cls.gt_y_single = torch.tensor([[[[0.41617328, 0.5040753, 0.25266218, 0.4296828, 0.29928464], [0.5210769, 0.57222337, 0.2524979, 0.32063985, 0.32635176], [0.73108256, 0.6114335, 0.62033176, 0.8188273, 0.5562218], [0.83115816, 0.70803946, 0.7084047, 0.74928707, 0.7769296], - [0.54266506, 0.45964524, 0.5780159, 0.80522037, 0.7321807]]]]) + [0.54266506, 0.45964524, 0.5780159, 0.80522037, 0.7321807]]]], dtype=cls.dtype) cls.gt_y_multiple = torch.tensor([[[[0.49311584, 0.35972416, 0.40843594, 0.3638034, 0.49751836], [0.70881474, 0.75481665, 0.5826779, 0.34767765, 0.46865487], @@ -219,58 +219,118 @@ def setup_class(cls): [0.49006107, 0.42982674, 0.34184104, 0.15493104, 0.49633422], [0.54400194, 0.5265246, 0.22381854, 0.3929715, 0.6757667], [0.32961223, 0.38482672, 0.68877804, 0.71822757, 0.711909], - [0.561259, 0.71047884, 0.84651315, 0.8541089, 0.644432]]]]) + [0.561259, 0.71047884, 0.84651315, 0.8541089, 0.644432]]]], + dtype=cls.dtype) + + cls.x_grad = torch.tensor([[[[0.075625, 0.15125, 0.15124999, 0.15125002, 0.15812504, 0.15812503, 0.15124999, 0.15124999, 0.15125006, 0.0756249], + [0.15125, 0.30250007, 0.3025, 0.30250007, 0.31625012, + 0.31625003, 0.3025, 0.3025, 0.30250013, 0.1512498], + [0.15124999, 0.3025, 0.30249995, 0.3025, 0.31625006, + 0.31625, 0.30249995, 0.30249995, 0.30250007, 0.15124978], + [0.15125002, 0.30250007, 0.3025, 0.30250007, 0.31625012, + 0.3162501, 0.3025, 0.3025, 0.30250013, 0.15124981], + [0.15812504, 0.31625012, 0.31625006, 0.31625012, 0.33062524, + 0.3306251, 0.31625006, 0.31625006, 0.3162502, 0.15812483], + [0.5181251, 1.0962502, 1.0362502, 1.0962503, 0.69062525, 0.6906252, + 1.0962502, 1.0362502, 1.0962503, 0.5181248], + [0.93125, 1.9925, 1.8624997, 1.9925, 1.0962502, 1.0962502, + 1.9925, 1.8624998, 1.9925, 0.9312496], + [0.8712501, 1.8625, 1.7425002, 1.8625001, 1.0362502, 1.0362502, + 1.8625, 1.7425001, 1.8625002, 0.8712497], + [0.93125004, 1.9925, 1.8625002, 1.9925, 1.0962503, 1.0962503, + 1.9925001, 1.8625001, 1.9925001, 0.93124974], + [0.43562484, 0.9312497, 0.8712497, 0.9312497, 0.5181249, 0.5181248, + 0.9312496, 0.8712497, 0.93124974, 0.43562466]]]], + dtype=cls.dtype) def test_roi_align_basic_cpu(self): device = torch.device('cpu') - self.x = self.x.to(device) - self.single_roi = self.single_roi.to(device) - self.gt_y_multiple = self.gt_y_multiple.to(device) + x = self.x.to(device) + single_roi = self.single_roi.to(device) + gt_y_single = self.gt_y_single.to(device) pool_h, pool_w = (5, 5) - roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2.0) - y = roi_align(self.x, self.single_roi) - - assert torch.equal(self.gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CPU' + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) + y = roi_align(x, single_roi) + + assert torch.equal(gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CPU' def test_roi_align_cpu(self): device = torch.device('cpu') - self.x = self.x.to(device) - self.rois = self.rois.to(device) - self.gt_y_multiple = self.gt_y_multiple.to(device) + x = self.x.to(device) + rois = self.rois.to(device) + gt_y_multiple = self.gt_y_multiple.to(device) pool_h, pool_w = (5, 5) - roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2.0) - y = roi_align(self.x, self.rois) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) + y = roi_align(x, rois) - assert torch.equal(self.gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CPU' + assert torch.equal(gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CPU' @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_roi_align_basic_cuda(self): device = torch.device('cuda') - self.x = self.x.to(device) - self.single_roi = self.single_roi.to(device) - self.gt_y_single = self.gt_y_single.to(device) + x = self.x.to(device) + single_roi = self.single_roi.to(device) + gt_y_single = self.gt_y_single.to(device) pool_h, pool_w = (5, 5) - roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2.0) - y = roi_align(self.x, self.single_roi) - - assert torch.allclose(self.gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CUDA' + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) + y = roi_align(x, single_roi) + + assert torch.allclose(gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CUDA' @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_roi_align_cuda(self): device = torch.device('cuda') - self.x = self.x.to(device) - self.rois = self.rois.to(device) - self.gt_y_multiple = self.gt_y_multiple.to(device) + x = self.x.to(device) + rois = self.rois.to(device) + gt_y_multiple = self.gt_y_multiple.to(device) + + pool_h, pool_w = (5, 5) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) + y = roi_align(x, rois) + + assert torch.allclose(gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CUDA' + + def test_roi_align_gradient_cpu(self): + """ + Compute gradients for ROIAlign with multiple bounding boxes on CPU + """ + device = torch.device('cpu') + pool_h, pool_w = (5, 5) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) + + x = self.x.to(device).clone() + rois = self.rois.to(device) + gt_grad = self.x_grad.to(device) + x.requires_grad = True + y = roi_align(x, rois) + s = y.sum() + s.backward() + + assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for ROIAlign CPU' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_align_gradient_cuda(self): + """ + Compute gradients for ROIAlign with multiple bounding boxes on the GPU + """ + device = torch.device('cuda') pool_h, pool_w = (5, 5) - roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2.0) - y = roi_align(self.x, self.rois) + roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) + + x = self.x.to(device).clone() + rois = self.rois.to(device) + gt_grad = self.x_grad.to(device) - assert torch.allclose(self.gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CUDA' + x.requires_grad = True + y = roi_align(x, rois) + s = y.sum() + s.backward() + assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for ROIAlign CUDA' if __name__ == '__main__': unittest.main() From 110d998f7c35db91b609fdbacbce2bd1215d53c9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 17 Oct 2018 13:15:33 -0400 Subject: [PATCH 09/11] replace .type() with .options() for tensor initialization in ROIAlign layers --- torchvision/csrc/cpu/ROIAlign_cpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/csrc/cpu/ROIAlign_cpu.cpp b/torchvision/csrc/cpu/ROIAlign_cpu.cpp index 14136f06d45..142843e60c8 100644 --- a/torchvision/csrc/cpu/ROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/ROIAlign_cpu.cpp @@ -362,7 +362,7 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, auto height = input.size(2); auto width = input.size(3); - at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type()); + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); auto output_size = num_rois * pooled_height * pooled_width * channels; From d99e4d5d6ad51a8f9b3b316b35362d0e4d72138a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 17 Oct 2018 18:05:58 -0400 Subject: [PATCH 10/11] support for Half types in ROIAlign --- torchvision/csrc/cpu/ROIAlign_cpu.cpp | 4 ++-- torchvision/csrc/cuda/ROIAlign_cuda.cu | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/csrc/cpu/ROIAlign_cpu.cpp b/torchvision/csrc/cpu/ROIAlign_cpu.cpp index 142843e60c8..295aa8415f2 100644 --- a/torchvision/csrc/cpu/ROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/ROIAlign_cpu.cpp @@ -369,7 +369,7 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, if (output.numel() == 0) return output; - AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] { ROIAlignForward( output_size, input.data(), @@ -414,7 +414,7 @@ at::Tensor ROIAlign_backward_cpu(const at::Tensor& grad, int h_stride = grad.stride(2); int w_stride = grad.stride(3); - AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_forward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_forward", [&] { ROIAlignBackward( grad.numel(), grad.data(), diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/ROIAlign_cuda.cu index d4d9154a621..c21e5538997 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ROIAlign_cuda.cu @@ -273,7 +273,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, return output; } - AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] { RoIAlignForward<<>>( output_size, input.data(), @@ -323,7 +323,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, int h_stride = grad.stride(2); int w_stride = grad.stride(3); - AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] { RoIAlignBackward<<>>( grad.numel(), grad.data(), From ae3c4531fb3b79c7e3e443a1311cccd8c2fbdaa0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 17 Oct 2018 18:12:00 -0400 Subject: [PATCH 11/11] gradcheck tests for ROIAlign --- test/test_layers.py | 40 +++++++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/test/test_layers.py b/test/test_layers.py index 390f5701437..03b6a589f8e 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -220,7 +220,7 @@ def setup_class(cls): [0.54400194, 0.5265246, 0.22381854, 0.3929715, 0.6757667], [0.32961223, 0.38482672, 0.68877804, 0.71822757, 0.711909], [0.561259, 0.71047884, 0.84651315, 0.8541089, 0.644432]]]], - dtype=cls.dtype) + dtype=cls.dtype) cls.x_grad = torch.tensor([[[[0.075625, 0.15125, 0.15124999, 0.15125002, 0.15812504, 0.15812503, 0.15124999, 0.15124999, 0.15125006, 0.0756249], [0.15125, 0.30250007, 0.3025, 0.30250007, 0.31625012, @@ -241,7 +241,7 @@ def setup_class(cls): 1.9925001, 1.8625001, 1.9925001, 0.93124974], [0.43562484, 0.9312497, 0.8712497, 0.9312497, 0.5181249, 0.5181248, 0.9312496, 0.8712497, 0.93124974, 0.43562466]]]], - dtype=cls.dtype) + dtype=cls.dtype) def test_roi_align_basic_cpu(self): device = torch.device('cpu') @@ -252,8 +252,8 @@ def test_roi_align_basic_cpu(self): pool_h, pool_w = (5, 5) roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) y = roi_align(x, single_roi) - - assert torch.equal(gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CPU' + + assert torch.allclose(gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CPU' def test_roi_align_cpu(self): device = torch.device('cpu') @@ -265,7 +265,7 @@ def test_roi_align_cpu(self): roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) y = roi_align(x, rois) - assert torch.equal(gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CPU' + assert torch.allclose(gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CPU' @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_roi_align_basic_cuda(self): @@ -277,7 +277,7 @@ def test_roi_align_basic_cuda(self): pool_h, pool_w = (5, 5) roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device) y = roi_align(x, single_roi) - + assert torch.allclose(gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CUDA' @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") @@ -309,9 +309,21 @@ def test_roi_align_gradient_cpu(self): y = roi_align(x, rois) s = y.sum() s.backward() - + assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for ROIAlign CPU' + def test_roi_align_gradcheck_cpu(self): + dtype = torch.float64 + device = torch.device('cpu') + m = layers.ROIAlign((5, 5), 0.5, 1).to(dtype=dtype, device=device) + x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) + rois = self.rois.to(device=device, dtype=dtype) + + def func(input): + return m(input, rois) + + assert gradcheck(func, (x,)), 'gradcheck failed for ROIAlign CPU' + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_roi_align_gradient_cuda(self): """ @@ -332,5 +344,19 @@ def test_roi_align_gradient_cuda(self): assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for ROIAlign CUDA' + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_align_gradcheck_cuda(self): + dtype = torch.float64 + device = torch.device('cuda') + m = layers.ROIAlign((5, 5), 0.5, 1).to(dtype=dtype, device=device) + x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True) + rois = self.rois.to(device=device, dtype=dtype) + + def func(input): + return m(input, rois) + + assert gradcheck(func, (x,)), 'gradcheck failed for ROIAlign CUDA' + + if __name__ == '__main__': unittest.main()