Skip to content

Eigen/numpy referencing #610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 288 additions & 28 deletions docs/advanced/cast/eigen.rst

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions docs/advanced/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ with the basics of binding functions and classes, as explained in :doc:`/basics`
and :doc:`/classes`. The following guide is applicable to both free and member
functions, i.e. *methods* in Python.

.. _return_value_policies:

Return value policies
=====================

Expand Down Expand Up @@ -319,6 +321,8 @@ like so:
py::class_<MyClass>("MyClass")
.def("myFunction", py::arg("arg") = (SomeType *) nullptr);

.. _nonconverting_arguments:

Non-converting arguments
========================

Expand Down
11 changes: 11 additions & 0 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,17 @@ template <typename type> using cast_is_temporary_value_reference = bool_constant
!std::is_base_of<type_caster_generic, make_caster<type>>::value
>;

// When a value returned from a C++ function is being cast back to Python, we almost always want to
// force `policy = move`, regardless of the return value policy the function/method was declared
// with. Some classes (most notably Eigen::Ref and related) need to avoid this, and so can do so by
// specializing this struct.
template <typename Return, typename SFINAE = void> struct return_value_policy_override {
static return_value_policy policy(return_value_policy p) {
return !std::is_lvalue_reference<Return>::value && !std::is_pointer<Return>::value
? return_value_policy::move : p;
}
};

// Basic python -> C++ casting; throws if casting fails
template <typename T, typename SFINAE> type_caster<T, SFINAE> &load_type(type_caster<T, SFINAE> &conv, const handle &handle) {
if (!conv.load(handle, true)) {
Expand Down
6 changes: 4 additions & 2 deletions include/pybind11/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,11 @@ inline internals &get_internals();
#ifdef PYBIND11_CPP14
using std::enable_if_t;
using std::conditional_t;
using std::remove_cv_t;
#else
template <bool B, typename T = void> using enable_if_t = typename std::enable_if<B, T>::type;
template <bool B, typename T, typename F> using conditional_t = typename std::conditional<B, T, F>::type;
template <typename T> using remove_cv_t = typename std::remove_cv<T>::type;
#endif

/// Index sequences
Expand Down Expand Up @@ -504,9 +506,9 @@ struct is_template_base_of_impl {
/// `is_template_base_of<Base, T>` is true if `struct T : Base<U> {}` where U can be anything
template <template<typename...> class Base, typename T>
#if !defined(_MSC_VER)
using is_template_base_of = decltype(is_template_base_of_impl<Base>::check((T*)nullptr));
using is_template_base_of = decltype(is_template_base_of_impl<Base>::check((remove_cv_t<T>*)nullptr));
#else // MSVC2015 has trouble with decltype in template aliases
struct is_template_base_of : decltype(is_template_base_of_impl<Base>::check((T*)nullptr)) { };
struct is_template_base_of : decltype(is_template_base_of_impl<Base>::check((remove_cv_t<T>*)nullptr)) { };
#endif

/// Check if T is std::shared_ptr<U> where U can be anything
Expand Down
540 changes: 441 additions & 99 deletions include/pybind11/eigen.h

Large diffs are not rendered by default.

27 changes: 15 additions & 12 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ inline numpy_internals& get_numpy_internals() {

struct npy_api {
enum constants {
NPY_C_CONTIGUOUS_ = 0x0001,
NPY_F_CONTIGUOUS_ = 0x0002,
NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
NPY_ARRAY_OWNDATA_ = 0x0004,
NPY_ARRAY_FORCECAST_ = 0x0010,
NPY_ENSURE_ARRAY_ = 0x0040,
NPY_ARRAY_ENSUREARRAY_ = 0x0040,
NPY_ARRAY_ALIGNED_ = 0x0100,
NPY_ARRAY_WRITEABLE_ = 0x0400,
NPY_BOOL_ = 0,
Expand Down Expand Up @@ -154,6 +154,7 @@ struct npy_api {
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *);
PyObject *(*PyArray_Squeeze_)(PyObject *);
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
private:
enum functions {
API_PyArray_Type = 2,
Expand All @@ -168,7 +169,8 @@ struct npy_api {
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
API_PyArray_Squeeze = 136
API_PyArray_Squeeze = 136,
API_PyArray_SetBaseObject = 282
};

static npy_api lookup() {
Expand All @@ -194,6 +196,7 @@ struct npy_api {
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
DECL_NPY_API(PyArray_Squeeze);
DECL_NPY_API(PyArray_SetBaseObject);
#undef DECL_NPY_API
return api;
}
Expand Down Expand Up @@ -330,8 +333,8 @@ class array : public buffer {
PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)

enum {
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
};

Expand Down Expand Up @@ -365,7 +368,7 @@ class array : public buffer {
pybind11_fail("NumPy: unable to create array!");
if (ptr) {
if (base) {
detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr();
api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
} else {
tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
}
Expand Down Expand Up @@ -533,7 +536,7 @@ class array : public buffer {

void check_writeable() const {
if (!writeable())
throw std::runtime_error("array is not writeable");
throw std::domain_error("array is not writeable");
}

static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
Expand Down Expand Up @@ -568,7 +571,7 @@ class array : public buffer {
if (ptr == nullptr)
return nullptr;
return detail::npy_api::get().PyArray_FromAny_(
ptr, nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
}
};

Expand Down Expand Up @@ -632,8 +635,8 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
}

/// Ensure that the argument is a NumPy array of the correct dtype.
/// In case of an error, nullptr is returned and the Python error is cleared.
/// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
/// it). In case of an error, nullptr is returned and the Python error is cleared.
static array_t ensure(handle h) {
auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
if (!result)
Expand All @@ -654,7 +657,7 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
return nullptr;
return detail::npy_api::get().PyArray_FromAny_(
ptr, dtype::of<T>().release().ptr(), 0, 0,
detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
}
};

Expand Down
6 changes: 2 additions & 4 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,8 @@ class cpp_function : public function {
? &call.func.data : call.func.data[0]);
capture *cap = const_cast<capture *>(reinterpret_cast<const capture *>(data));

/* Override policy for rvalues -- always move */
constexpr auto is_rvalue = !std::is_pointer<Return>::value
&& !std::is_lvalue_reference<Return>::value;
const auto policy = is_rvalue ? return_value_policy::move : call.func.policy;
/* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */
const auto policy = detail::return_value_policy_override<Return>::policy(call.func.policy);

/* Perform the function call */
handle result = cast_out::cast(args_converter.template call<Return>(cap->f),
Expand Down
Loading