From 6a16b8132fa6697ed6e07bfb506ac1bb460576df Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 3 Mar 2021 16:29:54 +0100 Subject: [PATCH 1/3] factor out find_classes --- torchvision/datasets/folder.py | 90 ++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 25 deletions(-) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index ef3ae7af896..fb4861e637a 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -32,9 +32,43 @@ def is_image_file(filename: str) -> bool: return has_file_allowed_extension(filename, IMG_EXTENSIONS) +def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: + """Finds the class folders in a dataset structured as follows: + + .. code:: + + directory/ + ├── class_x + │ ├── xxx.ext + │ ├── xxy.ext + │ └── ... + │ └── xxz.ext + └── class_y + ├── 123.ext + ├── nsdf3.ext + └── ... + └── asd932_.ext + + Args: + directory (str): Root directory path. + + Raises: + FileNotFoundError: If ``directory`` has no class folders. + + Returns: + (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index. + """ + classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) + if not classes: + raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") + + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + + def make_dataset( directory: str, - class_to_idx: Dict[str, int], + class_to_idx: Optional[Dict[str, int]] = None, extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> List[Tuple[str, int]]: @@ -42,7 +76,8 @@ def make_dataset( Args: directory (str): root dataset directory - class_to_idx (Dict[str, int]): dictionary mapping class name to class index + class_to_idx (Optional[Dict[str, int]]): Dictionary mapping class name to class index. If omitted, is generated + by :func:`find_classes`. extensions (optional): A list of allowed extensions. Either extensions or is_valid_file should be passed. Defaults to None. is_valid_file (optional): A function that takes path of a file @@ -51,21 +86,34 @@ def make_dataset( is_valid_file should not be passed. Defaults to None. Raises: + ValueError: In case ``class_to_idx`` is empty. ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. + FileNotFoundError: In case no valid file was found for any class. Returns: List[Tuple[str, int]]: samples of a form (path_to_sample, class) """ - instances = [] directory = os.path.expanduser(directory) + + if class_to_idx is None: + _, class_to_idx = find_classes(directory) + elif not class_to_idx: + raise ValueError("'class_to_index' must have at least one entry to collect any samples.") + both_none = extensions is None and is_valid_file is None both_something = extensions is not None and is_valid_file is not None if both_none or both_something: raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + if extensions is not None: + def is_valid_file(x: str) -> bool: return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) + is_valid_file = cast(Callable[[str], bool], is_valid_file) + + instances = [] + available_classes = set() for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] target_dir = os.path.join(directory, target_class) @@ -77,6 +125,17 @@ def is_valid_file(x: str) -> bool: if is_valid_file(path): item = path, class_index instances.append(item) + + if target_class not in available_classes: + available_classes.add(target_class) + + empty_classes = available_classes - set(class_to_idx.keys()) + if empty_classes: + msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " + if extensions is not None: + msg += f"Supported extensions are: {', '.join(extensions)}" + raise FileNotFoundError(msg) + return instances @@ -125,11 +184,6 @@ def __init__( target_transform=target_transform) classes, class_to_idx = self._find_classes(self.root) samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) - if len(samples) == 0: - msg = "Found 0 files in subfolders of: {}\n".format(self.root) - if extensions is not None: - msg += "Supported extensions are: {}".format(",".join(extensions)) - raise RuntimeError(msg) self.loader = loader self.extensions = extensions @@ -148,23 +202,9 @@ def make_dataset( ) -> List[Tuple[str, int]]: return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) - def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: - """ - Finds the class folders in a dataset. - - Args: - dir (string): Root directory path. - - Returns: - tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. - - Ensures: - No class is a subdirectory of another. - """ - classes = [d.name for d in os.scandir(dir) if d.is_dir()] - classes.sort() - class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - return classes, class_to_idx + @staticmethod + def _find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: + return find_classes(dir) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ From dbb9d9a63fedc7c7ed9a94f49d7a479c2340d27a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 3 Mar 2021 16:31:51 +0100 Subject: [PATCH 2/3] use find_classes in video datasets --- test/test_datasets.py | 3 --- torchvision/datasets/hmdb51.py | 6 ++---- torchvision/datasets/kinetics.py | 6 ++---- torchvision/datasets/ucf101.py | 6 ++---- 4 files changed, 6 insertions(+), 15 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 859419df2b0..87f3ae22edd 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1031,9 +1031,6 @@ def inject_fake_data(self, tmpdir, config): return num_videos_per_class * len(classes) - def test_not_found_or_corrupted(self): - self.skipTest("Dataset currently does not handle the case of no found videos.") - class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): DATASET_CLASS = datasets.HMDB51 diff --git a/torchvision/datasets/hmdb51.py b/torchvision/datasets/hmdb51.py index 621630cf264..113186b71b9 100644 --- a/torchvision/datasets/hmdb51.py +++ b/torchvision/datasets/hmdb51.py @@ -2,7 +2,7 @@ import os from .utils import list_dir -from .folder import make_dataset +from .folder import find_classes, make_dataset from .video_utils import VideoClips from .vision import VisionDataset @@ -62,8 +62,7 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, raise ValueError("fold should be between 1 and 3, got {}".format(fold)) extensions = ('avi',) - classes = sorted(list_dir(root)) - class_to_idx = {class_: i for (i, class_) in enumerate(classes)} + self.classes, class_to_idx = find_classes(self.root) self.samples = make_dataset( self.root, class_to_idx, @@ -89,7 +88,6 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, self.full_video_clips = video_clips self.fold = fold self.train = train - self.classes = classes self.indices = self._select_fold(video_paths, annotation_path, fold, train) self.video_clips = video_clips.subset(self.indices) self.transform = transform diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index e977fc42ba7..000c10cb37c 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -1,5 +1,5 @@ from .utils import list_dir -from .folder import make_dataset +from .folder import find_classes, make_dataset from .video_utils import VideoClips from .vision import VisionDataset @@ -55,10 +55,8 @@ def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None, _video_min_dimension=0, _audio_samples=0, _audio_channels=0): super(Kinetics400, self).__init__(root) - classes = list(sorted(list_dir(root))) - class_to_idx = {classes[i]: i for i in range(len(classes))} + self.classes, class_to_idx = find_classes(self.root) self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None) - self.classes = classes video_list = [x[0] for x in self.samples] self.video_clips = VideoClips( video_list, diff --git a/torchvision/datasets/ucf101.py b/torchvision/datasets/ucf101.py index e5cf11d7fa2..709151c2fcb 100644 --- a/torchvision/datasets/ucf101.py +++ b/torchvision/datasets/ucf101.py @@ -1,7 +1,7 @@ import os from .utils import list_dir -from .folder import make_dataset +from .folder import find_classes, make_dataset from .video_utils import VideoClips from .vision import VisionDataset @@ -55,10 +55,8 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, self.fold = fold self.train = train - classes = list(sorted(list_dir(root))) - class_to_idx = {classes[i]: i for i in range(len(classes))} + self.classes, class_to_idx = find_classes(self.root) self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None) - self.classes = classes video_list = [x[0] for x in self.samples] video_clips = VideoClips( video_list, From bacf2a6824c75ca449e521896ecf9839c8b84cd5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 3 Mar 2021 16:46:17 +0100 Subject: [PATCH 3/3] adapt old tests --- test/test_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 87f3ae22edd..299cf1e61c3 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -111,10 +111,10 @@ def test_imagefolder(self): def test_imagefolder_empty(self): with get_tmp_dir() as root: - with self.assertRaises(RuntimeError): + with self.assertRaises(FileNotFoundError): torchvision.datasets.ImageFolder(root, loader=lambda x: x) - with self.assertRaises(RuntimeError): + with self.assertRaises(FileNotFoundError): torchvision.datasets.ImageFolder( root, loader=lambda x: x, is_valid_file=lambda x: False )