Skip to content

Commit 85b5829

Browse files
bpo-44633: Fix parameter substitution of the union type with wrong types. (GH-27218) (GH-27224)
A TypeError is now raised instead of returning NotImplemented. (cherry picked from commit 3ea5332) Co-authored-by: Serhiy Storchaka <[email protected]>
1 parent 03aad30 commit 85b5829

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

Lib/test/test_types.py

+6
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,12 @@ def test_union_parameter_chaining(self):
772772
self.assertEqual((list[T] | list[S])[int, T], list[int] | list[T])
773773
self.assertEqual((list[T] | list[S])[int, int], list[int])
774774

775+
def test_union_parameter_substitution_errors(self):
776+
T = typing.TypeVar("T")
777+
x = int | T
778+
with self.assertRaises(TypeError):
779+
x[42]
780+
775781
def test_or_type_operator_with_forward(self):
776782
T = typing.TypeVar('T')
777783
ForwardAfter = T | 'Forward'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Parameter substitution of the union type with wrong types now raises
2+
``TypeError`` instead of returning ``NotImplemented``.

Objects/unionobject.c

+28-16
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,22 @@ is_unionable(PyObject *obj)
302302
PyObject *
303303
_Py_union_type_or(PyObject* self, PyObject* other)
304304
{
305+
int r = is_unionable(self);
306+
if (r > 0) {
307+
r = is_unionable(other);
308+
}
309+
if (r < 0) {
310+
return NULL;
311+
}
312+
if (!r) {
313+
Py_RETURN_NOTIMPLEMENTED;
314+
}
315+
305316
PyObject *tuple = PyTuple_Pack(2, self, other);
306317
if (tuple == NULL) {
307318
return NULL;
308319
}
320+
309321
PyObject *new_union = make_union(tuple);
310322
Py_DECREF(tuple);
311323
return new_union;
@@ -434,6 +446,21 @@ union_getitem(PyObject *self, PyObject *item)
434446
return NULL;
435447
}
436448

449+
// Check arguments are unionable.
450+
Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
451+
for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
452+
PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
453+
int is_arg_unionable = is_unionable(arg);
454+
if (is_arg_unionable <= 0) {
455+
Py_DECREF(newargs);
456+
if (is_arg_unionable == 0) {
457+
PyErr_Format(PyExc_TypeError,
458+
"Each union argument must be a type, got %.100R", arg);
459+
}
460+
return NULL;
461+
}
462+
}
463+
437464
PyObject *res = make_union(newargs);
438465

439466
Py_DECREF(newargs);
@@ -495,21 +522,6 @@ make_union(PyObject *args)
495522
{
496523
assert(PyTuple_CheckExact(args));
497524

498-
unionobject* result = NULL;
499-
500-
// Check arguments are unionable.
501-
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
502-
for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
503-
PyObject *arg = PyTuple_GET_ITEM(args, iarg);
504-
int is_arg_unionable = is_unionable(arg);
505-
if (is_arg_unionable < 0) {
506-
return NULL;
507-
}
508-
if (!is_arg_unionable) {
509-
Py_RETURN_NOTIMPLEMENTED;
510-
}
511-
}
512-
513525
args = dedup_and_flatten_args(args);
514526
if (args == NULL) {
515527
return NULL;
@@ -521,7 +533,7 @@ make_union(PyObject *args)
521533
return result1;
522534
}
523535

524-
result = PyObject_GC_New(unionobject, &_PyUnion_Type);
536+
unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
525537
if (result == NULL) {
526538
Py_DECREF(args);
527539
return NULL;

0 commit comments

Comments
 (0)