From 228ec10885cfbb9b93f3457d95c940d30b16093a Mon Sep 17 00:00:00 2001 From: Ruth Fong Date: Mon, 9 Apr 2018 17:27:15 +0100 Subject: [PATCH 1/7] add FGVCAircraft --- torchvision/datasets/__init__.py | 3 +- torchvision/datasets/aircraft.py | 128 +++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 torchvision/datasets/aircraft.py diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index e2d2801216a..476aff118cf 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,3 +1,4 @@ +from .aircraft import FGVCAircraft from .lsun import LSUN, LSUNClass from .folder import ImageFolder, DatasetFolder from .coco import CocoCaptions, CocoDetection @@ -10,7 +11,7 @@ from .semeion import SEMEION from .omniglot import Omniglot -__all__ = ('LSUN', 'LSUNClass', +__all__ = ('FGVCAircraft', 'LSUN', 'LSUNClass', 'ImageFolder', 'DatasetFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', diff --git a/torchvision/datasets/aircraft.py b/torchvision/datasets/aircraft.py new file mode 100644 index 00000000000..50eff4e7778 --- /dev/null +++ b/torchvision/datasets/aircraft.py @@ -0,0 +1,128 @@ +from __future__ import print_function +import torch.utils.data as data +from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader +from PIL import Image +import os +import numpy as np + +AIRPLANE_CLASS_TYPES = ['variant', 'family', 'manufacturer'] + + +def make_dataset(dir, image_ids, targets): + assert(len(image_ids) == len(targets)) + images = [] + dir = os.path.expanduser(dir) + for i in range(len(image_ids)): + item = (os.path.join(dir, 'data', 'images', + '%s.jpg' % image_ids[i]), targets[i]) + images.append(item) + return images + + +def find_classes(classes_file): + # read classes file, separating out image IDs and class names + image_ids = [] + targets = [] + f = open(classes_file, 'r') + for line in f: + split_line = line.split(' ') + image_ids.append(split_line[0]) + targets.append(' '.join(split_line[1:])) + f.close() + + # index class names + classes = np.unique(targets) + class_to_idx = {classes[i]: i for i in range(len(classes))} + targets = [class_to_idx[c] for c in targets] + + return (image_ids, targets, classes, class_to_idx) + +class FGVCAircraft(data.Dataset): + """`FGVC-Aircraft `_ Dataset. + + Args: + root (string): Root directory path to dataset. + class_type (string, optional): The level of FGVC-Aircraft fine-grain classification + to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g. ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in the root directory. If dataset is already downloaded, it is not + downloaded again. + """ + url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' + + def __init__(self, root, class_type='variant', train=True, transform=None, + target_transform=None, loader=default_loader, download=False): + assert(class_type in AIRPLANE_CLASS_TYPES) + self.root = os.path.expanduser(root) + self.class_type = class_type + self.split = 'train' if train else 'val' + self.classes_file = os.path.join(root, 'data', + 'images_%s_%s.txt' % (self.class_type, self.split)) + + if download: + self.download() + + (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) + samples = make_dataset(root, image_ids, targets) + + self.transform = transform + self.target_transform = target_transform + self.loader = loader + self.train = train + + self.samples = samples + self.classes = classes + self.class_to_idx = class_to_idx + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + def __len__(self): + return len(self.samples) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + def _check_exists(self): + return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ + os.path.exists(self.classes_file) + + def download(self): + """Download the FGVC-Aircraft data if it doesn't exist already.""" + raise NotImplementedError + from six.moves import urllib + import tarfile + + if self._check_exists(): + return + print('Downloading ' + url) + tmp_file = urlretrieve(url, filename=None)[0] + tar = tarfile.open(file_tmp) + From f042e9e5a6a8120797e69a0f991be04f0d1a9f9c Mon Sep 17 00:00:00 2001 From: Ruth Fong Date: Mon, 9 Apr 2018 23:56:28 +0100 Subject: [PATCH 2/7] finish download function --- torchvision/datasets/aircraft.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/torchvision/datasets/aircraft.py b/torchvision/datasets/aircraft.py index 50eff4e7778..b5839d5c730 100644 --- a/torchvision/datasets/aircraft.py +++ b/torchvision/datasets/aircraft.py @@ -5,6 +5,7 @@ import os import numpy as np + AIRPLANE_CLASS_TYPES = ['variant', 'family', 'manufacturer'] @@ -37,6 +38,7 @@ def find_classes(classes_file): return (image_ids, targets, classes, class_to_idx) + class FGVCAircraft(data.Dataset): """`FGVC-Aircraft `_ Dataset. @@ -116,13 +118,35 @@ def _check_exists(self): def download(self): """Download the FGVC-Aircraft data if it doesn't exist already.""" - raise NotImplementedError from six.moves import urllib import tarfile if self._check_exists(): return - print('Downloading ' + url) - tmp_file = urlretrieve(url, filename=None)[0] - tar = tarfile.open(file_tmp) + # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz + print('Downloading %s ... (may take a few minutes)' % self.url) + parent_dir = os.path.abspath(os.path.join(self.root, os.pardir)) + tar_name = self.url.rpartition('/')[-1] + tar_path = os.path.join(parent_dir, tar_name) + data = urllib.request.urlopen(self.url) + + # download .tar.gz file + with open(tar_path, 'wb') as f: + f.write(data.read()) + + # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b + print('Extracting %s... (may take a few minutes)' % tar_path) + tar = tarfile.open(tar_path) + tar.extractall(parent_dir) + + # rename data folder to self.root + tmp_folder = tar_path.strip('.tar.gz') + print('Rename data folder %s to %s' % (tmp_folder, self.root)) + os.rename(tmp_folder, self.root) + + # delete .tar.gz file + print('Delete .tar.gz file %s' % tar_path) + os.remove(tar_path) + + print('Done!') From e672b5ec611af846d5f54130cdb685c608399a23 Mon Sep 17 00:00:00 2001 From: Ruth Fong Date: Tue, 10 Apr 2018 00:01:43 +0100 Subject: [PATCH 3/7] fix root reference bug --- torchvision/datasets/aircraft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/datasets/aircraft.py b/torchvision/datasets/aircraft.py index b5839d5c730..b4e0e38ad2d 100644 --- a/torchvision/datasets/aircraft.py +++ b/torchvision/datasets/aircraft.py @@ -63,14 +63,14 @@ def __init__(self, root, class_type='variant', train=True, transform=None, self.root = os.path.expanduser(root) self.class_type = class_type self.split = 'train' if train else 'val' - self.classes_file = os.path.join(root, 'data', + self.classes_file = os.path.join(self.root, 'data', 'images_%s_%s.txt' % (self.class_type, self.split)) if download: self.download() (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) - samples = make_dataset(root, image_ids, targets) + samples = make_dataset(self.root, image_ids, targets) self.transform = transform self.target_transform = target_transform From 8ce4da7b9beb7030c5ec8f5516b334273995aeb3 Mon Sep 17 00:00:00 2001 From: Ruth Fong Date: Tue, 10 Apr 2018 00:10:05 +0100 Subject: [PATCH 4/7] update comments --- torchvision/datasets/aircraft.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchvision/datasets/aircraft.py b/torchvision/datasets/aircraft.py index b4e0e38ad2d..fce491ca3a0 100644 --- a/torchvision/datasets/aircraft.py +++ b/torchvision/datasets/aircraft.py @@ -136,17 +136,18 @@ def download(self): f.write(data.read()) # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b - print('Extracting %s... (may take a few minutes)' % tar_path) + data_folder = tar_path.strip('.tar.gz') + print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder)) tar = tarfile.open(tar_path) tar.extractall(parent_dir) - # rename data folder to self.root - tmp_folder = tar_path.strip('.tar.gz') - print('Rename data folder %s to %s' % (tmp_folder, self.root)) - os.rename(tmp_folder, self.root) + # if necessary, rename data folder to self.root + if not os.path.samefile(data_folder, self.root): + print('Renaming %s to %s ...' % (data_folder, self.root)) + os.rename(data_folder, self.root) # delete .tar.gz file - print('Delete .tar.gz file %s' % tar_path) + print('Deleting %s ...' % tar_path) os.remove(tar_path) print('Done!') From 2c5b549ce90ec6beb73de902564af1193f7c8dd2 Mon Sep 17 00:00:00 2001 From: Ruth Fong Date: Tue, 10 Apr 2018 00:22:38 +0100 Subject: [PATCH 5/7] fix flake errors --- torchvision/datasets/aircraft.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchvision/datasets/aircraft.py b/torchvision/datasets/aircraft.py index fce491ca3a0..d13d5a4ef91 100644 --- a/torchvision/datasets/aircraft.py +++ b/torchvision/datasets/aircraft.py @@ -14,7 +14,7 @@ def make_dataset(dir, image_ids, targets): images = [] dir = os.path.expanduser(dir) for i in range(len(image_ids)): - item = (os.path.join(dir, 'data', 'images', + item = (os.path.join(dir, 'data', 'images', '%s.jpg' % image_ids[i]), targets[i]) images.append(item) return images @@ -48,7 +48,7 @@ class FGVCAircraft(data.Dataset): to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g. ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the + target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. download (bool, optional): If true, downloads the dataset from the internet and @@ -57,14 +57,14 @@ class FGVCAircraft(data.Dataset): """ url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' - def __init__(self, root, class_type='variant', train=True, transform=None, + def __init__(self, root, class_type='variant', train=True, transform=None, target_transform=None, loader=default_loader, download=False): assert(class_type in AIRPLANE_CLASS_TYPES) self.root = os.path.expanduser(root) self.class_type = class_type self.split = 'train' if train else 'val' - self.classes_file = os.path.join(self.root, 'data', - 'images_%s_%s.txt' % (self.class_type, self.split)) + self.classes_file = os.path.join(self.root, 'data', + 'images_%s_%s.txt' % (self.class_type, self.split)) if download: self.download() @@ -76,10 +76,10 @@ def __init__(self, root, class_type='variant', train=True, transform=None, self.target_transform = target_transform self.loader = loader self.train = train - + self.samples = samples self.classes = classes - self.class_to_idx = class_to_idx + self.class_to_idx = class_to_idx def __getitem__(self, index): """ @@ -100,7 +100,7 @@ def __getitem__(self, index): return sample, target def __len__(self): - return len(self.samples) + return len(self.samples) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' From 91f62ad08993d6ba0600ef36e22126af2f687f1e Mon Sep 17 00:00:00 2001 From: Ruth Fong Date: Tue, 10 Apr 2018 00:30:37 +0100 Subject: [PATCH 6/7] add error checking like that in STL10 --- torchvision/datasets/aircraft.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/torchvision/datasets/aircraft.py b/torchvision/datasets/aircraft.py index d13d5a4ef91..54adc5c01ce 100644 --- a/torchvision/datasets/aircraft.py +++ b/torchvision/datasets/aircraft.py @@ -6,7 +6,6 @@ import numpy as np -AIRPLANE_CLASS_TYPES = ['variant', 'family', 'manufacturer'] def make_dataset(dir, image_ids, targets): @@ -56,13 +55,22 @@ class FGVCAircraft(data.Dataset): downloaded again. """ url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' + class_types = ('variant', 'family', 'manufacturer') + splits = ('train', 'val', 'trainval', 'test') - def __init__(self, root, class_type='variant', train=True, transform=None, + def __init__(self, root, class_type='variant', split='train', transform=None, target_transform=None, loader=default_loader, download=False): - assert(class_type in AIRPLANE_CLASS_TYPES) + if split not in self.splits: + raise ValueError('Split "{}" not found. Valid splits are: {}'.format( + split, ', '.join(self.splits), + )) + if class_type not in self.class_types: + raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( + class_type, ', '.join(self.class_types), + )) self.root = os.path.expanduser(root) self.class_type = class_type - self.split = 'train' if train else 'val' + self.split = split self.classes_file = os.path.join(self.root, 'data', 'images_%s_%s.txt' % (self.class_type, self.split)) @@ -75,7 +83,6 @@ def __init__(self, root, class_type='variant', train=True, transform=None, self.transform = transform self.target_transform = target_transform self.loader = loader - self.train = train self.samples = samples self.classes = classes From b9aa402bc59371e5208d3fc458423d1507d883cc Mon Sep 17 00:00:00 2001 From: Ruth Fong Date: Tue, 10 Apr 2018 00:31:49 +0100 Subject: [PATCH 7/7] delete extra blank lines --- torchvision/datasets/aircraft.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/datasets/aircraft.py b/torchvision/datasets/aircraft.py index 54adc5c01ce..ab5918971a8 100644 --- a/torchvision/datasets/aircraft.py +++ b/torchvision/datasets/aircraft.py @@ -6,8 +6,6 @@ import numpy as np - - def make_dataset(dir, image_ids, targets): assert(len(image_ids) == len(targets)) images = []