diff --git a/.circleci/config.yml b/.circleci/config.yml index b616f3bde93..a77a3b5bedd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -351,7 +351,7 @@ jobs: - install_torchvision - install_prototype_dependencies - pip_install: - args: scipy pycocotools h5py + args: scipy pycocotools h5py av rarfile descr: Install optional dependencies - run: name: Enable prototype tests diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 76d90274be3..a42f15d3d1c 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -351,7 +351,7 @@ jobs: - install_torchvision - install_prototype_dependencies - pip_install: - args: scipy pycocotools h5py + args: scipy pycocotools h5py av rarfile descr: Install optional dependencies - run: name: Enable prototype tests diff --git a/mypy.ini b/mypy.ini index 6d7863b627e..fc2741e28f4 100644 --- a/mypy.ini +++ b/mypy.ini @@ -155,3 +155,7 @@ ignore_missing_imports = True [mypy-h5py.*] ignore_missing_imports = True + +[mypy-rarfile.*] + +ignore_missing_imports = True diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 123d8f29d3f..71ed8584b5a 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -9,6 +9,7 @@ import pathlib import pickle import random +import unittest.mock import xml.etree.ElementTree as ET from collections import defaultdict, Counter @@ -16,11 +17,10 @@ import PIL.Image import pytest import torch -from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file +from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, create_video_folder from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor from torchvision.prototype.datasets._api import find -from torchvision.prototype.utils._internal import sequence_to_str make_tensor = functools.partial(_make_tensor, device="cpu") make_scalar = functools.partial(make_tensor, ()) @@ -67,14 +67,15 @@ def prepare(self, home, config): mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config)) - available_file_names = {path.name for path in root.glob("*")} - required_file_names = {resource.file_name for resource in self.dataset.resources(config)} - missing_file_names = required_file_names - available_file_names - if missing_file_names: - raise pytest.UsageError( - f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " - f"for {config}, but they were not created by the mock data function." - ) + for resource in self.dataset.resources(config): + with unittest.mock.patch( + "torchvision.prototype.datasets.utils._resource.OnlineResource.download", + side_effect=TypeError( + f"Dataset '{self.name}' requires the file {resource.file_name} for {config}, " + f"but it was not created by the mock data function." + ), + ): + resource.load(root) return mock_info @@ -1344,3 +1345,79 @@ def pcam(info, root, config): compressed_file.write(compressed_data) return num_images + + +@register_mock +def ucf101(info, root, config): + video_folder = root / "UCF101" / "UCF-101" + + categories_and_labels = [ + ("ApplyEyeMakeup", 0), + ("LongJump", 50), + ("YoYo", 100), + ] + + def file_name_fn(cls, idx, clips_per_group=2): + return f"v_{cls}_g{(idx // clips_per_group) + 1:02d}_c{(idx % clips_per_group) + 1:02d}.avi" + + video_files = [ + create_video_folder( + video_folder, category, lambda idx: file_name_fn(category, idx), num_examples=int(torch.randint(1, 6, ())) + ) + for category, _ in categories_and_labels + ] + + splits_folder = root / "ucfTrainTestList" + splits_folder.mkdir() + + with open(splits_folder / "classInd.txt", "w") as file: + file.write("\n".join(f"{label} {category}" for category, label in categories_and_labels) + "\n") + + video_ids = [path.relative_to(video_folder).as_posix() for path in itertools.chain.from_iterable(video_files)] + splits = ("train", "test") + num_samples_map = {} + for fold in range(1, 4): + random.shuffle(video_ids) + for offset, split in enumerate(splits): + video_ids_in_config = video_ids[offset :: len(splits)] + with open(splits_folder / f"{split}list{fold:02d}.txt", "w") as file: + file.write("\n".join(video_ids_in_config) + "\n") + + num_samples_map[info.make_config(split=split, fold=str(fold))] = len(video_ids_in_config) + + make_zip(root, "UCF101TrainTestSplits-RecognitionTask.zip", splits_folder) + + return num_samples_map[config] + + +@register_mock +def hmdb51(info, root, config): + video_folder = root / "hmdb51_org" + + categories = [ + "brush_hair", + "pour", + "wave", + ] + + video_files = { + category: create_video_folder( + video_folder, category, lambda idx: f"{category}_{idx}.avi", num_examples=int(torch.randint(3, 10, ())) + ) + for category in categories + } + + splits_folder = root / "test_train_splits" / "testTrainMulti_7030_splits" + splits_folder.mkdir(parents=True) + + num_samples_map = defaultdict(lambda: 0) + for category, fold in itertools.product(categories, range(1, 4)): + videos = video_files[category] + + with open(splits_folder / f"{category}_test_split{fold}.txt", "w") as file: + file.write("\n".join(f"{path.name} {idx % 3}" for idx, path in enumerate(videos)) + "\n") + + for split, split_id in (("train", 1), ("test", 2)): + num_samples_map[info.make_config(split=split, fold=str(fold))] += len(videos[split_id::3]) + + return num_samples_map[config] diff --git a/test/test_prototype_videoutils.py b/test/test_prototype_videoutils.py new file mode 100644 index 00000000000..a59c453786a --- /dev/null +++ b/test/test_prototype_videoutils.py @@ -0,0 +1,88 @@ +import math +import os + +import pytest +import torch +from torchvision.io import _HAS_VIDEO_DECODER, _HAS_VIDEO_OPT, VideoReader +from torchvision.prototype.features import EncodedData +from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer +from torchvision.prototype.datasets.utils._video import KeyframeDecoder, RandomFrameDecoder +try: + import av +except ImportError: + av = None + +VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") + + +@pytest.mark.skipif(av is None, reason="PyAV unavailable") +class TestVideoDatasetUtils: + # TODO: atm we separate backends in order to allow for testing on different systems; + # once we have things packaged we should add this as test parametrisation + # (this also applies for GPU decoding as well) + + @pytest.mark.parametrize( + "video_file", + [ + "RATRACE_wave_f_nm_np1_fr_goo_37.avi", + "TrumanShow_wave_f_nm_np1_fr_med_26.avi", + "v_SoccerJuggling_g23_c01.avi", + "v_SoccerJuggling_g24_c01.avi", + "R6llTwEh07w.mp4", + "SOX5yA1l24A.mp4", + "WUzgd7C1pWA.mp4", + ], + ) + def test_random_decoder_av(self, video_file): + """Read a sequence of random frames from a video + Checks that files are valid video frames and no error is thrown during decoding. + """ + video_file = os.path.join(VIDEO_DIR, video_file) + video = ReadOnlyTensorBuffer(EncodedData.from_path(video_file)) + print(next(video)) + pass + + def test_random_decoder_cpu(self, video_file): + """Read a sequence of random frames from a video using CPU backend + Checks that files are valid video frames and no error is thrown during decoding, + and compares them to `pyav` output. + """ + pass + + def test_random_decoder_GPU(self, video_file): + """Read a sequence of random frames from a video using GPU backend + Checks that files are valid video frames and no error is thrown during decoding, + and compares them to `pyav` output. + """ + pass + + def test_keyframe_decoder_av(self, video_file): + """Read all keyframes from a video; + Compare the output to naive keyframe reading with `pyav` + """ + pass + + def test_keyframe_decoder_cpu(self, video_file): + """Read all keyframes from a video using CPU backend; + ATM should raise a warning and default to `pyav` + TODO: should we fail or default to a working backend + """ + pass + + def test_keyframe_decoder_GPU(self, video_file): + """Read all keyframes from a video using CPU backend; + ATM should raise a warning and default to `pyav` + TODO: should we fail or default to a working backend + """ + pass + + def test_clip_decoder(self, video_file): + """ATM very crude test: + check only if fails, or if the clip sampling is correct, + don't bother with the content just yet. + """ + pass + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index dbc9cf2a6b4..9ec4501e3b1 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -1,4 +1,5 @@ import bz2 +import contextlib import gzip import hashlib import itertools @@ -301,6 +302,15 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No ".tgz": (".tar", ".gz"), } +with contextlib.suppress(ImportError): + import rarfile + + def _extract_rar(from_path: str, to_path: str, compression: Optional[str]) -> None: + with rarfile.RarFile(from_path, "r") as rar: + rar.extractall(to_path) + + _ARCHIVE_EXTRACTORS[".rar"] = _extract_rar + def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: """Detect the archive type and/or compression of a file. diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 9fdfca904f5..61f6278ece7 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -7,6 +7,7 @@ from .dtd import DTD from .fer2013 import FER2013 from .gtsrb import GTSRB +from .hmdb51 import HMDB51 from .imagenet import ImageNet from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .oxford_iiit_pet import OxfordIITPet @@ -14,4 +15,5 @@ from .sbd import SBD from .semeion import SEMEION from .svhn import SVHN +from .ucf101 import UCF101 from .voc import VOC diff --git a/torchvision/prototype/datasets/_builtin/hmdb51.categories b/torchvision/prototype/datasets/_builtin/hmdb51.categories new file mode 100644 index 00000000000..3217416f524 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/hmdb51.categories @@ -0,0 +1,51 @@ +brush_hair +cartwheel +catch +chew +clap +climb +climb_stairs +dive +draw_sword +dribble +drink +eat +fall_floor +fencing +flic_flac +golf +handstand +hit +hug +jump +kick +kick_ball +kiss +laugh +pick +pour +pullup +punch +push +pushup +ride_bike +ride_horse +run +shake_hands +shoot_ball +shoot_bow +shoot_gun +sit +situp +smile +smoke +somersault +stand +swing_baseball +sword +sword_exercise +talk +throw +turn +walk +wave diff --git a/torchvision/prototype/datasets/_builtin/hmdb51.py b/torchvision/prototype/datasets/_builtin/hmdb51.py new file mode 100644 index 00000000000..ce37a0045b5 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/hmdb51.py @@ -0,0 +1,117 @@ +import functools +import pathlib +import re +from typing import Any, Dict, List, Tuple, BinaryIO + +from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, IterKeyZipper +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, +) +from torchvision.prototype.datasets.utils._internal import ( + INFINITE_BUFFER_SIZE, + getitem, + path_accessor, + hint_sharding, + hint_shuffling, +) +from torchvision.prototype.features import EncodedVideo, Label + + +class HMDB51(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "hmdb51", + homepage="https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/", + dependencies=("rarfile",), + valid_options=dict( + split=("train", "test"), + fold=("1", "2", "3"), + ), + ) + + def _extract_videos_archive(self, path: pathlib.Path) -> pathlib.Path: + folder = OnlineResource._extract(path) + for rar_file in folder.glob("*.rar"): + OnlineResource._extract(rar_file) + rar_file.unlink() + return folder + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + url_root = "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10" + + splits = HttpResource( + f"{url_root}/test_train_splits.rar", + sha256="229c94f845720d01eb3946d39f39292ea962d50a18136484aa47c1eba251d2b7", + ) + videos = HttpResource( + f"{url_root}/hmdb51_org.rar", + sha256="9e714a0d8b76104d76e932764a7ca636f929fff66279cda3f2e326fa912a328e", + ) + videos._preprocess = self._extract_videos_archive + return [splits, videos] + + _SPLIT_FILE_PATTERN = re.compile(r"(?P\w+?)_test_split(?P[1-3])[.]txt") + + def _is_fold(self, data: Tuple[str, Any], *, fold: str) -> bool: + path = pathlib.Path(data[0]) + return self._SPLIT_FILE_PATTERN.match(path.name)["fold"] == fold # type: ignore[index] + + _SPLIT_ID_TO_NAME = { + "1": "train", + "2": "test", + } + + def _is_split(self, data: Dict[str, Any], *, split: str) -> bool: + split_id = data["split_id"] + + # In addition to split id 1 and 2 corresponding to the train and test splits, some videos are annotated with + # split id 0, which indicates that the video is not included in either split + if split_id not in self._SPLIT_ID_TO_NAME: + return False + + return self._SPLIT_ID_TO_NAME[split_id] == split + + def _prepare_sample(self, data: Tuple[List[str], Tuple[str, BinaryIO]]) -> Dict[str, Any]: + _, (path, buffer) = data + path = pathlib.Path(path) + return dict( + label=Label.from_category(path.parent.name, categories=self.categories), + video=EncodedVideo.from_file(buffer, path=path), + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + ) -> IterDataPipe[Dict[str, Any]]: + splits_dp, videos_dp = resource_dps + + splits_dp = Filter(splits_dp, functools.partial(self._is_fold, fold=config.fold)) + splits_dp = CSVDictParser(splits_dp, fieldnames=("filename", "split_id"), delimiter=" ") + splits_dp = Filter(splits_dp, functools.partial(self._is_split, split=config.split)) + splits_dp = hint_sharding(splits_dp) + splits_dp = hint_shuffling(splits_dp) + + dp = IterKeyZipper( + splits_dp, + videos_dp, + key_fn=getitem("filename"), + ref_key_fn=path_accessor("name"), + buffer_size=INFINITE_BUFFER_SIZE, + ) + return Mapper(dp, self._prepare_sample) + + def _generate_categories(self, root: pathlib.Path) -> List[str]: + config = self.default_config + resources = self.resources(config) + + dp = resources[0].load(root) + categories = { + self._SPLIT_FILE_PATTERN.match(pathlib.Path(path).name)["category"] for path, _ in dp # type: ignore[index] + } + return sorted(categories) diff --git a/torchvision/prototype/datasets/_builtin/ucf101.categories b/torchvision/prototype/datasets/_builtin/ucf101.categories new file mode 100644 index 00000000000..dd41d095c7c --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/ucf101.categories @@ -0,0 +1,101 @@ +ApplyEyeMakeup +ApplyLipstick +Archery +BabyCrawling +BalanceBeam +BandMarching +BaseballPitch +Basketball +BasketballDunk +BenchPress +Biking +Billiards +BlowDryHair +BlowingCandles +BodyWeightSquats +Bowling +BoxingPunchingBag +BoxingSpeedBag +BreastStroke +BrushingTeeth +CleanAndJerk +CliffDiving +CricketBowling +CricketShot +CuttingInKitchen +Diving +Drumming +Fencing +FieldHockeyPenalty +FloorGymnastics +FrisbeeCatch +FrontCrawl +GolfSwing +Haircut +Hammering +HammerThrow +HandstandPushups +HandstandWalking +HeadMassage +HighJump +HorseRace +HorseRiding +HulaHoop +IceDancing +JavelinThrow +JugglingBalls +JumpingJack +JumpRope +Kayaking +Knitting +LongJump +Lunges +MilitaryParade +Mixing +MoppingFloor +Nunchucks +ParallelBars +PizzaTossing +PlayingCello +PlayingDaf +PlayingDhol +PlayingFlute +PlayingGuitar +PlayingPiano +PlayingSitar +PlayingTabla +PlayingViolin +PoleVault +PommelHorse +PullUps +Punch +PushUps +Rafting +RockClimbingIndoor +RopeClimbing +Rowing +SalsaSpin +ShavingBeard +Shotput +SkateBoarding +Skiing +Skijet +SkyDiving +SoccerJuggling +SoccerPenalty +StillRings +SumoWrestling +Surfing +Swing +TableTennisShot +TaiChi +TennisSwing +ThrowDiscus +TrampolineJumping +Typing +UnevenBars +VolleyballSpiking +WalkingWithDog +WallPushups +WritingOnBoard +YoYo diff --git a/torchvision/prototype/datasets/_builtin/ucf101.py b/torchvision/prototype/datasets/_builtin/ucf101.py new file mode 100644 index 00000000000..dfc7652468b --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/ucf101.py @@ -0,0 +1,94 @@ +import csv +import pathlib +from typing import Any, Dict, List, Tuple, cast, BinaryIO + +from torch.utils.data import IterDataPipe +from torch.utils.data.datapipes.iter import Filter, Mapper +from torchdata.datapipes.iter import CSVParser, IterKeyZipper +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, +) +from torchvision.prototype.datasets.utils._internal import ( + path_accessor, + path_comparator, + hint_sharding, + hint_shuffling, + INFINITE_BUFFER_SIZE, +) +from torchvision.prototype.features import Label, EncodedVideo + +csv.register_dialect("ucf101", delimiter=" ") + + +class UCF101(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "ucf101", + dependencies=("rarfile",), + valid_options=dict( + split=("train", "test"), + fold=("1", "2", "3"), + ), + homepage="https://www.crcv.ucf.edu/data/UCF101.php", + ) + + def _extract_videos_archive(self, path: pathlib.Path) -> pathlib.Path: + folder = OnlineResource._extract(path) + for rar_file in folder.glob("*.rar"): + OnlineResource._extract(rar_file) + rar_file.unlink() + return folder + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + url_root = "https://www.crcv.ucf.edu/data/UCF101/" + + splits = HttpResource( + f"{url_root}/UCF101TrainTestSplits-RecognitionTask.zip", + sha256="5c0d1a53b8ed364a2ac830a73f405e51bece7d98ce1254fd19ed4a36b224bd27", + ) + + # The SSL certificate of the server is currently invalid, but downloading "unsafe" data is not supported yet + videos = HttpResource( + f"{url_root}/UCF101.rar", + sha256="ca8dfadb4c891cb11316f94d52b6b0ac2a11994e67a0cae227180cd160bd8e55", + ) + videos._preprocess = self._extract_videos_archive + + return [splits, videos] + + def _prepare_sample(self, data: Tuple[Tuple[str, str], Tuple[str, BinaryIO]]) -> Dict[str, Any]: + _, (path, buffer) = data + path = pathlib.Path(path) + return dict( + label=Label.from_category(path.parent.name, categories=self.categories), + video=EncodedVideo.from_file(buffer, path=path), + ) + + def _make_datapipe( + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig + ) -> IterDataPipe[Dict[str, Any]]: + splits_dp, videos_dp = resource_dps + + splits_dp: IterDataPipe[Tuple[str, BinaryIO]] = Filter( + splits_dp, path_comparator("name", f"{config.split}list0{config.fold}.txt") + ) + splits_dp = CSVParser(splits_dp, dialect="ucf101") + splits_dp = hint_sharding(splits_dp) + splits_dp = hint_shuffling(splits_dp) + + dp = IterKeyZipper(splits_dp, videos_dp, path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE) + return Mapper(dp, self._prepare_sample) + + def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: + config = self.default_config + resources = self.resources(config) + + dp = resources[0].load(root) + dp: IterDataPipe[Tuple[str, BinaryIO]] = Filter(dp, path_comparator("name", "classInd.txt")) + dp = CSVParser(dp, dialect="ucf101") + _, categories = zip(*dp) + return cast(Tuple[str, ...], categories) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 9423b65a8ee..89668a3b3d2 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -2,3 +2,4 @@ from ._dataset import DatasetConfig, DatasetInfo, Dataset from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource +from ._video import KeyframeDecoder, RandomFrameDecoder, ClipDecoder diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index b2ae175c551..7edd130c414 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -22,6 +22,7 @@ download_file_from_google_drive, _get_redirect_url, _get_google_drive_file_id, + tqdm, ) @@ -129,7 +130,7 @@ def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool def _check_sha256(self, path: pathlib.Path, *, chunk_size: int = 1024 * 1024) -> None: hash = hashlib.sha256() with open(path, "rb") as file: - for chunk in iter(lambda: file.read(chunk_size), b""): + for chunk in tqdm(iter(lambda: file.read(chunk_size), b"")): hash.update(chunk) sha256 = hash.hexdigest() if sha256 != self.sha256: diff --git a/torchvision/prototype/datasets/utils/_video.py b/torchvision/prototype/datasets/utils/_video.py new file mode 100644 index 00000000000..37703ec7262 --- /dev/null +++ b/torchvision/prototype/datasets/utils/_video.py @@ -0,0 +1,235 @@ +import random +import warnings +from typing import Any, Dict, Iterator, Optional, Tuple + +import av +import numpy as np +import torch +from torchdata.datapipes.iter import IterDataPipe +from torchvision import get_video_backend +from torchvision.io import video, _video_opt, VideoReader +from torchvision.prototype.features import Image, EncodedVideo +from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer, query_recursively + + +class _VideoDecoder(IterDataPipe): + def __init__(self, datapipe: IterDataPipe, *, inline: bool = True) -> None: + # TODO: add gpu support + self.datapipe = datapipe + self._inline = inline + + def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + raise NotImplementedError + + def _find_encoded_video(self, id: Tuple[Any, ...], obj: Any) -> Optional[Tuple[Any, ...]]: + if isinstance(obj, EncodedVideo): + return id, obj + else: + return None + + def _integrate_data(self, sample: Any, id: Tuple[Any, ...], data: Dict[str, Any]) -> Any: + if not self._inline: + return sample, data + elif not id: + return data + + grand_parent = None + parent = sample + for item in id[:-1]: + grand_parent = parent + parent = parent[item] + + if not isinstance(parent, dict): + raise TypeError( + f"Could not inline the decoded video data, " + f"since the container at item {''.join(str([item]) for item in id[:-1])} " + f"that holds the `EncodedVideo` at item {[id[-1]]} is not a 'dict' but a '{type(parent).__name__}'. " + f"If you don't want to automatically inline the decoded video data, construct the decoder with " + f"{type(self).__name__}(..., inline=False). This will change the return type to a tuple of the input " + f"and the decoded video data for each iteration." + ) + + parent = parent.copy() + del parent[id[-1]] + parent.update(data) + + if not grand_parent: + return parent + + grand_parent[id[-2]] = parent + return sample + + def __iter__(self) -> Iterator[Any]: + for sample in self.datapipe: + ids_and_videos = list(query_recursively(self._find_encoded_video, sample)) + if not ids_and_videos: + raise TypeError("no encoded video") + elif len(ids_and_videos) > 1: + raise ValueError("more than one encoded video") + id, video = ids_and_videos[0] + + for data in self._decode(ReadOnlyTensorBuffer(video), video.meta.copy()): + yield self._integrate_data(sample, id, data) + + +class KeyframeDecoder(_VideoDecoder): + def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + if get_video_backend() == "video_reader": + warnings.warn("Video reader API not implemented for keyframes yet, reverting to PyAV") + + with av.open(buffer, metadata_errors="ignore") as container: + stream = container.streams.video[0] + stream.codec_context.skip_frame = "NONKEY" + for frame in container.decode(stream): + yield dict( + frame=Image.from_pil(frame.to_image()), + pts=frame.pts, + video_meta=dict( + meta, + time_base=float(frame.time_base), + guessed_fps=float(stream.guessed_rate), + ), + ) + + +class RandomFrameDecoder(_VideoDecoder): + def __init__(self, datapipe: IterDataPipe, *, num_samples: int = 1, inline: bool = True) -> None: + super().__init__(datapipe, inline=inline) + self.num_sampler = num_samples + + def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + if get_video_backend() == "video_reader": + vid = VideoReader(buffer, device=self.device) + # seek and return frames + metadata = vid.get_metadata()["video"] + duration = metadata["duration"][0] if self.device == "cpu" else metadata["duration"] + fps = metadata["fps"][0] if self.device == "cpu" else metadata["fps"] + max_seek = duration - (self.clip_len / fps + 0.1) # FIXME: random param + seek_idxs = random.sample(list(range(max_seek)), self.num_samples) + for i in seek_idxs: + vid.seek(i) + frame = vid.next() + yield dict( + frame=frame['data'], + pts = frame['pts'], + video_meta=dict( + guessed_fps=fps, + ), + ) + else: + with av.open(buffer, metadata_errors="ignore") as container: + stream = container.streams.video[0] + # duration is given in time_base units as int + duration = stream.duration + # seek to a random frame + seek_idxs = random.sample(list(range(duration)), self.num_samples) + for i in seek_idxs: + container.seek(i, any_frame=True, stream=stream) + frame = next(container.decode(stream)) + yield dict( + frame=Image.from_pil(frame.to_image()), + pts=frame.pts, + video_meta=dict( + time_base=float(frame.time_base), + guessed_fps=float(stream.guessed_rate), + ), + ) + +class ClipDecoder(_VideoDecoder): + def __init__( + self, + datapipe: IterDataPipe, + *, + num_frames_per_clip: int = 8, + num_clips_per_video: int = 1, + step_between_clips: int = 1, + inline: bool = True, + ) -> None: + super().__init__(datapipe, inline=inline) + self.num_frames_per_clip = num_frames_per_clip + self.num_clips_per_video = num_clips_per_video + self.step_between_clips = step_between_clips + + def _unfold(self, tensor: torch.Tensor, dilation: int = 1) -> torch.Tensor: + """ + similar to tensor.unfold, but with the dilation + and specialized for 1d tensors + Returns all consecutive windows of `self.num_frames_per_clip` elements, with + `self.step_between_clips` between windows. The distance between each element + in a window is given by `dilation`. + """ + assert tensor.dim() == 1 + o_stride = tensor.stride(0) + numel = tensor.numel() + new_stride = (self.step_between_clips * o_stride, dilation * o_stride) + new_size = ( + (numel - (dilation * (self.num_frames_per_clip - 1) + 1)) // self.step_between_clips + 1, + self.num_frames_per_clip, + ) + if new_size[0] < 1: + new_size = (0, self.num_frames_per_clip) + return torch.as_strided(tensor, new_size, new_stride) + + def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + if get_video_backend() == "video_reader": + pass + else: + with av.open(buffer, metadata_errors="ignore") as container: + stream = container.streams.video[0] + time_base = stream.time_base + + # duration is given in time_base units as int + duration = stream.duration + + # get video_stream timestramps + # with a tolerance for pyav imprecission + _ptss = torch.arange(duration - 7) + _ptss = self._unfold(_ptss) + # shuffle the clips + perm = torch.randperm(_ptss.size(0)) + idx = perm[: self.num_clips_per_video] + samples = _ptss[idx] + + for clip_pts in samples: + start_pts = clip_pts[0].item() + end_pts = clip_pts[-1].item() + # video_timebase is the default time_base + pts_unit = "pts" + start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, "pts", time_base) + video_frames = video._read_from_stream( + container, + float(start_pts), + float(end_pts), + pts_unit, + stream, + {"video": 0}, + ) + + vframes_list = [frame.to_ndarray(format="rgb24") for frame in video_frames] + + if vframes_list: + vframes = torch.as_tensor(np.stack(vframes_list)) + # account for rounding errors in conversion + # FIXME: fix this in the code + vframes = vframes[: self.num_frames_per_clip, ...] + + else: + vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) + print("FAIL") + + # [N,H,W,C] to [N,C,H,W] + vframes = vframes.permute(0, 3, 1, 2) + assert vframes.size(0) == self.num_frames_per_clip + + # TODO: support sampling rates (FPS change) + # TODO: optimization (read all and select) + + yield { + "clip": vframes, + "pts": clip_pts, + "range": (start_pts, end_pts), + "video_meta": { + "time_base": float(stream.time_base), + "guessed_fps": float(stream.guessed_rate), + }, + } diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index 276aeec2529..d4bc8b550a2 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -1,6 +1,7 @@ import os +import pathlib import sys -from typing import BinaryIO, Tuple, Type, TypeVar, Union, Optional, Any +from typing import BinaryIO, Tuple, Type, TypeVar, Union, Dict, Any, Optional import PIL.Image import torch @@ -9,23 +10,33 @@ from ._feature import _Feature from ._image import Image -D = TypeVar("D", bound="EncodedData") +E = TypeVar("E", bound="EncodedData") class EncodedData(_Feature): - @classmethod - def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor: - # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? - return super()._to_tensor(data, dtype=dtype, device=device) + meta: Dict[str, Any] + + def __new__( + cls: Type[E], + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + **meta: Any, + ) -> E: + encoded_data = super().__new__(cls, data, dtype=dtype, device=device) + encoded_data._metadata.update(dict(meta=meta)) + return encoded_data @classmethod - def from_file(cls: Type[D], file: BinaryIO) -> D: - return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder)) + def from_file(cls: Type[E], file: BinaryIO, **meta: Any) -> E: + return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **meta) @classmethod - def from_path(cls: Type[D], path: Union[str, os.PathLike]) -> D: + def from_path(cls: Type[E], path: Union[str, os.PathLike], **meta: Any) -> E: + path = pathlib.Path(path) with open(path, "rb") as file: - return cls.from_file(file) + return cls.from_file(file, path=path, **meta) class EncodedImage(EncodedData): diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 5ecc4cbedb7..cbc8d542121 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -3,9 +3,10 @@ import warnings from typing import Any, Optional, Union, Tuple, cast +import PIL.Image import torch from torchvision.prototype.utils._internal import StrEnum -from torchvision.transforms.functional import to_pil_image +from torchvision.transforms.functional import to_pil_image, pil_to_tensor from torchvision.utils import draw_bounding_boxes from torchvision.utils import make_grid @@ -75,6 +76,12 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace: else: return ColorSpace.OTHER + @classmethod + def from_pil( + cls, image: PIL.Image.Image, *, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None + ) -> "Image": + return cls(pil_to_tensor(image), dtype=dtype, device=device) + def show(self) -> None: # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we # promote this out of the prototype state diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 24d794a2cb4..aaf32a66fae 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Union, Optional, Tuple import PIL.Image import torch @@ -7,7 +7,7 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: - def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]: + def fn(id: Tuple[Any, ...], input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]: if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): return input diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 2e38471ea65..366a19f2bbc 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -312,15 +312,18 @@ def apply_recursively(fn: Callable, obj: Any) -> Any: return fn(obj) -def query_recursively(fn: Callable[[Any], Optional[D]], obj: Any) -> Iterator[D]: +def query_recursively( + fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = () +) -> Iterator[D]: # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: # "a" == "a"[0][0]... - if (isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str)) or isinstance( - obj, collections.abc.Mapping - ): - for item in obj.values() if isinstance(obj, collections.abc.Mapping) else obj: - yield from query_recursively(fn, item) + if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): + for idx, item in enumerate(obj): + yield from query_recursively(fn, item, id=(*id, idx)) + elif isinstance(obj, collections.abc.Mapping): + for key, item in obj.items(): + yield from query_recursively(fn, item, id=(*id, key)) else: - result = fn(obj) + result = fn(id, obj) if result is not None: yield result