From e312085f88649e729f47b1fe9a8bef1bd2d950e6 Mon Sep 17 00:00:00 2001 From: "ernest.parke" Date: Mon, 4 Jun 2018 14:48:15 -0400 Subject: [PATCH 1/2] Addresses issue #145 as per @fmessa's suggestion. --- torchvision/datasets/folder.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 4352b120d07..6b75e5d4102 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -32,17 +32,11 @@ def is_image_file(filename): return has_file_allowed_extension(filename, IMG_EXTENSIONS) -def find_classes(dir): - classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] - classes.sort() - class_to_idx = {classes[i]: i for i in range(len(classes))} - return classes, class_to_idx - def make_dataset(dir, class_to_idx, extensions): images = [] dir = os.path.expanduser(dir) - for target in sorted(os.listdir(dir)): + for target in sorted(class_to_idx.keys()): d = os.path.join(dir, target) if not os.path.isdir(d): continue @@ -86,7 +80,7 @@ class DatasetFolder(data.Dataset): """ def __init__(self, root, loader, extensions, transform=None, target_transform=None): - classes, class_to_idx = find_classes(root) + classes, class_to_idx = self._find_classes(root) samples = make_dataset(root, class_to_idx, extensions) if len(samples) == 0: raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" @@ -104,6 +98,24 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No self.transform = transform self.target_transform = target_transform + def _find_classes(self, dir): + """ + Finds the class folders in a dataset. + + Args: + dir (string): Root directory path. + + Returns: + tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + + Ensures: + No class is a subdirectory of another. + """ + classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx + def __getitem__(self, index): """ Args: From 38b05b5fa1aeac4b8f8c1836d18b6c636ac9b06a Mon Sep 17 00:00:00 2001 From: "ernest.parke" Date: Mon, 4 Jun 2018 16:01:44 -0400 Subject: [PATCH 2/2] Removed blank line for styling. --- torchvision/datasets/folder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 6b75e5d4102..52362ca198f 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -32,7 +32,6 @@ def is_image_file(filename): return has_file_allowed_extension(filename, IMG_EXTENSIONS) - def make_dataset(dir, class_to_idx, extensions): images = [] dir = os.path.expanduser(dir)