Skip to content

Commit ae3c453

Browse files
committed
gradcheck tests for ROIAlign
1 parent d99e4d5 commit ae3c453

File tree

1 file changed

+33
-7
lines changed

1 file changed

+33
-7
lines changed

test/test_layers.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def setup_class(cls):
220220
[0.54400194, 0.5265246, 0.22381854, 0.3929715, 0.6757667],
221221
[0.32961223, 0.38482672, 0.68877804, 0.71822757, 0.711909],
222222
[0.561259, 0.71047884, 0.84651315, 0.8541089, 0.644432]]]],
223-
dtype=cls.dtype)
223+
dtype=cls.dtype)
224224

225225
cls.x_grad = torch.tensor([[[[0.075625, 0.15125, 0.15124999, 0.15125002, 0.15812504, 0.15812503, 0.15124999, 0.15124999, 0.15125006, 0.0756249],
226226
[0.15125, 0.30250007, 0.3025, 0.30250007, 0.31625012,
@@ -241,7 +241,7 @@ def setup_class(cls):
241241
1.9925001, 1.8625001, 1.9925001, 0.93124974],
242242
[0.43562484, 0.9312497, 0.8712497, 0.9312497, 0.5181249, 0.5181248,
243243
0.9312496, 0.8712497, 0.93124974, 0.43562466]]]],
244-
dtype=cls.dtype)
244+
dtype=cls.dtype)
245245

246246
def test_roi_align_basic_cpu(self):
247247
device = torch.device('cpu')
@@ -252,8 +252,8 @@ def test_roi_align_basic_cpu(self):
252252
pool_h, pool_w = (5, 5)
253253
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
254254
y = roi_align(x, single_roi)
255-
256-
assert torch.equal(gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CPU'
255+
256+
assert torch.allclose(gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CPU'
257257

258258
def test_roi_align_cpu(self):
259259
device = torch.device('cpu')
@@ -265,7 +265,7 @@ def test_roi_align_cpu(self):
265265
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
266266
y = roi_align(x, rois)
267267

268-
assert torch.equal(gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CPU'
268+
assert torch.allclose(gt_y_multiple, y), 'ROIAlign layer incorrect for multiple ROIs on CPU'
269269

270270
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
271271
def test_roi_align_basic_cuda(self):
@@ -277,7 +277,7 @@ def test_roi_align_basic_cuda(self):
277277
pool_h, pool_w = (5, 5)
278278
roi_align = layers.ROIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
279279
y = roi_align(x, single_roi)
280-
280+
281281
assert torch.allclose(gt_y_single, y), 'ROIAlign layer incorrect for single ROI on CUDA'
282282

283283
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
@@ -309,9 +309,21 @@ def test_roi_align_gradient_cpu(self):
309309
y = roi_align(x, rois)
310310
s = y.sum()
311311
s.backward()
312-
312+
313313
assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for ROIAlign CPU'
314314

315+
def test_roi_align_gradcheck_cpu(self):
316+
dtype = torch.float64
317+
device = torch.device('cpu')
318+
m = layers.ROIAlign((5, 5), 0.5, 1).to(dtype=dtype, device=device)
319+
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True)
320+
rois = self.rois.to(device=device, dtype=dtype)
321+
322+
def func(input):
323+
return m(input, rois)
324+
325+
assert gradcheck(func, (x,)), 'gradcheck failed for ROIAlign CPU'
326+
315327
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
316328
def test_roi_align_gradient_cuda(self):
317329
"""
@@ -332,5 +344,19 @@ def test_roi_align_gradient_cuda(self):
332344

333345
assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for ROIAlign CUDA'
334346

347+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
348+
def test_roi_align_gradcheck_cuda(self):
349+
dtype = torch.float64
350+
device = torch.device('cuda')
351+
m = layers.ROIAlign((5, 5), 0.5, 1).to(dtype=dtype, device=device)
352+
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True)
353+
rois = self.rois.to(device=device, dtype=dtype)
354+
355+
def func(input):
356+
return m(input, rois)
357+
358+
assert gradcheck(func, (x,)), 'gradcheck failed for ROIAlign CUDA'
359+
360+
335361
if __name__ == '__main__':
336362
unittest.main()

0 commit comments

Comments
 (0)