Skip to content

Commit e1657ea

Browse files
author
Mmanu Chaturvedi
committed
Add ability to create object matrices
1 parent 6a81dbb commit e1657ea

File tree

4 files changed

+237
-22
lines changed

4 files changed

+237
-22
lines changed

include/pybind11/eigen.h

+144-19
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>>
107107
template <typename PlainObjectType, int Options, typename StrideType>
108108
struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> { using type = StrideType; };
109109

110+
template <typename Scalar> bool is_pyobject_() {
111+
return static_cast<pybind11::detail::npy_api::constants>(npy_format_descriptor<Scalar>::value) == npy_api::NPY_OBJECT_;
112+
}
113+
110114
// Helper struct for extracting information from an Eigen type
111115
template <typename Type_> struct EigenProps {
112116
using Type = Type_;
@@ -139,14 +143,19 @@ template <typename Type_> struct EigenProps {
139143
const auto dims = a.ndim();
140144
if (dims < 1 || dims > 2)
141145
return false;
142-
146+
bool is_pyobject = false;
147+
if (is_pyobject_<Scalar>())
148+
is_pyobject = true;
149+
ssize_t scalar_size = (is_pyobject ? static_cast<ssize_t>(sizeof(PyObject*)) :
150+
static_cast<ssize_t>(sizeof(Scalar)));
143151
if (dims == 2) { // Matrix type: require exact match (or dynamic)
144152

145153
EigenIndex
146154
np_rows = a.shape(0),
147155
np_cols = a.shape(1),
148-
np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
149-
np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
156+
np_rstride = a.strides(0) / scalar_size,
157+
np_cstride = a.strides(1) / scalar_size;
158+
150159
if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
151160
return false;
152161

@@ -156,7 +165,7 @@ template <typename Type_> struct EigenProps {
156165
// Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
157166
// is used, we want the (single) numpy stride value.
158167
const EigenIndex n = a.shape(0),
159-
stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
168+
stride = a.strides(0) / scalar_size;
160169

161170
if (vector) { // Eigen type is a compile-time vector
162171
if (fixed && size != n)
@@ -207,11 +216,51 @@ template <typename Type_> struct EigenProps {
207216
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
208217
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
209218
array a;
210-
if (props::vector)
211-
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
212-
else
213-
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
214-
src.data(), base);
219+
using Scalar = typename props::Type::Scalar;
220+
bool is_pyoject = static_cast<pybind11::detail::npy_api::constants>(npy_format_descriptor<Scalar>::value) == npy_api::NPY_OBJECT_;
221+
222+
if (!is_pyoject) {
223+
if (props::vector)
224+
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
225+
else
226+
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
227+
src.data(), base);
228+
}
229+
else {
230+
if (props::vector) {
231+
a = array(
232+
npy_format_descriptor<Scalar>::dtype(),
233+
{ (size_t) src.size() },
234+
nullptr,
235+
base
236+
);
237+
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
238+
for (ssize_t i = 0; i < src.size(); ++i) {
239+
const Scalar src_val = props::fixed_rows ? src(0, i) : src(i, 0);
240+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src_val, policy, base));
241+
if (!value_)
242+
return handle();
243+
a.attr("itemset")(i, value_);
244+
}
245+
}
246+
else {
247+
a = array(
248+
npy_format_descriptor<Scalar>::dtype(),
249+
{(size_t) src.rows(), (size_t) src.cols()},
250+
nullptr,
251+
base
252+
);
253+
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
254+
for (ssize_t i = 0; i < src.rows(); ++i) {
255+
for (ssize_t j = 0; j < src.cols(); ++j) {
256+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src(i, j), policy, base));
257+
if (!value_)
258+
return handle();
259+
a.attr("itemset")(i, j, value_);
260+
}
261+
}
262+
}
263+
}
215264

216265
if (!writeable)
217266
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
@@ -265,14 +314,46 @@ struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
265314
auto fits = props::conformable(buf);
266315
if (!fits)
267316
return false;
268-
317+
int result = 0;
269318
// Allocate the new type, then build a numpy reference into it
270319
value = Type(fits.rows, fits.cols);
271-
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
272-
if (dims == 1) ref = ref.squeeze();
273-
else if (ref.ndim() == 1) buf = buf.squeeze();
274-
275-
int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
320+
bool is_pyobject = is_pyobject_<Scalar>();
321+
322+
if (!is_pyobject) {
323+
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
324+
if (dims == 1) ref = ref.squeeze();
325+
else if (ref.ndim() == 1) buf = buf.squeeze();
326+
result =
327+
detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
328+
}
329+
else {
330+
if (dims == 1) {
331+
if (Type::RowsAtCompileTime == Eigen::Dynamic)
332+
value.resize(buf.shape(0), 1);
333+
if (Type::ColsAtCompileTime == Eigen::Dynamic)
334+
value.resize(1, buf.shape(0));
335+
336+
for (ssize_t i = 0; i < buf.shape(0); ++i) {
337+
make_caster <Scalar> conv_val;
338+
if (!conv_val.load(buf.attr("item")(i).cast<pybind11::object>(), convert))
339+
return false;
340+
value(i) = cast_op<Scalar>(conv_val);
341+
}
342+
} else {
343+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
344+
value.resize(buf.shape(0), buf.shape(1));
345+
}
346+
for (ssize_t i = 0; i < buf.shape(0); ++i) {
347+
for (ssize_t j = 0; j < buf.shape(1); ++j) {
348+
// p is the const void pointer to the item
349+
make_caster<Scalar> conv_val;
350+
if (!conv_val.load(buf.attr("item")(i,j).cast<pybind11::object>(), convert))
351+
return false;
352+
value(i,j) = cast_op<Scalar>(conv_val);
353+
}
354+
}
355+
}
356+
}
276357

277358
if (result < 0) { // Copy failed!
278359
PyErr_Clear();
@@ -424,13 +505,19 @@ struct type_caster<
424505
// storage order conversion. (Note that we refuse to use this temporary copy when loading an
425506
// argument for a Ref<M> with M non-const, i.e. a read-write reference).
426507
Array copy_or_ref;
508+
typename std::remove_cv<PlainObjectType>::type val;
427509
public:
428510
bool load(handle src, bool convert) {
429511
// First check whether what we have is already an array of the right type. If not, we can't
430512
// avoid a copy (because the copy is also going to do type conversion).
431513
bool need_copy = !isinstance<Array>(src);
432514

433515
EigenConformable<props::row_major> fits;
516+
bool is_pyobject = false;
517+
if (is_pyobject_<Scalar>()) {
518+
is_pyobject = true;
519+
need_copy = true;
520+
}
434521
if (!need_copy) {
435522
// We don't need a converting copy, but we also need to check whether the strides are
436523
// compatible with the Ref's stride requirements
@@ -453,15 +540,53 @@ struct type_caster<
453540
// We need to copy: If we need a mutable reference, or we're not supposed to convert
454541
// (either because we're in the no-convert overload pass, or because we're explicitly
455542
// instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
456-
if (!convert || need_writeable) return false;
543+
if (!is_pyobject && (!convert || need_writeable)) {
544+
return false;
545+
}
457546

458547
Array copy = Array::ensure(src);
459548
if (!copy) return false;
460549
fits = props::conformable(copy);
461-
if (!fits || !fits.template stride_compatible<props>())
550+
if (!fits || !fits.template stride_compatible<props>()) {
462551
return false;
463-
copy_or_ref = std::move(copy);
464-
loader_life_support::add_patient(copy_or_ref);
552+
}
553+
554+
if (!is_pyobject) {
555+
copy_or_ref = std::move(copy);
556+
loader_life_support::add_patient(copy_or_ref);
557+
}
558+
else {
559+
auto dims = copy.ndim();
560+
if (dims == 1) {
561+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
562+
val.resize(copy.shape(0), 1);
563+
}
564+
for (ssize_t i = 0; i < copy.shape(0); ++i) {
565+
make_caster <Scalar> conv_val;
566+
if (!conv_val.load(copy.attr("item")(i).template cast<pybind11::object>(),
567+
convert))
568+
return false;
569+
val(i) = cast_op<Scalar>(conv_val);
570+
571+
}
572+
} else {
573+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
574+
val.resize(copy.shape(0), copy.shape(1));
575+
}
576+
for (ssize_t i = 0; i < copy.shape(0); ++i) {
577+
for (ssize_t j = 0; j < copy.shape(1); ++j) {
578+
// p is the const void pointer to the item
579+
make_caster <Scalar> conv_val;
580+
if (!conv_val.load(copy.attr("item")(i, j).template cast<pybind11::object>(),
581+
convert))
582+
return false;
583+
val(i, j) = cast_op<Scalar>(conv_val);
584+
}
585+
}
586+
}
587+
ref.reset(new Type(val));
588+
return true;
589+
}
465590
}
466591

467592
ref.reset();

include/pybind11/numpy.h

+15
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,21 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
12271227
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
12281228
({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
12291229

1230+
#define PYBIND11_NUMPY_OBJECT_DTYPE(Type) \
1231+
namespace pybind11 { namespace detail { \
1232+
template <> struct npy_format_descriptor<Type> { \
1233+
public: \
1234+
enum { value = npy_api::NPY_OBJECT_ }; \
1235+
static pybind11::dtype dtype() { \
1236+
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) { \
1237+
return reinterpret_borrow<pybind11::dtype>(ptr); \
1238+
} \
1239+
pybind11_fail("Unsupported buffer format!"); \
1240+
} \
1241+
static constexpr auto name = _("object"); \
1242+
}; \
1243+
}}
1244+
12301245
#endif // __CLION_IDE__
12311246

12321247
template <class T>

tests/test_eigen.cpp

+36-3
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
#include <pybind11/eigen.h>
1313
#include <pybind11/stl.h>
1414
#include <Eigen/Cholesky>
15+
#include <unsupported/Eigen/AutoDiff>
16+
#include "src/Core/util/DisableStupidWarnings.h"
1517

1618
using MatrixXdR = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
17-
18-
19+
typedef Eigen::AutoDiffScalar<Eigen::VectorXd> ADScalar;
20+
typedef Eigen::Matrix<ADScalar, Eigen::Dynamic, 1> VectorXADScalar;
21+
typedef Eigen::Matrix<ADScalar, 1, Eigen::Dynamic> VectorXADScalarR;
22+
PYBIND11_NUMPY_OBJECT_DTYPE(ADScalar);
1923

2024
// Sets/resets a testing reference matrix to have values of 10*r + c, where r and c are the
2125
// (1-based) row/column number.
@@ -74,7 +78,9 @@ TEST_SUBMODULE(eigen, m) {
7478
using FixedMatrixR = Eigen::Matrix<float, 5, 6, Eigen::RowMajor>;
7579
using FixedMatrixC = Eigen::Matrix<float, 5, 6>;
7680
using DenseMatrixR = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
81+
using DenseADScalarMatrixR = Eigen::Matrix<ADScalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
7782
using DenseMatrixC = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>;
83+
using DenseADScalarMatrixC = Eigen::Matrix<ADScalar, Eigen::Dynamic, Eigen::Dynamic>;
7884
using FourRowMatrixC = Eigen::Matrix<float, 4, Eigen::Dynamic>;
7985
using FourColMatrixC = Eigen::Matrix<float, Eigen::Dynamic, 4>;
8086
using FourRowMatrixR = Eigen::Matrix<float, 4, Eigen::Dynamic>;
@@ -86,10 +92,14 @@ TEST_SUBMODULE(eigen, m) {
8692

8793
// various tests
8894
m.def("double_col", [](const Eigen::VectorXf &x) -> Eigen::VectorXf { return 2.0f * x; });
95+
m.def("double_adscalar_col", [](const VectorXADScalar &x) -> VectorXADScalar { return 2.0f * x; });
8996
m.def("double_row", [](const Eigen::RowVectorXf &x) -> Eigen::RowVectorXf { return 2.0f * x; });
97+
m.def("double_adscalar_row", [](const VectorXADScalarR &x) -> VectorXADScalarR { return 2.0f * x; });
9098
m.def("double_complex", [](const Eigen::VectorXcf &x) -> Eigen::VectorXcf { return 2.0f * x; });
9199
m.def("double_threec", [](py::EigenDRef<Eigen::Vector3f> x) { x *= 2; });
100+
m.def("double_adscalarc", [](py::EigenDRef<VectorXADScalar> x) { x *= 2; });
92101
m.def("double_threer", [](py::EigenDRef<Eigen::RowVector3f> x) { x *= 2; });
102+
m.def("double_adscalarr", [](py::EigenDRef<VectorXADScalarR> x) { x *= 2; });
93103
m.def("double_mat_cm", [](Eigen::MatrixXf x) -> Eigen::MatrixXf { return 2.0f * x; });
94104
m.def("double_mat_rm", [](DenseMatrixR x) -> DenseMatrixR { return 2.0f * x; });
95105

@@ -134,6 +144,12 @@ TEST_SUBMODULE(eigen, m) {
134144
return m;
135145
}, py::return_value_policy::reference);
136146

147+
// Increments ADScalar Matrix
148+
m.def("incr_adscalar_matrix", [](Eigen::Ref<DenseADScalarMatrixC> m, double v) {
149+
m += DenseADScalarMatrixC::Constant(m.rows(), m.cols(), v);
150+
return m;
151+
}, py::return_value_policy::reference);
152+
137153
// Same, but accepts a matrix of any strides
138154
m.def("incr_matrix_any", [](py::EigenDRef<Eigen::MatrixXd> m, double v) {
139155
m += Eigen::MatrixXd::Constant(m.rows(), m.cols(), v);
@@ -168,12 +184,16 @@ TEST_SUBMODULE(eigen, m) {
168184
// return value referencing/copying tests:
169185
class ReturnTester {
170186
Eigen::MatrixXd mat = create();
187+
DenseADScalarMatrixR ad_mat = create_ADScalar_mat();
171188
public:
172189
ReturnTester() { print_created(this); }
173190
~ReturnTester() { print_destroyed(this); }
174-
static Eigen::MatrixXd create() { return Eigen::MatrixXd::Ones(10, 10); }
191+
static Eigen::MatrixXd create() { return Eigen::MatrixXd::Ones(10, 10); }
192+
static DenseADScalarMatrixR create_ADScalar_mat() { DenseADScalarMatrixR ad_mat(2, 2);
193+
ad_mat << 1, 2, 3, 7; return ad_mat; }
175194
static const Eigen::MatrixXd createConst() { return Eigen::MatrixXd::Ones(10, 10); }
176195
Eigen::MatrixXd &get() { return mat; }
196+
DenseADScalarMatrixR& get_ADScalarMat() {return ad_mat;}
177197
Eigen::MatrixXd *getPtr() { return &mat; }
178198
const Eigen::MatrixXd &view() { return mat; }
179199
const Eigen::MatrixXd *viewPtr() { return &mat; }
@@ -192,6 +212,7 @@ TEST_SUBMODULE(eigen, m) {
192212
.def_static("create", &ReturnTester::create)
193213
.def_static("create_const", &ReturnTester::createConst)
194214
.def("get", &ReturnTester::get, rvp::reference_internal)
215+
.def("get_ADScalarMat", &ReturnTester::get_ADScalarMat, rvp::reference_internal)
195216
.def("get_ptr", &ReturnTester::getPtr, rvp::reference_internal)
196217
.def("view", &ReturnTester::view, rvp::reference_internal)
197218
.def("view_ptr", &ReturnTester::view, rvp::reference_internal)
@@ -211,6 +232,18 @@ TEST_SUBMODULE(eigen, m) {
211232
.def("corners_const", &ReturnTester::cornersConst, rvp::reference_internal)
212233
;
213234

235+
py::class_<ADScalar>(m, "AutoDiffXd")
236+
.def("__init__",
237+
[](ADScalar & self,
238+
double value,
239+
const Eigen::VectorXd& derivatives) {
240+
new (&self) ADScalar(value, derivatives);
241+
})
242+
.def("value", [](const ADScalar & self) {
243+
return self.value();
244+
})
245+
;
246+
214247
// test_special_matrix_objects
215248
// Returns a DiagonalMatrix with diagonal (1,2,3,...)
216249
m.def("incr_diag", [](int k) {

0 commit comments

Comments
 (0)