Skip to content

Commit d72e906

Browse files
authored
[prototype] Speed up adjust_hue_image_tensor (#6938)
* Performance optimization on adjust_hue_image_tensor * handle ints * Inplace logical ops * Remove unnecessary casting. * Fix linter.
1 parent 70edf96 commit d72e906

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,10 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
208208

209209
mask_maxc_neq_r = maxc != r
210210
mask_maxc_eq_g = maxc == g
211-
mask_maxc_neq_g = ~mask_maxc_eq_g
212211

213-
hr = (bc - gc).mul_(~mask_maxc_neq_r)
214-
hg = (2.0 + rc).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r)
215-
hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r)
212+
hg = rc.add(2.0).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r)
213+
hr = bc.sub_(gc).mul_(~mask_maxc_neq_r)
214+
hb = gc.add_(4.0).sub_(rc).mul_(mask_maxc_neq_r.logical_and_(mask_maxc_eq_g.logical_not_()))
216215

217216
h = hr.add_(hg).add_(hb)
218217
h = h.mul_(1.0 / 6.0).add_(1.0).fmod_(1.0)
@@ -221,14 +220,16 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
221220

222221
def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
223222
h, s, v = img.unbind(dim=-3)
224-
h6 = h * 6
223+
h6 = h.mul(6)
225224
i = torch.floor(h6)
226-
f = h6 - i
225+
f = h6.sub_(i)
227226
i = i.to(dtype=torch.int32)
228227

229-
p = (v * (1.0 - s)).clamp_(0.0, 1.0)
230-
q = (v * (1.0 - s * f)).clamp_(0.0, 1.0)
231-
t = (v * (1.0 - s * (1.0 - f))).clamp_(0.0, 1.0)
228+
sxf = s * f
229+
one_minus_s = 1.0 - s
230+
q = (1.0 - sxf).mul_(v).clamp_(0.0, 1.0)
231+
t = sxf.add_(one_minus_s).mul_(v).clamp_(0.0, 1.0)
232+
p = one_minus_s.mul_(v).clamp_(0.0, 1.0)
232233
i.remainder_(6)
233234

234235
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
@@ -238,7 +239,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
238239
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
239240
a4 = torch.stack((a1, a2, a3), dim=-4)
240241

241-
return (a4.mul_(mask.to(dtype=img.dtype).unsqueeze(dim=-4))).sum(dim=-3)
242+
return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3)
242243

243244

244245
def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor:

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def convert_format_bounding_box(
164164
if new_format == old_format:
165165
return bounding_box
166166

167+
# TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
167168
if old_format == BoundingBoxFormat.XYWH:
168169
bounding_box = _xywh_to_xyxy(bounding_box, inplace)
169170
elif old_format == BoundingBoxFormat.CXCYWH:

torchvision/prototype/transforms/functional/_type_conversion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import unittest.mock
21
from typing import Any, Dict, Tuple, Union
32

43
import numpy as np
@@ -20,6 +19,8 @@ def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image:
2019

2120
@torch.jit.unused
2221
def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
22+
import unittest.mock
23+
2324
with unittest.mock.patch("torchvision.io.video.os.path.exists", return_value=True):
2425
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]
2526

0 commit comments

Comments
 (0)