@@ -29,7 +29,7 @@ def list_model_fns(module):
29
29
return [get_model_builder (name ) for name in list_models (module )]
30
30
31
31
32
- def _get_image (input_shape , real_image , device ):
32
+ def _get_image (input_shape , real_image , device , dtype = None ):
33
33
"""This routine loads a real or random image based on `real_image` argument.
34
34
Currently, the real image is utilized for the following list of models:
35
35
- `retinanet_resnet50_fpn`,
@@ -60,10 +60,10 @@ def _get_image(input_shape, real_image, device):
60
60
convert_tensor = transforms .ToTensor ()
61
61
image = convert_tensor (img )
62
62
assert tuple (image .size ()) == input_shape
63
- return image .to (device = device )
63
+ return image .to (device = device , dtype = dtype )
64
64
65
65
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
66
- return torch .rand (input_shape ).to (device = device )
66
+ return torch .rand (input_shape ).to (device = device , dtype = dtype )
67
67
68
68
69
69
@pytest .fixture
@@ -278,6 +278,11 @@ def _check_input_backprop(model, inputs):
278
278
# tests under test_quantized_classification_model will be skipped for the following models.
279
279
quantized_flaky_models = ("inception_v3" , "resnet50" )
280
280
281
+ # The tests for the following detection models are flaky.
282
+ # We run those tests on float64 to avoid floating point errors.
283
+ # FIXME: we shouldn't have to do that :'/
284
+ detection_flaky_models = ("keypointrcnn_resnet50_fpn" , "maskrcnn_resnet50_fpn" , "maskrcnn_resnet50_fpn_v2" )
285
+
281
286
282
287
# The following contains configuration parameters for all models which are used by
283
288
# the _test_*_model methods.
@@ -777,13 +782,17 @@ def test_detection_model(model_fn, dev):
777
782
"input_shape" : (3 , 300 , 300 ),
778
783
}
779
784
model_name = model_fn .__name__
785
+ if model_name in detection_flaky_models :
786
+ dtype = torch .float64
787
+ else :
788
+ dtype = torch .get_default_dtype ()
780
789
kwargs = {** defaults , ** _model_params .get (model_name , {})}
781
790
input_shape = kwargs .pop ("input_shape" )
782
791
real_image = kwargs .pop ("real_image" , False )
783
792
784
793
model = model_fn (** kwargs )
785
- model .eval ().to (device = dev )
786
- x = _get_image (input_shape = input_shape , real_image = real_image , device = dev )
794
+ model .eval ().to (device = dev , dtype = dtype )
795
+ x = _get_image (input_shape = input_shape , real_image = real_image , device = dev , dtype = dtype )
787
796
model_input = [x ]
788
797
with torch .no_grad (), freeze_rng_state ():
789
798
out = model (model_input )
0 commit comments