Skip to content

Commit b16dec1

Browse files
authored
[proto] Performance improvements for equalize op (#6757)
* [proto] Performance improvements for equalize op * Added tests
1 parent 54a2d4e commit b16dec1

File tree

2 files changed

+44
-26
lines changed

2 files changed

+44
-26
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,3 +1037,14 @@ def test_to_image_pil(inpt, mode):
10371037
assert isinstance(output, PIL.Image.Image)
10381038

10391039
assert np.asarray(inpt).sum() == np.asarray(output).sum()
1040+
1041+
1042+
def test_equalize_image_tensor_edge_cases():
1043+
inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
1044+
output = F.equalize_image_tensor(inpt)
1045+
torch.testing.assert_close(inpt, output)
1046+
1047+
inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
1048+
inpt[..., 100:, 100:] = 1
1049+
output = F.equalize_image_tensor(inpt)
1050+
assert output.unique().tolist() == [0, 255]

torchvision/prototype/transforms/functional/_color.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -183,28 +183,37 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
183183
return autocontrast_image_pil(inpt)
184184

185185

186-
def _scale_channel(img_chan: torch.Tensor) -> torch.Tensor:
187-
# TODO: we should expect bincount to always be faster than histc, but this
188-
# isn't always the case. Once
189-
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
190-
# block and only use bincount.
191-
if img_chan.is_cuda:
192-
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
193-
else:
194-
hist = torch.bincount(img_chan.view(-1), minlength=256)
195-
196-
nonzero_hist = hist[hist != 0]
197-
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
198-
if step == 0:
199-
return img_chan
200-
201-
lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
202-
# Doing inplace clamp and converting lut to uint8 improves perfs
203-
lut.clamp_(0, 255)
204-
lut = lut.to(torch.uint8)
205-
lut = torch.nn.functional.pad(lut[:-1], [1, 0])
206-
207-
return lut[img_chan.to(torch.int64)]
186+
def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor:
187+
# input img shape should be [N, H, W]
188+
shape = img.shape
189+
# Compute image histogram:
190+
flat_img = img.flatten(start_dim=1).to(torch.long) # -> [N, H * W]
191+
hist = flat_img.new_zeros(shape[0], 256)
192+
hist.scatter_add_(dim=1, index=flat_img, src=flat_img.new_ones(1).expand_as(flat_img))
193+
194+
# Compute image cdf
195+
chist = hist.cumsum_(dim=1)
196+
# Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255
197+
# Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax()
198+
idx = chist.argmax(dim=1).sub_(1)
199+
# If histogram is degenerate (hist of zero image), index is -1
200+
neg_idx_mask = idx < 0
201+
idx.clamp_(min=0)
202+
step = chist.gather(dim=1, index=idx.unsqueeze(1))
203+
step[neg_idx_mask] = 0
204+
step.div_(255, rounding_mode="floor")
205+
206+
# Compute batched Look-up-table:
207+
# Necessary to avoid an integer division by zero, which raises
208+
clamped_step = step.clamp(min=1)
209+
chist.add_(torch.div(step, 2, rounding_mode="floor")).div_(clamped_step, rounding_mode="floor").clamp_(0, 255)
210+
lut = chist.to(torch.uint8) # [N, 256]
211+
212+
# Pad lut with zeros
213+
zeros = lut.new_zeros((1, 1)).expand(shape[0], 1)
214+
lut = torch.cat([zeros, lut[:, :-1]], dim=1)
215+
216+
return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).view_as(img))
208217

209218

210219
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
@@ -217,10 +226,8 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
217226

218227
if image.numel() == 0:
219228
return image
220-
elif image.ndim == 2:
221-
return _scale_channel(image)
222-
else:
223-
return torch.stack([_scale_channel(x) for x in image.view(-1, height, width)]).view(image.shape)
229+
230+
return _equalize_image_tensor_vec(image.view(-1, height, width)).view(image.shape)
224231

225232

226233
equalize_image_pil = _FP.equalize

0 commit comments

Comments
 (0)