Skip to content

feat: add loader to Omniglot and INaturalist's argument. #8945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import re
import shutil
import string
import sys
import unittest
import xml.etree.ElementTree as ET
import zipfile
Expand Down Expand Up @@ -1146,6 +1147,7 @@ class OmniglotTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Omniglot

ADDITIONAL_CONFIGS = combinations_grid(background=(True, False))
SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir, config):
target_folder = (
Expand Down Expand Up @@ -1902,6 +1904,7 @@ def test_class_to_idx(self):
assert dataset.class_to_idx == class_to_idx


@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.INaturalist
FEATURE_TYPES = (PIL.Image.Image, (int, tuple))
Expand All @@ -1910,6 +1913,7 @@ class INaturalistTestCase(datasets_utils.ImageDatasetTestCase):
target_type=("kingdom", "full", "genus", ["kingdom", "phylum", "class", "order", "family", "genus", "full"]),
version=("2021_train",),
)
SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir, config):
categories = [
Expand Down
8 changes: 7 additions & 1 deletion torchvision/datasets/inaturalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class INaturalist(VisionDataset):
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""

def __init__(
Expand All @@ -72,6 +75,7 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
loader: Optional[Callable[[Union[str, Path]], Any]] = None,
) -> None:
self.version = verify_str_arg(version, "version", DATASET_URLS.keys())

Expand Down Expand Up @@ -109,6 +113,8 @@ def __init__(
for fname in files:
self.index.append((dir_index, fname))

self.loader = loader or Image.open

def _init_2021(self) -> None:
"""Initialize based on 2021 layout"""

Expand Down Expand Up @@ -178,7 +184,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""

cat_id, fname = self.index[index]
img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
img = self.loader(os.path.join(self.root, self.all_categories[cat_id], fname))

target: Any = []
for t in self.target_type:
Expand Down
7 changes: 6 additions & 1 deletion torchvision/datasets/omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class Omniglot(VisionDataset):
download (bool, optional): If true, downloads the dataset zip files from the internet and
puts it in root directory. If the zip files are already downloaded, they are not
downloaded again.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""

folder = "omniglot-py"
Expand All @@ -39,6 +42,7 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
loader: Optional[Callable[[Union[str, Path]], Any]] = None,
) -> None:
super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
self.background = background
Expand All @@ -59,6 +63,7 @@ def __init__(
for idx, character in enumerate(self._characters)
]
self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
self.loader = loader

def __len__(self) -> int:
return len(self._flat_character_images)
Expand All @@ -73,7 +78,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
image_name, character_class = self._flat_character_images[index]
image_path = join(self.target_folder, self._characters[character_class], image_name)
image = Image.open(image_path, mode="r").convert("L")
image = Image.open(image_path, mode="r").convert("L") if self.loader is None else self.loader(image_path)

if self.transform:
image = self.transform(image)
Expand Down