Skip to content

Unify input checks on detection models #2295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,36 +197,52 @@ def compute_mean_std(tensor):
# TODO: refactor tests
# self.check_script(model, name)
self.checkModule(model, name, ([x],))

if dev == "cuda":
with torch.cuda.amp.autocast():
out = model(model_input)
# See autocast_flaky_numerics comment at top of file.
if name not in autocast_flaky_numerics:
check_out(out)

def _test_detection_model_validation(self, name):
def _test_detection_model_checks(self, name):
set_rng_seed(0)
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
input_shape = (3, 300, 300)

input_shape = (1, 3, 300, 300)
x = [torch.rand(input_shape)]

# validate that targets are present in training
self.assertRaises(ValueError, model, x)
N = 4 # nb of boxes
targets = [{"boxes": None, "labels":None}]

def test_tensor_checks(tname, dtype, shape):
# presence check
if tname in targets[0]:
del targets[0][tname]
self.assertRaises(ValueError, model, x, targets=targets)

# validate type
targets = [{'boxes': 0.}]
self.assertRaises(ValueError, model, x, targets=targets)
# type check
targets[0][tname] = torch.zeros((1,),
dtype=torch.bool if dtype != torch.bool else torch.float)
self.assertRaises(ValueError, model, x, targets=targets)

# validate boxes shape
for boxes in (torch.rand((4,)), torch.rand((1, 5))):
targets = [{'boxes': boxes}]
# shape check
targets[0][tname] = torch.zeros((*shape, 1), dtype=dtype)
self.assertRaises(ValueError, model, x, targets=targets)

# validate that no degenerate boxes are present
boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]])
targets = [{'boxes': boxes}]
self.assertRaises(ValueError, model, x, targets=targets)
# set the Tensor to the correct shape and dtype for next tests
targets[0][tname] = torch.zeros(shape, dtype=dtype)

# check that targets are available when training
self.assertRaises(ValueError, model, x)

test_tensor_checks("boxes", torch.float, (N, 4))
test_tensor_checks("labels", torch.int64, (N,))

if "mask" in name:
test_tensor_checks("masks", torch.uint8, (N, 300, 300))
if "keypoint" in name:
test_tensor_checks("keypoints", torch.float, (N, 5, 3))

def _test_video_model(self, name, dev):
# the default input shape is
Expand Down Expand Up @@ -412,9 +428,9 @@ def do_test(self, model_name=model_name, dev=dev):
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)

def do_validation_test(self, model_name=model_name):
self._test_detection_model_validation(model_name)
self._test_detection_model_checks(model_name)

setattr(ModelTester, "test_" + model_name + "_validation", do_validation_test)
setattr(ModelTester, "test_" + model_name + "_checks", do_validation_test)


for model_name in get_available_video_models():
Expand Down
2 changes: 1 addition & 1 deletion test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _make_empty_sample(self, add_masks=False, add_keypoints=False):
negative_target["masks"] = torch.zeros(0, 100, 100, dtype=torch.uint8)

if add_keypoints:
negative_target["keypoints"] = torch.zeros(17, 0, 3, dtype=torch.float32)
negative_target["keypoints"] = torch.zeros(0, 17, 3, dtype=torch.float32)

targets = [negative_target]
return images, targets
Expand Down
12 changes: 0 additions & 12 deletions torchvision/models/detection/generalized_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,6 @@ def forward(self, images, targets=None):
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
if self.training:
assert targets is not None
for target in targets:
boxes = target["boxes"]
if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError("Expected target boxes to be a tensor"
"of shape [N, 4], got {:}.".format(
boxes.shape))
else:
raise ValueError("Expected target boxes to be of type "
"Tensor, got {:}.".format(type(boxes)))

original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
for img in images:
Expand Down
93 changes: 78 additions & 15 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,57 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1):
return ret


def _check_target_item(target, # type: Dict[str, Tensor]
key, # type: str
dtype, # type: List[torch.dtype]
shape, # type: List[Optional[int]]
shape_string="", # type: Optional[str]
):
# type: (...) -> None
"""
Checks that a key in target corresponds to a Tensor with correct shape and
dtype.

Args:
target (Dict[str, Tensor]): target for a training instance containing
the Tensors to be checked
key (str): key for the Tensor to be checked in target.
dtype (List[torch.dtype]): a list containing the possible dtypes of
the Tensor.
shape (List[int]): the expected shape of the Tensor. This function checks
the number of dimensions and any non-None values of this list.
shape_string (Optional[str]): optional string for exception messages.
If messages. If not specified, shape will be cast into string and
used instead.

Raises:
ValueError if the Tensor fails a check.
"""
if len(shape_string) == 0:
shape_string = str(shape)

if key not in target:
raise ValueError("Key '{:}' not found in targets.".format(key))

arr = target[key]
if not isinstance(arr, torch.Tensor):
raise ValueError("Expected target {:} to be of type Tensor, got {:}."
.format(key, type(arr)))
if arr.dtype not in dtype:
raise ValueError("Expected target {:} to be a Tensor with dtype {:}, "
"got {:}.".format(key, dtype, arr.dtype))
if len(arr.shape) != len(shape):
raise ValueError("Expected target {:} to be Tensor with shape {:}, got"
" {:}.".format(key, shape_string, arr.shape))
for i in range(len(shape)):
item = shape[i]
if item is not None:
if arr.shape[i] != shape[i]:
raise ValueError("Expected target {:} to be Tensor with shape "
"{:}, got {:}.".format(key,
shape_string, arr.shape))


class RoIHeads(torch.nn.Module):
__annotations__ = {
'box_coder': det_utils.BoxCoder,
Expand All @@ -505,6 +556,7 @@ def __init__(self,
mask_roi_pool=None,
mask_head=None,
mask_predictor=None,
# Keypoints
keypoint_roi_pool=None,
keypoint_head=None,
keypoint_predictor=None,
Expand Down Expand Up @@ -619,18 +671,37 @@ def add_gt_proposals(self, proposals, gt_boxes):

def check_targets(self, targets):
# type: (Optional[List[Dict[str, Tensor]]]) -> None
assert targets is not None
assert all(["boxes" in t for t in targets])
assert all(["labels" in t for t in targets])
if self.has_mask():
assert all(["masks" in t for t in targets])
"""
Check that the training targets contain the necessary Tensors with
correct shape and dtype.

Args:
targets (List[Dict[str, Tensor]]): ground-truth boxes present in
the images.

Raises:
ValueError if targets fails a check.
"""
# TODO: https://github.com/pytorch/pytorch/issues/26731
floating_point_types = [torch.float, torch.double, torch.half]

if targets is None:
raise ValueError("In training mode, targets should be passed")

for t in targets:
_check_target_item(t, "boxes", floating_point_types, [None, 4], shape_string="[N, 4]")
N = t["boxes"].shape[0] # must match for labels, masks and keypoints
_check_target_item(t, "labels", [torch.int64], [N], shape_string="[N,]")
if self.has_mask():
_check_target_item(t, "masks", [torch.uint8], [N, None, None], shape_string="[N, H, W]")
if self.has_keypoint():
_check_target_item(t, "keypoints", floating_point_types, [N, None, 3], shape_string="[N, K, 3]")

def select_training_samples(self,
proposals, # type: List[Tensor]
targets # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
self.check_targets(targets)
assert targets is not None
dtype = proposals[0].dtype
device = proposals[0].device
Expand Down Expand Up @@ -733,16 +804,8 @@ def forward(self,
image_shapes (List[Tuple[H, W]])
targets (List[Dict])
"""
if targets is not None:
for t in targets:
# TODO: https://github.com/pytorch/pytorch/issues/26731
floating_point_types = (torch.float, torch.double, torch.half)
assert t["boxes"].dtype in floating_point_types, 'target boxes must of float type'
assert t["labels"].dtype == torch.int64, 'target labels must of int64 type'
if self.has_keypoint():
assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type'

if self.training:
self.check_targets(targets)
proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
else:
labels = None
Expand Down