You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The docstring for transforms.ColorJitter().get_params(...) states that it accepts float and tuple inputs, just as ColorJitter's __init__(...), however, the current implementation supports tuple inputs only. Unlike the constructor it does not support float inputs.
To Reproduce
Steps to reproduce the behavior:
`from torchvision import transforms as T
t = T.ColorJitter()
t = t.get_params(0.4, 0.4, 0.4, 0.2)`
Expected behavior
the .get_params(...) should take float inputs.
PyTorch / torchvision Version (e.g., 1.0 / 0.4.0): 1.4.0
OS (e.g., Linux): Linux
How you installed PyTorch / torchvision (conda, pip, source): pip
Build command you used (if compiling from source): N/A
Yes, this is a bug for 0.4.0 version, but currently ColorJitter's behaviour in the master is such that ColorJitter.get_params() is not used inside __call__/forward and kept for backward compatibility: #2298 (comment). Getting back ColorJitter.get_params is low-priority issue and once it is done, due to torch jit limitation on Union usage, we can not make it work on simple floats and list of floats...
My use case is that I want to apply the same transform (color jitter) to two images.
Here is a simple workaround for those interested:
from torchvision import transforms as T
cj = T.ColorJitter(0.4, 0.4, 0.4, 0.2)
t = cj.get_params(cj.brightness, cj.contrast, cj.saturation, cj.hue)
# t is now frozen (no randomness when called) and can be used to transform both images with the same transform.
Uh oh!
There was an error while loading. Please reload this page.
🐛 Bug
The docstring for
transforms.ColorJitter().get_params(...)
states that it accepts float and tuple inputs, just asColorJitter
's__init__(...)
, however, the current implementation supports tuple inputs only. Unlike the constructor it does not support float inputs.To Reproduce
Steps to reproduce the behavior:
`from torchvision import transforms as T
t = T.ColorJitter()
t = t.get_params(0.4, 0.4, 0.4, 0.2)`
Expected behavior
the
.get_params(...)
should take float inputs.conda
,pip
, source): pipcc @vfdev-5
The text was updated successfully, but these errors were encountered: