diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dccc1c1a9b2..463a97359ab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,8 +15,8 @@ repos: hooks: - id: ufmt additional_dependencies: - - black == 21.9b0 - - usort == 0.6.4 + - black == 22.3.0 + - usort == 1.0.2 - repo: https://gitlab.com/pycqa/flake8 rev: 3.9.2 diff --git a/hubconf.py b/hubconf.py index 9455805bcfd..1231b0bbea6 100644 --- a/hubconf.py +++ b/hubconf.py @@ -3,8 +3,8 @@ from torchvision.models import get_weight from torchvision.models.alexnet import alexnet -from torchvision.models.convnext import convnext_tiny, convnext_small, convnext_base, convnext_large -from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 +from torchvision.models.convnext import convnext_base, convnext_large, convnext_small, convnext_tiny +from torchvision.models.densenet import densenet121, densenet161, densenet169, densenet201 from torchvision.models.efficientnet import ( efficientnet_b0, efficientnet_b1, @@ -14,9 +14,9 @@ efficientnet_b5, efficientnet_b6, efficientnet_b7, - efficientnet_v2_s, - efficientnet_v2_m, efficientnet_v2_l, + efficientnet_v2_m, + efficientnet_v2_s, ) from torchvision.models.googlenet import googlenet from torchvision.models.inception import inception_v3 @@ -25,40 +25,40 @@ from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small from torchvision.models.optical_flow import raft_large, raft_small from torchvision.models.regnet import ( - regnet_y_400mf, - regnet_y_800mf, - regnet_y_1_6gf, - regnet_y_3_2gf, - regnet_y_8gf, - regnet_y_16gf, - regnet_y_32gf, - regnet_y_128gf, - regnet_x_400mf, - regnet_x_800mf, + regnet_x_16gf, regnet_x_1_6gf, + regnet_x_32gf, regnet_x_3_2gf, + regnet_x_400mf, + regnet_x_800mf, regnet_x_8gf, - regnet_x_16gf, - regnet_x_32gf, + regnet_y_128gf, + regnet_y_16gf, + regnet_y_1_6gf, + regnet_y_32gf, + regnet_y_3_2gf, + regnet_y_400mf, + regnet_y_800mf, + regnet_y_8gf, ) from torchvision.models.resnet import ( + resnet101, + resnet152, resnet18, resnet34, resnet50, - resnet101, - resnet152, - resnext50_32x4d, resnext101_32x8d, resnext101_64x4d, - wide_resnet50_2, + resnext50_32x4d, wide_resnet101_2, + wide_resnet50_2, ) from torchvision.models.segmentation import ( - fcn_resnet50, - fcn_resnet101, - deeplabv3_resnet50, - deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, + deeplabv3_resnet101, + deeplabv3_resnet50, + fcn_resnet101, + fcn_resnet50, lraspp_mobilenet_v3_large, ) from torchvision.models.shufflenetv2 import ( @@ -68,12 +68,6 @@ shufflenet_v2_x2_0, ) from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1 -from torchvision.models.swin_transformer import swin_t, swin_s, swin_b -from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn -from torchvision.models.vision_transformer import ( - vit_b_16, - vit_b_32, - vit_l_16, - vit_l_32, - vit_h_14, -) +from torchvision.models.swin_transformer import swin_b, swin_s, swin_t +from torchvision.models.vgg import vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn +from torchvision.models.vision_transformer import vit_b_16, vit_b_32, vit_h_14, vit_l_16, vit_l_32 diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index c0e5af1dcfc..a66a47f8674 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -9,7 +9,7 @@ import torchvision import utils from torch import nn -from train import train_one_epoch, evaluate, load_data +from train import evaluate, load_data, train_one_epoch def main(args): diff --git a/references/detection/group_by_aspect_ratio.py b/references/detection/group_by_aspect_ratio.py index 1323849a6a4..5312cc036d6 100644 --- a/references/detection/group_by_aspect_ratio.py +++ b/references/detection/group_by_aspect_ratio.py @@ -2,7 +2,7 @@ import copy import math from collections import defaultdict -from itertools import repeat, chain +from itertools import chain, repeat import numpy as np import torch diff --git a/references/detection/train.py b/references/detection/train.py index f56ac66881c..178f7460417 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -29,8 +29,8 @@ import torchvision.models.detection.mask_rcnn import utils from coco_utils import get_coco, get_coco_kp -from engine import train_one_epoch, evaluate -from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups +from engine import evaluate, train_one_epoch +from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler from torchvision.transforms import InterpolationMode from transforms import SimpleCopyPaste diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 35ae34bd56a..7da854505f2 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -1,11 +1,10 @@ -from typing import List, Tuple, Dict, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torchvision from torch import nn, Tensor from torchvision import ops -from torchvision.transforms import functional as F -from torchvision.transforms import transforms as T, InterpolationMode +from torchvision.transforms import functional as F, InterpolationMode, transforms as T def _flip_coco_person_keypoints(kps, width): diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 7c4c45ab275..0327d92bdf9 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -6,8 +6,8 @@ import torch import torchvision.models.optical_flow import utils -from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval -from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K +from presets import OpticalFlowPresetEval, OpticalFlowPresetTrain +from torchvision.datasets import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel def get_train_dataset(stage, dataset_root): diff --git a/references/optical_flow/utils.py b/references/optical_flow/utils.py index 065a2be8bfc..8b07e9de35c 100644 --- a/references/optical_flow/utils.py +++ b/references/optical_flow/utils.py @@ -1,8 +1,7 @@ import datetime import os import time -from collections import defaultdict -from collections import deque +from collections import defaultdict, deque import torch import torch.distributed as dist @@ -158,7 +157,7 @@ def log_every(self, iterable, print_freq=5, header=None): def compute_metrics(flow_pred, flow_gt, valid_flow_mask=None): epe = ((flow_pred - flow_gt) ** 2).sum(dim=1).sqrt() - flow_norm = (flow_gt ** 2).sum(dim=1).sqrt() + flow_norm = (flow_gt**2).sum(dim=1).sqrt() if valid_flow_mask is not None: epe = epe[valid_flow_mask] @@ -183,7 +182,7 @@ def sequence_loss(flow_preds, flow_gt, valid_flow_mask, gamma=0.8, max_flow=400) raise ValueError(f"Gamma should be < 1, got {gamma}.") # exlude invalid pixels and extremely large diplacements - flow_norm = torch.sum(flow_gt ** 2, dim=1).sqrt() + flow_norm = torch.sum(flow_gt**2, dim=1).sqrt() valid_flow_mask = valid_flow_mask & (flow_norm < max_flow) valid_flow_mask = valid_flow_mask[:, None, :, :] diff --git a/references/segmentation/utils.py b/references/segmentation/utils.py index dfd12726b53..4ea24db83ed 100644 --- a/references/segmentation/utils.py +++ b/references/segmentation/utils.py @@ -75,7 +75,7 @@ def update(self, a, b): with torch.inference_mode(): k = (a >= 0) & (a < n) inds = n * a[k].to(torch.int64) + b[k] - self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n) + self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) def reset(self): self.mat.zero_() diff --git a/setup.py b/setup.py index 54c961159bb..54319451521 100644 --- a/setup.py +++ b/setup.py @@ -7,9 +7,9 @@ import sys import torch -from pkg_resources import parse_version, get_distribution, DistributionNotFound -from setuptools import setup, find_packages -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME +from pkg_resources import DistributionNotFound, get_distribution, parse_version +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDA_HOME, CUDAExtension def read(*names, **kwargs): diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 0fe0cbd6dd7..8c5484a2823 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -14,12 +14,12 @@ import unittest.mock import warnings import xml.etree.ElementTree as ET -from collections import defaultdict, Counter +from collections import Counter, defaultdict import numpy as np import pytest import torch -from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid +from datasets_utils import combinations_grid, create_image_file, create_image_folder, make_tar, make_zip from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor from torchvision.prototype import datasets diff --git a/test/conftest.py b/test/conftest.py index a8b9054a4e5..1a9b2db7f5c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,7 +3,7 @@ import numpy as np import pytest import torch -from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG +from common_utils import CIRCLECI_GPU_NO_CUDA_MSG, CUDA_NOT_AVAILABLE_MSG, IN_CIRCLE_CI, IN_FBCODE, IN_RE_WORKER def pytest_configure(config): diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 88eb4e17823..2043caae0a2 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -22,7 +22,7 @@ import torch import torchvision.datasets import torchvision.io -from common_utils import get_tmp_dir, disable_console_output +from common_utils import disable_console_output, get_tmp_dir __all__ = [ diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index 4d2e475e1df..5fa7c6bca44 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -9,15 +9,15 @@ from os import path from urllib.error import HTTPError, URLError from urllib.parse import urlparse -from urllib.request import urlopen, Request +from urllib.request import Request, urlopen import pytest from torchvision import datasets from torchvision.datasets.utils import ( - download_url, + _get_redirect_url, check_integrity, download_file_from_google_drive, - _get_redirect_url, + download_url, USER_AGENT, ) diff --git a/test/test_datasets_samplers.py b/test/test_datasets_samplers.py index 7174d6321f7..9e3826b2c13 100644 --- a/test/test_datasets_samplers.py +++ b/test/test_datasets_samplers.py @@ -1,12 +1,8 @@ import pytest import torch -from common_utils import get_list_of_videos, assert_equal +from common_utils import assert_equal, get_list_of_videos from torchvision import io -from torchvision.datasets.samplers import ( - DistributedSampler, - RandomClipSampler, - UniformClipSampler, -) +from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler from torchvision.datasets.video_utils import VideoClips diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index cfdbd6f6d02..adaa4f5446c 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -1,8 +1,8 @@ import pytest import torch -from common_utils import get_list_of_videos, assert_equal +from common_utils import assert_equal, get_list_of_videos from torchvision import io -from torchvision.datasets.video_utils import VideoClips, unfold +from torchvision.datasets.video_utils import unfold, VideoClips class TestVideo: diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 7961d173e3f..677d19d18f7 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -5,7 +5,7 @@ import test_models as TM import torch from torchvision import models -from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._api import Weights, WeightsEnum from torchvision.models._utils import handle_legacy_interface diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 969aedf6d2d..52979a019e7 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -14,14 +14,14 @@ import torchvision.transforms.functional_pil as F_pil import torchvision.transforms.functional_tensor as F_t from common_utils import ( - cpu_and_gpu, - needs_cuda, + _assert_approx_equal_tensor_to_pil, + _assert_equal_tensor_to_pil, _create_data, _create_data_batch, - _assert_equal_tensor_to_pil, - _assert_approx_equal_tensor_to_pil, _test_fn_on_batch, assert_equal, + cpu_and_gpu, + needs_cuda, ) from torchvision.transforms import InterpolationMode diff --git a/test/test_image.py b/test/test_image.py index e4358f6f1e1..89374ebc8c5 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -8,21 +8,21 @@ import pytest import torch import torchvision.transforms.functional as F -from common_utils import needs_cuda, assert_equal -from PIL import Image, __version__ as PILLOW_VERSION +from common_utils import assert_equal, needs_cuda +from PIL import __version__ as PILLOW_VERSION, Image from torchvision.io.image import ( - decode_png, + _read_png_16, + decode_image, decode_jpeg, + decode_png, encode_jpeg, - write_jpeg, - decode_image, - read_file, encode_png, - write_png, - write_file, ImageReadMode, + read_file, read_image, - _read_png_16, + write_file, + write_jpeg, + write_png, ) IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") @@ -168,7 +168,7 @@ def test_decode_png(img_path, pil_mode, mode): img_lpng = _read_png_16(img_path, mode=mode) assert img_lpng.dtype == torch.int32 # PIL converts 16 bits pngs in uint8 - img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8) + img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8) else: data = read_file(img_path) img_lpng = decode_image(data, mode=mode) diff --git a/test/test_models.py b/test/test_models.py index 866fafae5f6..05bab11e479 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -14,7 +14,7 @@ import torch.fx import torch.nn as nn from _utils_internal import get_relative_path -from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda +from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed from torchvision import models ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index c4efbd96cf3..13db78d53fc 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -4,7 +4,7 @@ from common_utils import assert_equal from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead from torchvision.models.detection.roi_heads import RoIHeads -from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork +from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead from torchvision.ops import MultiScaleRoIAlign @@ -60,7 +60,7 @@ def test_assign_targets_to_proposals(self): resolution = box_roi_pool.output_size[0] representation_size = 1024 - box_head = TwoMLPHead(4 * resolution ** 2, representation_size) + box_head = TwoMLPHead(4 * resolution**2, representation_size) representation_size = 1024 box_predictor = FastRCNNPredictor(representation_size, 2) diff --git a/test/test_onnx.py b/test/test_onnx.py index ba0880a621d..d5dae64b4d0 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -4,13 +4,12 @@ import pytest import torch -from common_utils import set_rng_seed, assert_equal -from torchvision import models -from torchvision import ops +from common_utils import assert_equal, set_rng_seed +from torchvision import models, ops from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead from torchvision.models.detection.image_list import ImageList from torchvision.models.detection.roi_heads import RoIHeads -from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork +from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.ops._register_onnx_ops import _onnx_opset_version @@ -265,7 +264,7 @@ def _init_test_roi_heads_faster_rcnn(self): resolution = box_roi_pool.output_size[0] representation_size = 1024 - box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size) + box_head = TwoMLPHead(out_channels * resolution**2, representation_size) representation_size = 1024 box_predictor = FastRCNNPredictor(representation_size, num_classes) diff --git a/test/test_ops.py b/test/test_ops.py index 96cfb630e8d..bc4f9d19464 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -79,7 +79,7 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar rois_dtype = self.dtype if rois_dtype is None else rois_dtype pool_size = 5 # n_channels % (pool_size ** 2) == 0 required for PS opeartions. - n_channels = 2 * (pool_size ** 2) + n_channels = 2 * (pool_size**2) x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device) if not contiguous: x = x.permute(0, 1, 3, 2) @@ -115,7 +115,7 @@ def test_is_leaf_node(self, device): def test_backward(self, seed, device, contiguous): torch.random.manual_seed(seed) pool_size = 2 - x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) + x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) if not contiguous: x = x.permute(0, 1, 3, 2) rois = torch.tensor( diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 5a8c9e7eff8..6ddba1806c6 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -5,14 +5,14 @@ import pytest import torch -from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS -from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair +from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks +from torch.testing._comparison import assert_equal, ObjectPair, TensorLikePair from torch.utils.data import DataLoader from torch.utils.data.graph import traverse from torch.utils.data.graph_settings import get_all_graph_pipes -from torchdata.datapipes.iter import Shuffler, ShardingFilter +from torchdata.datapipes.iter import ShardingFilter, Shuffler from torchvision._utils import sequence_to_str -from torchvision.prototype import transforms, datasets +from torchvision.prototype import datasets, transforms from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.features import Image, Label diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py index 8790b1638f9..2098ac736ac 100644 --- a/test/test_prototype_datasets_utils.py +++ b/test/test_prototype_datasets_utils.py @@ -9,8 +9,8 @@ from torchdata.datapipes.iter import FileOpener, TarArchiveLoader from torchvision.datasets._optical_flow import _read_flo as read_flo_ref from torchvision.datasets.utils import _decompress -from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset, OnlineResource -from torchvision.prototype.datasets.utils._internal import read_flo, fromfile +from torchvision.prototype.datasets.utils import Dataset, GDriveResource, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import fromfile, read_flo @pytest.mark.filterwarnings("error:The given NumPy array is not writeable:UserWarning") diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index c76a84f8634..eefb1669901 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -2,7 +2,7 @@ import test_models as TM import torch import torchvision.prototype.models.depth.stereo.raft_stereo as raft_stereo -from common_utils import set_rng_seed, cpu_and_gpu +from common_utils import cpu_and_gpu, set_rng_seed @pytest.mark.parametrize("model_builder", (raft_stereo.raft_stereo_base, raft_stereo.raft_stereo_realtime)) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index eb92af41071..1a56e8d3928 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -3,13 +3,9 @@ import pytest import torch from common_utils import assert_equal -from test_prototype_transforms_functional import ( - make_images, - make_bounding_boxes, - make_one_hot_labels, -) -from torchvision.prototype import transforms, features -from torchvision.transforms.functional import to_pil_image, pil_to_tensor +from test_prototype_transforms_functional import make_bounding_boxes, make_images, make_one_hot_labels +from torchvision.prototype import features, transforms +from torchvision.transforms.functional import pil_to_tensor, to_pil_image def make_vanilla_tensor_images(*args, **kwargs): diff --git a/test/test_transforms.py b/test/test_transforms.py index 427239d6d70..6ec670ffa78 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -24,7 +24,7 @@ except ImportError: stats = None -from common_utils import cycle_over, int_dtypes, float_dtypes, assert_equal +from common_utils import assert_equal, cycle_over, float_dtypes, int_dtypes GRACE_HOPPER = get_file_path_2( diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 7dc6dbd95d9..f4ca544deb8 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -6,19 +6,18 @@ import torch import torchvision.transforms._pil_constants as _pil_constants from common_utils import ( - get_tmp_dir, - int_dtypes, - float_dtypes, + _assert_approx_equal_tensor_to_pil, + _assert_equal_tensor_to_pil, _create_data, _create_data_batch, - _assert_equal_tensor_to_pil, - _assert_approx_equal_tensor_to_pil, - cpu_and_gpu, assert_equal, + cpu_and_gpu, + float_dtypes, + get_tmp_dir, + int_dtypes, ) from torchvision import transforms as T -from torchvision.transforms import InterpolationMode -from torchvision.transforms import functional as F +from torchvision.transforms import functional as F, InterpolationMode from torchvision.transforms.autoaugment import _apply_op NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC diff --git a/test/test_utils.py b/test/test_utils.py index 7cff53e98a3..dde3ee90dc3 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -10,7 +10,7 @@ import torchvision.transforms.functional as F import torchvision.utils as utils from common_utils import assert_equal -from PIL import Image, __version__ as PILLOW_VERSION, ImageColor +from PIL import __version__ as PILLOW_VERSION, Image, ImageColor PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) @@ -45,8 +45,8 @@ def test_normalize_in_make_grid(): # Rounding the result to one decimal for comparison n_digits = 1 - rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits) - rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits) + rounded_grid_max = torch.round(grid_max * 10**n_digits) / (10**n_digits) + rounded_grid_min = torch.round(grid_min * 10**n_digits) / (10**n_digits) assert_equal(norm_max, rounded_grid_max, msg="Normalized max is not equal to 1") assert_equal(norm_min, rounded_grid_min, msg="Normalized min is not equal to 0") diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 32b522cbc42..739f79407b3 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -2,12 +2,7 @@ import warnings import torch -from torchvision import datasets -from torchvision import io -from torchvision import models -from torchvision import ops -from torchvision import transforms -from torchvision import utils +from torchvision import datasets, io, models, ops, transforms, utils from .extension import _HAS_OPS diff --git a/torchvision/_utils.py b/torchvision/_utils.py index 8e8fe1b8a83..b739ef0966e 100644 --- a/torchvision/_utils.py +++ b/torchvision/_utils.py @@ -1,5 +1,5 @@ import enum -from typing import Sequence, TypeVar, Type +from typing import Sequence, Type, TypeVar T = TypeVar("T", bound=enum.Enum) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 295fe922478..099d10da35d 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,4 +1,4 @@ -from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D, HD1K +from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -11,19 +11,19 @@ from .fakedata import FakeData from .fer2013 import FER2013 from .fgvc_aircraft import FGVCAircraft -from .flickr import Flickr8k, Flickr30k +from .flickr import Flickr30k, Flickr8k from .flowers102 import Flowers102 -from .folder import ImageFolder, DatasetFolder +from .folder import DatasetFolder, ImageFolder from .food101 import Food101 from .gtsrb import GTSRB from .hmdb51 import HMDB51 from .imagenet import ImageNet from .inaturalist import INaturalist -from .kinetics import Kinetics400, Kinetics +from .kinetics import Kinetics, Kinetics400 from .kitti import Kitti -from .lfw import LFWPeople, LFWPairs +from .lfw import LFWPairs, LFWPeople from .lsun import LSUN, LSUNClass -from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST +from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST from .omniglot import Omniglot from .oxford_iiit_pet import OxfordIIITPet from .pcam import PCAM @@ -40,7 +40,7 @@ from .ucf101 import UCF101 from .usps import USPS from .vision import VisionDataset -from .voc import VOCSegmentation, VOCDetection +from .voc import VOCDetection, VOCSegmentation from .widerface import WIDERFace __all__ = ( diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 8a36c1b8d04..bc26f51dc75 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -9,7 +9,7 @@ from PIL import Image from ..io.image import _read_png_16 -from .utils import verify_str_arg, _read_pfm +from .utils import _read_pfm, verify_str_arg from .vision import VisionDataset @@ -466,7 +466,7 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name): flow_and_valid = _read_png_16(file_name).to(torch.float32) flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] - flow = (flow - 2 ** 15) / 64 # This conversion is explained somewhere on the kitti archive + flow = (flow - 2**15) / 64 # This conversion is explained somewhere on the kitti archive valid_flow_mask = valid_flow_mask.bool() # For consistency with other datasets, we convert to numpy diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index e95043ce2de..3a9635dfe09 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -1,6 +1,6 @@ import os import os.path -from typing import Any, Callable, List, Optional, Union, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union from PIL import Image diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index e9dd883b92e..dbacece88c9 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -1,12 +1,12 @@ import csv import os from collections import namedtuple -from typing import Any, Callable, List, Optional, Union, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union import PIL import torch -from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive +from .utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg from .vision import VisionDataset CSV = namedtuple("CSV", ["header", "index", "data"]) diff --git a/torchvision/datasets/cityscapes.py b/torchvision/datasets/cityscapes.py index b07c093e10c..86d65c7c091 100644 --- a/torchvision/datasets/cityscapes.py +++ b/torchvision/datasets/cityscapes.py @@ -1,11 +1,11 @@ import json import os from collections import namedtuple -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from PIL import Image -from .utils import extract_archive, verify_str_arg, iterable_to_str +from .utils import extract_archive, iterable_to_str, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/clevr.py b/torchvision/datasets/clevr.py index 112765a6b5d..94e261e3355 100644 --- a/torchvision/datasets/clevr.py +++ b/torchvision/datasets/clevr.py @@ -1,6 +1,6 @@ import json import pathlib -from typing import Any, Callable, Optional, Tuple, List +from typing import Any, Callable, List, Optional, Tuple from urllib.parse import urlparse from PIL import Image diff --git a/torchvision/datasets/coco.py b/torchvision/datasets/coco.py index 9bb8bda67d1..f53aba16e5f 100644 --- a/torchvision/datasets/coco.py +++ b/torchvision/datasets/coco.py @@ -1,5 +1,5 @@ import os.path -from typing import Any, Callable, Optional, Tuple, List +from typing import Any, Callable, List, Optional, Tuple from PIL import Image diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index b5c650cb276..9a62520fe2b 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -2,7 +2,7 @@ from typing import Callable, Optional from .folder import ImageFolder -from .utils import verify_str_arg, download_and_extract_archive +from .utils import download_and_extract_archive, verify_str_arg class Country211(ImageFolder): diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py index deb27312573..2d8314346b9 100644 --- a/torchvision/datasets/dtd.py +++ b/torchvision/datasets/dtd.py @@ -1,10 +1,10 @@ import os import pathlib -from typing import Optional, Callable +from typing import Callable, Optional import PIL.Image -from .utils import verify_str_arg, download_and_extract_archive +from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/fer2013.py b/torchvision/datasets/fer2013.py index 60cbfd9bf28..bcd20c1e4a2 100644 --- a/torchvision/datasets/fer2013.py +++ b/torchvision/datasets/fer2013.py @@ -5,7 +5,7 @@ import torch from PIL import Image -from .utils import verify_str_arg, check_integrity +from .utils import check_integrity, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/flowers102.py b/torchvision/datasets/flowers102.py index 97a8fb416ba..ad3a6dda0e8 100644 --- a/torchvision/datasets/flowers102.py +++ b/torchvision/datasets/flowers102.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Tuple, Callable, Optional +from typing import Any, Callable, Optional, Tuple import PIL.Image diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index d5a7e88083b..40d5e26d242 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -1,7 +1,6 @@ import os import os.path -from typing import Any, Callable, cast, Dict, List, Optional, Tuple -from typing import Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from PIL import Image diff --git a/torchvision/datasets/food101.py b/torchvision/datasets/food101.py index 1bb4d8094d5..aa405eedcf9 100644 --- a/torchvision/datasets/food101.py +++ b/torchvision/datasets/food101.py @@ -1,10 +1,10 @@ import json from pathlib import Path -from typing import Any, Tuple, Callable, Optional +from typing import Any, Callable, Optional, Tuple import PIL.Image -from .utils import verify_str_arg, download_and_extract_archive +from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/hmdb51.py b/torchvision/datasets/hmdb51.py index f7341f4aa30..9067418d847 100644 --- a/torchvision/datasets/hmdb51.py +++ b/torchvision/datasets/hmdb51.py @@ -1,6 +1,6 @@ import glob import os -from typing import Optional, Callable, Tuple, Dict, Any, List +from typing import Any, Callable, Dict, List, Optional, Tuple from torch import Tensor diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index a272bb86e57..4b86bf2f2b9 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -2,7 +2,7 @@ import shutil import tempfile from contextlib import contextmanager -from typing import Any, Dict, List, Iterator, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch diff --git a/torchvision/datasets/inaturalist.py b/torchvision/datasets/inaturalist.py index 7d5fc279820..50b32ef0f4a 100644 --- a/torchvision/datasets/inaturalist.py +++ b/torchvision/datasets/inaturalist.py @@ -1,6 +1,6 @@ import os import os.path -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from PIL import Image diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index 2ba5e50845e..9352355522d 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -11,7 +11,7 @@ from torch import Tensor from .folder import find_classes, make_dataset -from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity +from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg from .video_utils import VideoClips from .vision import VisionDataset diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index c290e6dc0e8..a936351cdcc 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -7,7 +7,7 @@ from PIL import Image -from .utils import verify_str_arg, iterable_to_str +from .utils import iterable_to_str, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 9f9ec457499..fd742544935 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -12,7 +12,7 @@ import torch from PIL import Image -from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity +from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index 5a09d61ccca..41d18c1bdd5 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -3,7 +3,7 @@ from PIL import Image -from .utils import download_and_extract_archive, check_integrity, list_dir, list_files +from .utils import check_integrity, download_and_extract_archive, list_dir, list_files from .vision import VisionDataset diff --git a/torchvision/datasets/oxford_iiit_pet.py b/torchvision/datasets/oxford_iiit_pet.py index 733aa78256b..667ee13717d 100644 --- a/torchvision/datasets/oxford_iiit_pet.py +++ b/torchvision/datasets/oxford_iiit_pet.py @@ -1,8 +1,7 @@ import os import os.path import pathlib -from typing import Any, Callable, Optional, Union, Tuple -from typing import Sequence +from typing import Any, Callable, Optional, Sequence, Tuple, Union from PIL import Image diff --git a/torchvision/datasets/pcam.py b/torchvision/datasets/pcam.py index 4f124674961..63faf721a0f 100644 --- a/torchvision/datasets/pcam.py +++ b/torchvision/datasets/pcam.py @@ -3,7 +3,7 @@ from PIL import Image -from .utils import download_file_from_google_drive, _decompress, verify_str_arg +from .utils import _decompress, download_file_from_google_drive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/places365.py b/torchvision/datasets/places365.py index dd11d7331ae..c26b6f03074 100644 --- a/torchvision/datasets/places365.py +++ b/torchvision/datasets/places365.py @@ -4,7 +4,7 @@ from urllib.parse import urljoin from .folder import default_loader -from .utils import verify_str_arg, check_integrity, download_and_extract_archive +from .utils import check_integrity, download_and_extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/rendered_sst2.py b/torchvision/datasets/rendered_sst2.py index 02445dddb05..89adf8cf8d8 100644 --- a/torchvision/datasets/rendered_sst2.py +++ b/torchvision/datasets/rendered_sst2.py @@ -1,10 +1,10 @@ from pathlib import Path -from typing import Any, Tuple, Callable, Optional +from typing import Any, Callable, Optional, Tuple import PIL.Image from .folder import make_dataset -from .utils import verify_str_arg, download_and_extract_archive +from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/samplers/__init__.py b/torchvision/datasets/samplers/__init__.py index 861a029a9ec..58b2d2abd93 100644 --- a/torchvision/datasets/samplers/__init__.py +++ b/torchvision/datasets/samplers/__init__.py @@ -1,3 +1,3 @@ -from .clip_sampler import DistributedSampler, UniformClipSampler, RandomClipSampler +from .clip_sampler import DistributedSampler, RandomClipSampler, UniformClipSampler __all__ = ("DistributedSampler", "UniformClipSampler", "RandomClipSampler") diff --git a/torchvision/datasets/samplers/clip_sampler.py b/torchvision/datasets/samplers/clip_sampler.py index f4975f8c021..026c3d75d3b 100644 --- a/torchvision/datasets/samplers/clip_sampler.py +++ b/torchvision/datasets/samplers/clip_sampler.py @@ -1,5 +1,5 @@ import math -from typing import Optional, List, Iterator, Sized, Union, cast +from typing import cast, Iterator, List, Optional, Sized, Union import torch import torch.distributed as dist diff --git a/torchvision/datasets/sbd.py b/torchvision/datasets/sbd.py index 030643dc794..8399d025b1b 100644 --- a/torchvision/datasets/sbd.py +++ b/torchvision/datasets/sbd.py @@ -5,7 +5,7 @@ import numpy as np from PIL import Image -from .utils import download_url, verify_str_arg, download_and_extract_archive +from .utils import download_and_extract_archive, download_url, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/sbu.py b/torchvision/datasets/sbu.py index cd483a46190..6bfe0b88cba 100644 --- a/torchvision/datasets/sbu.py +++ b/torchvision/datasets/sbu.py @@ -3,7 +3,7 @@ from PIL import Image -from .utils import download_url, check_integrity +from .utils import check_integrity, download_url from .vision import VisionDataset diff --git a/torchvision/datasets/semeion.py b/torchvision/datasets/semeion.py index eb9ee247f13..c47703afbde 100644 --- a/torchvision/datasets/semeion.py +++ b/torchvision/datasets/semeion.py @@ -4,7 +4,7 @@ import numpy as np from PIL import Image -from .utils import download_url, check_integrity +from .utils import check_integrity, download_url from .vision import VisionDataset diff --git a/torchvision/datasets/stanford_cars.py b/torchvision/datasets/stanford_cars.py index daca0b0b46a..3e9430ef214 100644 --- a/torchvision/datasets/stanford_cars.py +++ b/torchvision/datasets/stanford_cars.py @@ -1,5 +1,5 @@ import pathlib -from typing import Callable, Optional, Any, Tuple +from typing import Any, Callable, Optional, Tuple from PIL import Image diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 1ef50cf0a24..8a906619a9d 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -1,5 +1,5 @@ import os.path -from typing import Any, Callable, Optional, Tuple, cast +from typing import Any, Callable, cast, Optional, Tuple import numpy as np from PIL import Image diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index cc3457fb16f..05cb910dde8 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Tuple, Callable, Optional +from typing import Any, Callable, Optional, Tuple import PIL.Image diff --git a/torchvision/datasets/svhn.py b/torchvision/datasets/svhn.py index 8a2e5839971..facb2d8858e 100644 --- a/torchvision/datasets/svhn.py +++ b/torchvision/datasets/svhn.py @@ -4,7 +4,7 @@ import numpy as np from PIL import Image -from .utils import download_url, check_integrity, verify_str_arg +from .utils import check_integrity, download_url, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/datasets/ucf101.py b/torchvision/datasets/ucf101.py index 4ee5f1f3df9..c82b509e535 100644 --- a/torchvision/datasets/ucf101.py +++ b/torchvision/datasets/ucf101.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Tuple, Optional, Callable +from typing import Any, Callable, Dict, List, Optional, Tuple from torch import Tensor diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index b14f25e986b..30506b3fc79 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -15,7 +15,7 @@ import urllib.request import warnings import zipfile -from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator +from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar from urllib.parse import urlparse import numpy as np @@ -23,10 +23,7 @@ import torch from torch.utils.model_zoo import tqdm -from .._internally_replaced_utils import ( - _download_file_from_remote_location, - _is_remote_location_available, -) +from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available USER_AGENT = "pytorch/vision" diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index 3fdd50d19c7..c4890ff4416 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -2,15 +2,10 @@ import math import warnings from fractions import Fraction -from typing import Any, Dict, List, Optional, Callable, Union, Tuple, TypeVar, cast +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar, Union import torch -from torchvision.io import ( - _probe_video_from_file, - _read_video_from_file, - read_video, - read_video_timestamps, -) +from torchvision.io import _probe_video_from_file, _read_video_from_file, read_video, read_video_timestamps from .utils import tqdm diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index 3448d62702c..32888cd5c8c 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -9,7 +9,7 @@ except ImportError: from xml.etree.ElementTree import parse as ET_parse import warnings -from typing import Any, Callable, Dict, Optional, Tuple, List +from typing import Any, Callable, Dict, List, Optional, Tuple from PIL import Image diff --git a/torchvision/datasets/widerface.py b/torchvision/datasets/widerface.py index a0f1e1fe285..b46c7982d8b 100644 --- a/torchvision/datasets/widerface.py +++ b/torchvision/datasets/widerface.py @@ -1,16 +1,11 @@ import os from os.path import abspath, expanduser -from typing import Any, Callable, List, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from PIL import Image -from .utils import ( - download_file_from_google_drive, - download_and_extract_archive, - extract_archive, - verify_str_arg, -) +from .utils import download_and_extract_archive, download_file_from_google_drive, extract_archive, verify_str_arg from .vision import VisionDataset diff --git a/torchvision/extension.py b/torchvision/extension.py index ae1da9c0d04..3bad8351b23 100644 --- a/torchvision/extension.py +++ b/torchvision/extension.py @@ -23,7 +23,6 @@ def _has_ops(): def _has_ops(): # noqa: F811 return True - except (ImportError, OSError): pass diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 22788cef71e..ba7d4f69f26 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -9,8 +9,6 @@ except ModuleNotFoundError: _HAS_GPU_VIDEO_DECODER = False from ._video_opt import ( - Timebase, - VideoMetaData, _HAS_VIDEO_OPT, _probe_video_from_file, _probe_video_from_memory, @@ -18,25 +16,23 @@ _read_video_from_memory, _read_video_timestamps_from_file, _read_video_timestamps_from_memory, + Timebase, + VideoMetaData, ) from .image import ( - ImageReadMode, decode_image, decode_jpeg, decode_png, encode_jpeg, encode_png, + ImageReadMode, read_file, read_image, write_file, write_jpeg, write_png, ) -from .video import ( - read_video, - read_video_timestamps, - write_video, -) +from .video import read_video, read_video_timestamps, write_video from .video_reader import VideoReader diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index 055b195a8f4..b598196d413 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -1,7 +1,7 @@ import math import warnings from fractions import Fraction -from typing import List, Tuple, Dict, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch diff --git a/torchvision/io/video_reader.py b/torchvision/io/video_reader.py index 881b9d75bd4..c2ffa049d31 100644 --- a/torchvision/io/video_reader.py +++ b/torchvision/io/video_reader.py @@ -8,16 +8,13 @@ from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER except ModuleNotFoundError: _HAS_GPU_VIDEO_DECODER = False -from ._video_opt import ( - _HAS_VIDEO_OPT, -) +from ._video_opt import _HAS_VIDEO_OPT if _HAS_VIDEO_OPT: def _has_video_opt() -> bool: return True - else: def _has_video_opt() -> bool: diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 00b5ebefe55..7bca0276c34 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -13,9 +13,5 @@ from .vgg import * from .vision_transformer import * from .swin_transformer import * -from . import detection -from . import optical_flow -from . import quantization -from . import segmentation -from . import video +from . import detection, optical_flow, quantization, segmentation, video from ._api import get_weight diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 7c6530d66c4..901bb0015e4 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -3,7 +3,7 @@ import sys from dataclasses import dataclass, fields from inspect import signature -from typing import Any, Callable, Dict, Mapping, cast +from typing import Any, Callable, cast, Dict, Mapping from torchvision._utils import StrEnum diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index c565f611999..5d930e60295 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -2,7 +2,7 @@ import inspect import warnings from collections import OrderedDict -from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union from torch import nn diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 6c461a501c9..5d1401dcb36 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -6,9 +6,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 435789ca0e2..5b79e5934f4 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -9,9 +9,9 @@ from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index e8a66f5771b..8eaac615c86 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -11,9 +11,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 12b3784099f..7d28e96d305 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -3,9 +3,9 @@ from typing import Dict, List, Optional, Tuple import torch -from torch import Tensor, nn +from torch import nn, Tensor from torch.nn import functional as F -from torchvision.ops import FrozenBatchNorm2d, complete_box_iou_loss, distance_box_iou_loss, generalized_box_iou_loss +from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss class BalancedPositiveNegativeSampler: diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index fbef524b99c..4941d7ec440 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -6,8 +6,8 @@ from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool from .. import mobilenet, resnet -from .._api import WeightsEnum, _get_enum_from_fn -from .._utils import IntermediateLayerGetter, handle_legacy_interface +from .._api import _get_enum_from_fn, WeightsEnum +from .._utils import handle_legacy_interface, IntermediateLayerGetter class BackboneWithFPN(nn.Module): diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index fb98ca86b34..de46aadfe4f 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -7,17 +7,17 @@ from ...ops import misc as misc_nn_ops from ...transforms._presets import ObjectDetection -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from ..resnet import ResNet50_Weights, resnet50 +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights +from ..resnet import resnet50, ResNet50_Weights from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator -from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor +from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers from .generalized_rcnn import GeneralizedRCNN from .roi_heads import RoIHeads -from .rpn import RPNHead, RegionProposalNetwork +from .rpn import RegionProposalNetwork, RPNHead from .transform import GeneralizedRCNNTransform @@ -250,7 +250,7 @@ def __init__( if box_head is None: resolution = box_roi_pool.output_size[0] representation_size = 1024 - box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size) + box_head = TwoMLPHead(out_channels * resolution**2, representation_size) if box_predictor is None: representation_size = 1024 diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 9851b7f7c05..efaac721328 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -2,21 +2,19 @@ import warnings from collections import OrderedDict from functools import partial -from typing import Any, Callable, Dict, List, Tuple, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch import nn, Tensor -from ...ops import sigmoid_focal_loss, generalized_box_iou_loss -from ...ops import boxes as box_ops -from ...ops import misc as misc_nn_ops +from ...ops import boxes as box_ops, generalized_box_iou_loss, misc as misc_nn_ops, sigmoid_focal_loss from ...ops.feature_pyramid_network import LastLevelP6P7 from ...transforms._presets import ObjectDetection from ...utils import _log_api_usage_once -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..resnet import resnet50, ResNet50_Weights from . import _utils as det_utils from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index fdcaea5a3eb..b481265077f 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -4,7 +4,7 @@ import warnings from collections import OrderedDict -from typing import Tuple, List, Dict, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch from torch import nn, Tensor diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 0052e49409c..f4044a2c1a2 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -6,10 +6,10 @@ from ...ops import misc as misc_nn_ops from ...transforms._presets import ObjectDetection -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..resnet import resnet50, ResNet50_Weights from ._utils import overwrite_eps from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .faster_rcnn import FasterRCNN diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 66dde13adff..422bacd135b 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -6,13 +6,13 @@ from ...ops import misc as misc_nn_ops from ...transforms._presets import ObjectDetection -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..resnet import resnet50, ResNet50_Weights from ._utils import overwrite_eps from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers -from .faster_rcnn import FasterRCNN, FastRCNNConvFCHead, RPNHead, _default_anchorgen +from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead __all__ = [ diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 18e6b432a4f..57c75354389 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -2,23 +2,21 @@ import warnings from collections import OrderedDict from functools import partial -from typing import Any, Callable, Dict, List, Tuple, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch import nn, Tensor -from ...ops import sigmoid_focal_loss -from ...ops import boxes as box_ops -from ...ops import misc as misc_nn_ops +from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss from ...ops.feature_pyramid_network import LastLevelP6P7 from ...transforms._presets import ObjectDetection from ...utils import _log_api_usage_once -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..resnet import resnet50, ResNet50_Weights from . import _utils as det_utils -from ._utils import overwrite_eps, _box_loss +from ._utils import _box_loss, overwrite_eps from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .transform import GeneralizedRCNNTransform diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index d2abebfca68..18a6782a06b 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 39f82ca323b..07a8b931150 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -1,10 +1,9 @@ -from typing import List, Optional, Dict, Tuple +from typing import Dict, List, Optional, Tuple import torch from torch import nn, Tensor from torch.nn import functional as F -from torchvision.ops import Conv2dNormActivation -from torchvision.ops import boxes as box_ops +from torchvision.ops import boxes as box_ops, Conv2dNormActivation from . import _utils as det_utils @@ -322,15 +321,12 @@ def compute_loss( labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) - box_loss = ( - F.smooth_l1_loss( - pred_bbox_deltas[sampled_pos_inds], - regression_targets[sampled_pos_inds], - beta=1 / 9, - reduction="sum", - ) - / (sampled_inds.numel()) - ) + box_loss = F.smooth_l1_loss( + pred_bbox_deltas[sampled_pos_inds], + regression_targets[sampled_pos_inds], + beta=1 / 9, + reduction="sum", + ) / (sampled_inds.numel()) objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds]) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index bcbea25d6d7..1a926116450 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -9,10 +9,10 @@ from ...ops import boxes as box_ops from ...transforms._presets import ObjectDetection from ...utils import _log_api_usage_once -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..vgg import VGG, VGG16_Weights, vgg16 +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..vgg import VGG, vgg16, VGG16_Weights from . import _utils as det_utils from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 3be9b6fb9f2..7d695823b39 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -10,10 +10,10 @@ from ...transforms._presets import ObjectDetection from ...utils import _log_api_usage_once from .. import mobilenet -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from .._utils import _ovewrite_value_param, handle_legacy_interface +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights from . import _utils as det_utils from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 4f653a86acd..dd2d728abf9 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -1,5 +1,5 @@ import math -from typing import List, Tuple, Dict, Optional, Any +from typing import Any, Dict, List, Optional, Tuple import torch import torchvision diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index bfd59aee951..417ebabcbe5 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -3,7 +3,7 @@ import warnings from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Dict, Optional, List, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import nn, Tensor @@ -12,9 +12,9 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible +from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index 1b380076b2a..d247d9a3e26 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -5,12 +5,11 @@ from collections import OrderedDict from copy import deepcopy from itertools import chain -from typing import Dict, Callable, List, Union, Optional, Tuple, Any +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torchvision -from torch import fx -from torch import nn +from torch import fx, nn from torch.fx.graph_module import _copy_attr diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 5b0a91d4791..895fcd1e4e6 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -1,7 +1,7 @@ import warnings from collections import namedtuple from functools import partial -from typing import Optional, Tuple, List, Callable, Any +from typing import Any, Callable, List, Optional, Tuple import torch import torch.nn as nn @@ -10,9 +10,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"] diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 9207485085f..c1a87954f7c 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -1,7 +1,7 @@ import warnings from collections import namedtuple from functools import partial -from typing import Callable, Any, Optional, Tuple, List +from typing import Any, Callable, List, Optional, Tuple import torch import torch.nn.functional as F @@ -9,9 +9,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 8286674d232..27117ae3a83 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -8,9 +8,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 4c4a7d1e293..06fbff2802a 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -1,17 +1,16 @@ import warnings from functools import partial -from typing import Callable, Any, Optional, List +from typing import Any, Callable, List, Optional import torch -from torch import Tensor -from torch import nn +from torch import nn, Tensor from ..ops.misc import Conv2dNormActivation from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible +from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface __all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index dfdd529bfc2..10d2a1c91ac 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -8,9 +8,9 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible +from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index c535300a68c..f0205ef608c 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -8,10 +8,10 @@ from torch.nn import functional as F from ...transforms._presets import ImageClassification -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, GoogLeNet_Weights +from .._utils import _ovewrite_named_param, handle_legacy_interface +from ..googlenet import BasicConv2d, GoogLeNet, GoogLeNet_Weights, GoogLeNetOutputs, Inception, InceptionAux from .utils import _fuse_modules, _replace_relu, quantize_model diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index ba4b21d4112..1698cec7557 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -7,12 +7,12 @@ import torch.nn.functional as F from torch import Tensor from torchvision.models import inception as inception_module -from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights +from torchvision.models.inception import Inception_V3_Weights, InceptionOutputs from ...transforms._presets import ImageClassification -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from .._utils import _ovewrite_named_param, handle_legacy_interface from .utils import _fuse_modules, _replace_relu, quantize_model diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 936e9bcc1b1..61a3cb7eeba 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -1,16 +1,15 @@ from functools import partial from typing import Any, Optional, Union -from torch import Tensor -from torch import nn -from torch.ao.quantization import QuantStub, DeQuantStub -from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, MobileNet_V2_Weights +from torch import nn, Tensor +from torch.ao.quantization import DeQuantStub, QuantStub +from torchvision.models.mobilenetv2 import InvertedResidual, MobileNet_V2_Weights, MobileNetV2 from ...ops.misc import Conv2dNormActivation from ...transforms._presets import ImageClassification -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from .._utils import _ovewrite_named_param, handle_legacy_interface from .utils import _fuse_modules, _replace_relu, quantize_model diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 94036143138..56341bb280e 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -3,19 +3,19 @@ import torch from torch import nn, Tensor -from torch.ao.quantization import QuantStub, DeQuantStub +from torch.ao.quantization import DeQuantStub, QuantStub from ...ops.misc import Conv2dNormActivation, SqueezeExcitation from ...transforms._presets import ImageClassification -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from .._utils import _ovewrite_named_param, handle_legacy_interface from ..mobilenetv3 import ( + _mobilenet_v3_conf, InvertedResidual, InvertedResidualConfig, - MobileNetV3, - _mobilenet_v3_conf, MobileNet_V3_Large_Weights, + MobileNetV3, ) from .utils import _fuse_modules, _replace_relu diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index d51bde50a57..bf3c733887e 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -1,12 +1,12 @@ from functools import partial -from typing import Any, Type, Union, List, Optional +from typing import Any, List, Optional, Type, Union import torch import torch.nn as nn from torch import Tensor from torchvision.models.resnet import ( - Bottleneck, BasicBlock, + Bottleneck, ResNet, ResNet18_Weights, ResNet50_Weights, @@ -15,9 +15,9 @@ ) from ...transforms._presets import ImageClassification -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from .._utils import _ovewrite_named_param, handle_legacy_interface from .utils import _fuse_modules, _replace_relu, quantize_model diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 781591ae118..028df8be982 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -7,9 +7,9 @@ from torchvision.models import shufflenetv2 from ...transforms._presets import ImageClassification -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from .._utils import _ovewrite_named_param, handle_legacy_interface from ..shufflenetv2 import ( ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights, diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index d2958e8686c..d4b4147404c 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -9,9 +9,9 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible +from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 70602705521..667bece5730 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Type, Any, Callable, Union, List, Optional +from typing import Any, Callable, List, Optional, Type, Union import torch import torch.nn as nn @@ -7,9 +7,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index 44a60a95c54..56560e9dab5 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Optional, Dict +from typing import Dict, Optional from torch import nn, Tensor from torch.nn import functional as F diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index e232235f0ff..0937369a1e7 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -6,11 +6,11 @@ from torch.nn import functional as F from ...transforms._presets import SemanticSegmentation -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _VOC_CATEGORIES -from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large -from ..resnet import ResNet, resnet50, resnet101, ResNet50_Weights, ResNet101_Weights +from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3 +from ..resnet import ResNet, resnet101, ResNet101_Weights, resnet50, ResNet50_Weights from ._utils import _SimpleSegmentationModel from .fcn import FCNHead diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index b44d0d7547a..2782d675ffe 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -4,10 +4,10 @@ from torch import nn from ...transforms._presets import SemanticSegmentation -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _VOC_CATEGORIES -from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet, ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 +from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter +from ..resnet import ResNet, resnet101, ResNet101_Weights, resnet50, ResNet50_Weights from ._utils import _SimpleSegmentationModel diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index 385960cbde4..339d5feffe6 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -7,10 +7,10 @@ from ...transforms._presets import SemanticSegmentation from ...utils import _log_api_usage_once -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _VOC_CATEGORIES -from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large +from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3 __all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 48695c70193..cc4291c9a86 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Callable, Any, List, Optional +from typing import Any, Callable, List, Optional import torch import torch.nn as nn @@ -7,9 +7,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index dbc0f54fb77..8d43d3a0330 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -7,9 +7,9 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"] diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 7bc6b46c674..db5604fb377 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Optional, Callable, List, Any +from typing import Any, Callable, List, Optional import torch import torch.nn.functional as F @@ -9,7 +9,7 @@ from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param @@ -366,7 +366,7 @@ def __init__( # build SwinTransformer blocks for i_stage in range(len(depths)): stage: List[nn.Module] = [] - dim = embed_dim * 2 ** i_stage + dim = embed_dim * 2**i_stage for i_layer in range(depths[i_stage]): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 937458b48cd..7c141381ee8 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -1,14 +1,14 @@ from functools import partial -from typing import Union, List, Dict, Any, Optional, cast +from typing import Any, cast, Dict, List, Optional, Union import torch import torch.nn as nn from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 0fd76399b5e..702116f047c 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -7,10 +7,10 @@ import torch.fx import torch.nn as nn -from ...ops import StochasticDepth, MLP +from ...ops import MLP, StochasticDepth from ...transforms._presets import VideoClassification from ...utils import _log_api_usage_once -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _KINETICS400_CATEGORIES from .._utils import _ovewrite_named_param diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index 6ec8bfc0b3e..ab369c55553 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -1,14 +1,14 @@ from functools import partial -from typing import Tuple, Optional, Callable, List, Sequence, Type, Any, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union import torch.nn as nn from torch import Tensor from ...transforms._presets import VideoClassification from ...utils import _log_api_usage_once -from .._api import WeightsEnum, Weights +from .._api import Weights, WeightsEnum from .._meta import _KINETICS400_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from .._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 57c1479b13d..e9a8c94cc67 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -1,7 +1,7 @@ import math from collections import OrderedDict from functools import partial -from typing import Any, Callable, List, NamedTuple, Optional, Dict +from typing import Any, Callable, Dict, List, NamedTuple, Optional import torch import torch.nn as nn @@ -9,9 +9,9 @@ from ..ops.misc import Conv2dNormActivation, MLP from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param +from ._utils import _ovewrite_named_param, handle_legacy_interface __all__ = [ diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 5d56f0bca42..827505b842d 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -1,25 +1,25 @@ from ._register_onnx_ops import _register_custom_op from .boxes import ( - nms, batched_nms, - remove_small_boxes, - clip_boxes_to_image, box_area, box_convert, box_iou, - generalized_box_iou, - distance_box_iou, + clip_boxes_to_image, complete_box_iou, + distance_box_iou, + generalized_box_iou, masks_to_boxes, + nms, + remove_small_boxes, ) from .ciou_loss import complete_box_iou_loss from .deform_conv import deform_conv2d, DeformConv2d from .diou_loss import distance_box_iou_loss -from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d +from .drop_block import drop_block2d, drop_block3d, DropBlock2d, DropBlock3d from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss -from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP, Permute +from .misc import Conv2dNormActivation, Conv3dNormActivation, FrozenBatchNorm2d, MLP, Permute, SqueezeExcitation from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_pool import ps_roi_pool, PSRoIPool diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 72c95442b78..e42e7e04a70 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -6,7 +6,7 @@ from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once -from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh +from ._box_convert import _box_cxcywh_to_xyxy, _box_xywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xyxy_to_xywh from ._utils import _upcast @@ -331,7 +331,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso w_gt = boxes2[:, 2] - boxes2[:, 0] h_gt = boxes2[:, 3] - boxes2[:, 1] - v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) + v = (4 / (torch.pi**2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) with torch.no_grad(): alpha = v / (1 - iou + v + eps) return diou - alpha * v diff --git a/torchvision/ops/ciou_loss.py b/torchvision/ops/ciou_loss.py index 1f271fb0a1d..a71baf28e70 100644 --- a/torchvision/ops/ciou_loss.py +++ b/torchvision/ops/ciou_loss.py @@ -58,7 +58,7 @@ def complete_box_iou_loss( h_pred = y2 - y1 w_gt = x2g - x1g h_gt = y2g - y1g - v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) + v = (4 / (torch.pi**2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) with torch.no_grad(): alpha = v / (1 - iou + v + eps) diff --git a/torchvision/ops/drop_block.py b/torchvision/ops/drop_block.py index a798677f60f..e65496ea29a 100644 --- a/torchvision/ops/drop_block.py +++ b/torchvision/ops/drop_block.py @@ -37,7 +37,7 @@ def drop_block2d( N, C, H, W = input.size() block_size = min(block_size, W, H) # compute the gamma of Bernoulli distribution - gamma = (p * H * W) / ((block_size ** 2) * ((H - block_size + 1) * (W - block_size + 1))) + gamma = (p * H * W) / ((block_size**2) * ((H - block_size + 1) * (W - block_size + 1))) noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device) noise.bernoulli_(gamma) @@ -83,7 +83,7 @@ def drop_block3d( N, C, D, H, W = input.size() block_size = min(block_size, D, H, W) # compute the gamma of Bernoulli distribution - gamma = (p * D * H * W) / ((block_size ** 3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1))) + gamma = (p * D * H * W) / ((block_size**3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1))) noise = torch.empty( (N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device ) diff --git a/torchvision/ops/feature_pyramid_network.py b/torchvision/ops/feature_pyramid_network.py index 9062405a997..ffec3505ec0 100644 --- a/torchvision/ops/feature_pyramid_network.py +++ b/torchvision/ops/feature_pyramid_network.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Tuple, List, Dict, Callable, Optional +from typing import Callable, Dict, List, Optional, Tuple import torch.nn.functional as F from torch import nn, Tensor diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index a7210f5739b..0c555ec4fe9 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -1,7 +1,7 @@ import torch from ..utils import _log_api_usage_once -from ._utils import _upcast_non_float, _loss_inter_union +from ._utils import _loss_inter_union, _upcast_non_float def generalized_box_iou_loss( diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 422119ceaec..d4bda7decc5 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, List, Optional, Union, Tuple, Sequence +from typing import Callable, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index f881201a2d2..cfcb9e94056 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, List, Dict, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.fx diff --git a/torchvision/ops/ps_roi_align.py b/torchvision/ops/ps_roi_align.py index 7153e49ac05..0228a2a5554 100644 --- a/torchvision/ops/ps_roi_align.py +++ b/torchvision/ops/ps_roi_align.py @@ -4,7 +4,7 @@ from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once -from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape +from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format def ps_roi_align( diff --git a/torchvision/ops/ps_roi_pool.py b/torchvision/ops/ps_roi_pool.py index a27c36ee76c..1a3eed35915 100644 --- a/torchvision/ops/ps_roi_pool.py +++ b/torchvision/ops/ps_roi_pool.py @@ -4,7 +4,7 @@ from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once -from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape +from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format def ps_roi_pool( diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 131c1b81d0f..afe9e42af16 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -7,7 +7,7 @@ from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once -from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape +from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format def roi_align( diff --git a/torchvision/ops/roi_pool.py b/torchvision/ops/roi_pool.py index 37cbf7febee..50dc2f64421 100644 --- a/torchvision/ops/roi_pool.py +++ b/torchvision/ops/roi_pool.py @@ -7,7 +7,7 @@ from torchvision.extension import _assert_has_ops from ..utils import _log_api_usage_once -from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape +from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format def roi_pool( diff --git a/torchvision/prototype/__init__.py b/torchvision/prototype/__init__.py index e1be6c81f59..bef5ecc411d 100644 --- a/torchvision/prototype/__init__.py +++ b/torchvision/prototype/__init__.py @@ -1,5 +1 @@ -from . import datasets -from . import features -from . import models -from . import transforms -from . import utils +from . import datasets, features, models, transforms, utils diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 407dc23f64b..f6f06c60a21 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from torchvision.prototype.datasets import home from torchvision.prototype.datasets.utils import Dataset diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 4acc1d53b4d..d84e9af9fc4 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -11,7 +11,7 @@ from .food101 import Food101 from .gtsrb import GTSRB from .imagenet import ImageNet -from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST +from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST from .oxford_iiit_pet import OxfordIIITPet from .pcam import PCAM from .sbd import SBD diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index fe3dc2000e6..a00bf2e2cc9 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -1,23 +1,18 @@ import pathlib import re -from typing import Any, Dict, List, Tuple, BinaryIO, Union +from typing import Any, BinaryIO, Dict, List, Tuple, Union import numpy as np -from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - Filter, - IterKeyZipper, -) +from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, - read_mat, hint_sharding, hint_shuffling, + INFINITE_BUFFER_SIZE, read_categories_file, + read_mat, ) -from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage +from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 46ccf8de6f7..e42657e826e 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -1,27 +1,17 @@ import csv import pathlib -from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO, Union - -from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - Filter, - Zipper, - IterKeyZipper, -) -from torchvision.prototype.datasets.utils import ( - Dataset, - GDriveResource, - OnlineResource, -) +from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union + +from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper +from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, getitem, - path_accessor, hint_sharding, hint_shuffling, + INFINITE_BUFFER_SIZE, + path_accessor, ) -from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox +from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 514938d6e5f..26196ded638 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -2,22 +2,18 @@ import io import pathlib import pickle -from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO, Union +from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, Union import numpy as np -from torchdata.datapipes.iter import ( - IterDataPipe, - Filter, - Mapper, -) +from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( + hint_sharding, hint_shuffling, path_comparator, - hint_sharding, read_categories_file, ) -from torchvision.prototype.features import Label, Image +from torchvision.prototype.features import Image, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py index 3a139787c6f..4ddacdfb982 100644 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ b/torchvision/prototype/datasets/_builtin/clevr.py @@ -1,17 +1,17 @@ import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union +from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union -from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher +from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, + getitem, hint_sharding, hint_shuffling, - path_comparator, + INFINITE_BUFFER_SIZE, path_accessor, - getitem, + path_comparator, ) -from torchvision.prototype.features import Label, EncodedImage +from torchvision.prototype.features import EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index ff3b5f37c96..16a16998bf7 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -1,35 +1,30 @@ import pathlib import re -from collections import OrderedDict -from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union +from collections import defaultdict, OrderedDict +from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union import torch from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - Filter, Demultiplexer, + Filter, Grouper, + IterDataPipe, IterKeyZipper, JsonParser, + Mapper, UnBatcher, ) -from torchvision.prototype.datasets.utils import ( - HttpResource, - OnlineResource, - Dataset, -) +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - MappingIterator, - INFINITE_BUFFER_SIZE, getitem, - read_categories_file, - path_accessor, hint_sharding, hint_shuffling, + INFINITE_BUFFER_SIZE, + MappingIterator, + path_accessor, + read_categories_file, ) -from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage +from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label from .._api import register_dataset, register_info @@ -151,7 +146,7 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, ) _META_FILE_PATTERN = re.compile( - fr"(?P({'|'.join(_ANN_DECODERS.keys())}))_(?P[a-zA-Z]+)(?P\d+)[.]json" + rf"(?P({'|'.join(_ANN_DECODERS.keys())}))_(?P[a-zA-Z]+)(?P\d+)[.]json" ) def _filter_meta_files(self, data: Tuple[str, Any]) -> bool: diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index 012ecae19e2..f9821ea4eb6 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -1,12 +1,12 @@ import pathlib from typing import Any, Dict, List, Tuple, Union -from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter +from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - path_comparator, hint_sharding, hint_shuffling, + path_comparator, read_categories_file, ) from torchvision.prototype.features import EncodedImage, Label diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index 0e5a80de825..bb3f712c59d 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -1,30 +1,30 @@ import csv import functools import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable, Union +from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, + CSVDictParser, + CSVParser, + Demultiplexer, Filter, + IterDataPipe, IterKeyZipper, - Demultiplexer, LineReader, - CSVParser, - CSVDictParser, + Mapper, ) from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, - read_mat, + getitem, hint_sharding, hint_shuffling, - getitem, + INFINITE_BUFFER_SIZE, + path_accessor, path_comparator, read_categories_file, - path_accessor, + read_mat, ) -from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage +from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index b082ada19ce..e7ff1e79559 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -1,22 +1,18 @@ import enum import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union +from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union -from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser -from torchvision.prototype.datasets.utils import ( - Dataset, - HttpResource, - OnlineResource, -) +from torchdata.datapipes.iter import CSVParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, + getitem, hint_sharding, + hint_shuffling, + INFINITE_BUFFER_SIZE, path_comparator, - getitem, read_categories_file, - hint_shuffling, ) -from torchvision.prototype.features import Label, EncodedImage +from torchvision.prototype.features import EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index c1a914c6f63..b2693aa96c0 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -2,17 +2,10 @@ from typing import Any, Dict, List, Union import torch -from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser -from torchvision.prototype.datasets.utils import ( - Dataset, - OnlineResource, - KaggleDownloadResource, -) -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, -) -from torchvision.prototype.features import Label, Image +from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper +from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import Image, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py index 5100e5d8c74..3657116ae7a 100644 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -1,24 +1,17 @@ from pathlib import Path -from typing import Any, Tuple, List, Dict, Optional, BinaryIO, Union - -from torchdata.datapipes.iter import ( - IterDataPipe, - Filter, - Mapper, - LineReader, - Demultiplexer, - IterKeyZipper, -) +from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union + +from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - hint_shuffling, - hint_sharding, - path_comparator, getitem, + hint_sharding, + hint_shuffling, INFINITE_BUFFER_SIZE, + path_comparator, read_categories_file, ) -from torchvision.prototype.features import Label, EncodedImage +from torchvision.prototype.features import EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index 01f754208e2..8dc0a8240c8 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -1,19 +1,15 @@ import pathlib from typing import Any, Dict, List, Optional, Tuple, Union -from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer -from torchvision.prototype.datasets.utils import ( - Dataset, - OnlineResource, - HttpResource, -) +from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - path_comparator, hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE, + path_comparator, ) -from torchvision.prototype.features import Label, BoundingBox, EncodedImage +from torchvision.prototype.features import BoundingBox, EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 1307757cef6..062e240a8b8 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -2,33 +2,29 @@ import functools import pathlib import re -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast, Union +from typing import Any, BinaryIO, cast, Dict, List, Match, Optional, Tuple, Union from torchdata.datapipes.iter import ( + Demultiplexer, + Enumerator, + Filter, IterDataPipe, - LineReader, IterKeyZipper, + LineReader, Mapper, - Filter, - Demultiplexer, TarArchiveLoader, - Enumerator, -) -from torchvision.prototype.datasets.utils import ( - OnlineResource, - ManualDownloadResource, - Dataset, ) +from torchvision.prototype.datasets.utils import Dataset, ManualDownloadResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, getitem, - read_mat, hint_sharding, hint_shuffling, - read_categories_file, + INFINITE_BUFFER_SIZE, path_accessor, + read_categories_file, + read_mat, ) -from torchvision.prototype.features import Label, EncodedImage +from torchvision.prototype.features import EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index e5537a1ef66..7a459b2d0ea 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -3,12 +3,12 @@ import operator import pathlib import string -from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence +from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Sequence, Tuple, Union import torch -from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, Mapper, Zipper, Decompressor +from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE from torchvision.prototype.features import Image, Label from torchvision.prototype.utils._internal import fromfile diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index f7da02a4765..499dbd837ed 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -1,23 +1,19 @@ import enum import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union +from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union -from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser -from torchvision.prototype.datasets.utils import ( - Dataset, - HttpResource, - OnlineResource, -) +from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, Mapper +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, + getitem, hint_sharding, hint_shuffling, - getitem, + INFINITE_BUFFER_SIZE, path_accessor, - read_categories_file, path_comparator, + read_categories_file, ) -from torchvision.prototype.features import Label, EncodedImage +from torchvision.prototype.features import EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 7cd31469139..162f22f1abd 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -1,19 +1,12 @@ import io import pathlib from collections import namedtuple -from typing import Any, Dict, List, Optional, Tuple, Iterator, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchvision.prototype import features -from torchvision.prototype.datasets.utils import ( - Dataset, - OnlineResource, - GDriveResource, -) -from torchvision.prototype.datasets.utils._internal import ( - hint_sharding, - hint_shuffling, -) +from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 0c806fe098c..c7a79c4188e 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -1,26 +1,19 @@ import pathlib import re -from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union +from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union import numpy as np -from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - Demultiplexer, - Filter, - IterKeyZipper, - LineReader, -) +from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - INFINITE_BUFFER_SIZE, - read_mat, getitem, - path_accessor, - path_comparator, hint_sharding, hint_shuffling, + INFINITE_BUFFER_SIZE, + path_accessor, + path_comparator, read_categories_file, + read_mat, ) from torchvision.prototype.features import _Feature, EncodedImage diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index 5051bde4047..8107f6565e4 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -2,16 +2,8 @@ from typing import Any, Dict, List, Tuple, Union import torch -from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - CSVParser, -) -from torchvision.prototype.datasets.utils import ( - Dataset, - HttpResource, - OnlineResource, -) +from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, OneHotLabel diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index 465d753c2e5..011204f2bfb 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Dict, List, Tuple, Iterator, BinaryIO, Union +from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource @@ -7,8 +7,8 @@ hint_sharding, hint_shuffling, path_comparator, - read_mat, read_categories_file, + read_mat, ) from torchvision.prototype.features import BoundingBox, EncodedImage, Label diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py index 175aa6c0a51..6dd55a77c99 100644 --- a/torchvision/prototype/datasets/_builtin/svhn.py +++ b/torchvision/prototype/datasets/_builtin/svhn.py @@ -1,23 +1,11 @@ import pathlib -from typing import Any, Dict, List, Tuple, BinaryIO, Union +from typing import Any, BinaryIO, Dict, List, Tuple, Union import numpy as np -from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - UnBatcher, -) -from torchvision.prototype.datasets.utils import ( - Dataset, - HttpResource, - OnlineResource, -) -from torchvision.prototype.datasets.utils._internal import ( - read_mat, - hint_sharding, - hint_shuffling, -) -from torchvision.prototype.features import Label, Image +from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat +from torchvision.prototype.features import Image, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index e732f3b788a..e5ca58f8428 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -2,8 +2,8 @@ from typing import Any, Dict, List, Union import torch -from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor -from torchvision.prototype.datasets.utils import Dataset, OnlineResource, HttpResource +from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index d875df521f2..2f13ce10d6f 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -1,29 +1,22 @@ import enum import functools import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Union +from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union from xml.etree import ElementTree -from torchdata.datapipes.iter import ( - IterDataPipe, - Mapper, - Filter, - Demultiplexer, - IterKeyZipper, - LineReader, -) +from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchvision.datasets import VOCDetection -from torchvision.prototype.datasets.utils import OnlineResource, HttpResource, Dataset +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( - path_accessor, getitem, - INFINITE_BUFFER_SIZE, - path_comparator, hint_sharding, hint_shuffling, + INFINITE_BUFFER_SIZE, + path_accessor, + path_comparator, read_categories_file, ) -from torchvision.prototype.features import BoundingBox, Label, EncodedImage +from torchvision.prototype.features import BoundingBox, EncodedImage, Label from .._api import register_dataset, register_info diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index c3a38becb6c..b2ec23c5e3d 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -2,11 +2,11 @@ import os import os.path import pathlib -from typing import BinaryIO, Optional, Collection, Union, Tuple, List, Dict, Any +from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union -from torchdata.datapipes.iter import IterDataPipe, FileLister, Mapper, Filter, FileOpener +from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling -from torchvision.prototype.features import Label, EncodedImage, EncodedData +from torchvision.prototype.features import EncodedData, EncodedImage, Label __all__ = ["from_data_folder", "from_image_folder"] diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 94c5907b47d..41ccbf48951 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,3 +1,3 @@ from . import _internal # usort: skip from ._dataset import Dataset -from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource +from ._resource import GDriveResource, HttpResource, KaggleDownloadResource, ManualDownloadResource, OnlineResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 528d0a0f25f..e7486c854ac 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -1,7 +1,7 @@ import abc import importlib import pathlib -from typing import Any, Dict, List, Optional, Sequence, Union, Collection, Iterator +from typing import Any, Collection, Dict, Iterator, List, Optional, Sequence, Union from torch.utils.data import IterDataPipe from torchvision.datasets.utils import verify_str_arg diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 007e91eb657..6768469be67 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -2,21 +2,7 @@ import functools import pathlib import pickle -from typing import BinaryIO -from typing import ( - Sequence, - Callable, - Union, - Any, - Tuple, - TypeVar, - List, - Iterator, - Dict, - IO, - Sized, -) -from typing import cast +from typing import Any, BinaryIO, Callable, cast, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union import torch import torch.distributed as dist diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 3c9b95cb498..dc01c72de28 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -2,26 +2,26 @@ import hashlib import itertools import pathlib -from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn, Set +from typing import Any, Callable, IO, NoReturn, Optional, Sequence, Set, Tuple, Union from urllib.parse import urlparse from torchdata.datapipes.iter import ( - IterableWrapper, FileLister, FileOpener, + IterableWrapper, IterDataPipe, - ZipArchiveLoader, - TarArchiveLoader, RarArchiveLoader, + TarArchiveLoader, + ZipArchiveLoader, ) from torchvision.datasets.utils import ( - download_url, - _detect_file_type, - extract_archive, _decompress, - download_file_from_google_drive, - _get_redirect_url, + _detect_file_type, _get_google_drive_file_id, + _get_redirect_url, + download_file_from_google_drive, + download_url, + extract_archive, ) from typing_extensions import Literal diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index eb9d1f6ac3a..c704954c03f 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Tuple, Union, Optional, Sequence +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from torchvision._utils import StrEnum diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index 612714c4c3a..ccab0b1b8a8 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -2,7 +2,7 @@ import os import sys -from typing import BinaryIO, Tuple, Type, TypeVar, Union, Optional, Any +from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union import PIL.Image import torch diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 775f09f2f4b..85f758c638c 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -1,4 +1,4 @@ -from typing import Any, cast, TypeVar, Union, Optional, Type, Callable, List, Tuple, Sequence, Mapping +from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union import torch from torch._C import _TensorBase, DisableTorchFunction diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 447e67b33e9..70b93478d17 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -1,13 +1,12 @@ from __future__ import annotations import warnings -from typing import Any, List, Optional, Union, Sequence, Tuple, cast +from typing import Any, cast, List, Optional, Sequence, Tuple, Union import torch from torchvision._utils import StrEnum -from torchvision.transforms.functional import to_pil_image, InterpolationMode -from torchvision.utils import draw_bounding_boxes -from torchvision.utils import make_grid +from torchvision.transforms.functional import InterpolationMode, to_pil_image +from torchvision.utils import draw_bounding_boxes, make_grid from ._bounding_box import BoundingBox from ._feature import _Feature diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index e3433b7bb08..c61419a61b6 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional, Sequence, cast, Union +from typing import Any, cast, Optional, Sequence, Union import torch from torchvision.prototype.utils._internal import apply_recursively diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index f894f33d1b2..fdb71358a8f 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, Union, Sequence +from typing import List, Optional, Sequence, Union from torchvision.transforms import InterpolationMode diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py index 418e2629c48..fa636f8ef00 100644 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ b/torchvision/prototype/models/depth/stereo/raft_stereo.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Callable, Tuple +from typing import Callable, List, Optional, Tuple import torch import torch.nn as nn @@ -6,8 +6,8 @@ import torchvision.models.optical_flow.raft as raft from torch import Tensor from torchvision.models._api import WeightsEnum -from torchvision.models.optical_flow._utils import make_coords_grid, grid_sample, upsample_flow -from torchvision.models.optical_flow.raft import ResidualBlock, MotionEncoder, FlowHead +from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow +from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock from torchvision.ops import Conv2dNormActivation from torchvision.utils import _log_api_usage_once diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 2075ea7c52b..3f4299f6fb9 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -2,26 +2,26 @@ from ._transform import Transform # usort: skip -from ._augment import RandomErasing, RandomMixup, RandomCutmix -from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix -from ._color import ColorJitter, RandomPhotometricDistort, RandomEqualize +from ._augment import RandomCutmix, RandomErasing, RandomMixup +from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide +from ._color import ColorJitter, RandomEqualize, RandomPhotometricDistort from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import ( - Resize, + BatchMultiCrop, CenterCrop, - RandomResizedCrop, FiveCrop, - TenCrop, - BatchMultiCrop, + Pad, + RandomAffine, RandomHorizontalFlip, + RandomResizedCrop, + RandomRotation, RandomVerticalFlip, - Pad, RandomZoomOut, - RandomRotation, - RandomAffine, + Resize, + TenCrop, ) -from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace -from ._misc import Identity, Normalize, ToDtype, Lambda +from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype +from ._misc import Identity, Lambda, Normalize, ToDtype from ._type_conversion import DecodeImage, LabelToOneHot from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 8ed81eef8f2..d1c3db816ad 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -6,10 +6,10 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, functional as F +from torchvision.prototype.transforms import functional as F, Transform from ._transform import _RandomApplyTransform -from ._utils import query_image, get_image_dimensions, has_any, has_all +from ._utils import get_image_dimensions, has_all, has_any, query_image class RandomErasing(_RandomApplyTransform): diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 03aa96e08fb..f4f1a3547b1 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,13 +1,13 @@ import math -from typing import Any, Dict, Tuple, Optional, Callable, List, cast, Sequence, TypeVar, Union, Type +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, functional as F +from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.utils._internal import query_recursively from torchvision.transforms.autoaugment import AutoAugmentPolicy -from torchvision.transforms.functional import pil_to_tensor, to_pil_image, InterpolationMode +from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from ._utils import get_image_dimensions diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 85e22aaeb1a..e71be8b5934 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -1,14 +1,14 @@ import collections.abc -from typing import Any, Dict, Union, Tuple, Optional, Sequence, TypeVar +from typing import Any, Dict, Optional, Sequence, Tuple, TypeVar, Union import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, functional as F +from torchvision.prototype.transforms import functional as F, Transform from torchvision.transforms import functional as _F from ._transform import _RandomApplyTransform -from ._utils import is_simple_tensor, get_image_dimensions, query_image +from ._utils import get_image_dimensions, is_simple_tensor, query_image T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image) diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py index e9c72e2e020..fd1f58f3351 100644 --- a/torchvision/prototype/transforms/_container.py +++ b/torchvision/prototype/transforms/_container.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, List, Dict +from typing import Any, Dict, List, Optional import torch from torchvision.prototype.transforms import Transform diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index d4162b2b631..6c511635435 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -2,18 +2,18 @@ import math import numbers import warnings -from typing import Any, Dict, List, Optional, Union, Sequence, Tuple, cast +from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, functional as F -from torchvision.transforms.functional import pil_to_tensor, InterpolationMode -from torchvision.transforms.transforms import _setup_size, _setup_angle, _check_sequence_input +from torchvision.prototype.transforms import functional as F, Transform +from torchvision.transforms.functional import InterpolationMode, pil_to_tensor +from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor +from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image class RandomHorizontalFlip(_RandomApplyTransform): diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 6791bbbc69c..fcf0e0db883 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -1,9 +1,9 @@ -from typing import Union, Any, Dict, Optional +from typing import Any, Dict, Optional, Union import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, functional as F +from torchvision.prototype.transforms import functional as F, Transform from torchvision.transforms.functional import convert_image_dtype from ._utils import is_simple_tensor diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 54440ee05a5..769e05809e7 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,8 +1,8 @@ import functools -from typing import Any, List, Type, Callable, Dict +from typing import Any, Callable, Dict, List, Type import torch -from torchvision.prototype.transforms import Transform, functional as F +from torchvision.prototype.transforms import functional as F, Transform class Identity(Transform): diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 09c071a27e0..9a698aa5e23 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -3,7 +3,7 @@ import numpy as np import PIL.Image from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, functional as F +from torchvision.prototype.transforms import functional as F, Transform from ._utils import is_simple_tensor diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 0517757a758..1344790e633 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,11 +1,11 @@ -from typing import Any, Optional, Tuple, Union, Type, Iterator +from typing import Any, Iterator, Optional, Tuple, Type, Union import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.utils._internal import query_recursively -from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_pil +from .functional._meta import get_dimensions_image_pil, get_dimensions_image_tensor def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index a8c17577a56..19b1c26f2d5 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -5,105 +5,103 @@ convert_image_color_space_pil, ) # usort: skip -from ._augment import ( - erase_image_tensor, -) +from ._augment import erase_image_tensor from ._color import ( adjust_brightness, - adjust_brightness_image_tensor, adjust_brightness_image_pil, + adjust_brightness_image_tensor, adjust_contrast, - adjust_contrast_image_tensor, adjust_contrast_image_pil, + adjust_contrast_image_tensor, + adjust_gamma, + adjust_gamma_image_pil, + adjust_gamma_image_tensor, + adjust_hue, + adjust_hue_image_pil, + adjust_hue_image_tensor, adjust_saturation, - adjust_saturation_image_tensor, adjust_saturation_image_pil, + adjust_saturation_image_tensor, adjust_sharpness, - adjust_sharpness_image_tensor, adjust_sharpness_image_pil, - adjust_hue, - adjust_hue_image_tensor, - adjust_hue_image_pil, - adjust_gamma, - adjust_gamma_image_tensor, - adjust_gamma_image_pil, - posterize, - posterize_image_tensor, - posterize_image_pil, - solarize, - solarize_image_tensor, - solarize_image_pil, + adjust_sharpness_image_tensor, autocontrast, - autocontrast_image_tensor, autocontrast_image_pil, + autocontrast_image_tensor, equalize, - equalize_image_tensor, equalize_image_pil, + equalize_image_tensor, invert, - invert_image_tensor, invert_image_pil, + invert_image_tensor, + posterize, + posterize_image_pil, + posterize_image_tensor, + solarize, + solarize_image_pil, + solarize_image_tensor, ) from ._geometry import ( + affine, + affine_bounding_box, + affine_image_pil, + affine_image_tensor, + affine_segmentation_mask, + center_crop, + center_crop_bounding_box, + center_crop_image_pil, + center_crop_image_tensor, + center_crop_segmentation_mask, + crop, + crop_bounding_box, + crop_image_pil, + crop_image_tensor, + crop_segmentation_mask, + five_crop_image_pil, + five_crop_image_tensor, horizontal_flip, horizontal_flip_bounding_box, - horizontal_flip_image_tensor, horizontal_flip_image_pil, + horizontal_flip_image_tensor, horizontal_flip_segmentation_mask, + pad, + pad_bounding_box, + pad_image_pil, + pad_image_tensor, + pad_segmentation_mask, + perspective, + perspective_bounding_box, + perspective_image_pil, + perspective_image_tensor, + perspective_segmentation_mask, resize, resize_bounding_box, - resize_image_tensor, resize_image_pil, + resize_image_tensor, resize_segmentation_mask, - center_crop, - center_crop_bounding_box, - center_crop_segmentation_mask, - center_crop_image_tensor, - center_crop_image_pil, resized_crop, resized_crop_bounding_box, - resized_crop_image_tensor, resized_crop_image_pil, + resized_crop_image_tensor, resized_crop_segmentation_mask, - affine, - affine_bounding_box, - affine_image_tensor, - affine_image_pil, - affine_segmentation_mask, rotate, rotate_bounding_box, - rotate_image_tensor, rotate_image_pil, + rotate_image_tensor, rotate_segmentation_mask, - pad, - pad_bounding_box, - pad_image_tensor, - pad_image_pil, - pad_segmentation_mask, - crop, - crop_bounding_box, - crop_image_tensor, - crop_image_pil, - crop_segmentation_mask, - perspective, - perspective_bounding_box, - perspective_image_tensor, - perspective_image_pil, - perspective_segmentation_mask, + ten_crop_image_pil, + ten_crop_image_tensor, vertical_flip, - vertical_flip_image_tensor, - vertical_flip_image_pil, vertical_flip_bounding_box, + vertical_flip_image_pil, + vertical_flip_image_tensor, vertical_flip_segmentation_mask, - five_crop_image_tensor, - five_crop_image_pil, - ten_crop_image_tensor, - ten_crop_image_pil, ) -from ._misc import normalize_image_tensor, gaussian_blur_image_tensor +from ._misc import gaussian_blur_image_tensor, normalize_image_tensor from ._type_conversion import ( decode_image_with_pil, decode_video_with_av, label_to_one_hot, - to_image_tensor, to_image_pil, + to_image_tensor, ) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index d5c5d305722..554fb98ae52 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -3,7 +3,7 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP +from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT # shortcut type diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 8d3ed675047..8938b2bf31c 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,19 +1,19 @@ import numbers import warnings -from typing import Tuple, List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union import PIL.Image import torch from torchvision.prototype import features -from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP +from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms.functional import ( - pil_modes_mapping, + _compute_output_size, _get_inverse_affine_matrix, InterpolationMode, - _compute_output_size, + pil_modes_mapping, ) -from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil +from ._meta import convert_bounding_box_format, get_dimensions_image_pil, get_dimensions_image_tensor # shortcut type diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 2386f47b226..db7918558bc 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -1,9 +1,9 @@ -from typing import Tuple, Optional +from typing import Optional, Tuple import PIL.Image import torch from torchvision.prototype.features import BoundingBoxFormat, ColorSpace -from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP +from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT get_dimensions_image_tensor = _FT.get_dimensions get_dimensions_image_pil = _FP.get_dimensions diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 7b7139a5fd9..096ba32f2cf 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import List, Optional import PIL.Image import torch diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py index 37f8f9b70a3..0619852900f 100644 --- a/torchvision/prototype/transforms/functional/_type_conversion.py +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -1,5 +1,5 @@ import unittest.mock -from typing import Dict, Any, Tuple, Union +from typing import Any, Dict, Tuple, Union import numpy as np import PIL.Image diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 233128880e3..fb5c3b83de6 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -3,18 +3,7 @@ import io import mmap import platform -from typing import ( - Any, - BinaryIO, - Callable, - Collection, - Iterator, - Sequence, - Tuple, - TypeVar, - Union, - Optional, -) +from typing import Any, BinaryIO, Callable, Collection, Iterator, Optional, Sequence, Tuple, TypeVar, Union import numpy as np import torch diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index e49912e0f00..33b94d01c9d 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -5,7 +5,7 @@ from typing import Optional, Tuple import torch -from torch import Tensor, nn +from torch import nn, Tensor from . import functional as F, InterpolationMode diff --git a/torchvision/transforms/_transforms_video.py b/torchvision/transforms/_transforms_video.py index 69512af6eb1..1ed6de7612d 100644 --- a/torchvision/transforms/_transforms_video.py +++ b/torchvision/transforms/_transforms_video.py @@ -4,10 +4,7 @@ import random import warnings -from torchvision.transforms import ( - RandomCrop, - RandomResizedCrop, -) +from torchvision.transforms import RandomCrop, RandomResizedCrop from . import _functional_video as F diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 357e5bf250e..9dbbe91e741 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -1,6 +1,6 @@ import math from enum import Enum -from typing import List, Tuple, Optional, Dict +from typing import Dict, List, Optional, Tuple import torch from torch import Tensor diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index eea53a228a9..77d5b33b55a 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np import torch @@ -15,8 +15,7 @@ accimage = None from ..utils import _log_api_usage_once -from . import functional_pil as F_pil -from . import functional_tensor as F_t +from . import functional_pil as F_pil, functional_tensor as F_t class InterpolationMode(Enum): diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 2b0872acf8a..a1e49f5c2d8 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,9 +1,9 @@ import warnings -from typing import Optional, Tuple, List, Union +from typing import List, Optional, Tuple, Union import torch from torch import Tensor -from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad +from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad def _is_tensor_a_torch_image(x: Tensor) -> bool: @@ -247,7 +247,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: if not torch.is_floating_point(img): result = convert_image_dtype(result, torch.float32) - result = (gain * result ** gamma).clamp(0, 1) + result = (gain * result**gamma).clamp(0, 1) result = convert_image_dtype(result, dtype) return result diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index abf01a13360..ae7853ec5ea 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -3,7 +3,7 @@ import random import warnings from collections.abc import Sequence -from typing import Tuple, List, Optional +from typing import List, Optional, Tuple import torch from torch import Tensor @@ -15,7 +15,7 @@ from ..utils import _log_api_usage_once from . import functional as F -from .functional import InterpolationMode, _interpolation_modes_from_int +from .functional import _interpolation_modes_from_int, InterpolationMode __all__ = [ "Compose", diff --git a/torchvision/utils.py b/torchvision/utils.py index abb8e7f0e45..3809a13c049 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -449,7 +449,7 @@ def flow_to_image(flow: torch.Tensor) -> torch.Tensor: if flow.ndim != 4 or flow.shape[1] != 2: raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") - max_norm = torch.sum(flow ** 2, dim=1).sqrt().max() + max_norm = torch.sum(flow**2, dim=1).sqrt().max() epsilon = torch.finfo((flow).dtype).eps normalized_flow = flow / (max_norm + epsilon) img = _normalized_flow_to_image(normalized_flow) @@ -476,7 +476,7 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) colorwheel = _make_colorwheel().to(device) # shape [55x3] num_cols = colorwheel.shape[0] - norm = torch.sum(normalized_flow ** 2, dim=1).sqrt() + norm = torch.sum(normalized_flow**2, dim=1).sqrt() a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi fk = (a + 1) / 2 * (num_cols - 1) k0 = torch.floor(fk).to(torch.long) @@ -542,7 +542,7 @@ def _make_colorwheel() -> torch.Tensor: def _generate_color_palette(num_objects: int): - palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1]) return [tuple((i * palette) % 255) for i in range(num_objects)]