From 82662d26c8be8804cd6da99ecfcb483c1dd4dac6 Mon Sep 17 00:00:00 2001 From: "deepsource-autofix[bot]" <62050782+deepsource-autofix[bot]@users.noreply.github.com> Date: Wed, 13 Oct 2021 00:28:28 +0200 Subject: [PATCH] Refactor unnecessary `else` / `elif` when `if` block has a `return` statement Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec --- .circleci/regenerate.py | 6 +- test/common_utils.py | 8 +- test/test_datasets.py | 7 +- test/test_datasets_download.py | 3 +- test/test_image.py | 3 +- test/test_models.py | 3 +- test/test_onnx.py | 3 +- torchvision/datasets/cityscapes.py | 7 +- torchvision/datasets/folder.py | 3 +- torchvision/datasets/imagenet.py | 11 ++- torchvision/datasets/inaturalist.py | 14 ++- torchvision/datasets/stl10.py | 2 +- torchvision/io/video.py | 3 +- .../models/detection/backbone_utils.py | 15 ++-- .../models/detection/generalized_rcnn.py | 3 +- torchvision/models/googlenet.py | 6 +- torchvision/models/inception.py | 6 +- torchvision/models/mnasnet.py | 3 +- torchvision/models/mobilenetv2.py | 3 +- torchvision/models/quantization/googlenet.py | 3 +- torchvision/models/quantization/inception.py | 3 +- .../models/quantization/mobilenetv2.py | 3 +- .../models/quantization/mobilenetv3.py | 3 +- torchvision/ops/boxes.py | 6 +- .../prototype/datasets/_builtin/coco.py | 5 +- .../prototype/datasets/_builtin/mnist.py | 5 +- .../prototype/datasets/_builtin/sbd.py | 7 +- .../prototype/datasets/_builtin/voc.py | 7 +- .../prototype/datasets/utils/_internal.py | 5 +- torchvision/transforms/autoaugment.py | 7 +- torchvision/transforms/functional.py | 6 +- torchvision/transforms/functional_pil.py | 87 +++++++++---------- torchvision/transforms/functional_tensor.py | 53 ++++++----- 33 files changed, 134 insertions(+), 175 deletions(-) diff --git a/.circleci/regenerate.py b/.circleci/regenerate.py index 2e1d397394c..4912607ea2c 100755 --- a/.circleci/regenerate.py +++ b/.circleci/regenerate.py @@ -134,10 +134,10 @@ def upload_doc_job(filter_branch): def get_manylinux_image(cu_version): if cu_version == "cpu": return "pytorch/manylinux-cuda102" - elif cu_version.startswith("cu"): + if cu_version.startswith("cu"): cu_suffix = cu_version[len("cu") :] return f"pytorch/manylinux-cuda{cu_suffix}" - elif cu_version.startswith("rocm"): + if cu_version.startswith("rocm"): rocm_suffix = cu_version[len("rocm") :] return f"pytorch/manylinux-rocm:{rocm_suffix}" @@ -145,7 +145,7 @@ def get_manylinux_image(cu_version): def get_conda_image(cu_version): if cu_version == "cpu": return "pytorch/conda-builder:cpu" - elif cu_version.startswith("cu"): + if cu_version.startswith("cu"): cu_suffix = cu_version[len("cu") :] return f"pytorch/conda-builder:cuda{cu_suffix}" diff --git a/test/common_utils.py b/test/common_utils.py index f782613971c..a6eec717f9a 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -45,20 +45,18 @@ def __call__(self, object): if isinstance(object, torch.Tensor): return self.tensor_map_fn(object) - elif isinstance(object, dict): + if isinstance(object, dict): mapped_dict = {} for key, value in object.items(): mapped_dict[self(key)] = self(value) return mapped_dict - elif isinstance(object, (list, tuple)): + if isinstance(object, (list, tuple)): mapped_iter = [] for iter in object: mapped_iter.append(self(iter)) return mapped_iter if not isinstance(object, tuple) else tuple(mapped_iter) - - else: - return object + return object def map_nested_tensor_object(object, tensor_map_fn): diff --git a/test/test_datasets.py b/test/test_datasets.py index d2dc4ea6958..687e8c1a676 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1482,16 +1482,15 @@ class QMNISTTestCase(MNISTTestCase): def _num_images(self, config): if config["what"] == "nist": return 3 - elif config["what"] == "train": + if config["what"] == "train": return 2 - elif config["what"] == "test50k": + if config["what"] == "test50k": # The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create # more than 10000 images for the dataset to not be empty. Since this takes significantly longer than the # creation of all other splits, this is excluded from the 'ADDITIONAL_CONFIGS' and is tested only once in # 'test_num_examples_test50k'. return 10001 - else: - return 1 + return 1 def _labels_file(self, config): return f"{self._prefix(config)}-labels-idx2-int" diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index 4bf31eba92b..274fb890ede 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -94,8 +94,7 @@ def add_mock(stack, name, file, **kwargs): except AttributeError as error: if file != "utils": return add_mock(stack, name, "utils", **kwargs) - else: - raise pytest.UsageError from error + raise pytest.UsageError from error if urls_and_md5s is None: urls_and_md5s = set() diff --git a/test/test_image.py b/test/test_image.py index 9c6a73b8362..9b10573dca4 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -401,8 +401,7 @@ def _collect_if(cond): def _inner(test_func): if cond: return test_func - else: - return pytest.mark.dont_collect(test_func) + return pytest.mark.dont_collect(test_func) return _inner diff --git a/test/test_models.py b/test/test_models.py index 5e5b3429778..adc43a4e086 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -557,8 +557,7 @@ def compact(tensor): elements_per_sample = functools.reduce(operator.mul, size[1:], 1) if elements_per_sample > 30: return compute_mean_std(tensor) - else: - return subsample_tensor(tensor) + return subsample_tensor(tensor) def subsample_tensor(tensor): num_elems = tensor.size(0) diff --git a/test/test_onnx.py b/test/test_onnx.py index c81d490a882..6246f0e0244 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -70,8 +70,7 @@ def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False): def to_numpy(tensor): if tensor.requires_grad: return tensor.detach().cpu().numpy() - else: - return tensor.cpu().numpy() + return tensor.cpu().numpy() inputs = list(map(to_numpy, inputs)) outputs = list(map(to_numpy, outputs)) diff --git a/torchvision/datasets/cityscapes.py b/torchvision/datasets/cityscapes.py index cfc3e8bab71..19595eebafe 100644 --- a/torchvision/datasets/cityscapes.py +++ b/torchvision/datasets/cityscapes.py @@ -213,9 +213,8 @@ def _load_json(self, path: str) -> Dict[str, Any]: def _get_target_suffix(self, mode: str, target_type: str) -> str: if target_type == "instance": return "{}_instanceIds.png".format(mode) - elif target_type == "semantic": + if target_type == "semantic": return "{}_labelIds.png".format(mode) - elif target_type == "color": + if target_type == "color": return "{}_color.png".format(mode) - else: - return "{}_polygons.json".format(mode) + return "{}_polygons.json".format(mode) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index fedf4a35539..a1a77f2b897 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -264,8 +264,7 @@ def default_loader(path: str) -> Any: if get_image_backend() == "accimage": return accimage_loader(path) - else: - return pil_loader(path) + return pil_loader(path) class ImageFolder(DatasetFolder): diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 0fdb3395a5e..d7b5c5f7ab7 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -91,12 +91,11 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str if check_integrity(file): return torch.load(file) - else: - msg = ( - "The meta file {} is not present in the root directory or is corrupted. " - "This file is automatically created by the ImageNet dataset." - ) - raise RuntimeError(msg.format(file, root)) + msg = ( + "The meta file {} is not present in the root directory or is corrupted. " + "This file is automatically created by the ImageNet dataset." + ) + raise RuntimeError(msg.format(file, root)) def _verify_archive(root: str, file: str, md5: str) -> None: diff --git a/torchvision/datasets/inaturalist.py b/torchvision/datasets/inaturalist.py index 1e2d09d39f8..dbef2fe0d64 100644 --- a/torchvision/datasets/inaturalist.py +++ b/torchvision/datasets/inaturalist.py @@ -211,14 +211,12 @@ def category_name(self, category_type: str, category_id: int) -> str: """ if category_type == "full": return self.all_categories[category_id] - else: - if category_type not in self.categories_index: - raise ValueError(f"Invalid category type '{category_type}'") - else: - for name, id in self.categories_index[category_type].items(): - if id == category_id: - return name - raise ValueError(f"Invalid category id {category_id} for {category_type}") + if category_type not in self.categories_index: + raise ValueError(f"Invalid category type '{category_type}'") + for name, id in self.categories_index[category_type].items(): + if id == category_id: + return name + raise ValueError(f"Invalid category id {category_id} for {category_type}") def _check_integrity(self) -> bool: return os.path.exists(self.root) and len(os.listdir(self.root)) > 0 diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 20ebbc3b0ee..2dee71c06a2 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -89,7 +89,7 @@ def __init__( def _verify_folds(self, folds: Optional[int]) -> Optional[int]: if folds is None: return folds - elif isinstance(folds, int): + if isinstance(folds, int): if folds in range(10): return folds msg = "Value for argument folds should be in the range [0, 10), " "but got {}." diff --git a/torchvision/io/video.py b/torchvision/io/video.py index e5648459113..d1e8fe8e09b 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -355,8 +355,7 @@ def _decode_video_timestamps(container: "av.container.Container") -> List[int]: if _can_read_timestamps_from_packets(container): # fast path return [x.pts for x in container.demux(video=0) if x.pts is not None] - else: - return [x.pts for x in container.decode(video=0) if x.pts is not None] + return [x.pts for x in container.decode(video=0) if x.pts is not None] def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]: diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 70a7b40bd50..1092f7c1d5a 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -170,11 +170,10 @@ def mobilenet_backbone( in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers] return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) - else: - m = nn.Sequential( - backbone, - # depthwise linear combination of channels to reduce their size - nn.Conv2d(backbone[-1].out_channels, out_channels, 1), - ) - m.out_channels = out_channels - return m + m = nn.Sequential( + backbone, + # depthwise linear combination of channels to reduce their size + nn.Conv2d(backbone[-1].out_channels, out_channels, 1), + ) + m.out_channels = out_channels + return m diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index c77c892e63e..ba585e93db1 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -107,5 +107,4 @@ def forward(self, images, targets=None): warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting") self._has_warned = True return losses, detections - else: - return self.eager_outputs(losses, detections) + return self.eager_outputs(losses, detections) diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 132805389a7..4af9d9f1aa0 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -200,8 +200,7 @@ def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs: if self.training and self.aux_logits: return _GoogLeNetOutputs(x, aux2, aux1) - else: - return x # type: ignore[return-value] + return x # type: ignore[return-value] def forward(self, x: Tensor) -> GoogLeNetOutputs: x = self._transform_input(x) @@ -211,8 +210,7 @@ def forward(self, x: Tensor) -> GoogLeNetOutputs: if not aux_defined: warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple") return GoogLeNetOutputs(x, aux2, aux1) - else: - return self.eager_outputs(x, aux2, aux1) + return self.eager_outputs(x, aux2, aux1) class Inception(nn.Module): diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 2f18b8bc569..12aae8b489f 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -192,8 +192,7 @@ def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]: def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs: if self.training and self.aux_logits: return InceptionOutputs(x, aux) - else: - return x # type: ignore[return-value] + return x # type: ignore[return-value] def forward(self, x: Tensor) -> InceptionOutputs: x = self._transform_input(x) @@ -203,8 +202,7 @@ def forward(self, x: Tensor) -> InceptionOutputs: if not aux_defined: warnings.warn("Scripted Inception3 always returns Inception3 Tuple") return InceptionOutputs(x, aux) - else: - return self.eager_outputs(x, aux) + return self.eager_outputs(x, aux) class InceptionA(nn.Module): diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 3f48f82c41e..0fcee9386cf 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -47,8 +47,7 @@ def __init__( def forward(self, input: Tensor) -> Tensor: if self.apply_residual: return self.layers(input) + input - else: - return self.layers(input) + return self.layers(input) def _stack( diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 4b12478ba18..b593038fbbd 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -80,8 +80,7 @@ def __init__( def forward(self, x: Tensor) -> Tensor: if self.use_res_connect: return x + self.conv(x) - else: - return self.conv(x) + return self.conv(x) class MobileNetV2(nn.Module): diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 4b6d25e013c..acc420006ca 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -155,8 +155,7 @@ def forward(self, x: Tensor) -> GoogLeNetOutputs: if not aux_defined: warnings.warn("Scripted QuantizableGoogleNet always returns GoogleNetOutputs Tuple") return GoogLeNetOutputs(x, aux2, aux1) - else: - return self.eager_outputs(x, aux2, aux1) + return self.eager_outputs(x, aux2, aux1) def fuse_model(self) -> None: r"""Fuse conv/bn/relu modules in googlenet model diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index acad3f6df53..cfe995b93e1 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -235,8 +235,7 @@ def forward(self, x: Tensor) -> InceptionOutputs: if not aux_defined: warnings.warn("Scripted QuantizableInception3 always returns QuantizableInception3 Tuple") return InceptionOutputs(x, aux) - else: - return self.eager_outputs(x, aux) + return self.eager_outputs(x, aux) def fuse_model(self) -> None: r"""Fuse conv/bn/relu modules in inception model diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index a2c88cdd388..34020f58e79 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -25,8 +25,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def forward(self, x: Tensor) -> Tensor: if self.use_res_connect: return self.skip_add.add(x, self.conv(x)) - else: - return self.conv(x) + return self.conv(x) def fuse_model(self) -> None: for idx in range(len(self.conv)): diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index ad195d178c7..378759f3a00 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -75,8 +75,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def forward(self, x: Tensor) -> Tensor: if self.use_res_connect: return self.skip_add.add(x, self.block(x)) - else: - return self.block(x) + return self.block(x) class QuantizableMobileNetV3(MobileNetV3): diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index f8d9c596606..f0c0acddca3 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -66,8 +66,7 @@ def batched_nms( # Ideally for GPU we'd use a higher threshold if boxes.numel() > 4_000 and not torchvision._is_tracing(): return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold) - else: - return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) + return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) @torch.jit._script_if_tracing @@ -210,8 +209,7 @@ def _upcast(t: Tensor) -> Tensor: # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type if t.is_floating_point(): return t if t.dtype in (torch.float32, torch.float64) else t.float() - else: - return t if t.dtype in (torch.int32, torch.int64) else t.int() + return t if t.dtype in (torch.int32, torch.int64) else t.int() def box_area(boxes: Tensor) -> Tensor: diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 94a361aff2d..c18aedd603f 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -71,10 +71,9 @@ def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: key, _ = data if key == "images": return 0 - elif key == "annotations": + if key == "annotations": return 1 - else: - return None + return None def _decode_ann(self, ann: Dict[str, Any]) -> Dict[str, Any]: area = torch.tensor(ann["area"]) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index b20c3ed6266..5dff85314ab 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -272,10 +272,9 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> (images_file, _), (labels_file, _) = self._files_and_checksums(config) if path.name == images_file: return 0 - elif path.name == labels_file: + if path.name == labels_file: return 1 - else: - return None + return None _LABEL_OFFSETS = { 38: 1, diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 9262351a118..c21988ed414 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -60,13 +60,12 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: if parent.name == "dataset": return 0 - elif grandparent.name == "dataset": + if grandparent.name == "dataset": if parent.name == "img": return 1 - elif parent.name == "cls": + if parent.name == "cls": return 2 - else: - return None + return None else: return None diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 9a1f4d05728..914ff906e68 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -76,12 +76,11 @@ def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]: if self._is_in_folder(data, name="ImageSets", depth=2): return 0 - elif self._is_in_folder(data, name="JPEGImages"): + if self._is_in_folder(data, name="JPEGImages"): return 1 - elif self._is_in_folder(data, name=self._ANNS_FOLDER[config.task]): + if self._is_in_folder(data, name=self._ANNS_FOLDER[config.task]): return 2 - else: - return None + return None def _decode_detection_ann(self, buffer: io.IOBase) -> torch.Tensor: result = VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot()) # type: ignore[arg-type] diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 72c55233e7d..d6d3f34e0f4 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -173,10 +173,9 @@ def _detect_compression_type(self, path: str) -> CompressionType: ext = os.path.splitext(path)[1] if ext == ".gz": return self.types.GZIP - elif ext == ".xz": + if ext == ".xz": return self.types.LZMA - else: - raise RuntimeError("FIXME") + raise RuntimeError("FIXME") def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: for path, file in self.datapipe: diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index f99e0aa2950..9a5264b8295 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -151,7 +151,7 @@ def _get_policies( (("Color", 0.6, 4), ("Contrast", 1.0, 8)), (("Equalize", 0.8, None), ("Equalize", 0.6, None)), ] - elif policy == AutoAugmentPolicy.CIFAR10: + if policy == AutoAugmentPolicy.CIFAR10: return [ (("Invert", 0.1, None), ("Contrast", 0.2, 6)), (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), @@ -179,7 +179,7 @@ def _get_policies( (("Equalize", 0.8, None), ("Invert", 0.1, None)), (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), ] - elif policy == AutoAugmentPolicy.SVHN: + if policy == AutoAugmentPolicy.SVHN: return [ (("ShearX", 0.9, 4), ("Invert", 0.2, None)), (("ShearY", 0.9, 8), ("Invert", 0.7, None)), @@ -207,8 +207,7 @@ def _get_policies( (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), (("ShearX", 0.7, 2), ("Invert", 0.1, None)), ] - else: - raise ValueError("The provided policy {} is not recognized.".format(policy)) + raise ValueError("The provided policy {} is not recognized.".format(policy)) def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: return { diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 9578134cae0..285209a861b 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -127,8 +127,7 @@ def to_tensor(pic): # backward compatibility if isinstance(img, torch.ByteTensor): return img.to(dtype=default_float_dtype).div(255) - else: - return img + return img if accimage is not None and isinstance(pic, accimage.Image): nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) @@ -146,8 +145,7 @@ def to_tensor(pic): img = img.permute((2, 0, 1)).contiguous() if isinstance(img, torch.ByteTensor): return img.to(dtype=default_float_dtype).div(255) - else: - return img + return img def pil_to_tensor(pic): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index eb2ab31a4a9..08265b93c7c 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -15,8 +15,7 @@ def _is_pil_image(img: Any) -> bool: if accimage is not None: return isinstance(img, (Image.Image, accimage.Image)) - else: - return isinstance(img, Image.Image) + return isinstance(img, Image.Image) @torch.jit.unused @@ -167,44 +166,43 @@ def pad( return image return ImageOps.expand(img, border=padding, **opts) - else: - if isinstance(padding, int): - pad_left = pad_right = pad_top = pad_bottom = padding - if isinstance(padding, tuple) and len(padding) == 2: - pad_left = pad_right = padding[0] - pad_top = pad_bottom = padding[1] - if isinstance(padding, tuple) and len(padding) == 4: - pad_left = padding[0] - pad_top = padding[1] - pad_right = padding[2] - pad_bottom = padding[3] - - p = [pad_left, pad_top, pad_right, pad_bottom] - cropping = -np.minimum(p, 0) - - if cropping.any(): - crop_left, crop_top, crop_right, crop_bottom = cropping - img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom)) - - pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0) - - if img.mode == "P": - palette = img.getpalette() - img = np.asarray(img) - img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) - img = Image.fromarray(img) - img.putpalette(palette) - return img - + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + if isinstance(padding, tuple) and len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + if isinstance(padding, tuple) and len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + p = [pad_left, pad_top, pad_right, pad_bottom] + cropping = -np.minimum(p, 0) + + if cropping.any(): + crop_left, crop_top, crop_right, crop_bottom = cropping + img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom)) + + pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0) + + if img.mode == "P": + palette = img.getpalette() img = np.asarray(img) - # RGB image - if len(img.shape) == 3: - img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) - # Grayscale image - if len(img.shape) == 2: - img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) + img = Image.fromarray(img) + img.putpalette(palette) + return img + + img = np.asarray(img) + # RGB image + if len(img.shape) == 3: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) + # Grayscale image + if len(img.shape) == 2: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) - return Image.fromarray(img) + return Image.fromarray(img) @torch.jit.unused @@ -257,13 +255,12 @@ def resize( new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) return img.resize((new_w, new_h), interpolation) - else: - if max_size is not None: - raise ValueError( - "max_size should only be passed if size specifies the length of the smaller edge, " - "i.e. size should be an int or a sequence of length 1 in torchscript mode." - ) - return img.resize(size[::-1], interpolation) + if max_size is not None: + raise ValueError( + "max_size should only be passed if size specifies the length of the smaller edge, " + "i.e. size should be an int or a sequence of length 1 in torchscript mode." + ) + return img.resize(size[::-1], interpolation) @torch.jit.unused diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index d0fd78346b6..85e105110b2 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -25,7 +25,7 @@ def get_image_size(img: Tensor) -> List[int]: def get_image_num_channels(img: Tensor) -> int: if img.ndim == 2: return 1 - elif img.ndim > 2: + if img.ndim > 2: return img.shape[-3] raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim)) @@ -81,30 +81,28 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - max_val = _max_value(dtype) result = image.mul(max_val + 1.0 - eps) return result.to(dtype) - else: - input_max = _max_value(image.dtype) - - # int to float - # TODO: replace with dtype.is_floating_point when torchscript supports it - if torch.tensor(0, dtype=dtype).is_floating_point(): - image = image.to(dtype) - return image / input_max - - output_max = _max_value(dtype) - - # int to int - if input_max > output_max: - # factor should be forced to int for torch jit script - # otherwise factor is a float and image // factor can produce different results - factor = int((input_max + 1) // (output_max + 1)) - image = torch.div(image, factor, rounding_mode="floor") - return image.to(dtype) - else: - # factor should be forced to int for torch jit script - # otherwise factor is a float and image * factor can produce different results - factor = int((output_max + 1) // (input_max + 1)) - image = image.to(dtype) - return image * factor + input_max = _max_value(image.dtype) + + # int to float + # TODO: replace with dtype.is_floating_point when torchscript supports it + if torch.tensor(0, dtype=dtype).is_floating_point(): + image = image.to(dtype) + return image / input_max + + output_max = _max_value(dtype) + + # int to int + if input_max > output_max: + # factor should be forced to int for torch jit script + # otherwise factor is a float and image // factor can produce different results + factor = int((input_max + 1) // (output_max + 1)) + image = torch.div(image, factor, rounding_mode="floor") + return image.to(dtype) + # factor should be forced to int for torch jit script + # otherwise factor is a float and image * factor can produce different results + factor = int((output_max + 1) // (input_max + 1)) + image = image.to(dtype) + return image * factor def vflip(img: Tensor) -> Tensor: @@ -401,10 +399,9 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: ndim = img.ndim if ndim == 3: return img[:, y_indices[:, None], x_indices[None, :]] - elif ndim == 4: + if ndim == 4: return img[:, :, y_indices[:, None], x_indices[None, :]] - else: - raise RuntimeError("Symmetric padding of N-D tensors are not supported yet") + raise RuntimeError("Symmetric padding of N-D tensors are not supported yet") def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: