@@ -32,17 +32,52 @@ def is_image_file(filename: str) -> bool:
32
32
return has_file_allowed_extension (filename , IMG_EXTENSIONS )
33
33
34
34
35
+ def find_classes (directory : str ) -> Tuple [List [str ], Dict [str , int ]]:
36
+ """Finds the class folders in a dataset structured as follows:
37
+
38
+ .. code::
39
+
40
+ directory/
41
+ ├── class_x
42
+ │ ├── xxx.ext
43
+ │ ├── xxy.ext
44
+ │ └── ...
45
+ │ └── xxz.ext
46
+ └── class_y
47
+ ├── 123.ext
48
+ ├── nsdf3.ext
49
+ └── ...
50
+ └── asd932_.ext
51
+
52
+ Args:
53
+ directory (str): Root directory path.
54
+
55
+ Raises:
56
+ FileNotFoundError: If ``directory`` has no class folders.
57
+
58
+ Returns:
59
+ (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
60
+ """
61
+ classes = sorted (entry .name for entry in os .scandir (directory ) if entry .is_dir ())
62
+ if not classes :
63
+ raise FileNotFoundError (f"Couldn't find any class folder in { directory } ." )
64
+
65
+ class_to_idx = {cls_name : i for i , cls_name in enumerate (classes )}
66
+ return classes , class_to_idx
67
+
68
+
35
69
def make_dataset (
36
70
directory : str ,
37
- class_to_idx : Dict [str , int ],
71
+ class_to_idx : Optional [ Dict [str , int ]] = None ,
38
72
extensions : Optional [Tuple [str , ...]] = None ,
39
73
is_valid_file : Optional [Callable [[str ], bool ]] = None ,
40
74
) -> List [Tuple [str , int ]]:
41
75
"""Generates a list of samples of a form (path_to_sample, class).
42
76
43
77
Args:
44
78
directory (str): root dataset directory
45
- class_to_idx (Dict[str, int]): dictionary mapping class name to class index
79
+ class_to_idx (Optional[Dict[str, int]]): Dictionary mapping class name to class index. If omitted, is generated
80
+ by :func:`find_classes`.
46
81
extensions (optional): A list of allowed extensions.
47
82
Either extensions or is_valid_file should be passed. Defaults to None.
48
83
is_valid_file (optional): A function that takes path of a file
@@ -51,21 +86,34 @@ def make_dataset(
51
86
is_valid_file should not be passed. Defaults to None.
52
87
53
88
Raises:
89
+ ValueError: In case ``class_to_idx`` is empty.
54
90
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
91
+ FileNotFoundError: In case no valid file was found for any class.
55
92
56
93
Returns:
57
94
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
58
95
"""
59
- instances = []
60
96
directory = os .path .expanduser (directory )
97
+
98
+ if class_to_idx is None :
99
+ _ , class_to_idx = find_classes (directory )
100
+ elif not class_to_idx :
101
+ raise ValueError ("'class_to_index' must have at least one entry to collect any samples." )
102
+
61
103
both_none = extensions is None and is_valid_file is None
62
104
both_something = extensions is not None and is_valid_file is not None
63
105
if both_none or both_something :
64
106
raise ValueError ("Both extensions and is_valid_file cannot be None or not None at the same time" )
107
+
65
108
if extensions is not None :
109
+
66
110
def is_valid_file (x : str ) -> bool :
67
111
return has_file_allowed_extension (x , cast (Tuple [str , ...], extensions ))
112
+
68
113
is_valid_file = cast (Callable [[str ], bool ], is_valid_file )
114
+
115
+ instances = []
116
+ available_classes = set ()
69
117
for target_class in sorted (class_to_idx .keys ()):
70
118
class_index = class_to_idx [target_class ]
71
119
target_dir = os .path .join (directory , target_class )
@@ -77,6 +125,17 @@ def is_valid_file(x: str) -> bool:
77
125
if is_valid_file (path ):
78
126
item = path , class_index
79
127
instances .append (item )
128
+
129
+ if target_class not in available_classes :
130
+ available_classes .add (target_class )
131
+
132
+ empty_classes = available_classes - set (class_to_idx .keys ())
133
+ if empty_classes :
134
+ msg = f"Found no valid file for the classes { ', ' .join (sorted (empty_classes ))} . "
135
+ if extensions is not None :
136
+ msg += f"Supported extensions are: { ', ' .join (extensions )} "
137
+ raise FileNotFoundError (msg )
138
+
80
139
return instances
81
140
82
141
@@ -125,11 +184,6 @@ def __init__(
125
184
target_transform = target_transform )
126
185
classes , class_to_idx = self ._find_classes (self .root )
127
186
samples = self .make_dataset (self .root , class_to_idx , extensions , is_valid_file )
128
- if len (samples ) == 0 :
129
- msg = "Found 0 files in subfolders of: {}\n " .format (self .root )
130
- if extensions is not None :
131
- msg += "Supported extensions are: {}" .format ("," .join (extensions ))
132
- raise RuntimeError (msg )
133
187
134
188
self .loader = loader
135
189
self .extensions = extensions
@@ -148,23 +202,9 @@ def make_dataset(
148
202
) -> List [Tuple [str , int ]]:
149
203
return make_dataset (directory , class_to_idx , extensions = extensions , is_valid_file = is_valid_file )
150
204
151
- def _find_classes (self , dir : str ) -> Tuple [List [str ], Dict [str , int ]]:
152
- """
153
- Finds the class folders in a dataset.
154
-
155
- Args:
156
- dir (string): Root directory path.
157
-
158
- Returns:
159
- tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
160
-
161
- Ensures:
162
- No class is a subdirectory of another.
163
- """
164
- classes = [d .name for d in os .scandir (dir ) if d .is_dir ()]
165
- classes .sort ()
166
- class_to_idx = {cls_name : i for i , cls_name in enumerate (classes )}
167
- return classes , class_to_idx
205
+ @staticmethod
206
+ def _find_classes (dir : str ) -> Tuple [List [str ], Dict [str , int ]]:
207
+ return find_classes (dir )
168
208
169
209
def __getitem__ (self , index : int ) -> Tuple [Any , Any ]:
170
210
"""
0 commit comments