diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 3714fc13682..83b0de8ea53 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -137,7 +137,8 @@ def _apply_image_or_video_transform( elif transform_id == "Posterize": return F.posterize(image, bits=int(magnitude)) elif transform_id == "Solarize": - return F.solarize(image, threshold=magnitude) + bound = 1.0 if isinstance(image, torch.Tensor) and image.is_floating_point() else 255.0 + return F.solarize(image, threshold=bound * magnitude) elif transform_id == "AutoContrast": return F.autocontrast(image) elif transform_id == "Equalize": @@ -169,7 +170,7 @@ class AutoAugment(_AutoAugmentBase): lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), False, ), - "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), "AutoContrast": (lambda num_bins, height, width: None, False), "Equalize": (lambda num_bins, height, width: None, False), "Invert": (lambda num_bins, height, width: None, False), @@ -324,7 +325,7 @@ class RandAugment(_AutoAugmentBase): lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), False, ), - "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), "AutoContrast": (lambda num_bins, height, width: None, False), "Equalize": (lambda num_bins, height, width: None, False), } @@ -378,7 +379,7 @@ class TrivialAugmentWide(_AutoAugmentBase): lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(), False, ), - "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), "AutoContrast": (lambda num_bins, height, width: None, False), "Equalize": (lambda num_bins, height, width: None, False), } @@ -423,7 +424,7 @@ class AugMix(_AutoAugmentBase): lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), False, ), - "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), + "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), "AutoContrast": (lambda num_bins, height, width: None, False), "Equalize": (lambda num_bins, height, width: None, False), } @@ -505,13 +506,7 @@ def forward(self, *inputs: Any) -> Any: aug = self._apply_image_or_video_transform( aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) - mix.add_( - # The multiplication below could become in-place provided `aug is not batch and aug.is_floating_point()` - # Currently we can't do this because `aug` has to be `unint8` to support ops like `equalize`. - # TODO: change this once all ops in `F` support floats. https://github.com/pytorch/vision/issues/6840 - combined_weights[:, i].reshape(batch_dims) - * aug - ) + mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) if isinstance(orig_image_or_video, (features.Image, features.Video)):