Skip to content

Commit 77e4187

Browse files
authored
Speed up equalize transform: use bincount instead of histc (#3493)
* use bincount instead of hist * only use bincount when on CPU * Added equality test for CPU vs cuda * Fix flake8 and tests * tuple instead of int for size
1 parent 414427d commit 77e4187

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

test/test_functional_tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,18 @@ class CUDATester(Tester):
977977
def setUp(self):
978978
self.device = "cuda"
979979

980+
def test_scale_channel(self):
981+
"""Make sure that _scale_channel gives the same results on CPU and GPU as
982+
histc or bincount are used depending on the device.
983+
"""
984+
# TODO: when # https://github.com/pytorch/pytorch/issues/53194 is fixed,
985+
# only use bincount and remove that test.
986+
size = (1_000,)
987+
img_chan = torch.randint(0, 256, size=size).to('cpu')
988+
scaled_cpu = F_t._scale_channel(img_chan)
989+
scaled_cuda = F_t._scale_channel(img_chan.to('cuda'))
990+
self.assertTrue(scaled_cpu.equal(scaled_cuda.to('cpu')))
991+
980992

981993
if __name__ == '__main__':
982994
unittest.main()

torchvision/transforms/functional_tensor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,14 @@ def autocontrast(img: Tensor) -> Tensor:
902902

903903

904904
def _scale_channel(img_chan):
905-
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
905+
# TODO: we should expect bincount to always be faster than histc, but this
906+
# isn't always the case. Once
907+
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
908+
# block and only use bincount.
909+
if img_chan.is_cuda:
910+
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
911+
else:
912+
hist = torch.bincount(img_chan.view(-1), minlength=256)
906913

907914
nonzero_hist = hist[hist != 0]
908915
step = nonzero_hist[:-1].sum() // 255

0 commit comments

Comments
 (0)