Skip to content

Fix argument passing from trampoline methods #2911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion include/pybind11/detail/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,11 @@ enum class return_value_policy : uint8_t {
collected while Python is still using the child. More advanced
variations of this scheme are also possible using combinations of
return_value_policy::reference and the keep_alive call policy */
reference_internal
reference_internal,

/* This internally-only used policy applies to C++ arguments passed
to virtual methods overridden in Python to allow reference passing. */
automatic_override
};

PYBIND11_NAMESPACE_BEGIN(detail)
Expand Down
23 changes: 15 additions & 8 deletions include/pybind11/detail/smart_holder_type_casters.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,10 @@ struct smart_holder_type_caster : smart_holder_type_caster_load<T>,
}

static handle cast(T const &src, return_value_policy policy, handle parent) {
return cast(const_cast<T &>(src), policy, parent);
}

static handle cast(T &src, return_value_policy policy, handle parent) {
// type_caster_base BEGIN
// clang-format off
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
Expand All @@ -481,11 +485,13 @@ struct smart_holder_type_caster : smart_holder_type_caster_load<T>,
// type_caster_base END
}

static handle cast(T &src, return_value_policy policy, handle parent) {
return cast(const_cast<T const &>(src), policy, parent); // Mutbl2Const
static handle cast(T const *src, return_value_policy policy, handle parent) {
return cast(const_cast<T *>(src), policy, parent);
}

static handle cast(T const *src, return_value_policy policy, handle parent) {
static handle cast(T *src, return_value_policy policy, handle parent) {
if (policy == return_value_policy::automatic_override)
policy = return_value_policy::reference;
auto st = type_caster_base<T>::src_and_type(src);
return cast_const_raw_ptr( // Originally type_caster_generic::cast.
st.first,
Expand All @@ -496,10 +502,6 @@ struct smart_holder_type_caster : smart_holder_type_caster_load<T>,
make_constructor::make_move_constructor(src));
}

static handle cast(T *src, return_value_policy policy, handle parent) {
return cast(const_cast<T const *>(src), policy, parent); // Mutbl2Const
}

#if defined(_MSC_VER) && _MSC_VER < 1910
// Working around MSVC 2015 bug. const-correctness is lost.
// SMART_HOLDER_WIP: IMPROVABLE: make common code work with MSVC 2015.
Expand Down Expand Up @@ -723,14 +725,19 @@ struct smart_holder_type_caster<std::unique_ptr<T, D>> : smart_holder_type_caste
return none().release();
if (policy == return_value_policy::automatic)
policy = return_value_policy::reference_internal;
if (policy != return_value_policy::reference_internal)
else if (policy == return_value_policy::automatic_override)
;
else if (policy != return_value_policy::reference_internal)
throw cast_error("Invalid return_value_policy for unique_ptr&");
return smart_holder_type_caster<T>::cast(src.get(), policy, parent);
}

template <typename>
using cast_op_type = std::unique_ptr<T, D>;

// TODO: This always returns a new, moving unique_ptr instance to the raw pointer,
// even if argument should be passed as reference.
// See test_class_sh_basic.py::test_unique_ptr_cref_roundtrip
operator std::unique_ptr<T, D>() { return this->template loaded_as_unique_ptr<D>(); }
};

Expand Down
2 changes: 2 additions & 0 deletions include/pybind11/detail/type_caster_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,8 @@ template <typename type> class type_caster_base : public type_caster_generic {
}

static handle cast(const itype *src, return_value_policy policy, handle parent) {
if (policy == return_value_policy::automatic_override)
policy = return_value_policy::reference;
auto st = src_and_type(src);
return type_caster_generic::cast(
st.first, policy, parent, st.second,
Expand Down
4 changes: 2 additions & 2 deletions include/pybind11/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class object_api : public pyobject_tag {
function will throw a `cast_error` exception. When the Python function
call fails, a `error_already_set` exception is thrown.
\endrst */
template <return_value_policy policy = return_value_policy::automatic_reference, typename... Args>
template <return_value_policy policy = return_value_policy::automatic_override, typename... Args>
object operator()(Args &&...args) const;
template <return_value_policy policy = return_value_policy::automatic_reference, typename... Args>
template <return_value_policy policy = return_value_policy::automatic_override, typename... Args>
PYBIND11_DEPRECATED("call(...) was deprecated in favor of operator()(...)")
object call(Args&&... args) const;

Expand Down
50 changes: 34 additions & 16 deletions tests/test_class_sh_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <pybind11/smart_holder.h>

#include <cstdint>
#include <memory>
#include <string>
#include <vector>
Expand All @@ -17,15 +18,27 @@ struct atyp { // Short for "any type".
atyp(atyp &&other) { mtxt = other.mtxt + "_MvCtor"; }
};

struct uconsumer { // unique_ptr consumer
struct consumer { // unique_ptr consumer
std::unique_ptr<atyp> held;
bool valid() const { return static_cast<bool>(held); }

void pass_valu(std::unique_ptr<atyp> obj) { held = std::move(obj); }
void pass_rref(std::unique_ptr<atyp> &&obj) { held = std::move(obj); }
std::unique_ptr<atyp> rtrn_valu() { return std::move(held); }
std::unique_ptr<atyp> &rtrn_lref() { return held; }
const std::unique_ptr<atyp> &rtrn_cref() { return held; }
std::string pass_uq_valu(std::unique_ptr<atyp> obj) {
held = std::move(obj);
return held->mtxt;
}
std::string pass_uq_rref(std::unique_ptr<atyp> &&obj) {
held = std::move(obj);
return held->mtxt;
}
std::string pass_uq_cref(const std::unique_ptr<atyp> &obj) { return obj->mtxt; }
std::string pass_cptr(const atyp *obj) { return obj->mtxt; }
std::string pass_cref(const atyp &obj) { return obj.mtxt; }

std::unique_ptr<atyp> rtrn_uq_valu() { return std::move(held); }
std::unique_ptr<atyp> &rtrn_uq_lref() { return held; }
const std::unique_ptr<atyp> &rtrn_uq_cref() { return held; }
const atyp *rtrn_cptr() { return held.get(); }
const atyp &rtrn_cref() { return *held; }
};

// clang-format off
Expand Down Expand Up @@ -68,7 +81,7 @@ std::string pass_udcp(std::unique_ptr<atyp const, sddc> obj) { return "pass_udcp

// Helpers for testing.
std::string get_mtxt(atyp const &obj) { return obj.mtxt; }
std::ptrdiff_t get_ptr(atyp const &obj) { return reinterpret_cast<std::ptrdiff_t>(&obj); }
std::uintptr_t get_ptr(atyp const &obj) { return reinterpret_cast<std::uintptr_t>(&obj); }

std::unique_ptr<atyp> unique_ptr_roundtrip(std::unique_ptr<atyp> obj) { return obj; }
const std::unique_ptr<atyp> &unique_ptr_cref_roundtrip(const std::unique_ptr<atyp> &obj) {
Expand All @@ -84,7 +97,7 @@ struct SharedPtrStash {
} // namespace pybind11_tests

PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::atyp)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::uconsumer)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::consumer)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::SharedPtrStash)

namespace pybind11_tests {
Expand All @@ -102,7 +115,7 @@ TEST_SUBMODULE(class_sh_basic, m) {
m.def("rtrn_valu", rtrn_valu);
m.def("rtrn_rref", rtrn_rref);
m.def("rtrn_cref", rtrn_cref);
m.def("rtrn_mref", rtrn_mref);
m.def("rtrn_mref", rtrn_mref, py::return_value_policy::reference);
m.def("rtrn_cptr", rtrn_cptr);
m.def("rtrn_mptr", rtrn_mptr);

Expand Down Expand Up @@ -130,14 +143,19 @@ TEST_SUBMODULE(class_sh_basic, m) {
m.def("pass_udmp", pass_udmp);
m.def("pass_udcp", pass_udcp);

py::classh<uconsumer>(m, "uconsumer")
py::classh<consumer>(m, "consumer")
.def(py::init<>())
.def("valid", &uconsumer::valid)
.def("pass_valu", &uconsumer::pass_valu)
.def("pass_rref", &uconsumer::pass_rref)
.def("rtrn_valu", &uconsumer::rtrn_valu)
.def("rtrn_lref", &uconsumer::rtrn_lref)
.def("rtrn_cref", &uconsumer::rtrn_cref);
.def("valid", &consumer::valid)
.def("pass_uq_valu", &consumer::pass_uq_valu)
.def("pass_uq_rref", &consumer::pass_uq_rref)
.def("pass_uq_cref", &consumer::pass_uq_cref)
.def("pass_cptr", &consumer::pass_cptr)
.def("pass_cref", &consumer::pass_cref)
.def("rtrn_uq_valu", &consumer::rtrn_uq_valu)
.def("rtrn_uq_lref", &consumer::rtrn_uq_lref)
.def("rtrn_uq_cref", &consumer::rtrn_uq_cref)
.def("rtrn_cptr", &consumer::rtrn_cptr, py::return_value_policy::reference_internal)
.def("rtrn_cref", &consumer::rtrn_cref, py::return_value_policy::reference_internal);

// Helpers for testing.
// These require selected functions above to work first, as indicated:
Expand Down
127 changes: 86 additions & 41 deletions tests/test_class_sh_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,19 @@ def test_atyp_constructors():
assert obj.__class__.__name__ == "atyp"


def check_regex(expected, actual):
result = re.match(expected + "$", actual)
if result is None:
pytest.fail("expected: '{}' != actual: '{}'".format(expected, actual))


@pytest.mark.parametrize(
"rtrn_f, expected",
[
(m.rtrn_valu, "rtrn_valu(_MvCtor)*_MvCtor"),
(m.rtrn_rref, "rtrn_rref(_MvCtor)*_MvCtor"),
(m.rtrn_cref, "rtrn_cref(_MvCtor)*_CpCtor"),
(m.rtrn_mref, "rtrn_mref(_MvCtor)*_CpCtor"),
(m.rtrn_valu, "rtrn_valu(_MvCtor){1,3}"),
(m.rtrn_rref, "rtrn_rref(_MvCtor){1}"),
(m.rtrn_cref, "rtrn_cref_CpCtor"),
(m.rtrn_mref, "rtrn_mref"),
(m.rtrn_cptr, "rtrn_cptr"),
(m.rtrn_mptr, "rtrn_mptr"),
(m.rtrn_shmp, "rtrn_shmp"),
Expand All @@ -34,25 +40,25 @@ def test_atyp_constructors():
],
)
def test_cast(rtrn_f, expected):
assert re.match(expected, m.get_mtxt(rtrn_f()))
check_regex(expected, m.get_mtxt(rtrn_f()))


@pytest.mark.parametrize(
"pass_f, mtxt, expected",
[
(m.pass_valu, "Valu", "pass_valu:Valu(_MvCtor)*_CpCtor"),
(m.pass_cref, "Cref", "pass_cref:Cref(_MvCtor)*_MvCtor"),
(m.pass_mref, "Mref", "pass_mref:Mref(_MvCtor)*_MvCtor"),
(m.pass_cptr, "Cptr", "pass_cptr:Cptr(_MvCtor)*_MvCtor"),
(m.pass_mptr, "Mptr", "pass_mptr:Mptr(_MvCtor)*_MvCtor"),
(m.pass_shmp, "Shmp", "pass_shmp:Shmp(_MvCtor)*_MvCtor"),
(m.pass_shcp, "Shcp", "pass_shcp:Shcp(_MvCtor)*_MvCtor"),
(m.pass_uqmp, "Uqmp", "pass_uqmp:Uqmp(_MvCtor)*_MvCtor"),
(m.pass_uqcp, "Uqcp", "pass_uqcp:Uqcp(_MvCtor)*_MvCtor"),
(m.pass_valu, "Valu", "pass_valu:Valu(_MvCtor){1,2}_CpCtor"),
(m.pass_cref, "Cref", "pass_cref:Cref(_MvCtor){1,2}"),
(m.pass_mref, "Mref", "pass_mref:Mref(_MvCtor){1,2}"),
(m.pass_cptr, "Cptr", "pass_cptr:Cptr(_MvCtor){1,2}"),
(m.pass_mptr, "Mptr", "pass_mptr:Mptr(_MvCtor){1,2}"),
(m.pass_shmp, "Shmp", "pass_shmp:Shmp(_MvCtor){1,2}"),
(m.pass_shcp, "Shcp", "pass_shcp:Shcp(_MvCtor){1,2}"),
(m.pass_uqmp, "Uqmp", "pass_uqmp:Uqmp(_MvCtor){1,2}"),
(m.pass_uqcp, "Uqcp", "pass_uqcp:Uqcp(_MvCtor){1,2}"),
],
)
def test_load_with_mtxt(pass_f, mtxt, expected):
assert re.match(expected, pass_f(m.atyp(mtxt)))
check_regex(expected, pass_f(m.atyp(mtxt)))


@pytest.mark.parametrize(
Expand Down Expand Up @@ -111,53 +117,92 @@ def test_unique_ptr_roundtrip(num_round_trips=1000):
for _ in range(num_round_trips):
id_orig = id(recycled)
recycled = m.unique_ptr_roundtrip(recycled)
assert re.match("passenger(_MvCtor)*_MvCtor", m.get_mtxt(recycled))
check_regex("passenger(_MvCtor){1,2}", m.get_mtxt(recycled))
id_rtrn = id(recycled)
# Ensure the returned object is a different Python instance.
assert id_rtrn != id_orig
id_orig = id_rtrn


# This currently fails, because a unique_ptr is always loaded by value
# due to pybind11/detail/smart_holder_type_casters.h:689
# I think, we need to provide more cast operators.
@pytest.mark.skip
def test_unique_ptr_cref_roundtrip(num_round_trips=1000):
orig = m.atyp("passenger")
id_orig = id(orig)
mtxt_orig = m.get_mtxt(orig)

recycled = m.unique_ptr_cref_roundtrip(orig)
assert m.get_mtxt(orig) == mtxt_orig
assert m.get_mtxt(recycled) == mtxt_orig
assert id(recycled) == id_orig


@pytest.mark.parametrize(
"pass_f, rtrn_f, moved_out, moved_in",
[
(m.uconsumer.pass_valu, m.uconsumer.rtrn_valu, True, True),
(m.uconsumer.pass_rref, m.uconsumer.rtrn_valu, True, True),
(m.uconsumer.pass_valu, m.uconsumer.rtrn_lref, True, False),
(m.uconsumer.pass_valu, m.uconsumer.rtrn_cref, True, False),
(m.consumer.pass_uq_valu, m.consumer.rtrn_uq_valu, True, True),
(m.consumer.pass_uq_rref, m.consumer.rtrn_uq_valu, True, True),
(m.consumer.pass_uq_valu, m.consumer.rtrn_uq_lref, True, False),
(m.consumer.pass_uq_valu, m.consumer.rtrn_uq_cref, True, False),
],
)
def test_unique_ptr_consumer_roundtrip(pass_f, rtrn_f, moved_out, moved_in):
c = m.uconsumer()
assert not c.valid()
c = m.consumer()
recycled = m.atyp("passenger")
mtxt_orig = m.get_mtxt(recycled)
assert re.match("passenger_(MvCtor){1,2}", mtxt_orig)
ptr_orig = m.get_ptr(recycled)
check_regex("passenger(_MvCtor){1,2}", mtxt_orig)

pass_f(c, recycled)
if moved_out:
pass_f(c, recycled) # pass object to C++ consumer c
if moved_out: # if moved (always), ensure it is flagged as disowned
with pytest.raises(ValueError) as excinfo:
m.get_mtxt(recycled)
assert "Python instance was disowned" in str(excinfo.value)

recycled = rtrn_f(c)
assert c.valid() != moved_in
assert c.valid() != moved_in # consumer gave up ownership?
assert m.get_ptr(recycled) == ptr_orig # underlying C++ object never changes
assert m.get_mtxt(recycled) == mtxt_orig # object was not moved or copied


@pytest.mark.parametrize(
"rtrn_f",
[m.consumer.rtrn_uq_cref, m.consumer.rtrn_cref, m.consumer.rtrn_cptr],
)
@pytest.mark.parametrize(
"pass_f",
[
# This fails with: ValueError: Cannot disown non-owning holder (loaded_as_unique_ptr).
#
# smart_holder_type_caster_load<T>::loaded_as_unique_ptr() attempts to pass
# the not-owned cref as a new unique_ptr, which would eventually destroy the object,
# and is thus (correctly) suppressed.
# To fix this, smart_holder would need to store the (original) unique_ptr reference,
# e.g. using a union of unique_ptr + shared_ptr.
pytest.param(m.consumer.pass_uq_cref, marks=pytest.mark.xfail),
m.consumer.pass_cptr,
m.consumer.pass_cref,
],
)
def test_unique_ptr_cref_consumer_roundtrip(rtrn_f, pass_f):
c = m.consumer()
passenger = m.atyp("passenger")
mtxt_orig = m.get_mtxt(passenger)
ptr_orig = m.get_ptr(passenger)

c.pass_uq_valu(passenger) # moves passenger to C++ (checked above)

for _ in range(10):
cref = rtrn_f(c) # fetches const reference, should keep-alive parent c
assert pass_f(c, cref) == mtxt_orig
assert m.get_ptr(cref) == ptr_orig


# This fails with: ValueError: Missing value for wrapped C++ type: Python instance was disowned
# when accessing the orig object after passing it into m.unique_ptr_cref_roundtrip().
# This is because smart_holder_type_caster_load<T>::loaded_as_unique_ptr() always moves.
@pytest.mark.xfail
def test_unique_ptr_cref_roundtrip():
orig = m.atyp("passenger")
id_orig = id(orig)
ptr_orig = m.get_ptr(orig)
mtxt_orig = m.get_mtxt(orig)

recycled = m.unique_ptr_cref_roundtrip(orig)
# passing by reference shouldn't change pointer
assert m.get_ptr(orig) == ptr_orig
assert m.get_ptr(recycled) == ptr_orig
# nor apply any copy or move construction
assert m.get_mtxt(orig) == mtxt_orig
assert m.get_mtxt(recycled) == mtxt_orig
assert id(recycled) == id_orig


def test_py_type_handle_of_atyp():
Expand Down
Loading