Skip to content

Commit 0818c68

Browse files
authored
Improve error handling in make_dataset (#3496)
* factor out find_classes * use find_classes in video datasets * adapt old tests
1 parent 19ad0bb commit 0818c68

File tree

5 files changed

+73
-42
lines changed

5 files changed

+73
-42
lines changed

test/test_datasets.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ def test_imagefolder(self):
111111

112112
def test_imagefolder_empty(self):
113113
with get_tmp_dir() as root:
114-
with self.assertRaises(RuntimeError):
114+
with self.assertRaises(FileNotFoundError):
115115
torchvision.datasets.ImageFolder(root, loader=lambda x: x)
116116

117-
with self.assertRaises(RuntimeError):
117+
with self.assertRaises(FileNotFoundError):
118118
torchvision.datasets.ImageFolder(
119119
root, loader=lambda x: x, is_valid_file=lambda x: False
120120
)
@@ -1092,9 +1092,6 @@ def inject_fake_data(self, tmpdir, config):
10921092

10931093
return num_videos_per_class * len(classes)
10941094

1095-
def test_not_found_or_corrupted(self):
1096-
self.skipTest("Dataset currently does not handle the case of no found videos.")
1097-
10981095

10991096
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
11001097
DATASET_CLASS = datasets.HMDB51

torchvision/datasets/folder.py

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,52 @@ def is_image_file(filename: str) -> bool:
3232
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
3333

3434

35+
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
36+
"""Finds the class folders in a dataset structured as follows:
37+
38+
.. code::
39+
40+
directory/
41+
├── class_x
42+
│ ├── xxx.ext
43+
│ ├── xxy.ext
44+
│ └── ...
45+
│ └── xxz.ext
46+
└── class_y
47+
├── 123.ext
48+
├── nsdf3.ext
49+
└── ...
50+
└── asd932_.ext
51+
52+
Args:
53+
directory (str): Root directory path.
54+
55+
Raises:
56+
FileNotFoundError: If ``directory`` has no class folders.
57+
58+
Returns:
59+
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
60+
"""
61+
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
62+
if not classes:
63+
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
64+
65+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
66+
return classes, class_to_idx
67+
68+
3569
def make_dataset(
3670
directory: str,
37-
class_to_idx: Dict[str, int],
71+
class_to_idx: Optional[Dict[str, int]] = None,
3872
extensions: Optional[Tuple[str, ...]] = None,
3973
is_valid_file: Optional[Callable[[str], bool]] = None,
4074
) -> List[Tuple[str, int]]:
4175
"""Generates a list of samples of a form (path_to_sample, class).
4276
4377
Args:
4478
directory (str): root dataset directory
45-
class_to_idx (Dict[str, int]): dictionary mapping class name to class index
79+
class_to_idx (Optional[Dict[str, int]]): Dictionary mapping class name to class index. If omitted, is generated
80+
by :func:`find_classes`.
4681
extensions (optional): A list of allowed extensions.
4782
Either extensions or is_valid_file should be passed. Defaults to None.
4883
is_valid_file (optional): A function that takes path of a file
@@ -51,21 +86,34 @@ def make_dataset(
5186
is_valid_file should not be passed. Defaults to None.
5287
5388
Raises:
89+
ValueError: In case ``class_to_idx`` is empty.
5490
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
91+
FileNotFoundError: In case no valid file was found for any class.
5592
5693
Returns:
5794
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
5895
"""
59-
instances = []
6096
directory = os.path.expanduser(directory)
97+
98+
if class_to_idx is None:
99+
_, class_to_idx = find_classes(directory)
100+
elif not class_to_idx:
101+
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
102+
61103
both_none = extensions is None and is_valid_file is None
62104
both_something = extensions is not None and is_valid_file is not None
63105
if both_none or both_something:
64106
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
107+
65108
if extensions is not None:
109+
66110
def is_valid_file(x: str) -> bool:
67111
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
112+
68113
is_valid_file = cast(Callable[[str], bool], is_valid_file)
114+
115+
instances = []
116+
available_classes = set()
69117
for target_class in sorted(class_to_idx.keys()):
70118
class_index = class_to_idx[target_class]
71119
target_dir = os.path.join(directory, target_class)
@@ -77,6 +125,17 @@ def is_valid_file(x: str) -> bool:
77125
if is_valid_file(path):
78126
item = path, class_index
79127
instances.append(item)
128+
129+
if target_class not in available_classes:
130+
available_classes.add(target_class)
131+
132+
empty_classes = available_classes - set(class_to_idx.keys())
133+
if empty_classes:
134+
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
135+
if extensions is not None:
136+
msg += f"Supported extensions are: {', '.join(extensions)}"
137+
raise FileNotFoundError(msg)
138+
80139
return instances
81140

82141

@@ -125,11 +184,6 @@ def __init__(
125184
target_transform=target_transform)
126185
classes, class_to_idx = self._find_classes(self.root)
127186
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
128-
if len(samples) == 0:
129-
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
130-
if extensions is not None:
131-
msg += "Supported extensions are: {}".format(",".join(extensions))
132-
raise RuntimeError(msg)
133187

134188
self.loader = loader
135189
self.extensions = extensions
@@ -148,23 +202,9 @@ def make_dataset(
148202
) -> List[Tuple[str, int]]:
149203
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
150204

151-
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
152-
"""
153-
Finds the class folders in a dataset.
154-
155-
Args:
156-
dir (string): Root directory path.
157-
158-
Returns:
159-
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
160-
161-
Ensures:
162-
No class is a subdirectory of another.
163-
"""
164-
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
165-
classes.sort()
166-
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
167-
return classes, class_to_idx
205+
@staticmethod
206+
def _find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]:
207+
return find_classes(dir)
168208

169209
def __getitem__(self, index: int) -> Tuple[Any, Any]:
170210
"""

torchvision/datasets/hmdb51.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33

44
from .utils import list_dir
5-
from .folder import make_dataset
5+
from .folder import find_classes, make_dataset
66
from .video_utils import VideoClips
77
from .vision import VisionDataset
88

@@ -62,8 +62,7 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
6262
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
6363

6464
extensions = ('avi',)
65-
classes = sorted(list_dir(root))
66-
class_to_idx = {class_: i for (i, class_) in enumerate(classes)}
65+
self.classes, class_to_idx = find_classes(self.root)
6766
self.samples = make_dataset(
6867
self.root,
6968
class_to_idx,
@@ -89,7 +88,6 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
8988
self.full_video_clips = video_clips
9089
self.fold = fold
9190
self.train = train
92-
self.classes = classes
9391
self.indices = self._select_fold(video_paths, annotation_path, fold, train)
9492
self.video_clips = video_clips.subset(self.indices)
9593
self.transform = transform

torchvision/datasets/kinetics.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .utils import list_dir
2-
from .folder import make_dataset
2+
from .folder import find_classes, make_dataset
33
from .video_utils import VideoClips
44
from .vision import VisionDataset
55

@@ -56,10 +56,8 @@ def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
5656
_video_min_dimension=0, _audio_samples=0, _audio_channels=0):
5757
super(Kinetics400, self).__init__(root)
5858

59-
classes = list(sorted(list_dir(root)))
60-
class_to_idx = {classes[i]: i for i in range(len(classes))}
59+
self.classes, class_to_idx = find_classes(self.root)
6160
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
62-
self.classes = classes
6361
video_list = [x[0] for x in self.samples]
6462
self.video_clips = VideoClips(
6563
video_list,

torchvision/datasets/ucf101.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22

33
from .utils import list_dir
4-
from .folder import make_dataset
4+
from .folder import find_classes, make_dataset
55
from .video_utils import VideoClips
66
from .vision import VisionDataset
77

@@ -55,10 +55,8 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
5555
self.fold = fold
5656
self.train = train
5757

58-
classes = list(sorted(list_dir(root)))
59-
class_to_idx = {classes[i]: i for i in range(len(classes))}
58+
self.classes, class_to_idx = find_classes(self.root)
6059
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
61-
self.classes = classes
6260
video_list = [x[0] for x in self.samples]
6361
video_clips = VideoClips(
6462
video_list,

0 commit comments

Comments
 (0)