diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index cd06cfe1cab..02a79019967 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from glob import glob from pathlib import Path -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, cast, List, Optional, Tuple, Union import numpy as np from PIL import Image @@ -14,6 +14,9 @@ from .utils import _read_pfm, download_and_extract_archive, verify_str_arg from .vision import VisionDataset +T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray] +T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]] + __all__ = () _read_pfm_file = functools.partial(_read_pfm, slice_channels=1) @@ -24,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset): _has_built_in_disparity_mask = False - def __init__(self, root: str, transforms: Optional[Callable] = None): + def __init__(self, root: str, transforms: Optional[Callable] = None) -> None: """ Args: root(str): Root directory of the dataset. @@ -58,7 +61,11 @@ def _read_img(self, file_path: Union[str, Path]) -> Image.Image: img = img.convert("RGB") return img - def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str] = None): + def _scan_pairs( + self, + paths_left_pattern: str, + paths_right_pattern: Optional[str] = None, + ) -> List[Tuple[str, Optional[str]]]: left_paths = list(sorted(glob(paths_left_pattern))) @@ -85,11 +92,11 @@ def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str return paths @abstractmethod - def _read_disparity(self, file_path: str) -> Tuple: + def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: # function that returns a disparity map and an occlusion map pass - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> Union[T1, T2]: """Return example at given index. Args: @@ -120,7 +127,7 @@ def __getitem__(self, index: int) -> Tuple: ) = self.transforms(imgs, dsp_maps, valid_masks) if self._has_built_in_disparity_mask or valid_masks[0] is not None: - return imgs[0], imgs[1], dsp_maps[0], valid_masks[0] + return imgs[0], imgs[1], dsp_maps[0], cast(np.ndarray, valid_masks[0]) else: return imgs[0], imgs[1], dsp_maps[0] @@ -156,7 +163,7 @@ class CarlaStereo(StereoMatchingDataset): transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. """ - def __init__(self, root: str, transforms: Optional[Callable] = None): + def __init__(self, root: str, transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) root = Path(root) / "carla-highres" @@ -171,13 +178,13 @@ def __init__(self, root: str, transforms: Optional[Callable] = None): disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) self._disparities = disparities - def _read_disparity(self, file_path: str) -> Tuple: + def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]: disparity_map = _read_pfm_file(file_path) disparity_map = np.abs(disparity_map) # ensure that the disparity is positive valid_mask = None return disparity_map, valid_mask - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> T1: """Return example at given index. Args: @@ -189,7 +196,7 @@ def __getitem__(self, index: int) -> Tuple: If a ``valid_mask`` is generated within the ``transforms`` parameter, a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. """ - return super().__getitem__(index) + return cast(T1, super().__getitem__(index)) class Kitti2012Stereo(StereoMatchingDataset): @@ -233,7 +240,7 @@ class Kitti2012Stereo(StereoMatchingDataset): _has_built_in_disparity_mask = True - def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -250,7 +257,7 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl else: self._disparities = list((None, None) for _ in self._images) - def _read_disparity(self, file_path: str) -> Tuple: + def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]: # test split has no disparity maps if file_path is None: return None, None @@ -261,7 +268,7 @@ def _read_disparity(self, file_path: str) -> Tuple: valid_mask = None return disparity_map, valid_mask - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> T1: """Return example at given index. Args: @@ -274,7 +281,7 @@ def __getitem__(self, index: int) -> Tuple: generate a valid mask. Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. """ - return super().__getitem__(index) + return cast(T1, super().__getitem__(index)) class Kitti2015Stereo(StereoMatchingDataset): @@ -321,7 +328,7 @@ class Kitti2015Stereo(StereoMatchingDataset): _has_built_in_disparity_mask = True - def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -338,7 +345,7 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl else: self._disparities = list((None, None) for _ in self._images) - def _read_disparity(self, file_path: str) -> Tuple: + def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]: # test split has no disparity maps if file_path is None: return None, None @@ -349,7 +356,7 @@ def _read_disparity(self, file_path: str) -> Tuple: valid_mask = None return disparity_map, valid_mask - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> T1: """Return example at given index. Args: @@ -362,7 +369,7 @@ def __getitem__(self, index: int) -> Tuple: generate a valid mask. Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. """ - return super().__getitem__(index) + return cast(T1, super().__getitem__(index)) class Middlebury2014Stereo(StereoMatchingDataset): @@ -479,7 +486,7 @@ def __init__( use_ambient_views: bool = False, transforms: Optional[Callable] = None, download: bool = False, - ): + ) -> None: super().__init__(root, transforms) verify_str_arg(split, "split", valid_values=("train", "test", "additional")) @@ -558,7 +565,7 @@ def _read_img(self, file_path: Union[str, Path]) -> Image.Image: file_path = random.choice(ambient_file_paths) # type: ignore return super()._read_img(file_path) - def _read_disparity(self, file_path: str) -> Tuple: + def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]: # test split has not disparity maps if file_path is None: return None, None @@ -569,7 +576,7 @@ def _read_disparity(self, file_path: str) -> Tuple: valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities return disparity_map, valid_mask - def _download_dataset(self, root: str): + def _download_dataset(self, root: str) -> None: base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip" # train and additional splits have 2 different calibration settings root = Path(root) / "Middlebury2014" @@ -608,7 +615,7 @@ def _download_dataset(self, root: str): # cleanup MiddEval3 directory shutil.rmtree(str(root / "MiddEval3")) - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> T2: """Return example at given index. Args: @@ -619,7 +626,7 @@ def __getitem__(self, index: int) -> Tuple: The disparity is a numpy array of shape (1, H, W) and the images are PIL images. ``valid_mask`` is implicitly ``None`` for `split=test`. """ - return super().__getitem__(index) + return cast(T2, super().__getitem__(index)) class CREStereo(StereoMatchingDataset): @@ -670,7 +677,7 @@ def __init__( self, root: str, transforms: Optional[Callable] = None, - ): + ) -> None: super().__init__(root, transforms) root = Path(root) / "CREStereo" @@ -688,14 +695,14 @@ def __init__( disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) self._disparities += disparities - def _read_disparity(self, file_path: str) -> Tuple: + def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]: disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) # unsqueeze the disparity map into (C, H, W) format disparity_map = disparity_map[None, :, :] / 32.0 valid_mask = None return disparity_map, valid_mask - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> T1: """Return example at given index. Args: @@ -707,7 +714,7 @@ def __getitem__(self, index: int) -> Tuple: ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not generate a valid mask. """ - return super().__getitem__(index) + return cast(T1, super().__getitem__(index)) class FallingThingsStereo(StereoMatchingDataset): @@ -755,7 +762,7 @@ class FallingThingsStereo(StereoMatchingDataset): transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. """ - def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None): + def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) root = Path(root) / "FallingThings" @@ -782,7 +789,7 @@ def __init__(self, root: str, variant: str = "single", transforms: Optional[Call right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png") self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) - def _read_disparity(self, file_path: str) -> Tuple: + def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]: # (H, W) image depth = np.asarray(Image.open(file_path)) # as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt @@ -799,7 +806,7 @@ def _read_disparity(self, file_path: str) -> Tuple: valid_mask = None return disparity_map, valid_mask - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> T1: """Return example at given index. Args: @@ -811,7 +818,7 @@ def __getitem__(self, index: int) -> Tuple: If a ``valid_mask`` is generated within the ``transforms`` parameter, a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. """ - return super().__getitem__(index) + return cast(T1, super().__getitem__(index)) class SceneFlowStereo(StereoMatchingDataset): @@ -874,7 +881,7 @@ def __init__( variant: str = "FlyingThings3D", pass_name: str = "clean", transforms: Optional[Callable] = None, - ): + ) -> None: super().__init__(root, transforms) root = Path(root) / "SceneFlow" @@ -905,13 +912,13 @@ def __init__( right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm") self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) - def _read_disparity(self, file_path: str) -> Tuple: + def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]: disparity_map = _read_pfm_file(file_path) disparity_map = np.abs(disparity_map) # ensure that the disparity is positive valid_mask = None return disparity_map, valid_mask - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> T1: """Return example at given index. Args: @@ -923,7 +930,7 @@ def __getitem__(self, index: int) -> Tuple: If a ``valid_mask`` is generated within the ``transforms`` parameter, a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. """ - return super().__getitem__(index) + return cast(T1, super().__getitem__(index)) class SintelStereo(StereoMatchingDataset): @@ -973,7 +980,7 @@ class SintelStereo(StereoMatchingDataset): _has_built_in_disparity_mask = True - def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None): + def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both")) @@ -1014,7 +1021,7 @@ def _get_occlussion_mask_paths(self, file_path: str) -> Tuple[str, str]: return occlusion_path, outofframe_path - def _read_disparity(self, file_path: str) -> Tuple: + def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]: if file_path is None: return None, None @@ -1034,7 +1041,7 @@ def _read_disparity(self, file_path: str) -> Tuple: valid_mask = np.logical_and(off_mask, valid_mask) return disparity_map, valid_mask - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> T2: """Return example at given index. Args: @@ -1045,7 +1052,7 @@ def __getitem__(self, index: int) -> Tuple: The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst the valid_mask is a numpy array of shape (H, W). """ - return super().__getitem__(index) + return cast(T2, super().__getitem__(index)) class InStereo2k(StereoMatchingDataset): @@ -1080,7 +1087,7 @@ class InStereo2k(StereoMatchingDataset): transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. """ - def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) root = Path(root) / "InStereo2k" / split @@ -1095,14 +1102,14 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl right_disparity_pattern = str(root / "*" / "right_disp.png") self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) - def _read_disparity(self, file_path: str) -> Tuple: + def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]: disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) # unsqueeze disparity to (C, H, W) disparity_map = disparity_map[None, :, :] / 1024.0 valid_mask = None return disparity_map, valid_mask - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> T1: """Return example at given index. Args: @@ -1114,7 +1121,7 @@ def __getitem__(self, index: int) -> Tuple: If a ``valid_mask`` is generated within the ``transforms`` parameter, a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. """ - return super().__getitem__(index) + return cast(T1, super().__getitem__(index)) class ETH3DStereo(StereoMatchingDataset): @@ -1169,7 +1176,7 @@ class ETH3DStereo(StereoMatchingDataset): _has_built_in_disparity_mask = True - def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -1189,7 +1196,7 @@ def __init__(self, root: str, split: str = "train", transforms: Optional[Callabl disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm") self._disparities = self._scan_pairs(disparity_pattern, None) - def _read_disparity(self, file_path: str) -> Tuple: + def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]: # test split has no disparity maps if file_path is None: return None, None @@ -1201,7 +1208,7 @@ def _read_disparity(self, file_path: str) -> Tuple: valid_mask = np.asarray(valid_mask).astype(bool) return disparity_map, valid_mask - def __getitem__(self, index: int) -> Tuple: + def __getitem__(self, index: int) -> T2: """Return example at given index. Args: @@ -1214,4 +1221,4 @@ def __getitem__(self, index: int) -> Tuple: generate a valid mask. Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. """ - return super().__getitem__(index) + return cast(T2, super().__getitem__(index))