Skip to content

Commit 2c67e2a

Browse files
author
Mmanu Chaturvedi
committed
Add ability to create object matrices
1 parent 6a81dbb commit 2c67e2a

File tree

2 files changed

+163
-18
lines changed

2 files changed

+163
-18
lines changed

include/pybind11/eigen.h

+148-18
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include "numpy.h"
13+
#include <numpy/ndarraytypes.h>
1314

1415
#if defined(__INTEL_COMPILER)
1516
# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
@@ -139,14 +140,19 @@ template <typename Type_> struct EigenProps {
139140
const auto dims = a.ndim();
140141
if (dims < 1 || dims > 2)
141142
return false;
142-
143+
bool is_pyobject = false;
144+
if (npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_)
145+
is_pyobject = true;
146+
ssize_t scalar_size = (is_pyobject ? static_cast<ssize_t>(sizeof(PyObject*)) :
147+
static_cast<ssize_t>(sizeof(Scalar)));
143148
if (dims == 2) { // Matrix type: require exact match (or dynamic)
144149

145150
EigenIndex
146151
np_rows = a.shape(0),
147152
np_cols = a.shape(1),
148-
np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
149-
np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
153+
np_rstride = a.strides(0) / scalar_size,
154+
np_cstride = a.strides(1) / scalar_size;
155+
150156
if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
151157
return false;
152158

@@ -156,7 +162,7 @@ template <typename Type_> struct EigenProps {
156162
// Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
157163
// is used, we want the (single) numpy stride value.
158164
const EigenIndex n = a.shape(0),
159-
stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
165+
stride = a.strides(0) / scalar_size;
160166

161167
if (vector) { // Eigen type is a compile-time vector
162168
if (fixed && size != n)
@@ -207,11 +213,56 @@ template <typename Type_> struct EigenProps {
207213
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
208214
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
209215
array a;
210-
if (props::vector)
211-
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
212-
else
213-
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
214-
src.data(), base);
216+
using Scalar = typename props::Type::Scalar;
217+
if (props::vector) {
218+
if (npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_) {
219+
a = array(
220+
npy_format_descriptor<Scalar>::dtype(),
221+
{ (size_t) src.size() },
222+
nullptr,
223+
base
224+
);
225+
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
226+
for (ssize_t i = 0; i < src.size(); ++i) {
227+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src(i, 0), policy, base));
228+
if (!value_)
229+
return handle();
230+
auto p = a.mutable_data(i);
231+
PyArray_SETITEM(a.ptr(), p, value_.release().ptr());
232+
}
233+
}
234+
else
235+
a = array({src.size()},
236+
{elem_size * src.innerStride()},
237+
src.data(),
238+
base);
239+
240+
}
241+
else {
242+
if (npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_) {
243+
a = array(
244+
npy_format_descriptor<Scalar>::dtype(),
245+
{(size_t) src.rows(), (size_t) src.cols()},
246+
nullptr,
247+
base
248+
);
249+
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
250+
for (ssize_t i = 0; i < src.rows(); ++i) {
251+
for (ssize_t j = 0; j < src.cols(); ++j) {
252+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src(i, j), policy, base));
253+
if (!value_)
254+
return handle();
255+
auto p = a.mutable_data(i, j);
256+
PyArray_SETITEM(a.ptr(), p, value_.release().ptr());
257+
}
258+
}
259+
}
260+
else
261+
a = array({src.rows(), src.cols()},
262+
{elem_size * src.rowStride(), elem_size * src.colStride()},
263+
src.data(),
264+
base);
265+
}
215266

216267
if (!writeable)
217268
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
@@ -265,14 +316,47 @@ struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
265316
auto fits = props::conformable(buf);
266317
if (!fits)
267318
return false;
268-
319+
int result = 0;
269320
// Allocate the new type, then build a numpy reference into it
270321
value = Type(fits.rows, fits.cols);
271-
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
272-
if (dims == 1) ref = ref.squeeze();
273-
else if (ref.ndim() == 1) buf = buf.squeeze();
274322

275-
int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
323+
if (npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_) {
324+
if (dims == 1){
325+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
326+
value.resize(buf.shape(0), 1);
327+
}
328+
for (ssize_t i = 0; i < buf.shape(0); ++i) {
329+
auto p = buf.mutable_data(i);
330+
make_caster <Scalar> conv_val;
331+
if (!conv_val.load(PyArray_GETITEM(buf.ptr(), p), convert))
332+
return false;
333+
value(i) = cast_op<Scalar>(conv_val);
334+
}
335+
} else {
336+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
337+
value.resize(buf.shape(0), buf.shape(1));
338+
}
339+
for (ssize_t i = 0; i < buf.shape(0); ++i) {
340+
for (ssize_t j = 0; j < buf.shape(1); ++j) {
341+
// p is the const void pointer to the item
342+
auto p = buf.mutable_data(i, j);
343+
make_caster<Scalar> conv_val;
344+
if (!conv_val.load(PyArray_GETITEM(buf.ptr(), p),
345+
convert))
346+
return false;
347+
value(i,j) = cast_op<Scalar>(conv_val);
348+
}
349+
}
350+
}
351+
}
352+
else
353+
{
354+
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
355+
if (dims == 1) ref = ref.squeeze();
356+
else if (ref.ndim() == 1) buf = buf.squeeze();
357+
result =
358+
detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
359+
}
276360

277361
if (result < 0) { // Copy failed!
278362
PyErr_Clear();
@@ -424,13 +508,19 @@ struct type_caster<
424508
// storage order conversion. (Note that we refuse to use this temporary copy when loading an
425509
// argument for a Ref<M> with M non-const, i.e. a read-write reference).
426510
Array copy_or_ref;
511+
typename std::remove_cv<PlainObjectType>::type val;
427512
public:
428513
bool load(handle src, bool convert) {
429514
// First check whether what we have is already an array of the right type. If not, we can't
430515
// avoid a copy (because the copy is also going to do type conversion).
431516
bool need_copy = !isinstance<Array>(src);
432517

433518
EigenConformable<props::row_major> fits;
519+
bool is_pyobject = false;
520+
if (npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_) {
521+
is_pyobject = true;
522+
need_copy = true;
523+
}
434524
if (!need_copy) {
435525
// We don't need a converting copy, but we also need to check whether the strides are
436526
// compatible with the Ref's stride requirements
@@ -453,15 +543,55 @@ struct type_caster<
453543
// We need to copy: If we need a mutable reference, or we're not supposed to convert
454544
// (either because we're in the no-convert overload pass, or because we're explicitly
455545
// instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
456-
if (!convert || need_writeable) return false;
546+
if (!is_pyobject && (!convert || need_writeable)) {
547+
return false;
548+
}
457549

458550
Array copy = Array::ensure(src);
459551
if (!copy) return false;
460552
fits = props::conformable(copy);
461-
if (!fits || !fits.template stride_compatible<props>())
553+
if (!fits || !fits.template stride_compatible<props>()) {
462554
return false;
463-
copy_or_ref = std::move(copy);
464-
loader_life_support::add_patient(copy_or_ref);
555+
}
556+
557+
if (npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_) {
558+
auto dims = copy.ndim();
559+
if (dims == 1){
560+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
561+
val.resize(copy.shape(0), 1);
562+
}
563+
for (ssize_t i = 0; i < copy.shape(0); ++i) {
564+
auto p = copy.mutable_data(i);
565+
make_caster <Scalar> conv_val;
566+
if (!conv_val.load(PyArray_GETITEM(copy.ptr(), p),
567+
convert))
568+
return false;
569+
val(i) = cast_op<Scalar>(conv_val);
570+
571+
}
572+
} else {
573+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
574+
val.resize(copy.shape(0), copy.shape(1));
575+
}
576+
for (ssize_t i = 0; i < copy.shape(0); ++i) {
577+
for (ssize_t j = 0; j < copy.shape(1); ++j) {
578+
// p is the const void pointer to the item
579+
auto p = copy.mutable_data(i, j);
580+
make_caster <Scalar> conv_val;
581+
if (!conv_val.load(PyArray_GETITEM(copy.ptr(), p),
582+
convert))
583+
return false;
584+
val(i, j) = cast_op<Scalar>(conv_val);
585+
}
586+
}
587+
}
588+
ref.reset(new Type(val));
589+
return true;
590+
}
591+
else {
592+
copy_or_ref = std::move(copy);
593+
loader_life_support::add_patient(copy_or_ref);
594+
}
465595
}
466596

467597
ref.reset();

include/pybind11/numpy.h

+15
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,21 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
12271227
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
12281228
({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
12291229

1230+
#define PYBIND11_NUMPY_OBJECT_DTYPE(Type) \
1231+
namespace pybind11 { namespace detail { \
1232+
template <> struct npy_format_descriptor<Type> { \
1233+
public: \
1234+
enum { value = npy_api::NPY_OBJECT_ }; \
1235+
static pybind11::dtype dtype() { \
1236+
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) { \
1237+
return reinterpret_borrow<pybind11::dtype>(ptr); \
1238+
} \
1239+
pybind11_fail("Unsupported buffer format!"); \
1240+
} \
1241+
static constexpr auto name = _("object"); \
1242+
}; \
1243+
}}
1244+
12301245
#endif // __CLION_IDE__
12311246

12321247
template <class T>

0 commit comments

Comments
 (0)