Skip to content

Commit 2d055ce

Browse files
serhiy-storchakauriyyoFidget-Spinner
authored
[3.10] bpo-44490: Add __parameters__ and __getitem__ to types.Union (GH-26980) (GH-27207)
Co-authored-by: Ken Jin <[email protected]> Co-authored-by: Guido van Rossum <[email protected]>. (cherry picked from commit c45fa1a) Co-authored-by: Yurii Karabas <[email protected]> Co-authored-by: Ken Jin <[email protected]>
1 parent e22e864 commit 2d055ce

File tree

5 files changed

+102
-19
lines changed

5 files changed

+102
-19
lines changed

Include/internal/pycore_unionobject.h

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ PyAPI_FUNC(PyObject *) _Py_Union(PyObject *args);
1212
PyAPI_DATA(PyTypeObject) _Py_UnionType;
1313
PyAPI_FUNC(PyObject *) _Py_union_type_or(PyObject* self, PyObject* param);
1414

15+
PyObject *_Py_subs_parameters(PyObject *, PyObject *, PyObject *, PyObject *);
16+
PyObject *_Py_make_parameters(PyObject *);
17+
1518
#ifdef __cplusplus
1619
}
1720
#endif

Lib/test/test_types.py

+13
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,8 @@ def test_or_type_operator_with_TypeVar(self):
711711
TV = typing.TypeVar('T')
712712
assert TV | str == typing.Union[TV, str]
713713
assert str | TV == typing.Union[str, TV]
714+
self.assertIs((int | TV)[int], int)
715+
self.assertIs((TV | int)[int], int)
714716

715717
def test_union_args(self):
716718
def check(arg, expected):
@@ -742,6 +744,17 @@ def check(arg, expected):
742744
check(x | None, (x, type(None)))
743745
check(None | x, (type(None), x))
744746

747+
def test_union_parameter_chaining(self):
748+
T = typing.TypeVar("T")
749+
S = typing.TypeVar("S")
750+
751+
self.assertEqual((float | list[T])[int], float | list[int])
752+
self.assertEqual(list[int | list[T]].__parameters__, (T,))
753+
self.assertEqual(list[int | list[T]][str], list[int | list[str]])
754+
self.assertEqual((list[T] | list[S]).__parameters__, (T, S))
755+
self.assertEqual((list[T] | list[S])[int, T], list[int] | list[T])
756+
self.assertEqual((list[T] | list[S])[int, int], list[int])
757+
745758
def test_or_type_operator_with_forward(self):
746759
T = typing.TypeVar('T')
747760
ForwardAfter = T | 'Forward'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Add ``__parameters__`` attribute and ``__getitem__``
2+
operator to ``types.Union``. Patch provided by Yurii Karabas.

Objects/genericaliasobject.c

+31-19
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ tuple_add(PyObject *self, Py_ssize_t len, PyObject *item)
198198
return 0;
199199
}
200200

201-
static PyObject *
202-
make_parameters(PyObject *args)
201+
PyObject *
202+
_Py_make_parameters(PyObject *args)
203203
{
204204
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
205205
Py_ssize_t len = nargs;
@@ -294,18 +294,10 @@ subs_tvars(PyObject *obj, PyObject *params, PyObject **argitems)
294294
return obj;
295295
}
296296

297-
static PyObject *
298-
ga_getitem(PyObject *self, PyObject *item)
297+
PyObject *
298+
_Py_subs_parameters(PyObject *self, PyObject *args, PyObject *parameters, PyObject *item)
299299
{
300-
gaobject *alias = (gaobject *)self;
301-
// do a lookup for __parameters__ so it gets populated (if not already)
302-
if (alias->parameters == NULL) {
303-
alias->parameters = make_parameters(alias->args);
304-
if (alias->parameters == NULL) {
305-
return NULL;
306-
}
307-
}
308-
Py_ssize_t nparams = PyTuple_GET_SIZE(alias->parameters);
300+
Py_ssize_t nparams = PyTuple_GET_SIZE(parameters);
309301
if (nparams == 0) {
310302
return PyErr_Format(PyExc_TypeError,
311303
"There are no type variables left in %R",
@@ -320,32 +312,32 @@ ga_getitem(PyObject *self, PyObject *item)
320312
nitems > nparams ? "many" : "few",
321313
self);
322314
}
323-
/* Replace all type variables (specified by alias->parameters)
315+
/* Replace all type variables (specified by parameters)
324316
with corresponding values specified by argitems.
325317
t = list[T]; t[int] -> newargs = [int]
326318
t = dict[str, T]; t[int] -> newargs = [str, int]
327319
t = dict[T, list[S]]; t[str, int] -> newargs = [str, list[int]]
328320
*/
329-
Py_ssize_t nargs = PyTuple_GET_SIZE(alias->args);
321+
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
330322
PyObject *newargs = PyTuple_New(nargs);
331323
if (newargs == NULL) {
332324
return NULL;
333325
}
334326
for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
335-
PyObject *arg = PyTuple_GET_ITEM(alias->args, iarg);
327+
PyObject *arg = PyTuple_GET_ITEM(args, iarg);
336328
int typevar = is_typevar(arg);
337329
if (typevar < 0) {
338330
Py_DECREF(newargs);
339331
return NULL;
340332
}
341333
if (typevar) {
342-
Py_ssize_t iparam = tuple_index(alias->parameters, nparams, arg);
334+
Py_ssize_t iparam = tuple_index(parameters, nparams, arg);
343335
assert(iparam >= 0);
344336
arg = argitems[iparam];
345337
Py_INCREF(arg);
346338
}
347339
else {
348-
arg = subs_tvars(arg, alias->parameters, argitems);
340+
arg = subs_tvars(arg, parameters, argitems);
349341
if (arg == NULL) {
350342
Py_DECREF(newargs);
351343
return NULL;
@@ -354,6 +346,26 @@ ga_getitem(PyObject *self, PyObject *item)
354346
PyTuple_SET_ITEM(newargs, iarg, arg);
355347
}
356348

349+
return newargs;
350+
}
351+
352+
static PyObject *
353+
ga_getitem(PyObject *self, PyObject *item)
354+
{
355+
gaobject *alias = (gaobject *)self;
356+
// Populate __parameters__ if needed.
357+
if (alias->parameters == NULL) {
358+
alias->parameters = _Py_make_parameters(alias->args);
359+
if (alias->parameters == NULL) {
360+
return NULL;
361+
}
362+
}
363+
364+
PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
365+
if (newargs == NULL) {
366+
return NULL;
367+
}
368+
357369
PyObject *res = Py_GenericAlias(alias->origin, newargs);
358370

359371
Py_DECREF(newargs);
@@ -550,7 +562,7 @@ ga_parameters(PyObject *self, void *unused)
550562
{
551563
gaobject *alias = (gaobject *)self;
552564
if (alias->parameters == NULL) {
553-
alias->parameters = make_parameters(alias->args);
565+
alias->parameters = _Py_make_parameters(alias->args);
554566
if (alias->parameters == NULL) {
555567
return NULL;
556568
}

Objects/unionobject.c

+53
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
typedef struct {
99
PyObject_HEAD
1010
PyObject *args;
11+
PyObject *parameters;
1112
} unionobject;
1213

1314
static void
@@ -18,6 +19,7 @@ unionobject_dealloc(PyObject *self)
1819
_PyObject_GC_UNTRACK(self);
1920

2021
Py_XDECREF(alias->args);
22+
Py_XDECREF(alias->parameters);
2123
Py_TYPE(self)->tp_free(self);
2224
}
2325

@@ -26,6 +28,7 @@ union_traverse(PyObject *self, visitproc visit, void *arg)
2628
{
2729
unionobject *alias = (unionobject *)self;
2830
Py_VISIT(alias->args);
31+
Py_VISIT(alias->parameters);
2932
return 0;
3033
}
3134

@@ -450,6 +453,53 @@ static PyMethodDef union_methods[] = {
450453
{"__subclasscheck__", union_subclasscheck, METH_O},
451454
{0}};
452455

456+
457+
static PyObject *
458+
union_getitem(PyObject *self, PyObject *item)
459+
{
460+
unionobject *alias = (unionobject *)self;
461+
// Populate __parameters__ if needed.
462+
if (alias->parameters == NULL) {
463+
alias->parameters = _Py_make_parameters(alias->args);
464+
if (alias->parameters == NULL) {
465+
return NULL;
466+
}
467+
}
468+
469+
PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
470+
if (newargs == NULL) {
471+
return NULL;
472+
}
473+
474+
PyObject *res = _Py_Union(newargs);
475+
476+
Py_DECREF(newargs);
477+
return res;
478+
}
479+
480+
static PyMappingMethods union_as_mapping = {
481+
.mp_subscript = union_getitem,
482+
};
483+
484+
static PyObject *
485+
union_parameters(PyObject *self, void *Py_UNUSED(unused))
486+
{
487+
unionobject *alias = (unionobject *)self;
488+
if (alias->parameters == NULL) {
489+
alias->parameters = _Py_make_parameters(alias->args);
490+
if (alias->parameters == NULL) {
491+
return NULL;
492+
}
493+
}
494+
Py_INCREF(alias->parameters);
495+
return alias->parameters;
496+
}
497+
498+
static PyGetSetDef union_properties[] = {
499+
{"__parameters__", union_parameters, (setter)NULL, "Type variables in the types.Union.", NULL},
500+
{0}
501+
};
502+
453503
static PyNumberMethods union_as_number = {
454504
.nb_or = _Py_union_type_or, // Add __or__ function
455505
};
@@ -471,8 +521,10 @@ PyTypeObject _Py_UnionType = {
471521
.tp_members = union_members,
472522
.tp_methods = union_methods,
473523
.tp_richcompare = union_richcompare,
524+
.tp_as_mapping = &union_as_mapping,
474525
.tp_as_number = &union_as_number,
475526
.tp_repr = union_repr,
527+
.tp_getset = union_properties,
476528
};
477529

478530
PyObject *
@@ -516,6 +568,7 @@ _Py_Union(PyObject *args)
516568
return NULL;
517569
}
518570

571+
result->parameters = NULL;
519572
result->args = args;
520573
_PyObject_GC_TRACK(result);
521574
return (PyObject*)result;

0 commit comments

Comments
 (0)