Skip to content

transforms.ColorJitter().get_params(...) does not support float inputs #2669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
sagadre opened this issue Sep 11, 2020 · 2 comments · May be fixed by #2672
Open

transforms.ColorJitter().get_params(...) does not support float inputs #2669

sagadre opened this issue Sep 11, 2020 · 2 comments · May be fixed by #2672

Comments

@sagadre
Copy link

sagadre commented Sep 11, 2020

🐛 Bug

The docstring for transforms.ColorJitter().get_params(...) states that it accepts float and tuple inputs, just as ColorJitter's __init__(...), however, the current implementation supports tuple inputs only. Unlike the constructor it does not support float inputs.

To Reproduce

Steps to reproduce the behavior:

`from torchvision import transforms as T

t = T.ColorJitter()
t = t.get_params(0.4, 0.4, 0.4, 0.2)`

Expected behavior

the .get_params(...) should take float inputs.

  • PyTorch / torchvision Version (e.g., 1.0 / 0.4.0): 1.4.0
  • OS (e.g., Linux): Linux
  • How you installed PyTorch / torchvision (conda, pip, source): pip
  • Build command you used (if compiling from source): N/A
  • Python version: 3.8
  • CUDA/cuDNN version: 10.1
  • GPU models and configuration: GeForce 2080ti
  • Any other relevant information: N/A

cc @vfdev-5

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 14, 2020

@sagadre thanks for reporting !

Yes, this is a bug for 0.4.0 version, but currently ColorJitter's behaviour in the master is such that ColorJitter.get_params() is not used inside __call__/forward and kept for backward compatibility: #2298 (comment). Getting back ColorJitter.get_params is low-priority issue and once it is done, due to torch jit limitation on Union usage, we can not make it work on simple floats and list of floats...

@sagadre
Copy link
Author

sagadre commented Sep 16, 2020

@vfdev-5 I see. Thanks for the explanation!

My use case is that I want to apply the same transform (color jitter) to two images.

Here is a simple workaround for those interested:

from torchvision import transforms as T

cj = T.ColorJitter(0.4, 0.4, 0.4, 0.2)
t = cj.get_params(cj.brightness, cj.contrast, cj.saturation, cj.hue)

# t is now frozen (no randomness when called) and can be used to transform both images with the same transform.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants