diff --git a/array_api_compat/paddle/_aliases.py b/array_api_compat/paddle/_aliases.py index 88f71e7d..df45aef5 100644 --- a/array_api_compat/paddle/_aliases.py +++ b/array_api_compat/paddle/_aliases.py @@ -462,6 +462,20 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = paddle.unsqueeze(out, a) return out +_NP_2_PADDLE_DTYPE = { + "BOOL": 'bool', + "UINT8": 'uint8', + "INT8": 'int8', + "INT16": 'int16', + "INT32": 'int32', + "INT64": 'int64', + "FLOAT16": 'float16', + "BFLOAT16": 'bfloat16', + "FLOAT32": 'float32', + "FLOAT64": 'float64', + "COMPLEX128": 'complex128', + "COMPLEX64": 'complex64', +} def prod( x: array, @@ -476,7 +490,36 @@ def prod( x = paddle.to_tensor(x) ndim = x.ndim - # below because it still needs to upcast. + # fix reducing on the zero dimension + if x.numel() == 0: + if dtype is not None: + output_dtype = _NP_2_PADDLE_DTYPE[dtype.name] + else: + if x.dtype == paddle.bool: + output_dtype = paddle.int64 + else: + output_dtype = x.dtype + + if axis is None: + return paddle.to_tensor(1, dtype=output_dtype) + + if keepdims: + output_shape = list(x.shape) + if isinstance(axis, int): + axis = (axis,) + for ax in axis: + output_shape[ax] = 1 + else: + output_shape = [dim for i, dim in enumerate(x.shape) if i not in (axis if isinstance(axis, tuple) else [axis])] + if not output_shape: + return paddle.to_tensor(1, dtype=output_dtype) + + return paddle.ones(output_shape, dtype=output_dtype) + + + if dtype is not None: + dtype = _NP_2_PADDLE_DTYPE[dtype.name] + if axis == (): if dtype is None: # We can't upcast uint8 according to the spec because there is no @@ -492,13 +535,17 @@ def prod( return _reduce_multiple_axes( paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs ) + + if axis is None: # paddle doesn't support keepdims with axis=None + if dtype is None and x.dtype == paddle.int32: + dtype = 'int64' res = paddle.prod(x, dtype=dtype, **kwargs) res = _axis_none_keepdims(res, ndim, keepdims) return res - - return paddle.prod(x, axis, dtype=dtype, keepdim=keepdims, **kwargs) + + return paddle.prod(x, axis=axis, keepdims=keepdims, dtype=dtype, **kwargs) def sum( @@ -747,7 +794,17 @@ def roll( def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") - return paddle.nonzero(x, as_tuple=True, **kwargs) + + if paddle.is_floating_point(x) or paddle.is_complex(x) : + # Use paddle.isclose() to determine which elements are + # "close enough" to zero. + zero_tensor = paddle.zeros(shape=x.shape ,dtype=x.dtype) + is_zero_mask = paddle.isclose(x, zero_tensor) + is_nonzero_mask = paddle.logical_not(is_zero_mask) + return paddle.nonzero(is_nonzero_mask, as_tuple=True, **kwargs) + + else: + return paddle.nonzero(x, as_tuple=True, **kwargs) def where(condition: array, x1: array, x2: array, /) -> array: @@ -832,7 +889,7 @@ def eye( if n_cols is None: n_cols = n_rows z = paddle.zeros([n_rows, n_cols], dtype=dtype, **kwargs).to(device) - if abs(k) <= n_rows + n_cols: + if n_rows > 0 and n_cols > 0 and abs(k) <= n_rows + n_cols: z.diagonal(k).fill_(1) return z @@ -867,7 +924,11 @@ def full( ) -> array: if isinstance(shape, int): shape = (shape,) - + if dtype is None : + if isinstance(fill_value, (bool)): + dtype = "bool" + elif isinstance(fill_value, int): + dtype = 'int64' return paddle.full(shape, fill_value, dtype=dtype, **kwargs).to(device) @@ -914,6 +975,8 @@ def triu(x: array, /, *, k: int = 0) -> array: def expand_dims(x: array, /, *, axis: int = 0) -> array: + if axis < -x.ndim - 1 or axis > x.ndim: + raise IndexError(f"Axis {axis} is out of bounds for array of dimension { x.ndim}") return paddle.unsqueeze(x, axis) @@ -973,6 +1036,22 @@ def unique_values(x: array) -> array: def matmul(x1: array, x2: array, /, **kwargs) -> array: # paddle.matmul doesn't type promote (but differently from _fix_promotion) + d1 = x1.ndim + d2 = x2.ndim + + if d1 == 0 or d2 == 0: + raise ValueError("matmul does not support 0-D (scalar) inputs.") + + k1 = x1.shape[-1] + + if d2 == 1: + k2 = x2.shape[0] + else: + k2 = x2.shape[-2] + + if k1 != k2: + raise ValueError(f"Shapes {x1.shape} and {x2.shape} are not aligned for matmul: " + f"{k1} (dim -1 of x1) != {k2} (dim -2 of x2)") x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return paddle.matmul(x1, x2, **kwargs) @@ -988,7 +1067,36 @@ def meshgrid(*arrays: array, indexing: str = "xy") -> List[array]: def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + shape1 = x1.shape + shape2 = x2.shape + rank1 = len(shape1) + rank2 = len(shape2) + if rank1 == 0 or rank2 == 0: + raise ValueError( + f"Vector dot product requires non-scalar inputs (rank > 0). " + f"Got ranks {rank1} and {rank2} for shapes {shape1} and {shape2}." + ) + try: + norm_axis1 = axis if axis >= 0 else rank1 + axis + if not (0 <= norm_axis1 < rank1): + raise IndexError # Axis out of bounds for x1 + norm_axis2 = axis if axis >= 0 else rank2 + axis + if not (0 <= norm_axis2 < rank2): + raise IndexError # Axis out of bounds for x2 + size1 = shape1[norm_axis1] + size2 = shape2[norm_axis2] + except IndexError: + raise ValueError( + f"Axis {axis} is out of bounds for input shapes {shape1} (rank {rank1}) " + f"and/or {shape2} (rank {rank2})." + ) + + if size1 != size2: + raise ValueError( + f"Inputs must have the same dimension size along the reduction axis ({axis}). " + f"Got shapes {shape1} and {shape2}, with sizes {size1} and {size2} " + f"along the normalized axis {norm_axis1} and {norm_axis2} respectively." + ) return paddle.linalg.vecdot(x1, x2, axis=axis) @@ -1063,21 +1171,39 @@ def is_complex(dtype): def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array: - if axis is None: + _axis = axis + if _axis is None: if x.ndim != 1: - raise ValueError("axis must be specified when ndim > 1") - axis = 0 - return paddle.index_select(x, axis, indices, **kwargs) + raise ValueError("axis must be specified when x.ndim > 1") + _axis = 0 + elif not isinstance(_axis, int): + raise TypeError(f"axis must be an integer, but received {type(_axis)}") + + if not (-x.ndim <= _axis < x.ndim): + raise IndexError(f"axis {_axis} is out of bounds for tensor of dimension {x.ndim}") + + if isinstance(indices, paddle.Tensor): + indices_tensor = indices + elif isinstance(indices, int): + indices_tensor = paddle.to_tensor([indices], dtype='int64') + else: + # Otherwise (e.g., list, tuple), convert directly + indices_tensor = paddle.to_tensor(indices, dtype='int64') + # Ensure indices is a 1D tensor + if indices_tensor.ndim == 0: + indices_tensor = indices_tensor.reshape([1]) + elif indices_tensor.ndim > 1: + raise ValueError(f"indices must be a 1D tensor, but received a {indices_tensor.ndim}D tensor") + + return paddle.index_select(x, index=indices_tensor, axis=_axis) def sign(x: array, /) -> array: # paddle sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 - if paddle.is_complex(x): - out = x / paddle.abs(x) - # sign(0) = 0 but the above formula would give nan - out[x == 0 + 0j] = 0 + 0j - return out + if paddle.is_complex(x) and x.ndim == 0 and x.item() == 0j: + # Handle 0-D complex zero explicitly + return paddle.zeros_like(x, dtype=x.dtype) else: out = paddle.sign(x) if paddle.is_floating_point(x): diff --git a/array_api_compat/paddle/linalg.py b/array_api_compat/paddle/linalg.py index 7dd1a266..aa091c81 100644 --- a/array_api_compat/paddle/linalg.py +++ b/array_api_compat/paddle/linalg.py @@ -84,6 +84,8 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr # Use our wrapped sum to make sure it does upcasting correctly return sum(paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1), axis=-1, dtype=dtype) +def diagonal(x: ndarray, / , *, offset: int = 0, **kwargs) -> ndarray: + return paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) def vector_norm( x: array, @@ -123,24 +125,52 @@ def matrix_norm( keepdims: bool = False, ord: Optional[Union[int, float, Literal["fro", "nuc"]]] = "fro", ) -> array: - return paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims) - + res = paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims) + if res.dtype == paddle.complex64 : + res = paddle.cast(res, "float32") + if res.dtype == paddle.complex128: + res = paddle.cast(res, "float64") + return res def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array: if rtol is None: return paddle.linalg.pinv(x) + + # change rtol shape + if isinstance(rtol, (int, float)): + rtol = paddle.to_tensor(rtol, dtype=x.dtype) + + # broadcast rtol to [..., 1] + if rtol.ndim > 0: + rtol = rtol.unsqueeze(-1) return paddle.linalg.pinv(x, rcond=rtol) def slogdet(x: array): - det = paddle.linalg.det(x) - sign = paddle.sign(det) - log_det = paddle.log(det) + return tuple_to_namedtuple(paddle.linalg.slogdet(x), ["sign", "logabsdet"]) + +def tuple_to_namedtuple(data, fields): + nt_class = namedtuple('DynamicNameTuple', fields) + return nt_class(*data) + +def eigh(x: array): + return tuple_to_namedtuple(paddle.linalg.eigh(x), ['eigenvalues', 'eigenvectors']) + +def qr(x: array, mode: Optional[str] = None) -> array: + if mode is None: + return tuple_to_namedtuple(paddle.linalg.qr(x), ['Q', 'R']) + + return tuple_to_namedtuple(paddle.linalg.qr(x, mode), ['Q', 'R']) + - slotdet = namedtuple("slotdet", ["sign", "logabsdet"]) - return slotdet(sign, log_det) +def svd(x: array, full_matrices: Optional[bool]= None) -> array: + if full_matrices is None : + return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices=True), ['U', 'S', 'Vh']) + return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices), ['U', 'S', 'Vh']) +def svdvals(x: array) -> array: + return paddle.linalg.svd(x)[1] __all__ = linalg_all + [ "outer", @@ -154,6 +184,9 @@ def slogdet(x: array): "trace", "vector_norm", "slogdet", + "eigh", + "diagonal", + "svdvals" ] _all_ignore = ["paddle_linalg", "sum"] diff --git a/paddle-xfails.txt b/paddle-xfails.txt index 6998f374..b92267c2 100644 --- a/paddle-xfails.txt +++ b/paddle-xfails.txt @@ -106,3 +106,62 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] array_api_tests/test_searching_functions.py::test_where + +array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_pow[pow(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_hypot +array_api_tests/test_operators_and_elementwise_functions.py::test_copysign +array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] +array_api_tests/test_linalg.py::test_outer +array_api_tests/test_linalg.py::test_vecdot +array_api_tests/test_operators_and_elementwise_functions.py::test_clip +array_api_tests/test_manipulation_functions.py::test_stack +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide + +# do not pass +array_api_tests/test_has_names[array_attribute-device] +array_api_tests/test_signatures.py::test_func_signature[meshgrid] + +array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)] +array_api_tests/test_indexing_functions.py::test_take +array_api_tests/test_linalg.py::test_linalg_vecdot +array_api_tests/test_creation_functions.py::test_asarray_arrays + +array_api_tests/test_linalg.py::test_qr + +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor + +# test exceeds the deadline of 800ms +array_api_tests/test_linalg.py::test_pinv +array_api_tests/test_linalg.py::test_det + +# only supports access to dimension 0 to 9, but received dimension is 10. +array_api_tests/test_linalg.py::test_tensordot +array_api_tests/test_linalg.py::test_linalg_tensordot \ No newline at end of file