Skip to content

Commit 72de19b

Browse files
committed
ENH: test bitwise_{left,right}_shift with scalars
1 parent a3f2555 commit 72de19b

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

array_api_tests/hypothesis_helpers.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,12 @@ def scalars(draw, dtypes, finite=False, **kwds):
456456
dtypes should be one of the shared_* dtypes strategies.
457457
"""
458458
dtype = draw(dtypes)
459+
mM = kwds.pop('mM', None)
459460
if dh.is_int_dtype(dtype):
460-
m, M = dh.dtype_ranges[dtype]
461+
if mM is None:
462+
m, M = dh.dtype_ranges[dtype]
463+
else:
464+
m, M = mM
461465
return draw(integers(m, M))
462466
elif dtype == bool_dtype:
463467
return draw(booleans())
@@ -588,18 +592,20 @@ def two_mutual_arrays(
588592

589593

590594
@composite
591-
def array_and_py_scalar(draw, dtypes):
595+
def array_and_py_scalar(draw, dtypes, mM=None, positive=False):
592596
"""Draw a pair: (array, scalar) or (scalar, array)."""
593597
dtype = draw(sampled_from(dtypes))
594598

595-
scalar_var = draw(scalars(just(dtype), finite=True,
596-
**{'min_value': 1/ (2<<5), 'max_value': 2<<5}
597-
))
599+
scalar_var = draw(scalars(just(dtype), finite=True, mM=mM))
600+
if positive:
601+
assume (scalar_var > 0)
598602

599603
elements={}
600604
if dtype in dh.real_float_dtypes:
601605
elements = {'allow_nan': False, 'allow_infinity': False,
602606
'min_value': 1.0 / (2<<5), 'max_value': 2<<5}
607+
if positive:
608+
elements = {'min_value': 0}
603609
array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements))
604610

605611
if draw(booleans()):

array_api_tests/test_operators_and_elementwise_functions.py

+17
Original file line numberDiff line numberDiff line change
@@ -1883,3 +1883,20 @@ def test_binary_with_scalars_bitwise(func_data, x1x2):
18831883
refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
18841884
_check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)
18851885

1886+
1887+
@pytest.mark.min_version("2024.12")
1888+
@pytest.mark.parametrize('func_data',
1889+
# func_name, refimpl, kwargs, expected_dtype
1890+
[
1891+
("bitwise_left_shift", operator.lshift, {}, None),
1892+
("bitwise_right_shift", operator.rshift, {}, None),
1893+
],
1894+
ids=lambda func_data: func_data[0] # use names for test IDs
1895+
)
1896+
@given(x1x2=hh.array_and_py_scalar([xp.int32], positive=True, mM=(1, 3)))
1897+
def test_binary_with_scalars_bitwise_shifts(func_data, x1x2):
1898+
func_name, refimpl, kwargs, expected = func_data
1899+
# repack the refimpl
1900+
refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
1901+
_check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)
1902+

0 commit comments

Comments
 (0)