diff --git a/test/test_datasets.py b/test/test_datasets.py index 60bc7613e53..407a6bfb338 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1032,5 +1032,65 @@ def test_not_found_or_corrupted(self): self.skipTest("Dataset currently does not handle the case of no found videos.") +class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): + DATASET_CLASS = datasets.HMDB51 + + CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False)) + + _VIDEO_FOLDER = "videos" + _SPLITS_FOLDER = "splits" + _CLASSES = ("brush_hair", "wave") + + def dataset_args(self, tmpdir, config): + tmpdir = pathlib.Path(tmpdir) + root = tmpdir / self._VIDEO_FOLDER + annotation_path = tmpdir / self._SPLITS_FOLDER + return root, annotation_path + + def inject_fake_data(self, tmpdir, config): + tmpdir = pathlib.Path(tmpdir) + + video_folder = tmpdir / self._VIDEO_FOLDER + os.makedirs(video_folder) + video_files = self._create_videos(video_folder) + + splits_folder = tmpdir / self._SPLITS_FOLDER + os.makedirs(splits_folder) + num_examples = self._create_split_files(splits_folder, video_files, config["fold"], config["train"]) + + return num_examples + + def _create_videos(self, root, num_examples_per_class=3): + def file_name_fn(cls, idx, clips_per_group=2): + return f"{cls}_{(idx // clips_per_group) + 1:d}_{(idx % clips_per_group) + 1:d}.avi" + + return [ + ( + cls, + datasets_utils.create_video_folder( + root, + cls, + lambda idx: file_name_fn(cls, idx), + num_examples_per_class, + ), + ) + for cls in self._CLASSES + ] + + def _create_split_files(self, root, video_files, fold, train): + num_videos = num_train_videos = 0 + + for cls, videos in video_files: + num_videos += len(videos) + + train_videos = set(random.sample(videos, random.randrange(1, len(videos) - 1))) + num_train_videos += len(train_videos) + + with open(pathlib.Path(root) / f"{cls}_test_split{fold}.txt", "w") as fh: + fh.writelines(f"{file.name} {1 if file in train_videos else 2}\n" for file in videos) + + return num_train_videos if train else (num_videos - num_train_videos) + + if __name__ == "__main__": unittest.main()