-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Improve error handling in make_dataset #3496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6a16b81
dbb9d9a
bacf2a6
87f151a
9fa4a7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,17 +32,52 @@ def is_image_file(filename: str) -> bool: | |
return has_file_allowed_extension(filename, IMG_EXTENSIONS) | ||
|
||
|
||
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: | ||
"""Finds the class folders in a dataset structured as follows: | ||
|
||
.. code:: | ||
|
||
directory/ | ||
├── class_x | ||
│ ├── xxx.ext | ||
│ ├── xxy.ext | ||
│ └── ... | ||
│ └── xxz.ext | ||
└── class_y | ||
├── 123.ext | ||
├── nsdf3.ext | ||
└── ... | ||
└── asd932_.ext | ||
|
||
Args: | ||
directory (str): Root directory path. | ||
|
||
Raises: | ||
FileNotFoundError: If ``directory`` has no class folders. | ||
|
||
Returns: | ||
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index. | ||
""" | ||
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) | ||
if not classes: | ||
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") | ||
|
||
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} | ||
return classes, class_to_idx | ||
|
||
|
||
def make_dataset( | ||
directory: str, | ||
class_to_idx: Dict[str, int], | ||
class_to_idx: Optional[Dict[str, int]] = None, | ||
extensions: Optional[Tuple[str, ...]] = None, | ||
is_valid_file: Optional[Callable[[str], bool]] = None, | ||
) -> List[Tuple[str, int]]: | ||
"""Generates a list of samples of a form (path_to_sample, class). | ||
|
||
Args: | ||
directory (str): root dataset directory | ||
class_to_idx (Dict[str, int]): dictionary mapping class name to class index | ||
class_to_idx (Optional[Dict[str, int]]): Dictionary mapping class name to class index. If omitted, is generated | ||
by :func:`find_classes`. | ||
extensions (optional): A list of allowed extensions. | ||
Either extensions or is_valid_file should be passed. Defaults to None. | ||
is_valid_file (optional): A function that takes path of a file | ||
|
@@ -51,21 +86,34 @@ def make_dataset( | |
is_valid_file should not be passed. Defaults to None. | ||
|
||
Raises: | ||
ValueError: In case ``class_to_idx`` is empty. | ||
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. | ||
FileNotFoundError: In case no valid file was found for any class. | ||
|
||
Returns: | ||
List[Tuple[str, int]]: samples of a form (path_to_sample, class) | ||
""" | ||
instances = [] | ||
directory = os.path.expanduser(directory) | ||
|
||
if class_to_idx is None: | ||
_, class_to_idx = find_classes(directory) | ||
elif not class_to_idx: | ||
raise ValueError("'class_to_index' must have at least one entry to collect any samples.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is BC breaking, but in a "right" way. Before There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this makes sense, especially for the folder dataset. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Nothing built-in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good to me then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pmeier
The problem is that users can override Should we make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To better explain myself: imagine a scenario where someone has a If they call Does that make sense? To avoid such issues I'm suggesting to force the user to pass |
||
|
||
both_none = extensions is None and is_valid_file is None | ||
both_something = extensions is not None and is_valid_file is not None | ||
if both_none or both_something: | ||
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") | ||
|
||
if extensions is not None: | ||
|
||
def is_valid_file(x: str) -> bool: | ||
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) | ||
|
||
is_valid_file = cast(Callable[[str], bool], is_valid_file) | ||
|
||
instances = [] | ||
available_classes = set() | ||
for target_class in sorted(class_to_idx.keys()): | ||
class_index = class_to_idx[target_class] | ||
target_dir = os.path.join(directory, target_class) | ||
|
@@ -77,6 +125,17 @@ def is_valid_file(x: str) -> bool: | |
if is_valid_file(path): | ||
item = path, class_index | ||
instances.append(item) | ||
|
||
if target_class not in available_classes: | ||
available_classes.add(target_class) | ||
|
||
empty_classes = available_classes - set(class_to_idx.keys()) | ||
if empty_classes: | ||
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " | ||
if extensions is not None: | ||
msg += f"Supported extensions are: {', '.join(extensions)}" | ||
raise FileNotFoundError(msg) | ||
|
||
return instances | ||
|
||
|
||
|
@@ -125,11 +184,6 @@ def __init__( | |
target_transform=target_transform) | ||
classes, class_to_idx = self._find_classes(self.root) | ||
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) | ||
if len(samples) == 0: | ||
msg = "Found 0 files in subfolders of: {}\n".format(self.root) | ||
if extensions is not None: | ||
msg += "Supported extensions are: {}".format(",".join(extensions)) | ||
raise RuntimeError(msg) | ||
|
||
self.loader = loader | ||
self.extensions = extensions | ||
|
@@ -148,23 +202,9 @@ def make_dataset( | |
) -> List[Tuple[str, int]]: | ||
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) | ||
|
||
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: | ||
""" | ||
Finds the class folders in a dataset. | ||
|
||
Args: | ||
dir (string): Root directory path. | ||
|
||
Returns: | ||
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. | ||
|
||
Ensures: | ||
No class is a subdirectory of another. | ||
""" | ||
classes = [d.name for d in os.scandir(dir) if d.is_dir()] | ||
classes.sort() | ||
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} | ||
return classes, class_to_idx | ||
@staticmethod | ||
def _find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: | ||
Comment on lines
+205
to
+206
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I wouldn't have made this a But this is not really important, so let's move forward with this PR and get this merged |
||
return find_classes(dir) | ||
|
||
def __getitem__(self, index: int) -> Tuple[Any, Any]: | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This BC breaking if someone relies on
DatasetFolder
orImageFolder
to raise aRuntimeError
if no samples are found. Since the error type was never documented anywhere, I think this is fine.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say this is fine but might be worth checking if it breaks something in FBCODE.