Skip to content

Commit 9563e3e

Browse files
authored
Add allow_empty parameter to ImageFolder and related utils (#8311)
1 parent e00f4e6 commit 9563e3e

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

test/test_datasets.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,6 +1620,10 @@ def inject_fake_data(self, tmpdir, config):
16201620
num_examples_total += num_examples
16211621
classes.append(cls)
16221622

1623+
if config.pop("make_empty_class", False):
1624+
os.makedirs(pathlib.Path(tmpdir) / "empty_class")
1625+
classes.append("empty_class")
1626+
16231627
return dict(num_examples=num_examples_total, classes=classes)
16241628

16251629
def _file_name_fn(self, cls, ext, idx):
@@ -1644,6 +1648,23 @@ def test_classes(self, config):
16441648
assert len(dataset.classes) == len(info["classes"])
16451649
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])
16461650

1651+
def test_allow_empty(self):
1652+
config = {
1653+
"extensions": self._EXTENSIONS,
1654+
"make_empty_class": True,
1655+
}
1656+
1657+
config["allow_empty"] = True
1658+
with self.create_dataset(config) as (dataset, info):
1659+
assert "empty_class" in dataset.classes
1660+
assert len(dataset.classes) == len(info["classes"])
1661+
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])
1662+
1663+
config["allow_empty"] = False
1664+
with pytest.raises(FileNotFoundError, match="Found no valid file"):
1665+
with self.create_dataset(config) as (dataset, info):
1666+
pass
1667+
16471668

16481669
class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
16491670
DATASET_CLASS = datasets.ImageFolder

torchvision/datasets/folder.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def make_dataset(
5050
class_to_idx: Optional[Dict[str, int]] = None,
5151
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
5252
is_valid_file: Optional[Callable[[str], bool]] = None,
53+
allow_empty: bool = False,
5354
) -> List[Tuple[str, int]]:
5455
"""Generates a list of samples of a form (path_to_sample, class).
5556
@@ -95,7 +96,7 @@ def is_valid_file(x: str) -> bool:
9596
available_classes.add(target_class)
9697

9798
empty_classes = set(class_to_idx.keys()) - available_classes
98-
if empty_classes:
99+
if empty_classes and not allow_empty:
99100
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
100101
if extensions is not None:
101102
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
@@ -123,6 +124,8 @@ class DatasetFolder(VisionDataset):
123124
is_valid_file (callable, optional): A function that takes path of a file
124125
and check if the file is a valid file (used to check of corrupt files)
125126
both extensions and is_valid_file should not be passed.
127+
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
128+
An error is raised on empty folders if False (default).
126129
127130
Attributes:
128131
classes (list): List of the class names sorted alphabetically.
@@ -139,10 +142,17 @@ def __init__(
139142
transform: Optional[Callable] = None,
140143
target_transform: Optional[Callable] = None,
141144
is_valid_file: Optional[Callable[[str], bool]] = None,
145+
allow_empty: bool = False,
142146
) -> None:
143147
super().__init__(root, transform=transform, target_transform=target_transform)
144148
classes, class_to_idx = self.find_classes(self.root)
145-
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
149+
samples = self.make_dataset(
150+
self.root,
151+
class_to_idx=class_to_idx,
152+
extensions=extensions,
153+
is_valid_file=is_valid_file,
154+
allow_empty=allow_empty,
155+
)
146156

147157
self.loader = loader
148158
self.extensions = extensions
@@ -158,6 +168,7 @@ def make_dataset(
158168
class_to_idx: Dict[str, int],
159169
extensions: Optional[Tuple[str, ...]] = None,
160170
is_valid_file: Optional[Callable[[str], bool]] = None,
171+
allow_empty: bool = False,
161172
) -> List[Tuple[str, int]]:
162173
"""Generates a list of samples of a form (path_to_sample, class).
163174
@@ -172,6 +183,8 @@ def make_dataset(
172183
and checks if the file is a valid file
173184
(used to check of corrupt files) both extensions and
174185
is_valid_file should not be passed. Defaults to None.
186+
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
187+
An error is raised on empty folders if False (default).
175188
176189
Raises:
177190
ValueError: In case ``class_to_idx`` is empty.
@@ -186,7 +199,9 @@ def make_dataset(
186199
# find_classes() function, instead of using that of the find_classes() method, which
187200
# is potentially overridden and thus could have a different logic.
188201
raise ValueError("The class_to_idx parameter cannot be None.")
189-
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
202+
return make_dataset(
203+
directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty
204+
)
190205

191206
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
192207
"""Find the class folders in a dataset structured as follows::
@@ -291,6 +306,8 @@ class ImageFolder(DatasetFolder):
291306
loader (callable, optional): A function to load an image given its path.
292307
is_valid_file (callable, optional): A function that takes path of an Image file
293308
and check if the file is a valid file (used to check of corrupt files)
309+
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
310+
An error is raised on empty folders if False (default).
294311
295312
Attributes:
296313
classes (list): List of the class names sorted alphabetically.
@@ -305,6 +322,7 @@ def __init__(
305322
target_transform: Optional[Callable] = None,
306323
loader: Callable[[str], Any] = default_loader,
307324
is_valid_file: Optional[Callable[[str], bool]] = None,
325+
allow_empty: bool = False,
308326
):
309327
super().__init__(
310328
root,
@@ -313,5 +331,6 @@ def __init__(
313331
transform=transform,
314332
target_transform=target_transform,
315333
is_valid_file=is_valid_file,
334+
allow_empty=allow_empty,
316335
)
317336
self.imgs = self.samples

0 commit comments

Comments
 (0)