Skip to content

Commit b82c0f0

Browse files
bmerrydean0x7d
authored andcommitted
Allow std::complex field with PYBIND11_NUMPY_DTYPE (#831)
This exposed a few underlying issues: 1. is_pod_struct was too strict to allow this. I've relaxed it to require only trivially copyable and standard layout, rather than POD (which additionally requires a trivial constructor, which std::complex violates). 2. format_descriptor<std::complex<T>>::format() returned numpy format strings instead of PEP3118 format strings, but register_dtype feeds format codes of its fields to _dtype_from_pep3118. I've changed it to return PEP3118 format codes. format_descriptor is a public type, so this may be considered an incompatible change. 3. register_structured_dtype tried to be smart about whether to mark fields as unaligned (with ^). However, it's examining the C++ alignment, rather than what numpy (or possibly PEP3118) thinks the alignment should be. For complex values those are different. I've made it mark all fields as ^ unconditionally, which should always be safe even if they are aligned, because we explicitly mark the padding.
1 parent 8e0d832 commit b82c0f0

File tree

6 files changed

+91
-26
lines changed

6 files changed

+91
-26
lines changed

docs/advanced/pycpp/numpy.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,12 @@ expects the type followed by field names:
198198
/* now both A and B can be used as template arguments to py::array_t */
199199
}
200200
201-
The structure should consist of fundamental arithmetic types, previously
202-
registered substructures, and arrays of any of the above. Both C++ arrays and
203-
``std::array`` are supported.
201+
The structure should consist of fundamental arithmetic types, ``std::complex``,
202+
previously registered substructures, and arrays of any of the above. Both C++
203+
arrays and ``std::array`` are supported. While there is a static assertion to
204+
prevent many types of unsupported structures, it is still the user's
205+
responsibility to use only "plain" structures that can be safely manipulated as
206+
raw memory without violating invariants.
204207

205208
Vectorizing functions
206209
=====================

include/pybind11/common.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -608,14 +608,14 @@ template <typename T> struct is_fmt_numeric<T, enable_if_t<std::is_arithmetic<T>
608608
};
609609
NAMESPACE_END(detail)
610610

611-
template <typename T> struct format_descriptor<T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>> {
612-
static constexpr const char c = "?bBhHiIqQfdgFDG"[detail::is_fmt_numeric<T>::index];
611+
template <typename T> struct format_descriptor<T, detail::enable_if_t<std::is_arithmetic<T>::value>> {
612+
static constexpr const char c = "?bBhHiIqQfdg"[detail::is_fmt_numeric<T>::index];
613613
static constexpr const char value[2] = { c, '\0' };
614614
static std::string format() { return std::string(1, c); }
615615
};
616616

617617
template <typename T> constexpr const char format_descriptor<
618-
T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>>::value[2];
618+
T, detail::enable_if_t<std::is_arithmetic<T>::value>>::value[2];
619619

620620
/// RAII wrapper that temporarily clears any Python error state
621621
struct error_scope {

include/pybind11/complex.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,19 @@
1818
#endif
1919

2020
NAMESPACE_BEGIN(pybind11)
21+
22+
template <typename T> struct format_descriptor<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
23+
static constexpr const char c = format_descriptor<T>::c;
24+
static constexpr const char value[3] = { 'Z', c, '\0' };
25+
static std::string format() { return std::string(value); }
26+
};
27+
28+
template <typename T> constexpr const char format_descriptor<
29+
std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
30+
2131
NAMESPACE_BEGIN(detail)
2232

23-
// The format codes are already in the string in common.h, we just need to provide a specialization
24-
template <typename T> struct is_fmt_numeric<std::complex<T>> {
33+
template <typename T> struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
2534
static constexpr bool value = true;
2635
static constexpr int index = is_fmt_numeric<T>::index + 3;
2736
};

include/pybind11/numpy.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,14 @@ template <typename T, size_t N> struct array_info<T[N]> : array_info<std::array<
287287
template <typename T> using remove_all_extents_t = typename array_info<T>::type;
288288

289289
template <typename T> using is_pod_struct = all_of<
290-
std::is_pod<T>, // since we're accessing directly in memory we need a POD type
290+
std::is_standard_layout<T>, // since we're accessing directly in memory we need a standard layout type
291+
#if !defined(__GNUG__) || defined(__clang__) || __GNUC__ >= 5
292+
std::is_trivially_copyable<T>,
293+
#else
294+
// GCC 4 doesn't implement is_trivially_copyable, so approximate it
295+
std::is_trivially_destructible<T>,
296+
satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
297+
#endif
291298
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
292299
>;
293300

@@ -1016,7 +1023,6 @@ struct field_descriptor {
10161023
const char *name;
10171024
ssize_t offset;
10181025
ssize_t size;
1019-
ssize_t alignment;
10201026
std::string format;
10211027
dtype descr;
10221028
};
@@ -1053,13 +1059,15 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
10531059
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
10541060
ssize_t offset = 0;
10551061
std::ostringstream oss;
1056-
oss << "T{";
1062+
// mark the structure as unaligned with '^', because numpy and C++ don't
1063+
// always agree about alignment (particularly for complex), and we're
1064+
// explicitly listing all our padding. This depends on none of the fields
1065+
// overriding the endianness. Putting the ^ in front of individual fields
1066+
// isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049
1067+
oss << "^T{";
10571068
for (auto& field : ordered_fields) {
10581069
if (field.offset > offset)
10591070
oss << (field.offset - offset) << 'x';
1060-
// mark unaligned fields with '^' (unaligned native type)
1061-
if (field.offset % field.alignment)
1062-
oss << '^';
10631071
oss << field.format << ':' << field.name << ':';
10641072
offset = field.offset + field.size;
10651073
}
@@ -1121,7 +1129,6 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
11211129
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
11221130
::pybind11::detail::field_descriptor { \
11231131
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
1124-
alignof(decltype(std::declval<T>().Field)), \
11251132
::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
11261133
::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
11271134
}

tests/test_numpy_dtypes.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ struct StringStruct {
7070
std::array<char, 3> b;
7171
};
7272

73+
struct ComplexStruct {
74+
std::complex<float> cflt;
75+
std::complex<double> cdbl;
76+
};
77+
78+
std::ostream& operator<<(std::ostream& os, const ComplexStruct& v) {
79+
return os << "c:" << v.cflt << "," << v.cdbl;
80+
}
81+
7382
struct ArrayStruct {
7483
char a[3][4];
7584
int32_t b[2];
@@ -219,6 +228,18 @@ py::array_t<EnumStruct, 0> create_enum_array(size_t n) {
219228
return arr;
220229
}
221230

231+
py::array_t<ComplexStruct, 0> create_complex_array(size_t n) {
232+
auto arr = mkarray_via_buffer<ComplexStruct>(n);
233+
auto ptr = (ComplexStruct *) arr.mutable_data();
234+
for (size_t i = 0; i < n; i++) {
235+
ptr[i].cflt.real(float(i));
236+
ptr[i].cflt.imag(float(i) + 0.25f);
237+
ptr[i].cdbl.real(double(i) + 0.5);
238+
ptr[i].cdbl.imag(double(i) + 0.75);
239+
}
240+
return arr;
241+
}
242+
222243
template <typename S>
223244
py::list print_recarray(py::array_t<S, 0> arr) {
224245
const auto req = arr.request();
@@ -241,7 +262,8 @@ py::list print_format_descriptors() {
241262
py::format_descriptor<PartialNestedStruct>::format(),
242263
py::format_descriptor<StringStruct>::format(),
243264
py::format_descriptor<ArrayStruct>::format(),
244-
py::format_descriptor<EnumStruct>::format()
265+
py::format_descriptor<EnumStruct>::format(),
266+
py::format_descriptor<ComplexStruct>::format()
245267
};
246268
auto l = py::list();
247269
for (const auto &fmt : fmts) {
@@ -260,7 +282,8 @@ py::list print_dtypes() {
260282
py::str(py::dtype::of<StringStruct>()),
261283
py::str(py::dtype::of<ArrayStruct>()),
262284
py::str(py::dtype::of<EnumStruct>()),
263-
py::str(py::dtype::of<StructWithUglyNames>())
285+
py::str(py::dtype::of<StructWithUglyNames>()),
286+
py::str(py::dtype::of<ComplexStruct>())
264287
};
265288
auto l = py::list();
266289
for (const auto &s : dtypes) {
@@ -401,6 +424,7 @@ test_initializer numpy_dtypes([](py::module &m) {
401424
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
402425
PYBIND11_NUMPY_DTYPE(ArrayStruct, a, b, c, d);
403426
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
427+
PYBIND11_NUMPY_DTYPE(ComplexStruct, cflt, cdbl);
404428
PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b);
405429
PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z);
406430

@@ -431,6 +455,8 @@ test_initializer numpy_dtypes([](py::module &m) {
431455
m.def("print_array_array", &print_recarray<ArrayStruct>);
432456
m.def("create_enum_array", &create_enum_array);
433457
m.def("print_enum_array", &print_recarray<EnumStruct>);
458+
m.def("create_complex_array", &create_complex_array);
459+
m.def("print_complex_array", &print_recarray<ComplexStruct>);
434460
m.def("test_array_ctors", &test_array_ctors);
435461
m.def("test_dtype_ctors", &test_dtype_ctors);
436462
m.def("test_dtype_methods", &test_dtype_methods);

tests/test_numpy_dtypes.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,21 +73,22 @@ def test_format_descriptors():
7373

7474
ld = np.dtype('longdouble')
7575
ldbl_fmt = ('4x' if ld.alignment > 4 else '') + ld.char
76-
ss_fmt = "T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}"
76+
ss_fmt = "^T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}"
7777
dbl = np.dtype('double')
78-
partial_fmt = ("T{?:bool_:3xI:uint_:f:float_:" +
78+
partial_fmt = ("^T{?:bool_:3xI:uint_:f:float_:" +
7979
str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) +
8080
"xg:ldbl_:}")
8181
nested_extra = str(max(8, ld.alignment))
8282
assert print_format_descriptors() == [
8383
ss_fmt,
84-
"T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}",
85-
"T{" + ss_fmt + ":a:T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}:b:}",
84+
"^T{?:bool_:I:uint_:f:float_:g:ldbl_:}",
85+
"^T{" + ss_fmt + ":a:^T{?:bool_:I:uint_:f:float_:g:ldbl_:}:b:}",
8686
partial_fmt,
87-
"T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}",
88-
"T{3s:a:3s:b:}",
89-
"T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}",
90-
'T{q:e1:B:e2:}'
87+
"^T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}",
88+
"^T{3s:a:3s:b:}",
89+
"^T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}",
90+
'^T{q:e1:B:e2:}',
91+
'^T{Zf:cflt:Zd:cdbl:}'
9192
]
9293

9394

@@ -108,7 +109,8 @@ def test_dtype(simple_dtype):
108109
"'formats':[('S4', (3,)),('<i4', (2,)),('u1', (3,)),('<f4', (4, 2))], " +
109110
"'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e),
110111
"[('e1', '" + e + "i8'), ('e2', 'u1')]",
111-
"[('x', 'i1'), ('y', '" + e + "u8')]"
112+
"[('x', 'i1'), ('y', '" + e + "u8')]",
113+
"[('cflt', '" + e + "c8'), ('cdbl', '" + e + "c16')]"
112114
]
113115

114116
d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'],
@@ -260,6 +262,24 @@ def test_enum_array():
260262
assert create_enum_array(0).dtype == dtype
261263

262264

265+
def test_complex_array():
266+
from pybind11_tests import create_complex_array, print_complex_array
267+
from sys import byteorder
268+
e = '<' if byteorder == 'little' else '>'
269+
270+
arr = create_complex_array(3)
271+
dtype = arr.dtype
272+
assert dtype == np.dtype([('cflt', e + 'c8'), ('cdbl', e + 'c16')])
273+
assert print_complex_array(arr) == [
274+
"c:(0,0.25),(0.5,0.75)",
275+
"c:(1,1.25),(1.5,1.75)",
276+
"c:(2,2.25),(2.5,2.75)"
277+
]
278+
assert arr['cflt'].tolist() == [0.0 + 0.25j, 1.0 + 1.25j, 2.0 + 2.25j]
279+
assert arr['cdbl'].tolist() == [0.5 + 0.75j, 1.5 + 1.75j, 2.5 + 2.75j]
280+
assert create_complex_array(0).dtype == dtype
281+
282+
263283
def test_signature(doc):
264284
from pybind11_tests import create_rec_nested
265285

0 commit comments

Comments
 (0)