Skip to content

Commit e1eb827

Browse files
authored
Fix array __abs__ method for complex inputs (#294)
1 parent 553601e commit e1eb827

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

cubed/array_api/array_object.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
_integer_dtypes,
1414
_integer_or_boolean_dtypes,
1515
_numeric_dtypes,
16+
complex64,
17+
complex128,
18+
float32,
19+
float64,
1620
)
1721
from cubed.array_api.linear_algebra_functions import matmul
1822
from cubed.core.array import CoreArray
@@ -337,7 +341,13 @@ def __rrshift__(self, other, /):
337341
def __abs__(self, /):
338342
if self.dtype not in _numeric_dtypes:
339343
raise TypeError("Only numeric dtypes are allowed in __abs__")
340-
return elemwise(np.abs, self, dtype=self.dtype)
344+
if self.dtype == complex64:
345+
dtype = float32
346+
elif self.dtype == complex128:
347+
dtype = float64
348+
else:
349+
dtype = self.dtype
350+
return elemwise(np.abs, self, dtype=dtype)
341351

342352
def __array_namespace__(self, /, *, api_version=None):
343353
if api_version is not None and not api_version.startswith("2021."):

0 commit comments

Comments
 (0)