@@ -302,10 +302,22 @@ is_unionable(PyObject *obj)
302
302
PyObject *
303
303
_Py_union_type_or (PyObject * self , PyObject * other )
304
304
{
305
+ int r = is_unionable (self );
306
+ if (r > 0 ) {
307
+ r = is_unionable (other );
308
+ }
309
+ if (r < 0 ) {
310
+ return NULL ;
311
+ }
312
+ if (!r ) {
313
+ Py_RETURN_NOTIMPLEMENTED ;
314
+ }
315
+
305
316
PyObject * tuple = PyTuple_Pack (2 , self , other );
306
317
if (tuple == NULL ) {
307
318
return NULL ;
308
319
}
320
+
309
321
PyObject * new_union = make_union (tuple );
310
322
Py_DECREF (tuple );
311
323
return new_union ;
@@ -434,6 +446,21 @@ union_getitem(PyObject *self, PyObject *item)
434
446
return NULL ;
435
447
}
436
448
449
+ // Check arguments are unionable.
450
+ Py_ssize_t nargs = PyTuple_GET_SIZE (newargs );
451
+ for (Py_ssize_t iarg = 0 ; iarg < nargs ; iarg ++ ) {
452
+ PyObject * arg = PyTuple_GET_ITEM (newargs , iarg );
453
+ int is_arg_unionable = is_unionable (arg );
454
+ if (is_arg_unionable <= 0 ) {
455
+ Py_DECREF (newargs );
456
+ if (is_arg_unionable == 0 ) {
457
+ PyErr_Format (PyExc_TypeError ,
458
+ "Each union argument must be a type, got %.100R" , arg );
459
+ }
460
+ return NULL ;
461
+ }
462
+ }
463
+
437
464
PyObject * res = make_union (newargs );
438
465
439
466
Py_DECREF (newargs );
@@ -495,21 +522,6 @@ make_union(PyObject *args)
495
522
{
496
523
assert (PyTuple_CheckExact (args ));
497
524
498
- unionobject * result = NULL ;
499
-
500
- // Check arguments are unionable.
501
- Py_ssize_t nargs = PyTuple_GET_SIZE (args );
502
- for (Py_ssize_t iarg = 0 ; iarg < nargs ; iarg ++ ) {
503
- PyObject * arg = PyTuple_GET_ITEM (args , iarg );
504
- int is_arg_unionable = is_unionable (arg );
505
- if (is_arg_unionable < 0 ) {
506
- return NULL ;
507
- }
508
- if (!is_arg_unionable ) {
509
- Py_RETURN_NOTIMPLEMENTED ;
510
- }
511
- }
512
-
513
525
args = dedup_and_flatten_args (args );
514
526
if (args == NULL ) {
515
527
return NULL ;
@@ -521,7 +533,7 @@ make_union(PyObject *args)
521
533
return result1 ;
522
534
}
523
535
524
- result = PyObject_GC_New (unionobject , & _PyUnion_Type );
536
+ unionobject * result = PyObject_GC_New (unionobject , & _PyUnion_Type );
525
537
if (result == NULL ) {
526
538
Py_DECREF (args );
527
539
return NULL ;
0 commit comments