diff --git a/test/common_utils.py b/test/common_utils.py index 697b6f6e4ca..bd945b09e21 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -351,7 +351,7 @@ def assert_close( def parametrized_error_message(*args, **kwargs): def to_str(obj): - if isinstance(obj, torch.Tensor) and obj.numel() > 10: + if isinstance(obj, torch.Tensor) and obj.numel() > 30: return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})" elif isinstance(obj, enum.Enum): return f"{type(obj).__name__}.{obj.name}" diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index e648b35d441..ee9576b6487 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -146,7 +146,7 @@ def test_scripted_vs_eager(self, test_id, info, args_kwargs, device): actual, expected, **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), - msg=parametrized_error_message(*([actual, expected] + other_args), **kwargs), + msg=parametrized_error_message(input, other_args, **kwargs), ) def _unbatch(self, batch, *, data_dims): @@ -204,7 +204,7 @@ def test_batched_vs_single(self, test_id, info, args_kwargs, device): actual, expected, **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device), - msg=parametrized_error_message(*other_args, **kwargs), + msg=parametrized_error_message(batched_input, *other_args, **kwargs), ) @sample_inputs @@ -236,7 +236,7 @@ def test_cuda_vs_cpu(self, test_id, info, args_kwargs): output_cpu, check_device=False, **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device), - msg=parametrized_error_message(*other_args, **kwargs), + msg=parametrized_error_message(input_cpu, *other_args, **kwargs), ) @sample_inputs @@ -294,7 +294,7 @@ def test_float32_vs_uint8(self, test_id, info, args_kwargs): actual, expected, **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device), - msg=parametrized_error_message(*other_args, **kwargs), + msg=parametrized_error_message(input, *other_args, **kwargs), ) diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index 3c3611cb8cc..6fea2513712 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -860,8 +860,8 @@ def sample_inputs_rotate_video(): reference_fn=reference_rotate_bounding_box, reference_inputs_fn=reference_inputs_rotate_bounding_box, closeness_kwargs={ - **scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6), - **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5), + **scripted_vs_eager_float64_tolerances("cpu", atol=1e-4, rtol=1e-4), + **scripted_vs_eager_float64_tolerances("cuda", atol=1e-4, rtol=1e-4), }, ), KernelInfo(