-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Fix flakiness on detection tests #2966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8977b7a
23c153c
ecec6c2
b824945
f2ac321
7ae01df
83b2682
4b59173
f908aef
3ae8f09
496cf21
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,15 @@ | ||
from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state | ||
from collections import OrderedDict | ||
from itertools import product | ||
import functools | ||
import operator | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
from torchvision import models | ||
import unittest | ||
import random | ||
|
||
from torchvision.models.detection._utils import overwrite_eps | ||
import warnings | ||
|
||
|
||
def set_rng_seed(seed): | ||
|
@@ -88,14 +89,10 @@ def get_available_video_models(): | |
# trying autocast. However, they still try an autocasted forward pass, so they still ensure | ||
# autocast coverage suffices to prevent dtype errors in each model. | ||
autocast_flaky_numerics = ( | ||
"fasterrcnn_resnet50_fpn", | ||
"inception_v3", | ||
"keypointrcnn_resnet50_fpn", | ||
"maskrcnn_resnet50_fpn", | ||
"resnet101", | ||
"resnet152", | ||
"wide_resnet101_2", | ||
"retinanet_resnet50_fpn", | ||
) | ||
|
||
|
||
|
@@ -148,10 +145,9 @@ def _test_detection_model(self, name, dev): | |
set_rng_seed(0) | ||
kwargs = {} | ||
if "retinanet" in name: | ||
kwargs["score_thresh"] = 0.013 | ||
# Reduce the default threshold to ensure the returned boxes are not empty. | ||
kwargs["score_thresh"] = 0.01 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use a threshold that will produce many more boxes and make things more interesting. |
||
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs) | ||
if "keypointrcnn" in name or "retinanet" in name: | ||
overwrite_eps(model, 0.0) | ||
model.eval().to(device=dev) | ||
input_shape = (3, 300, 300) | ||
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests | ||
|
@@ -163,15 +159,22 @@ def _test_detection_model(self, name, dev): | |
def check_out(out): | ||
self.assertEqual(len(out), 1) | ||
|
||
def compact(tensor): | ||
size = tensor.size() | ||
elements_per_sample = functools.reduce(operator.mul, size[1:], 1) | ||
if elements_per_sample > 30: | ||
return compute_mean_std(tensor) | ||
else: | ||
return subsample_tensor(tensor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some tensors are too large to store on expected (for example masks). If the size of the input is over 30 elements per record, then we assert its mean and standard dev instead. |
||
|
||
def subsample_tensor(tensor): | ||
num_elems = tensor.numel() | ||
num_elems = tensor.size(0) | ||
num_samples = 20 | ||
if num_elems <= num_samples: | ||
return tensor | ||
|
||
flat_tensor = tensor.flatten() | ||
ith_index = num_elems // num_samples | ||
return flat_tensor[ith_index - 1::ith_index] | ||
return tensor[ith_index - 1::ith_index] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We no longer flatten the results of the sample. This can be useful if we want to maintain the structure of the boxes. |
||
|
||
def compute_mean_std(tensor): | ||
# can't compute mean of integral tensor | ||
|
@@ -180,18 +183,32 @@ def compute_mean_std(tensor): | |
std = torch.std(tensor) | ||
return {"mean": mean, "std": std} | ||
|
||
if name == "maskrcnn_resnet50_fpn": | ||
# maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now | ||
# compare results with mean and std | ||
test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std) | ||
# mean values are small, use large prec | ||
self.assertExpected(test_value, prec=.01, strip_suffix="_" + dev) | ||
else: | ||
self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor), | ||
prec=0.01, | ||
strip_suffix="_" + dev) | ||
|
||
check_out(out) | ||
output = map_nested_tensor_object(out, tensor_map_fn=compact) | ||
prec = 0.01 | ||
strip_suffix = "_" + dev | ||
try: | ||
# We first try to assert the entire output if possible. This is not | ||
# only the best way to assert results but also handles the cases | ||
# where we need to create a new expected result. | ||
self.assertExpected(output, prec=prec, strip_suffix=strip_suffix) | ||
except AssertionError: | ||
# Unfortunately detection models are flaky due to the unstable sort | ||
# in NMS. If matching across all outputs fails, use the same approach | ||
# as in NMSTester.test_nms_cuda to see if this is caused by duplicate | ||
# scores. | ||
expected_file = self._get_expected_file(strip_suffix=strip_suffix) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note for the future: we might move away from our custom base Tester in the future in favor of PyTorch one. |
||
expected = torch.load(expected_file) | ||
self.assertEqual(output[0]["scores"], expected[0]["scores"], prec=prec) | ||
|
||
# Note: Fmassa proposed turning off NMS by adapting the threshold | ||
# and then using the Hungarian algorithm as in DETR to find the | ||
# best match between output and expected boxes and eliminate some | ||
# of the flakiness. Worth exploring. | ||
return False # Partial validation performed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we reached this point, we partially validated the results. Return False so that we can flag this accordingly. |
||
|
||
return True # Full validation performed | ||
|
||
full_validation = check_out(out) | ||
|
||
scripted_model = torch.jit.script(model) | ||
scripted_model.eval() | ||
|
@@ -200,9 +217,6 @@ def compute_mean_std(tensor): | |
self.assertEqual(scripted_out[0]["scores"], out[0]["scores"]) | ||
# labels currently float in script: need to investigate (though same result) | ||
self.assertEqual(scripted_out[0]["labels"].to(dtype=torch.long), out[0]["labels"]) | ||
self.assertTrue("boxes" in out[0]) | ||
self.assertTrue("scores" in out[0]) | ||
self.assertTrue("labels" in out[0]) | ||
# don't check script because we are compiling it here: | ||
# TODO: refactor tests | ||
# self.check_script(model, name) | ||
|
@@ -213,7 +227,15 @@ def compute_mean_std(tensor): | |
out = model(model_input) | ||
# See autocast_flaky_numerics comment at top of file. | ||
if name not in autocast_flaky_numerics: | ||
check_out(out) | ||
full_validation &= check_out(out) | ||
|
||
if not full_validation: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If any of the validations was partial, flag the test as skipped and raise a warning. It's better to flag it as skipped than mark it as green. |
||
msg = "The output of {} could only be partially validated. " \ | ||
"This is likely due to unit-test flakiness, but you may " \ | ||
"want to do additional manual checks if you made " \ | ||
"significant changes to the codebase.".format(self._testMethodName) | ||
warnings.warn(msg, RuntimeWarning) | ||
raise unittest.SkipTest(msg) | ||
|
||
def _test_detection_model_validation(self, name): | ||
set_rng_seed(0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do manual check instead of an assertion. We don't want to throw an
AssertionError
exception that will be caught by the external try.