diff --git a/torch_np/_normalizations.py b/torch_np/_normalizations.py index d8f9d376..d2757440 100644 --- a/torch_np/_normalizations.py +++ b/torch_np/_normalizations.py @@ -80,7 +80,7 @@ def normalize_ndarray(arg, name=None): from ._ndarray import ndarray if not isinstance(arg, ndarray): - raise TypeError("'out' must be an array") + raise TypeError(f"'{name}' must be an array") return arg.tensor @@ -135,8 +135,9 @@ def maybe_copy_to(out, result, promote_scalar_result=False): out.tensor.copy_(result) return out elif isinstance(result, (tuple, list)): - # FIXME: this is broken (there is no copy_to) - return type(result)(map(copy_to, zip(result, out))) + return type(result)( + maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result) + ) else: assert False # We should never hit this path diff --git a/torch_np/_ufuncs.py b/torch_np/_ufuncs.py index 57144fae..8695bb8c 100644 --- a/torch_np/_ufuncs.py +++ b/torch_np/_ufuncs.py @@ -31,7 +31,7 @@ def _ufunc_postprocess(result, out, casting): _binary = [ name for name in dir(_binary_ufuncs_impl) - if not name.startswith("_") and name not in ["torch", "matmul"] + if not name.startswith("_") and name not in ["torch", "matmul", "divmod"] ] @@ -106,6 +106,7 @@ def matmul( # # nin=2, nout=2 # +@normalizer def divmod( x1: ArrayLike, x2: ArrayLike, @@ -122,15 +123,19 @@ def divmod( signature=None, extobj=None, ): - num_outs = sum(x is None for x in [out1, out2]) - if sum_outs == 1: + # make sure we either have no out arrays at all, or there is either + # out1, out2, or out=tuple, but not both + num_outs = sum(x is not None for x in [out1, out2]) + if num_outs == 1: raise ValueError("both out1 and out2 need to be provided") - if sum_outs != 0 and out != (None, None): - raise ValueError("Either provide out1 and out2, or out.") - if out is not None: + elif num_outs == 2: + o1, o2 = out + if o1 is not None or o2 is not None: + raise TypeError( + "cannot specify 'out' as both a positional and keyword argument" + ) + else: out1, out2 = out - if out1.shape != out2.shape or out1.dtype != out2.dtype: - raise ValueError("out1, out2 must be compatible") tensors = _ufunc_preprocess( (x1, x2), True, casting, order, dtype, subok, signature, extobj diff --git a/torch_np/tests/test_basic.py b/torch_np/tests/test_basic.py index 7cb8d3bc..e77a8a07 100644 --- a/torch_np/tests/test_basic.py +++ b/torch_np/tests/test_basic.py @@ -7,6 +7,7 @@ import torch_np as w import torch_np._ufuncs as _ufuncs +from torch_np.testing import assert_equal # These function receive one array_like arg and return one array_like result one_arg_funcs = [ @@ -445,3 +446,57 @@ def test_typecast(self): # force the type cast w.copyto(dst, src, casting="unsafe") assert (dst == src).all() + + +class TestDivmod: + def test_divmod_out(self): + x1 = w.arange(8, 15) + x2 = w.arange(4, 11) + + out = (w.empty_like(x1), w.empty_like(x1)) + + quot, rem = w.divmod(x1, x2, out=out) + + assert_equal(quot, x1 // x2) + assert_equal(rem, x1 % x2) + + out1, out2 = out + assert quot is out[0] + assert rem is out[1] + + def test_divmod_out_list(self): + x1 = [4, 5, 6] + x2 = [2, 1, 2] + + out = (w.empty_like(x1), w.empty_like(x1)) + + quot, rem = w.divmod(x1, x2, out=out) + + assert quot is out[0] + assert rem is out[1] + + @pytest.mark.xfail(reason="out1, out2 not implemented") + def test_divmod_pos_only(self): + x1 = [4, 5, 6] + x2 = [2, 1, 2] + + out1, out2 = w.empty_like(x1), w.empty_like(x1) + + quot, rem = w.divmod(x1, x2, out1, out2) + + assert quot is out1 + assert rem is out2 + + def test_divmod_no_out(self): + # check that the out= machinery handles no out at all + x1 = w.array([4, 5, 6]) + x2 = w.array([2, 1, 2]) + quot, rem = w.divmod(x1, x2) + + assert_equal(quot, x1 // x2) + assert_equal(rem, x1 % x2) + + def test_divmod_out_both_pos_and_kw(self): + o = w.empty(1) + with assert_raises(TypeError): + w.divmod(1, 2, o, o, out=(o, o))