diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 5f8137910b4..5624054e7cd 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -1,4 +1,5 @@ from torchvision import models from torchvision import datasets from torchvision import transforms +from torchvision import joint_transforms from torchvision import utils diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index a8606a73849..8f1936335b0 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -6,6 +6,7 @@ from .mnist import MNIST from .svhn import SVHN from .phototour import PhotoTour +from .camvid import CamVid __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', diff --git a/torchvision/datasets/camvid.py b/torchvision/datasets/camvid.py new file mode 100644 index 00000000000..70f0ddd5622 --- /dev/null +++ b/torchvision/datasets/camvid.py @@ -0,0 +1,125 @@ +from __future__ import print_function + +import os +import torch +import torch.utils.data as data +import numpy as np +from PIL import Image +from .folder import is_image_file, default_loader + + +classes = ['Sky', 'Building', 'Column-Pole', 'Road', + 'Sidewalk', 'Tree', 'Sign-Symbol', 'Fence', 'Car', 'Pedestrain', + 'Bicyclist', 'Void'] + +# weights when using median frequency balancing used in SegNet paper +# https://arxiv.org/pdf/1511.00561.pdf +# The numbers were generated by https://github.com/yandex/segnet-torch/blob/master/datasets/camvid-gen.lua +class_weight = [0.58872014284134, 0.51052379608154, 2.6966278553009, 0.45021694898605, 1.1785038709641, + 0.77028578519821, 2.4782588481903, 2.5273461341858, 1.0122526884079, 3.2375309467316, + 4.1312313079834, 0] +# mean and std +mean = [0.41189489566336, 0.4251328133025, 0.4326707089857] +std = [0.27413549931506, 0.28506257482912, 0.28284674400252] + +class_color = [ + (128, 128, 128), + (128, 0, 0), + (192, 192, 128), + (128, 64, 128), + (0, 0, 192), + (128, 128, 0), + (192, 128, 128), + (64, 64, 128), + (64, 0, 128), + (64, 64, 0), + (0, 128, 192), + (0, 0, 0), +] + + +def _make_dataset(dir): + images = [] + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + item = path + images.append(item) + return images + + +class LabelToLongTensor(object): + def __call__(self, pic): + if isinstance(pic, np.ndarray): + # handle numpy array + label = torch.from_numpy(pic).long() + else: + label = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + label = label.view(pic.size[1], pic.size[0], 1) + label = label.transpose(0, 1).transpose(0, 2).squeeze().contiguous().long() + return label + + +class LabelTensorToPILImage(object): + def __call__(self, label): + label = label.unsqueeze(0) + colored_label = torch.zeros(3, label.size(1), label.size(2)).byte() + for i, color in enumerate(class_color): + mask = label.eq(i) + for j in range(3): + colored_label[j].masked_fill_(mask, color[j]) + npimg = colored_label.numpy() + npimg = np.transpose(npimg, (1, 2, 0)) + mode = None + if npimg.shape[2] == 1: + npimg = npimg[:, :, 0] + mode = "L" + + return Image.fromarray(npimg, mode=mode) + + +class CamVid(data.Dataset): + + def __init__(self, root, split='train', joint_transform=None, + transform=None, target_transform=LabelToLongTensor(), download=False, + loader=default_loader): + self.root = root + assert split in ('train', 'val', 'test') + self.split = split + self.transform = transform + self.target_transform = target_transform + self.joint_transform = joint_transform + self.loader = loader + self.class_weight = class_weight + self.classes = classes + self.class_weight = class_weight + self.mean = mean + self.std = std + + if download: + self.download() + + self.imgs = _make_dataset(os.path.join(self.root, self.split)) + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + target = Image.open(path.replace(self.split, self.split + 'annot')) + + if self.joint_transform is not None: + img, target = self.joint_transform([img, target]) + + if self.transform is not None: + img = self.transform(img) + + target = self.target_transform(target) + return img, target + + def __len__(self): + return len(self.imgs) + + def download(self): + # TODO: please download the dataset from + # https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid + raise NotImplementedError diff --git a/torchvision/joint_transforms.py b/torchvision/joint_transforms.py new file mode 100644 index 00000000000..a5ffe3d1358 --- /dev/null +++ b/torchvision/joint_transforms.py @@ -0,0 +1,155 @@ +from __future__ import division +import torch +import math +import random +from PIL import Image, ImageOps +import numpy as np +import numbers +import types + + +class JointScale(object): + """Rescales the input PIL.Image to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, imgs): + w, h = imgs[0].size + if (w <= h and w == self.size) or (h <= w and h == self.size): + return imgs + if w < h: + ow = self.size + oh = int(self.size * h / w) + return [img.resize((ow, oh), self.interpolation) for img in imgs] + else: + oh = self.size + ow = int(self.size * w / h) + return [img.resize((ow, oh), self.interpolation) for img in imgs] + + +class JointCenterCrop(object): + """Crops the given PIL.Image at the center to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, imgs): + w, h = imgs[0].size + th, tw = self.size + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + return [img.crop((x1, y1, x1 + tw, y1 + th)) for img in imgs] + + +class JointPad(object): + """Pads the given PIL.Image on all sides with the given "pad" value""" + + def __init__(self, padding, fill=0): + assert isinstance(padding, numbers.Number) + assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) + self.padding = padding + self.fill = fill + + def __call__(self, imgs): + return [ImageOps.expand(img, border=self.padding, fill=self.fill) for img in imgs] + + +class JointLambda(object): + """Applies a lambda as a transform.""" + + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, imgs): + return [self.lambd(img) for img in imgs] + + +class JointRandomCrop(object): + """Crops the given list of PIL.Image at a random location to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + + def __init__(self, size, padding=0): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + + def __call__(self, imgs): + if self.padding > 0: + imgs = [ImageOps.expand(img, border=self.padding, fill=0) for img in imgs] + + w, h = imgs[0].size + th, tw = self.size + if w == tw and h == th: + return imgs + + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + return [img.crop((x1, y1, x1 + tw, y1 + th)) for img in imgs] + + +class JointRandomHorizontalFlip(object): + """Randomly horizontally flips the given list of PIL.Image with a probability of 0.5 + """ + + def __call__(self, imgs): + if random.random() < 0.5: + return [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs] + return imgs + + +class JointRandomSizedCrop(object): + """Random crop the given list of PIL.Image to a random size of (0.08 to 1.0) of the original size + and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio + This is popularly used to train the Inception networks + size: size of the smaller edge + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, imgs): + for attempt in range(10): + area = imgs[0].size[0] * imgs[0].size[1] + target_area = random.uniform(0.08, 1.0) * area + aspect_ratio = random.uniform(3. / 4, 4. / 3) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if random.random() < 0.5: + w, h = h, w + + if w <= imgs[0].size[0] and h <= imgs[0].size[1]: + x1 = random.randint(0, imgs[0].size[0] - w) + y1 = random.randint(0, imgs[0].size[1] - h) + + imgs = [img.crop((x1, y1, x1 + w, y1 + h)) for img in imgs] + assert(imgs[0].size == (w, h)) + + return [img.resize((self.size, self.size), self.interpolation) for img in imgs] + + # Fallback + scale = JointScale(self.size, interpolation=self.interpolation) + crop = JointCenterCrop(self.size) + return crop(scale(imgs))