Skip to content

Improve error handling in make_dataset #3496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def test_imagefolder(self):

def test_imagefolder_empty(self):
with get_tmp_dir() as root:
with self.assertRaises(RuntimeError):
with self.assertRaises(FileNotFoundError):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This BC breaking if someone relies on DatasetFolder or ImageFolder to raise a RuntimeError if no samples are found. Since the error type was never documented anywhere, I think this is fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say this is fine but might be worth checking if it breaks something in FBCODE.

torchvision.datasets.ImageFolder(root, loader=lambda x: x)

with self.assertRaises(RuntimeError):
with self.assertRaises(FileNotFoundError):
torchvision.datasets.ImageFolder(
root, loader=lambda x: x, is_valid_file=lambda x: False
)
Expand Down Expand Up @@ -1092,9 +1092,6 @@ def inject_fake_data(self, tmpdir, config):

return num_videos_per_class * len(classes)

def test_not_found_or_corrupted(self):
self.skipTest("Dataset currently does not handle the case of no found videos.")


class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.HMDB51
Expand Down
90 changes: 65 additions & 25 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,52 @@ def is_image_file(filename: str) -> bool:
return has_file_allowed_extension(filename, IMG_EXTENSIONS)


def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset structured as follows:

.. code::

directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext

Args:
directory (str): Root directory path.

Raises:
FileNotFoundError: If ``directory`` has no class folders.

Returns:
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx


def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).

Args:
directory (str): root dataset directory
class_to_idx (Dict[str, int]): dictionary mapping class name to class index
class_to_idx (Optional[Dict[str, int]]): Dictionary mapping class name to class index. If omitted, is generated
by :func:`find_classes`.
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
Expand All @@ -51,21 +86,34 @@ def make_dataset(
is_valid_file should not be passed. Defaults to None.

Raises:
ValueError: In case ``class_to_idx`` is empty.
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
FileNotFoundError: In case no valid file was found for any class.

Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
instances = []
directory = os.path.expanduser(directory)

if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is BC breaking, but in a "right" way. Before make_dataset silently returned an empty list in case class_to_idx was empty. IMO it is a reasonable assumption that no user calls make_dataset without having at least a single class. If you disagree with this assumption, we can get BC back by return [] here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense, especially for the folder dataset.
Is there any other dataset that inherits this other than videodatasets?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any other dataset that inherits this other than videodatasets?

Nothing built-in.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me then

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier DatasetFolder.make_dataset is public and the docstring is:

class_to_idx (Optional[Dict[str, int]]): Dictionary mapping class name to class index. If omitted, is generated by :func:`find_classes`

The problem is that users can override DatasetFolder.find_classes too, so there's a conflict with the find_classes() function: DatasetFolder.find_classes will rely on the function rather than on the method.

Should we make class_to_idx a mandatory parameter in DatasetFolder.make_dataset to avoid any potential issue?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To better explain myself: imagine a scenario where someone has a class MyCoolNewDataset(DatasetFolder) and they override MyCoolNewDataset.find_classes with a custom class_to_idx logic.

If they call MyCoolNewDataset.make_dataset while passing None to the class_to_idx parameter, what will be used is the class_to_idx logic from the find_classes function, instead of using the logic from the find_classes method - which is different since they overrode it.

Does that make sense?

To avoid such issues I'm suggesting to force the user to pass class_to_idx in DatasetFolder.make_dataset, or more accurately to raise an error if None is passed in DatasetFolder.make_dataset


both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

if extensions is not None:

def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))

is_valid_file = cast(Callable[[str], bool], is_valid_file)

instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
Expand All @@ -77,6 +125,17 @@ def is_valid_file(x: str) -> bool:
if is_valid_file(path):
item = path, class_index
instances.append(item)

if target_class not in available_classes:
available_classes.add(target_class)

empty_classes = available_classes - set(class_to_idx.keys())
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {', '.join(extensions)}"
raise FileNotFoundError(msg)

return instances


Expand Down Expand Up @@ -125,11 +184,6 @@ def __init__(
target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
if extensions is not None:
msg += "Supported extensions are: {}".format(",".join(extensions))
raise RuntimeError(msg)

self.loader = loader
self.extensions = extensions
Expand All @@ -148,23 +202,9 @@ def make_dataset(
) -> List[Tuple[str, int]]:
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
"""
Finds the class folders in a dataset.

Args:
dir (string): Root directory path.

Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.

Ensures:
No class is a subdirectory of another.
"""
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
@staticmethod
def _find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]:
Comment on lines +205 to +206
Copy link
Member

@fmassa fmassa Mar 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I wouldn't have made this a staticmethod, as other instantiations of this dataset could rely on self for generating a custom set of classes and class ids

But this is not really important, so let's move forward with this PR and get this merged

return find_classes(dir)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Expand Down
6 changes: 2 additions & 4 deletions torchvision/datasets/hmdb51.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

from .utils import list_dir
from .folder import make_dataset
from .folder import find_classes, make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset

Expand Down Expand Up @@ -62,8 +62,7 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
raise ValueError("fold should be between 1 and 3, got {}".format(fold))

extensions = ('avi',)
classes = sorted(list_dir(root))
class_to_idx = {class_: i for (i, class_) in enumerate(classes)}
self.classes, class_to_idx = find_classes(self.root)
self.samples = make_dataset(
self.root,
class_to_idx,
Expand All @@ -89,7 +88,6 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
self.full_video_clips = video_clips
self.fold = fold
self.train = train
self.classes = classes
self.indices = self._select_fold(video_paths, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices)
self.transform = transform
Expand Down
6 changes: 2 additions & 4 deletions torchvision/datasets/kinetics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .utils import list_dir
from .folder import make_dataset
from .folder import find_classes, make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset

Expand Down Expand Up @@ -56,10 +56,8 @@ def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
_video_min_dimension=0, _audio_samples=0, _audio_channels=0):
super(Kinetics400, self).__init__(root)

classes = list(sorted(list_dir(root)))
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.classes, class_to_idx = find_classes(self.root)
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples]
self.video_clips = VideoClips(
video_list,
Expand Down
6 changes: 2 additions & 4 deletions torchvision/datasets/ucf101.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

from .utils import list_dir
from .folder import make_dataset
from .folder import find_classes, make_dataset
from .video_utils import VideoClips
from .vision import VisionDataset

Expand Down Expand Up @@ -55,10 +55,8 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
self.fold = fold
self.train = train

classes = list(sorted(list_dir(root)))
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.classes, class_to_idx = find_classes(self.root)
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples]
video_clips = VideoClips(
video_list,
Expand Down