Skip to content

Commit 4b861d5

Browse files
committed
ENH: test bitwise_{left,right}_shift with scalars
1 parent 8ff22d4 commit 4b861d5

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

array_api_tests/hypothesis_helpers.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,11 @@ def scalars(draw, dtypes, finite=False, **kwds):
457457
"""
458458
dtype = draw(dtypes)
459459
if dh.is_int_dtype(dtype):
460-
m, M = dh.dtype_ranges[dtype]
460+
mM = kwds.pop('mM', None)
461+
if mM is None:
462+
m, M = dh.dtype_ranges[dtype]
463+
else:
464+
m, M = mM
461465
return draw(integers(m, M))
462466
elif dtype == bool_dtype:
463467
return draw(booleans())
@@ -588,18 +592,20 @@ def two_mutual_arrays(
588592

589593

590594
@composite
591-
def array_and_py_scalar(draw, dtypes):
595+
def array_and_py_scalar(draw, dtypes, mM=None, positive=False):
592596
"""Draw a pair: (array, scalar) or (scalar, array)."""
593597
dtype = draw(sampled_from(dtypes))
594598

595-
scalar_var = draw(scalars(just(dtype), finite=True,
596-
**{'min_value': 1/ (2<<5), 'max_value': 2<<5}
597-
))
599+
scalar_var = draw(scalars(just(dtype), finite=True, mM=mM))
600+
if positive:
601+
assume (scalar_var > 0)
598602

599603
elements={}
600604
if dtype in dh.real_float_dtypes:
601605
elements = {'allow_nan': False, 'allow_infinity': False,
602606
'min_value': 1.0 / (2<<5), 'max_value': 2<<5}
607+
if positive:
608+
elements = {'min_value': 0}
603609
array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements))
604610

605611
if draw(booleans()):

array_api_tests/test_operators_and_elementwise_functions.py

+17
Original file line numberDiff line numberDiff line change
@@ -1881,3 +1881,20 @@ def test_binary_with_scalars_bitwise(func_data, x1x2):
18811881
refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
18821882
_check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)
18831883

1884+
1885+
@pytest.mark.min_version("2024.12")
1886+
@pytest.mark.parametrize('func_data',
1887+
# func_name, refimpl, kwargs, expected_dtype
1888+
[
1889+
("bitwise_left_shift", operator.lshift, {}, None),
1890+
("bitwise_right_shift", operator.rshift, {}, None),
1891+
],
1892+
ids=lambda func_data: func_data[0] # use names for test IDs
1893+
)
1894+
@given(x1x2=hh.array_and_py_scalar([xp.int32], positive=True, mM=(1, 3)))
1895+
def test_binary_with_scalars_bitwise_shifts(func_data, x1x2):
1896+
func_name, refimpl, kwargs, expected = func_data
1897+
# repack the refimpl
1898+
refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
1899+
_check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)
1900+

0 commit comments

Comments
 (0)