From 13abbc7bad65468ff1ed54717d534219f4dc15fc Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Thu, 27 Oct 2022 11:39:50 +0200 Subject: [PATCH 1/2] style: Added typing annotations to datasets/_optical_flow --- torchvision/datasets/_optical_flow.py | 64 ++++++++++++++++----------- torchvision/datasets/utils.py | 4 +- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index bc26f51dc75..86b00627518 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from glob import glob from pathlib import Path +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -28,15 +29,15 @@ class FlowDataset(ABC, VisionDataset): # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be. _has_builtin_flow_mask = False - def __init__(self, root, transforms=None): + def __init__(self, root: str, transforms: Optional[Callable] = None) -> None: super().__init__(root=root) self.transforms = transforms - self._flow_list = [] - self._image_list = [] + self._flow_list: List[str] = [] + self._image_list: List[List[str]] = [] - def _read_img(self, file_name): + def _read_img(self, file_name: Union[str, Path]) -> Image.Image: img = Image.open(file_name) if img.mode != "RGB": img = img.convert("RGB") @@ -47,7 +48,7 @@ def _read_flow(self, file_name): # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True pass - def __getitem__(self, index): + def __getitem__(self, index: int) -> Tuple: img1 = self._read_img(self._image_list[index][0]) img2 = self._read_img(self._image_list[index][1]) @@ -70,10 +71,10 @@ def __getitem__(self, index): else: return img1, img2, flow - def __len__(self): + def __len__(self) -> int: return len(self._image_list) - def __rmul__(self, v): + def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset: return torch.utils.data.ConcatDataset([self] * v) @@ -118,7 +119,13 @@ class Sintel(FlowDataset): return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ - def __init__(self, root, split="train", pass_name="clean", transforms=None): + def __init__( + self, + root: str, + split: str = "train", + pass_name: str = "clean", + transforms: Optional[Callable] = None, + ) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -139,7 +146,7 @@ def __init__(self, root, split="train", pass_name="clean", transforms=None): if split == "train": self._flow_list += sorted(glob(str(flow_root / scene / "*.flo"))) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Tuple: """Return example at given index. Args: @@ -154,7 +161,7 @@ def __getitem__(self, index): """ return super().__getitem__(index) - def _read_flow(self, file_name): + def _read_flow(self, file_name: Union[str, Path]) -> np.ndarray: return _read_flo(file_name) @@ -180,7 +187,7 @@ class KittiFlow(FlowDataset): _has_builtin_flow_mask = True - def __init__(self, root, split="train", transforms=None): + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -200,7 +207,7 @@ def __init__(self, root, split="train", transforms=None): if split == "train": self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png"))) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Tuple: """Return example at given index. Args: @@ -215,7 +222,7 @@ def __getitem__(self, index): """ return super().__getitem__(index) - def _read_flow(self, file_name): + def _read_flow(self, file_name: Union[str, Path]) -> Tuple[np.ndarray, np.ndarray]: return _read_16bits_png_with_flow_and_valid_mask(file_name) @@ -245,7 +252,7 @@ class FlyingChairs(FlowDataset): return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ - def __init__(self, root, split="train", transforms=None): + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "val")) @@ -268,7 +275,7 @@ def __init__(self, root, split="train", transforms=None): self._flow_list += [flows[i]] self._image_list += [[images[2 * i], images[2 * i + 1]]] - def __getitem__(self, index): + def __getitem__(self, index: int) -> Tuple: """Return example at given index. Args: @@ -283,7 +290,7 @@ def __getitem__(self, index): """ return super().__getitem__(index) - def _read_flow(self, file_name): + def _read_flow(self, file_name: Union[str, Path]) -> np.ndarray: return _read_flo(file_name) @@ -316,7 +323,14 @@ class FlyingThings3D(FlowDataset): return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ - def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None): + def __init__( + self, + root: str, + split: str = "train", + pass_name: str = "clean", + camera: str = "left", + transforms: Optional[Callable] = None, + ) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -359,7 +373,7 @@ def __init__(self, root, split="train", pass_name="clean", camera="left", transf self._image_list += [[images[i + 1], images[i]]] self._flow_list += [flows[i + 1]] - def __getitem__(self, index): + def __getitem__(self, index: int) -> Tuple: """Return example at given index. Args: @@ -374,7 +388,7 @@ def __getitem__(self, index): """ return super().__getitem__(index) - def _read_flow(self, file_name): + def _read_flow(self, file_name: Union[str, Path]) -> np.ndarray: return _read_pfm(file_name) @@ -401,7 +415,7 @@ class HD1K(FlowDataset): _has_builtin_flow_mask = True - def __init__(self, root, split="train", transforms=None): + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -426,10 +440,10 @@ def __init__(self, root, split="train", transforms=None): "Could not find the HD1K images. Please make sure the directory structure is correct." ) - def _read_flow(self, file_name): + def _read_flow(self, file_name: Union[str, Path]) -> Tuple[np.ndarray, np.ndarray]: return _read_16bits_png_with_flow_and_valid_mask(file_name) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Tuple: """Return example at given index. Args: @@ -445,7 +459,7 @@ def __getitem__(self, index): return super().__getitem__(index) -def _read_flo(file_name): +def _read_flo(file_name: Union[str, Path]) -> np.ndarray: """Read .flo file in Middlebury format""" # Code adapted from: # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy @@ -462,9 +476,9 @@ def _read_flo(file_name): return data.reshape(h, w, 2).transpose(2, 0, 1) -def _read_16bits_png_with_flow_and_valid_mask(file_name): +def _read_16bits_png_with_flow_and_valid_mask(file_name: Union[str, Path]) -> Tuple[np.ndarray, np.ndarray]: - flow_and_valid = _read_png_16(file_name).to(torch.float32) + flow_and_valid = _read_png_16(str(file_name)).to(torch.float32) flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] flow = (flow - 2**15) / 64 # This conversion is explained somewhere on the kitti archive valid_flow_mask = valid_flow_mask.bool() diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index b8aaff3d773..6d4296e1d4f 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -15,7 +15,7 @@ import urllib.request import warnings import zipfile -from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar +from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union from urllib.parse import urlparse import numpy as np @@ -486,7 +486,7 @@ def verify_str_arg( return value -def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray: +def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray: """Read file in .pfm format. Might contain either 1 or 3 channels of data. Args: From 084a76ced02daeb54d94eb5e115fcc0fd6929b62 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Tue, 1 Nov 2022 15:56:33 +0100 Subject: [PATCH 2/2] style: Reverted back to str typing --- torchvision/datasets/_optical_flow.py | 36 +++++++++++++++------------ torchvision/datasets/utils.py | 4 +-- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 86b00627518..c7663258899 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -14,6 +14,10 @@ from .vision import VisionDataset +T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]] +T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]] + + __all__ = ( "KittiFlow", "Sintel", @@ -37,18 +41,18 @@ def __init__(self, root: str, transforms: Optional[Callable] = None) -> None: self._flow_list: List[str] = [] self._image_list: List[List[str]] = [] - def _read_img(self, file_name: Union[str, Path]) -> Image.Image: + def _read_img(self, file_name: str) -> Image.Image: img = Image.open(file_name) if img.mode != "RGB": img = img.convert("RGB") return img @abstractmethod - def _read_flow(self, file_name): + def _read_flow(self, file_name: str): # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True pass - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> Union[T1, T2]: img1 = self._read_img(self._image_list[index][0]) img2 = self._read_img(self._image_list[index][1]) @@ -146,7 +150,7 @@ def __init__( if split == "train": self._flow_list += sorted(glob(str(flow_root / scene / "*.flo"))) - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> Union[T1, T2]: """Return example at given index. Args: @@ -161,7 +165,7 @@ def __getitem__(self, index: int) -> Tuple: """ return super().__getitem__(index) - def _read_flow(self, file_name: Union[str, Path]) -> np.ndarray: + def _read_flow(self, file_name: str) -> np.ndarray: return _read_flo(file_name) @@ -207,7 +211,7 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl if split == "train": self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png"))) - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> Union[T1, T2]: """Return example at given index. Args: @@ -222,7 +226,7 @@ def __getitem__(self, index: int) -> Tuple: """ return super().__getitem__(index) - def _read_flow(self, file_name: Union[str, Path]) -> Tuple[np.ndarray, np.ndarray]: + def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]: return _read_16bits_png_with_flow_and_valid_mask(file_name) @@ -275,7 +279,7 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl self._flow_list += [flows[i]] self._image_list += [[images[2 * i], images[2 * i + 1]]] - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> Union[T1, T2]: """Return example at given index. Args: @@ -290,7 +294,7 @@ def __getitem__(self, index: int) -> Tuple: """ return super().__getitem__(index) - def _read_flow(self, file_name: Union[str, Path]) -> np.ndarray: + def _read_flow(self, file_name: str) -> np.ndarray: return _read_flo(file_name) @@ -373,7 +377,7 @@ def __init__( self._image_list += [[images[i + 1], images[i]]] self._flow_list += [flows[i + 1]] - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> Union[T1, T2]: """Return example at given index. Args: @@ -388,7 +392,7 @@ def __getitem__(self, index: int) -> Tuple: """ return super().__getitem__(index) - def _read_flow(self, file_name: Union[str, Path]) -> np.ndarray: + def _read_flow(self, file_name: str) -> np.ndarray: return _read_pfm(file_name) @@ -440,10 +444,10 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl "Could not find the HD1K images. Please make sure the directory structure is correct." ) - def _read_flow(self, file_name: Union[str, Path]) -> Tuple[np.ndarray, np.ndarray]: + def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]: return _read_16bits_png_with_flow_and_valid_mask(file_name) - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> Union[T1, T2]: """Return example at given index. Args: @@ -459,7 +463,7 @@ def __getitem__(self, index: int) -> Tuple: return super().__getitem__(index) -def _read_flo(file_name: Union[str, Path]) -> np.ndarray: +def _read_flo(file_name: str) -> np.ndarray: """Read .flo file in Middlebury format""" # Code adapted from: # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy @@ -476,9 +480,9 @@ def _read_flo(file_name: Union[str, Path]) -> np.ndarray: return data.reshape(h, w, 2).transpose(2, 0, 1) -def _read_16bits_png_with_flow_and_valid_mask(file_name: Union[str, Path]) -> Tuple[np.ndarray, np.ndarray]: +def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]: - flow_and_valid = _read_png_16(str(file_name)).to(torch.float32) + flow_and_valid = _read_png_16(file_name).to(torch.float32) flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] flow = (flow - 2**15) / 64 # This conversion is explained somewhere on the kitti archive valid_flow_mask = valid_flow_mask.bool() diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 6d4296e1d4f..b8aaff3d773 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -15,7 +15,7 @@ import urllib.request import warnings import zipfile -from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar from urllib.parse import urlparse import numpy as np @@ -486,7 +486,7 @@ def verify_str_arg( return value -def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray: +def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray: """Read file in .pfm format. Might contain either 1 or 3 channels of data. Args: