diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 633ebab0153..041874bacbc 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -23,7 +23,7 @@ def find_classes(dir): def make_dataset(dir, class_to_idx): images = [] - for target in os.listdir(dir): + for target in class_to_idx.keys(): d = os.path.join(dir, target) if not os.path.isdir(d): continue @@ -43,10 +43,33 @@ def default_loader(path): class ImageFolder(data.Dataset): - - def __init__(self, root, transform=None, target_transform=None, - loader=default_loader): - classes, class_to_idx = find_classes(root) + """ + A class representing a directory of images as a `Dataset`. + + Args: + root (String): The path to the directory. + transform (Object): A callable object that transforms images. See: torchvision/transforms.py. By default no + transformations will be applied to images. + target_transform (Object): A callable object that transforms targets (labels). By default no transformations + will be applied to targets. + loader (Function): Loads an image and returns it in a usable form. By default loads images by + their path and returns a `PIL.Image` instance. + classes (List/Tuple): The sub-directories of `root` that correspond to the classes of this data set. By default + all sub-directories of `root` are used. + + Example: + >>> dataset = folder.ImageFolder('./dataset', transform=transforms.Compose([ + >>> transforms.Scale(size=224), + >>> transforms.RandomCrop(size=224), + >>> transforms.ToTensor() + >>> ]), classes=['cat', 'dog']) + """ + + def __init__(self, root, transform=None, target_transform=None, loader=default_loader, classes=None): + if not classes: + classes, class_to_idx = find_classes(root) + else: + class_to_idx = {classes[i]: i for i in range(len(classes))} imgs = make_dataset(root, class_to_idx) if len(imgs) == 0: raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"