diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 2be81389c2c..eca4d9cada8 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -208,3 +208,102 @@ def __call__(self, img): scale = Scale(self.size, interpolation=self.interpolation) crop = CenterCrop(self.size) return crop(scale(img)) + + + +class Lighting(object): + """Lighting noise(AlexNet - style PCA - based noise)""" + + def __init__(self, alphastd, eigval, eigvec): + self.alphastd = alphastd + self.eigval = eigval + self.eigvec = eigvec + + def __call__(self, img): + if self.alphastd == 0: + return img + + alpha = torch.Tensor(3).normal_(0, self.alphastd) + rgb = self.eigvec.clone()\ + .mul(alpha.view(1, 3).expand(3, 3))\ + .mul(self.eigval.view(1, 3).expand(3, 3))\ + .sum(1).squeeze() + + img.add_(rgb.view(3, 1, 1).expand_as(img)) + return img + + +class Grayscale(object): + + def __call__(self, img): + img[0].mul_(0.299).add_(0.587, img[1]).add_(0.114, img[2]) + img[1].copy_(img[0]) + img[2].copy_(img[0]) + return img + + +class Saturation(object): + + def __init__(self, var): + self.var = var + + def __call__(self, img): + gs = Grayscale()(img.clone()) + alpha = random.uniform(0, self.var) + img.lerp_(gs, alpha) + + return img + + +class Brightness(object): + + def __init__(self, var): + self.var = var + + def __call__(self, img): + gs = img.new().resize_as_(img).zero_() + alpha = random.uniform(0, self.var) + img.lerp_(gs, alpha) + + return img + + +class Contrast(object): + + def __init__(self, var): + self.var = var + + def __call__(self, img): + gs = Grayscale()(img.clone()) + gs.fill_(gs.mean()) + alpha = random.uniform(0, self.var) + img.lerp_(gs, alpha) + return img + + +class RandomOrder(object): + """ Composes several transforms together in random order. + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img): + if self.transforms is None: + return img + order = torch.randperm(len(self.transforms)) + for i in order: + img = self.transforms[i](img) + return img + + +class ColorJitter(RandomOrder): + + def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): + self.transforms = [] + if brightness != 0: + self.transforms.append(Brightness(brightness)) + if contrast != 0: + self.transforms.append(Contrast(contrast)) + if saturation != 0: + self.transforms.append(Saturation(saturation))