Skip to content

Commit 7303340

Browse files
committed
Add a scope guard call policy
`m.def("foo", foo, py::call_guard<T>());` is equivalent to: ```c++ Return wrap_foo(Args...) { T guard{}; return foo(std::forward<Args>(args)...); } ```
1 parent dfd89a6 commit 7303340

File tree

6 files changed

+81
-9
lines changed

6 files changed

+81
-9
lines changed

include/pybind11/attr.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,17 @@ struct metaclass {
6767
/// Annotation to mark enums as an arithmetic type
6868
struct arithmetic { };
6969

70+
/// A call policy which places a guard variable (of type T) around the function call
71+
template <typename T>
72+
struct call_guard {
73+
static_assert(std::is_default_constructible<T>::value,
74+
"The guard type must be default constructible");
75+
76+
using type = T;
77+
};
78+
79+
template <> struct call_guard<void> { using type = detail::void_type; };
80+
7081
/// @} annotations
7182

7283
NAMESPACE_BEGIN(detail)
@@ -371,6 +382,9 @@ struct process_attribute<metaclass> : process_attribute_default<metaclass> {
371382
template <>
372383
struct process_attribute<arithmetic> : process_attribute_default<arithmetic> {};
373384

385+
template <typename T>
386+
struct process_attribute<call_guard<T>> : process_attribute_default<call_guard<T>> {};
387+
374388
/***
375389
* Process a keep_alive call policy -- invokes keep_alive_impl during the
376390
* pre-call handler if both Nurse, Patient != 0 and use the post-call handler
@@ -407,6 +421,13 @@ template <typename... Args> struct process_attributes {
407421
}
408422
};
409423

424+
template <typename T>
425+
using is_call_guard = is_<call_guard, T>;
426+
427+
/// Extract the `T` from the first `call_guard<T>` in `Extras...` (or `void_type` if none found)
428+
template <typename... Extra>
429+
using extract_call_guard_t = typename first_of_t<is_call_guard, call_guard<void>, Extra...>::type;
430+
410431
/// Check the number of named arguments at compile time
411432
template <typename... Extra,
412433
size_t named = constexpr_sum(std::is_base_of<arg, Extra>::value...),

include/pybind11/cast.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,14 +1378,14 @@ class argument_loader {
13781378
return load_impl_sequence(call, indices{});
13791379
}
13801380

1381-
template <typename Return, typename Func>
1381+
template <typename Return, typename Guard, typename Func>
13821382
enable_if_t<!std::is_void<Return>::value, Return> call(Func &&f) {
1383-
return call_impl<Return>(std::forward<Func>(f), indices{});
1383+
return call_impl<Return>(std::forward<Func>(f), indices{}, Guard{});
13841384
}
13851385

1386-
template <typename Return, typename Func>
1386+
template <typename Return, typename Guard, typename Func>
13871387
enable_if_t<std::is_void<Return>::value, void_type> call(Func &&f) {
1388-
call_impl<Return>(std::forward<Func>(f), indices{});
1388+
call_impl<Return>(std::forward<Func>(f), indices{}, Guard{});
13891389
return void_type();
13901390
}
13911391

@@ -1401,8 +1401,8 @@ class argument_loader {
14011401
return true;
14021402
}
14031403

1404-
template <typename Return, typename Func, size_t... Is>
1405-
Return call_impl(Func &&f, index_sequence<Is...>) {
1404+
template <typename Return, typename Func, size_t... Is, typename Guard>
1405+
Return call_impl(Func &&f, index_sequence<Is...>, Guard &&) {
14061406
return std::forward<Func>(f)(cast_op<Args>(std::get<Is>(value))...);
14071407
}
14081408

include/pybind11/common.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,13 @@ using is_template_base_of = decltype(is_template_base_of_impl<Base>::check((remo
506506
struct is_template_base_of : decltype(is_template_base_of_impl<Base>::check((remove_cv_t<T>*)nullptr)) { };
507507
#endif
508508

509+
/// Check if T is an instantiation of the template `Class`. For example:
510+
/// `is_<std::shared_ptr, T>` is true if `T == std::shared_ptr<U>` where U can be anything.
511+
template <template<typename...> class Class, typename T> struct is_ : std::false_type { };
512+
template <template<typename...> class C, typename... Us> struct is_<C, C<Us...>> : std::true_type { };
513+
509514
/// Check if T is std::shared_ptr<U> where U can be anything
510-
template <typename T> struct is_shared_ptr : std::false_type { };
511-
template <typename U> struct is_shared_ptr<std::shared_ptr<U>> : std::true_type { };
515+
template <typename T> using is_shared_ptr = is_<std::shared_ptr, T>;
512516

513517
/// Ignore that a variable is unused in compiler warnings
514518
inline void ignore_unused(const int *) { }

include/pybind11/pybind11.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,11 @@ class cpp_function : public function {
143143
/* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */
144144
const auto policy = detail::return_value_policy_override<Return>::policy(call.func.policy);
145145

146+
/* Function scope guard -- defaults to the compile-to-nothing `void_type` */
147+
using Guard = detail::extract_call_guard_t<Extra...>;
148+
146149
/* Perform the function call */
147-
handle result = cast_out::cast(args_converter.template call<Return>(cap->f),
150+
handle result = cast_out::cast(args_converter.template call<Return, Guard>(cap->f),
148151
policy, call.parent);
149152

150153
/* Invoke call policy post-call hook */

tests/test_keep_alive.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,36 @@ test_initializer keep_alive([](py::module &m) {
3838
py::class_<Child>(m, "Child")
3939
.def(py::init<>());
4040
});
41+
42+
struct CustomGuard {
43+
static bool enabled;
44+
45+
CustomGuard() { enabled = true; }
46+
~CustomGuard() { enabled = false; }
47+
48+
static const char *report_status() { return enabled ? "guarded" : "unguarded"; }
49+
};
50+
51+
bool CustomGuard::enabled = false;
52+
53+
test_initializer call_guard([](py::module &pm) {
54+
auto m = pm.def_submodule("call_policies");
55+
56+
m.def("unguarded_call", &CustomGuard::report_status);
57+
m.def("guarded_call", &CustomGuard::report_status, py::call_guard<CustomGuard>());
58+
59+
#if defined(WITH_THREAD) && !defined(PYPY_VERSION)
60+
// `py::call_guard<py::gil_scoped_release>()` should work in PyPy as well,
61+
// but it's unclear how to test it without `PyGILState_GetThisThreadState`.
62+
auto report_gil_status = []() {
63+
auto is_gil_held = false;
64+
if (auto tstate = py::detail::get_thread_state_unchecked())
65+
is_gil_held = (tstate == PyGILState_GetThisThreadState());
66+
67+
return is_gil_held ? "GIL held" : "GIL released";
68+
};
69+
70+
m.def("with_gil", report_gil_status);
71+
m.def("without_gil", report_gil_status, py::call_guard<py::gil_scoped_release>());
72+
#endif
73+
});

tests/test_keep_alive.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,14 @@ def test_return_none(capture):
9595
del p
9696
pytest.gc_collect()
9797
assert capture == "Releasing parent."
98+
99+
100+
def test_call_guard():
101+
from pybind11_tests import call_policies
102+
103+
assert call_policies.unguarded_call() == "unguarded"
104+
assert call_policies.guarded_call() == "guarded"
105+
106+
if hasattr(call_policies, "with_gil"):
107+
assert call_policies.with_gil() == "GIL held"
108+
assert call_policies.without_gil() == "GIL released"

0 commit comments

Comments
 (0)