From f5df5fc2a6a6c72589fc3beab47bcf0ce84f8da9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 3 Nov 2021 11:27:34 +0000 Subject: [PATCH 01/15] Add Kitti and Sintel --- torchvision/datasets/__init__.py | 3 + torchvision/datasets/_optical_flow.py | 149 ++++++++++++++++++++++++++ 2 files changed, 152 insertions(+) create mode 100644 torchvision/datasets/_optical_flow.py diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 72a73d1d51b..5edcd1bc584 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,3 +1,4 @@ +from ._optical_flow import KittiFlow, Sintel from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -71,4 +72,6 @@ "INaturalist", "LFWPeople", "LFWPairs", + "KittiFlow", + "Sintel", ) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py new file mode 100644 index 00000000000..b979790ccaa --- /dev/null +++ b/torchvision/datasets/_optical_flow.py @@ -0,0 +1,149 @@ +import os +from abc import ABC, abstractmethod +from glob import glob +from pathlib import Path + +import numpy as np +import torch +from PIL import Image + +from ..io.image import _read_png_16 +from .vision import VisionDataset + + +__all__ = ( + "KittiFlow", + "Sintel", +) + + +class FlowDataset(ABC, VisionDataset): + def __init__(self, root, transforms=None): + + super().__init__(root=root) + self.transforms = transforms + + self._flow_list = [] + self._image_list = [] + + def _read_img(self, file_name): + return Image.open(file_name) + + @abstractmethod + def _read_flow(self, file_name): + # Return the flow or a tuple (flow, valid) for datasets where the valid mask is built-in + pass + + def __getitem__(self, index): + # Some datasets like Kitti have a built-in valid mask, indicating which flow values are valid + # For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow), + # and it's up to whatever consumes the dataset to decide what `valid` should be. + + img1 = self._read_img(self._image_list[index][0]) + img2 = self._read_img(self._image_list[index][1]) + flow = self._read_flow(self._flow_list[index]) if self._flow_list else None + + if isinstance(flow, tuple): + flow, valid = flow + else: + valid = None + + if self.transforms is not None: + img1, img2, flow, valid = self.transforms(img1, img2, flow, valid) + + if valid is None: + return img1, img2, flow + else: + return img1, img2, flow, valid + + def __len__(self): + return len(self._image_list) + + +class Sintel(FlowDataset): + def __init__( + self, + root, + split="train", + dstype="clean", + transforms=None, + ): + + super().__init__(root=root, transforms=transforms) + + if split not in ("train", "test"): + raise ValueError("split must be either 'train' or 'test'") + + if dstype not in ("clean", "final"): + raise ValueError("dstype must be either 'clean' or 'final'") + + split_dir = "training" if split == "train" else split + flow_root = Path(root) / split_dir / "flow" + image_root = Path(root) / split_dir / dstype + + for scene in os.listdir(image_root): + image_list = sorted(glob(str(image_root / scene / "*.png"))) + for i in range(len(image_list) - 1): + self._image_list += [[image_list[i], image_list[i + 1]]] + + if split == "train": + self._flow_list += sorted(glob(str(flow_root / scene / "*.flo"))) + + def _read_flow(self, file_name): + return _read_flo(file_name) + + +class KittiFlow(FlowDataset): + def __init__( + self, + root, + split="train", + transforms=None, + ): + super().__init__(root=root, transforms=transforms) + + if split not in ("train", "test"): + raise ValueError("split must be either 'train' or 'test'") + + root = Path(root) / ("training" if split == "train" else split) + images1 = sorted(glob(str(root / "image_2" / "*_10.png"))) + images2 = sorted(glob(str(root / "image_2" / "*_11.png"))) + + for img1, img2 in zip(images1, images2): + self._image_list += [[img1, img2]] + + if split == "train": + self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png"))) + + def _read_flow(self, file_name): + return _read_16bits_png_with_flow_and_valid_mask(file_name) + + +def _read_flo(file_name): + """Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(file_name, "rb") as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + raise ValueError("Magic number incorrect. Invalid .flo file") + + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + + +def _read_16bits_png_with_flow_and_valid_mask(file_name): + + flow_and_valid = _read_png_16(file_name).to(torch.float32) + flow, valid = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] + flow = (flow - 2 ** 15) / 64 # This conversion is explained somewhere on the kitti archive + + return flow, valid From c3dd41b7318db1e9f8b4f2633258ff6474eac110 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 3 Nov 2021 18:40:47 +0000 Subject: [PATCH 02/15] Add tests --- test/datasets_utils.py | 8 +- test/test_datasets.py | 127 ++++++++++++++++++++++++++ torchvision/datasets/_optical_flow.py | 53 +++++++---- 3 files changed, 168 insertions(+), 20 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 3fb89a6d3da..6c3124ae9e7 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -198,6 +198,7 @@ class DatasetTestCase(unittest.TestCase): ``transforms``, or ``download``. - REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not available, the tests are skipped. + - EXTRA_PATCHES(set): Additional patches to add for each test, to e.g. mock a specific function Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on. The fake data should resemble the original data as close as necessary, while containing only few examples. During @@ -249,6 +250,8 @@ def test_baz(self): ADDITIONAL_CONFIGS = None REQUIRED_PACKAGES = None + EXTRA_PATCHES = None + # These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS. _TRANSFORM_KWARGS = { "transform", @@ -374,6 +377,9 @@ def create_dataset( if patch_checks: patchers.update(self._patch_checks()) + if self.EXTRA_PATCHES is not None: + patchers.update(self.EXTRA_PATCHES) + with get_tmp_dir() as tmpdir: args = self.dataset_args(tmpdir, complete_config) info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None @@ -381,7 +387,7 @@ def create_dataset( with self._maybe_apply_patches(patchers), disable_console_output(): dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs) - yield dataset, info + yield dataset, info @classmethod def setUpClass(cls): diff --git a/test/test_datasets.py b/test/test_datasets.py index 575e5ccb811..02de9d4e0d8 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1871,5 +1871,132 @@ def _inject_pairs(self, root, num_pairs, same): datasets_utils.create_image_folder(root, name2, lambda _: f"{name2}_{no2:04d}.jpg", 1, 250) +class SintelTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Sintel + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final")) + # We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files, + # which is something we want to # avoid. + _FAKE_FLOW = "Fake Flow" + EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.Sintel._read_flow", return_value=_FAKE_FLOW)} + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None))) + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "Sintel" + + num_images_per_scene = 3 if config["split"] == "train" else 4 + num_scenes = 2 + + for split_dir in ("training", "test"): + for pass_name in ("clean", "final"): + image_root = root / split_dir / pass_name + + for scene_id in range(num_scenes): + scene_dir = image_root / f"scene_{scene_id}" + datasets_utils.create_image_folder( + image_root, + name=str(scene_dir), + file_name_fn=lambda image_idx: f"frame_000{image_idx}.png", + num_examples=num_images_per_scene, + ) + + # For the ground truth flow value we just create empty files so that they're properly discovered, + # see comment above about EXTRA_PATCHES + flow_root = root / "training" / "flow" + for scene_id in range(num_scenes): + scene_dir = flow_root / f"scene_{scene_id}" + os.makedirs(scene_dir) + for i in range(num_images_per_scene - 1): + open(str(scene_dir / f"frame_000{i}.flo"), "a").close() + + # with e.g. num_images_per_scene = 3, for a single scene with have 3 images + # which are frame_0000, frame_0001 and frame_0002 + # They will be consecutively paired as (frame_0000, frame_0001), (frame_0001, frame_0002), + # that is 3 - 1 = 2 examples. Hence the formula below + num_examples = (num_images_per_scene - 1) * num_scenes + return num_examples + + def test_flow(self): + # Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images + with self.create_dataset(split="train") as (dataset, _): + assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) + for _, _, flow in dataset: + assert flow == self._FAKE_FLOW + + # Make sure flow is always None for test split + with self.create_dataset(split="test") as (dataset, _): + assert dataset._image_list and not dataset._flow_list + for _, _, flow in dataset: + assert flow is None + + def test_bad_input(self): + with pytest.raises(ValueError, match="split must be either"): + with self.create_dataset(split="bad"): + pass + + with pytest.raises(ValueError, match="pass_name must be either"): + with self.create_dataset(pass_name="bad"): + pass + + +class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.KittiFlow + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "Kitti" + + num_examples = 2 if config["split"] == "train" else 3 + for split_dir in ("training", "test"): + + datasets_utils.create_image_folder( + root / split_dir, + name="image_2", + file_name_fn=lambda image_idx: f"{image_idx}_10.png", + num_examples=num_examples, + ) + datasets_utils.create_image_folder( + root / split_dir, + name="image_2", + file_name_fn=lambda image_idx: f"{image_idx}_11.png", + num_examples=num_examples, + ) + + # For kitti the ground truth flows are encoded as 16-bits pngs. + # create_image_folder() will actually create 8-bits pngs, but it doesn't + # matter much: the flow reader will still be able to read the files, it + # will just be garbage flow value - but we don't care about that here. + datasets_utils.create_image_folder( + root / "training", + name="flow_occ", + file_name_fn=lambda image_idx: f"{image_idx}_10.png", + num_examples=num_examples, + ) + + return num_examples + + def test_flow_and_valid(self): + # Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images + # Also assert flow and valid are of the expected shape + with self.create_dataset(split="train") as (dataset, _): + assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) + for _, _, flow, valid in dataset: + two, h, w = flow.shape + assert two == 2 + assert valid.shape == (h, w) + + # Make sure flow and valid are always None for test split + with self.create_dataset(split="test") as (dataset, _): + assert dataset._image_list and not dataset._flow_list + for _, _, flow, valid in dataset: + assert flow is None + assert valid is None + + def test_bad_input(self): + with pytest.raises(ValueError, match="split must be either"): + with self.create_dataset(split="bad"): + pass + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index b979790ccaa..4d26ee0d68c 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -18,6 +18,11 @@ class FlowDataset(ABC, VisionDataset): + # Some datasets like Kitti have a built-in valid mask, indicating which flow values are valid + # For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow), + # and it's up to whatever consumes the dataset to decide what `valid` should be. + _has_builtin_flow_mask = False + def __init__(self, root, transforms=None): super().__init__(root=root) @@ -31,30 +36,30 @@ def _read_img(self, file_name): @abstractmethod def _read_flow(self, file_name): - # Return the flow or a tuple (flow, valid) for datasets where the valid mask is built-in + # Return the flow or a tuple with the flow and the valid mask if _has_builtin_flow_mask is True pass def __getitem__(self, index): - # Some datasets like Kitti have a built-in valid mask, indicating which flow values are valid - # For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow), - # and it's up to whatever consumes the dataset to decide what `valid` should be. img1 = self._read_img(self._image_list[index][0]) img2 = self._read_img(self._image_list[index][1]) - flow = self._read_flow(self._flow_list[index]) if self._flow_list else None - if isinstance(flow, tuple): - flow, valid = flow + if self._flow_list: # it will be empty for some dataset when split="test" + flow = self._read_flow(self._flow_list[index]) + if self._has_builtin_flow_mask: + flow, valid = flow + else: + valid = None else: - valid = None + flow = valid = None if self.transforms is not None: img1, img2, flow, valid = self.transforms(img1, img2, flow, valid) - if valid is None: - return img1, img2, flow - else: + if self._has_builtin_flow_mask: return img1, img2, flow, valid + else: + return img1, img2, flow def __len__(self): return len(self._image_list) @@ -65,7 +70,7 @@ def __init__( self, root, split="train", - dstype="clean", + pass_name="clean", transforms=None, ): @@ -74,12 +79,14 @@ def __init__( if split not in ("train", "test"): raise ValueError("split must be either 'train' or 'test'") - if dstype not in ("clean", "final"): - raise ValueError("dstype must be either 'clean' or 'final'") + if pass_name not in ("clean", "final"): + raise ValueError("pass_name must be either 'clean' or 'final'") + + root = Path(root) / "Sintel" split_dir = "training" if split == "train" else split - flow_root = Path(root) / split_dir / "flow" - image_root = Path(root) / split_dir / dstype + image_root = root / split_dir / pass_name + flow_root = root / "training" / "flow" for scene in os.listdir(image_root): image_list = sorted(glob(str(image_root / scene / "*.png"))) @@ -94,6 +101,8 @@ def _read_flow(self, file_name): class KittiFlow(FlowDataset): + _has_builtin_flow_mask = True + def __init__( self, root, @@ -105,10 +114,15 @@ def __init__( if split not in ("train", "test"): raise ValueError("split must be either 'train' or 'test'") - root = Path(root) / ("training" if split == "train" else split) + root = Path(root) / "Kitti" / ("training" if split == "train" else split) images1 = sorted(glob(str(root / "image_2" / "*_10.png"))) images2 = sorted(glob(str(root / "image_2" / "*_11.png"))) + if not images1 or not images2: + raise FileNotFoundError( + "Could not find the Kitti flow images. Please make sure the directory structure is correct." + ) + for img1, img2 in zip(images1, images2): self._image_list += [[img1, img2]] @@ -137,7 +151,7 @@ def _read_flo(file_name): data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) # Reshape data into 3D array (columns, rows, bands) # The reshape here is for visualization, the original code is (w,h,2) - return np.resize(data, (int(h), int(w), 2)) + return np.resize(data, (2, int(h), int(w))) def _read_16bits_png_with_flow_and_valid_mask(file_name): @@ -146,4 +160,5 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name): flow, valid = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] flow = (flow - 2 ** 15) / 64 # This conversion is explained somewhere on the kitti archive - return flow, valid + # For consistency with other datasets, we convert to numpy + return flow.numpy(), valid.numpy() From 8c766021600353ec055be2b1197067664435e85c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Nov 2021 09:40:59 +0000 Subject: [PATCH 03/15] Add some docs --- torchvision/datasets/_optical_flow.py | 51 ++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 4d26ee0d68c..2c27f0c1af3 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -101,6 +101,29 @@ def _read_flow(self, file_name): class KittiFlow(FlowDataset): + """Kitti Dataset for optical flow (2015) + + The dataset can be downloaded `from here + `_. + + The dataset is expected to have the following structure: :: + + root + Kitti + testing + image_2 + training + image_2 + flow_occ + + + Args: + root (string): Root directory of the KittiFlow Dataset. + split (string, optional): The dataset split, either "train" (default) or "test" + transforms (callable, optional): A function/transform that takes in + ``img1, img2, flow, valid`` and returns a transformed version. + """ + _has_builtin_flow_mask = True def __init__( @@ -129,6 +152,21 @@ def __init__( if split == "train": self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png"))) + def __getitem__(self, index): + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow, + valid)`` where ``valid`` is a numpy boolean mask of shape (H, W) + indicating which flow values are valid. The flow is a numpy array of + shape (2, H, W) and the images are PIL images. If `split="test"`, a + 4-tuple with ``(img1, img2, None, None)`` is returned. + """ + return super().__getitem__(index) + def _read_flow(self, file_name): return _read_16bits_png_with_flow_and_valid_mask(file_name) @@ -137,21 +175,16 @@ def _read_flo(file_name): """Read .flo file in Middlebury format""" # Code adapted from: # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy - # WARNING: this will work on little-endian architectures (eg Intel x86) only! - # print 'fn = %s'%(fn) with open(file_name, "rb") as f: magic = np.fromfile(f, np.float32, count=1) if 202021.25 != magic: raise ValueError("Magic number incorrect. Invalid .flo file") - w = np.fromfile(f, np.int32, count=1) - h = np.fromfile(f, np.int32, count=1) - # print 'Reading %d x %d flo file\n' % (w, h) - data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) - # Reshape data into 3D array (columns, rows, bands) - # The reshape here is for visualization, the original code is (w,h,2) - return np.resize(data, (2, int(h), int(w))) + w = int(np.fromfile(f, np.int32, count=1)) + h = int(np.fromfile(f, np.int32, count=1)) + data = np.fromfile(f, np.float32, count=2 * w * h) + return data.reshape(2, h, w) def _read_16bits_png_with_flow_and_valid_mask(file_name): From e6ecc4ef7f70428650557643b42790ba68d615f9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Nov 2021 09:47:20 +0000 Subject: [PATCH 04/15] More docs --- docs/source/datasets.rst | 2 + torchvision/datasets/_optical_flow.py | 53 ++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index fdf01eb8ffa..89dfe7e08d8 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -48,6 +48,7 @@ You can also create your own datasets using the provided :ref:`base classes `_. + + The dataset is expected to have the following structure: :: + + root + Sintel + testing + clean + scene_1 + scene_2 + ... + final + scene_1 + scene_2 + ... + training + clean + scene_1 + scene_2 + ... + final + scene_1 + scene_2 + ... + flow + scene_1 + scene_2 + ... + + Args: + root (string): Root directory of the Sintel Dataset. + split (string, optional): The dataset split, either "train" (default) or "test" + transforms (callable, optional): A function/transform that takes in + ``img1, img2, flow, valid`` and returns a transformed version. + ``valid`` is expected for consistency with other datasets which + return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. + """ def __init__( self, root, @@ -96,6 +135,19 @@ def __init__( if split == "train": self._flow_list += sorted(glob(str(flow_root / scene / "*.flo"))) + def __getitem__(self, index): + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow). + The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a + 4-tuple with ``(img1, img2, None)`` is returned. + """ + return super().__getitem__(index) + def _read_flow(self, file_name): return _read_flo(file_name) @@ -116,7 +168,6 @@ class KittiFlow(FlowDataset): image_2 flow_occ - Args: root (string): Root directory of the KittiFlow Dataset. split (string, optional): The dataset split, either "train" (default) or "test" From 721b94b686c00c6fd4d3e4829d07849f7274abfc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Nov 2021 09:52:05 +0000 Subject: [PATCH 05/15] more docs --- torchvision/datasets/_optical_flow.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index ad15d591617..51d522fbddd 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -66,9 +66,7 @@ def __len__(self): class Sintel(FlowDataset): - """Sintel Dataset for optical flow. - - The dataset can be downloaded `from here `_. + """`Sintel `_ Dataset for optical flow. The dataset is expected to have the following structure: :: @@ -100,11 +98,14 @@ class Sintel(FlowDataset): Args: root (string): Root directory of the Sintel Dataset. split (string, optional): The dataset split, either "train" (default) or "test" + pass_name (string, optional): The pass to use, either "clean" (default) or "final". See link above for + details on the different passes. transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid`` and returns a transformed version. ``valid`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ + def __init__( self, root, @@ -144,7 +145,7 @@ def __getitem__(self, index): Returns: tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow). The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a - 4-tuple with ``(img1, img2, None)`` is returned. + 3-tuple with ``(img1, img2, None)`` is returned. """ return super().__getitem__(index) @@ -153,10 +154,7 @@ def _read_flow(self, file_name): class KittiFlow(FlowDataset): - """Kitti Dataset for optical flow (2015) - - The dataset can be downloaded `from here - `_. + """`Kitti `_ dataset for optical flow (2015). The dataset is expected to have the following structure: :: From 6f95da0bc5cd3dc2503bd5cf7793076d4b043e03 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Nov 2021 10:38:58 +0000 Subject: [PATCH 06/15] test -> testing for Kitti --- test/test_datasets.py | 2 +- torchvision/datasets/_optical_flow.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 02de9d4e0d8..57c2a80181a 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1947,7 +1947,7 @@ def inject_fake_data(self, tmpdir, config): root = pathlib.Path(tmpdir) / "Kitti" num_examples = 2 if config["split"] == "train" else 3 - for split_dir in ("training", "test"): + for split_dir in ("training", "testing"): datasets_utils.create_image_folder( root / split_dir, diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index aef41852c73..dd699d80fe2 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -186,7 +186,7 @@ def __init__( if split not in ("train", "test"): raise ValueError("split must be either 'train' or 'test'") - root = Path(root) / "Kitti" / ("training" if split == "train" else split) + root = Path(root) / "Kitti" / (split + "ing") images1 = sorted(glob(str(root / "image_2" / "*_10.png"))) images2 = sorted(glob(str(root / "image_2" / "*_11.png"))) From 3b8ba30c73df4ae5fcc55d39cbbc1434acff66ce Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Nov 2021 10:45:50 +0000 Subject: [PATCH 07/15] less vert space --- torchvision/datasets/_optical_flow.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index dd699d80fe2..7cb19e8d8c4 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -106,14 +106,7 @@ class Sintel(FlowDataset): return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ - def __init__( - self, - root, - split="train", - pass_name="clean", - transforms=None, - ): - + def __init__(self, root, split="train", pass_name="clean", transforms=None): super().__init__(root=root, transforms=transforms) if split not in ("train", "test"): @@ -175,12 +168,7 @@ class KittiFlow(FlowDataset): _has_builtin_flow_mask = True - def __init__( - self, - root, - split="train", - transforms=None, - ): + def __init__(self, root, split="train", transforms=None): super().__init__(root=root, transforms=transforms) if split not in ("train", "test"): From b57885eeec77694fed2d1872cc3155dda0d75208 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Nov 2021 13:45:29 +0000 Subject: [PATCH 08/15] Add FlyingThings3D dataset --- test/test_datasets.py | 74 +++++++++++++++++ torchvision/datasets/__init__.py | 3 +- torchvision/datasets/_optical_flow.py | 113 ++++++++++++++++++++++++++ 3 files changed, 189 insertions(+), 1 deletion(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 57c2a80181a..ebccd2aecfa 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1998,5 +1998,79 @@ def test_bad_input(self): pass +class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.FlyingThings3D + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("train", "test"), pass_name=("clean", "final", "both"), camera=("left", "right", "both") + ) + # We patch the flow reader, because this would otherwise force us to generate fake (but readable) .PFM files, + # which is something we want to avoid. + _FAKE_FLOW = "Fake Flow" + EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.FlyingThings3D._read_flow", return_value=_FAKE_FLOW)} + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None))) + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "FlyingThings3D" + + num_images_per_camera = 3 if config["split"] == "train" else 4 + passes = ("frames_cleanpass", "frames_finalpass") + splits = ("TRAIN", "TEST") + letters = ("A", "B", "C") + subfolders = ("0000", "0001") + cameras = ("left", "right") + for pass_name in passes: + for split in splits: + for letter in letters: + for subfolder in subfolders: + current_folder = root / pass_name / split / letter / subfolder + for camera in cameras: + datasets_utils.create_image_folder( + current_folder, + name=camera, + file_name_fn=lambda image_idx: f"00{image_idx}.png", + num_examples=num_images_per_camera, + ) + + # For the ground truth flow value we just create empty files so that they're properly discovered, + # see comment above about EXTRA_PATCHES + directions = ("into_future", "into_past") + for split in splits: + for letter in letters: + for subfolder in subfolders: + for direction in directions: + current_folder = root / "optical_flow" / split / letter / subfolder / direction + for camera in cameras: + os.makedirs(str(current_folder / camera)) + for i in range(num_images_per_camera): + open(str(current_folder / camera / f"{i}.pfm"), "a").close() + + num_cameras = 2 if config["camera"] == "both" else 1 + num_passes = 2 if config["pass_name"] == "both" else 1 + num_examples = ( + (num_images_per_camera - 1) * num_cameras * len(subfolders) * len(letters) * len(splits) * num_passes + ) + return num_examples + + @datasets_utils.test_all_configs + def test_flow(self, config): + with self.create_dataset(config=config) as (dataset, _): + assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) + for _, _, flow in dataset: + assert flow == self._FAKE_FLOW + + def test_bad_input(self): + with pytest.raises(ValueError, match="split must be either"): + with self.create_dataset(split="bad"): + pass + + with pytest.raises(ValueError, match="pass_name must be either"): + with self.create_dataset(pass_name="bad"): + pass + + with pytest.raises(ValueError, match="camera must be either"): + with self.create_dataset(camera="bad"): + pass + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 5edcd1bc584..3de7f9c7224 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,4 +1,4 @@ -from ._optical_flow import KittiFlow, Sintel +from ._optical_flow import KittiFlow, Sintel, FlyingThings3D from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -74,4 +74,5 @@ "LFWPairs", "KittiFlow", "Sintel", + "FlyingThings3D", ) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 7cb19e8d8c4..5ba1af2605e 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -1,4 +1,5 @@ import os +import re from abc import ABC, abstractmethod from glob import glob from pathlib import Path @@ -14,6 +15,7 @@ __all__ = ( "KittiFlow", "Sintel", + "FlyingThings3D", ) @@ -208,6 +210,89 @@ def _read_flow(self, file_name): return _read_16bits_png_with_flow_and_valid_mask(file_name) +class FlyingThings3D(FlowDataset): + """`FlyingThings3D `_ dataset for optical flow. + + The dataset is expected to have the following structure: :: + + root + FlyingThings3D + + Args: + root (string): Root directory of the Sintel Dataset. + split (string, optional): The dataset split, either "train" (default) or "test" + pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for + details on the different passes. + camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both". + transforms (callable, optional): A function/transform that takes in + ``img1, img2, flow, valid`` and returns a transformed version. + ``valid`` is expected for consistency with other datasets which + return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. + """ + + def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None): + super().__init__(root=root, transforms=transforms) + + if split not in ("train", "test"): + raise ValueError("split must be either 'train' or 'test'") + split = split.upper() + + if pass_name not in ("clean", "final", "both"): + raise ValueError("pass_name must be either 'clean', 'final', or 'both'") + passes = { + "clean": ["frames_cleanpass"], + "final": ["frames_finalpass"], + "both": ["frames_cleanpass", "frames_finalpass"], + }[pass_name] + + if camera not in ("left", "right", "both"): + raise ValueError("camera must be either 'left', 'right', or 'both'") + cameras = ["left", "right"] if camera == "both" else [camera] + + root = Path(root) / "FlyingThings3D" + + for pass_name in passes: + for camera in cameras: + for direction in ["into_future", "into_past"]: + image_dirs = sorted(glob(str(root / pass_name / split / "*/*"))) + image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs]) + + flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*"))) + flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs]) + + if not image_dirs or not flow_dirs: + raise FileNotFoundError( + "Could not find the FlyingThings3D flow images. " + "Please make sure the directory structure is correct." + ) + + for image_dir, flow_dir in zip(image_dirs, flow_dirs): + images = sorted(glob(str(image_dir / "*.png"))) + flows = sorted(glob(str(flow_dir / "*.pfm"))) + for i in range(len(flows) - 1): + if direction == "into_future": + self._image_list += [[images[i], images[i + 1]]] + self._flow_list += [flows[i]] + elif direction == "into_past": + self._image_list += [[images[i + 1], images[i]]] + self._flow_list += [flows[i + 1]] + + def __getitem__(self, index): + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img1, img2, flow)``. + The flow is a numpy array of shape (2, H, W) and the images are PIL images. + """ + return super().__getitem__(index) + + def _read_flow(self, file_name): + return _read_pfm(file_name) + + def _read_flo(file_name): """Read .flo file in Middlebury format""" # Code adapted from: @@ -232,3 +317,31 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name): # For consistency with other datasets, we convert to numpy return flow.numpy(), valid.numpy() + + +def _read_pfm(file_name): + """Read flow in .pfm format""" + + with open(file_name, "rb") as f: + header = f.readline().rstrip() + if header != b"PF": + raise ValueError("Invalid PFM file") + + dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline()) + if not dim_match: + raise Exception("Malformed PFM header.") + w, h = (int(dim) for dim in dim_match.groups()) + + scale = float(f.readline().rstrip()) + if scale < 0: # little-endian + endian = "<" + scale = -scale + else: + endian = ">" # big-endian + + data = np.fromfile(f, dtype=endian + "f") + + data = data.reshape(h, w, 3).transpose(2, 0, 1) + data = np.flip(data, axis=1) # flip on h dimension + data = data[:2, :, :] + return data.astype(np.float32) From bb1b8971a04a1f9ea573ede95fd003f069703ea3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 4 Nov 2021 17:37:00 +0000 Subject: [PATCH 09/15] Add fake pfm file generation for more robust testing --- test/datasets_utils.py | 9 +++++++++ test/test_datasets.py | 17 ++++++++--------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 9906d42e012..3e3ce029c5f 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -8,6 +8,7 @@ import random import shutil import string +import struct import tarfile import unittest import unittest.mock @@ -922,3 +923,11 @@ def create_random_string(length: int, *digits: str) -> str: digits = "".join(itertools.chain(*digits)) return "".join(random.choice(digits) for _ in range(length)) + + +def make_fake_pfm_file(h, w, file_name): + values = list(range(3 * h * w)) + # Note: we pack everything in little endian: -1, and "<" + content = f"PF \n{w} {h} \n-1.0\n".encode() + struct.pack("<" + "f" * len(values), *values) + with open(file_name, "wb") as f: + f.write(content) diff --git a/test/test_datasets.py b/test/test_datasets.py index ebccd2aecfa..de9457d3a72 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2003,11 +2003,9 @@ class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase): ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( split=("train", "test"), pass_name=("clean", "final", "both"), camera=("left", "right", "both") ) - # We patch the flow reader, because this would otherwise force us to generate fake (but readable) .PFM files, - # which is something we want to avoid. - _FAKE_FLOW = "Fake Flow" - EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.FlyingThings3D._read_flow", return_value=_FAKE_FLOW)} - FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None))) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + + FLOW_H, FLOW_W = 3, 4 def inject_fake_data(self, tmpdir, config): root = pathlib.Path(tmpdir) / "FlyingThings3D" @@ -2031,8 +2029,6 @@ def inject_fake_data(self, tmpdir, config): num_examples=num_images_per_camera, ) - # For the ground truth flow value we just create empty files so that they're properly discovered, - # see comment above about EXTRA_PATCHES directions = ("into_future", "into_past") for split in splits: for letter in letters: @@ -2042,7 +2038,9 @@ def inject_fake_data(self, tmpdir, config): for camera in cameras: os.makedirs(str(current_folder / camera)) for i in range(num_images_per_camera): - open(str(current_folder / camera / f"{i}.pfm"), "a").close() + datasets_utils.make_fake_pfm_file( + self.FLOW_H, self.FLOW_W, file_name=str(current_folder / camera / f"{i}.pfm") + ) num_cameras = 2 if config["camera"] == "both" else 1 num_passes = 2 if config["pass_name"] == "both" else 1 @@ -2056,7 +2054,8 @@ def test_flow(self, config): with self.create_dataset(config=config) as (dataset, _): assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) for _, _, flow in dataset: - assert flow == self._FAKE_FLOW + assert flow.shape == (2, self.FLOW_H, self.FLOW_W) + # We don't check the values because the reshaping and flipping makes it hard to figure out def test_bad_input(self): with pytest.raises(ValueError, match="split must be either"): From 9affa21958e43051fc952294e5fd6d13fcc0e869 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 5 Nov 2021 09:50:01 +0000 Subject: [PATCH 10/15] minor update --- docs/source/datasets.rst | 1 + test/test_datasets.py | 6 +++--- torchvision/datasets/_optical_flow.py | 19 +++++++++++++------ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 89dfe7e08d8..f9f6c02dd99 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -43,6 +43,7 @@ You can also create your own datasets using the provided :ref:`base classes Date: Fri, 5 Nov 2021 09:56:54 +0000 Subject: [PATCH 11/15] Address comments --- test/test_datasets.py | 38 +++++++++++++++----------------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 5ab63e39109..57b55e3b6c8 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2016,31 +2016,23 @@ def inject_fake_data(self, tmpdir, config): letters = ("A", "B", "C") subfolders = ("0000", "0001") cameras = ("left", "right") - for pass_name in passes: - for split in splits: - for letter in letters: - for subfolder in subfolders: - current_folder = root / pass_name / split / letter / subfolder - for camera in cameras: - datasets_utils.create_image_folder( - current_folder, - name=camera, - file_name_fn=lambda image_idx: f"00{image_idx}.png", - num_examples=num_images_per_camera, - ) + for pass_name, split, letter, subfolder, camera in itertools.product(passes, splits, letters, subfolders, cameras): + current_folder = root / pass_name / split / letter / subfolder + datasets_utils.create_image_folder( + current_folder, + name=camera, + file_name_fn=lambda image_idx: f"00{image_idx}.png", + num_examples=num_images_per_camera, + ) directions = ("into_future", "into_past") - for split in splits: - for letter in letters: - for subfolder in subfolders: - for direction in directions: - current_folder = root / "optical_flow" / split / letter / subfolder / direction - for camera in cameras: - os.makedirs(str(current_folder / camera)) - for i in range(num_images_per_camera): - datasets_utils.make_fake_pfm_file( - self.FLOW_H, self.FLOW_W, file_name=str(current_folder / camera / f"{i}.pfm") - ) + for split, letter, subfolder, direction, camera in itertools.product(splits, letters, subfolders, directions, cameras): + current_folder = root / "optical_flow" / split / letter / subfolder / direction / camera + os.makedirs(str(current_folder), exist_ok=True) + for i in range(num_images_per_camera): + datasets_utils.make_fake_pfm_file( + self.FLOW_H, self.FLOW_W, file_name=str(current_folder / f"{i}.pfm") + ) num_cameras = 2 if config["camera"] == "both" else 1 num_passes = 2 if config["pass_name"] == "both" else 1 From 8b4731babf588c6054f02edf8d6ee42ffbd0ed5e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 5 Nov 2021 09:57:59 +0000 Subject: [PATCH 12/15] Typo --- test/test_datasets.py | 12 +++++++----- torchvision/datasets/_optical_flow.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 57b55e3b6c8..49af749193a 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2016,7 +2016,9 @@ def inject_fake_data(self, tmpdir, config): letters = ("A", "B", "C") subfolders = ("0000", "0001") cameras = ("left", "right") - for pass_name, split, letter, subfolder, camera in itertools.product(passes, splits, letters, subfolders, cameras): + for pass_name, split, letter, subfolder, camera in itertools.product( + passes, splits, letters, subfolders, cameras + ): current_folder = root / pass_name / split / letter / subfolder datasets_utils.create_image_folder( current_folder, @@ -2026,13 +2028,13 @@ def inject_fake_data(self, tmpdir, config): ) directions = ("into_future", "into_past") - for split, letter, subfolder, direction, camera in itertools.product(splits, letters, subfolders, directions, cameras): + for split, letter, subfolder, direction, camera in itertools.product( + splits, letters, subfolders, directions, cameras + ): current_folder = root / "optical_flow" / split / letter / subfolder / direction / camera os.makedirs(str(current_folder), exist_ok=True) for i in range(num_images_per_camera): - datasets_utils.make_fake_pfm_file( - self.FLOW_H, self.FLOW_W, file_name=str(current_folder / f"{i}.pfm") - ) + datasets_utils.make_fake_pfm_file(self.FLOW_H, self.FLOW_W, file_name=str(current_folder / f"{i}.pfm")) num_cameras = 2 if config["camera"] == "both" else 1 num_passes = 2 if config["pass_name"] == "both" else 1 diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 89603b8c971..84df3d3ccd5 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -229,7 +229,7 @@ class FlyingThings3D(FlowDataset): TRAIN Args: - root (string): Root directory of the Sintel Dataset. + root (string): Root directory of the intel FlyingThings3D Dataset. split (string, optional): The dataset split, either "train" (default) or "test" pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for details on the different passes. From dd14728c0f7cf7ad35faec68d23eb0b5dcfe89eb Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 5 Nov 2021 10:28:50 +0000 Subject: [PATCH 13/15] formatting --- torchvision/datasets/_optical_flow.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 9ee86ea4fcc..d042faef3c9 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -233,6 +233,7 @@ class FlyingChairs(FlowDataset): ``valid`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ + def __init__(self, root, split="train", transforms=None): super().__init__(root=root, transforms=transforms) @@ -300,6 +301,7 @@ class FlyingThings3D(FlowDataset): ``valid`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ + def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None): super().__init__(root=root, transforms=transforms) @@ -343,6 +345,7 @@ def __init__(self, root, split="train", pass_name="clean", camera="left", transf elif direction == "into_past": self._image_list += [[images[i + 1], images[i]]] self._flow_list += [flows[i + 1]] + def __getitem__(self, index): """Return example at given index. From ff4af143a40b1ca75f6e953e447d27a29803f0d9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 5 Nov 2021 10:30:16 +0000 Subject: [PATCH 14/15] remove changes related to EXTRA_PATCHES which we don't need anymore --- test/datasets_utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index aba654ded00..5210a1512b1 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -204,7 +204,6 @@ class DatasetTestCase(unittest.TestCase): ``transforms``, or ``download``. - REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not available, the tests are skipped. - - EXTRA_PATCHES(set): Additional patches to add for each test, to e.g. mock a specific function Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on. The fake data should resemble the original data as close as necessary, while containing only few examples. During @@ -256,8 +255,6 @@ def test_baz(self): ADDITIONAL_CONFIGS = None REQUIRED_PACKAGES = None - EXTRA_PATCHES = None - # These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS. _TRANSFORM_KWARGS = { "transform", @@ -383,9 +380,6 @@ def create_dataset( if patch_checks: patchers.update(self._patch_checks()) - if self.EXTRA_PATCHES is not None: - patchers.update(self.EXTRA_PATCHES) - with get_tmp_dir() as tmpdir: args = self.dataset_args(tmpdir, complete_config) info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None @@ -393,7 +387,7 @@ def create_dataset( with self._maybe_apply_patches(patchers), disable_console_output(): dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs) - yield dataset, info + yield dataset, info @classmethod def setUpClass(cls): From ba4cfd6e2325ccf1ff69c95a4670929af6ae09b9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 5 Nov 2021 10:33:24 +0000 Subject: [PATCH 15/15] itertools.prodcut --- torchvision/datasets/_optical_flow.py | 50 +++++++++++++-------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index d042faef3c9..6ff49395a0a 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -1,3 +1,4 @@ +import itertools import os import re from abc import ABC, abstractmethod @@ -320,31 +321,30 @@ def __init__(self, root, split="train", pass_name="clean", camera="left", transf root = Path(root) / "FlyingThings3D" - for pass_name in passes: - for camera in cameras: - for direction in ["into_future", "into_past"]: - image_dirs = sorted(glob(str(root / pass_name / split / "*/*"))) - image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs]) - - flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*"))) - flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs]) - - if not image_dirs or not flow_dirs: - raise FileNotFoundError( - "Could not find the FlyingThings3D flow images. " - "Please make sure the directory structure is correct." - ) - - for image_dir, flow_dir in zip(image_dirs, flow_dirs): - images = sorted(glob(str(image_dir / "*.png"))) - flows = sorted(glob(str(flow_dir / "*.pfm"))) - for i in range(len(flows) - 1): - if direction == "into_future": - self._image_list += [[images[i], images[i + 1]]] - self._flow_list += [flows[i]] - elif direction == "into_past": - self._image_list += [[images[i + 1], images[i]]] - self._flow_list += [flows[i + 1]] + directions = ("into_future", "into_past") + for pass_name, camera, direction in itertools.product(passes, cameras, directions): + image_dirs = sorted(glob(str(root / pass_name / split / "*/*"))) + image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs]) + + flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*"))) + flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs]) + + if not image_dirs or not flow_dirs: + raise FileNotFoundError( + "Could not find the FlyingThings3D flow images. " + "Please make sure the directory structure is correct." + ) + + for image_dir, flow_dir in zip(image_dirs, flow_dirs): + images = sorted(glob(str(image_dir / "*.png"))) + flows = sorted(glob(str(flow_dir / "*.pfm"))) + for i in range(len(flows) - 1): + if direction == "into_future": + self._image_list += [[images[i], images[i + 1]]] + self._flow_list += [flows[i]] + elif direction == "into_past": + self._image_list += [[images[i + 1], images[i]]] + self._flow_list += [flows[i + 1]] def __getitem__(self, index): """Return example at given index.