Skip to content

Add CamVid dataset for segmentation #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from torchvision import models
from torchvision import datasets
from torchvision import transforms
from torchvision import joint_transforms
from torchvision import utils
1 change: 1 addition & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .mnist import MNIST
from .svhn import SVHN
from .phototour import PhotoTour
from .camvid import CamVid

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder',
Expand Down
125 changes: 125 additions & 0 deletions torchvision/datasets/camvid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import print_function

import os
import torch
import torch.utils.data as data
import numpy as np
from PIL import Image
from .folder import is_image_file, default_loader


classes = ['Sky', 'Building', 'Column-Pole', 'Road',
'Sidewalk', 'Tree', 'Sign-Symbol', 'Fence', 'Car', 'Pedestrain',
'Bicyclist', 'Void']

# weights when using median frequency balancing used in SegNet paper
# https://arxiv.org/pdf/1511.00561.pdf
# The numbers were generated by https://github.com/yandex/segnet-torch/blob/master/datasets/camvid-gen.lua
class_weight = [0.58872014284134, 0.51052379608154, 2.6966278553009, 0.45021694898605, 1.1785038709641,
0.77028578519821, 2.4782588481903, 2.5273461341858, 1.0122526884079, 3.2375309467316,
4.1312313079834, 0]
# mean and std
mean = [0.41189489566336, 0.4251328133025, 0.4326707089857]
std = [0.27413549931506, 0.28506257482912, 0.28284674400252]

class_color = [
(128, 128, 128),
(128, 0, 0),
(192, 192, 128),
(128, 64, 128),
(0, 0, 192),
(128, 128, 0),
(192, 128, 128),
(64, 64, 128),
(64, 0, 128),
(64, 64, 0),
(0, 128, 192),
(0, 0, 0),
]


def _make_dataset(dir):
images = []
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
item = path
images.append(item)
return images


class LabelToLongTensor(object):
def __call__(self, pic):
if isinstance(pic, np.ndarray):
# handle numpy array
label = torch.from_numpy(pic).long()
else:
label = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
label = label.view(pic.size[1], pic.size[0], 1)
label = label.transpose(0, 1).transpose(0, 2).squeeze().contiguous().long()
return label


class LabelTensorToPILImage(object):
def __call__(self, label):
label = label.unsqueeze(0)
colored_label = torch.zeros(3, label.size(1), label.size(2)).byte()
for i, color in enumerate(class_color):
mask = label.eq(i)
for j in range(3):
colored_label[j].masked_fill_(mask, color[j])
npimg = colored_label.numpy()
npimg = np.transpose(npimg, (1, 2, 0))
mode = None
if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]
mode = "L"

return Image.fromarray(npimg, mode=mode)


class CamVid(data.Dataset):

def __init__(self, root, split='train', joint_transform=None,
transform=None, target_transform=LabelToLongTensor(), download=False,
loader=default_loader):
self.root = root
assert split in ('train', 'val', 'test')
self.split = split
self.transform = transform
self.target_transform = target_transform
self.joint_transform = joint_transform
self.loader = loader
self.class_weight = class_weight
self.classes = classes
self.class_weight = class_weight
self.mean = mean
self.std = std

if download:
self.download()

self.imgs = _make_dataset(os.path.join(self.root, self.split))

def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
target = Image.open(path.replace(self.split, self.split + 'annot'))

if self.joint_transform is not None:
img, target = self.joint_transform([img, target])

if self.transform is not None:
img = self.transform(img)

target = self.target_transform(target)
return img, target

def __len__(self):
return len(self.imgs)

def download(self):
# TODO: please download the dataset from
# https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid
raise NotImplementedError
155 changes: 155 additions & 0 deletions torchvision/joint_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import division
import torch
import math
import random
from PIL import Image, ImageOps
import numpy as np
import numbers
import types


class JointScale(object):
"""Rescales the input PIL.Image to the given 'size'.
'size' will be the size of the smaller edge.
For example, if height > width, then image will be
rescaled to (size * height / width, size)
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""

def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation

def __call__(self, imgs):
w, h = imgs[0].size
if (w <= h and w == self.size) or (h <= w and h == self.size):
return imgs
if w < h:
ow = self.size
oh = int(self.size * h / w)
return [img.resize((ow, oh), self.interpolation) for img in imgs]
else:
oh = self.size
ow = int(self.size * w / h)
return [img.resize((ow, oh), self.interpolation) for img in imgs]


class JointCenterCrop(object):
"""Crops the given PIL.Image at the center to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""

def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size

def __call__(self, imgs):
w, h = imgs[0].size
th, tw = self.size
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return [img.crop((x1, y1, x1 + tw, y1 + th)) for img in imgs]


class JointPad(object):
"""Pads the given PIL.Image on all sides with the given "pad" value"""

def __init__(self, padding, fill=0):
assert isinstance(padding, numbers.Number)
assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
self.padding = padding
self.fill = fill

def __call__(self, imgs):
return [ImageOps.expand(img, border=self.padding, fill=self.fill) for img in imgs]


class JointLambda(object):
"""Applies a lambda as a transform."""

def __init__(self, lambd):
assert isinstance(lambd, types.LambdaType)
self.lambd = lambd

def __call__(self, imgs):
return [self.lambd(img) for img in imgs]


class JointRandomCrop(object):
"""Crops the given list of PIL.Image at a random location to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""

def __init__(self, size, padding=0):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding

def __call__(self, imgs):
if self.padding > 0:
imgs = [ImageOps.expand(img, border=self.padding, fill=0) for img in imgs]

w, h = imgs[0].size
th, tw = self.size
if w == tw and h == th:
return imgs

x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
return [img.crop((x1, y1, x1 + tw, y1 + th)) for img in imgs]


class JointRandomHorizontalFlip(object):
"""Randomly horizontally flips the given list of PIL.Image with a probability of 0.5
"""

def __call__(self, imgs):
if random.random() < 0.5:
return [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs]
return imgs


class JointRandomSizedCrop(object):
"""Random crop the given list of PIL.Image to a random size of (0.08 to 1.0) of the original size
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
This is popularly used to train the Inception networks
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""

def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation

def __call__(self, imgs):
for attempt in range(10):
area = imgs[0].size[0] * imgs[0].size[1]
target_area = random.uniform(0.08, 1.0) * area
aspect_ratio = random.uniform(3. / 4, 4. / 3)

w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))

if random.random() < 0.5:
w, h = h, w

if w <= imgs[0].size[0] and h <= imgs[0].size[1]:
x1 = random.randint(0, imgs[0].size[0] - w)
y1 = random.randint(0, imgs[0].size[1] - h)

imgs = [img.crop((x1, y1, x1 + w, y1 + h)) for img in imgs]
assert(imgs[0].size == (w, h))

return [img.resize((self.size, self.size), self.interpolation) for img in imgs]

# Fallback
scale = JointScale(self.size, interpolation=self.interpolation)
crop = JointCenterCrop(self.size)
return crop(scale(imgs))