diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index bc26f51dc75..c7663258899 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from glob import glob from pathlib import Path +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -13,6 +14,10 @@ from .vision import VisionDataset +T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]] +T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]] + + __all__ = ( "KittiFlow", "Sintel", @@ -28,26 +33,26 @@ class FlowDataset(ABC, VisionDataset): # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be. _has_builtin_flow_mask = False - def __init__(self, root, transforms=None): + def __init__(self, root: str, transforms: Optional[Callable] = None) -> None: super().__init__(root=root) self.transforms = transforms - self._flow_list = [] - self._image_list = [] + self._flow_list: List[str] = [] + self._image_list: List[List[str]] = [] - def _read_img(self, file_name): + def _read_img(self, file_name: str) -> Image.Image: img = Image.open(file_name) if img.mode != "RGB": img = img.convert("RGB") return img @abstractmethod - def _read_flow(self, file_name): + def _read_flow(self, file_name: str): # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True pass - def __getitem__(self, index): + def __getitem__(self, index: int) -> Union[T1, T2]: img1 = self._read_img(self._image_list[index][0]) img2 = self._read_img(self._image_list[index][1]) @@ -70,10 +75,10 @@ def __getitem__(self, index): else: return img1, img2, flow - def __len__(self): + def __len__(self) -> int: return len(self._image_list) - def __rmul__(self, v): + def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset: return torch.utils.data.ConcatDataset([self] * v) @@ -118,7 +123,13 @@ class Sintel(FlowDataset): return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ - def __init__(self, root, split="train", pass_name="clean", transforms=None): + def __init__( + self, + root: str, + split: str = "train", + pass_name: str = "clean", + transforms: Optional[Callable] = None, + ) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -139,7 +150,7 @@ def __init__(self, root, split="train", pass_name="clean", transforms=None): if split == "train": self._flow_list += sorted(glob(str(flow_root / scene / "*.flo"))) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Union[T1, T2]: """Return example at given index. Args: @@ -154,7 +165,7 @@ def __getitem__(self, index): """ return super().__getitem__(index) - def _read_flow(self, file_name): + def _read_flow(self, file_name: str) -> np.ndarray: return _read_flo(file_name) @@ -180,7 +191,7 @@ class KittiFlow(FlowDataset): _has_builtin_flow_mask = True - def __init__(self, root, split="train", transforms=None): + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -200,7 +211,7 @@ def __init__(self, root, split="train", transforms=None): if split == "train": self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png"))) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Union[T1, T2]: """Return example at given index. Args: @@ -215,7 +226,7 @@ def __getitem__(self, index): """ return super().__getitem__(index) - def _read_flow(self, file_name): + def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]: return _read_16bits_png_with_flow_and_valid_mask(file_name) @@ -245,7 +256,7 @@ class FlyingChairs(FlowDataset): return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ - def __init__(self, root, split="train", transforms=None): + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "val")) @@ -268,7 +279,7 @@ def __init__(self, root, split="train", transforms=None): self._flow_list += [flows[i]] self._image_list += [[images[2 * i], images[2 * i + 1]]] - def __getitem__(self, index): + def __getitem__(self, index: int) -> Union[T1, T2]: """Return example at given index. Args: @@ -283,7 +294,7 @@ def __getitem__(self, index): """ return super().__getitem__(index) - def _read_flow(self, file_name): + def _read_flow(self, file_name: str) -> np.ndarray: return _read_flo(file_name) @@ -316,7 +327,14 @@ class FlyingThings3D(FlowDataset): return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ - def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None): + def __init__( + self, + root: str, + split: str = "train", + pass_name: str = "clean", + camera: str = "left", + transforms: Optional[Callable] = None, + ) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -359,7 +377,7 @@ def __init__(self, root, split="train", pass_name="clean", camera="left", transf self._image_list += [[images[i + 1], images[i]]] self._flow_list += [flows[i + 1]] - def __getitem__(self, index): + def __getitem__(self, index: int) -> Union[T1, T2]: """Return example at given index. Args: @@ -374,7 +392,7 @@ def __getitem__(self, index): """ return super().__getitem__(index) - def _read_flow(self, file_name): + def _read_flow(self, file_name: str) -> np.ndarray: return _read_pfm(file_name) @@ -401,7 +419,7 @@ class HD1K(FlowDataset): _has_builtin_flow_mask = True - def __init__(self, root, split="train", transforms=None): + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -426,10 +444,10 @@ def __init__(self, root, split="train", transforms=None): "Could not find the HD1K images. Please make sure the directory structure is correct." ) - def _read_flow(self, file_name): + def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]: return _read_16bits_png_with_flow_and_valid_mask(file_name) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Union[T1, T2]: """Return example at given index. Args: @@ -445,7 +463,7 @@ def __getitem__(self, index): return super().__getitem__(index) -def _read_flo(file_name): +def _read_flo(file_name: str) -> np.ndarray: """Read .flo file in Middlebury format""" # Code adapted from: # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy @@ -462,7 +480,7 @@ def _read_flo(file_name): return data.reshape(h, w, 2).transpose(2, 0, 1) -def _read_16bits_png_with_flow_and_valid_mask(file_name): +def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]: flow_and_valid = _read_png_16(file_name).to(torch.float32) flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]