|
9 | 9 | from math import log, exp, pi, fsum, sin, factorial
|
10 | 10 | from test import support
|
11 | 11 | from fractions import Fraction
|
12 |
| - |
| 12 | +from collections import Counter |
13 | 13 |
|
14 | 14 | class TestBasicOps:
|
15 | 15 | # Superclass with tests common to all generators.
|
@@ -161,6 +161,77 @@ def test_sample_on_sets(self):
|
161 | 161 | population = {10, 20, 30, 40, 50, 60, 70}
|
162 | 162 | self.gen.sample(population, k=5)
|
163 | 163 |
|
| 164 | + def test_sample_with_counts(self): |
| 165 | + sample = self.gen.sample |
| 166 | + |
| 167 | + # General case |
| 168 | + colors = ['red', 'green', 'blue', 'orange', 'black', 'brown', 'amber'] |
| 169 | + counts = [500, 200, 20, 10, 5, 0, 1 ] |
| 170 | + k = 700 |
| 171 | + summary = Counter(sample(colors, counts=counts, k=k)) |
| 172 | + self.assertEqual(sum(summary.values()), k) |
| 173 | + for color, weight in zip(colors, counts): |
| 174 | + self.assertLessEqual(summary[color], weight) |
| 175 | + self.assertNotIn('brown', summary) |
| 176 | + |
| 177 | + # Case that exhausts the population |
| 178 | + k = sum(counts) |
| 179 | + summary = Counter(sample(colors, counts=counts, k=k)) |
| 180 | + self.assertEqual(sum(summary.values()), k) |
| 181 | + for color, weight in zip(colors, counts): |
| 182 | + self.assertLessEqual(summary[color], weight) |
| 183 | + self.assertNotIn('brown', summary) |
| 184 | + |
| 185 | + # Case with population size of 1 |
| 186 | + summary = Counter(sample(['x'], counts=[10], k=8)) |
| 187 | + self.assertEqual(summary, Counter(x=8)) |
| 188 | + |
| 189 | + # Case with all counts equal. |
| 190 | + nc = len(colors) |
| 191 | + summary = Counter(sample(colors, counts=[10]*nc, k=10*nc)) |
| 192 | + self.assertEqual(summary, Counter(10*colors)) |
| 193 | + |
| 194 | + # Test error handling |
| 195 | + with self.assertRaises(TypeError): |
| 196 | + sample(['red', 'green', 'blue'], counts=10, k=10) # counts not iterable |
| 197 | + with self.assertRaises(ValueError): |
| 198 | + sample(['red', 'green', 'blue'], counts=[-3, -7, -8], k=2) # counts are negative |
| 199 | + with self.assertRaises(ValueError): |
| 200 | + sample(['red', 'green', 'blue'], counts=[0, 0, 0], k=2) # counts are zero |
| 201 | + with self.assertRaises(ValueError): |
| 202 | + sample(['red', 'green'], counts=[10, 10], k=21) # population too small |
| 203 | + with self.assertRaises(ValueError): |
| 204 | + sample(['red', 'green', 'blue'], counts=[1, 2], k=2) # too few counts |
| 205 | + with self.assertRaises(ValueError): |
| 206 | + sample(['red', 'green', 'blue'], counts=[1, 2, 3, 4], k=2) # too many counts |
| 207 | + |
| 208 | + def test_sample_counts_equivalence(self): |
| 209 | + # Test the documented strong equivalence to a sample with repeated elements. |
| 210 | + # We run this test on random.Random() which makes deterministic selections |
| 211 | + # for a given seed value. |
| 212 | + sample = random.sample |
| 213 | + seed = random.seed |
| 214 | + |
| 215 | + colors = ['red', 'green', 'blue', 'orange', 'black', 'amber'] |
| 216 | + counts = [500, 200, 20, 10, 5, 1 ] |
| 217 | + k = 700 |
| 218 | + seed(8675309) |
| 219 | + s1 = sample(colors, counts=counts, k=k) |
| 220 | + seed(8675309) |
| 221 | + expanded = [color for (color, count) in zip(colors, counts) for i in range(count)] |
| 222 | + self.assertEqual(len(expanded), sum(counts)) |
| 223 | + s2 = sample(expanded, k=k) |
| 224 | + self.assertEqual(s1, s2) |
| 225 | + |
| 226 | + pop = 'abcdefghi' |
| 227 | + counts = [10, 9, 8, 7, 6, 5, 4, 3, 2] |
| 228 | + seed(8675309) |
| 229 | + s1 = ''.join(sample(pop, counts=counts, k=30)) |
| 230 | + expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)]) |
| 231 | + seed(8675309) |
| 232 | + s2 = ''.join(sample(expanded, k=30)) |
| 233 | + self.assertEqual(s1, s2) |
| 234 | + |
164 | 235 | def test_choices(self):
|
165 | 236 | choices = self.gen.choices
|
166 | 237 | data = ['red', 'green', 'blue', 'yellow']
|
|
0 commit comments