Skip to content

Add FlyingThings3D dataset for optical flow #4858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Flickr8k
Flickr30k
FlyingChairs
FlyingThings3D
HMDB51
ImageNet
INaturalist
Expand Down
16 changes: 9 additions & 7 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ class DatasetTestCase(unittest.TestCase):
``transforms``, or ``download``.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped.
- EXTRA_PATCHES(set): Additional patches to add for each test, to e.g. mock a specific function

Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
The fake data should resemble the original data as close as necessary, while containing only few examples. During
Expand Down Expand Up @@ -256,8 +255,6 @@ def test_baz(self):
ADDITIONAL_CONFIGS = None
REQUIRED_PACKAGES = None

EXTRA_PATCHES = None

# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
_TRANSFORM_KWARGS = {
"transform",
Expand Down Expand Up @@ -383,17 +380,14 @@ def create_dataset(
if patch_checks:
patchers.update(self._patch_checks())

if self.EXTRA_PATCHES is not None:
patchers.update(self.EXTRA_PATCHES)

with get_tmp_dir() as tmpdir:
args = self.dataset_args(tmpdir, complete_config)
info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None

with self._maybe_apply_patches(patchers), disable_console_output():
dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)

yield dataset, info
yield dataset, info

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -925,6 +919,14 @@ def create_random_string(length: int, *digits: str) -> str:
return "".join(random.choice(digits) for _ in range(length))


def make_fake_pfm_file(h, w, file_name):
values = list(range(3 * h * w))
# Note: we pack everything in little endian: -1.0, and "<"
content = f"PF \n{w} {h} \n-1.0\n".encode() + struct.pack("<" + "f" * len(values), *values)
with open(file_name, "wb") as f:
f.write(content)


def make_fake_flo_file(h, w, file_name):
"""Creates a fake flow file in .flo format."""
values = list(range(2 * h * w))
Expand Down
67 changes: 67 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,5 +2048,72 @@ def test_flow(self, config):
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))


class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FlyingThings3D
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "test"), pass_name=("clean", "final", "both"), camera=("left", "right", "both")
)
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))

FLOW_H, FLOW_W = 3, 4

def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "FlyingThings3D"

num_images_per_camera = 3 if config["split"] == "train" else 4
passes = ("frames_cleanpass", "frames_finalpass")
splits = ("TRAIN", "TEST")
letters = ("A", "B", "C")
subfolders = ("0000", "0001")
cameras = ("left", "right")
for pass_name, split, letter, subfolder, camera in itertools.product(
passes, splits, letters, subfolders, cameras
):
current_folder = root / pass_name / split / letter / subfolder
datasets_utils.create_image_folder(
current_folder,
name=camera,
file_name_fn=lambda image_idx: f"00{image_idx}.png",
num_examples=num_images_per_camera,
)

directions = ("into_future", "into_past")
for split, letter, subfolder, direction, camera in itertools.product(
splits, letters, subfolders, directions, cameras
):
current_folder = root / "optical_flow" / split / letter / subfolder / direction / camera
os.makedirs(str(current_folder), exist_ok=True)
for i in range(num_images_per_camera):
datasets_utils.make_fake_pfm_file(self.FLOW_H, self.FLOW_W, file_name=str(current_folder / f"{i}.pfm"))

num_cameras = 2 if config["camera"] == "both" else 1
num_passes = 2 if config["pass_name"] == "both" else 1
num_examples = (
(num_images_per_camera - 1) * num_cameras * len(subfolders) * len(letters) * len(splits) * num_passes
)
return num_examples

@datasets_utils.test_all_configs
def test_flow(self, config):
with self.create_dataset(config=config) as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow in dataset:
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
# We don't check the values because the reshaping and flipping makes it hard to figure out

def test_bad_input(self):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
with self.create_dataset(split="bad"):
pass

with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"):
with self.create_dataset(pass_name="bad"):
pass

with pytest.raises(ValueError, match="Unknown value 'bad' for argument camera"):
with self.create_dataset(camera="bad"):
pass


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._optical_flow import KittiFlow, Sintel, FlyingChairs
from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
Expand Down Expand Up @@ -75,4 +75,5 @@
"KittiFlow",
"Sintel",
"FlyingChairs",
"FlyingThings3D",
)
119 changes: 119 additions & 0 deletions torchvision/datasets/_optical_flow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import itertools
import os
import re
from abc import ABC, abstractmethod
from glob import glob
from pathlib import Path
Expand All @@ -15,6 +17,7 @@
__all__ = (
"KittiFlow",
"Sintel",
"FlyingThings3D",
"FlyingChairs",
)

Expand Down Expand Up @@ -271,6 +274,94 @@ def _read_flow(self, file_name):
return _read_flo(file_name)


class FlyingThings3D(FlowDataset):
"""`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow.

The dataset is expected to have the following structure: ::

root
FlyingThings3D
frames_cleanpass
TEST
TRAIN
frames_finalpass
TEST
TRAIN
optical_flow
TEST
TRAIN

Args:
root (string): Root directory of the intel FlyingThings3D Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
details on the different passes.
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None):
super().__init__(root=root, transforms=transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
split = split.upper()

verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
passes = {
"clean": ["frames_cleanpass"],
"final": ["frames_finalpass"],
"both": ["frames_cleanpass", "frames_finalpass"],
}[pass_name]

verify_str_arg(camera, "camera", valid_values=("left", "right", "both"))
cameras = ["left", "right"] if camera == "both" else [camera]

root = Path(root) / "FlyingThings3D"

directions = ("into_future", "into_past")
for pass_name, camera, direction in itertools.product(passes, cameras, directions):
image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs])

flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs])

if not image_dirs or not flow_dirs:
raise FileNotFoundError(
"Could not find the FlyingThings3D flow images. "
"Please make sure the directory structure is correct."
)

for image_dir, flow_dir in zip(image_dirs, flow_dirs):
images = sorted(glob(str(image_dir / "*.png")))
flows = sorted(glob(str(flow_dir / "*.pfm")))
for i in range(len(flows) - 1):
if direction == "into_future":
self._image_list += [[images[i], images[i + 1]]]
self._flow_list += [flows[i]]
elif direction == "into_past":
self._image_list += [[images[i + 1], images[i]]]
self._flow_list += [flows[i + 1]]

def __getitem__(self, index):
"""Return example at given index.

Args:
index(int): The index of the example to retrieve

Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
"""
return super().__getitem__(index)

def _read_flow(self, file_name):
return _read_pfm(file_name)


def _read_flo(file_name):
"""Read .flo file in Middlebury format"""
# Code adapted from:
Expand All @@ -295,3 +386,31 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name):

# For consistency with other datasets, we convert to numpy
return flow.numpy(), valid.numpy()


def _read_pfm(file_name):
"""Read flow in .pfm format"""

with open(file_name, "rb") as f:
header = f.readline().rstrip()
if header != b"PF":
raise ValueError("Invalid PFM file")

dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
if not dim_match:
raise Exception("Malformed PFM header.")
w, h = (int(dim) for dim in dim_match.groups())

scale = float(f.readline().rstrip())
if scale < 0: # little-endian
endian = "<"
scale = -scale
else:
endian = ">" # big-endian

data = np.fromfile(f, dtype=endian + "f")

data = data.reshape(h, w, 3).transpose(2, 0, 1)
data = np.flip(data, axis=1) # flip on h dimension
data = data[:2, :, :]
return data.astype(np.float32)