Skip to content

Commit 04b65df

Browse files
jotsiffacebook-github-bot
authored andcommitted
Issue 14984: Remove divide by zero error in index_put_ (#14986)
Summary: No check for zero index tensor was done in the accumulate=True (serial) case in the new TensorIterator code since #13420. #14984 Pull Request resolved: #14986 Differential Revision: D13417861 Pulled By: colesbury fbshipit-source-id: e6ed1af8f708b53a35803fc157ed1f043169ec89
1 parent 109c8d2 commit 04b65df

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

aten/src/ATen/native/TensorIterator.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,9 @@ void TensorIterator::serial_for_each(const loop_t& loop, Range range) const {
379379
}
380380

381381
void TensorIterator::serial_for_each(const loop2d_t& loop, Range range) const {
382+
if (range.size() == 0) {
383+
return;
384+
}
382385
auto strides = get_strides();
383386
while (strides.size() < 2 * ntensors()) {
384387
strides.push_back(0);
@@ -682,8 +685,10 @@ DimCounter::DimCounter(IntList shape, Range range)
682685
int64_t ndim = values.size();
683686
for (int dim = 0; dim < ndim; dim++) {
684687
int64_t size = shape[dim];
685-
values[dim] = linear_offset % size;
686-
linear_offset /= size;
688+
if (size > 0) {
689+
values[dim] = linear_offset % size;
690+
linear_offset /= size;
691+
}
687692
}
688693
AT_ASSERT(linear_offset == 0);
689694
}

test/test_indexing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def test_byte_mask(self):
4545
v = torch.tensor([1.])
4646
self.assertEqual(v[v == 0], torch.tensor([]))
4747

48+
def test_byte_mask_accumulate(self):
49+
mask = torch.zeros(size=(10, ), dtype=torch.uint8)
50+
y = torch.ones(size=(10, 10))
51+
y.index_put_((mask, ), y[mask], accumulate=True)
52+
self.assertEqual(y, torch.ones(size=(10, 10)))
53+
4854
def test_multiple_byte_mask(self):
4955
v = torch.randn(5, 7, 3)
5056
# note: these broadcast together and are transposed to the first dim

0 commit comments

Comments
 (0)