@@ -183,28 +183,37 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
183
183
return autocontrast_image_pil (inpt )
184
184
185
185
186
- def _scale_channel (img_chan : torch .Tensor ) -> torch .Tensor :
187
- # TODO: we should expect bincount to always be faster than histc, but this
188
- # isn't always the case. Once
189
- # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
190
- # block and only use bincount.
191
- if img_chan .is_cuda :
192
- hist = torch .histc (img_chan .to (torch .float32 ), bins = 256 , min = 0 , max = 255 )
193
- else :
194
- hist = torch .bincount (img_chan .view (- 1 ), minlength = 256 )
195
-
196
- nonzero_hist = hist [hist != 0 ]
197
- step = torch .div (nonzero_hist [:- 1 ].sum (), 255 , rounding_mode = "floor" )
198
- if step == 0 :
199
- return img_chan
200
-
201
- lut = torch .div (torch .cumsum (hist , 0 ) + torch .div (step , 2 , rounding_mode = "floor" ), step , rounding_mode = "floor" )
202
- # Doing inplace clamp and converting lut to uint8 improves perfs
203
- lut .clamp_ (0 , 255 )
204
- lut = lut .to (torch .uint8 )
205
- lut = torch .nn .functional .pad (lut [:- 1 ], [1 , 0 ])
206
-
207
- return lut [img_chan .to (torch .int64 )]
186
+ def _equalize_image_tensor_vec (img : torch .Tensor ) -> torch .Tensor :
187
+ # input img shape should be [N, H, W]
188
+ shape = img .shape
189
+ # Compute image histogram:
190
+ flat_img = img .flatten (start_dim = 1 ).to (torch .long ) # -> [N, H * W]
191
+ hist = flat_img .new_zeros (shape [0 ], 256 )
192
+ hist .scatter_add_ (dim = 1 , index = flat_img , src = flat_img .new_ones (1 ).expand_as (flat_img ))
193
+
194
+ # Compute image cdf
195
+ chist = hist .cumsum_ (dim = 1 )
196
+ # Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255
197
+ # Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax()
198
+ idx = chist .argmax (dim = 1 ).sub_ (1 )
199
+ # If histogram is degenerate (hist of zero image), index is -1
200
+ neg_idx_mask = idx < 0
201
+ idx .clamp_ (min = 0 )
202
+ step = chist .gather (dim = 1 , index = idx .unsqueeze (1 ))
203
+ step [neg_idx_mask ] = 0
204
+ step .div_ (255 , rounding_mode = "floor" )
205
+
206
+ # Compute batched Look-up-table:
207
+ # Necessary to avoid an integer division by zero, which raises
208
+ clamped_step = step .clamp (min = 1 )
209
+ chist .add_ (torch .div (step , 2 , rounding_mode = "floor" )).div_ (clamped_step , rounding_mode = "floor" ).clamp_ (0 , 255 )
210
+ lut = chist .to (torch .uint8 ) # [N, 256]
211
+
212
+ # Pad lut with zeros
213
+ zeros = lut .new_zeros ((1 , 1 )).expand (shape [0 ], 1 )
214
+ lut = torch .cat ([zeros , lut [:, :- 1 ]], dim = 1 )
215
+
216
+ return torch .where ((step == 0 ).unsqueeze (- 1 ), img , lut .gather (dim = 1 , index = flat_img ).view_as (img ))
208
217
209
218
210
219
def equalize_image_tensor (image : torch .Tensor ) -> torch .Tensor :
@@ -217,10 +226,8 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
217
226
218
227
if image .numel () == 0 :
219
228
return image
220
- elif image .ndim == 2 :
221
- return _scale_channel (image )
222
- else :
223
- return torch .stack ([_scale_channel (x ) for x in image .view (- 1 , height , width )]).view (image .shape )
229
+
230
+ return _equalize_image_tensor_vec (image .view (- 1 , height , width )).view (image .shape )
224
231
225
232
226
233
equalize_image_pil = _FP .equalize
0 commit comments