From 5120204a4187edd8c669ec6f8af35ba8339cf9ac Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 21 Nov 2024 15:41:40 +0200 Subject: [PATCH 1/3] ENH: test can_cast(complex dtypes) --- array_api_tests/test_data_type_functions.py | 22 ++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 34c40024..0fa49753 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -19,10 +19,18 @@ def non_complex_dtypes(): return xps.boolean_dtypes() | hh.real_dtypes +def numeric_dtypes(): + return xps.boolean_dtypes() | hh.real_dtypes | hh.complex_dtypes + + def float32(n: Union[int, float]) -> float: return struct.unpack("!f", struct.pack("!f", float(n)))[0] +def _float_match_complex(complex_dtype): + return xp.float32 if complex_dtype == xp.complex64 else xp.float64 + + @given( x_dtype=non_complex_dtypes(), dtype=non_complex_dtypes(), @@ -107,7 +115,7 @@ def test_broadcast_to(x, data): # TODO: test values -@given(_from=non_complex_dtypes(), to=non_complex_dtypes(), data=st.data()) +@given(_from=numeric_dtypes(), to=numeric_dtypes(), data=st.data()) def test_can_cast(_from, to, data): from_ = data.draw( st.just(_from) | hh.arrays(dtype=_from, shape=hh.shapes()), label="from_" @@ -127,8 +135,15 @@ def test_can_cast(_from, to, data): break assert same_family is not None # sanity check if same_family: - from_min, from_max = dh.dtype_ranges[_from] - to_min, to_max = dh.dtype_ranges[to] + from_dtype = (_float_match_complex(_from) + if _from in (xp.complex64, xp.complex128) + else _from) + to_dtype = (_float_match_complex(to) + if to in (xp.complex64, xp.complex128) + else to) + + from_min, from_max = dh.dtype_ranges[from_dtype] + to_min, to_max = dh.dtype_ranges[to_dtype] expected = from_min >= to_min and from_max <= to_max else: expected = False @@ -139,6 +154,7 @@ def test_can_cast(_from, to, data): assert out == expected, f"{out=}, but should be {expected} {f_func}" + @pytest.mark.parametrize("dtype", dh.real_float_dtypes) def test_finfo(dtype): out = xp.finfo(dtype) From 9f28ec809c2eb84cb1d86a489f9dfe60f0643fe3 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 23 Nov 2024 13:04:59 +0200 Subject: [PATCH 2/3] Simplify test_can_cast 1) test all dtypes 2) check against the promotion table 3) remove checking the limits (value-based casting?) --- array_api_tests/test_data_type_functions.py | 49 ++++----------------- 1 file changed, 9 insertions(+), 40 deletions(-) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 0fa49753..32f15ec2 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -19,18 +19,10 @@ def non_complex_dtypes(): return xps.boolean_dtypes() | hh.real_dtypes -def numeric_dtypes(): - return xps.boolean_dtypes() | hh.real_dtypes | hh.complex_dtypes - - def float32(n: Union[int, float]) -> float: return struct.unpack("!f", struct.pack("!f", float(n)))[0] -def _float_match_complex(complex_dtype): - return xp.float32 if complex_dtype == xp.complex64 else xp.float64 - - @given( x_dtype=non_complex_dtypes(), dtype=non_complex_dtypes(), @@ -115,46 +107,23 @@ def test_broadcast_to(x, data): # TODO: test values -@given(_from=numeric_dtypes(), to=numeric_dtypes(), data=st.data()) -def test_can_cast(_from, to, data): - from_ = data.draw( - st.just(_from) | hh.arrays(dtype=_from, shape=hh.shapes()), label="from_" - ) +@given(_from=hh.all_dtypes, to=hh.all_dtypes) +def test_can_cast(_from, to): + out = xp.can_cast(_from, to) - out = xp.can_cast(from_, to) + expected = False + for other in dh.all_dtypes: + if dh.promotion_table.get((_from, other)) == to: + expected = True + break f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]" - assert isinstance(out, bool), f"{type(out)=}, but should be bool {f_func}" - if _from == xp.bool: - expected = to == xp.bool - else: - same_family = None - for dtypes in [dh.all_int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]: - if _from in dtypes: - same_family = to in dtypes - break - assert same_family is not None # sanity check - if same_family: - from_dtype = (_float_match_complex(_from) - if _from in (xp.complex64, xp.complex128) - else _from) - to_dtype = (_float_match_complex(to) - if to in (xp.complex64, xp.complex128) - else to) - - from_min, from_max = dh.dtype_ranges[from_dtype] - to_min, to_max = dh.dtype_ranges[to_dtype] - expected = from_min >= to_min and from_max <= to_max - else: - expected = False if expected: # cross-kind casting is not explicitly disallowed. We can only test - # the cases where it should return True. TODO: if expected=False, - # check that the array library actually allows such casts. + # the cases where it should return True. assert out == expected, f"{out=}, but should be {expected} {f_func}" - @pytest.mark.parametrize("dtype", dh.real_float_dtypes) def test_finfo(dtype): out = xp.finfo(dtype) From f1288680e642b59462c34b24d065ee5966b385ba Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 26 Nov 2024 17:41:12 +0200 Subject: [PATCH 3/3] restore a TODO comment --- array_api_tests/test_data_type_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 32f15ec2..c69e4143 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -120,7 +120,8 @@ def test_can_cast(_from, to): f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]" if expected: # cross-kind casting is not explicitly disallowed. We can only test - # the cases where it should return True. + # the cases where it should return True. TODO: if expected=False, + # check that the array library actually allows such casts. assert out == expected, f"{out=}, but should be {expected} {f_func}"