Skip to content

Support paddle #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 142 additions & 16 deletions array_api_compat/paddle/_aliases.py
Original file line number Diff line number Diff line change
@@ -462,6 +462,20 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
out = paddle.unsqueeze(out, a)
return out

_NP_2_PADDLE_DTYPE = {
"BOOL": 'bool',
"UINT8": 'uint8',
"INT8": 'int8',
"INT16": 'int16',
"INT32": 'int32',
"INT64": 'int64',
"FLOAT16": 'float16',
"BFLOAT16": 'bfloat16',
"FLOAT32": 'float32',
"FLOAT64": 'float64',
"COMPLEX128": 'complex128',
"COMPLEX64": 'complex64',
}

def prod(
x: array,
@@ -476,7 +490,36 @@ def prod(
x = paddle.to_tensor(x)
ndim = x.ndim

# below because it still needs to upcast.
# fix reducing on the zero dimension
if x.numel() == 0:
if dtype is not None:
output_dtype = _NP_2_PADDLE_DTYPE[dtype.name]
else:
if x.dtype == paddle.bool:
output_dtype = paddle.int64
else:
output_dtype = x.dtype

if axis is None:
return paddle.to_tensor(1, dtype=output_dtype)

if keepdims:
output_shape = list(x.shape)
if isinstance(axis, int):
axis = (axis,)
for ax in axis:
output_shape[ax] = 1
else:
output_shape = [dim for i, dim in enumerate(x.shape) if i not in (axis if isinstance(axis, tuple) else [axis])]
if not output_shape:
return paddle.to_tensor(1, dtype=output_dtype)

return paddle.ones(output_shape, dtype=output_dtype)


if dtype is not None:
dtype = _NP_2_PADDLE_DTYPE[dtype.name]

if axis == ():
if dtype is None:
# We can't upcast uint8 according to the spec because there is no
@@ -492,13 +535,17 @@ def prod(
return _reduce_multiple_axes(
paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
)


if axis is None:
# paddle doesn't support keepdims with axis=None
if dtype is None and x.dtype == paddle.int32:
dtype = 'int64'
res = paddle.prod(x, dtype=dtype, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res

return paddle.prod(x, axis, dtype=dtype, keepdim=keepdims, **kwargs)
return paddle.prod(x, axis=axis, keepdims=keepdims, dtype=dtype, **kwargs)


def sum(
@@ -747,7 +794,17 @@ def roll(
def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
if x.ndim == 0:
raise ValueError("nonzero() does not support zero-dimensional arrays")
return paddle.nonzero(x, as_tuple=True, **kwargs)

if paddle.is_floating_point(x) or paddle.is_complex(x) :
# Use paddle.isclose() to determine which elements are
# "close enough" to zero.
zero_tensor = paddle.zeros(shape=x.shape ,dtype=x.dtype)
is_zero_mask = paddle.isclose(x, zero_tensor)
is_nonzero_mask = paddle.logical_not(is_zero_mask)
return paddle.nonzero(is_nonzero_mask, as_tuple=True, **kwargs)

else:
return paddle.nonzero(x, as_tuple=True, **kwargs)


def where(condition: array, x1: array, x2: array, /) -> array:
@@ -832,7 +889,7 @@ def eye(
if n_cols is None:
n_cols = n_rows
z = paddle.zeros([n_rows, n_cols], dtype=dtype, **kwargs).to(device)
if abs(k) <= n_rows + n_cols:
if n_rows > 0 and n_cols > 0 and abs(k) <= n_rows + n_cols:
z.diagonal(k).fill_(1)
return z

@@ -867,7 +924,11 @@ def full(
) -> array:
if isinstance(shape, int):
shape = (shape,)

if dtype is None :
if isinstance(fill_value, (bool)):
dtype = "bool"
elif isinstance(fill_value, int):
dtype = 'int64'
return paddle.full(shape, fill_value, dtype=dtype, **kwargs).to(device)


@@ -914,6 +975,8 @@ def triu(x: array, /, *, k: int = 0) -> array:


def expand_dims(x: array, /, *, axis: int = 0) -> array:
if axis < -x.ndim - 1 or axis > x.ndim:
raise IndexError(f"Axis {axis} is out of bounds for array of dimension { x.ndim}")
return paddle.unsqueeze(x, axis)


@@ -973,6 +1036,22 @@ def unique_values(x: array) -> array:

def matmul(x1: array, x2: array, /, **kwargs) -> array:
# paddle.matmul doesn't type promote (but differently from _fix_promotion)
d1 = x1.ndim
d2 = x2.ndim

if d1 == 0 or d2 == 0:
raise ValueError("matmul does not support 0-D (scalar) inputs.")

k1 = x1.shape[-1]

if d2 == 1:
k2 = x2.shape[0]
else:
k2 = x2.shape[-2]

if k1 != k2:
raise ValueError(f"Shapes {x1.shape} and {x2.shape} are not aligned for matmul: "
f"{k1} (dim -1 of x1) != {k2} (dim -2 of x2)")
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return paddle.matmul(x1, x2, **kwargs)

@@ -988,7 +1067,36 @@ def meshgrid(*arrays: array, indexing: str = "xy") -> List[array]:


def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
shape1 = x1.shape
shape2 = x2.shape
rank1 = len(shape1)
rank2 = len(shape2)
if rank1 == 0 or rank2 == 0:
raise ValueError(
f"Vector dot product requires non-scalar inputs (rank > 0). "
f"Got ranks {rank1} and {rank2} for shapes {shape1} and {shape2}."
)
try:
norm_axis1 = axis if axis >= 0 else rank1 + axis
if not (0 <= norm_axis1 < rank1):
raise IndexError # Axis out of bounds for x1
norm_axis2 = axis if axis >= 0 else rank2 + axis
if not (0 <= norm_axis2 < rank2):
raise IndexError # Axis out of bounds for x2
size1 = shape1[norm_axis1]
size2 = shape2[norm_axis2]
except IndexError:
raise ValueError(
f"Axis {axis} is out of bounds for input shapes {shape1} (rank {rank1}) "
f"and/or {shape2} (rank {rank2})."
)

if size1 != size2:
raise ValueError(
f"Inputs must have the same dimension size along the reduction axis ({axis}). "
f"Got shapes {shape1} and {shape2}, with sizes {size1} and {size2} "
f"along the normalized axis {norm_axis1} and {norm_axis2} respectively."
)
return paddle.linalg.vecdot(x1, x2, axis=axis)


@@ -1063,21 +1171,39 @@ def is_complex(dtype):


def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array:
if axis is None:
_axis = axis
if _axis is None:
if x.ndim != 1:
raise ValueError("axis must be specified when ndim > 1")
axis = 0
return paddle.index_select(x, axis, indices, **kwargs)
raise ValueError("axis must be specified when x.ndim > 1")
_axis = 0
elif not isinstance(_axis, int):
raise TypeError(f"axis must be an integer, but received {type(_axis)}")

if not (-x.ndim <= _axis < x.ndim):
raise IndexError(f"axis {_axis} is out of bounds for tensor of dimension {x.ndim}")

if isinstance(indices, paddle.Tensor):
indices_tensor = indices
elif isinstance(indices, int):
indices_tensor = paddle.to_tensor([indices], dtype='int64')
else:
# Otherwise (e.g., list, tuple), convert directly
indices_tensor = paddle.to_tensor(indices, dtype='int64')
# Ensure indices is a 1D tensor
if indices_tensor.ndim == 0:
indices_tensor = indices_tensor.reshape([1])
elif indices_tensor.ndim > 1:
raise ValueError(f"indices must be a 1D tensor, but received a {indices_tensor.ndim}D tensor")

return paddle.index_select(x, index=indices_tensor, axis=_axis)


def sign(x: array, /) -> array:
# paddle sign() does not support complex numbers and does not propagate
# nans. See https://github.com/data-apis/array-api-compat/issues/136
if paddle.is_complex(x):
out = x / paddle.abs(x)
# sign(0) = 0 but the above formula would give nan
out[x == 0 + 0j] = 0 + 0j
return out
if paddle.is_complex(x) and x.ndim == 0 and x.item() == 0j:
# Handle 0-D complex zero explicitly
return paddle.zeros_like(x, dtype=x.dtype)
else:
out = paddle.sign(x)
if paddle.is_floating_point(x):
47 changes: 40 additions & 7 deletions array_api_compat/paddle/linalg.py
Original file line number Diff line number Diff line change
@@ -84,6 +84,8 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr
# Use our wrapped sum to make sure it does upcasting correctly
return sum(paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1), axis=-1, dtype=dtype)

def diagonal(x: ndarray, / , *, offset: int = 0, **kwargs) -> ndarray:
return paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)

def vector_norm(
x: array,
@@ -123,24 +125,52 @@ def matrix_norm(
keepdims: bool = False,
ord: Optional[Union[int, float, Literal["fro", "nuc"]]] = "fro",
) -> array:
return paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims)

res = paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims)
if res.dtype == paddle.complex64 :
res = paddle.cast(res, "float32")
if res.dtype == paddle.complex128:
res = paddle.cast(res, "float64")
return res

def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
if rtol is None:
return paddle.linalg.pinv(x)

# change rtol shape
if isinstance(rtol, (int, float)):
rtol = paddle.to_tensor(rtol, dtype=x.dtype)

# broadcast rtol to [..., 1]
if rtol.ndim > 0:
rtol = rtol.unsqueeze(-1)

return paddle.linalg.pinv(x, rcond=rtol)


def slogdet(x: array):
det = paddle.linalg.det(x)
sign = paddle.sign(det)
log_det = paddle.log(det)
return tuple_to_namedtuple(paddle.linalg.slogdet(x), ["sign", "logabsdet"])

def tuple_to_namedtuple(data, fields):
nt_class = namedtuple('DynamicNameTuple', fields)
return nt_class(*data)

def eigh(x: array):
return tuple_to_namedtuple(paddle.linalg.eigh(x), ['eigenvalues', 'eigenvectors'])

def qr(x: array, mode: Optional[str] = None) -> array:
if mode is None:
return tuple_to_namedtuple(paddle.linalg.qr(x), ['Q', 'R'])

return tuple_to_namedtuple(paddle.linalg.qr(x, mode), ['Q', 'R'])


slotdet = namedtuple("slotdet", ["sign", "logabsdet"])
return slotdet(sign, log_det)
def svd(x: array, full_matrices: Optional[bool]= None) -> array:
if full_matrices is None :
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices=True), ['U', 'S', 'Vh'])
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices), ['U', 'S', 'Vh'])

def svdvals(x: array) -> array:
return paddle.linalg.svd(x)[1]

__all__ = linalg_all + [
"outer",
@@ -154,6 +184,9 @@ def slogdet(x: array):
"trace",
"vector_norm",
"slogdet",
"eigh",
"diagonal",
"svdvals"
]

_all_ignore = ["paddle_linalg", "sum"]
59 changes: 59 additions & 0 deletions paddle-xfails.txt
Original file line number Diff line number Diff line change
@@ -106,3 +106,62 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_ceil
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
array_api_tests/test_searching_functions.py::test_where

array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[pow(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_hypot
array_api_tests/test_operators_and_elementwise_functions.py::test_copysign
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
array_api_tests/test_linalg.py::test_outer
array_api_tests/test_linalg.py::test_vecdot
array_api_tests/test_operators_and_elementwise_functions.py::test_clip
array_api_tests/test_manipulation_functions.py::test_stack
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide

# do not pass
array_api_tests/test_has_names[array_attribute-device]
array_api_tests/test_signatures.py::test_func_signature[meshgrid]

array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
array_api_tests/test_indexing_functions.py::test_take
array_api_tests/test_linalg.py::test_linalg_vecdot
array_api_tests/test_creation_functions.py::test_asarray_arrays

array_api_tests/test_linalg.py::test_qr

array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor

# test exceeds the deadline of 800ms
array_api_tests/test_linalg.py::test_pinv
array_api_tests/test_linalg.py::test_det

# only supports access to dimension 0 to 9, but received dimension is 10.
array_api_tests/test_linalg.py::test_tensordot
array_api_tests/test_linalg.py::test_linalg_tensordot