Skip to content

Commit 3f45ee7

Browse files
author
Mmanu Chaturvedi
committed
Add ability to create object matrices
1 parent 6a81dbb commit 3f45ee7

File tree

4 files changed

+239
-22
lines changed

4 files changed

+239
-22
lines changed

include/pybind11/eigen.h

+148-19
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include "numpy.h"
13+
#include "numpy/ndarraytypes.h"
1314

1415
#if defined(__INTEL_COMPILER)
1516
# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
@@ -139,14 +140,19 @@ template <typename Type_> struct EigenProps {
139140
const auto dims = a.ndim();
140141
if (dims < 1 || dims > 2)
141142
return false;
142-
143+
bool is_pyobject = false;
144+
if (npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_)
145+
is_pyobject = true;
146+
ssize_t scalar_size = (is_pyobject ? static_cast<ssize_t>(sizeof(PyObject*)) :
147+
static_cast<ssize_t>(sizeof(Scalar)));
143148
if (dims == 2) { // Matrix type: require exact match (or dynamic)
144149

145150
EigenIndex
146151
np_rows = a.shape(0),
147152
np_cols = a.shape(1),
148-
np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
149-
np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
153+
np_rstride = a.strides(0) / scalar_size,
154+
np_cstride = a.strides(1) / scalar_size;
155+
150156
if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
151157
return false;
152158

@@ -156,7 +162,7 @@ template <typename Type_> struct EigenProps {
156162
// Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
157163
// is used, we want the (single) numpy stride value.
158164
const EigenIndex n = a.shape(0),
159-
stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
165+
stride = a.strides(0) / scalar_size;
160166

161167
if (vector) { // Eigen type is a compile-time vector
162168
if (fixed && size != n)
@@ -207,11 +213,53 @@ template <typename Type_> struct EigenProps {
207213
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
208214
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
209215
array a;
210-
if (props::vector)
211-
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
212-
else
213-
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
214-
src.data(), base);
216+
using Scalar = typename props::Type::Scalar;
217+
bool is_pyoject = npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_;
218+
219+
if (!is_pyoject) {
220+
if (props::vector)
221+
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
222+
else
223+
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
224+
src.data(), base);
225+
}
226+
else {
227+
if (props::vector) {
228+
a = array(
229+
npy_format_descriptor<Scalar>::dtype(),
230+
{ (size_t) src.size() },
231+
nullptr,
232+
base
233+
);
234+
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
235+
for (ssize_t i = 0; i < src.size(); ++i) {
236+
const Scalar src_val = props::fixed_rows ? src(0, i) : src(i, 0);
237+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src_val, policy, base));
238+
if (!value_)
239+
return handle();
240+
auto p = a.mutable_data(i);
241+
PyArray_SETITEM(a.ptr(), p, value_.release().ptr());
242+
}
243+
}
244+
else {
245+
a = array(
246+
npy_format_descriptor<Scalar>::dtype(),
247+
{(size_t) src.rows(), (size_t) src.cols()},
248+
nullptr,
249+
base
250+
);
251+
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
252+
for (ssize_t i = 0; i < src.rows(); ++i) {
253+
for (ssize_t j = 0; j < src.cols(); ++j) {
254+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src(i, j), policy, base));
255+
if (!value_)
256+
return handle();
257+
auto p = a.mutable_data(i, j);
258+
PyArray_SETITEM(a.ptr(), p, value_.release().ptr());
259+
}
260+
}
261+
}
262+
}
215263

216264
if (!writeable)
217265
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
@@ -265,14 +313,49 @@ struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
265313
auto fits = props::conformable(buf);
266314
if (!fits)
267315
return false;
268-
316+
int result = 0;
269317
// Allocate the new type, then build a numpy reference into it
270318
value = Type(fits.rows, fits.cols);
271-
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
272-
if (dims == 1) ref = ref.squeeze();
273-
else if (ref.ndim() == 1) buf = buf.squeeze();
274-
275-
int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
319+
bool is_pyobject = npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_;
320+
321+
if (!is_pyobject) {
322+
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
323+
if (dims == 1) ref = ref.squeeze();
324+
else if (ref.ndim() == 1) buf = buf.squeeze();
325+
result =
326+
detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
327+
}
328+
else {
329+
if (dims == 1){
330+
if (Type::RowsAtCompileTime == Eigen::Dynamic)
331+
value.resize(buf.shape(0), 1);
332+
if (Type::ColsAtCompileTime == Eigen::Dynamic)
333+
value.resize(1, buf.shape(0));
334+
335+
for (ssize_t i = 0; i < buf.shape(0); ++i) {
336+
auto p = buf.mutable_data(i);
337+
make_caster <Scalar> conv_val;
338+
if (!conv_val.load(PyArray_GETITEM(buf.ptr(), p), convert))
339+
return false;
340+
value(i) = cast_op<Scalar>(conv_val);
341+
}
342+
} else {
343+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
344+
value.resize(buf.shape(0), buf.shape(1));
345+
}
346+
for (ssize_t i = 0; i < buf.shape(0); ++i) {
347+
for (ssize_t j = 0; j < buf.shape(1); ++j) {
348+
// p is the const void pointer to the item
349+
auto p = buf.mutable_data(i, j);
350+
make_caster<Scalar> conv_val;
351+
if (!conv_val.load(PyArray_GETITEM(buf.ptr(), p),
352+
convert))
353+
return false;
354+
value(i,j) = cast_op<Scalar>(conv_val);
355+
}
356+
}
357+
}
358+
}
276359

277360
if (result < 0) { // Copy failed!
278361
PyErr_Clear();
@@ -424,13 +507,19 @@ struct type_caster<
424507
// storage order conversion. (Note that we refuse to use this temporary copy when loading an
425508
// argument for a Ref<M> with M non-const, i.e. a read-write reference).
426509
Array copy_or_ref;
510+
typename std::remove_cv<PlainObjectType>::type val;
427511
public:
428512
bool load(handle src, bool convert) {
429513
// First check whether what we have is already an array of the right type. If not, we can't
430514
// avoid a copy (because the copy is also going to do type conversion).
431515
bool need_copy = !isinstance<Array>(src);
432516

433517
EigenConformable<props::row_major> fits;
518+
bool is_pyobject = false;
519+
if (npy_format_descriptor<Scalar>::value == npy_api::NPY_OBJECT_) {
520+
is_pyobject = true;
521+
need_copy = true;
522+
}
434523
if (!need_copy) {
435524
// We don't need a converting copy, but we also need to check whether the strides are
436525
// compatible with the Ref's stride requirements
@@ -453,15 +542,55 @@ struct type_caster<
453542
// We need to copy: If we need a mutable reference, or we're not supposed to convert
454543
// (either because we're in the no-convert overload pass, or because we're explicitly
455544
// instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
456-
if (!convert || need_writeable) return false;
545+
if (!is_pyobject && (!convert || need_writeable)) {
546+
return false;
547+
}
457548

458549
Array copy = Array::ensure(src);
459550
if (!copy) return false;
460551
fits = props::conformable(copy);
461-
if (!fits || !fits.template stride_compatible<props>())
552+
if (!fits || !fits.template stride_compatible<props>()) {
462553
return false;
463-
copy_or_ref = std::move(copy);
464-
loader_life_support::add_patient(copy_or_ref);
554+
}
555+
556+
if (!is_pyobject) {
557+
copy_or_ref = std::move(copy);
558+
loader_life_support::add_patient(copy_or_ref);
559+
}
560+
else {
561+
auto dims = copy.ndim();
562+
if (dims == 1){
563+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
564+
val.resize(copy.shape(0), 1);
565+
}
566+
for (ssize_t i = 0; i < copy.shape(0); ++i) {
567+
auto p = copy.mutable_data(i);
568+
make_caster <Scalar> conv_val;
569+
if (!conv_val.load(PyArray_GETITEM(copy.ptr(), p),
570+
convert))
571+
return false;
572+
val(i) = cast_op<Scalar>(conv_val);
573+
574+
}
575+
} else {
576+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
577+
val.resize(copy.shape(0), copy.shape(1));
578+
}
579+
for (ssize_t i = 0; i < copy.shape(0); ++i) {
580+
for (ssize_t j = 0; j < copy.shape(1); ++j) {
581+
// p is the const void pointer to the item
582+
auto p = copy.mutable_data(i, j);
583+
make_caster <Scalar> conv_val;
584+
if (!conv_val.load(PyArray_GETITEM(copy.ptr(), p),
585+
convert))
586+
return false;
587+
val(i, j) = cast_op<Scalar>(conv_val);
588+
}
589+
}
590+
}
591+
ref.reset(new Type(val));
592+
return true;
593+
}
465594
}
466595

467596
ref.reset();

include/pybind11/numpy.h

+15
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,21 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
12271227
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
12281228
({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
12291229

1230+
#define PYBIND11_NUMPY_OBJECT_DTYPE(Type) \
1231+
namespace pybind11 { namespace detail { \
1232+
template <> struct npy_format_descriptor<Type> { \
1233+
public: \
1234+
enum { value = npy_api::NPY_OBJECT_ }; \
1235+
static pybind11::dtype dtype() { \
1236+
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) { \
1237+
return reinterpret_borrow<pybind11::dtype>(ptr); \
1238+
} \
1239+
pybind11_fail("Unsupported buffer format!"); \
1240+
} \
1241+
static constexpr auto name = _("object"); \
1242+
}; \
1243+
}}
1244+
12301245
#endif // __CLION_IDE__
12311246

12321247
template <class T>

tests/test_eigen.cpp

+37-3
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
#include <pybind11/eigen.h>
1313
#include <pybind11/stl.h>
1414
#include <Eigen/Cholesky>
15+
#include <unsupported/Eigen/AutoDiff>
1516

1617
using MatrixXdR = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
17-
18-
18+
typedef Eigen::AutoDiffScalar<Eigen::VectorXd> ADScalar;
19+
typedef Eigen::Matrix<ADScalar, Eigen::Dynamic, 1> VectorXADScalar;
20+
typedef Eigen::Matrix<ADScalar, 1, Eigen::Dynamic> VectorXADScalarR;
21+
PYBIND11_NUMPY_OBJECT_DTYPE(ADScalar);
1922

2023
// Sets/resets a testing reference matrix to have values of 10*r + c, where r and c are the
2124
// (1-based) row/column number.
@@ -72,9 +75,13 @@ struct CustomOperatorNew {
7275

7376
TEST_SUBMODULE(eigen, m) {
7477
using FixedMatrixR = Eigen::Matrix<float, 5, 6, Eigen::RowMajor>;
78+
using FixedADScalarMatrixR = Eigen::Matrix<ADScalar, 5, 6, Eigen::RowMajor>;
7579
using FixedMatrixC = Eigen::Matrix<float, 5, 6>;
80+
using FixedADScalarMatrixC = Eigen::Matrix<ADScalar, 5, 6>;
7681
using DenseMatrixR = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
82+
using DenseADScalarMatrixR = Eigen::Matrix<ADScalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
7783
using DenseMatrixC = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>;
84+
using DenseADScalarMatrixC = Eigen::Matrix<ADScalar, Eigen::Dynamic, Eigen::Dynamic>;
7885
using FourRowMatrixC = Eigen::Matrix<float, 4, Eigen::Dynamic>;
7986
using FourColMatrixC = Eigen::Matrix<float, Eigen::Dynamic, 4>;
8087
using FourRowMatrixR = Eigen::Matrix<float, 4, Eigen::Dynamic>;
@@ -86,10 +93,14 @@ TEST_SUBMODULE(eigen, m) {
8693

8794
// various tests
8895
m.def("double_col", [](const Eigen::VectorXf &x) -> Eigen::VectorXf { return 2.0f * x; });
96+
m.def("double_adscalar_col", [](const VectorXADScalar &x) -> VectorXADScalar { return 2.0f * x; });
8997
m.def("double_row", [](const Eigen::RowVectorXf &x) -> Eigen::RowVectorXf { return 2.0f * x; });
98+
m.def("double_adscalar_row", [](const VectorXADScalarR &x) -> VectorXADScalarR { return 2.0f * x; });
9099
m.def("double_complex", [](const Eigen::VectorXcf &x) -> Eigen::VectorXcf { return 2.0f * x; });
91100
m.def("double_threec", [](py::EigenDRef<Eigen::Vector3f> x) { x *= 2; });
101+
m.def("double_adscalarc", [](py::EigenDRef<VectorXADScalar> x) { x *= 2; });
92102
m.def("double_threer", [](py::EigenDRef<Eigen::RowVector3f> x) { x *= 2; });
103+
m.def("double_adscalarr", [](py::EigenDRef<VectorXADScalarR> x) { x *= 2; });
93104
m.def("double_mat_cm", [](Eigen::MatrixXf x) -> Eigen::MatrixXf { return 2.0f * x; });
94105
m.def("double_mat_rm", [](DenseMatrixR x) -> DenseMatrixR { return 2.0f * x; });
95106

@@ -134,6 +145,12 @@ TEST_SUBMODULE(eigen, m) {
134145
return m;
135146
}, py::return_value_policy::reference);
136147

148+
// Increments ADScalar Matrix
149+
m.def("incr_adscalar_matrix", [](Eigen::Ref<DenseADScalarMatrixC> m, double v) {
150+
m += DenseADScalarMatrixC::Constant(m.rows(), m.cols(), v);
151+
return m;
152+
}, py::return_value_policy::reference);
153+
137154
// Same, but accepts a matrix of any strides
138155
m.def("incr_matrix_any", [](py::EigenDRef<Eigen::MatrixXd> m, double v) {
139156
m += Eigen::MatrixXd::Constant(m.rows(), m.cols(), v);
@@ -168,12 +185,16 @@ TEST_SUBMODULE(eigen, m) {
168185
// return value referencing/copying tests:
169186
class ReturnTester {
170187
Eigen::MatrixXd mat = create();
188+
DenseADScalarMatrixR ad_mat = create_ADScalar_mat();
171189
public:
172190
ReturnTester() { print_created(this); }
173191
~ReturnTester() { print_destroyed(this); }
174-
static Eigen::MatrixXd create() { return Eigen::MatrixXd::Ones(10, 10); }
192+
static Eigen::MatrixXd create() { return Eigen::MatrixXd::Ones(10, 10); }
193+
static DenseADScalarMatrixR create_ADScalar_mat() { DenseADScalarMatrixR ad_mat(2, 2);
194+
ad_mat << 1, 2, 3, 7; return ad_mat; }
175195
static const Eigen::MatrixXd createConst() { return Eigen::MatrixXd::Ones(10, 10); }
176196
Eigen::MatrixXd &get() { return mat; }
197+
DenseADScalarMatrixR& get_ADScalarMat() {return ad_mat;}
177198
Eigen::MatrixXd *getPtr() { return &mat; }
178199
const Eigen::MatrixXd &view() { return mat; }
179200
const Eigen::MatrixXd *viewPtr() { return &mat; }
@@ -192,6 +213,7 @@ TEST_SUBMODULE(eigen, m) {
192213
.def_static("create", &ReturnTester::create)
193214
.def_static("create_const", &ReturnTester::createConst)
194215
.def("get", &ReturnTester::get, rvp::reference_internal)
216+
.def("get_ADScalarMat", &ReturnTester::get_ADScalarMat, rvp::reference_internal)
195217
.def("get_ptr", &ReturnTester::getPtr, rvp::reference_internal)
196218
.def("view", &ReturnTester::view, rvp::reference_internal)
197219
.def("view_ptr", &ReturnTester::view, rvp::reference_internal)
@@ -211,6 +233,18 @@ TEST_SUBMODULE(eigen, m) {
211233
.def("corners_const", &ReturnTester::cornersConst, rvp::reference_internal)
212234
;
213235

236+
py::class_<ADScalar>(m, "AutoDiffXd")
237+
.def("__init__",
238+
[](ADScalar & self,
239+
double value,
240+
const Eigen::VectorXd& derivatives) {
241+
new (&self) ADScalar(value, derivatives);
242+
})
243+
.def("value", [](const ADScalar & self) {
244+
return self.value();
245+
})
246+
;
247+
214248
// test_special_matrix_objects
215249
// Returns a DiagonalMatrix with diagonal (1,2,3,...)
216250
m.def("incr_diag", [](int k) {

0 commit comments

Comments
 (0)