Skip to content

Commit 1c025f1

Browse files
committed
port UCF101
1 parent 9c66ddc commit 1c025f1

File tree

5 files changed

+199
-2
lines changed

5 files changed

+199
-2
lines changed

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
from .sbd import SBD
1616
from .semeion import SEMEION
1717
from .svhn import SVHN
18+
from .ucf101 import UCF101
1819
from .voc import VOC

torchvision/prototype/datasets/_builtin/hmdb51.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def _make_info(self) -> DatasetInfo:
2626
return DatasetInfo(
2727
"hmdb51",
2828
homepage="https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/",
29+
dependencies=("rarfile",),
2930
valid_options=dict(
3031
split=("train", "test"),
3132
split_number=("1", "2", "3"),
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
ApplyEyeMakeup
2+
ApplyLipstick
3+
Archery
4+
BabyCrawling
5+
BalanceBeam
6+
BandMarching
7+
BaseballPitch
8+
Basketball
9+
BasketballDunk
10+
BenchPress
11+
Biking
12+
Billiards
13+
BlowDryHair
14+
BlowingCandles
15+
BodyWeightSquats
16+
Bowling
17+
BoxingPunchingBag
18+
BoxingSpeedBag
19+
BreastStroke
20+
BrushingTeeth
21+
CleanAndJerk
22+
CliffDiving
23+
CricketBowling
24+
CricketShot
25+
CuttingInKitchen
26+
Diving
27+
Drumming
28+
Fencing
29+
FieldHockeyPenalty
30+
FloorGymnastics
31+
FrisbeeCatch
32+
FrontCrawl
33+
GolfSwing
34+
Haircut
35+
Hammering
36+
HammerThrow
37+
HandstandPushups
38+
HandstandWalking
39+
HeadMassage
40+
HighJump
41+
HorseRace
42+
HorseRiding
43+
HulaHoop
44+
IceDancing
45+
JavelinThrow
46+
JugglingBalls
47+
JumpingJack
48+
JumpRope
49+
Kayaking
50+
Knitting
51+
LongJump
52+
Lunges
53+
MilitaryParade
54+
Mixing
55+
MoppingFloor
56+
Nunchucks
57+
ParallelBars
58+
PizzaTossing
59+
PlayingCello
60+
PlayingDaf
61+
PlayingDhol
62+
PlayingFlute
63+
PlayingGuitar
64+
PlayingPiano
65+
PlayingSitar
66+
PlayingTabla
67+
PlayingViolin
68+
PoleVault
69+
PommelHorse
70+
PullUps
71+
Punch
72+
PushUps
73+
Rafting
74+
RockClimbingIndoor
75+
RopeClimbing
76+
Rowing
77+
SalsaSpin
78+
ShavingBeard
79+
Shotput
80+
SkateBoarding
81+
Skiing
82+
Skijet
83+
SkyDiving
84+
SoccerJuggling
85+
SoccerPenalty
86+
StillRings
87+
SumoWrestling
88+
Surfing
89+
Swing
90+
TableTennisShot
91+
TaiChi
92+
TennisSwing
93+
ThrowDiscus
94+
TrampolineJumping
95+
Typing
96+
UnevenBars
97+
VolleyballSpiking
98+
WalkingWithDog
99+
WallPushups
100+
WritingOnBoard
101+
YoYo
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import csv
2+
import pathlib
3+
from typing import Any, Dict, List, Tuple, cast, BinaryIO
4+
5+
from torch.utils.data import IterDataPipe
6+
from torch.utils.data.datapipes.iter import Filter, Mapper
7+
from torchdata.datapipes.iter import CSVParser, IterKeyZipper
8+
from torchvision.prototype.datasets.utils import (
9+
Dataset,
10+
DatasetConfig,
11+
DatasetInfo,
12+
HttpResource,
13+
OnlineResource,
14+
)
15+
from torchvision.prototype.datasets.utils._internal import (
16+
path_accessor,
17+
path_comparator,
18+
hint_sharding,
19+
hint_shuffling,
20+
INFINITE_BUFFER_SIZE,
21+
)
22+
from torchvision.prototype.features import Label, EncodedVideo
23+
24+
csv.register_dialect("ucf101", delimiter=" ")
25+
26+
27+
class UCF101(Dataset):
28+
def _make_info(self) -> DatasetInfo:
29+
return DatasetInfo(
30+
"ucf101",
31+
dependencies=("rarfile",),
32+
valid_options=dict(
33+
split=("train", "test"),
34+
fold=("1", "2", "3"),
35+
),
36+
homepage="https://www.crcv.ucf.edu/data/UCF101.php",
37+
)
38+
39+
def _extract_videos_archive(self, path: pathlib.Path) -> pathlib.Path:
40+
folder = OnlineResource._extract(path)
41+
for rar_file in folder.glob("*.rar"):
42+
OnlineResource._extract(rar_file)
43+
rar_file.unlink()
44+
return folder
45+
46+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
47+
url_root = "https://www.crcv.ucf.edu/data/UCF101/"
48+
49+
splits = HttpResource(
50+
f"{url_root}/UCF101TrainTestSplits-RecognitionTask.zip",
51+
sha256="5c0d1a53b8ed364a2ac830a73f405e51bece7d98ce1254fd19ed4a36b224bd27",
52+
)
53+
54+
videos = HttpResource(
55+
f"{url_root}/UCF101.rar",
56+
sha256="ca8dfadb4c891cb11316f94d52b6b0ac2a11994e67a0cae227180cd160bd8e55",
57+
extract=True,
58+
)
59+
videos._preprocess = self._extract_videos_archive
60+
61+
return [splits, videos]
62+
63+
def _prepare_sample(self, data: Tuple[Tuple[str, str], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
64+
_, (path, buffer) = data
65+
path = pathlib.Path(path)
66+
return dict(
67+
label=Label.from_category(path.parent.name, categories=self.categories),
68+
video=EncodedVideo.from_file(buffer, path=path),
69+
)
70+
71+
def _make_datapipe(
72+
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
73+
) -> IterDataPipe[Dict[str, Any]]:
74+
splits_dp, images_dp = resource_dps
75+
76+
splits_dp: IterDataPipe[Tuple[str, BinaryIO]] = Filter(
77+
splits_dp, path_comparator("name", f"{config.split}list0{config.fold}.txt")
78+
)
79+
splits_dp = CSVParser(splits_dp, dialect="ucf101")
80+
splits_dp = hint_sharding(splits_dp)
81+
splits_dp = hint_shuffling(splits_dp)
82+
83+
dp = IterKeyZipper(splits_dp, images_dp, path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE)
84+
return Mapper(dp, self._prepare_sample)
85+
86+
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
87+
config = self.default_config
88+
resources = self.resources(config)
89+
90+
dp = resources[0].load(root)
91+
dp = Filter(dp, path_comparator("name", "classInd.txt"))
92+
dp = CSVParser(dp, dialect="ucf101")
93+
_, categories = zip(*dp)
94+
return cast(Tuple[str, ...], categories)

torchvision/prototype/features/_encoded.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ def from_file(cls: Type[E], file: BinaryIO, **meta: Any) -> E:
3333
return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **meta)
3434

3535
@classmethod
36-
def from_path(cls: Type[E], path: Union[str, os.PathLike]) -> E:
36+
def from_path(cls: Type[E], path: Union[str, os.PathLike], **meta: Any) -> E:
3737
path = pathlib.Path(path)
3838
with open(path, "rb") as file:
39-
return cls.from_file(file, path=path)
39+
return cls.from_file(file, path=path, **meta)
4040

4141

4242
class EncodedImage(EncodedData):

0 commit comments

Comments
 (0)