Skip to content

Commit ab321fc

Browse files
committed
WIP on adding gray images support for adjust_contrast (#4477)
[ghstack-poisoned]
1 parent c2a4e9f commit ab321fc

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

test/common_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,12 @@ def needs_cuda(test_func):
128128
def _create_data(height=3, width=3, channels=3, device="cpu"):
129129
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
130130
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
131-
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
131+
data = tensor.permute(1, 2, 0).contiguous().cpu().numpy()
132+
mode = "RGB"
133+
if channels == 1:
134+
mode = "L"
135+
data = data[..., 0]
136+
pil_img = Image.fromarray(data, mode=mode)
132137
return tensor, pil_img
133138

134139

test/test_functional_tensor.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -641,12 +641,14 @@ def backward(ctx, grad_output):
641641
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
642642

643643

644-
def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"):
644+
def check_functional_vs_PIL_vs_scripted(
645+
fn, fn_pil, fn_t, config, device, dtype, channels=3, tol=2.0 + 1e-10, agg_method="max"
646+
):
645647

646648
script_fn = torch.jit.script(fn)
647649
torch.manual_seed(15)
648-
tensor, pil_img = _create_data(26, 34, device=device)
649-
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
650+
tensor, pil_img = _create_data(26, 34, channels=channels, device=device)
651+
batch_tensors = _create_data_batch(16, 18, num_samples=4, channels=channels, device=device)
650652

651653
if dtype is not None:
652654
tensor = F.convert_image_dtype(tensor, dtype)
@@ -798,14 +800,16 @@ def test_equalize(device):
798800
@pytest.mark.parametrize('device', cpu_and_gpu())
799801
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
800802
@pytest.mark.parametrize('config', [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
801-
def test_adjust_contrast(device, dtype, config):
803+
@pytest.mark.parametrize('channels', [1, 3])
804+
def test_adjust_contrast(device, dtype, config, channels):
802805
check_functional_vs_PIL_vs_scripted(
803806
F.adjust_contrast,
804807
F_pil.adjust_contrast,
805808
F_t.adjust_contrast,
806809
config,
807810
device,
808-
dtype
811+
dtype,
812+
channels=channels
809813
)
810814

811815

torchvision/transforms/functional_tensor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
169169

170170
_assert_image_tensor(img)
171171

172-
_assert_channels(img, [3])
173-
172+
_assert_channels(img, [3, 1])
173+
c = get_image_num_channels(img)
174174
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
175-
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
175+
if c == 3:
176+
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
177+
else:
178+
mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
176179

177180
return _blend(img, mean, contrast_factor)
178181

0 commit comments

Comments
 (0)