Skip to content

Commit 218bcb3

Browse files
jagermandean0x7d
andcommitted
Override deduced Base class when defining Derived methods
When defining method from a member function pointer (e.g. `.def("f", &Derived::f)`) we run into a problem if `&Derived::f` is actually implemented in some base class `Base` when `Base` isn't pybind-registered. This happens because the class type is deduced from the member function pointer, which then becomes a lambda with first argument this deduced type. For a base class implementation, the deduced type is `Base`, not `Derived`, and so we generate and registered an overload which takes a `Base *` as first argument. Trying to call this fails if `Base` isn't registered (e.g. because it's an implementation detail class that isn't intended to be exposed to Python) because the type caster for an unregistered type always fails. This commit adds a `method_adaptor` function that rebinds a member function to a derived type member function and otherwise (i.e. regular functions/lambda) leaves the argument as-is. This is now used for class definitions so that they are bound with type being registered rather than a potential base type. A closely related fix in this commit is to similarly update the lambdas used for `def_readwrite` (and related) to bind to the class type being registered rather than the deduced type so that registering a property that resolves to a base class member similarly generates a usable function. Fixes #854, #910. Co-Authored-By: Dean Moldovan <[email protected]>
1 parent 6b442ff commit 218bcb3

File tree

4 files changed

+96
-11
lines changed

4 files changed

+96
-11
lines changed

include/pybind11/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,11 @@ using exactly_one_t = typename exactly_one<Predicate, Default, Ts...>::type;
621621
template <typename T, typename... /*Us*/> struct deferred_type { using type = T; };
622622
template <typename T, typename... Us> using deferred_t = typename deferred_type<T, Us...>::type;
623623

624+
/// Like is_base_of, but requires a strict base (i.e. `is_strict_base_of<T, T>::value == false`,
625+
/// unlike `std::is_base_of`)
626+
template <typename Base, typename Derived> using is_strict_base_of = bool_constant<
627+
std::is_base_of<Base, Derived>::value && !std::is_same<Base, Derived>::value>;
628+
624629
template <template<typename...> class Base>
625630
struct is_template_base_of_impl {
626631
template <typename... Us> static std::true_type check(Base<Us...> *);

include/pybind11/pybind11.h

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -908,11 +908,24 @@ inline void call_operator_delete(void *p) { ::operator delete(p); }
908908

909909
NAMESPACE_END(detail)
910910

911+
/// Given a pointer to a member function, cast it to its `Derived` version.
912+
/// Forward everything else unchanged.
913+
template <typename /*Derived*/, typename F>
914+
auto method_adaptor(F &&f) -> decltype(std::forward<F>(f)) { return std::forward<F>(f); }
915+
916+
template <typename Derived, typename Return, typename Class, typename... Args,
917+
typename Adapted = Return (Derived::*)(Args...)>
918+
Adapted method_adaptor(Return (Class::*pmf)(Args...)) { return pmf; }
919+
920+
template <typename Derived, typename Return, typename Class, typename... Args,
921+
typename Adapted = Return (Derived::*)(Args...) const>
922+
Adapted method_adaptor(Return (Class::*pmf)(Args...) const) { return pmf; }
923+
911924
template <typename type_, typename... options>
912925
class class_ : public detail::generic_type {
913926
template <typename T> using is_holder = detail::is_holder_type<type_, T>;
914-
template <typename T> using is_subtype = detail::bool_constant<std::is_base_of<type_, T>::value && !std::is_same<T, type_>::value>;
915-
template <typename T> using is_base = detail::bool_constant<std::is_base_of<T, type_>::value && !std::is_same<T, type_>::value>;
927+
template <typename T> using is_subtype = detail::is_strict_base_of<type_, T>;
928+
template <typename T> using is_base = detail::is_strict_base_of<T, type_>;
916929
// struct instead of using here to help MSVC:
917930
template <typename T> struct is_valid_class_option :
918931
detail::any_of<is_holder<T>, is_subtype<T>, is_base<T>> {};
@@ -978,8 +991,8 @@ class class_ : public detail::generic_type {
978991

979992
template <typename Func, typename... Extra>
980993
class_ &def(const char *name_, Func&& f, const Extra&... extra) {
981-
cpp_function cf(std::forward<Func>(f), name(name_), is_method(*this),
982-
sibling(getattr(*this, name_, none())), extra...);
994+
auto cf = cpp_function(method_adaptor<type>(std::forward<Func>(f)), name(name_),
995+
is_method(*this), sibling(getattr(*this, name_, none())), extra...);
983996
attr(cf.name()) = cf;
984997
return *this;
985998
}
@@ -1042,15 +1055,17 @@ class class_ : public detail::generic_type {
10421055

10431056
template <typename C, typename D, typename... Extra>
10441057
class_ &def_readwrite(const char *name, D C::*pm, const Extra&... extra) {
1045-
cpp_function fget([pm](const C &c) -> const D &{ return c.*pm; }, is_method(*this)),
1046-
fset([pm](C &c, const D &value) { c.*pm = value; }, is_method(*this));
1058+
static_assert(std::is_base_of<C, type>::value, "def_readwrite() requires a class member (or base class member)");
1059+
cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this)),
1060+
fset([pm](type &c, const D &value) { c.*pm = value; }, is_method(*this));
10471061
def_property(name, fget, fset, return_value_policy::reference_internal, extra...);
10481062
return *this;
10491063
}
10501064

10511065
template <typename C, typename D, typename... Extra>
10521066
class_ &def_readonly(const char *name, const D C::*pm, const Extra& ...extra) {
1053-
cpp_function fget([pm](const C &c) -> const D &{ return c.*pm; }, is_method(*this));
1067+
static_assert(std::is_base_of<C, type>::value, "def_readonly() requires a class member (or base class member)");
1068+
cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this));
10541069
def_property_readonly(name, fget, return_value_policy::reference_internal, extra...);
10551070
return *this;
10561071
}
@@ -1073,7 +1088,8 @@ class class_ : public detail::generic_type {
10731088
/// Uses return_value_policy::reference_internal by default
10741089
template <typename Getter, typename... Extra>
10751090
class_ &def_property_readonly(const char *name, const Getter &fget, const Extra& ...extra) {
1076-
return def_property_readonly(name, cpp_function(fget), return_value_policy::reference_internal, extra...);
1091+
return def_property_readonly(name, cpp_function(method_adaptor<type>(fget)),
1092+
return_value_policy::reference_internal, extra...);
10771093
}
10781094

10791095
/// Uses cpp_function's return_value_policy by default
@@ -1095,9 +1111,14 @@ class class_ : public detail::generic_type {
10951111
}
10961112

10971113
/// Uses return_value_policy::reference_internal by default
1114+
template <typename Getter, typename Setter, typename... Extra>
1115+
class_ &def_property(const char *name, const Getter &fget, const Setter &fset, const Extra& ...extra) {
1116+
return def_property(name, fget, cpp_function(method_adaptor<type>(fset)), extra...);
1117+
}
10981118
template <typename Getter, typename... Extra>
10991119
class_ &def_property(const char *name, const Getter &fget, const cpp_function &fset, const Extra& ...extra) {
1100-
return def_property(name, cpp_function(fget), fset, return_value_policy::reference_internal, extra...);
1120+
return def_property(name, cpp_function(method_adaptor<type>(fget)), fset,
1121+
return_value_policy::reference_internal, extra...);
11011122
}
11021123

11031124
/// Uses cpp_function's return_value_policy by default

tests/test_methods_and_attributes.cpp

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ template <> struct type_caster<ArgAlwaysConverts> {
159159
};
160160
}}
161161

162-
/// Issue/PR #648: bad arg default debugging output
162+
// Issue/PR #648: bad arg default debugging output
163163
class NotRegistered {};
164164

165165
// Test None-allowed py::arg argument policy
@@ -177,6 +177,23 @@ struct StrIssue {
177177
StrIssue(int i) : val{i} {}
178178
};
179179

180+
// Issues #854, #910: incompatible function args when member function/pointer is in unregistered base class
181+
class UnregisteredBase {
182+
public:
183+
void do_nothing() const {}
184+
void increase_value() { rw_value++; ro_value += 0.25; }
185+
void set_int(int v) { rw_value = v; }
186+
int get_int() const { return rw_value; }
187+
double get_double() const { return ro_value; }
188+
int rw_value = 42;
189+
double ro_value = 1.25;
190+
};
191+
class RegisteredDerived : public UnregisteredBase {
192+
public:
193+
using UnregisteredBase::UnregisteredBase;
194+
double sum() const { return rw_value + ro_value; }
195+
};
196+
180197
test_initializer methods_and_attributes([](py::module &m) {
181198
py::class_<ExampleMandA> emna(m, "ExampleMandA");
182199
emna.def(py::init<>())
@@ -325,7 +342,7 @@ test_initializer methods_and_attributes([](py::module &m) {
325342
m.def("ints_preferred", [](int i) { return i / 2; }, py::arg("i"));
326343
m.def("ints_only", [](int i) { return i / 2; }, py::arg("i").noconvert());
327344

328-
/// Issue/PR #648: bad arg default debugging output
345+
// Issue/PR #648: bad arg default debugging output
329346
#if !defined(NDEBUG)
330347
m.attr("debug_enabled") = true;
331348
#else
@@ -360,4 +377,26 @@ test_initializer methods_and_attributes([](py::module &m) {
360377
.def("__str__", [](const StrIssue &si) {
361378
return "StrIssue[" + std::to_string(si.val) + "]"; }
362379
);
380+
381+
// Issues #854/910: incompatible function args when member function/pointer is in unregistered
382+
// base class The methods and member pointers below actually resolve to members/pointers in
383+
// UnregisteredBase; before this test/fix they would be registered via lambda with a first
384+
// argument of an unregistered type, and thus uncallable.
385+
py::class_<RegisteredDerived>(m, "RegisteredDerived")
386+
.def(py::init<>())
387+
.def("do_nothing", &RegisteredDerived::do_nothing)
388+
.def("increase_value", &RegisteredDerived::increase_value)
389+
.def_readwrite("rw_value", &RegisteredDerived::rw_value)
390+
.def_readonly("ro_value", &RegisteredDerived::ro_value)
391+
// These should trigger a static_assert if uncommented
392+
//.def_readwrite("fails", &SimpleValue::value) // should trigger a static_assert if uncommented
393+
//.def_readonly("fails", &SimpleValue::value) // should trigger a static_assert if uncommented
394+
.def_property("rw_value_prop", &RegisteredDerived::get_int, &RegisteredDerived::set_int)
395+
.def_property_readonly("ro_value_prop", &RegisteredDerived::get_double)
396+
// This one is in the registered class:
397+
.def("sum", &RegisteredDerived::sum)
398+
;
399+
400+
using Adapted = decltype(py::method_adaptor<RegisteredDerived>(&RegisteredDerived::do_nothing));
401+
static_assert(std::is_same<Adapted, void (RegisteredDerived::*)() const>::value, "");
363402
});

tests/test_methods_and_attributes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,23 @@ def test_str_issue(msg):
457457
458458
Invoked with: 'no', 'such', 'constructor'
459459
"""
460+
461+
462+
def test_unregistered_base_implementations():
463+
from pybind11_tests import RegisteredDerived
464+
465+
a = RegisteredDerived()
466+
a.do_nothing()
467+
assert a.rw_value == 42
468+
assert a.ro_value == 1.25
469+
a.rw_value += 5
470+
assert a.sum() == 48.25
471+
a.increase_value()
472+
assert a.rw_value == 48
473+
assert a.ro_value == 1.5
474+
assert a.sum() == 49.5
475+
assert a.rw_value_prop == 48
476+
a.rw_value_prop += 1
477+
assert a.rw_value_prop == 49
478+
a.increase_value()
479+
assert a.ro_value_prop == 1.75

0 commit comments

Comments
 (0)