Skip to content

Commit 9c66ddc

Browse files
committed
add hmdb51 dataset and prototype for new style video decoding
1 parent 7cffef6 commit 9c66ddc

File tree

9 files changed

+436
-11
lines changed

9 files changed

+436
-11
lines changed

torchvision/datasets/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import bz2
2+
import contextlib
23
import gzip
34
import hashlib
45
import itertools
@@ -301,6 +302,15 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
301302
".tgz": (".tar", ".gz"),
302303
}
303304

305+
with contextlib.suppress(ImportError):
306+
import rarfile
307+
308+
def _extract_rar(from_path: str, to_path: str, compression: Optional[str]) -> None:
309+
with rarfile.RarFile(from_path, "r") as rar:
310+
rar.extractall(to_path)
311+
312+
_ARCHIVE_EXTRACTORS[".rar"] = _extract_rar
313+
304314

305315
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
306316
"""Detect the archive type and/or compression of a file.

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .dtd import DTD
88
from .fer2013 import FER2013
99
from .gtsrb import GTSRB
10+
from .hmdb51 import HMDB51
1011
from .imagenet import ImageNet
1112
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
1213
from .oxford_iiit_pet import OxfordIITPet
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
brush_hair
2+
cartwheel
3+
catch
4+
chew
5+
clap
6+
climb
7+
climb_stairs
8+
dive
9+
draw_sword
10+
dribble
11+
drink
12+
eat
13+
fall_floor
14+
fencing
15+
flic_flac
16+
golf
17+
handstand
18+
hit
19+
hug
20+
jump
21+
kick
22+
kick_ball
23+
kiss
24+
laugh
25+
pick
26+
pour
27+
pullup
28+
punch
29+
push
30+
pushup
31+
ride_bike
32+
ride_horse
33+
run
34+
shake_hands
35+
shoot_ball
36+
shoot_bow
37+
shoot_gun
38+
sit
39+
situp
40+
smile
41+
smoke
42+
somersault
43+
stand
44+
swing_baseball
45+
sword
46+
sword_exercise
47+
talk
48+
throw
49+
turn
50+
walk
51+
wave
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import functools
2+
import pathlib
3+
import re
4+
from typing import Any, Dict, List, Tuple, BinaryIO
5+
6+
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, IterKeyZipper
7+
from torchvision.prototype.datasets.utils import (
8+
Dataset,
9+
DatasetConfig,
10+
DatasetInfo,
11+
HttpResource,
12+
OnlineResource,
13+
)
14+
from torchvision.prototype.datasets.utils._internal import (
15+
INFINITE_BUFFER_SIZE,
16+
getitem,
17+
path_accessor,
18+
hint_sharding,
19+
hint_shuffling,
20+
)
21+
from torchvision.prototype.features import EncodedVideo, Label
22+
23+
24+
class HMDB51(Dataset):
25+
def _make_info(self) -> DatasetInfo:
26+
return DatasetInfo(
27+
"hmdb51",
28+
homepage="https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/",
29+
valid_options=dict(
30+
split=("train", "test"),
31+
split_number=("1", "2", "3"),
32+
),
33+
)
34+
35+
def _extract_videos_archive(self, path: pathlib.Path) -> pathlib.Path:
36+
folder = OnlineResource._extract(path)
37+
for rar_file in folder.glob("*.rar"):
38+
OnlineResource._extract(rar_file)
39+
rar_file.unlink()
40+
return folder
41+
42+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
43+
url_root = "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10"
44+
45+
splits = HttpResource(
46+
f"{url_root}/test_train_splits.rar",
47+
sha256="229c94f845720d01eb3946d39f39292ea962d50a18136484aa47c1eba251d2b7",
48+
)
49+
videos = HttpResource(
50+
f"{url_root}/hmdb51_org.rar",
51+
sha256="9e714a0d8b76104d76e932764a7ca636f929fff66279cda3f2e326fa912a328e",
52+
)
53+
videos._preprocess = self._extract_videos_archive
54+
return [splits, videos]
55+
56+
_SPLIT_FILE_PATTERN = re.compile(r"(?P<category>\w+?)_test_split(?P<split_number>[1-3])[.]txt")
57+
58+
def _is_split_number(self, data: Tuple[str, Any], *, split_number: str) -> bool:
59+
path = pathlib.Path(data[0])
60+
return self._SPLIT_FILE_PATTERN.match(path.name)["split_number"] == split_number # type: ignore[union-attr]
61+
62+
_SPLIT_ID_TO_NAME = {
63+
"1": "train",
64+
"2": "test",
65+
}
66+
67+
def _is_split(self, data: Dict[str, Any], *, split: str) -> bool:
68+
split_id = data["split_id"]
69+
70+
# TODO: explain
71+
if split_id not in self._SPLIT_ID_TO_NAME:
72+
return False
73+
74+
return self._SPLIT_ID_TO_NAME[split_id] == split
75+
76+
def _prepare_sample(self, data: Tuple[List[str], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
77+
_, (path, buffer) = data
78+
path = pathlib.Path(path)
79+
return dict(
80+
label=Label.from_category(path.parent.name, categories=self.categories),
81+
video=EncodedVideo.from_file(buffer, path=path),
82+
)
83+
84+
def _make_datapipe(
85+
self,
86+
resource_dps: List[IterDataPipe],
87+
*,
88+
config: DatasetConfig,
89+
) -> IterDataPipe[Dict[str, Any]]:
90+
splits_dp, videos_dp = resource_dps
91+
92+
splits_dp = Filter(splits_dp, functools.partial(self._is_split_number, split_number=config.split_number))
93+
splits_dp = CSVDictParser(splits_dp, fieldnames=("filename", "split_id"), delimiter=" ")
94+
splits_dp = Filter(splits_dp, functools.partial(self._is_split, split=config.split))
95+
splits_dp = hint_sharding(splits_dp)
96+
splits_dp = hint_shuffling(splits_dp)
97+
98+
dp = IterKeyZipper(
99+
splits_dp,
100+
videos_dp,
101+
key_fn=getitem("filename"),
102+
ref_key_fn=path_accessor("name"),
103+
buffer_size=INFINITE_BUFFER_SIZE,
104+
)
105+
return Mapper(dp, self._prepare_sample)
106+
107+
def _generate_categories(self, root: pathlib.Path) -> List[str]:
108+
config = self.default_config
109+
resources = self.resources(config)
110+
111+
dp = resources[0].load(root)
112+
categories = {
113+
self._SPLIT_FILE_PATTERN.match(pathlib.Path(path).name)["category"] # type: ignore[union-attr]
114+
for path, _ in dp
115+
}
116+
return sorted(categories)

torchvision/prototype/datasets/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from ._dataset import DatasetConfig, DatasetInfo, Dataset
33
from ._query import SampleQuery
44
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource
5+
from ._video import KeyframeDecoder, RandomFrameDecoder, ClipDecoder

0 commit comments

Comments
 (0)