Skip to content

Commit c45fa1a

Browse files
uriyyoFidget-Spinnergvanrossum
authored
bpo-44490: Add __parameters__ and __getitem__ to types.Union (GH-26980)
Co-authored-by: Ken Jin <[email protected]> Co-authored-by: Guido van Rossum <[email protected]>
1 parent 8b849ea commit c45fa1a

File tree

5 files changed

+101
-19
lines changed

5 files changed

+101
-19
lines changed

Include/genericaliasobject.h

+5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
extern "C" {
66
#endif
77

8+
#ifndef Py_LIMITED_API
9+
PyAPI_FUNC(PyObject *) _Py_subs_parameters(PyObject *, PyObject *, PyObject *, PyObject *);
10+
PyAPI_FUNC(PyObject *) _Py_make_parameters(PyObject *);
11+
#endif
12+
813
PyAPI_FUNC(PyObject *) Py_GenericAlias(PyObject *, PyObject *);
914
PyAPI_DATA(PyTypeObject) Py_GenericAliasType;
1015

Lib/test/test_types.py

+10
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,16 @@ def test_or_type_operator_with_TypeVar(self):
666666
assert TV | str == typing.Union[TV, str]
667667
assert str | TV == typing.Union[str, TV]
668668

669+
def test_union_parameter_chaining(self):
670+
T = typing.TypeVar("T")
671+
S = typing.TypeVar("S")
672+
673+
self.assertEqual((float | list[T])[int], float | list[int])
674+
self.assertEqual(list[int | list[T]].__parameters__, (T,))
675+
self.assertEqual(list[int | list[T]][str], list[int | list[str]])
676+
self.assertEqual((list[T] | list[S]).__parameters__, (T, S))
677+
self.assertEqual((list[T] | list[S])[int, T], list[int] | list[T])
678+
669679
def test_or_type_operator_with_forward(self):
670680
T = typing.TypeVar('T')
671681
ForwardAfter = T | 'Forward'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Add ``__parameters__`` attribute and ``__getitem__``
2+
operator to ``types.Union``. Patch provided by Yurii Karabas.

Objects/genericaliasobject.c

+31-19
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ tuple_add(PyObject *self, Py_ssize_t len, PyObject *item)
198198
return 0;
199199
}
200200

201-
static PyObject *
202-
make_parameters(PyObject *args)
201+
PyObject *
202+
_Py_make_parameters(PyObject *args)
203203
{
204204
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
205205
Py_ssize_t len = nargs;
@@ -294,18 +294,10 @@ subs_tvars(PyObject *obj, PyObject *params, PyObject **argitems)
294294
return obj;
295295
}
296296

297-
static PyObject *
298-
ga_getitem(PyObject *self, PyObject *item)
297+
PyObject *
298+
_Py_subs_parameters(PyObject *self, PyObject *args, PyObject *parameters, PyObject *item)
299299
{
300-
gaobject *alias = (gaobject *)self;
301-
// do a lookup for __parameters__ so it gets populated (if not already)
302-
if (alias->parameters == NULL) {
303-
alias->parameters = make_parameters(alias->args);
304-
if (alias->parameters == NULL) {
305-
return NULL;
306-
}
307-
}
308-
Py_ssize_t nparams = PyTuple_GET_SIZE(alias->parameters);
300+
Py_ssize_t nparams = PyTuple_GET_SIZE(parameters);
309301
if (nparams == 0) {
310302
return PyErr_Format(PyExc_TypeError,
311303
"There are no type variables left in %R",
@@ -320,32 +312,32 @@ ga_getitem(PyObject *self, PyObject *item)
320312
nitems > nparams ? "many" : "few",
321313
self);
322314
}
323-
/* Replace all type variables (specified by alias->parameters)
315+
/* Replace all type variables (specified by parameters)
324316
with corresponding values specified by argitems.
325317
t = list[T]; t[int] -> newargs = [int]
326318
t = dict[str, T]; t[int] -> newargs = [str, int]
327319
t = dict[T, list[S]]; t[str, int] -> newargs = [str, list[int]]
328320
*/
329-
Py_ssize_t nargs = PyTuple_GET_SIZE(alias->args);
321+
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
330322
PyObject *newargs = PyTuple_New(nargs);
331323
if (newargs == NULL) {
332324
return NULL;
333325
}
334326
for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
335-
PyObject *arg = PyTuple_GET_ITEM(alias->args, iarg);
327+
PyObject *arg = PyTuple_GET_ITEM(args, iarg);
336328
int typevar = is_typevar(arg);
337329
if (typevar < 0) {
338330
Py_DECREF(newargs);
339331
return NULL;
340332
}
341333
if (typevar) {
342-
Py_ssize_t iparam = tuple_index(alias->parameters, nparams, arg);
334+
Py_ssize_t iparam = tuple_index(parameters, nparams, arg);
343335
assert(iparam >= 0);
344336
arg = argitems[iparam];
345337
Py_INCREF(arg);
346338
}
347339
else {
348-
arg = subs_tvars(arg, alias->parameters, argitems);
340+
arg = subs_tvars(arg, parameters, argitems);
349341
if (arg == NULL) {
350342
Py_DECREF(newargs);
351343
return NULL;
@@ -354,6 +346,26 @@ ga_getitem(PyObject *self, PyObject *item)
354346
PyTuple_SET_ITEM(newargs, iarg, arg);
355347
}
356348

349+
return newargs;
350+
}
351+
352+
static PyObject *
353+
ga_getitem(PyObject *self, PyObject *item)
354+
{
355+
gaobject *alias = (gaobject *)self;
356+
// Populate __parameters__ if needed.
357+
if (alias->parameters == NULL) {
358+
alias->parameters = _Py_make_parameters(alias->args);
359+
if (alias->parameters == NULL) {
360+
return NULL;
361+
}
362+
}
363+
364+
PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
365+
if (newargs == NULL) {
366+
return NULL;
367+
}
368+
357369
PyObject *res = Py_GenericAlias(alias->origin, newargs);
358370

359371
Py_DECREF(newargs);
@@ -550,7 +562,7 @@ ga_parameters(PyObject *self, void *unused)
550562
{
551563
gaobject *alias = (gaobject *)self;
552564
if (alias->parameters == NULL) {
553-
alias->parameters = make_parameters(alias->args);
565+
alias->parameters = _Py_make_parameters(alias->args);
554566
if (alias->parameters == NULL) {
555567
return NULL;
556568
}

Objects/unionobject.c

+53
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
typedef struct {
99
PyObject_HEAD
1010
PyObject *args;
11+
PyObject *parameters;
1112
} unionobject;
1213

1314
static void
@@ -18,6 +19,7 @@ unionobject_dealloc(PyObject *self)
1819
_PyObject_GC_UNTRACK(self);
1920

2021
Py_XDECREF(alias->args);
22+
Py_XDECREF(alias->parameters);
2123
Py_TYPE(self)->tp_free(self);
2224
}
2325

@@ -26,6 +28,7 @@ union_traverse(PyObject *self, visitproc visit, void *arg)
2628
{
2729
unionobject *alias = (unionobject *)self;
2830
Py_VISIT(alias->args);
31+
Py_VISIT(alias->parameters);
2932
return 0;
3033
}
3134

@@ -435,6 +438,53 @@ static PyMethodDef union_methods[] = {
435438
{"__subclasscheck__", union_subclasscheck, METH_O},
436439
{0}};
437440

441+
442+
static PyObject *
443+
union_getitem(PyObject *self, PyObject *item)
444+
{
445+
unionobject *alias = (unionobject *)self;
446+
// Populate __parameters__ if needed.
447+
if (alias->parameters == NULL) {
448+
alias->parameters = _Py_make_parameters(alias->args);
449+
if (alias->parameters == NULL) {
450+
return NULL;
451+
}
452+
}
453+
454+
PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
455+
if (newargs == NULL) {
456+
return NULL;
457+
}
458+
459+
PyObject *res = _Py_Union(newargs);
460+
461+
Py_DECREF(newargs);
462+
return res;
463+
}
464+
465+
static PyMappingMethods union_as_mapping = {
466+
.mp_subscript = union_getitem,
467+
};
468+
469+
static PyObject *
470+
union_parameters(PyObject *self, void *Py_UNUSED(unused))
471+
{
472+
unionobject *alias = (unionobject *)self;
473+
if (alias->parameters == NULL) {
474+
alias->parameters = _Py_make_parameters(alias->args);
475+
if (alias->parameters == NULL) {
476+
return NULL;
477+
}
478+
}
479+
Py_INCREF(alias->parameters);
480+
return alias->parameters;
481+
}
482+
483+
static PyGetSetDef union_properties[] = {
484+
{"__parameters__", union_parameters, (setter)NULL, "Type variables in the types.Union.", NULL},
485+
{0}
486+
};
487+
438488
static PyNumberMethods union_as_number = {
439489
.nb_or = _Py_union_type_or, // Add __or__ function
440490
};
@@ -456,8 +506,10 @@ PyTypeObject _Py_UnionType = {
456506
.tp_members = union_members,
457507
.tp_methods = union_methods,
458508
.tp_richcompare = union_richcompare,
509+
.tp_as_mapping = &union_as_mapping,
459510
.tp_as_number = &union_as_number,
460511
.tp_repr = union_repr,
512+
.tp_getset = union_properties,
461513
};
462514

463515
PyObject *
@@ -489,6 +541,7 @@ _Py_Union(PyObject *args)
489541
return NULL;
490542
}
491543

544+
result->parameters = NULL;
492545
result->args = dedup_and_flatten_args(args);
493546
_PyObject_GC_TRACK(result);
494547
if (result->args == NULL) {

0 commit comments

Comments
 (0)