Skip to content

Commit 9feda28

Browse files
committed
Address review
1 parent 314a7b3 commit 9feda28

File tree

2 files changed

+97
-53
lines changed

2 files changed

+97
-53
lines changed

Lib/test/test_capi/test_set.py

+91-53
Original file line numberDiff line numberDiff line change
@@ -5,139 +5,163 @@
55
# Skip this test if the _testcapi module isn't available.
66
_testcapi = import_helper.import_module('_testcapi')
77

8-
class set_child(set):
8+
class set_subclass(set):
99
pass
1010

11-
class frozenset_child(frozenset):
11+
class frozenset_subclass(frozenset):
1212
pass
1313

1414

1515
class TestSetCAPI(unittest.TestCase):
1616
def assertImmutable(self, action, *args):
1717
self.assertRaises(SystemError, action, frozenset(), *args)
1818
self.assertRaises(SystemError, action, frozenset({1}), *args)
19-
self.assertRaises(SystemError, action, frozenset_child(), *args)
20-
self.assertRaises(SystemError, action, frozenset_child({1}), *args)
19+
self.assertRaises(SystemError, action, frozenset_subclass(), *args)
20+
self.assertRaises(SystemError, action, frozenset_subclass({1}), *args)
2121

2222
def test_set_check(self):
2323
check = _testcapi.set_check
2424
self.assertTrue(check(set()))
2525
self.assertTrue(check({1, 2}))
2626
self.assertFalse(check(frozenset()))
27-
self.assertTrue(check(set_child()))
28-
self.assertFalse(check(frozenset_child()))
27+
self.assertTrue(check(set_subclass()))
28+
self.assertFalse(check(frozenset_subclass()))
2929
self.assertFalse(check(object()))
30+
# CRASHES: check(NULL)
3031

3132
def test_set_check_exact(self):
3233
check = _testcapi.set_checkexact
3334
self.assertTrue(check(set()))
3435
self.assertTrue(check({1, 2}))
3536
self.assertFalse(check(frozenset()))
36-
self.assertFalse(check(set_child()))
37-
self.assertFalse(check(frozenset_child()))
37+
self.assertFalse(check(set_subclass()))
38+
self.assertFalse(check(frozenset_subclass()))
3839
self.assertFalse(check(object()))
40+
# CRASHES: check(NULL)
3941

4042
def test_frozenset_check(self):
4143
check = _testcapi.frozenset_check
4244
self.assertFalse(check(set()))
4345
self.assertTrue(check(frozenset()))
4446
self.assertTrue(check(frozenset({1, 2})))
45-
self.assertFalse(check(set_child()))
46-
self.assertTrue(check(frozenset_child()))
47+
self.assertFalse(check(set_subclass()))
48+
self.assertTrue(check(frozenset_subclass()))
4749
self.assertFalse(check(object()))
50+
# CRASHES: check(NULL)
4851

4952
def test_frozenset_check_exact(self):
5053
check = _testcapi.frozenset_checkexact
5154
self.assertFalse(check(set()))
5255
self.assertTrue(check(frozenset()))
5356
self.assertTrue(check(frozenset({1, 2})))
54-
self.assertFalse(check(set_child()))
55-
self.assertFalse(check(frozenset_child()))
57+
self.assertFalse(check(set_subclass()))
58+
self.assertFalse(check(frozenset_subclass()))
5659
self.assertFalse(check(object()))
60+
# CRASHES: check(NULL)
5761

5862
def test_anyset_check(self):
5963
check = _testcapi.anyset_check
6064
self.assertTrue(check(set()))
6165
self.assertTrue(check({1, 2}))
6266
self.assertTrue(check(frozenset()))
6367
self.assertTrue(check(frozenset({1, 2})))
64-
self.assertTrue(check(set_child()))
65-
self.assertTrue(check(frozenset_child()))
68+
self.assertTrue(check(set_subclass()))
69+
self.assertTrue(check(frozenset_subclass()))
6670
self.assertFalse(check(object()))
71+
# CRASHES: check(NULL)
6772

6873
def test_anyset_check_exact(self):
6974
check = _testcapi.anyset_checkexact
7075
self.assertTrue(check(set()))
7176
self.assertTrue(check({1, 2}))
7277
self.assertTrue(check(frozenset()))
7378
self.assertTrue(check(frozenset({1, 2})))
74-
self.assertFalse(check(set_child()))
75-
self.assertFalse(check(frozenset_child()))
79+
self.assertFalse(check(set_subclass()))
80+
self.assertFalse(check(frozenset_subclass()))
7681
self.assertFalse(check(object()))
82+
# CRASHES: check(NULL)
7783

7884
def test_set_new(self):
79-
new = _testcapi.set_new
80-
self.assertEqual(new().__class__, set)
81-
self.assertEqual(new(), set())
82-
self.assertEqual(new((1, 1, 2)), {1, 2})
85+
set_new = _testcapi.set_new
86+
self.assertEqual(set_new().__class__, set)
87+
self.assertEqual(set_new(), set())
88+
self.assertEqual(set_new((1, 1, 2)), {1, 2})
8389
with self.assertRaisesRegex(TypeError, 'object is not iterable'):
84-
new(object())
90+
set_new(object())
91+
with self.assertRaisesRegex(TypeError, 'object is not iterable'):
92+
set_new(None)
8593
with self.assertRaisesRegex(TypeError, "unhashable type: 'dict'"):
86-
new((1, {}))
94+
set_new((1, {}))
8795

8896
def test_frozenset_new(self):
89-
new = _testcapi.frozenset_new
90-
self.assertEqual(new().__class__, frozenset)
91-
self.assertEqual(new(), frozenset())
92-
self.assertEqual(new((1, 1, 2)), frozenset({1, 2}))
97+
frozenset_new = _testcapi.frozenset_new
98+
self.assertEqual(frozenset_new().__class__, frozenset)
99+
self.assertEqual(frozenset_new(), frozenset())
100+
self.assertEqual(frozenset_new((1, 1, 2)), frozenset({1, 2}))
101+
with self.assertRaisesRegex(TypeError, 'object is not iterable'):
102+
frozenset_new(object())
93103
with self.assertRaisesRegex(TypeError, 'object is not iterable'):
94-
new(object())
104+
frozenset_new(None)
95105
with self.assertRaisesRegex(TypeError, "unhashable type: 'dict'"):
96-
new((1, {}))
106+
frozenset_new((1, {}))
97107

98108
def test_set_size(self):
99-
l = _testcapi.set_size
100-
self.assertEqual(l(set()), 0)
101-
self.assertEqual(l(frozenset()), 0)
102-
self.assertEqual(l({1, 1, 2}), 2)
103-
self.assertEqual(l(frozenset({1, 1, 2})), 2)
104-
self.assertEqual(l(set_child((1, 2, 3))), 3)
105-
self.assertEqual(l(frozenset_child((1, 2, 3))), 3)
109+
get_size = _testcapi.set_size
110+
self.assertEqual(get_size(set()), 0)
111+
self.assertEqual(get_size(frozenset()), 0)
112+
self.assertEqual(get_size({1, 1, 2}), 2)
113+
self.assertEqual(get_size(frozenset({1, 1, 2})), 2)
114+
self.assertEqual(get_size(set_subclass((1, 2, 3))), 3)
115+
self.assertEqual(get_size(frozenset_subclass((1, 2, 3))), 3)
106116
with self.assertRaises(SystemError):
107-
l([])
117+
get_size([])
118+
# CRASHES: get_size(NULL)
108119

109120
def test_set_get_size(self):
110-
l = _testcapi.set_get_size
111-
self.assertEqual(l(set()), 0)
112-
self.assertEqual(l(frozenset()), 0)
113-
self.assertEqual(l({1, 1, 2}), 2)
114-
self.assertEqual(l(frozenset({1, 1, 2})), 2)
115-
self.assertEqual(l(set_child((1, 2, 3))), 3)
116-
self.assertEqual(l(frozenset_child((1, 2, 3))), 3)
117-
# CRASHES: l([])
121+
get_size = _testcapi.set_get_size
122+
self.assertEqual(get_size(set()), 0)
123+
self.assertEqual(get_size(frozenset()), 0)
124+
self.assertEqual(get_size({1, 1, 2}), 2)
125+
self.assertEqual(get_size(frozenset({1, 1, 2})), 2)
126+
self.assertEqual(get_size(set_subclass((1, 2, 3))), 3)
127+
self.assertEqual(get_size(frozenset_subclass((1, 2, 3))), 3)
128+
# CRASHES: get_size(NULL)
129+
# CRASHES: get_size(object())
118130

119131
def test_set_contains(self):
120-
c = _testcapi.set_contains
121-
for cls in (set, frozenset, set_child, frozenset_child):
132+
contains = _testcapi.set_contains
133+
for cls in (set, frozenset, set_subclass, frozenset_subclass):
122134
with self.subTest(cls=cls):
123135
instance = cls((1, 2))
124-
self.assertTrue(c(instance, 1))
125-
self.assertFalse(c(instance, 'missing'))
136+
self.assertTrue(contains(instance, 1))
137+
self.assertFalse(contains(instance, 'missing'))
138+
with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"):
139+
contains(instance, [])
140+
# CRASHES: contains(instance, NULL)
141+
# CRASHES: contains(NULL, object())
142+
# CRASHES: contains(NULL, NULL)
126143

127144
def test_add(self):
128145
add = _testcapi.set_add
129-
for cls in (set, set_child):
146+
for cls in (set, set_subclass):
130147
with self.subTest(cls=cls):
131148
instance = cls((1, 2))
132149
self.assertEqual(add(instance, 1), 0)
133150
self.assertEqual(instance, {1, 2})
134151
self.assertEqual(add(instance, 3), 0)
135152
self.assertEqual(instance, {1, 2, 3})
153+
with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"):
154+
add(instance, [])
155+
# CRASHES: add(NULL, object())
156+
# CRASHES: add(instance, NULL)
157+
# CRASHES: add(NULL, NULL)
158+
with self.assertRaises(SystemError):
159+
add(object(), 1)
136160
self.assertImmutable(add, 1)
137161

138162
def test_discard(self):
139163
discard = _testcapi.set_discard
140-
for cls in (set, set_child):
164+
for cls in (set, set_subclass):
141165
with self.subTest(cls=cls):
142166
instance = cls((1, 2))
143167
self.assertEqual(discard(instance, 3), 0)
@@ -146,15 +170,21 @@ def test_discard(self):
146170
self.assertEqual(instance, {2})
147171
self.assertEqual(discard(instance, 2), 1)
148172
self.assertEqual(instance, set())
149-
# Discarding from empty set works
150173
self.assertEqual(discard(instance, 2), 0)
151174
self.assertEqual(instance, set())
175+
with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"):
176+
discard(instance, [])
177+
# CRASHES: discard(NULL, object())
178+
# CRASHES: discard(instance, NULL)
179+
# CRASHES: discard(NULL, NULL)
180+
with self.assertRaises(SystemError):
181+
discard(object(), 1)
152182
self.assertImmutable(discard, 1)
153183

154184
def test_pop(self):
155185
pop = _testcapi.set_pop
156186
orig = (1, 2)
157-
for cls in (set, set_child):
187+
for cls in (set, set_subclass):
158188
with self.subTest(cls=cls):
159189
instance = cls(orig)
160190
self.assertIn(pop(instance), orig)
@@ -163,13 +193,21 @@ def test_pop(self):
163193
self.assertEqual(len(instance), 0)
164194
with self.assertRaises(KeyError):
165195
pop(instance)
196+
# CRASHES: pop(NULL)
197+
with self.assertRaises(SystemError):
198+
pop(object())
166199
self.assertImmutable(pop)
167200

168201
def test_clear(self):
169202
clear = _testcapi.set_clear
170-
for cls in (set, set_child):
203+
for cls in (set, set_subclass):
171204
with self.subTest(cls=cls):
172205
instance = cls((1, 2))
173206
self.assertEqual(clear(instance), 0)
174207
self.assertEqual(instance, set())
208+
self.assertEqual(clear(instance), 0)
209+
self.assertEqual(instance, set())
210+
# CRASHES: clear(NULL)
211+
with self.assertRaises(SystemError):
212+
clear(object())
175213
self.assertImmutable(clear)

Modules/_testcapi/set.c

+6
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,42 @@
66
static PyObject *
77
set_check(PyObject *self, PyObject *obj)
88
{
9+
NULLABLE(obj);
910
RETURN_INT(PySet_Check(obj));
1011
}
1112

1213
static PyObject *
1314
set_checkexact(PyObject *self, PyObject *obj)
1415
{
16+
NULLABLE(obj);
1517
RETURN_INT(PySet_CheckExact(obj));
1618
}
1719

1820
static PyObject *
1921
frozenset_check(PyObject *self, PyObject *obj)
2022
{
23+
NULLABLE(obj);
2124
RETURN_INT(PyFrozenSet_Check(obj));
2225
}
2326

2427
static PyObject *
2528
frozenset_checkexact(PyObject *self, PyObject *obj)
2629
{
30+
NULLABLE(obj);
2731
RETURN_INT(PyFrozenSet_CheckExact(obj));
2832
}
2933

3034
static PyObject *
3135
anyset_check(PyObject *self, PyObject *obj)
3236
{
37+
NULLABLE(obj);
3338
RETURN_INT(PyAnySet_Check(obj));
3439
}
3540

3641
static PyObject *
3742
anyset_checkexact(PyObject *self, PyObject *obj)
3843
{
44+
NULLABLE(obj);
3945
RETURN_INT(PyAnySet_CheckExact(obj));
4046
}
4147

0 commit comments

Comments
 (0)