From f35c0efce5468493973e3d141cdfa8c04b6b66c8 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 28 Feb 2025 22:32:40 +0800 Subject: [PATCH 1/2] feat: add loader to `Omniglot` and `INaturalist`. --- test/test_datasets.py | 2 ++ torchvision/datasets/inaturalist.py | 8 +++++++- torchvision/datasets/omniglot.py | 7 ++++++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 1413d2c312d..02fce47f7e3 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1146,6 +1146,7 @@ class OmniglotTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Omniglot ADDITIONAL_CONFIGS = combinations_grid(background=(True, False)) + SUPPORT_TV_IMAGE_DECODE = True def inject_fake_data(self, tmpdir, config): target_folder = ( @@ -1910,6 +1911,7 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase): target_type=("kingdom", "full", "genus", ["kingdom", "phylum", "class", "order", "family", "genus", "full"]), version=("2021_train",), ) + SUPPORT_TV_IMAGE_DECODE = True def inject_fake_data(self, tmpdir, config): categories = [ diff --git a/torchvision/datasets/inaturalist.py b/torchvision/datasets/inaturalist.py index e041d41f4a2..8713bc041db 100644 --- a/torchvision/datasets/inaturalist.py +++ b/torchvision/datasets/inaturalist.py @@ -62,6 +62,9 @@ class INaturalist(VisionDataset): download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ def __init__( @@ -72,6 +75,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Optional[Callable[[Union[str, Path]], Any]] = None, ) -> None: self.version = verify_str_arg(version, "version", DATASET_URLS.keys()) @@ -109,6 +113,8 @@ def __init__( for fname in files: self.index.append((dir_index, fname)) + self.loader = loader or Image.open + def _init_2021(self) -> None: """Initialize based on 2021 layout""" @@ -178,7 +184,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: """ cat_id, fname = self.index[index] - img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname)) + img = self.loader(os.path.join(self.root, self.all_categories[cat_id], fname)) target: Any = [] for t in self.target_type: diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index c3434a72456..f8d182cdb25 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -23,6 +23,9 @@ class Omniglot(VisionDataset): download (bool, optional): If true, downloads the dataset zip files from the internet and puts it in root directory. If the zip files are already downloaded, they are not downloaded again. + loader (callable, optional): A function to load an image given its path. + By default, it uses PIL as its image loader, but users could also pass in + ``torchvision.io.decode_image`` for decoding image data into tensors directly. """ folder = "omniglot-py" @@ -39,6 +42,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, + loader: Optional[Callable[[Union[str, Path]], Any]] = None, ) -> None: super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform) self.background = background @@ -59,6 +63,7 @@ def __init__( for idx, character in enumerate(self._characters) ] self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, []) + self.loader = loader def __len__(self) -> int: return len(self._flat_character_images) @@ -73,7 +78,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: """ image_name, character_class = self._flat_character_images[index] image_path = join(self.target_folder, self._characters[character_class], image_name) - image = Image.open(image_path, mode="r").convert("L") + image = Image.open(image_path, mode="r").convert("L") if self.loader is None else self.loader(image_path) if self.transform: image = self.transform(image) From 0e8a85eeb5919df0f69c7af80380d62a52b07002 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 26 Mar 2025 15:52:11 +0000 Subject: [PATCH 2/2] skip on windows --- test/test_datasets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_datasets.py b/test/test_datasets.py index 84f243388aa..feaabd7acd2 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -11,6 +11,7 @@ import re import shutil import string +import sys import unittest import xml.etree.ElementTree as ET import zipfile @@ -1903,6 +1904,7 @@ def test_class_to_idx(self): assert dataset.class_to_idx == class_to_idx +@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows") class INaturalistTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.INaturalist FEATURE_TYPES = (PIL.Image.Image, (int, tuple))