Skip to content

Commit b48d4a0

Browse files
committed
Added py::args ref counting tests
1 parent 367d723 commit b48d4a0

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

tests/test_kwargs_and_defaults.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
*/
99

1010
#include "pybind11_tests.h"
11+
#include "constructor_stats.h"
1112
#include <pybind11/stl.h>
1213

1314
TEST_SUBMODULE(kwargs_and_defaults, m) {
@@ -53,6 +54,34 @@ TEST_SUBMODULE(kwargs_and_defaults, m) {
5354
m.def("mixed_plus_args_kwargs_defaults", mixed_plus_both,
5455
py::arg("i") = 1, py::arg("j") = 3.14159);
5556

57+
// test_args_refcount
58+
// PyPy needs a garbage collection to get the reference count values to match CPython's behaviour
59+
#ifdef PYPY_VERSION
60+
#define GC_IF_NEEDED ConstructorStats::gc()
61+
#else
62+
#define GC_IF_NEEDED
63+
#endif
64+
m.def("arg_refcount_h", [](py::handle h) { GC_IF_NEEDED; return h.ref_count(); });
65+
m.def("arg_refcount_h", [](py::handle h, py::handle, py::handle) { GC_IF_NEEDED; return h.ref_count(); });
66+
m.def("arg_refcount_o", [](py::object o) { GC_IF_NEEDED; return o.ref_count(); });
67+
m.def("args_refcount", [](py::args a) {
68+
GC_IF_NEEDED;
69+
py::tuple t(a.size());
70+
for (size_t i = 0; i < a.size(); i++)
71+
// Use raw Python API here to avoid an extra, intermediate incref on the tuple item:
72+
t[i] = (int) Py_REFCNT(PyTuple_GET_ITEM(a.ptr(), static_cast<ssize_t>(i)));
73+
return t;
74+
});
75+
m.def("mixed_args_refcount", [](py::object o, py::args a) {
76+
GC_IF_NEEDED;
77+
py::tuple t(a.size() + 1);
78+
t[0] = o.ref_count();
79+
for (size_t i = 0; i < a.size(); i++)
80+
// Use raw Python API here to avoid an extra, intermediate incref on the tuple item:
81+
t[i + 1] = (int) Py_REFCNT(PyTuple_GET_ITEM(a.ptr(), static_cast<ssize_t>(i)));
82+
return t;
83+
});
84+
5685
// pybind11 won't allow these to be bound: args and kwargs, if present, must be at the end.
5786
// Uncomment these to test that the static_assert is indeed working:
5887
// m.def("bad_args1", [](py::args, int) {});

tests/test_kwargs_and_defaults.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,43 @@ def test_mixed_args_and_kwargs(msg):
105105
106106
Invoked with: 1, 2; kwargs: j=1
107107
""" # noqa: E501 line too long
108+
109+
110+
def test_args_refcount():
111+
"""Issue/PR #1216 - py::args elements get double-inc_ref()ed when combined with regular
112+
arguments"""
113+
refcount = m.arg_refcount_h
114+
115+
myval = 54321
116+
expected = refcount(myval)
117+
assert m.arg_refcount_h(myval) == expected
118+
assert m.arg_refcount_o(myval) == expected + 1
119+
assert m.arg_refcount_h(myval) == expected
120+
assert refcount(myval) == expected
121+
122+
assert m.mixed_plus_args(1, 2.0, "a", myval) == (1, 2.0, ("a", myval))
123+
assert refcount(myval) == expected
124+
125+
assert m.mixed_plus_kwargs(3, 4.0, a=1, b=myval) == (3, 4.0, {"a": 1, "b": myval})
126+
assert refcount(myval) == expected
127+
128+
assert m.args_function(-1, myval) == (-1, myval)
129+
assert refcount(myval) == expected
130+
131+
assert m.mixed_plus_args_kwargs(5, 6.0, myval, a=myval) == (5, 6.0, (myval,), {"a": myval})
132+
assert refcount(myval) == expected
133+
134+
assert m.args_kwargs_function(7, 8, myval, a=1, b=myval) == \
135+
((7, 8, myval), {"a": 1, "b": myval})
136+
assert refcount(myval) == expected
137+
138+
exp3 = refcount(myval, myval, myval)
139+
assert m.args_refcount(myval, myval, myval) == (exp3, exp3, exp3)
140+
assert refcount(myval) == expected
141+
142+
# This function takes the first arg as a `py::object` and the rest as a `py::args`. Unlike the
143+
# previous case, when we have both positional and `py::args` we need to construct a new tuple
144+
# for the `py::args`; in the previous case, we could simply inc_ref and pass on Python's input
145+
# tuple without having to inc_ref the individual elements, but here we can't, hence the extra
146+
# refs.
147+
assert m.mixed_args_refcount(myval, myval, myval) == (exp3 + 3, exp3 + 3, exp3 + 3)

0 commit comments

Comments
 (0)