Skip to content

Commit 5469c23

Browse files
authored
Adjusting type_caster<std::reference_wrapper<T>> to support const/non-const propagation in cast_op. (#2705)
* Allow type_caster of std::reference_wrapper<T> to be the same as a native reference. Before, both std::reference_wrapper<T> and std::reference_wrapper<const T> would invoke cast_op<type>. This doesn't allow the type_caster<> specialization for T to distinguish reference_wrapper types from value types. After, the type_caster<> specialization invokes cast_op<type&>, which allows reference_wrapper to behave in the same way as a native reference type. * Add tests/examples for std::reference_wrapper<const T> * Add tests which use mutable/immutable variants This test is a chimera; it blends the pybind11 casters with a custom pytype implementation that supports immutable and mutable calls. In order to detect the immutable/mutable state, the cast_op needs to propagate it, even through e.g. std::reference<const T> Note: This is still a work in progress; some things are crashing, which likely means that I have a refcounting bug or something else missing. * Add/finish tests that distinguish const& from & Fixes the bugs in my custom python type implementation, demonstrate test that requires const& and reference_wrapper<const T> being treated differently from Non-const. * Add passing a const to non-const method. * Demonstrate non-const conversion of reference_wrapper in tests. Apply formatting presubmit check. * Fix build errors from presubmit checks. * Try and fix a few more CI errors * More CI fixes. * More CI fixups. * Try and get PyPy to work. * Additional minor fixups. Getting close to CI green. * More ci fixes? * fix clang-tidy warnings from presubmit * fix more clang-tidy warnings * minor comment and consistency cleanups * PyDECREF -> Py_DECREF * copy/move constructors * Resolve codereview comments * more review comment fixes * review comments: remove spurious & * Make the test fail even when the static_assert is commented out. This expands the test_freezable_type_caster a bit by: 1/ adding accessors .is_immutable and .addr to compare identity from python. 2/ Changing the default cast_op of the type_caster<> specialization to return a non-const value. In normal codepaths this is a reasonable default. 3/ adding roundtrip variants to exercise the by reference, by pointer and by reference_wrapper in all call paths. In conjunction with 2/, this demonstrates the failure case of the existing std::reference_wrpper conversion, which now loses const in a similar way that happens when using the default cast_op_type<>. * apply presubmit formatting * Revert inclusion of test_freezable_type_caster There's some concern that this test is a bit unwieldly because of the use of the raw <Python.h> functions. Removing for now. * Add a test that validates const references propagation. This test verifies that cast_op may be used to correctly detect const reference types when used with std::reference_wrapper. * mend * Review comments based changes. 1. std::add_lvalue_reference<type> -> type& 2. Simplify the test a little more; we're never returning the ConstRefCaster type so the class_ definition can be removed. * formatted files again. * Move const_ref_caster test to builtin_casters * Review comments: use cast_op and adjust some comments. * Simplify ConstRefCasted test I like this version better as it moves the assertion that matters back into python.
1 parent 91a6972 commit 5469c23

File tree

3 files changed

+95
-4
lines changed

3 files changed

+95
-4
lines changed

include/pybind11/cast.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -960,9 +960,14 @@ template <typename type> class type_caster<std::reference_wrapper<type>> {
960960
private:
961961
using caster_t = make_caster<type>;
962962
caster_t subcaster;
963-
using subcaster_cast_op_type = typename caster_t::template cast_op_type<type>;
964-
static_assert(std::is_same<typename std::remove_const<type>::type &, subcaster_cast_op_type>::value,
965-
"std::reference_wrapper<T> caster requires T to have a caster with an `T &` operator");
963+
using reference_t = type&;
964+
using subcaster_cast_op_type =
965+
typename caster_t::template cast_op_type<reference_t>;
966+
967+
static_assert(std::is_same<typename std::remove_const<type>::type &, subcaster_cast_op_type>::value ||
968+
std::is_same<reference_t, subcaster_cast_op_type>::value,
969+
"std::reference_wrapper<T> caster requires T to have a caster with an "
970+
"`operator T &()` or `operator const T &()`");
966971
public:
967972
bool load(handle src, bool convert) { return subcaster.load(src, convert); }
968973
static constexpr auto name = caster_t::name;
@@ -973,7 +978,7 @@ template <typename type> class type_caster<std::reference_wrapper<type>> {
973978
return caster_t::cast(&src.get(), policy, parent);
974979
}
975980
template <typename T> using cast_op_type = std::reference_wrapper<type>;
976-
operator std::reference_wrapper<type>() { return subcaster.operator subcaster_cast_op_type&(); }
981+
operator std::reference_wrapper<type>() { return cast_op<type &>(subcaster); }
977982
};
978983

979984
#define PYBIND11_TYPE_CASTER(type, py_name) \

tests/test_builtin_casters.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,49 @@
1515
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
1616
#endif
1717

18+
struct ConstRefCasted {
19+
int tag;
20+
};
21+
22+
PYBIND11_NAMESPACE_BEGIN(pybind11)
23+
PYBIND11_NAMESPACE_BEGIN(detail)
24+
template <>
25+
class type_caster<ConstRefCasted> {
26+
public:
27+
static constexpr auto name = _<ConstRefCasted>();
28+
29+
// Input is unimportant, a new value will always be constructed based on the
30+
// cast operator.
31+
bool load(handle, bool) { return true; }
32+
33+
operator ConstRefCasted&&() { value = {1}; return std::move(value); }
34+
operator ConstRefCasted&() { value = {2}; return value; }
35+
operator ConstRefCasted*() { value = {3}; return &value; }
36+
37+
operator const ConstRefCasted&() { value = {4}; return value; }
38+
operator const ConstRefCasted*() { value = {5}; return &value; }
39+
40+
// custom cast_op to explicitly propagate types to the conversion operators.
41+
template <typename T_>
42+
using cast_op_type =
43+
/// const
44+
conditional_t<
45+
std::is_same<remove_reference_t<T_>, const ConstRefCasted*>::value, const ConstRefCasted*,
46+
conditional_t<
47+
std::is_same<T_, const ConstRefCasted&>::value, const ConstRefCasted&,
48+
/// non-const
49+
conditional_t<
50+
std::is_same<remove_reference_t<T_>, ConstRefCasted*>::value, ConstRefCasted*,
51+
conditional_t<
52+
std::is_same<T_, ConstRefCasted&>::value, ConstRefCasted&,
53+
/* else */ConstRefCasted&&>>>>;
54+
55+
private:
56+
ConstRefCasted value = {0};
57+
};
58+
PYBIND11_NAMESPACE_END(detail)
59+
PYBIND11_NAMESPACE_END(pybind11)
60+
1861
TEST_SUBMODULE(builtin_casters, m) {
1962
// test_simple_string
2063
m.def("string_roundtrip", [](const char *s) { return s; });
@@ -147,6 +190,17 @@ TEST_SUBMODULE(builtin_casters, m) {
147190
// test_reference_wrapper
148191
m.def("refwrap_builtin", [](std::reference_wrapper<int> p) { return 10 * p.get(); });
149192
m.def("refwrap_usertype", [](std::reference_wrapper<UserType> p) { return p.get().value(); });
193+
m.def("refwrap_usertype_const", [](std::reference_wrapper<const UserType> p) { return p.get().value(); });
194+
195+
m.def("refwrap_lvalue", []() -> std::reference_wrapper<UserType> {
196+
static UserType x(1);
197+
return std::ref(x);
198+
});
199+
m.def("refwrap_lvalue_const", []() -> std::reference_wrapper<const UserType> {
200+
static UserType x(1);
201+
return std::cref(x);
202+
});
203+
150204
// Not currently supported (std::pair caster has return-by-value cast operator);
151205
// triggers static_assert failure.
152206
//m.def("refwrap_pair", [](std::reference_wrapper<std::pair<int, int>>) { });
@@ -189,4 +243,14 @@ TEST_SUBMODULE(builtin_casters, m) {
189243
py::object o = py::cast(v);
190244
return py::cast<void *>(o) == v;
191245
});
246+
247+
// Tests const/non-const propagation in cast_op.
248+
m.def("takes", [](ConstRefCasted x) { return x.tag; });
249+
m.def("takes_move", [](ConstRefCasted&& x) { return x.tag; });
250+
m.def("takes_ptr", [](ConstRefCasted* x) { return x->tag; });
251+
m.def("takes_ref", [](ConstRefCasted& x) { return x.tag; });
252+
m.def("takes_ref_wrap", [](std::reference_wrapper<ConstRefCasted> x) { return x.get().tag; });
253+
m.def("takes_const_ptr", [](const ConstRefCasted* x) { return x->tag; });
254+
m.def("takes_const_ref", [](const ConstRefCasted& x) { return x.tag; });
255+
m.def("takes_const_ref_wrap", [](std::reference_wrapper<const ConstRefCasted> x) { return x.get().tag; });
192256
}

tests/test_builtin_casters.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def test_reference_wrapper():
315315
"""std::reference_wrapper for builtin and user types"""
316316
assert m.refwrap_builtin(42) == 420
317317
assert m.refwrap_usertype(UserType(42)) == 42
318+
assert m.refwrap_usertype_const(UserType(42)) == 42
318319

319320
with pytest.raises(TypeError) as excinfo:
320321
m.refwrap_builtin(None)
@@ -324,6 +325,9 @@ def test_reference_wrapper():
324325
m.refwrap_usertype(None)
325326
assert "incompatible function arguments" in str(excinfo.value)
326327

328+
assert m.refwrap_lvalue().value == 1
329+
assert m.refwrap_lvalue_const().value == 1
330+
327331
a1 = m.refwrap_list(copy=True)
328332
a2 = m.refwrap_list(copy=True)
329333
assert [x.value for x in a1] == [2, 3]
@@ -421,3 +425,21 @@ def test_int_long():
421425

422426
def test_void_caster_2():
423427
assert m.test_void_caster()
428+
429+
430+
def test_const_ref_caster():
431+
"""Verifies that const-ref is propagated through type_caster cast_op.
432+
The returned ConstRefCasted type is a mimimal type that is constructed to
433+
reference the casting mode used.
434+
"""
435+
x = False
436+
assert m.takes(x) == 1
437+
assert m.takes_move(x) == 1
438+
439+
assert m.takes_ptr(x) == 3
440+
assert m.takes_ref(x) == 2
441+
assert m.takes_ref_wrap(x) == 2
442+
443+
assert m.takes_const_ptr(x) == 5
444+
assert m.takes_const_ref(x) == 4
445+
assert m.takes_const_ref_wrap(x) == 4

0 commit comments

Comments
 (0)