Skip to content

Commit 2c7d682

Browse files
kazhangfacebook-github-bot
authored andcommitted
[fbsync] Add HD1K dataset for optical flow (#4890)
Reviewed By: datumbox Differential Revision: D32470482 fbshipit-source-id: 3eade39532e69471dcec040fccbf2fdf2ca05d4e
1 parent 413848a commit 2c7d682

File tree

4 files changed

+113
-1
lines changed

4 files changed

+113
-1
lines changed

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4545
Flickr30k
4646
FlyingChairs
4747
FlyingThings3D
48+
HD1K
4849
HMDB51
4950
ImageNet
5051
INaturalist

test/test_datasets.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2126,5 +2126,47 @@ def test_bad_input(self):
21262126
pass
21272127

21282128

2129+
class HD1KTestCase(KittiFlowTestCase):
2130+
DATASET_CLASS = datasets.HD1K
2131+
2132+
def inject_fake_data(self, tmpdir, config):
2133+
root = pathlib.Path(tmpdir) / "hd1k"
2134+
2135+
num_sequences = 4 if config["split"] == "train" else 3
2136+
num_examples_per_train_sequence = 3
2137+
2138+
for seq_idx in range(num_sequences):
2139+
# Training data
2140+
datasets_utils.create_image_folder(
2141+
root / "hd1k_input",
2142+
name="image_2",
2143+
file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png",
2144+
num_examples=num_examples_per_train_sequence,
2145+
)
2146+
datasets_utils.create_image_folder(
2147+
root / "hd1k_flow_gt",
2148+
name="flow_occ",
2149+
file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png",
2150+
num_examples=num_examples_per_train_sequence,
2151+
)
2152+
2153+
# Test data
2154+
datasets_utils.create_image_folder(
2155+
root / "hd1k_challenge",
2156+
name="image_2",
2157+
file_name_fn=lambda _: f"{seq_idx:06d}_10.png",
2158+
num_examples=1,
2159+
)
2160+
datasets_utils.create_image_folder(
2161+
root / "hd1k_challenge",
2162+
name="image_2",
2163+
file_name_fn=lambda _: f"{seq_idx:06d}_11.png",
2164+
num_examples=1,
2165+
)
2166+
2167+
num_examples_per_sequence = num_examples_per_train_sequence if config["split"] == "train" else 2
2168+
return num_sequences * (num_examples_per_sequence - 1)
2169+
2170+
21292171
if __name__ == "__main__":
21302172
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D
1+
from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D, HD1K
22
from .caltech import Caltech101, Caltech256
33
from .celeba import CelebA
44
from .cifar import CIFAR10, CIFAR100
@@ -76,4 +76,5 @@
7676
"Sintel",
7777
"FlyingChairs",
7878
"FlyingThings3D",
79+
"HD1K",
7980
)

torchvision/datasets/_optical_flow.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"Sintel",
2020
"FlyingThings3D",
2121
"FlyingChairs",
22+
"HD1K",
2223
)
2324

2425

@@ -363,6 +364,73 @@ def _read_flow(self, file_name):
363364
return _read_pfm(file_name)
364365

365366

367+
class HD1K(FlowDataset):
368+
"""`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow.
369+
370+
The dataset is expected to have the following structure: ::
371+
372+
root
373+
hd1k
374+
hd1k_challenge
375+
image_2
376+
hd1k_flow_gt
377+
flow_occ
378+
hd1k_input
379+
image_2
380+
381+
Args:
382+
root (string): Root directory of the HD1K Dataset.
383+
split (string, optional): The dataset split, either "train" (default) or "test"
384+
transforms (callable, optional): A function/transform that takes in
385+
``img1, img2, flow, valid`` and returns a transformed version.
386+
"""
387+
388+
_has_builtin_flow_mask = True
389+
390+
def __init__(self, root, split="train", transforms=None):
391+
super().__init__(root=root, transforms=transforms)
392+
393+
verify_str_arg(split, "split", valid_values=("train", "test"))
394+
395+
root = Path(root) / "hd1k"
396+
if split == "train":
397+
# There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
398+
for seq_idx in range(36):
399+
flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
400+
images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
401+
for i in range(len(flows) - 1):
402+
self._flow_list += [flows[i]]
403+
self._image_list += [[images[i], images[i + 1]]]
404+
else:
405+
images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
406+
images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
407+
for image1, image2 in zip(images1, images2):
408+
self._image_list += [[image1, image2]]
409+
410+
if not self._image_list:
411+
raise FileNotFoundError(
412+
"Could not find the HD1K images. Please make sure the directory structure is correct."
413+
)
414+
415+
def _read_flow(self, file_name):
416+
return _read_16bits_png_with_flow_and_valid_mask(file_name)
417+
418+
def __getitem__(self, index):
419+
"""Return example at given index.
420+
421+
Args:
422+
index(int): The index of the example to retrieve
423+
424+
Returns:
425+
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
426+
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
427+
indicating which flow values are valid. The flow is a numpy array of
428+
shape (2, H, W) and the images are PIL images. If `split="test"`, a
429+
4-tuple with ``(img1, img2, None, None)`` is returned.
430+
"""
431+
return super().__getitem__(index)
432+
433+
366434
def _read_flo(file_name):
367435
"""Read .flo file in Middlebury format"""
368436
# Code adapted from:

0 commit comments

Comments
 (0)