Skip to content

Commit 924d373

Browse files
pmeierNicolasHug
andauthored
fix flaky test for rotate_bounding_box (#7362)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent feda8b7 commit 924d373

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

test/common_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def assert_close(
351351

352352
def parametrized_error_message(*args, **kwargs):
353353
def to_str(obj):
354-
if isinstance(obj, torch.Tensor) and obj.numel() > 10:
354+
if isinstance(obj, torch.Tensor) and obj.numel() > 30:
355355
return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
356356
elif isinstance(obj, enum.Enum):
357357
return f"{type(obj).__name__}.{obj.name}"

test/test_transforms_v2_functional.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
146146
actual,
147147
expected,
148148
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
149-
msg=parametrized_error_message(*([actual, expected] + other_args), **kwargs),
149+
msg=parametrized_error_message(input, other_args, **kwargs),
150150
)
151151

152152
def _unbatch(self, batch, *, data_dims):
@@ -204,7 +204,7 @@ def test_batched_vs_single(self, test_id, info, args_kwargs, device):
204204
actual,
205205
expected,
206206
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
207-
msg=parametrized_error_message(*other_args, **kwargs),
207+
msg=parametrized_error_message(batched_input, *other_args, **kwargs),
208208
)
209209

210210
@sample_inputs
@@ -236,7 +236,7 @@ def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
236236
output_cpu,
237237
check_device=False,
238238
**info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
239-
msg=parametrized_error_message(*other_args, **kwargs),
239+
msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
240240
)
241241

242242
@sample_inputs
@@ -294,7 +294,7 @@ def test_float32_vs_uint8(self, test_id, info, args_kwargs):
294294
actual,
295295
expected,
296296
**info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
297-
msg=parametrized_error_message(*other_args, **kwargs),
297+
msg=parametrized_error_message(input, *other_args, **kwargs),
298298
)
299299

300300

test/transforms_v2_kernel_infos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -860,8 +860,8 @@ def sample_inputs_rotate_video():
860860
reference_fn=reference_rotate_bounding_box,
861861
reference_inputs_fn=reference_inputs_rotate_bounding_box,
862862
closeness_kwargs={
863-
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
864-
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
863+
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-4, rtol=1e-4),
864+
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-4, rtol=1e-4),
865865
},
866866
),
867867
KernelInfo(

0 commit comments

Comments
 (0)