Skip to content
Draft
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8e5cc94
add paddle support in array-api-compat
HydrogenSulfate Nov 26, 2024
7118894
update README
HydrogenSulfate Nov 26, 2024
85dc3ba
update promotion table and can_cast table
HydrogenSulfate Nov 26, 2024
c5b82db
update doc
HydrogenSulfate Nov 26, 2024
7b99449
restore code
HydrogenSulfate Nov 26, 2024
bb40851
update docstring
HydrogenSulfate Nov 26, 2024
a7163f9
refine more code
HydrogenSulfate Nov 26, 2024
ec46178
add suffix for test_python_scalars and add paddle index-url in rqeuir…
HydrogenSulfate Nov 26, 2024
dfd4485
update paddle code
HydrogenSulfate Dec 3, 2024
5ae8ec8
fix
HydrogenSulfate Dec 3, 2024
b10273b
update code
HydrogenSulfate Dec 10, 2024
8d2425e
fix moveaxis
HydrogenSulfate Dec 14, 2024
7b8555e
fix default floating dtype of paddle.assaray
HydrogenSulfate Jan 8, 2025
603c852
use default_dtype only when dtype is None
HydrogenSulfate Jan 9, 2025
742792f
add floor and ceil with same return dtype
HydrogenSulfate Jan 9, 2025
fd6eea0
update code
HydrogenSulfate Jan 9, 2025
6f32d63
update
Mar 31, 2025
37785d4
Add broadcast_tensors alias, modify result_type
cangtianhuang Apr 1, 2025
0651731
refine
cangtianhuang Apr 1, 2025
2d4e571
Merge pull request #1 from cangtianhuang/support_paddle
HydrogenSulfate Apr 1, 2025
888966f
Merge branch 'HydrogenSulfate:support_paddle' into support_paddle
aquagull Apr 1, 2025
372283a
update
Apr 1, 2025
46f81c7
Merge branch 'support_paddle' of https://github.com/aquagull/array-ap…
Apr 1, 2025
b946e82
update
Apr 3, 2025
34af083
update
Apr 7, 2025
0dbf7dd
update
May 1, 2025
912fe3e
add paddle skip and xfail files
HydrogenSulfate May 12, 2025
add32c9
Merge branch 'HydrogenSulfate:support_paddle' into support_paddle
aquagull May 13, 2025
13e2782
update
Jun 7, 2025
e6cf011
update
Jun 7, 2025
67aa9ef
updat
Jun 7, 2025
41a93be
Merge pull request #2 from aquagull/support_paddle
HydrogenSulfate Jun 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/array-api-tests-paddle.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: Array API Tests (Paddle Latest)

on: [push, pull_request]

jobs:
array-api-tests-paddle:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: paddle
extra-env-vars: |
ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -2,8 +2,8 @@

This is a small wrapper around common array libraries that is compatible with
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want
NumPy, CuPy, PyTorch, Dask, JAX, ndonnx, `sparse` and Paddle are supported. If you want
support for other array libraries, or if you encounter any issues, please [open
an issue](https://github.com/data-apis/array-api-compat/issues).

See the documentation for more details https://data-apis.org/array-api-compat/
See the documentation for more details <https://data-apis.org/array-api-compat/>
78 changes: 78 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
@@ -120,6 +120,32 @@ def is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def is_paddle_array(x):
"""
Return True if `x` is a Paddle tensor.
This function does not import Paddle if it has not already been imported
and is therefore cheap to use.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing paddle if it isn't already
if 'paddle' not in sys.modules:
return False

import paddle

return paddle.is_tensor(x)

def is_ndonnx_array(x):
"""
Return True if `x` is a ndonnx Array.
@@ -252,6 +278,7 @@ def is_array_api_obj(x):
or is_dask_array(x) \
or is_jax_array(x) \
or is_pydata_sparse_array(x) \
or is_paddle_array(x) \
or hasattr(x, '__array_namespace__')

def _compat_module_name():
@@ -319,6 +346,27 @@ def is_torch_namespace(xp) -> bool:
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}


def is_paddle_namespace(xp) -> bool:
"""
Returns True if `xp` is a Paddle namespace.
This includes both Paddle itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'paddle', _compat_module_name() + '.paddle'}


def is_ndonnx_namespace(xp):
"""
Returns True if `xp` is an NDONNX namespace.
@@ -543,6 +591,14 @@ def your_function(x, y):
else:
import jax.experimental.array_api as jnp
namespaces.add(jnp)
elif is_paddle_array(x):
if _use_compat:
_check_api_version(api_version)
from .. import paddle as paddle_namespace
namespaces.add(paddle_namespace)
else:
import paddle
namespaces.add(paddle)
elif is_pydata_sparse_array(x):
if use_compat is True:
_check_api_version(api_version)
@@ -660,6 +716,16 @@ def device(x: Array, /) -> Device:
return "cpu"
# Return the device of the constituent array
return device(inner)
elif is_paddle_array(x):
raw_place_str = str(x.place)
if "gpu_pinned" in raw_place_str:
return "cpu"
elif "cpu" in raw_place_str:
return "cpu"
elif "gpu" in raw_place_str:
return "gpu"
raise ValueError(f"Unsupported Paddle device: {x.place}")

return x.device

# Prevent shadowing, used below
@@ -709,6 +775,14 @@ def _torch_to_device(x, device, /, stream=None):
raise NotImplementedError
return x.to(device)

def _paddle_to_device(x, device, /, stream=None):
if stream is not None:
raise NotImplementedError(
"paddle.Tensor.to() do not support stream argument yet"
)
return x.to(device)


def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
@@ -781,6 +855,8 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
# In JAX v0.4.31 and older, this import adds to_device method to x.
import jax.experimental.array_api # noqa: F401
return x.to_device(device, stream=stream)
elif is_paddle_array(x):
return _paddle_to_device(x, device, stream=stream)
elif is_pydata_sparse_array(x) and device == _device(x):
# Perform trivial check to return the same array if
# device is same instead of err-ing.
@@ -819,6 +895,8 @@ def size(x):
"is_torch_namespace",
"is_ndonnx_array",
"is_ndonnx_namespace",
"is_paddle_array",
"is_paddle_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"size",
22 changes: 22 additions & 0 deletions array_api_compat/paddle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from paddle import * # noqa: F403

# Several names are not included in the above import *
import paddle

for n in dir(paddle):
if n.startswith("_") or n.endswith("_") or "gpu" in n or "cpu" in n or "backward" in n:
continue
exec(f"{n} = paddle.{n}")


# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403

# See the comment in the numpy __init__.py
__import__(__package__ + ".linalg")

__import__(__package__ + ".fft")

from ..common._helpers import * # noqa: F403

__array_api_version__ = "2023.12"
1,487 changes: 1,487 additions & 0 deletions array_api_compat/paddle/_aliases.py

Large diffs are not rendered by default.

380 changes: 380 additions & 0 deletions array_api_compat/paddle/_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,380 @@
"""
Array API Inspection namespace
This is the namespace for inspection functions as defined by the array API
standard. See
https://data-apis.org/array-api/latest/API_specification/inspection.html for
more details.
"""

import paddle

from functools import cache


class __array_namespace_info__:
"""
Get the array API inspection namespace for Paddle.
The array API inspection namespace defines the following functions:
- capabilities()
- default_device()
- default_dtypes()
- dtypes()
- devices()
See
https://data-apis.org/array-api/latest/API_specification/inspection.html
for more details.
Returns
-------
info : ModuleType
The array API inspection namespace for Paddle.
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': numpy.float64,
'complex floating': numpy.complex128,
'integral': numpy.int64,
'indexing': numpy.int64}
"""

__module__ = "paddle"

def capabilities(self):
"""
Return a dictionary of array API library capabilities.
The resulting dictionary has the following keys:
- **"boolean indexing"**: boolean indicating whether an array library
supports boolean indexing. Always ``True`` for Paddle.
- **"data-dependent shapes"**: boolean indicating whether an array
library supports data-dependent output shapes. Always ``True`` for
Paddle.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
for more details.
See Also
--------
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
capabilities : dict
A dictionary of array API library capabilities.
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.capabilities()
{'boolean indexing': True,
'data-dependent shapes': True}
"""
return {
"boolean indexing": True,
"data-dependent shapes": True,
# 'max rank' will be part of the 2024.12 standard
# "max rank": 64,
}

def default_device(self):
"""
The default device used for new Paddle arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
device : str
The default device used for new Paddle arrays.
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.default_device()
'cpu'
"""
return paddle.device.get_device()

def default_dtypes(self, *, device=None):
"""
The default data types used for new Paddle arrays.
Parameters
----------
device : str, optional
The device to get the default data types for. For Paddle, only
``'cpu'`` is allowed.
Returns
-------
dtypes : dict
A dictionary describing the default data types used for new Paddle
arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': paddle.float32,
'complex floating': paddle.complex64,
'integral': paddle.int64,
'indexing': paddle.int64}
"""
# Note: if the default is set to float64, the devices like MPS that
# don't support float64 will error. We still return the default_dtype
# value here because this error doesn't represent a different default
# per-device.
default_floating = paddle.get_default_dtype()
if default_floating in ["float16", "float32", "float64", "bfloat16"]:
default_floating = getattr(paddle, default_floating)
else:
raise ValueError(f"Unsupported default floating: {default_floating}")
default_complex = (
paddle.complex64
if default_floating == paddle.float32
else paddle.complex128
)
default_integral = paddle.int64
return {
"real floating": default_floating,
"complex floating": default_complex,
"integral": default_integral,
"indexing": default_integral,
}

def _dtypes(self, kind):
bool = paddle.bool
int8 = paddle.int8
int16 = paddle.int16
int32 = paddle.int32
int64 = paddle.int64
uint8 = paddle.uint8
# uint16, uint32, and uint64 are not fully supported in paddle,
# we omit them from this function.
float32 = paddle.float32
float64 = paddle.float64
complex64 = paddle.complex64
complex128 = paddle.complex128

if kind is None:
return {
"bool": bool,
"int8": int8,
"int16": int16,
"int32": int32,
"int64": int64,
"uint8": uint8,
"float32": float32,
"float64": float64,
"complex64": complex64,
"complex128": complex128,
}
if kind == "bool":
return {"bool": bool}
if kind == "signed integer":
return {
"int8": int8,
"int16": int16,
"int32": int32,
"int64": int64,
}
if kind == "unsigned integer":
return {
"uint8": uint8,
}
if kind == "integral":
return {
"int8": int8,
"int16": int16,
"int32": int32,
"int64": int64,
"uint8": uint8,
}
if kind == "real floating":
return {
"float32": float32,
"float64": float64,
}
if kind == "complex floating":
return {
"complex64": complex64,
"complex128": complex128,
}
if kind == "numeric":
return {
"int8": int8,
"int16": int16,
"int32": int32,
"int64": int64,
"uint8": uint8,
"float32": float32,
"float64": float64,
"complex64": complex64,
"complex128": complex128,
}
if isinstance(kind, tuple):
res = {}
for k in kind:
res.update(self.dtypes(kind=k))
return res
raise ValueError(f"unsupported kind: {kind!r}")

@cache
def dtypes(self, *, device=None, kind=None):
"""
The array API data types supported by Paddle.
Note that this function only returns data types that are defined by
the array API.
Parameters
----------
device : str, optional
The device to get the data types for.
kind : str or tuple of str, optional
The kind of data types to return. If ``None``, all data types are
returned. If a string, only data types of that kind are returned.
If a tuple, a dictionary containing the union of the given kinds
is returned. The following kinds are supported:
- ``'bool'``: boolean data types (i.e., ``bool``).
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
``int16``, ``int32``, ``int64``).
- ``'unsigned integer'``: unsigned integer data types (i.e.,
``uint8``, ``uint16``, ``uint32``, ``uint64``).
- ``'integral'``: integer data types. Shorthand for ``('signed
integer', 'unsigned integer')``.
- ``'real floating'``: real-valued floating-point data types
(i.e., ``float32``, ``float64``).
- ``'complex floating'``: complex floating-point data types (i.e.,
``complex64``, ``complex128``).
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
'real floating', 'complex floating')``.
Returns
-------
dtypes : dict
A dictionary mapping the names of data types to the corresponding
Paddle data types.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.dtypes(kind='signed integer')
{'int8': numpy.int8,
'int16': numpy.int16,
'int32': numpy.int32,
'int64': numpy.int64}
"""
res = self._dtypes(kind)
for k, v in res.copy().items():
try:
paddle.empty((0,), dtype=v, device=device)
except:
del res[k]
return res

@cache
def devices(self):
"""
The devices supported by Paddle.
Returns
-------
devices : list of str
The devices supported by Paddle.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.devices()
[device(type='cpu'), device(type='mps', index=0), device(type='meta')]
"""
# Paddle doesn't have a straightforward way to get the list of all
# currently supported devices. To do this, we first parse the error
# message of paddle.device to get the list of all possible types of
# device:
try:
paddle.set_device("notadevice")
except ValueError as e:
# The error message is something like:
# ValueError: The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu', 'npu:x
devices_names = (
e.args[0]
.split("The device must be a string which is like ")[1]
.split(", ")
)
devices_names = [
name.strip("'") for name in devices_names if ":" not in name
]

# Next we need to check for different indices for different devices.
# device(device_name, index=index) doesn't actually check if the
# device name or index is valid. We have to try to create a tensor
# with it (which is why this function is cached).
devices = []
for device_name in devices_names:
i = 0
while True:
try:
if device_name == "cpu":
a = paddle.empty((0,), place=paddle.CPUPlace())
elif device_name == "gpu":
a = paddle.empty((0,), place=paddle.CUDAPlace(i))
elif device_name == "xpu":
a = paddle.empty((0,), place=paddle.XPUPlace())
else:
raise
if a.place in devices:
break
devices.append(a.device)
except:
break
i += 1

return devices
121 changes: 121 additions & 0 deletions array_api_compat/paddle/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import paddle
from ..common._typing import Device

array = paddle.Tensor
from typing import Optional, Union, Sequence, Literal

from paddle.fft import * # noqa: F403
import paddle.fft


def fftn(
x: array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return paddle.fft.fftn(x, s=s, axes=axes, norm=norm, **kwargs)


def ifftn(
x: array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return paddle.fft.ifftn(x, s=s, axes=axes, norm=norm, **kwargs)


def rfftn(
x: array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return paddle.fft.rfftn(x, s=s, axes=axes, norm=norm, **kwargs)


def irfftn(
x: array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return paddle.fft.irfftn(x, s=s, axes=axes, norm=norm, **kwargs)


def fftshift(
x: array,
/,
*,
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
return paddle.fft.fftshift(x, axes=axes, **kwargs)


def ifftshift(
x: array,
/,
*,
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
return paddle.fft.ifftshift(x, axes=axes, **kwargs)


def fftfreq(
n: int,
/,
*,
d: float = 1.0,
device: Optional[Device] = None,
) -> array:
out = paddle.fft.fftfreq(n, d)
if device is not None:
out = out.to(device)
return out


def rfftfreq(
n: int,
/,
*,
d: float = 1.0,
device: Optional[Device] = None,
) -> array:
out = paddle.fft.rfftfreq(n, d)
if device is not None:
out = out.to(device)
return out


__all__ = paddle.fft.__all__ + [
"fftn",
"ifftn",
"rfftn",
"irfftn",
"fftshift",
"ifftshift",
"fftfreq",
"rfftfreq",
]

_all_ignore = ["paddle"]
194 changes: 194 additions & 0 deletions array_api_compat/paddle/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import paddle

array = paddle.Tensor
from paddle import dtype as Dtype
from typing import Optional, Union, Tuple, Literal

inf = float("inf")

from ._aliases import _fix_promotion, sum
from collections import namedtuple

import paddle
from paddle.linalg import * # noqa: F403

# paddle.linalg doesn't define __all__
# from paddle.linalg import __all__ as linalg_all
from paddle import linalg as paddle_linalg

linalg_all = [i for i in dir(paddle_linalg) if not i.startswith("_")]

# outer is implemented in paddle but aren't in the linalg namespace
from paddle import outer
import paddle

# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot

# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3)


# paddle.cross also does not support broadcasting when it would add new
# dimensions
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")

if not (x1.shape[axis] == x2.shape[axis] == 3):
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")

x1, x2 = paddle.broadcast_tensors([x1, x2])
return paddle_linalg.cross(x1, x2, axis=axis)


def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype

x1, x2 = _fix_promotion(x1, x2, only_scalar=False)

# paddle.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

# paddle.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, "integral") or isdtype(x2.dtype, "integral"):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")

x1_ = paddle.moveaxis(x1, axis, -1)
x2_ = paddle.moveaxis(x2, axis, -1)
x1_, x2_ = paddle.broadcast_tensors([x1_, x2_])

res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
return paddle.linalg.vecdot(x1, x2, axis=axis, **kwargs)


def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)

if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
x2 = x2[None]
return paddle.linalg.solve(x1, x2, **kwargs)


# paddle.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
return sum(paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1), axis=-1, dtype=dtype)

def diagonal(x: ndarray, / , *, offset: int = 0, **kwargs) -> ndarray:
return paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)

def vector_norm(
x: array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
ord: Union[int, float, Literal[inf, -inf]] = 2,
**kwargs,
) -> array:
# paddle.vector_norm incorrectly treats axis=() the same as axis=None
if axis == ():
out = kwargs.get("out")
if out is None:
dtype = None
if x.dtype == paddle.complex64:
dtype = paddle.float32
elif x.dtype == paddle.complex128:
dtype = paddle.float64

out = paddle.zeros_like(x, dtype=dtype)

# The norm of a single scalar works out to abs(x) in every case except
# for ord=0, which is x != 0.
if ord == 0:
out[:] = x != 0
else:
out[:] = paddle.abs(x)
return out
return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs)


def matrix_norm(
x: array,
/,
*,
keepdims: bool = False,
ord: Optional[Union[int, float, Literal["fro", "nuc"]]] = "fro",
) -> array:
res = paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims)
if res.dtype == paddle.complex64 :
res = paddle.cast(res, "float32")
if res.dtype == paddle.complex128:
res = paddle.cast(res, "float64")
return res

def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
if rtol is None:
return paddle.linalg.pinv(x)

# change rtol shape
if isinstance(rtol, (int, float)):
rtol = paddle.to_tensor(rtol, dtype=x.dtype)

# broadcast rtol to [..., 1]
if rtol.ndim > 0:
rtol = rtol.unsqueeze(-1)

return paddle.linalg.pinv(x, rcond=rtol)


def slogdet(x: array):
return tuple_to_namedtuple(paddle.linalg.slogdet(x), ["sign", "logabsdet"])

def tuple_to_namedtuple(data, fields):
nt_class = namedtuple('DynamicNameTuple', fields)
return nt_class(*data)

def eigh(x: array):
return tuple_to_namedtuple(paddle.linalg.eigh(x), ['eigenvalues', 'eigenvectors'])

def qr(x: array, mode: Optional[str] = None) -> array:
if mode is None:
return tuple_to_namedtuple(paddle.linalg.qr(x), ['Q', 'R'])

return tuple_to_namedtuple(paddle.linalg.qr(x, mode), ['Q', 'R'])


def svd(x: array, full_matrices: Optional[bool]= None) -> array:
if full_matrices is None :
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices=True), ['U', 'S', 'Vh'])
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices), ['U', 'S', 'Vh'])

def svdvals(x: array) -> array:
return paddle.linalg.svd(x)[1]

__all__ = linalg_all + [
"outer",
"matmul",
"matrix_transpose",
"matrix_norm",
"tensordot",
"cross",
"vecdot",
"solve",
"trace",
"vector_norm",
"slogdet",
"eigh",
"diagonal",
"svdvals"
]

_all_ignore = ["paddle_linalg", "sum"]

del linalg_all
4 changes: 4 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -60,6 +60,10 @@ import array_api_compat.torch as torch
import array_api_compat.dask as da
```

```py
import array_api_compat.paddle as paddle
```

```{note}
There are no `array_api_compat` submodules for JAX, sparse, or ndonnx. These
support for these libraries is contained in the libraries themselves (JAX
22 changes: 22 additions & 0 deletions docs/supported-array-libraries.md
Original file line number Diff line number Diff line change
@@ -137,3 +137,25 @@ The minimum supported Dask version is 2023.12.0.
## [Sparse](https://sparse.pydata.org/en/stable/)

Similar to JAX, `sparse` Array API support is contained directly in `sparse`.

## [Paddle](https://www.paddlepaddle.org.cn/)

- Like NumPy/CuPy, we do not wrap the `paddle.Tensor` object. It is missing the
`__array_namespace__` and `to_device` methods, so the corresponding helper
functions {func}`~.array_namespace()` and {func}`~.to_device()` in this
library should be used instead.

- Paddle does not have unsigned integer types other than `uint8`, and no
attempt is made to implement them here.

- [`std()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html#array_api.std)
and
[`var()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html#array_api.var)
do not support floating-point `correction` except for `0.0` and `1.0`.

- The `stream` argument of the {func}`~.to_device()` helper is not supported.

- As with NumPy, type annotations and positional-only arguments may not
exactly match the spec for functions that are not wrapped at all.

The minimum supported PyTorch version is 3.0.0.
6 changes: 6 additions & 0 deletions paddle-skips.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
array_api_tests/test_array_object.py::test_getitem_masking
array_api_tests/test_data_type_functions.py::test_result_type
array_api_tests/test_data_type_functions.py::test_broadcast_arrays
array_api_tests/test_manipulation_functions.py::test_roll
array_api_tests/test_data_type_functions.py::test_broadcast_to
array_api_tests/test_linalg.py::test_cholesky
167 changes: 167 additions & 0 deletions paddle-xfails.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Skip 'copy=...'
array_api_tests/test_array_object.py::test_setitem
array_api_tests/test_array_object.py::test_setitem_masking
# array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x1, x2)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x1, x2)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x, s)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x1, x2)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x, s)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x1, x2)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x, s)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x1, x2)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x, s)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x1, x2)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x, s)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x1, x2)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x, s)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x1, x2)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x1, x2)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x1, x2)]
# array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)]

# Skip promotion test for 'Scalar op Tensor'
array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]

# torch do not pass
array_api_tests/test_creation_functions.py::test_asarray_scalars
array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_creation_functions.py::test_empty_like
array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_full
array_api_tests/test_creation_functions.py::test_full_like
array_api_tests/test_creation_functions.py::test_linspace
array_api_tests/test_creation_functions.py::test_ones
array_api_tests/test_creation_functions.py::test_ones_like
array_api_tests/test_creation_functions.py::test_zeros
array_api_tests/test_creation_functions.py::test_zeros_like
array_api_tests/test_fft.py::test_fft
array_api_tests/test_fft.py::test_ifft
array_api_tests/test_fft.py::test_fftn
array_api_tests/test_fft.py::test_ifftn
array_api_tests/test_fft.py::test_rfft
array_api_tests/test_fft.py::test_irfft
array_api_tests/test_fft.py::test_rfftn
array_api_tests/test_fft.py::test_hfft
array_api_tests/test_fft.py::test_ihfft
array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
array_api_tests/test_indexing_functions.py::test_take
array_api_tests/test_linalg.py::test_linalg_matmul
array_api_tests/test_linalg.py::test_qr
array_api_tests/test_linalg.py::test_solve
array_api_tests/test_manipulation_functions.py::test_concat
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_round
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]
array_api_tests/test_set_functions.py::test_unique_all
array_api_tests/test_set_functions.py::test_unique_counts
array_api_tests/test_set_functions.py::test_unique_inverse
array_api_tests/test_set_functions.py::test_unique_values
array_api_tests/test_signatures.py::test_func_signature[astype]
array_api_tests/test_signatures.py::test_func_signature[repeat]
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
array_api_tests/test_signatures.py::test_array_method_signature[to_device]
array_api_tests/test_sorting_functions.py::test_argsort
array_api_tests/test_sorting_functions.py::test_sort
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[bitwise_and(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[bitwise_left_shift(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[bitwise_or(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[bitwise_right_shift(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)]

# dtype promotion related
array_api_tests/test_operators_and_elementwise_functions.py::test_floor
array_api_tests/test_operators_and_elementwise_functions.py::test_ceil
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
array_api_tests/test_searching_functions.py::test_where

array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[pow(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_hypot
array_api_tests/test_operators_and_elementwise_functions.py::test_copysign
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
array_api_tests/test_linalg.py::test_outer
array_api_tests/test_linalg.py::test_vecdot
array_api_tests/test_operators_and_elementwise_functions.py::test_clip
array_api_tests/test_manipulation_functions.py::test_stack
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide

# do not pass
array_api_tests/test_has_names[array_attribute-device]
array_api_tests/test_signatures.py::test_func_signature[meshgrid]

array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
array_api_tests/test_indexing_functions.py::test_take
array_api_tests/test_linalg.py::test_linalg_vecdot
array_api_tests/test_creation_functions.py::test_asarray_arrays

array_api_tests/test_linalg.py::test_qr

array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor

# test exceeds the deadline of 800ms
array_api_tests/test_linalg.py::test_pinv
array_api_tests/test_linalg.py::test_det

# only supports access to dimension 0 to 9, but received dimension is 10.
array_api_tests/test_linalg.py::test_tensordot
array_api_tests/test_linalg.py::test_linalg_tensordot
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -4,5 +4,6 @@ jax[cpu]
numpy
pytest
torch
paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
sparse >=0.15.1
ndonnx
7 changes: 6 additions & 1 deletion tests/_helpers.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@

import pytest

wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array", "paddle"]
all_libraries = wrapped_libraries + ["jax.numpy"]

# `sparse` added array API support as of Python 3.10.
@@ -25,4 +25,9 @@ def import_(library, wrapper=False):
else:
library = 'array_api_compat.' + library

if library == 'paddle':
xp = import_module(library)
xp.asarray = xp.to_tensor
return xp

return import_module(library)
25 changes: 24 additions & 1 deletion tests/test_array_namespace.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import numpy as np
import pytest
import torch
import paddle

import array_api_compat
from array_api_compat import array_namespace
@@ -91,6 +92,12 @@ def test_array_namespace_errors_torch():
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))


def test_array_namespace_errors_paddle():
y = paddle.to_tensor([1, 2])
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))

def test_api_version():
x = torch.asarray([1, 2])
torch_ = import_("torch", wrapper=True)
@@ -115,7 +122,7 @@ def test_get_namespace():
# Backwards compatible wrapper
assert array_api_compat.get_namespace is array_api_compat.array_namespace

def test_python_scalars():
def test_python_scalars_torch():
a = torch.asarray([1, 2])
xp = import_("torch", wrapper=True)

@@ -130,3 +137,19 @@ def test_python_scalars():
assert array_namespace(a, 1j) == xp
assert array_namespace(a, True) == xp
assert array_namespace(a, None) == xp

def test_python_scalars_paddle():
a = paddle.to_tensor([1, 2])
xp = import_("paddle", wrapper=True)

pytest.raises(TypeError, lambda: array_namespace(1))
pytest.raises(TypeError, lambda: array_namespace(1.0))
pytest.raises(TypeError, lambda: array_namespace(1j))
pytest.raises(TypeError, lambda: array_namespace(True))
pytest.raises(TypeError, lambda: array_namespace(None))

assert array_namespace(a, 1) == xp
assert array_namespace(a, 1.0) == xp
assert array_namespace(a, 1j) == xp
assert array_namespace(a, True) == xp
assert array_namespace(a, None) == xp
15 changes: 13 additions & 2 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from array_api_compat import ( # noqa: F401
is_numpy_array, is_cupy_array, is_torch_array,
is_numpy_array, is_cupy_array, is_torch_array, is_paddle_array,
is_dask_array, is_jax_array, is_pydata_sparse_array,
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, is_paddle_namespace,
)

from array_api_compat import is_array_api_obj, device, to_device
@@ -21,6 +21,7 @@
'dask.array': 'is_dask_array',
'jax.numpy': 'is_jax_array',
'sparse': 'is_pydata_sparse_array',
'paddle': 'is_paddle_array',
}

is_namespace_functions = {
@@ -30,6 +31,7 @@
'dask.array': 'is_dask_namespace',
'jax.numpy': 'is_jax_namespace',
'sparse': 'is_pydata_sparse_namespace',
'paddle': 'is_paddle_namespace',
}


@@ -101,6 +103,13 @@ def test_asarray_cross_library(source_library, target_library, request):
if source_library == "cupy" and target_library != "cupy":
# cupy explicitly disallows implicit conversions to CPU
pytest.skip(reason="cupy does not support implicit conversion to CPU")
if source_library == "paddle" or target_library == "paddle":
pytest.skip(
reason=(
"paddle does not support implicit conversion from/to other framework "
"via 'asarray', dlpack is recommend now."
)
)
elif source_library == "sparse" and target_library != "sparse":
pytest.skip(reason="`sparse` does not allow implicit densification")
src_lib = import_(source_library, wrapper=True)
@@ -114,6 +123,8 @@ def test_asarray_cross_library(source_library, target_library, request):

@pytest.mark.parametrize("library", wrapped_libraries)
def test_asarray_copy(library):
if library == 'paddle':
pytest.skip("Paddle does not support explicit copies")
# Note, we have this test here because the test suite currently doesn't
# test the copy flag to asarray() very rigorously. Once
# https://github.com/data-apis/array-api-tests/issues/241 is fixed we
2 changes: 1 addition & 1 deletion tests/test_isdtype.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@
# Check the known dtypes by their string names

def _spec_dtypes(library):
if library == 'torch':
if library in ['torch', 'paddle']:
# torch does not have unsigned integer dtypes
return {
'bool',
8 changes: 6 additions & 2 deletions tests/test_no_dependencies.py
Original file line number Diff line number Diff line change
@@ -49,8 +49,12 @@ def _test_dependency(mod):
# TODO: Test that wrapper for library X doesn't depend on wrappers for library
# Y (except most array libraries actually do themselves depend on numpy).

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array",
"jax.numpy", "sparse", "array_api_strict"])
@pytest.mark.parametrize("library",
[
"numpy", "cupy", "numpy", "torch", "dask.array",
"jax.numpy", "sparse", "paddle", "array_api_strict"
]
)
def test_numpy_dependency(library):
# This import is here because it imports numpy
from ._helpers import import_
6 changes: 6 additions & 0 deletions tests/test_vendoring.py
Original file line number Diff line number Diff line change
@@ -24,3 +24,9 @@ def test_vendoring_torch():
def test_vendoring_dask():
from vendor_test import uses_dask
uses_dask._test_dask()


def test_vendoring_paddle():
from vendor_test import uses_paddle

uses_paddle._test_paddle()
30 changes: 30 additions & 0 deletions vendor_test/uses_paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Basic test that vendoring works

from .vendored._compat import (
is_paddle_array,
is_paddle_namespace,
paddle as paddle_compat,
)

import paddle

def _test_paddle():
a = paddle_compat.to_tensor([1., 2., 3.])
b = paddle_compat.arange(3, dtype=paddle_compat.float64)
assert a.dtype == paddle_compat.float32 == paddle.float32
assert b.dtype == paddle_compat.float64 == paddle.float64

# paddle.expand_dims does not exist. Update this to use something else if it is added
res = paddle_compat.expand_dims(a, axis=0)
assert res.dtype == paddle_compat.float32 == paddle.float32
assert res.shape == [1, 3]
assert isinstance(res.shape, list)
assert isinstance(a, paddle.Tensor)
assert isinstance(b, paddle.Tensor)
assert isinstance(res, paddle.Tensor)

assert paddle.allclose(res, paddle.to_tensor([[1., 2., 3.]]))

assert is_paddle_array(res)
assert is_paddle_namespace(paddle) and is_paddle_namespace(paddle_compat)