Skip to content

[Feature] Extend transformations to handle mutli-band non-uint8 images #422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
vfdev-5 opened this issue Feb 15, 2018 · 2 comments
Closed

Comments

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 15, 2018

Here I would like to discuss how to extend ToPILImage, ToTensor and transformations taking pil image as input to work with mutli-band non-uint8 images. Today, to following code does not work:

data = np.random.randint(0, 100, size=(10, 10, 3)).astype(np.float32)
t = Compose([
        ToPILImage(),
        RandomCrop(size=(5, 5))
    ])
print(t(data))

The error is TypeError: Input type float32 is not supported from ToPILImage, because PIL does not have a mode multi-band image with depth float32.

As PIL documentation suggests (ref):

if you need to handle band combinations that are not listed above, use a sequence of Image objects.

we can modify F.to_pil_image such that it could recognize the mode for a single band and then produce a list of pil images:

if mode is None:
    expected_mode = _get_expected_mode(npimg.dtype)
    if expected_mode is None:
        raise TypeError('Input type {} is not supported'.format(npimg.dtype))
    
    return [Image.fromarray(npimg[:, :, c], mode=expected_mode) for c in range(npimg.shape[2])]

Next we need to change every transformation that uses PIL image as input:

  • either each transformation handle whether img is PIL image or list
  • or with a factorized code: all transformations (that uses PIL image as input) inherit from
class TransformPIL(object):
    @staticmethod
    def handle(img, transformation, **kwargs):
        if isinstance(img, (list, tuple)):
            return [transformation(band, **kwargs) for band in img]
        else:
            return transformation(img, **kwargs)

and then, for example, Resize is

class Resize(TransformPIL):
    def __init__(self, size, interpolation=Image.BILINEAR):
        # ... AS BEFORE ...

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be scaled.

        Returns:
            PIL Image: Rescaled image.
        """        
        return self.handle(img, F.resize, size=self.size, interpolation=self.interpolation)

Finally, ToTensor class will gather list of pil images into a single tensor.

What do you think about this approach ? Thanks!

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Feb 16, 2018

Here is a way to implement the above idea

EDIT: All transformations are extended to support sequence of images.

@fmassa
Copy link
Member

fmassa commented Nov 30, 2020

This has been implemented in torchvision 0.8.0

@fmassa fmassa closed this as completed Nov 30, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants