Skip to content

Alternative approach to #3807 #4612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 10, 2023
66 changes: 48 additions & 18 deletions include/pybind11/detail/smart_holder_type_casters.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct is_smart_holder_type<smart_holder> : std::true_type {};
// SMART_HOLDER_WIP: Needs refactoring of existing pybind11 code.
inline void register_instance(instance *self, void *valptr, const type_info *tinfo);
inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo);
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *);

// Replace all occurrences of substrings in a string.
inline void replace_all(std::string &str, const std::string &from, const std::string &to) {
Expand All @@ -49,20 +50,52 @@ inline void replace_all(std::string &str, const std::string &from, const std::st
}
}

inline bool type_is_pybind11_class_(PyTypeObject *type_obj) {
#if defined(PYPY_VERSION)
auto &internals = get_internals();
return bool(internals.registered_types_py.find(type_obj)
!= internals.registered_types_py.end());
#else
return bool(type_obj->tp_new == pybind11_object_new);
#endif
}

inline bool is_instance_method_of_type(PyTypeObject *type_obj, PyObject *attr_name) {
PyObject *descr = _PyType_Lookup(type_obj, attr_name);
return bool((descr != nullptr) && PyInstanceMethod_Check(descr));
}

inline object try_get_as_capsule_method(PyObject *obj, PyObject *attr_name) {
if (PyType_Check(obj)) {
return object();
}
PyTypeObject *type_obj = Py_TYPE(obj);
bool known_callable = false;
if (type_is_pybind11_class_(type_obj)) {
if (!is_instance_method_of_type(type_obj, attr_name)) {
return object();
}
known_callable = true;
}
PyObject *method = PyObject_GetAttr(obj, attr_name);
if (method == nullptr) {
PyErr_Clear();
return object();
}
if (!known_callable && PyCallable_Check(method) == 0) {
Py_DECREF(method);
return object();
}
return reinterpret_steal<object>(method);
}

inline void *try_as_void_ptr_capsule_get_pointer(handle src, const char *typeid_name) {
std::string type_name = typeid_name;
detail::clean_type_id(type_name);

// Convert `a::b::c` to `a_b_c`.
replace_all(type_name, "::", "_");
// Remove all `*` in the type name.
replace_all(type_name, "*", "");

std::string as_void_ptr_function_name("as_");
as_void_ptr_function_name += type_name;
if (hasattr(src, as_void_ptr_function_name.c_str())) {
auto as_void_ptr_function = function(src.attr(as_void_ptr_function_name.c_str()));
auto void_ptr_capsule = as_void_ptr_function();
std::string suffix = clean_type_id(typeid_name);
replace_all(suffix, "::", "_"); // Convert `a::b::c` to `a_b_c`.
replace_all(suffix, "*", "");
object as_capsule_method = try_get_as_capsule_method(src.ptr(), str("as_" + suffix).ptr());
if (as_capsule_method) {
object void_ptr_capsule = as_capsule_method();
if (isinstance<capsule>(void_ptr_capsule)) {
return reinterpret_borrow<capsule>(void_ptr_capsule).get_pointer();
}
Expand Down Expand Up @@ -304,11 +337,8 @@ class modified_type_caster_generic_load_impl {
loaded_v_h = value_and_holder();
return true;
}
if (convert && cpptype) {
const auto &bases = all_type_info(srctype);
if (bases.empty() && try_as_void_ptr_capsule(src)) {
return true;
}
if (convert && cpptype && try_as_void_ptr_capsule(src)) {
return true;
}
return false;
}
Expand Down
40 changes: 31 additions & 9 deletions tests/test_class_sh_void_ptr_capsule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,33 @@ struct Derived2 : Base12 {
int bar() const { return 2; }
};

struct UnspecBase {
virtual ~UnspecBase() = default;
virtual int Get() const { return 100; }
};

int PassUnspecBase(const UnspecBase &sb) { return sb.Get() + 30; }

struct UnspecDerived : UnspecBase {
int Get() const override { return 200; }
};

} // namespace class_sh_void_ptr_capsule
} // namespace pybind11_tests

PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_void_ptr_capsule::Valid)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_void_ptr_capsule::TypeWithGetattr)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_void_ptr_capsule::Base1)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_void_ptr_capsule::Base2)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_void_ptr_capsule::Base12)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_void_ptr_capsule::Derived1)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_void_ptr_capsule::Derived2)
using namespace pybind11_tests::class_sh_void_ptr_capsule;

TEST_SUBMODULE(class_sh_void_ptr_capsule, m) {
using namespace pybind11_tests::class_sh_void_ptr_capsule;
PYBIND11_SMART_HOLDER_TYPE_CASTERS(Valid)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(TypeWithGetattr)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(Base1)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(Base2)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(Base12)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(Derived1)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(Derived2)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(UnspecBase)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(UnspecDerived)

TEST_SUBMODULE(class_sh_void_ptr_capsule, m) {
py::classh<Valid>(m, "Valid");

m.def("get_from_valid_capsule", &get_from_valid_capsule);
Expand Down Expand Up @@ -102,4 +115,13 @@ TEST_SUBMODULE(class_sh_void_ptr_capsule, m) {
py::classh<Derived1, Base12>(m, "Derived1").def(py::init<>()).def("bar", &Derived1::bar);

py::classh<Derived2, Base12>(m, "Derived2").def(py::init<>()).def("bar", &Derived2::bar);

py::classh<UnspecBase>(m, "UnspecBase");
m.def("PassUnspecBase", PassUnspecBase);
py::classh<UnspecDerived>(m, "UnspecDerived") // UnspecBase NOT specified as base here.
.def(py::init<>())
.def("as_pybind11_tests_class_sh_void_ptr_capsule_UnspecBase", [](UnspecDerived *self) {
return py::reinterpret_steal<py::object>(
PyCapsule_New(static_cast<void *>(self), nullptr, nullptr));
});
}
4 changes: 4 additions & 0 deletions tests/test_class_sh_void_ptr_capsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,7 @@ def test_multiple_inheritance_getattr():
assert d2.foo() == 0
assert d2.bar() == 2
assert d2.prop2 == "Base GetAttr: prop2"


def test_pass_unspecified_base():
assert m.PassUnspecBase(m.UnspecDerived()) == 230