diff --git a/test/test_datasets.py b/test/test_datasets.py index 407a6bfb338..d1efc385a5f 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1092,5 +1092,35 @@ def _create_split_files(self, root, video_files, fold, train): return num_train_videos if train else (num_videos - num_train_videos) +class OmniglotTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Omniglot + + CONFIGS = datasets_utils.combinations_grid(background=(True, False)) + + def inject_fake_data(self, tmpdir, config): + target_folder = ( + pathlib.Path(tmpdir) / "omniglot-py" / f"images_{'background' if config['background'] else 'evaluation'}" + ) + os.makedirs(target_folder) + + num_images = 0 + for name in ("Alphabet_of_the_Magi", "Tifinagh"): + num_images += self._create_alphabet_folder(target_folder, name) + + return num_images + + def _create_alphabet_folder(self, root, name): + num_images_total = 0 + for idx in range(torch.randint(1, 4, size=()).item()): + num_images = torch.randint(1, 4, size=()).item() + num_images_total += num_images + + datasets_utils.create_image_folder( + root / name, f"character{idx:02d}", lambda image_idx: f"{image_idx:02d}.png", num_images + ) + + return num_images_total + + if __name__ == "__main__": unittest.main()