Skip to content

Commit 563604c

Browse files
varunagrawalfmassa
authored andcommitted
Support for ROI Pooling (#592)
* ROI Pooling with tests. Fix for cuda context in ROI Align. * renamed bottom and top to follow torch conventions
1 parent 366f493 commit 563604c

File tree

10 files changed

+466
-164
lines changed

10 files changed

+466
-164
lines changed

setup.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import subprocess
99
import glob
1010

11-
import torch.cuda
11+
import torch
1212
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
1313

1414

@@ -84,12 +84,15 @@ def get_extensions():
8484
sources = main_file + source_cpu
8585
extension = CppExtension
8686

87+
extra_compile_args = {'cxx': []}
8788
define_macros = []
8889

8990
if torch.cuda.is_available() and CUDA_HOME is not None:
9091
extension = CUDAExtension
9192
sources += source_cuda
9293
define_macros += [('WITH_CUDA', None)]
94+
extra_compile_args['nvcc'] = ['-DCUDA_HAS_FP16=1', '-D__CUDA_NO_HALF_OPERATORS__',
95+
'-D__CUDA_NO_HALF_CONVERSIONS__', '-D__CUDA_NO_HALF2_OPERATORS__']
9396

9497
sources = [os.path.join(extensions_dir, s) for s in sources]
9598

@@ -100,7 +103,8 @@ def get_extensions():
100103
'torchvision._C',
101104
sources,
102105
include_dirs=include_dirs,
103-
define_macros=define_macros
106+
define_macros=define_macros,
107+
extra_compile_args=extra_compile_args,
104108
)
105109
]
106110

test/test_layers.py

Lines changed: 164 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,65 +7,184 @@
77
import unittest
88

99

10-
class Tester(unittest.TestCase):
10+
class ROIPoolTester(unittest.TestCase):
1111

12-
def test_roi_align(self):
13-
outputs = []
12+
def test_roi_pool_basic_cpu(self):
1413
dtype = torch.float32
15-
x = torch.rand(1, 1, 10, 10, dtype=dtype)
16-
rois = torch.tensor([
17-
[0, 0, 0, 10, 10],
18-
[0, 0, 5, 5, 10],
19-
[0, 5, 5, 10, 10]], dtype=dtype)
14+
device = torch.device('cpu')
15+
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device)
16+
rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
17+
dtype=dtype, device=device)
18+
19+
pool_h, pool_w = (5, 5)
20+
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
21+
y = roi_pool(x, rois)
22+
23+
gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w)
2024

21-
for device in ['cpu', 'cuda']:
22-
device = torch.device(device)
23-
x_n = x.to(device)
24-
rois_n = rois.to(device)
25-
output = layers.roi_align(x_n, rois_n, (5, 5), 0.5, 1).to('cpu')
26-
outputs.append(output)
25+
for n in range(0, gt_y.size(0)):
26+
start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1
27+
start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1
28+
roi_x = x[:, :, start_h:end_h, start_w:end_w]
29+
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
30+
for j in range(0, pool_h):
31+
for i in range(0, pool_w):
32+
gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w])
2733

28-
assert (outputs[0] - outputs[1]).abs().max() < 1e-6
34+
assert torch.equal(gt_y, y), 'ROIPool layer incorrect'
2935

30-
def test_roi_align_gradient(self):
31-
dtype = torch.float64
36+
def test_roi_pool_cpu(self):
37+
dtype = torch.float32
38+
device = torch.device('cpu')
39+
x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device)
40+
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
41+
[0, 0, 5, 4, 9],
42+
[0, 5, 5, 9, 9],
43+
[1, 0, 0, 9, 9]],
44+
dtype=dtype, device=device)
45+
46+
pool_h, pool_w = (5, 5)
47+
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
48+
y = roi_pool(x, rois)
49+
50+
gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device)
51+
for n in range(0, gt_y.size(0)):
52+
for r, roi in enumerate(rois):
53+
if roi[0] == n:
54+
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
55+
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
56+
roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w]
57+
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
58+
for j in range(0, pool_h):
59+
for i in range(0, pool_w):
60+
gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i],
61+
torch.max(roi_x[:, :,
62+
j * bin_h:(j + 1) * bin_h,
63+
i * bin_w:(i + 1) * bin_w])
64+
)
65+
66+
assert torch.equal(gt_y, y), 'ROIPool layer incorrect'
67+
68+
def test_roi_pool_gradient_cpu(self):
69+
dtype = torch.float32
70+
device = torch.device('cpu')
71+
layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device)
72+
x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True)
73+
cx = torch.ones(1, 1, 10, 10, dtype=dtype, requires_grad=True).cuda()
74+
rois = torch.tensor([
75+
[0, 0, 0, 9, 9],
76+
[0, 0, 5, 4, 9],
77+
[0, 0, 0, 4, 4]],
78+
dtype=dtype, device=device)
79+
80+
y = layer(x, rois)
81+
s = y.sum()
82+
s.backward()
83+
84+
gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
85+
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
86+
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
87+
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
88+
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
89+
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
90+
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
91+
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
92+
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
93+
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device)
94+
95+
assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool'
96+
97+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
98+
def test_roi_pool_basic_gpu(self):
99+
dtype = torch.float32
32100
device = torch.device('cuda')
33-
m = layers.ROIAlign((5, 5), 0.5, 1).to(dtype=dtype, device=device)
34101
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device)
35-
rois = torch.tensor([
36-
[0, 0, 0, 10, 10],
37-
[0, 0, 5, 5, 10],
38-
[0, 5, 5, 10, 10]], dtype=dtype, device=device)
102+
rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
103+
dtype=dtype, device=device)
39104

40-
def func(input):
41-
return m(input, rois)
105+
pool_h, pool_w = (5, 5)
106+
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
107+
y = roi_pool(x, rois)
108+
109+
gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w)
110+
111+
for n in range(0, gt_y.size(0)):
112+
start_h, end_h = int(rois[n, 2].item()), int(rois[n, 4].item()) + 1
113+
start_w, end_w = int(rois[n, 1].item()), int(rois[n, 3].item()) + 1
114+
roi_x = x[:, :, start_h:end_h, start_w:end_w]
115+
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
116+
for j in range(0, pool_h):
117+
for i in range(0, pool_w):
118+
gt_y[n, :, j, i] = torch.max(roi_x[:, :, j * bin_h:(j + 1) * bin_h, i * bin_w:(i + 1) * bin_w])
42119

43-
assert gradcheck(func, (x,)), 'gradcheck failed for roi_align'
120+
assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect'
44121

45-
def test_roi_pool_gradient(self):
46-
dtype = torch.float64
122+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
123+
def test_roi_pool_gpu(self):
124+
dtype = torch.float32
125+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
126+
x = torch.rand(2, 1, 10, 10, dtype=dtype, device=device)
127+
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
128+
[0, 0, 5, 4, 9],
129+
[0, 5, 5, 9, 9],
130+
[1, 0, 0, 9, 9]],
131+
dtype=dtype, device=device)
132+
133+
pool_h, pool_w = (5, 5)
134+
roi_pool = layers.ROIPool((pool_h, pool_w), 1)
135+
y = roi_pool(x, rois)
136+
137+
gt_y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, device=device)
138+
for n in range(0, gt_y.size(0)):
139+
for r, roi in enumerate(rois):
140+
if roi[0] == n:
141+
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
142+
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
143+
roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w]
144+
bin_h, bin_w = roi_x.size(2) // pool_h, roi_x.size(3) // pool_w
145+
for j in range(0, pool_h):
146+
for i in range(0, pool_w):
147+
gt_y[r, :, j, i] = torch.max(gt_y[r, :, j, i],
148+
torch.max(roi_x[:, :,
149+
j * bin_h:(j + 1) * bin_h,
150+
i * bin_w:(i + 1) * bin_w])
151+
)
152+
153+
assert torch.equal(gt_y.cuda(), y), 'ROIPool layer incorrect'
154+
155+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
156+
def test_roi_pool_gradient_gpu(self):
157+
dtype = torch.float32
47158
device = torch.device('cuda')
48-
m = layers.ROIPool((5, 5), 0.5).to(dtype=dtype, device=device)
49-
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device)
159+
layer = layers.ROIPool((5, 5), 1).to(dtype=dtype, device=device)
160+
x = torch.ones(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True)
50161
rois = torch.tensor([
51-
[0, 0, 0, 10, 10],
52-
[0, 0, 5, 5, 10],
53-
[0, 5, 5, 10, 10]], dtype=dtype, device=device)
162+
[0, 0, 0, 9, 9],
163+
[0, 0, 5, 4, 9],
164+
[0, 0, 0, 4, 4]],
165+
dtype=dtype, device=device)
54166

55167
def func(input):
56-
return m(input, rois)
57-
58-
assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool'
59-
60-
def test_nms(self):
61-
boxes = torch.tensor([
62-
[0, 0, 100, 100],
63-
[2, 2, 98, 98],
64-
[50, 50, 200, 200],
65-
[50, 50, 200, 200]], dtype=torch.float32)
66-
scores = torch.tensor([1, 2, 0.5, 1], dtype=torch.float32)
67-
keep = layers.nms(boxes, scores, 0.5)
68-
assert keep.tolist() == [1, 3]
168+
return layer(input, rois)
169+
170+
x.requires_grad = True
171+
y = layer(x, rois)
172+
# print(argmax, argmax.shape)
173+
s = y.sum()
174+
s.backward()
175+
gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
176+
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
177+
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
178+
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
179+
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
180+
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
181+
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
182+
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
183+
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
184+
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]], device=device)
185+
186+
assert torch.equal(x.grad, gt_grad), 'gradient incorrect for roi_pool'
187+
69188

70189
if __name__ == '__main__':
71190
unittest.main()

torchvision/csrc/ROIPool.h

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,42 +6,41 @@
66
#include "cuda/vision.h"
77
#endif
88

9-
10-
std::tuple<at::Tensor, at::Tensor> ROIPool_forward(const at::Tensor& input,
11-
const at::Tensor& rois,
12-
const float spatial_scale,
13-
const int pooled_height,
14-
const int pooled_width) {
15-
if (input.type().is_cuda()) {
9+
std::tuple<at::Tensor, at::Tensor> ROIPool_forward(const at::Tensor &input,
10+
const at::Tensor &rois,
11+
const float spatial_scale,
12+
const int pooled_height,
13+
const int pooled_width)
14+
{
15+
if (input.type().is_cuda())
16+
{
1617
#ifdef WITH_CUDA
17-
return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width);
18+
return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width);
1819
#else
19-
AT_ERROR("Not compiled with GPU support");
20+
AT_ERROR("Not compiled with GPU support");
2021
#endif
21-
}
22-
AT_ERROR("Not implemented on the CPU");
22+
}
23+
return ROIPool_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width);
2324
}
2425

25-
at::Tensor ROIPool_backward(const at::Tensor& grad,
26-
const at::Tensor& input,
27-
const at::Tensor& rois,
28-
const at::Tensor& argmax,
29-
const float spatial_scale,
30-
const int pooled_height,
31-
const int pooled_width,
32-
const int batch_size,
33-
const int channels,
34-
const int height,
35-
const int width) {
36-
if (grad.type().is_cuda()) {
26+
at::Tensor ROIPool_backward(const at::Tensor &grad,
27+
const at::Tensor &rois,
28+
const at::Tensor &argmax,
29+
const float spatial_scale,
30+
const int pooled_height,
31+
const int pooled_width,
32+
const int batch_size,
33+
const int channels,
34+
const int height,
35+
const int width)
36+
{
37+
if (grad.type().is_cuda())
38+
{
3739
#ifdef WITH_CUDA
38-
return ROIPool_backward_cuda(grad, input, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
40+
return ROIPool_backward_cuda(grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
3941
#else
40-
AT_ERROR("Not compiled with GPU support");
42+
AT_ERROR("Not compiled with GPU support");
4143
#endif
42-
}
43-
AT_ERROR("Not implemented on the CPU");
44-
}
45-
46-
47-
44+
}
45+
return ROIPool_backward_cpu(grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
46+
}

0 commit comments

Comments
 (0)