@@ -15,7 +15,8 @@ def setUpClass(cls):
15
15
16
16
def slow_roi_pooling (self , x , rois , pool_h , pool_w , spatial_scale = 1 ,
17
17
device = torch .device ('cpu' ), dtype = torch .float64 ):
18
- y = torch .zeros (rois .size (0 ), x .size (1 ), pool_h , pool_w , dtype = dtype , device = device )
18
+ c = x .size (1 )
19
+ y = torch .zeros (rois .size (0 ), c , pool_h , pool_w , dtype = dtype , device = device )
19
20
20
21
rois = torch .round (rois * spatial_scale )
21
22
@@ -24,14 +25,16 @@ def slow_roi_pooling(self, x, rois, pool_h, pool_w, spatial_scale=1,
24
25
if roi [0 ] == n :
25
26
start_h , end_h = int (roi [2 ].item ()), int (roi [4 ].item ()) + 1
26
27
start_w , end_w = int (roi [1 ].item ()), int (roi [3 ].item ()) + 1
27
- roi_x = x [roi [0 ].long (): roi [ 0 ]. long () + 1 , :, start_h :end_h , start_w :end_w ]
28
- bin_h , bin_w = roi_x .size (2 ) / pool_h , roi_x .size (3 ) / pool_w
28
+ roi_x = x [roi [0 ].long (), :, start_h :end_h , start_w :end_w ]
29
+ bin_h , bin_w = roi_x .size (- 2 ) / pool_h , roi_x .size (- 1 ) / pool_w
29
30
30
31
for j in range (0 , pool_h ):
31
32
cj = slice (int (np .floor (j * bin_h )), int (np .ceil ((j + 1 ) * bin_h )))
32
33
for i in range (0 , pool_w ):
33
34
ci = slice (int (np .floor (i * bin_w )), int (np .ceil ((i + 1 ) * bin_w )))
34
- y [r , :, j , i ] = torch .max (y [r , :, j , i ], torch .max (roi_x [:, :, cj , ci ]))
35
+ t = roi_x [:, cj , ci ].reshape (c , - 1 )
36
+ if t .numel () > 0 :
37
+ y [r , :, j , i ] = torch .max (t , 1 )[0 ]
35
38
return y
36
39
37
40
def test_roi_pool_basic_cpu (self ):
@@ -75,6 +78,34 @@ def test_roi_pool_cpu(self):
75
78
gt_y = self .slow_roi_pooling (x .permute (0 , 1 , 3 , 2 ), rois , pool_h , pool_w , device = device , dtype = self .dtype )
76
79
assert torch .allclose (gt_y , y ), 'RoIPool layer incorrect on CPU for batch > 1'
77
80
81
+ def test_roi_pool_cpu_empty_rois (self ):
82
+ device = torch .device ('cpu' )
83
+ x = torch .tensor (
84
+ [[[[0.1767 , 1.2851 , 4.2325 , 4.8645 , 7.1496 ]],
85
+ [[2.5916 , 4.3361 , 3.8143 , 6.1329 , 2.0230 ]],
86
+ [[1.4492 , 3.3384 , 4.0816 , 6.3116 , 5.1068 ]]]],
87
+ dtype = self .dtype , device = device )
88
+ rois = torch .tensor (
89
+ [[0. , 1. , 0. , 4. , 0. ],
90
+ [0. , 2. , 0. , 3. , 0. ],
91
+ [0. , 0. , 0. , 0. , 0. ],
92
+ [0. , 0. , 0. , 0. , 0. ],
93
+ [0. , 2. , 0. , 2. , 0. ]],
94
+ dtype = self .dtype , device = device )
95
+
96
+ pool_h , pool_w = (1 , 2 )
97
+ roi_pool = ops .RoIPool ((pool_h , pool_w ), 1 )
98
+ y = roi_pool (x , rois )
99
+
100
+ gt_y = self .slow_roi_pooling (x , rois , pool_h , pool_w , device = device , dtype = self .dtype )
101
+
102
+ assert torch .allclose (gt_y , y ), 'RoIPool layer incorrect on CPU empty rois'
103
+
104
+ # non-contiguous
105
+ y = roi_pool (x .permute (0 , 1 , 3 , 2 ), rois )
106
+ gt_y = self .slow_roi_pooling (x .permute (0 , 1 , 3 , 2 ), rois , pool_h , pool_w , device = device , dtype = self .dtype )
107
+ assert torch .allclose (gt_y , y ), 'RoIPool layer incorrect on CPU for empty rois non-contiguous'
108
+
78
109
def test_roi_pool_gradient_cpu (self ):
79
110
device = torch .device ('cpu' )
80
111
x = torch .ones (1 , 1 , 10 , 10 , dtype = self .dtype , device = device , requires_grad = True )
0 commit comments