Skip to content

[Feature] Extend transformations to handle mutli-band non-uint8 images #422

Closed
@vfdev-5

Description

@vfdev-5

Here I would like to discuss how to extend ToPILImage, ToTensor and transformations taking pil image as input to work with mutli-band non-uint8 images. Today, to following code does not work:

data = np.random.randint(0, 100, size=(10, 10, 3)).astype(np.float32)
t = Compose([
        ToPILImage(),
        RandomCrop(size=(5, 5))
    ])
print(t(data))

The error is TypeError: Input type float32 is not supported from ToPILImage, because PIL does not have a mode multi-band image with depth float32.

As PIL documentation suggests (ref):

if you need to handle band combinations that are not listed above, use a sequence of Image objects.

we can modify F.to_pil_image such that it could recognize the mode for a single band and then produce a list of pil images:

if mode is None:
    expected_mode = _get_expected_mode(npimg.dtype)
    if expected_mode is None:
        raise TypeError('Input type {} is not supported'.format(npimg.dtype))
    
    return [Image.fromarray(npimg[:, :, c], mode=expected_mode) for c in range(npimg.shape[2])]

Next we need to change every transformation that uses PIL image as input:

  • either each transformation handle whether img is PIL image or list
  • or with a factorized code: all transformations (that uses PIL image as input) inherit from
class TransformPIL(object):
    @staticmethod
    def handle(img, transformation, **kwargs):
        if isinstance(img, (list, tuple)):
            return [transformation(band, **kwargs) for band in img]
        else:
            return transformation(img, **kwargs)

and then, for example, Resize is

class Resize(TransformPIL):
    def __init__(self, size, interpolation=Image.BILINEAR):
        # ... AS BEFORE ...

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be scaled.

        Returns:
            PIL Image: Rescaled image.
        """        
        return self.handle(img, F.resize, size=self.size, interpolation=self.interpolation)

Finally, ToTensor class will gather list of pil images into a single tensor.

What do you think about this approach ? Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions