Skip to content

Commit 9a68ff1

Browse files
authored
GH-100805: Support numpy.array() in random.choice(). (GH-100830)
1 parent 87d3bd0 commit 9a68ff1

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

Lib/random.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,10 @@ def randint(self, a, b):
336336

337337
def choice(self, seq):
338338
"""Choose a random element from a non-empty sequence."""
339-
if not seq:
339+
340+
# As an accommodation for NumPy, we don't use "if not seq"
341+
# because bool(numpy.array()) raises a ValueError.
342+
if not len(seq):
340343
raise IndexError('Cannot choose from an empty sequence')
341344
return seq[self._randbelow(len(seq))]
342345

Lib/test/test_random.py

+15
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,21 @@ def test_choice(self):
111111
self.assertEqual(choice([50]), 50)
112112
self.assertIn(choice([25, 75]), [25, 75])
113113

114+
def test_choice_with_numpy(self):
115+
# Accommodation for NumPy arrays which have disabled __bool__().
116+
# See: https://github.com/python/cpython/issues/100805
117+
choice = self.gen.choice
118+
119+
class NA(list):
120+
"Simulate numpy.array() behavior"
121+
def __bool__(self):
122+
raise RuntimeError
123+
124+
with self.assertRaises(IndexError):
125+
choice(NA([]))
126+
self.assertEqual(choice(NA([50])), 50)
127+
self.assertIn(choice(NA([25, 75])), [25, 75])
128+
114129
def test_sample(self):
115130
# For the entire allowable range of 0 <= k <= N, validate that
116131
# the sample is of the correct length and contains only unique items
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Modify :func:`random.choice` implementation to once again work with NumPy
2+
arrays.

0 commit comments

Comments
 (0)