Skip to content

Commit f3d84b3

Browse files
vishwakftwRob Kunkle
authored and
Rob Kunkle
committed
Fix bincount for empty input (pytorch#9757)
Summary: Added tests too. Fixes pytorch#9756 . Pull Request resolved: pytorch#9757 Differential Revision: D8966879 Pulled By: soumith fbshipit-source-id: 9f08a9d5d5d037db16319141d7a227a5efa23869
1 parent 3f94b5d commit f3d84b3

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

test/test_torch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8335,6 +8335,12 @@ def _test_bincount(self, device):
83358335
with self.assertRaisesRegex(RuntimeError, 'same length'):
83368336
torch.bincount(torch.tensor([1, 0], device=device),
83378337
torch.tensor([1., 0.3, 0.5], device=device))
8338+
# 1-d input with no elements and default minlength
8339+
self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
8340+
torch.zeros(0, dtype=torch.long, device=device))
8341+
# 1-d input with no elements and specified minlength
8342+
self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
8343+
torch.zeros(10, dtype=torch.long, device=device))
83388344

83398345
# test tensor method without weights
83408346
long_counts = torch.tensor(

torch/_torch_docs.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,19 +607,22 @@ def parse_kwargs(desc):
607607
Count the frequency of each value in an array of non-negative ints.
608608
609609
The number of bins (size 1) is one larger than the largest value in
610-
:attr:`input`. If :attr:`minlength` is specified, the number of bins is at least
611-
:attr:`minlength`. If ``n`` is the value at position ``i``,
610+
:attr:`input` unless :attr:`input` is empty, in which case the result is a
611+
tensor of size 0. If :attr:`minlength` is specified, the number of bins is at least
612+
:attr:`minlength` and if :attr:`input` is empty, then the result is tensor of size
613+
:attr:`minlength` filled with zeros. If ``n`` is the value at position ``i``,
612614
:math:`out[n] += weights[i]` if :attr:`weights` is specified else
613615
:math:`out[n] += 1`.
614616
615617
Arguments:
616618
input (Tensor): 1-d int tensor
617619
weights (Tensor): optional, weight for each value in the input tensor.
618620
Should be of same size as input tensor.
619-
minlength (int): optional, min number of bins. Should be non-negative.
621+
minlength (int): optional, minimum number of bins. Should be non-negative.
620622
621623
Shape:
622-
output (Tensor): ``Size([max(input) + 1])``
624+
output (Tensor): ``Size([max(input) + 1])`` if :attr:`input` is non-empty, else
625+
``Size(0)``
623626
624627
Example::
625628

0 commit comments

Comments
 (0)