diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index d3100bc81d5..143f1279c54 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -251,6 +251,8 @@ def __init__( ArgsKwargs(p=0), ArgsKwargs(p=1), ], + # Use default tolerances of `torch.testing.assert_close` + closeness_kwargs=dict(rtol=None, atol=None), ), ConsistencyConfig( prototype_transforms.RandomAdjustSharpness, diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index e46dfd74935..671f3fcf90f 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -377,17 +377,23 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: return image bound = _FT._max_value(image.dtype) - dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + fp = image.is_floating_point() + float_image = image if fp else image.to(torch.float32) - minimum = image.amin(dim=(-2, -1), keepdim=True).to(dtype) - maximum = image.amax(dim=(-2, -1), keepdim=True).to(dtype) + minimum = float_image.amin(dim=(-2, -1), keepdim=True) + maximum = float_image.amax(dim=(-2, -1), keepdim=True) - scale = bound / (maximum - minimum) eq_idxs = maximum == minimum + inv_scale = maximum.sub_(minimum).div_(bound) minimum[eq_idxs] = 0.0 - scale[eq_idxs] = 1.0 + inv_scale[eq_idxs] = 1.0 + + if fp: + diff = float_image.sub(minimum) + else: + diff = float_image.sub_(minimum) - return (image - minimum).mul_(scale).clamp_(0, bound).to(image.dtype) + return diff.div_(inv_scale).clamp_(0, bound).to(image.dtype) autocontrast_image_pil = _FP.autocontrast