Skip to content

Commit 1ac40e4

Browse files
committed
Address review
1 parent 314a7b3 commit 1ac40e4

File tree

2 files changed

+96
-53
lines changed

2 files changed

+96
-53
lines changed

Lib/test/test_capi/test_set.py

+90-53
Original file line numberDiff line numberDiff line change
@@ -5,54 +5,57 @@
55
# Skip this test if the _testcapi module isn't available.
66
_testcapi = import_helper.import_module('_testcapi')
77

8-
class set_child(set):
8+
class set_subclass(set):
99
pass
1010

11-
class frozenset_child(frozenset):
11+
class frozenset_subclass(frozenset):
1212
pass
1313

1414

1515
class TestSetCAPI(unittest.TestCase):
1616
def assertImmutable(self, action, *args):
1717
self.assertRaises(SystemError, action, frozenset(), *args)
1818
self.assertRaises(SystemError, action, frozenset({1}), *args)
19-
self.assertRaises(SystemError, action, frozenset_child(), *args)
20-
self.assertRaises(SystemError, action, frozenset_child({1}), *args)
19+
self.assertRaises(SystemError, action, frozenset_subclass(), *args)
20+
self.assertRaises(SystemError, action, frozenset_subclass({1}), *args)
2121

2222
def test_set_check(self):
2323
check = _testcapi.set_check
2424
self.assertTrue(check(set()))
2525
self.assertTrue(check({1, 2}))
2626
self.assertFalse(check(frozenset()))
27-
self.assertTrue(check(set_child()))
28-
self.assertFalse(check(frozenset_child()))
27+
self.assertTrue(check(set_subclass()))
28+
self.assertFalse(check(frozenset_subclass()))
2929
self.assertFalse(check(object()))
30+
# CRASHES: check(NULL)
3031

3132
def test_set_check_exact(self):
3233
check = _testcapi.set_checkexact
3334
self.assertTrue(check(set()))
3435
self.assertTrue(check({1, 2}))
3536
self.assertFalse(check(frozenset()))
36-
self.assertFalse(check(set_child()))
37-
self.assertFalse(check(frozenset_child()))
37+
self.assertFalse(check(set_subclass()))
38+
self.assertFalse(check(frozenset_subclass()))
3839
self.assertFalse(check(object()))
40+
# CRASHES: check(NULL)
3941

4042
def test_frozenset_check(self):
4143
check = _testcapi.frozenset_check
4244
self.assertFalse(check(set()))
4345
self.assertTrue(check(frozenset()))
4446
self.assertTrue(check(frozenset({1, 2})))
45-
self.assertFalse(check(set_child()))
46-
self.assertTrue(check(frozenset_child()))
47+
self.assertFalse(check(set_subclass()))
48+
self.assertTrue(check(frozenset_subclass()))
4749
self.assertFalse(check(object()))
50+
# CRASHES: check(NULL)
4851

4952
def test_frozenset_check_exact(self):
5053
check = _testcapi.frozenset_checkexact
5154
self.assertFalse(check(set()))
5255
self.assertTrue(check(frozenset()))
5356
self.assertTrue(check(frozenset({1, 2})))
54-
self.assertFalse(check(set_child()))
55-
self.assertFalse(check(frozenset_child()))
57+
self.assertFalse(check(set_subclass()))
58+
self.assertFalse(check(frozenset_subclass()))
5659
self.assertFalse(check(object()))
5760

5861
def test_anyset_check(self):
@@ -61,83 +64,103 @@ def test_anyset_check(self):
6164
self.assertTrue(check({1, 2}))
6265
self.assertTrue(check(frozenset()))
6366
self.assertTrue(check(frozenset({1, 2})))
64-
self.assertTrue(check(set_child()))
65-
self.assertTrue(check(frozenset_child()))
67+
self.assertTrue(check(set_subclass()))
68+
self.assertTrue(check(frozenset_subclass()))
6669
self.assertFalse(check(object()))
70+
# CRASHES: check(NULL)
6771

6872
def test_anyset_check_exact(self):
6973
check = _testcapi.anyset_checkexact
7074
self.assertTrue(check(set()))
7175
self.assertTrue(check({1, 2}))
7276
self.assertTrue(check(frozenset()))
7377
self.assertTrue(check(frozenset({1, 2})))
74-
self.assertFalse(check(set_child()))
75-
self.assertFalse(check(frozenset_child()))
78+
self.assertFalse(check(set_subclass()))
79+
self.assertFalse(check(frozenset_subclass()))
7680
self.assertFalse(check(object()))
81+
# CRASHES: check(NULL)
7782

7883
def test_set_new(self):
79-
new = _testcapi.set_new
80-
self.assertEqual(new().__class__, set)
81-
self.assertEqual(new(), set())
82-
self.assertEqual(new((1, 1, 2)), {1, 2})
84+
set_new = _testcapi.set_new
85+
self.assertEqual(set_new().__class__, set)
86+
self.assertEqual(set_new(), set())
87+
self.assertEqual(set_new((1, 1, 2)), {1, 2})
8388
with self.assertRaisesRegex(TypeError, 'object is not iterable'):
84-
new(object())
89+
set_new(object())
90+
with self.assertRaisesRegex(TypeError, 'object is not iterable'):
91+
set_new(None)
8592
with self.assertRaisesRegex(TypeError, "unhashable type: 'dict'"):
86-
new((1, {}))
93+
set_new((1, {}))
8794

8895
def test_frozenset_new(self):
89-
new = _testcapi.frozenset_new
90-
self.assertEqual(new().__class__, frozenset)
91-
self.assertEqual(new(), frozenset())
92-
self.assertEqual(new((1, 1, 2)), frozenset({1, 2}))
96+
frozenset_new = _testcapi.frozenset_new
97+
self.assertEqual(frozenset_new().__class__, frozenset)
98+
self.assertEqual(frozenset_new(), frozenset())
99+
self.assertEqual(frozenset_new((1, 1, 2)), frozenset({1, 2}))
100+
with self.assertRaisesRegex(TypeError, 'object is not iterable'):
101+
frozenset_new(object())
93102
with self.assertRaisesRegex(TypeError, 'object is not iterable'):
94-
new(object())
103+
frozenset_new(None)
95104
with self.assertRaisesRegex(TypeError, "unhashable type: 'dict'"):
96-
new((1, {}))
105+
frozenset_new((1, {}))
97106

98107
def test_set_size(self):
99-
l = _testcapi.set_size
100-
self.assertEqual(l(set()), 0)
101-
self.assertEqual(l(frozenset()), 0)
102-
self.assertEqual(l({1, 1, 2}), 2)
103-
self.assertEqual(l(frozenset({1, 1, 2})), 2)
104-
self.assertEqual(l(set_child((1, 2, 3))), 3)
105-
self.assertEqual(l(frozenset_child((1, 2, 3))), 3)
108+
get_size = _testcapi.set_size
109+
self.assertEqual(get_size(set()), 0)
110+
self.assertEqual(get_size(frozenset()), 0)
111+
self.assertEqual(get_size({1, 1, 2}), 2)
112+
self.assertEqual(get_size(frozenset({1, 1, 2})), 2)
113+
self.assertEqual(get_size(set_subclass((1, 2, 3))), 3)
114+
self.assertEqual(get_size(frozenset_subclass((1, 2, 3))), 3)
106115
with self.assertRaises(SystemError):
107-
l([])
116+
get_size([])
117+
# CRASHES: get_size(NULL)
108118

109119
def test_set_get_size(self):
110-
l = _testcapi.set_get_size
111-
self.assertEqual(l(set()), 0)
112-
self.assertEqual(l(frozenset()), 0)
113-
self.assertEqual(l({1, 1, 2}), 2)
114-
self.assertEqual(l(frozenset({1, 1, 2})), 2)
115-
self.assertEqual(l(set_child((1, 2, 3))), 3)
116-
self.assertEqual(l(frozenset_child((1, 2, 3))), 3)
117-
# CRASHES: l([])
120+
get_size = _testcapi.set_get_size
121+
self.assertEqual(get_size(set()), 0)
122+
self.assertEqual(get_size(frozenset()), 0)
123+
self.assertEqual(get_size({1, 1, 2}), 2)
124+
self.assertEqual(get_size(frozenset({1, 1, 2})), 2)
125+
self.assertEqual(get_size(set_subclass((1, 2, 3))), 3)
126+
self.assertEqual(get_size(frozenset_subclass((1, 2, 3))), 3)
127+
# CRASHES: get_size(NULL)
128+
# CRASHES: get_size(object())
118129

119130
def test_set_contains(self):
120-
c = _testcapi.set_contains
121-
for cls in (set, frozenset, set_child, frozenset_child):
131+
contains = _testcapi.set_contains
132+
for cls in (set, frozenset, set_subclass, frozenset_subclass):
122133
with self.subTest(cls=cls):
123134
instance = cls((1, 2))
124-
self.assertTrue(c(instance, 1))
125-
self.assertFalse(c(instance, 'missing'))
135+
self.assertTrue(contains(instance, 1))
136+
self.assertFalse(contains(instance, 'missing'))
137+
with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"):
138+
contains(instance, [])
139+
# CRASHES: contains(instance, NULL)
140+
# CRASHES: contains(NULL, object())
141+
# CRASHES: contains(NULL, NULL)
126142

127143
def test_add(self):
128144
add = _testcapi.set_add
129-
for cls in (set, set_child):
145+
for cls in (set, set_subclass):
130146
with self.subTest(cls=cls):
131147
instance = cls((1, 2))
132148
self.assertEqual(add(instance, 1), 0)
133149
self.assertEqual(instance, {1, 2})
134150
self.assertEqual(add(instance, 3), 0)
135151
self.assertEqual(instance, {1, 2, 3})
152+
with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"):
153+
add(instance, [])
154+
# CRASHES: add(NULL, object())
155+
# CRASHES: add(instance, NULL)
156+
# CRASHES: add(NULL, NULL)
157+
with self.assertRaises(SystemError):
158+
add(object(), 1)
136159
self.assertImmutable(add, 1)
137160

138161
def test_discard(self):
139162
discard = _testcapi.set_discard
140-
for cls in (set, set_child):
163+
for cls in (set, set_subclass):
141164
with self.subTest(cls=cls):
142165
instance = cls((1, 2))
143166
self.assertEqual(discard(instance, 3), 0)
@@ -146,15 +169,21 @@ def test_discard(self):
146169
self.assertEqual(instance, {2})
147170
self.assertEqual(discard(instance, 2), 1)
148171
self.assertEqual(instance, set())
149-
# Discarding from empty set works
150172
self.assertEqual(discard(instance, 2), 0)
151173
self.assertEqual(instance, set())
174+
with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"):
175+
discard(instance, [])
176+
# CRASHES: discard(NULL, object())
177+
# CRASHES: discard(instance, NULL)
178+
# CRASHES: discard(NULL, NULL)
179+
with self.assertRaises(SystemError):
180+
discard(object(), 1)
152181
self.assertImmutable(discard, 1)
153182

154183
def test_pop(self):
155184
pop = _testcapi.set_pop
156185
orig = (1, 2)
157-
for cls in (set, set_child):
186+
for cls in (set, set_subclass):
158187
with self.subTest(cls=cls):
159188
instance = cls(orig)
160189
self.assertIn(pop(instance), orig)
@@ -163,13 +192,21 @@ def test_pop(self):
163192
self.assertEqual(len(instance), 0)
164193
with self.assertRaises(KeyError):
165194
pop(instance)
195+
# CRASHES: pop(NULL)
196+
with self.assertRaises(SystemError):
197+
pop(object())
166198
self.assertImmutable(pop)
167199

168200
def test_clear(self):
169201
clear = _testcapi.set_clear
170-
for cls in (set, set_child):
202+
for cls in (set, set_subclass):
171203
with self.subTest(cls=cls):
172204
instance = cls((1, 2))
173205
self.assertEqual(clear(instance), 0)
174206
self.assertEqual(instance, set())
207+
self.assertEqual(clear(instance), 0)
208+
self.assertEqual(instance, set())
209+
# CRASHES: clear(NULL)
210+
with self.assertRaises(SystemError):
211+
clear(object())
175212
self.assertImmutable(clear)

Modules/_testcapi/set.c

+6
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,42 @@
66
static PyObject *
77
set_check(PyObject *self, PyObject *obj)
88
{
9+
NULLABLE(obj);
910
RETURN_INT(PySet_Check(obj));
1011
}
1112

1213
static PyObject *
1314
set_checkexact(PyObject *self, PyObject *obj)
1415
{
16+
NULLABLE(obj);
1517
RETURN_INT(PySet_CheckExact(obj));
1618
}
1719

1820
static PyObject *
1921
frozenset_check(PyObject *self, PyObject *obj)
2022
{
23+
NULLABLE(obj);
2124
RETURN_INT(PyFrozenSet_Check(obj));
2225
}
2326

2427
static PyObject *
2528
frozenset_checkexact(PyObject *self, PyObject *obj)
2629
{
30+
NULLABLE(obj);
2731
RETURN_INT(PyFrozenSet_CheckExact(obj));
2832
}
2933

3034
static PyObject *
3135
anyset_check(PyObject *self, PyObject *obj)
3236
{
37+
NULLABLE(obj);
3338
RETURN_INT(PyAnySet_Check(obj));
3439
}
3540

3641
static PyObject *
3742
anyset_checkexact(PyObject *self, PyObject *obj)
3843
{
44+
NULLABLE(obj);
3945
RETURN_INT(PyAnySet_CheckExact(obj));
4046
}
4147

0 commit comments

Comments
 (0)