Skip to content

BUG: fix nout=2 ufuncs #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions torch_np/_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def normalize_ndarray(arg, name=None):
from ._ndarray import ndarray

if not isinstance(arg, ndarray):
raise TypeError("'out' must be an array")
raise TypeError(f"'{name}' must be an array")
return arg.tensor


Expand Down Expand Up @@ -135,8 +135,9 @@ def maybe_copy_to(out, result, promote_scalar_result=False):
out.tensor.copy_(result)
return out
elif isinstance(result, (tuple, list)):
# FIXME: this is broken (there is no copy_to)
return type(result)(map(copy_to, zip(result, out)))
return type(result)(
maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result)
)
else:
assert False # We should never hit this path

Expand Down
21 changes: 13 additions & 8 deletions torch_np/_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _ufunc_postprocess(result, out, casting):
_binary = [
name
for name in dir(_binary_ufuncs_impl)
if not name.startswith("_") and name not in ["torch", "matmul"]
if not name.startswith("_") and name not in ["torch", "matmul", "divmod"]
]


Expand Down Expand Up @@ -106,6 +106,7 @@ def matmul(
#
# nin=2, nout=2
#
@normalizer
def divmod(
x1: ArrayLike,
x2: ArrayLike,
Expand All @@ -122,15 +123,19 @@ def divmod(
signature=None,
extobj=None,
):
num_outs = sum(x is None for x in [out1, out2])
if sum_outs == 1:
# make sure we either have no out arrays at all, or there is either
# out1, out2, or out=tuple, but not both
num_outs = sum(x is not None for x in [out1, out2])
if num_outs == 1:
raise ValueError("both out1 and out2 need to be provided")
if sum_outs != 0 and out != (None, None):
raise ValueError("Either provide out1 and out2, or out.")
if out is not None:
elif num_outs == 2:
o1, o2 = out
if o1 is not None or o2 is not None:
raise TypeError(
"cannot specify 'out' as both a positional and keyword argument"
)
else:
out1, out2 = out
if out1.shape != out2.shape or out1.dtype != out2.dtype:
raise ValueError("out1, out2 must be compatible")

tensors = _ufunc_preprocess(
(x1, x2), True, casting, order, dtype, subok, signature, extobj
Expand Down
55 changes: 55 additions & 0 deletions torch_np/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch_np as w
import torch_np._ufuncs as _ufuncs
from torch_np.testing import assert_equal

# These function receive one array_like arg and return one array_like result
one_arg_funcs = [
Expand Down Expand Up @@ -445,3 +446,57 @@ def test_typecast(self):
# force the type cast
w.copyto(dst, src, casting="unsafe")
assert (dst == src).all()


class TestDivmod:
def test_divmod_out(self):
x1 = w.arange(8, 15)
x2 = w.arange(4, 11)

out = (w.empty_like(x1), w.empty_like(x1))

quot, rem = w.divmod(x1, x2, out=out)

assert_equal(quot, x1 // x2)
assert_equal(rem, x1 % x2)

out1, out2 = out
assert quot is out[0]
assert rem is out[1]

def test_divmod_out_list(self):
x1 = [4, 5, 6]
x2 = [2, 1, 2]

out = (w.empty_like(x1), w.empty_like(x1))

quot, rem = w.divmod(x1, x2, out=out)

assert quot is out[0]
assert rem is out[1]

@pytest.mark.xfail(reason="out1, out2 not implemented")
def test_divmod_pos_only(self):
x1 = [4, 5, 6]
x2 = [2, 1, 2]

out1, out2 = w.empty_like(x1), w.empty_like(x1)

quot, rem = w.divmod(x1, x2, out1, out2)

assert quot is out1
assert rem is out2

def test_divmod_no_out(self):
# check that the out= machinery handles no out at all
x1 = w.array([4, 5, 6])
x2 = w.array([2, 1, 2])
quot, rem = w.divmod(x1, x2)

assert_equal(quot, x1 // x2)
assert_equal(rem, x1 % x2)

def test_divmod_out_both_pos_and_kw(self):
o = w.empty(1)
with assert_raises(TypeError):
w.divmod(1, 2, o, o, out=(o, o))