diff --git a/model.py b/model.py index 302e111..16910b9 100644 --- a/model.py +++ b/model.py @@ -565,8 +565,9 @@ def detection_target_layer(proposals, gt_class_ids, gt_boxes, gt_masks, config): # Handle COCO crowds # A crowd box in COCO is a bounding box around several instances. Exclude # them from training. A crowd box is given a negative class ID. - if torch.nonzero(gt_class_ids < 0).size(): - crowd_ix = torch.nonzero(gt_class_ids < 0)[:, 0] + negative_gt_class_ids = torch.nonzero(gt_class_ids < 0) + if len(negative_gt_class_ids) > 0: + crowd_ix = negative_gt_class_ids[:, 0] non_crowd_ix = torch.nonzero(gt_class_ids > 0)[:, 0] crowd_boxes = gt_boxes[crowd_ix.data, :] crowd_masks = gt_masks[crowd_ix.data, :, :]