From dc79d8c66f90f46f144ad925f31561fb2aceebdb Mon Sep 17 00:00:00 2001 From: Vinay Sajip Date: Thu, 26 Sep 2019 12:06:32 +0100 Subject: [PATCH 1/5] bpo-16575: Add checks for unions passed by value to functions. --- Lib/ctypes/test/test_structures.py | 47 ++++++++++++++++++++++++++++++ Modules/_ctypes/_ctypes.c | 19 ++++++++++++ Modules/_ctypes/_ctypes_test.c | 38 ++++++++++++++++++++++++ Modules/_ctypes/ctypes.h | 2 ++ Modules/_ctypes/stgdict.c | 8 +++++ 5 files changed, 114 insertions(+) diff --git a/Lib/ctypes/test/test_structures.py b/Lib/ctypes/test/test_structures.py index cdf4a9182439f8..4890a31ee8004f 100644 --- a/Lib/ctypes/test/test_structures.py +++ b/Lib/ctypes/test/test_structures.py @@ -576,6 +576,53 @@ class U(Union): self.assertEqual(f2, [0x4567, 0x0123, 0xcdef, 0x89ab, 0x3210, 0x7654, 0xba98, 0xfedc]) + def test_union_by_value(self): + # See bpo-16575 + + # These should mirror the structures in Modules/_ctypes/_ctypes_test.c + + class Nested1(Structure): + _fields = [ + ('an_int', c_int), + ('another_int', c_int), + ] + + class Test4(Union): + _fields_ = [ + ('a_long', c_long), + ('a_struct', Nested1), + ] + + class Nested2(Structure): + _fields = [ + ('an_int', c_int), + ('a_union', Test4), + ] + + class Test5(Structure): + _fields = [ + ('an_int', c_int), + ('nested', Nested2), + ] + + test4 = Test4() + dll = CDLL(_ctypes_test.__file__) + with self.assertRaises(TypeError) as ctx: + func = dll._testfunc_union_by_value1 + func.restype = c_long + func.argtypes = (Test4,) + result = func(test4) + self.assertEqual(ctx.exception.args[0], 'item 1 in _argtypes_ passes ' + 'a union by value, which is unsupported.') + test5 = Test5() + with self.assertRaises(TypeError) as ctx: + func = dll._testfunc_union_by_value2 + func.restype = c_long + func.argtypes = (Test5,) + result = func(test5) + self.assertEqual(ctx.exception.args[0], 'item 1 in _argtypes_ passes ' + 'a union by value, which is unsupported.') + class PointerMemberTestCase(unittest.TestCase): def test(self): diff --git a/Modules/_ctypes/_ctypes.c b/Modules/_ctypes/_ctypes.c index 16a0cfe8dd4dc7..986e3b1a861e41 100644 --- a/Modules/_ctypes/_ctypes.c +++ b/Modules/_ctypes/_ctypes.c @@ -2383,6 +2383,25 @@ converters_from_argtypes(PyObject *ob) for (i = 0; i < nArgs; ++i) { PyObject *cnv; PyObject *tp = PyTuple_GET_ITEM(ob, i); + StgDictObject *stgdict = PyType_stgdict(tp); + + if (stgdict != NULL) { + if (stgdict->flags & TYPEFLAG_HASUNION) { + Py_DECREF(converters); + Py_DECREF(ob); + if (!PyErr_Occurred()) { + PyErr_Format(PyExc_TypeError, + "item %zd in _argtypes_ passes a union by " + "value, which is unsupported.", + i + 1); + } + return NULL; + } + if (stgdict->flags & TYPEFLAG_HASBITFIELD) { + printf("found stgdict with bitfield\n"); + } + } + if (_PyObject_LookupAttrId(tp, &PyId_from_param, &cnv) <= 0) { Py_DECREF(converters); Py_DECREF(ob); diff --git a/Modules/_ctypes/_ctypes_test.c b/Modules/_ctypes/_ctypes_test.c index 40da652620271a..d34dcedd06efde 100644 --- a/Modules/_ctypes/_ctypes_test.c +++ b/Modules/_ctypes/_ctypes_test.c @@ -131,6 +131,44 @@ _testfunc_array_in_struct2a(Test3B in) return result; } +typedef union { + long a_long; + struct { + int an_int; + int another_int; + } a_struct; +} Test4; + +typedef struct { + int an_int; + struct { + int an_int; + Test4 a_union; + } nested; +} Test5; + +EXPORT(long) +_testfunc_union_by_value1(Test4 in) { + long result = in.a_long + in.a_struct.an_int + in.a_struct.another_int; + + /* As the union/struct are passed by value, changes to them shouldn't be + * reflected in the caller. + */ + memset(&in, 0, sizeof(in)); + return result; +} + +EXPORT(long) +_testfunc_union_by_value2(Test5 in) { + long result = in.an_int + in.nested.an_int; + + /* As the union/struct are passed by value, changes to them shouldn't be + * reflected in the caller. + */ + memset(&in, 0, sizeof(in)); + return result; +} + EXPORT(void)testfunc_array(int values[4]) { printf("testfunc_array %d %d %d %d\n", diff --git a/Modules/_ctypes/ctypes.h b/Modules/_ctypes/ctypes.h index 5d3b9663385f59..e58f85233cb71e 100644 --- a/Modules/_ctypes/ctypes.h +++ b/Modules/_ctypes/ctypes.h @@ -288,6 +288,8 @@ PyObject *_ctypes_callproc(PPROC pProc, #define TYPEFLAG_ISPOINTER 0x100 #define TYPEFLAG_HASPOINTER 0x200 +#define TYPEFLAG_HASUNION 0x400 +#define TYPEFLAG_HASBITFIELD 0x800 #define DICTFLAG_FINAL 0x1000 diff --git a/Modules/_ctypes/stgdict.c b/Modules/_ctypes/stgdict.c index 97bcf5539aeeab..14e514d12fcd69 100644 --- a/Modules/_ctypes/stgdict.c +++ b/Modules/_ctypes/stgdict.c @@ -440,6 +440,13 @@ PyCStructUnionType_update_stgdict(PyObject *type, PyObject *fields, int isStruct PyMem_Free(stgdict->ffi_type_pointer.elements); basedict = PyType_stgdict((PyObject *)((PyTypeObject *)type)->tp_base); + if (basedict) { + stgdict->flags |= (basedict->flags & + (TYPEFLAG_HASUNION | TYPEFLAG_HASBITFIELD)); + } + if (!isStruct) { + stgdict->flags |= TYPEFLAG_HASUNION; + } if (basedict && !use_broken_old_ctypes_semantics) { size = offset = basedict->size; align = basedict->align; @@ -517,6 +524,7 @@ PyCStructUnionType_update_stgdict(PyObject *type, PyObject *fields, int isStruct stgdict->flags |= TYPEFLAG_HASPOINTER; dict->flags |= DICTFLAG_FINAL; /* mark field type final */ if (PyTuple_Size(pair) == 3) { /* bits specified */ + stgdict->flags |= TYPEFLAG_HASBITFIELD; switch(dict->ffi_type_pointer.type) { case FFI_TYPE_UINT8: case FFI_TYPE_UINT16: From 72f4cfd5ba93a4630ec0cece3cfc284c8788aaa8 Mon Sep 17 00:00:00 2001 From: Vinay Sajip Date: Thu, 26 Sep 2019 16:57:32 +0100 Subject: [PATCH 2/5] Updated the logic for propagation of flags. --- Modules/_ctypes/_ctypes.c | 3 +++ Modules/_ctypes/stgdict.c | 1 + 2 files changed, 4 insertions(+) diff --git a/Modules/_ctypes/_ctypes.c b/Modules/_ctypes/_ctypes.c index 986e3b1a861e41..1290daa6cdfc52 100644 --- a/Modules/_ctypes/_ctypes.c +++ b/Modules/_ctypes/_ctypes.c @@ -504,6 +504,9 @@ StructUnionType_new(PyTypeObject *type, PyObject *args, PyObject *kwds, int isSt Py_DECREF(result); return NULL; } + if (!isStruct) { + dict->flags |= TYPEFLAG_HASUNION; + } /* replace the class dict by our updated stgdict, which holds info about storage requirements of the instances */ if (-1 == PyDict_Update((PyObject *)dict, result->tp_dict)) { diff --git a/Modules/_ctypes/stgdict.c b/Modules/_ctypes/stgdict.c index 14e514d12fcd69..1d45ade5efd903 100644 --- a/Modules/_ctypes/stgdict.c +++ b/Modules/_ctypes/stgdict.c @@ -522,6 +522,7 @@ PyCStructUnionType_update_stgdict(PyObject *type, PyObject *fields, int isStruct stgdict->ffi_type_pointer.elements[ffi_ofs + i] = &dict->ffi_type_pointer; if (dict->flags & (TYPEFLAG_ISPOINTER | TYPEFLAG_HASPOINTER)) stgdict->flags |= TYPEFLAG_HASPOINTER; + stgdict->flags |= dict->flags & (TYPEFLAG_HASUNION | TYPEFLAG_HASBITFIELD); dict->flags |= DICTFLAG_FINAL; /* mark field type final */ if (PyTuple_Size(pair) == 3) { /* bits specified */ stgdict->flags |= TYPEFLAG_HASBITFIELD; From 7fe0eb7c7ce2b34273f1cab641285ae5e0136ce9 Mon Sep 17 00:00:00 2001 From: Vinay Sajip Date: Thu, 26 Sep 2019 16:58:47 +0100 Subject: [PATCH 3/5] Corrected typos in test. --- Lib/ctypes/test/test_structures.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Lib/ctypes/test/test_structures.py b/Lib/ctypes/test/test_structures.py index 4890a31ee8004f..d543701680de0b 100644 --- a/Lib/ctypes/test/test_structures.py +++ b/Lib/ctypes/test/test_structures.py @@ -582,7 +582,7 @@ def test_union_by_value(self): # These should mirror the structures in Modules/_ctypes/_ctypes_test.c class Nested1(Structure): - _fields = [ + _fields_ = [ ('an_int', c_int), ('another_int', c_int), ] @@ -594,13 +594,13 @@ class Test4(Union): ] class Nested2(Structure): - _fields = [ + _fields_ = [ ('an_int', c_int), ('a_union', Test4), ] class Test5(Structure): - _fields = [ + _fields_ = [ ('an_int', c_int), ('nested', Nested2), ] From d83fc65d994ee3c22b01242321a36d8d303c24b9 Mon Sep 17 00:00:00 2001 From: Vinay Sajip Date: Thu, 26 Sep 2019 17:09:18 +0100 Subject: [PATCH 4/5] Removed debugging code. --- Modules/_ctypes/_ctypes.c | 3 --- 1 file changed, 3 deletions(-) diff --git a/Modules/_ctypes/_ctypes.c b/Modules/_ctypes/_ctypes.c index 1290daa6cdfc52..876fc85f63bc32 100644 --- a/Modules/_ctypes/_ctypes.c +++ b/Modules/_ctypes/_ctypes.c @@ -2400,9 +2400,6 @@ converters_from_argtypes(PyObject *ob) } return NULL; } - if (stgdict->flags & TYPEFLAG_HASBITFIELD) { - printf("found stgdict with bitfield\n"); - } } if (_PyObject_LookupAttrId(tp, &PyId_from_param, &cnv) <= 0) { From a3cf72d7172219528208d87b4f8100a28f05cdd6 Mon Sep 17 00:00:00 2001 From: Vinay Sajip Date: Thu, 26 Sep 2019 17:40:24 +0100 Subject: [PATCH 5/5] Improved tests. --- Lib/ctypes/test/test_structures.py | 33 ++++++++++++++++++++++++++++++ Modules/_ctypes/_ctypes_test.c | 25 ++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/Lib/ctypes/test/test_structures.py b/Lib/ctypes/test/test_structures.py index d543701680de0b..ab23206df8a762 100644 --- a/Lib/ctypes/test/test_structures.py +++ b/Lib/ctypes/test/test_structures.py @@ -603,6 +603,7 @@ class Test5(Structure): _fields_ = [ ('an_int', c_int), ('nested', Nested2), + ('another_int', c_int), ] test4 = Test4() @@ -623,6 +624,38 @@ class Test5(Structure): self.assertEqual(ctx.exception.args[0], 'item 1 in _argtypes_ passes ' 'a union by value, which is unsupported.') + # passing by reference should be OK + test4.a_long = 12345; + func = dll._testfunc_union_by_reference1 + func.restype = c_long + func.argtypes = (POINTER(Test4),) + result = func(byref(test4)) + self.assertEqual(result, 12345) + self.assertEqual(test4.a_long, 0) + self.assertEqual(test4.a_struct.an_int, 0) + self.assertEqual(test4.a_struct.another_int, 0) + test4.a_struct.an_int = 0x12340000 + test4.a_struct.another_int = 0x5678 + func = dll._testfunc_union_by_reference2 + func.restype = c_long + func.argtypes = (POINTER(Test4),) + result = func(byref(test4)) + self.assertEqual(result, 0x12345678) + self.assertEqual(test4.a_long, 0) + self.assertEqual(test4.a_struct.an_int, 0) + self.assertEqual(test4.a_struct.another_int, 0) + test5.an_int = 0x12000000 + test5.nested.an_int = 0x345600 + test5.another_int = 0x78 + func = dll._testfunc_union_by_reference3 + func.restype = c_long + func.argtypes = (POINTER(Test5),) + result = func(byref(test5)) + self.assertEqual(result, 0x12345678) + self.assertEqual(test5.an_int, 0) + self.assertEqual(test5.nested.an_int, 0) + self.assertEqual(test5.another_int, 0) + class PointerMemberTestCase(unittest.TestCase): def test(self): diff --git a/Modules/_ctypes/_ctypes_test.c b/Modules/_ctypes/_ctypes_test.c index d34dcedd06efde..49ed82004f6048 100644 --- a/Modules/_ctypes/_ctypes_test.c +++ b/Modules/_ctypes/_ctypes_test.c @@ -145,6 +145,7 @@ typedef struct { int an_int; Test4 a_union; } nested; + int another_int; } Test5; EXPORT(long) @@ -169,6 +170,30 @@ _testfunc_union_by_value2(Test5 in) { return result; } +EXPORT(long) +_testfunc_union_by_reference1(Test4 *in) { + long result = in->a_long; + + memset(in, 0, sizeof(Test4)); + return result; +} + +EXPORT(long) +_testfunc_union_by_reference2(Test4 *in) { + long result = in->a_struct.an_int + in->a_struct.another_int; + + memset(in, 0, sizeof(Test4)); + return result; +} + +EXPORT(long) +_testfunc_union_by_reference3(Test5 *in) { + long result = in->an_int + in->nested.an_int + in->another_int; + + memset(in, 0, sizeof(Test5)); + return result; +} + EXPORT(void)testfunc_array(int values[4]) { printf("testfunc_array %d %d %d %d\n",