Skip to content

Support for customized channel TIF input image #514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
YubinXie opened this issue May 24, 2018 · 7 comments
Open

Support for customized channel TIF input image #514

YubinXie opened this issue May 24, 2018 · 7 comments

Comments

@YubinXie
Copy link

YubinXie commented May 24, 2018

I am currently working on medical images, where one field of view could have more than 3 channels (up to 50+ channels). Current loader does not support customized channel input. If you give it a stacked TIF image, it use convert 'RGB' function to make it something else. If I change the loader and make it output the n-d array of the input image, it gives future error because it assumes PIL image format as input. Any solutions? Thanks.

@fmassa
Copy link
Member

fmassa commented May 24, 2018

I think in order to properly support this it would be better to refactor our to_tensor / to_pil_image to be more generic.

As a workaround, do you need to perform any data augmentations in your code (rotation, flip, etc)?
If not, you can just use the LambdaTransform to convert your array into a torch tensor.

Also, do you want to normalize the image to be in a specific range?

It would be great if we could have some pointers of some reference images and tools that could be used to load them, so that we can take that into account when improving those transforms.

@YubinXie
Copy link
Author

YubinXie commented May 25, 2018

It worked if I have no transformation. But I do need transformation and normalization.

One thing I felt weird is that, current transformation is done via PIL, that is the reason PIL format is required. I think it would make more sense if you can do transformation on the np.array or tensor format.

@fmassa
Copy link
Member

fmassa commented May 25, 2018

Yes, requiring the image to be a PIL image has some limitations.
That's why I'm thinking about modifying the existing transforms and create a wrapper Image class, that could hold other formats as well.

But for that to work, we would need to know a bit better how the other formats are structured, and what is the representation that is used. Could you give us some pointers?

@YubinXie
Copy link
Author

YubinXie commented May 31, 2018

Majority images are grey or RGB. The only images I know that contain channel number above 4 is TIF stacked image. So I would say add a support for TIF stacked image for read-in part. Once read-in an image, convert it to numpy array, or pytorch tensor. And apply all the transformation on these.

@vadimkantorov
Copy link

vadimkantorov commented Jun 18, 2018

@fmassa It would be nice if PyTorch supported running without Pillow(-SIMD which can have installation issues, since it requires to uninstall pillow package which some other pip packages can reinstall back).

With nn.functional.upsample and .flip in core, is image file loading the only thing that PIL is indispensable?

@fmassa
Copy link
Member

fmassa commented Jul 5, 2018

@vadimkantorov almost.
Note that for the moment, upsample only works on floating point data, but I'd expect that we would add support for other data types in the near future.
Also, there are a few other operations that require a bit of work to get running (like rotations), even if that could be implemented using grid_sample (for floating point types only).

One thing I was thinking was to eventually extend the support of grid_sample to also perform bicubic interpolation, I think it could be useful for a few applications.

@mayaghaei
Copy link

I just discovered a way to do this, I am not sure it can be solution to your problem, but I'll share it in case it can be useful to others.
In my case I had multi-channel Tiff images, and I wanted to classify them using CNNs in Pytorch. I honestly gave up on data augmentation using Transforms in Pytorch, and I performed data augmentation offline (let's say in my input folders I have original data as well as augmented ones).
The game changer is however defining your own loader + taking advantage of Tifffile library in python. This is how I did it for my training set (val and test should be the same):

import tifffile
def my_tiff_loader(filename):
    return tifffile.imread(filename)

train_transform = transforms.Compose([transforms.ToTensor()])

train_data = datasets.ImageFolder('PATH TO TRAINSET', loader=my_tiff_loader, transform=train_transform)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants