Skip to content

Add optional classes param to ImageFolder and add docs. #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def find_classes(dir):

def make_dataset(dir, class_to_idx):
images = []
for target in os.listdir(dir):
for target in class_to_idx.keys():
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
Expand All @@ -43,10 +43,33 @@ def default_loader(path):


class ImageFolder(data.Dataset):

def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
classes, class_to_idx = find_classes(root)
"""
A class representing a directory of images as a `Dataset`.

Args:
root (String): The path to the directory.
transform (Object): A callable object that transforms images. See: torchvision/transforms.py. By default no
transformations will be applied to images.
target_transform (Object): A callable object that transforms targets (labels). By default no transformations
will be applied to targets.
loader (Function): Loads an image and returns it in a usable form. By default loads images by
their path and returns a `PIL.Image` instance.
classes (List/Tuple): The sub-directories of `root` that correspond to the classes of this data set. By default
all sub-directories of `root` are used.

Example:
>>> dataset = folder.ImageFolder('./dataset', transform=transforms.Compose([
>>> transforms.Scale(size=224),
>>> transforms.RandomCrop(size=224),
>>> transforms.ToTensor()
>>> ]), classes=['cat', 'dog'])
"""

def __init__(self, root, transform=None, target_transform=None, loader=default_loader, classes=None):
if not classes:
classes, class_to_idx = find_classes(root)
else:
class_to_idx = {classes[i]: i for i in range(len(classes))}
imgs = make_dataset(root, class_to_idx)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
Expand Down