Skip to content

Commit b76bfb8

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] Add FlyingChairs dataset for optical flow (#4860)
Reviewed By: kazhang Differential Revision: D32216687 fbshipit-source-id: 5f59a29d4f066fa84ad6495a15455ee966a23eb8
1 parent 1650d4c commit b76bfb8

File tree

5 files changed

+144
-20
lines changed

5 files changed

+144
-20
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4343
FashionMNIST
4444
Flickr8k
4545
Flickr30k
46+
FlyingChairs
4647
HMDB51
4748
ImageNet
4849
INaturalist

test/datasets_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import random
99
import shutil
1010
import string
11+
import struct
1112
import tarfile
1213
import unittest
1314
import unittest.mock
@@ -922,3 +923,11 @@ def create_random_string(length: int, *digits: str) -> str:
922923
digits = "".join(itertools.chain(*digits))
923924

924925
return "".join(random.choice(digits) for _ in range(length))
926+
927+
928+
def make_fake_flo_file(h, w, file_name):
929+
"""Creates a fake flow file in .flo format."""
930+
values = list(range(2 * h * w))
931+
content = b"PIEH" + struct.pack("i", w) + struct.pack("i", h) + struct.pack("f" * len(values), *values)
932+
with open(file_name, "wb") as f:
933+
f.write(content)

test/test_datasets.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,11 +1874,9 @@ def _inject_pairs(self, root, num_pairs, same):
18741874
class SintelTestCase(datasets_utils.ImageDatasetTestCase):
18751875
DATASET_CLASS = datasets.Sintel
18761876
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final"))
1877-
# We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
1878-
# which is something we want to # avoid.
1879-
_FAKE_FLOW = "Fake Flow"
1880-
EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.Sintel._read_flow", return_value=_FAKE_FLOW)}
1881-
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None)))
1877+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
1878+
1879+
FLOW_H, FLOW_W = 3, 4
18821880

18831881
def inject_fake_data(self, tmpdir, config):
18841882
root = pathlib.Path(tmpdir) / "Sintel"
@@ -1899,14 +1897,13 @@ def inject_fake_data(self, tmpdir, config):
18991897
num_examples=num_images_per_scene,
19001898
)
19011899

1902-
# For the ground truth flow value we just create empty files so that they're properly discovered,
1903-
# see comment above about EXTRA_PATCHES
19041900
flow_root = root / "training" / "flow"
19051901
for scene_id in range(num_scenes):
19061902
scene_dir = flow_root / f"scene_{scene_id}"
19071903
os.makedirs(scene_dir)
19081904
for i in range(num_images_per_scene - 1):
1909-
open(str(scene_dir / f"frame_000{i}.flo"), "a").close()
1905+
file_name = str(scene_dir / f"frame_000{i}.flo")
1906+
datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name)
19101907

19111908
# with e.g. num_images_per_scene = 3, for a single scene with have 3 images
19121909
# which are frame_0000, frame_0001 and frame_0002
@@ -1920,7 +1917,8 @@ def test_flow(self):
19201917
with self.create_dataset(split="train") as (dataset, _):
19211918
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
19221919
for _, _, flow in dataset:
1923-
assert flow == self._FAKE_FLOW
1920+
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
1921+
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
19241922

19251923
# Make sure flow is always None for test split
19261924
with self.create_dataset(split="test") as (dataset, _):
@@ -1929,11 +1927,11 @@ def test_flow(self):
19291927
assert flow is None
19301928

19311929
def test_bad_input(self):
1932-
with pytest.raises(ValueError, match="split must be either"):
1930+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
19331931
with self.create_dataset(split="bad"):
19341932
pass
19351933

1936-
with pytest.raises(ValueError, match="pass_name must be either"):
1934+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"):
19371935
with self.create_dataset(pass_name="bad"):
19381936
pass
19391937

@@ -1993,10 +1991,62 @@ def test_flow_and_valid(self):
19931991
assert valid is None
19941992

19951993
def test_bad_input(self):
1996-
with pytest.raises(ValueError, match="split must be either"):
1994+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
19971995
with self.create_dataset(split="bad"):
19981996
pass
19991997

20001998

1999+
class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
2000+
DATASET_CLASS = datasets.FlyingChairs
2001+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"))
2002+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
2003+
2004+
FLOW_H, FLOW_W = 3, 4
2005+
2006+
def _make_split_file(self, root, num_examples):
2007+
# We create a fake split file here, but users are asked to download the real one from the authors website
2008+
split_ids = [1] * num_examples["train"] + [2] * num_examples["val"]
2009+
random.shuffle(split_ids)
2010+
with open(str(root / "FlyingChairs_train_val.txt"), "w+") as split_file:
2011+
for split_id in split_ids:
2012+
split_file.write(f"{split_id}\n")
2013+
2014+
def inject_fake_data(self, tmpdir, config):
2015+
root = pathlib.Path(tmpdir) / "FlyingChairs"
2016+
2017+
num_examples = {"train": 5, "val": 3}
2018+
num_examples_total = sum(num_examples.values())
2019+
2020+
datasets_utils.create_image_folder( # img1
2021+
root,
2022+
name="data",
2023+
file_name_fn=lambda image_idx: f"00{image_idx}_img1.ppm",
2024+
num_examples=num_examples_total,
2025+
)
2026+
datasets_utils.create_image_folder( # img2
2027+
root,
2028+
name="data",
2029+
file_name_fn=lambda image_idx: f"00{image_idx}_img2.ppm",
2030+
num_examples=num_examples_total,
2031+
)
2032+
for i in range(num_examples_total):
2033+
file_name = str(root / "data" / f"00{i}_flow.flo")
2034+
datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name)
2035+
2036+
self._make_split_file(root, num_examples)
2037+
2038+
return num_examples[config["split"]]
2039+
2040+
@datasets_utils.test_all_configs
2041+
def test_flow(self, config):
2042+
# Make sure flow always exists, and make sure there are as many flow values as (pairs of) images
2043+
# Also make sure the flow is properly decoded
2044+
with self.create_dataset(config=config) as (dataset, _):
2045+
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
2046+
for _, _, flow in dataset:
2047+
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
2048+
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
2049+
2050+
20012051
if __name__ == "__main__":
20022052
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._optical_flow import KittiFlow, Sintel
1+
from ._optical_flow import KittiFlow, Sintel, FlyingChairs
22
from .caltech import Caltech101, Caltech256
33
from .celeba import CelebA
44
from .cifar import CIFAR10, CIFAR100
@@ -74,4 +74,5 @@
7474
"LFWPairs",
7575
"KittiFlow",
7676
"Sintel",
77+
"FlyingChairs",
7778
)

torchvision/datasets/_optical_flow.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
from PIL import Image
99

1010
from ..io.image import _read_png_16
11+
from .utils import verify_str_arg
1112
from .vision import VisionDataset
1213

1314

1415
__all__ = (
1516
"KittiFlow",
1617
"Sintel",
18+
"FlyingChairs",
1719
)
1820

1921

@@ -109,11 +111,8 @@ class Sintel(FlowDataset):
109111
def __init__(self, root, split="train", pass_name="clean", transforms=None):
110112
super().__init__(root=root, transforms=transforms)
111113

112-
if split not in ("train", "test"):
113-
raise ValueError("split must be either 'train' or 'test'")
114-
115-
if pass_name not in ("clean", "final"):
116-
raise ValueError("pass_name must be either 'clean' or 'final'")
114+
verify_str_arg(split, "split", valid_values=("train", "test"))
115+
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final"))
117116

118117
root = Path(root) / "Sintel"
119118

@@ -171,8 +170,7 @@ class KittiFlow(FlowDataset):
171170
def __init__(self, root, split="train", transforms=None):
172171
super().__init__(root=root, transforms=transforms)
173172

174-
if split not in ("train", "test"):
175-
raise ValueError("split must be either 'train' or 'test'")
173+
verify_str_arg(split, "split", valid_values=("train", "test"))
176174

177175
root = Path(root) / "Kitti" / (split + "ing")
178176
images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
@@ -208,6 +206,71 @@ def _read_flow(self, file_name):
208206
return _read_16bits_png_with_flow_and_valid_mask(file_name)
209207

210208

209+
class FlyingChairs(FlowDataset):
210+
"""`FlyingChairs <https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs>`_ Dataset for optical flow.
211+
212+
You will also need to download the FlyingChairs_train_val.txt file from the dataset page.
213+
214+
The dataset is expected to have the following structure: ::
215+
216+
root
217+
FlyingChairs
218+
data
219+
00001_flow.flo
220+
00001_img1.ppm
221+
00001_img2.ppm
222+
...
223+
FlyingChairs_train_val.txt
224+
225+
226+
Args:
227+
root (string): Root directory of the FlyingChairs Dataset.
228+
split (string, optional): The dataset split, either "train" (default) or "val"
229+
transforms (callable, optional): A function/transform that takes in
230+
``img1, img2, flow, valid`` and returns a transformed version.
231+
``valid`` is expected for consistency with other datasets which
232+
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
233+
"""
234+
235+
def __init__(self, root, split="train", transforms=None):
236+
super().__init__(root=root, transforms=transforms)
237+
238+
verify_str_arg(split, "split", valid_values=("train", "val"))
239+
240+
root = Path(root) / "FlyingChairs"
241+
images = sorted(glob(str(root / "data" / "*.ppm")))
242+
flows = sorted(glob(str(root / "data" / "*.flo")))
243+
244+
split_file_name = "FlyingChairs_train_val.txt"
245+
246+
if not os.path.exists(root / split_file_name):
247+
raise FileNotFoundError(
248+
"The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
249+
)
250+
251+
split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32)
252+
for i in range(len(flows)):
253+
split_id = split_list[i]
254+
if (split == "train" and split_id == 1) or (split == "val" and split_id == 2):
255+
self._flow_list += [flows[i]]
256+
self._image_list += [[images[2 * i], images[2 * i + 1]]]
257+
258+
def __getitem__(self, index):
259+
"""Return example at given index.
260+
261+
Args:
262+
index(int): The index of the example to retrieve
263+
264+
Returns:
265+
tuple: A 3-tuple with ``(img1, img2, flow)``.
266+
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
267+
"""
268+
return super().__getitem__(index)
269+
270+
def _read_flow(self, file_name):
271+
return _read_flo(file_name)
272+
273+
211274
def _read_flo(file_name):
212275
"""Read .flo file in Middlebury format"""
213276
# Code adapted from:

0 commit comments

Comments
 (0)