Skip to content

Commit d3dc4a1

Browse files
committed
tests for ROIAlign layer
1 parent 64735ab commit d3dc4a1

File tree

1 file changed

+89
-29
lines changed

1 file changed

+89
-29
lines changed

test/test_layers.py

Lines changed: 89 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,13 @@ def setup_class(cls):
197197
cls.rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
198198
[0, 0, 5, 4, 9],
199199
[0, 5, 5, 9, 9]],
200-
dtype=torch.float32)
200+
dtype=cls.dtype)
201201

202202
cls.gt_y_single = torch.tensor([[[[0.41617328, 0.5040753, 0.25266218, 0.4296828, 0.29928464],
203203
[0.5210769, 0.57222337, 0.2524979, 0.32063985, 0.32635176],
204204
[0.73108256, 0.6114335, 0.62033176, 0.8188273, 0.5562218],
205205
[0.83115816, 0.70803946, 0.7084047, 0.74928707, 0.7769296],
206-
[0.54266506, 0.45964524, 0.5780159, 0.80522037, 0.7321807]]]])
206+
[0.54266506, 0.45964524, 0.5780159, 0.80522037, 0.7321807]]]], dtype=cls.dtype)
207207

208208
cls.gt_y_multiple = torch.tensor([[[[0.49311584, 0.35972416, 0.40843594, 0.3638034, 0.49751836],
209209
[0.70881474, 0.75481665, 0.5826779, 0.34767765, 0.46865487],
@@ -219,58 +219,118 @@ def setup_class(cls):
219219
[0.49006107, 0.42982674, 0.34184104, 0.15493104, 0.49633422],
220220
[0.54400194, 0.5265246, 0.22381854, 0.3929715, 0.6757667],
221221
[0.32961223, 0.38482672, 0.68877804, 0.71822757, 0.711909],
222-
[0.561259, 0.71047884, 0.84651315, 0.8541089, 0.644432]]]])
222+
[0.561259, 0.71047884, 0.84651315, 0.8541089, 0.644432]]]],
223+
dtype=cls.dtype)
224+
225+
cls.x_grad = torch.tensor([[[[0.075625, 0.15125, 0.15124999, 0.15125002, 0.15812504, 0.15812503, 0.15124999, 0.15124999, 0.15125006, 0.0756249],
226+
[0.15125, 0.30250007, 0.3025, 0.30250007, 0.31625012,
227+
0.31625003, 0.3025, 0.3025, 0.30250013, 0.1512498],
228+
[0.15124999, 0.3025, 0.30249995, 0.3025, 0.31625006,
229+
0.31625, 0.30249995, 0.30249995, 0.30250007, 0.15124978],
230+
[0.15125002, 0.30250007, 0.3025, 0.30250007, 0.31625012,
231+
0.3162501, 0.3025, 0.3025, 0.30250013, 0.15124981],
232+
[0.15812504, 0.31625012, 0.31625006, 0.31625012, 0.33062524,
233+
0.3306251, 0.31625006, 0.31625006, 0.3162502, 0.15812483],
234+
[0.5181251, 1.0962502, 1.0362502, 1.0962503, 0.69062525, 0.6906252,
235+
1.0962502, 1.0362502, 1.0962503, 0.5181248],
236+
[0.93125, 1.9925, 1.8624997, 1.9925, 1.0962502, 1.0962502,
237+
1.9925, 1.8624998, 1.9925, 0.9312496],
238+
[0.8712501, 1.8625, 1.7425002, 1.8625001, 1.0362502, 1.0362502,
239+
1.8625, 1.7425001, 1.8625002, 0.8712497],
240+
[0.93125004, 1.9925, 1.8625002, 1.9925, 1.0962503, 1.0962503,
241+
1.9925001, 1.8625001, 1.9925001, 0.93124974],
242+
[0.43562484, 0.9312497, 0.8712497, 0.9312497, 0.5181249, 0.5181248,
243+
0.9312496, 0.8712497, 0.93124974, 0.43562466]]]],
244+
dtype=cls.dtype)
223245

224246
def test_roi_align_basic_cpu(self):
225247
device = torch.device('cpu')
226-
self.x = self.x.to(device)
227-
self.single_roi = self.single_roi.to(device)
228-
self.gt_y_multiple = self.gt_y_multiple.to(device)
248+
x = self.x.to(device)
249+
single_roi = self.single_roi.to(device)
250+
gt_y_single = self.gt_y_single.to(device)
229251

230252
pool_h, pool_w = (5, 5)
231-
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2.0)
232-
y = roi_align(self.x, self.single_roi)
233-
234-
assert torch.equal(self.gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CPU'
253+
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
254+
y = roi_align(x, single_roi)
255+
256+
assert torch.equal(gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CPU'
235257

236258
def test_roi_align_cpu(self):
237259
device = torch.device('cpu')
238-
self.x = self.x.to(device)
239-
self.rois = self.rois.to(device)
240-
self.gt_y_multiple = self.gt_y_multiple.to(device)
260+
x = self.x.to(device)
261+
rois = self.rois.to(device)
262+
gt_y_multiple = self.gt_y_multiple.to(device)
241263

242264
pool_h, pool_w = (5, 5)
243-
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2.0)
244-
y = roi_align(self.x, self.rois)
265+
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
266+
y = roi_align(x, rois)
245267

246-
assert torch.equal(self.gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CPU'
268+
assert torch.equal(gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CPU'
247269

248270
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
249271
def test_roi_align_basic_cuda(self):
250272
device = torch.device('cuda')
251-
self.x = self.x.to(device)
252-
self.single_roi = self.single_roi.to(device)
253-
self.gt_y_single = self.gt_y_single.to(device)
273+
x = self.x.to(device)
274+
single_roi = self.single_roi.to(device)
275+
gt_y_single = self.gt_y_single.to(device)
254276

255277
pool_h, pool_w = (5, 5)
256-
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2.0)
257-
y = roi_align(self.x, self.single_roi)
258-
259-
assert torch.allclose(self.gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CUDA'
278+
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
279+
y = roi_align(x, single_roi)
280+
281+
assert torch.allclose(gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CUDA'
260282

261283
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
262284
def test_roi_align_cuda(self):
263285
device = torch.device('cuda')
264-
self.x = self.x.to(device)
265-
self.rois = self.rois.to(device)
266-
self.gt_y_multiple = self.gt_y_multiple.to(device)
286+
x = self.x.to(device)
287+
rois = self.rois.to(device)
288+
gt_y_multiple = self.gt_y_multiple.to(device)
289+
290+
pool_h, pool_w = (5, 5)
291+
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
292+
y = roi_align(x, rois)
293+
294+
assert torch.allclose(gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CUDA'
295+
296+
def test_roi_align_gradient_cpu(self):
297+
"""
298+
Compute gradients for ROIAlign with multiple bounding boxes on CPU
299+
"""
300+
device = torch.device('cpu')
301+
pool_h, pool_w = (5, 5)
302+
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
303+
304+
x = self.x.to(device).clone()
305+
rois = self.rois.to(device)
306+
gt_grad = self.x_grad.to(device)
267307

308+
x.requires_grad = True
309+
y = roi_align(x, rois)
310+
s = y.sum()
311+
s.backward()
312+
313+
assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for ROIAlign CPU'
314+
315+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
316+
def test_roi_align_gradient_cuda(self):
317+
"""
318+
Compute gradients for ROIAlign with multiple bounding boxes on the GPU
319+
"""
320+
device = torch.device('cuda')
268321
pool_h, pool_w = (5, 5)
269-
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2.0)
270-
y = roi_align(self.x, self.rois)
322+
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
323+
324+
x = self.x.to(device).clone()
325+
rois = self.rois.to(device)
326+
gt_grad = self.x_grad.to(device)
271327

272-
assert torch.allclose(self.gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CUDA'
328+
x.requires_grad = True
329+
y = roi_align(x, rois)
330+
s = y.sum()
331+
s.backward()
273332

333+
assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for ROIAlign CUDA'
274334

275335
if __name__ == '__main__':
276336
unittest.main()

0 commit comments

Comments
 (0)