Skip to content

Add broadcast_tensors alias, modify result_type #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions array_api_compat/paddle/_aliases.py
Original file line number Diff line number Diff line change
@@ -112,25 +112,32 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
raise TypeError("At least one array or dtype must be provided")
if len(arrays_and_dtypes) == 1:
x = arrays_and_dtypes[0]
if isinstance(x, paddle.dtype):
return x
return x.dtype
return x if isinstance(x, paddle.dtype) else x.dtype
if len(arrays_and_dtypes) > 2:
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))

x, y = arrays_and_dtypes
xdt = x.dtype if not isinstance(x, paddle.dtype) else x
ydt = y.dtype if not isinstance(y, paddle.dtype) else y
xdt = x if isinstance(x, paddle.dtype) else x.dtype
ydt = y if isinstance(y, paddle.dtype) else y.dtype

if (xdt, ydt) in _promotion_table:
return _promotion_table[xdt, ydt]

# This doesn't result_type(dtype, dtype) for non-array API dtypes
# because paddle.result_type only accepts tensors. This does however, allow
# cross-kind promotion.
x = paddle.to_tensor([], dtype=x) if isinstance(x, paddle.dtype) else x
y = paddle.to_tensor([], dtype=y) if isinstance(y, paddle.dtype) else y
return paddle.result_type(x, y)
return _promotion_table[(xdt, ydt)]

type_order = {
paddle.bool: 0,
paddle.int8: 1,
paddle.uint8: 2,
paddle.int16: 3,
paddle.int32: 4,
paddle.int64: 5,
paddle.float16: 6,
paddle.float32: 7,
paddle.float64: 8,
paddle.complex64: 9,
paddle.complex128: 10
}

return xdt if type_order.get(xdt, 0) > type_order.get(ydt, 0) else ydt


def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
@@ -922,7 +929,15 @@ def astype(


def broadcast_arrays(*arrays: array) -> List[array]:
return paddle.broadcast_tensors(arrays)
original_dtypes = [arr.dtype for arr in arrays]
if len(set(original_dtypes)) == 1:
return paddle.broadcast_tensors(arrays)
target_dtype = result_type(*arrays)
casted_arrays = [arr.astype(target_dtype) if arr.dtype != target_dtype else arr
for arr in arrays]
broadcasted = paddle.broadcast_tensors(casted_arrays)
result = [arr.astype(original_dtype) for arr, original_dtype in zip(broadcasted, original_dtypes)]
return result


# Note that these named tuples aren't actually part of the standard namespace,