@@ -220,7 +220,7 @@ def setup_class(cls):
220
220
[0.54400194 , 0.5265246 , 0.22381854 , 0.3929715 , 0.6757667 ],
221
221
[0.32961223 , 0.38482672 , 0.68877804 , 0.71822757 , 0.711909 ],
222
222
[0.561259 , 0.71047884 , 0.84651315 , 0.8541089 , 0.644432 ]]]],
223
- dtype = cls .dtype )
223
+ dtype = cls .dtype )
224
224
225
225
cls .x_grad = torch .tensor ([[[[0.075625 , 0.15125 , 0.15124999 , 0.15125002 , 0.15812504 , 0.15812503 , 0.15124999 , 0.15124999 , 0.15125006 , 0.0756249 ],
226
226
[0.15125 , 0.30250007 , 0.3025 , 0.30250007 , 0.31625012 ,
@@ -241,7 +241,7 @@ def setup_class(cls):
241
241
1.9925001 , 1.8625001 , 1.9925001 , 0.93124974 ],
242
242
[0.43562484 , 0.9312497 , 0.8712497 , 0.9312497 , 0.5181249 , 0.5181248 ,
243
243
0.9312496 , 0.8712497 , 0.93124974 , 0.43562466 ]]]],
244
- dtype = cls .dtype )
244
+ dtype = cls .dtype )
245
245
246
246
def test_roi_align_basic_cpu (self ):
247
247
device = torch .device ('cpu' )
@@ -252,8 +252,8 @@ def test_roi_align_basic_cpu(self):
252
252
pool_h , pool_w = (5 , 5 )
253
253
roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2 ).to (device = device )
254
254
y = roi_align (x , single_roi )
255
-
256
- assert torch .equal (gt_y_single , y ), 'ROIAlign layer incorrect for single ROI on CPU'
255
+
256
+ assert torch .allclose (gt_y_single , y ), 'ROIAlign layer incorrect for single ROI on CPU'
257
257
258
258
def test_roi_align_cpu (self ):
259
259
device = torch .device ('cpu' )
@@ -265,7 +265,7 @@ def test_roi_align_cpu(self):
265
265
roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2 ).to (device = device )
266
266
y = roi_align (x , rois )
267
267
268
- assert torch .equal (gt_y_multiple , y ), 'ROIAlign layer incorrect for multiple ROIs on CPU'
268
+ assert torch .allclose (gt_y_multiple , y ), 'ROIAlign layer incorrect for multiple ROIs on CPU'
269
269
270
270
@unittest .skipIf (not torch .cuda .is_available (), "CUDA unavailable" )
271
271
def test_roi_align_basic_cuda (self ):
@@ -277,7 +277,7 @@ def test_roi_align_basic_cuda(self):
277
277
pool_h , pool_w = (5 , 5 )
278
278
roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2 ).to (device = device )
279
279
y = roi_align (x , single_roi )
280
-
280
+
281
281
assert torch .allclose (gt_y_single , y ), 'ROIAlign layer incorrect for single ROI on CUDA'
282
282
283
283
@unittest .skipIf (not torch .cuda .is_available (), "CUDA unavailable" )
@@ -309,9 +309,21 @@ def test_roi_align_gradient_cpu(self):
309
309
y = roi_align (x , rois )
310
310
s = y .sum ()
311
311
s .backward ()
312
-
312
+
313
313
assert torch .allclose (x .grad , gt_grad ), 'gradient incorrect for ROIAlign CPU'
314
314
315
+ def test_roi_align_gradcheck_cpu (self ):
316
+ dtype = torch .float64
317
+ device = torch .device ('cpu' )
318
+ m = layers .ROIAlign ((5 , 5 ), 0.5 , 1 ).to (dtype = dtype , device = device )
319
+ x = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = True )
320
+ rois = self .rois .to (device = device , dtype = dtype )
321
+
322
+ def func (input ):
323
+ return m (input , rois )
324
+
325
+ assert gradcheck (func , (x ,)), 'gradcheck failed for ROIAlign CPU'
326
+
315
327
@unittest .skipIf (not torch .cuda .is_available (), "CUDA unavailable" )
316
328
def test_roi_align_gradient_cuda (self ):
317
329
"""
@@ -332,5 +344,19 @@ def test_roi_align_gradient_cuda(self):
332
344
333
345
assert torch .allclose (x .grad , gt_grad ), 'gradient incorrect for ROIAlign CUDA'
334
346
347
+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA unavailable" )
348
+ def test_roi_align_gradcheck_cuda (self ):
349
+ dtype = torch .float64
350
+ device = torch .device ('cuda' )
351
+ m = layers .ROIAlign ((5 , 5 ), 0.5 , 1 ).to (dtype = dtype , device = device )
352
+ x = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = True )
353
+ rois = self .rois .to (device = device , dtype = dtype )
354
+
355
+ def func (input ):
356
+ return m (input , rois )
357
+
358
+ assert gradcheck (func , (x ,)), 'gradcheck failed for ROIAlign CUDA'
359
+
360
+
335
361
if __name__ == '__main__' :
336
362
unittest .main ()
0 commit comments