Skip to content

Commit ebbcce8

Browse files
author
Mmanu Chaturvedi
committed
Add ability to create object matrices
1 parent 6a81dbb commit ebbcce8

File tree

2 files changed

+160
-19
lines changed

2 files changed

+160
-19
lines changed

include/pybind11/eigen.h

+145-19
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include "numpy.h"
13+
#include "numpy/ndarraytypes.h"
1314

1415
#if defined(__INTEL_COMPILER)
1516
# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
@@ -139,14 +140,19 @@ template <typename Type_> struct EigenProps {
139140
const auto dims = a.ndim();
140141
if (dims < 1 || dims > 2)
141142
return false;
142-
143+
bool is_pyobject = false;
144+
if (npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_)
145+
is_pyobject = true;
146+
ssize_t scalar_size = (is_pyobject ? static_cast<ssize_t>(sizeof(PyObject*)) :
147+
static_cast<ssize_t>(sizeof(Scalar)));
143148
if (dims == 2) { // Matrix type: require exact match (or dynamic)
144149

145150
EigenIndex
146151
np_rows = a.shape(0),
147152
np_cols = a.shape(1),
148-
np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
149-
np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
153+
np_rstride = a.strides(0) / scalar_size,
154+
np_cstride = a.strides(1) / scalar_size;
155+
150156
if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
151157
return false;
152158

@@ -156,7 +162,7 @@ template <typename Type_> struct EigenProps {
156162
// Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
157163
// is used, we want the (single) numpy stride value.
158164
const EigenIndex n = a.shape(0),
159-
stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
165+
stride = a.strides(0) / scalar_size;
160166

161167
if (vector) { // Eigen type is a compile-time vector
162168
if (fixed && size != n)
@@ -207,11 +213,52 @@ template <typename Type_> struct EigenProps {
207213
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
208214
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
209215
array a;
210-
if (props::vector)
211-
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
212-
else
213-
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
214-
src.data(), base);
216+
using Scalar = typename props::Type::Scalar;
217+
bool is_pyoject = npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_;
218+
219+
if (!is_pyoject) {
220+
if (props::vector)
221+
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
222+
else
223+
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
224+
src.data(), base);
225+
}
226+
else {
227+
if (props::vector) {
228+
a = array(
229+
npy_format_descriptor<Scalar>::dtype(),
230+
{ (size_t) src.size() },
231+
nullptr,
232+
base
233+
);
234+
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
235+
for (ssize_t i = 0; i < src.size(); ++i) {
236+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src(i, 0), policy, base));
237+
if (!value_)
238+
return handle();
239+
auto p = a.mutable_data(i);
240+
PyArray_SETITEM(a.ptr(), p, value_.release().ptr());
241+
}
242+
}
243+
else {
244+
a = array(
245+
npy_format_descriptor<Scalar>::dtype(),
246+
{(size_t) src.rows(), (size_t) src.cols()},
247+
nullptr,
248+
base
249+
);
250+
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
251+
for (ssize_t i = 0; i < src.rows(); ++i) {
252+
for (ssize_t j = 0; j < src.cols(); ++j) {
253+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src(i, j), policy, base));
254+
if (!value_)
255+
return handle();
256+
auto p = a.mutable_data(i, j);
257+
PyArray_SETITEM(a.ptr(), p, value_.release().ptr());
258+
}
259+
}
260+
}
261+
}
215262

216263
if (!writeable)
217264
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
@@ -265,14 +312,47 @@ struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
265312
auto fits = props::conformable(buf);
266313
if (!fits)
267314
return false;
268-
315+
int result = 0;
269316
// Allocate the new type, then build a numpy reference into it
270317
value = Type(fits.rows, fits.cols);
271-
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
272-
if (dims == 1) ref = ref.squeeze();
273-
else if (ref.ndim() == 1) buf = buf.squeeze();
274-
275-
int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
318+
bool is_pyobject = npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_;
319+
320+
if (!is_pyobject) {
321+
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
322+
if (dims == 1) ref = ref.squeeze();
323+
else if (ref.ndim() == 1) buf = buf.squeeze();
324+
result =
325+
detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
326+
}
327+
else {
328+
if (dims == 1){
329+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
330+
value.resize(buf.shape(0), 1);
331+
}
332+
for (ssize_t i = 0; i < buf.shape(0); ++i) {
333+
auto p = buf.mutable_data(i);
334+
make_caster <Scalar> conv_val;
335+
if (!conv_val.load(PyArray_GETITEM(buf.ptr(), p), convert))
336+
return false;
337+
value(i) = cast_op<Scalar>(conv_val);
338+
}
339+
} else {
340+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
341+
value.resize(buf.shape(0), buf.shape(1));
342+
}
343+
for (ssize_t i = 0; i < buf.shape(0); ++i) {
344+
for (ssize_t j = 0; j < buf.shape(1); ++j) {
345+
// p is the const void pointer to the item
346+
auto p = buf.mutable_data(i, j);
347+
make_caster<Scalar> conv_val;
348+
if (!conv_val.load(PyArray_GETITEM(buf.ptr(), p),
349+
convert))
350+
return false;
351+
value(i,j) = cast_op<Scalar>(conv_val);
352+
}
353+
}
354+
}
355+
}
276356

277357
if (result < 0) { // Copy failed!
278358
PyErr_Clear();
@@ -424,13 +504,19 @@ struct type_caster<
424504
// storage order conversion. (Note that we refuse to use this temporary copy when loading an
425505
// argument for a Ref<M> with M non-const, i.e. a read-write reference).
426506
Array copy_or_ref;
507+
typename std::remove_cv<PlainObjectType>::type val;
427508
public:
428509
bool load(handle src, bool convert) {
429510
// First check whether what we have is already an array of the right type. If not, we can't
430511
// avoid a copy (because the copy is also going to do type conversion).
431512
bool need_copy = !isinstance<Array>(src);
432513

433514
EigenConformable<props::row_major> fits;
515+
bool is_pyobject = false;
516+
if (npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_) {
517+
is_pyobject = true;
518+
need_copy = true;
519+
}
434520
if (!need_copy) {
435521
// We don't need a converting copy, but we also need to check whether the strides are
436522
// compatible with the Ref's stride requirements
@@ -453,15 +539,55 @@ struct type_caster<
453539
// We need to copy: If we need a mutable reference, or we're not supposed to convert
454540
// (either because we're in the no-convert overload pass, or because we're explicitly
455541
// instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
456-
if (!convert || need_writeable) return false;
542+
if (!is_pyobject && (!convert || need_writeable)) {
543+
return false;
544+
}
457545

458546
Array copy = Array::ensure(src);
459547
if (!copy) return false;
460548
fits = props::conformable(copy);
461-
if (!fits || !fits.template stride_compatible<props>())
549+
if (!fits || !fits.template stride_compatible<props>()) {
462550
return false;
463-
copy_or_ref = std::move(copy);
464-
loader_life_support::add_patient(copy_or_ref);
551+
}
552+
553+
if (!is_pyobject) {
554+
copy_or_ref = std::move(copy);
555+
loader_life_support::add_patient(copy_or_ref);
556+
}
557+
else {
558+
auto dims = copy.ndim();
559+
if (dims == 1){
560+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
561+
val.resize(copy.shape(0), 1);
562+
}
563+
for (ssize_t i = 0; i < copy.shape(0); ++i) {
564+
auto p = copy.mutable_data(i);
565+
make_caster <Scalar> conv_val;
566+
if (!conv_val.load(PyArray_GETITEM(copy.ptr(), p),
567+
convert))
568+
return false;
569+
val(i) = cast_op<Scalar>(conv_val);
570+
571+
}
572+
} else {
573+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
574+
val.resize(copy.shape(0), copy.shape(1));
575+
}
576+
for (ssize_t i = 0; i < copy.shape(0); ++i) {
577+
for (ssize_t j = 0; j < copy.shape(1); ++j) {
578+
// p is the const void pointer to the item
579+
auto p = copy.mutable_data(i, j);
580+
make_caster <Scalar> conv_val;
581+
if (!conv_val.load(PyArray_GETITEM(copy.ptr(), p),
582+
convert))
583+
return false;
584+
val(i, j) = cast_op<Scalar>(conv_val);
585+
}
586+
}
587+
}
588+
ref.reset(new Type(val));
589+
return true;
590+
}
465591
}
466592

467593
ref.reset();

include/pybind11/numpy.h

+15
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,21 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
12271227
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
12281228
({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
12291229

1230+
#define PYBIND11_NUMPY_OBJECT_DTYPE(Type) \
1231+
namespace pybind11 { namespace detail { \
1232+
template <> struct npy_format_descriptor<Type> { \
1233+
public: \
1234+
enum { value = npy_api::NPY_OBJECT_ }; \
1235+
static pybind11::dtype dtype() { \
1236+
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) { \
1237+
return reinterpret_borrow<pybind11::dtype>(ptr); \
1238+
} \
1239+
pybind11_fail("Unsupported buffer format!"); \
1240+
} \
1241+
static constexpr auto name = _("object"); \
1242+
}; \
1243+
}}
1244+
12301245
#endif // __CLION_IDE__
12311246

12321247
template <class T>

0 commit comments

Comments
 (0)