Skip to content

Commit 4bdc54c

Browse files
committed
Fix RoIPool reference implementation in Python 2
Also fixes a bug in the clip_boxes_to_image -- this function needs a test!
1 parent dfe8ec1 commit 4bdc54c

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

test/test_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def slow_roi_pooling(self, x, rois, pool_h, pool_w, spatial_scale=1,
2626
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
2727
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
2828
roi_x = x[roi[0].long(), :, start_h:end_h, start_w:end_w]
29-
bin_h, bin_w = roi_x.size(-2) / pool_h, roi_x.size(-1) / pool_w
29+
bin_h, bin_w = roi_x.size(-2) / float(pool_h), roi_x.size(-1) / float(pool_w)
3030

3131
for j in range(0, pool_h):
3232
cj = slice(int(np.floor(j * bin_h)), int(np.ceil((j + 1) * bin_h)))

torchvision/ops/boxes.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,14 @@ def clip_boxes_to_image(boxes, size):
7272
Returns:
7373
clipped_boxes (Tensor[N, 4])
7474
"""
75-
boxes_x = boxes[:, 0::2]
76-
boxes_y = boxes[:, 1::2]
75+
dim = boxes.dim()
76+
boxes_x = boxes[..., 0::2]
77+
boxes_y = boxes[..., 1::2]
7778
height, width = size
7879
boxes_x = boxes_x.clamp(min=0, max=width)
79-
boxes_y = boxes_x.clamp(min=0, max=height)
80-
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=2)
81-
return clipped_boxes.reshape(-1, 4)
80+
boxes_y = boxes_y.clamp(min=0, max=height)
81+
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
82+
return clipped_boxes.reshape(boxes.shape)
8283

8384

8485
def box_area(boxes):

0 commit comments

Comments
 (0)