@@ -197,13 +197,13 @@ def setup_class(cls):
197
197
cls .rois = torch .tensor ([[0 , 0 , 0 , 9 , 9 ], # format is (xyxy)
198
198
[0 , 0 , 5 , 4 , 9 ],
199
199
[0 , 5 , 5 , 9 , 9 ]],
200
- dtype = torch . float32 )
200
+ dtype = cls . dtype )
201
201
202
202
cls .gt_y_single = torch .tensor ([[[[0.41617328 , 0.5040753 , 0.25266218 , 0.4296828 , 0.29928464 ],
203
203
[0.5210769 , 0.57222337 , 0.2524979 , 0.32063985 , 0.32635176 ],
204
204
[0.73108256 , 0.6114335 , 0.62033176 , 0.8188273 , 0.5562218 ],
205
205
[0.83115816 , 0.70803946 , 0.7084047 , 0.74928707 , 0.7769296 ],
206
- [0.54266506 , 0.45964524 , 0.5780159 , 0.80522037 , 0.7321807 ]]]])
206
+ [0.54266506 , 0.45964524 , 0.5780159 , 0.80522037 , 0.7321807 ]]]], dtype = cls . dtype )
207
207
208
208
cls .gt_y_multiple = torch .tensor ([[[[0.49311584 , 0.35972416 , 0.40843594 , 0.3638034 , 0.49751836 ],
209
209
[0.70881474 , 0.75481665 , 0.5826779 , 0.34767765 , 0.46865487 ],
@@ -219,58 +219,118 @@ def setup_class(cls):
219
219
[0.49006107 , 0.42982674 , 0.34184104 , 0.15493104 , 0.49633422 ],
220
220
[0.54400194 , 0.5265246 , 0.22381854 , 0.3929715 , 0.6757667 ],
221
221
[0.32961223 , 0.38482672 , 0.68877804 , 0.71822757 , 0.711909 ],
222
- [0.561259 , 0.71047884 , 0.84651315 , 0.8541089 , 0.644432 ]]]])
222
+ [0.561259 , 0.71047884 , 0.84651315 , 0.8541089 , 0.644432 ]]]],
223
+ dtype = cls .dtype )
224
+
225
+ cls .x_grad = torch .tensor ([[[[0.075625 , 0.15125 , 0.15124999 , 0.15125002 , 0.15812504 , 0.15812503 , 0.15124999 , 0.15124999 , 0.15125006 , 0.0756249 ],
226
+ [0.15125 , 0.30250007 , 0.3025 , 0.30250007 , 0.31625012 ,
227
+ 0.31625003 , 0.3025 , 0.3025 , 0.30250013 , 0.1512498 ],
228
+ [0.15124999 , 0.3025 , 0.30249995 , 0.3025 , 0.31625006 ,
229
+ 0.31625 , 0.30249995 , 0.30249995 , 0.30250007 , 0.15124978 ],
230
+ [0.15125002 , 0.30250007 , 0.3025 , 0.30250007 , 0.31625012 ,
231
+ 0.3162501 , 0.3025 , 0.3025 , 0.30250013 , 0.15124981 ],
232
+ [0.15812504 , 0.31625012 , 0.31625006 , 0.31625012 , 0.33062524 ,
233
+ 0.3306251 , 0.31625006 , 0.31625006 , 0.3162502 , 0.15812483 ],
234
+ [0.5181251 , 1.0962502 , 1.0362502 , 1.0962503 , 0.69062525 , 0.6906252 ,
235
+ 1.0962502 , 1.0362502 , 1.0962503 , 0.5181248 ],
236
+ [0.93125 , 1.9925 , 1.8624997 , 1.9925 , 1.0962502 , 1.0962502 ,
237
+ 1.9925 , 1.8624998 , 1.9925 , 0.9312496 ],
238
+ [0.8712501 , 1.8625 , 1.7425002 , 1.8625001 , 1.0362502 , 1.0362502 ,
239
+ 1.8625 , 1.7425001 , 1.8625002 , 0.8712497 ],
240
+ [0.93125004 , 1.9925 , 1.8625002 , 1.9925 , 1.0962503 , 1.0962503 ,
241
+ 1.9925001 , 1.8625001 , 1.9925001 , 0.93124974 ],
242
+ [0.43562484 , 0.9312497 , 0.8712497 , 0.9312497 , 0.5181249 , 0.5181248 ,
243
+ 0.9312496 , 0.8712497 , 0.93124974 , 0.43562466 ]]]],
244
+ dtype = cls .dtype )
223
245
224
246
def test_roi_align_basic_cpu (self ):
225
247
device = torch .device ('cpu' )
226
- self . x = self .x .to (device )
227
- self . single_roi = self .single_roi .to (device )
228
- self . gt_y_multiple = self .gt_y_multiple .to (device )
248
+ x = self .x .to (device )
249
+ single_roi = self .single_roi .to (device )
250
+ gt_y_single = self .gt_y_single .to (device )
229
251
230
252
pool_h , pool_w = (5 , 5 )
231
- roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2.0 )
232
- y = roi_align (self . x , self . single_roi )
233
-
234
- assert torch .equal (self . gt_y_single , y ), 'ROIAlign layer incorrect for single ROI on CPU'
253
+ roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2 ). to ( device = device )
254
+ y = roi_align (x , single_roi )
255
+
256
+ assert torch .equal (gt_y_single , y ), 'ROIAlign layer incorrect for single ROI on CPU'
235
257
236
258
def test_roi_align_cpu (self ):
237
259
device = torch .device ('cpu' )
238
- self . x = self .x .to (device )
239
- self . rois = self .rois .to (device )
240
- self . gt_y_multiple = self .gt_y_multiple .to (device )
260
+ x = self .x .to (device )
261
+ rois = self .rois .to (device )
262
+ gt_y_multiple = self .gt_y_multiple .to (device )
241
263
242
264
pool_h , pool_w = (5 , 5 )
243
- roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2.0 )
244
- y = roi_align (self . x , self . rois )
265
+ roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2 ). to ( device = device )
266
+ y = roi_align (x , rois )
245
267
246
- assert torch .equal (self . gt_y_multiple , y ), 'ROIAlign layer incorrect for multiple ROIs on CPU'
268
+ assert torch .equal (gt_y_multiple , y ), 'ROIAlign layer incorrect for multiple ROIs on CPU'
247
269
248
270
@unittest .skipIf (not torch .cuda .is_available (), "CUDA unavailable" )
249
271
def test_roi_align_basic_cuda (self ):
250
272
device = torch .device ('cuda' )
251
- self . x = self .x .to (device )
252
- self . single_roi = self .single_roi .to (device )
253
- self . gt_y_single = self .gt_y_single .to (device )
273
+ x = self .x .to (device )
274
+ single_roi = self .single_roi .to (device )
275
+ gt_y_single = self .gt_y_single .to (device )
254
276
255
277
pool_h , pool_w = (5 , 5 )
256
- roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2.0 )
257
- y = roi_align (self . x , self . single_roi )
258
-
259
- assert torch .allclose (self . gt_y_single , y ), 'ROIAlign layer incorrect for single ROI on CUDA'
278
+ roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2 ). to ( device = device )
279
+ y = roi_align (x , single_roi )
280
+
281
+ assert torch .allclose (gt_y_single , y ), 'ROIAlign layer incorrect for single ROI on CUDA'
260
282
261
283
@unittest .skipIf (not torch .cuda .is_available (), "CUDA unavailable" )
262
284
def test_roi_align_cuda (self ):
263
285
device = torch .device ('cuda' )
264
- self .x = self .x .to (device )
265
- self .rois = self .rois .to (device )
266
- self .gt_y_multiple = self .gt_y_multiple .to (device )
286
+ x = self .x .to (device )
287
+ rois = self .rois .to (device )
288
+ gt_y_multiple = self .gt_y_multiple .to (device )
289
+
290
+ pool_h , pool_w = (5 , 5 )
291
+ roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2 ).to (device = device )
292
+ y = roi_align (x , rois )
293
+
294
+ assert torch .allclose (gt_y_multiple , y ), 'ROIAlign layer incorrect for multiple ROIs on CUDA'
295
+
296
+ def test_roi_align_gradient_cpu (self ):
297
+ """
298
+ Compute gradients for ROIAlign with multiple bounding boxes on CPU
299
+ """
300
+ device = torch .device ('cpu' )
301
+ pool_h , pool_w = (5 , 5 )
302
+ roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2 ).to (device = device )
303
+
304
+ x = self .x .to (device ).clone ()
305
+ rois = self .rois .to (device )
306
+ gt_grad = self .x_grad .to (device )
267
307
308
+ x .requires_grad = True
309
+ y = roi_align (x , rois )
310
+ s = y .sum ()
311
+ s .backward ()
312
+
313
+ assert torch .allclose (x .grad , gt_grad ), 'gradient incorrect for ROIAlign CPU'
314
+
315
+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA unavailable" )
316
+ def test_roi_align_gradient_cuda (self ):
317
+ """
318
+ Compute gradients for ROIAlign with multiple bounding boxes on the GPU
319
+ """
320
+ device = torch .device ('cuda' )
268
321
pool_h , pool_w = (5 , 5 )
269
- roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2.0 )
270
- y = roi_align (self .x , self .rois )
322
+ roi_align = layers .ROIAlign ((pool_h , pool_w ), spatial_scale = 1 , sampling_ratio = 2 ).to (device = device )
323
+
324
+ x = self .x .to (device ).clone ()
325
+ rois = self .rois .to (device )
326
+ gt_grad = self .x_grad .to (device )
271
327
272
- assert torch .allclose (self .gt_y_multiple , y ), 'ROIAlign layer incorrect for multiple ROIs on CUDA'
328
+ x .requires_grad = True
329
+ y = roi_align (x , rois )
330
+ s = y .sum ()
331
+ s .backward ()
273
332
333
+ assert torch .allclose (x .grad , gt_grad ), 'gradient incorrect for ROIAlign CUDA'
274
334
275
335
if __name__ == '__main__' :
276
336
unittest .main ()
0 commit comments