Skip to content

Commit fbf250b

Browse files
committed
BUG: fix nout=2 out handling (divmod)
1 parent f30f4c4 commit fbf250b

File tree

3 files changed

+60
-13
lines changed

3 files changed

+60
-13
lines changed

torch_np/_normalizations.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def normalize_ndarray(arg, name=None):
6868
from ._ndarray import ndarray
6969

7070
if not isinstance(arg, ndarray):
71-
raise TypeError("'out' must be an array")
71+
raise TypeError(f"'{name}' must be an array")
7272
return arg.tensor
7373

7474

@@ -123,7 +123,9 @@ def maybe_copy_to(out, result, promote_scalar_result=False):
123123
out.tensor.copy_(result)
124124
return out
125125
elif isinstance(result, (tuple, list)):
126-
return type(result)(map(copy_to, zip(result, out)))
126+
return type(result)(
127+
maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result)
128+
)
127129
else:
128130
assert False # We should never hit this path
129131

@@ -180,9 +182,6 @@ def wrapped(*args, **kwds):
180182

181183
if "out" in params:
182184
out = sig.bind(*args, **kwds).arguments.get("out")
183-
184-
### if out is not None: breakpoint()
185-
186185
result = maybe_copy_to(out, result, promote_scalar_result)
187186
result = wrap_tensors(result)
188187

torch_np/_ufuncs.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _ufunc_postprocess(result, out, casting):
3131
_binary = [
3232
name
3333
for name in dir(_binary_ufuncs_impl)
34-
if not name.startswith("_") and name not in ["torch", "matmul"]
34+
if not name.startswith("_") and name not in ["torch", "matmul", "divmod"]
3535
]
3636

3737

@@ -106,6 +106,7 @@ def matmul(
106106
#
107107
# nin=2, nout=2
108108
#
109+
@normalizer
109110
def divmod(
110111
x1: ArrayLike,
111112
x2: ArrayLike,
@@ -122,15 +123,19 @@ def divmod(
122123
signature=None,
123124
extobj=None,
124125
):
125-
num_outs = sum(x is None for x in [out1, out2])
126-
if sum_outs == 1:
126+
# make sure we either have no out arrays at all, or there is either
127+
# out1, out2, or out=tuple, but not both
128+
out1t, out2t = out
129+
num_outs = sum(x is not None for x in [out1, out2])
130+
if num_outs == 1:
127131
raise ValueError("both out1 and out2 need to be provided")
128-
if sum_outs != 0 and out != (None, None):
129-
raise ValueError("Either provide out1 and out2, or out.")
130-
if out is not None:
132+
else:
131133
out1, out2 = out
132-
if out1.shape != out2.shape or out1.dtype != out2.dtype:
133-
raise ValueError("out1, out2 must be compatible")
134+
if num_outs == 2:
135+
if out1 is not None or out2 is not None:
136+
raise TypeError(
137+
"cannot specify 'out' as both a positional and keyword argument"
138+
)
134139

135140
tensors = _ufunc_preprocess(
136141
(x1, x2), True, casting, order, dtype, subok, signature, extobj

torch_np/tests/test_basic.py

+43
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch_np as w
99
import torch_np._ufuncs as _ufuncs
10+
from torch_np.testing import assert_equal
1011

1112
# These function receive one array_like arg and return one array_like result
1213
one_arg_funcs = [
@@ -445,3 +446,45 @@ def test_typecast(self):
445446
# force the type cast
446447
w.copyto(dst, src, casting="unsafe")
447448
assert (dst == src).all()
449+
450+
451+
class TestDivmod:
452+
def test_divmod_out(self):
453+
x1 = w.arange(8, 15)
454+
x2 = w.arange(4, 11)
455+
456+
out = (w.empty_like(x1), w.empty_like(x1))
457+
458+
quot, rem = w.divmod(x1, x2, out=out)
459+
460+
assert_equal(quot, x1 // x2)
461+
assert_equal(rem, x1 % x2)
462+
463+
out1, out2 = out
464+
assert quot is out[0]
465+
assert rem is out[1]
466+
467+
def test_divmod_out_list(self):
468+
x1 = [4, 5, 6]
469+
x2 = [2, 1, 2]
470+
471+
out = (w.empty_like(x1), w.empty_like(x1))
472+
473+
quot, rem = w.divmod(x1, x2, out=out)
474+
475+
assert quot is out[0]
476+
assert rem is out[1]
477+
478+
def test_divmod_no_out(self):
479+
# check that the out= machinery handles no out at all
480+
x1 = w.array([4, 5, 6])
481+
x2 = w.array([2, 1, 2])
482+
quot, rem = w.divmod(x1, x2)
483+
484+
assert_equal(quot, x1 // x2)
485+
assert_equal(rem, x1 % x2)
486+
487+
def test_divmod_out_both_pos_and_kw(self):
488+
o = w.empty(1)
489+
with assert_raises(TypeError):
490+
w.divmod(1, 2, o, o, out=(o, o))

0 commit comments

Comments
 (0)