Skip to content

OpenCV transforms and functional code #591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

OpenCV transforms and functional code #591

wants to merge 1 commit into from

Conversation

jbohnslav
Copy link

Added two files to 'transforms.' Doesn't interfere with any existing code.

I reimplemented transforms.py and functional.py using OpenCV 3.4.1. The code should output identical augmentation as the current Pillow-based implementation, with the exception of certain things like affine warping. I added some benchmarks that compare single-threaded performance in this github repository.

Usage:
import opencv_transforms as transforms

Everything else should work. There are a few flags, like resample, that currently don't do anything.

@soumith
Copy link
Member

soumith commented Aug 30, 2018

Hey Jim, this is awesome.

To help users with being very confident in switching between these transforms, we should compare: "OpenCV transforms vs PIL transforms" and unit-test them to have a sum of pixel-wise difference to not be too large.
Would you be up for adding those tests?

@soumith
Copy link
Member

soumith commented Aug 30, 2018

The other thing that'd be great is to integrate opencv transforms like we did accimage transforms, i.e. just have it be a "backend", while user-code remains exactly the same.

@soumith
Copy link
Member

soumith commented Aug 30, 2018

accimage for example can be enabled with: torchvision.set_image_backend('accimage')

@fmassa
Copy link
Member

fmassa commented Aug 31, 2018

Hi @jbohnslav ,

Thanks for the PR, this is looking good!
I second @soumith comment about making sure that we have the same results. One thing that needs modification is that the to_tensor function needs to perform the BGR->RGB conversion, and similarly to to_pil_image (which doesn't exist).

In those lines, I'm thinking about an unified interface for torch images, which internally supports PIL, OpenCV or other backends. They use the same underlying data type internally (maybe numpy arrays, maybe torch tensors), convert on-the-fly to the specific backend representation when a method is called. So we could use the implementation that you have here to fill the API for such a class.

I didn't finish drafting up something yet, but here is the gist of it. It is very incomplete and doesn't yet brings any advantage over what we currently have, but my goal with this would be to support a number of cases that are currently lacking in torchvision, such as #567 #514 #422 #546

class TorchImage(object):
    """
    This class is holds data representing images.
    Internally, the data is stored as a torch.Tensor,
    which can be seamlessly converted to a PIL image or
    a numpy array (for OpenCV).

    Warning: PIL I;16 is by default uint16, but PyTorch only support
    int16
    """
    @classmethod
    def open(file_path):
        pass

    @classmethod
    def from_pil(image):
        data = np.asarray(image)
        # FIXME I;16 is by default uint16, but PyTorch doesn't support it
        data = torch.as_tensor(data)
        return type(self)(data, backend="PIL")
    
    @classmethod
    def from_opencv(image):
        return type(self)(torch.as_tensor(image), format='BGR', backend="opencv")

    def __init__(self, data, **kwargs):
        assert isinstance(data, torch.Tensor)
        self.data = data
        self.opts = kwargs

    def to_tensor(self):
        # perform necessary conversions / normalization / RGB -> BGR / flipping
        # depending on the data type and shape
        tensor = self.tensor.permute(2, 1, 0)
        if self.opts.get("format", None) == "BGR":
            tensor = tensor[[2, 1, 0]]
        return tensor / get_max_tensor_element(tensor)

    def resize(self, size):
        # either use a bunch of conditionals, or inherit PILImage / OpenCVImage
        # classes from TorchImage that implement those functions, to be defined
        pass

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 31, 2018

If I could add my two cents for this, especially to handle multi-input data, I tried to prototype some concepts here : https://github.com/vfdev-5/VisionTransforms
Idea is to add multiple backends (numpy, Pillow, torch.tensor) and handle transforms (à la nn.Module) on multi-input. @fmassa I once posted it in the slack, so maybe you already saw it...
Any way, it would be great to arrive to a more concrete decision on how to cover all those problems :)
Thanks

@vadimkantorov
Copy link

These folks (they also wrote an OpenCV-based transform lib) seem to have done tests: https://github.com/albu/albumentations#benchmarking-results

@guohengkai
Copy link

Hi, how it is going?

@malfet malfet deleted the branch pytorch:master September 20, 2021 14:34
@malfet malfet closed this Sep 20, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants