Skip to content

Commit 8170055

Browse files
authored
Test some flaky detection models on float64 instead of float32 (#7204)
1 parent d75a524 commit 8170055

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

test/test_models.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def list_model_fns(module):
2929
return [get_model_builder(name) for name in list_models(module)]
3030

3131

32-
def _get_image(input_shape, real_image, device):
32+
def _get_image(input_shape, real_image, device, dtype=None):
3333
"""This routine loads a real or random image based on `real_image` argument.
3434
Currently, the real image is utilized for the following list of models:
3535
- `retinanet_resnet50_fpn`,
@@ -60,10 +60,10 @@ def _get_image(input_shape, real_image, device):
6060
convert_tensor = transforms.ToTensor()
6161
image = convert_tensor(img)
6262
assert tuple(image.size()) == input_shape
63-
return image.to(device=device)
63+
return image.to(device=device, dtype=dtype)
6464

6565
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
66-
return torch.rand(input_shape).to(device=device)
66+
return torch.rand(input_shape).to(device=device, dtype=dtype)
6767

6868

6969
@pytest.fixture
@@ -278,6 +278,11 @@ def _check_input_backprop(model, inputs):
278278
# tests under test_quantized_classification_model will be skipped for the following models.
279279
quantized_flaky_models = ("inception_v3", "resnet50")
280280

281+
# The tests for the following detection models are flaky.
282+
# We run those tests on float64 to avoid floating point errors.
283+
# FIXME: we shouldn't have to do that :'/
284+
detection_flaky_models = ("keypointrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn_v2")
285+
281286

282287
# The following contains configuration parameters for all models which are used by
283288
# the _test_*_model methods.
@@ -777,13 +782,17 @@ def test_detection_model(model_fn, dev):
777782
"input_shape": (3, 300, 300),
778783
}
779784
model_name = model_fn.__name__
785+
if model_name in detection_flaky_models:
786+
dtype = torch.float64
787+
else:
788+
dtype = torch.get_default_dtype()
780789
kwargs = {**defaults, **_model_params.get(model_name, {})}
781790
input_shape = kwargs.pop("input_shape")
782791
real_image = kwargs.pop("real_image", False)
783792

784793
model = model_fn(**kwargs)
785-
model.eval().to(device=dev)
786-
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev)
794+
model.eval().to(device=dev, dtype=dtype)
795+
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev, dtype=dtype)
787796
model_input = [x]
788797
with torch.no_grad(), freeze_rng_state():
789798
out = model(model_input)

0 commit comments

Comments
 (0)