diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 4352b120d07..52362ca198f 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -32,17 +32,10 @@ def is_image_file(filename): return has_file_allowed_extension(filename, IMG_EXTENSIONS) -def find_classes(dir): - classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] - classes.sort() - class_to_idx = {classes[i]: i for i in range(len(classes))} - return classes, class_to_idx - - def make_dataset(dir, class_to_idx, extensions): images = [] dir = os.path.expanduser(dir) - for target in sorted(os.listdir(dir)): + for target in sorted(class_to_idx.keys()): d = os.path.join(dir, target) if not os.path.isdir(d): continue @@ -86,7 +79,7 @@ class DatasetFolder(data.Dataset): """ def __init__(self, root, loader, extensions, transform=None, target_transform=None): - classes, class_to_idx = find_classes(root) + classes, class_to_idx = self._find_classes(root) samples = make_dataset(root, class_to_idx, extensions) if len(samples) == 0: raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" @@ -104,6 +97,24 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No self.transform = transform self.target_transform = target_transform + def _find_classes(self, dir): + """ + Finds the class folders in a dataset. + + Args: + dir (string): Root directory path. + + Returns: + tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + + Ensures: + No class is a subdirectory of another. + """ + classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx + def __getitem__(self, index): """ Args: