Skip to content

Commit 80b04cb

Browse files
committed
Add support for positional args with args/kwargs
This commit rewrites the function dispatcher code to support mixing regular arguments with py::args/py::kwargs arguments. It also simplifies the argument loader noticeably as it no longer has to worry about args/kwargs: all of that is now sorted out in the dispatcher, which now simply appends a tuple/dict if the function takes py::args/py::kwargs, then passes all the arguments in a single tuple. Some (intentional) restrictions: - you may not bind a function that has args/kwargs somewhere other than the end (this somewhat matches Python, and keeps the dispatch code a little cleaner by being able to not worry about where to inject the args/kwargs in the argument list). - If you specify an argument both positionally and via a keyword argument, you get a TypeError (as you do in Python).
1 parent 7830e85 commit 80b04cb

File tree

7 files changed

+237
-81
lines changed

7 files changed

+237
-81
lines changed

docs/advanced/functions.rst

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -256,16 +256,21 @@ Such functions can also be created using pybind11:
256256
m.def("generic", &generic);
257257
258258
The class ``py::args`` derives from ``py::tuple`` and ``py::kwargs`` derives
259-
from ``py::dict``. Note that the ``kwargs`` argument is invalid if no keyword
260-
arguments were actually provided. Please refer to the other examples for
261-
details on how to iterate over these, and on how to cast their entries into
262-
C++ objects. A demonstration is also available in
263-
``tests/test_kwargs_and_defaults.cpp``.
259+
from ``py::dict``.
264260

265-
.. warning::
261+
You may also use just one or the other, and may combine these with other
262+
arguments as long as the ``py::args`` and ``py::kwargs`` arguments are the last
263+
arguments accepted by the function.
264+
265+
Please refer to the other examples for details on how to iterate over these,
266+
and on how to cast their entries into C++ objects. A demonstration is also
267+
available in ``tests/test_kwargs_and_defaults.cpp``.
268+
269+
.. note::
266270

267-
Unlike Python, pybind11 does not allow combining normal parameters with the
268-
``args`` / ``kwargs`` special parameters.
271+
When combining \*args or \*\*kwargs with :ref:`keyword_args` you should
272+
*not* include ``py::arg`` tags for the ``py::args`` and ``py::kwargs``
273+
arguments.
269274

270275
Default arguments revisited
271276
===========================

include/pybind11/attr.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ struct function_record {
9595
std::vector<argument_record> args;
9696

9797
/// Pointer to lambda function which converts arguments and performs the actual call
98-
handle (*impl) (function_record *, handle, handle, handle) = nullptr;
98+
handle (*impl) (function_record *, handle, handle) = nullptr;
9999

100100
/// Storage for the wrapped function pointer and captured data, if any
101101
void *data[3] = { };
@@ -124,7 +124,7 @@ struct function_record {
124124
/// True if this is a method
125125
bool is_method : 1;
126126

127-
/// Number of arguments
127+
/// Number of arguments (including py::args and/or py::kwargs, if present)
128128
uint16_t nargs;
129129

130130
/// Python method object
@@ -378,8 +378,8 @@ template <typename... Args> struct process_attributes {
378378
template <typename... Extra,
379379
size_t named = constexpr_sum(std::is_base_of<arg, Extra>::value...),
380380
size_t self = constexpr_sum(std::is_same<is_method, Extra>::value...)>
381-
constexpr bool expected_num_args(size_t nargs) {
382-
return named == 0 || (self + named) == nargs;
381+
constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) {
382+
return named == 0 || (self + named + has_args + has_kwargs) == nargs;
383383
}
384384

385385
NAMESPACE_END(detail)

include/pybind11/cast.h

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,22 +1217,40 @@ constexpr arg operator"" _a(const char *name, size_t) { return arg(name); }
12171217

12181218
NAMESPACE_BEGIN(detail)
12191219

1220+
// forward declaration
1221+
struct function_record;
1222+
1223+
// Helper struct to only allow py::args and/or py::kwargs at the end of the function arguments
1224+
template <bool args, bool kwargs, bool args_kwargs_are_last> struct assert_args_kwargs_must_be_last {
1225+
static constexpr bool has_args = args, has_kwargs = kwargs;
1226+
static_assert(args_kwargs_are_last, "py::args/py::kwargs are only permitted as the last argument(s) of a function");
1227+
};
1228+
template <typename... T> struct args_kwargs_must_be_last;
1229+
template <typename T1, typename... Tmore> struct args_kwargs_must_be_last<T1, Tmore...>
1230+
: args_kwargs_must_be_last<Tmore...> {};
1231+
template <typename... T> struct args_kwargs_must_be_last<args, T...>
1232+
: assert_args_kwargs_must_be_last<true, false, sizeof...(T) == 0> {};
1233+
template <typename... T> struct args_kwargs_must_be_last<kwargs, T...>
1234+
: assert_args_kwargs_must_be_last<false, true, sizeof...(T) == 0> {};
1235+
template <typename... T> struct args_kwargs_must_be_last<args, kwargs, T...>
1236+
: assert_args_kwargs_must_be_last<true, true, sizeof...(T) == 0> {};
1237+
template <> struct args_kwargs_must_be_last<> : assert_args_kwargs_must_be_last<false, false, true> {};
1238+
12201239
/// Helper class which loads arguments for C++ functions called from Python
12211240
template <typename... Args>
12221241
class argument_loader {
1223-
using itypes = type_list<intrinsic_t<Args>...>;
12241242
using indices = make_index_sequence<sizeof...(Args)>;
12251243

1226-
public:
1227-
argument_loader() : value() {} // Helps gcc-7 properly initialize value
1244+
using check_args_kwargs = args_kwargs_must_be_last<intrinsic_t<Args>...>;
12281245

1229-
static constexpr auto has_kwargs = std::is_same<itypes, type_list<args, kwargs>>::value;
1230-
static constexpr auto has_args = has_kwargs || std::is_same<itypes, type_list<args>>::value;
1246+
public:
1247+
static constexpr bool has_kwargs = check_args_kwargs::has_kwargs;
1248+
static constexpr bool has_args = check_args_kwargs::has_args;
12311249

12321250
static PYBIND11_DESCR arg_names() { return detail::concat(make_caster<Args>::name()...); }
12331251

1234-
bool load_args(handle args, handle kwargs) {
1235-
return load_impl(args, kwargs, itypes{});
1252+
bool load_args(handle args) {
1253+
return load_impl_sequence(args, indices{});
12361254
}
12371255

12381256
template <typename Return, typename Func>
@@ -1247,26 +1265,12 @@ class argument_loader {
12471265
}
12481266

12491267
private:
1250-
bool load_impl(handle args_, handle, type_list<args>) {
1251-
std::get<0>(value).load(args_, true);
1252-
return true;
1253-
}
1254-
1255-
bool load_impl(handle args_, handle kwargs_, type_list<args, kwargs>) {
1256-
std::get<0>(value).load(args_, true);
1257-
std::get<1>(value).load(kwargs_, true);
1258-
return true;
1259-
}
1260-
1261-
bool load_impl(handle args, handle, ... /* anything else */) {
1262-
return load_impl_sequence(args, indices{});
1263-
}
12641268

12651269
static bool load_impl_sequence(handle, index_sequence<>) { return true; }
12661270

12671271
template <size_t... Is>
1268-
bool load_impl_sequence(handle src, index_sequence<Is...>) {
1269-
for (bool r : {std::get<Is>(value).load(PyTuple_GET_ITEM(src.ptr(), Is), true)...})
1272+
bool load_impl_sequence(handle args, index_sequence<Is...>) {
1273+
for (bool r : {std::get<Is>(value).load(PyTuple_GET_ITEM(args.ptr(), Is), true)...})
12701274
if (!r)
12711275
return false;
12721276
return true;

include/pybind11/pybind11.h

Lines changed: 123 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ class cpp_function : public function {
8282
/// Special internal constructor for functors, lambda functions, etc.
8383
template <typename Func, typename Return, typename... Args, typename... Extra /*,*/ PYBIND11_NOEXCEPT_TPL_ARG>
8484
void initialize(Func &&f, Return (*)(Args...) PYBIND11_NOEXCEPT_SPECIFIER, const Extra&... extra) {
85-
static_assert(detail::expected_num_args<Extra...>(sizeof...(Args)),
85+
constexpr bool have_args = detail::any_of<std::is_same<args, detail::intrinsic_t<Args>>...>::value,
86+
have_kwargs = detail::any_of<std::is_same<kwargs, detail::intrinsic_t<Args>>...>::value;
87+
static_assert(detail::expected_num_args<Extra...>(sizeof...(Args), have_args, have_kwargs),
8688
"The number of named arguments does not match the function signature");
8789

8890
struct capture { typename std::remove_reference<Func>::type f; };
@@ -117,11 +119,11 @@ class cpp_function : public function {
117119
>;
118120

119121
/* Dispatch code which converts function arguments and performs the actual function call */
120-
rec->impl = [](detail::function_record *rec, handle args, handle kwargs, handle parent) -> handle {
122+
rec->impl = [](detail::function_record *rec, handle args, handle parent) -> handle {
121123
cast_in args_converter;
122124

123125
/* Try to cast the function arguments into the C++ domain */
124-
if (!args_converter.load_args(args, kwargs))
126+
if (!args_converter.load_args(args))
125127
return PYBIND11_TRY_NEXT_OVERLOAD;
126128

127129
/* Invoke call policy pre-call hook */
@@ -379,66 +381,144 @@ class cpp_function : public function {
379381
}
380382

381383
/// Main dispatch logic for calls to functions bound using pybind11
382-
static PyObject *dispatcher(PyObject *self, PyObject *args, PyObject *kwargs) {
384+
static PyObject *dispatcher(PyObject *self, PyObject *args_in, PyObject *kwargs_in) {
383385
/* Iterator over the list of potentially admissible overloads */
384386
detail::function_record *overloads = (detail::function_record *) PyCapsule_GetPointer(self, nullptr),
385387
*it = overloads;
386388

387389
/* Need to know how many arguments + keyword arguments there are to pick the right overload */
388-
size_t nargs = (size_t) PyTuple_GET_SIZE(args),
389-
nkwargs = kwargs ? (size_t) PyDict_Size(kwargs) : 0;
390+
const size_t n_args_in = (size_t) PyTuple_GET_SIZE(args_in);
390391

391-
handle parent = nargs > 0 ? PyTuple_GET_ITEM(args, 0) : nullptr,
392+
handle parent = n_args_in > 0 ? PyTuple_GET_ITEM(args_in, 0) : nullptr,
392393
result = PYBIND11_TRY_NEXT_OVERLOAD;
393394
try {
394395
for (; it != nullptr; it = it->next) {
395-
auto args_ = reinterpret_borrow<tuple>(args);
396-
size_t kwargs_consumed = 0;
397-
398396
/* For each overload:
399-
1. If the required list of arguments is longer than the
400-
actually provided amount, create a copy of the argument
401-
list and fill in any available keyword/default arguments.
402-
2. Ensure that all keyword arguments were "consumed"
403-
3. Call the function call dispatcher (function_record::impl)
397+
1. Copy all positional arguments we were given, also checking to make sure that
398+
named positional arguments weren't *also* specified via kwarg.
399+
2. If we weren't given enough, try to make up the ommitted ones by checking
400+
whether they were provided by a kwarg matching the `py::arg("name")` name. If
401+
so, use it (and remove it from kwargs; if not, see if the function binding
402+
provided a default that we can use.
403+
3. Ensure that either all keyword arguments were "consumed", or that the function
404+
takes a kwargs argument to accept unconsumed kwargs.
405+
4. Any positional arguments still left get put into a tuple (for args), and any
406+
leftover kwargs get put into a dict.
407+
5. Pack everything into a tuple; if we have py::args or py::kwargs, they are an
408+
extra tuple or dict at the end of the positional arguments.
409+
6. Call the function call dispatcher (function_record::impl)
410+
411+
If one of these fail, move on to the next overload and keep trying until we get a
412+
result other than PYBIND11_TRY_NEXT_OVERLOAD.
404413
*/
405-
size_t nargs_ = nargs;
406-
if (nargs < it->args.size()) {
407-
nargs_ = it->args.size();
408-
args_ = tuple(nargs_);
409-
for (size_t i = 0; i < nargs; ++i) {
410-
handle item = PyTuple_GET_ITEM(args, i);
411-
PyTuple_SET_ITEM(args_.ptr(), i, item.inc_ref().ptr());
412-
}
413414

414-
int arg_ctr = 0;
415-
for (auto const &it2 : it->args) {
416-
int index = arg_ctr++;
417-
if (PyTuple_GET_ITEM(args_.ptr(), index))
418-
continue;
415+
size_t pos_args = it->nargs; // Number of positional arguments that we need
416+
if (it->has_args) --pos_args; // (but don't count py::args
417+
if (it->has_kwargs) --pos_args; // or py::kwargs)
419418

420-
handle value;
421-
if (kwargs)
422-
value = PyDict_GetItemString(kwargs, it2.name);
419+
if (!it->has_args && n_args_in > pos_args)
420+
continue; // Too many arguments for this overload
421+
422+
if (n_args_in < pos_args && it->args.size() < pos_args)
423+
continue; // Not enough arguments given, and not enough defaults to fill in the blanks
423424

425+
tuple pass_args(it->nargs);
426+
427+
size_t args_to_copy = std::min(pos_args, n_args_in);
428+
size_t args_copied = 0;
429+
430+
// 1. Copy any position arguments given.
431+
for (; args_copied < args_to_copy; ++args_copied) {
432+
// If we find a given positional argument that also has a named kwargs argument,
433+
// raise a TypeError like Python does. (We could also continue with the next
434+
// overload, but this seems highly likely to be a caller mistake rather than a
435+
// legitimate overload).
436+
if (kwargs_in && args_copied < it->args.size()) {
437+
handle value = PyDict_GetItemString(kwargs_in, it->args[args_copied].name);
424438
if (value)
425-
kwargs_consumed++;
426-
else if (it2.value)
427-
value = it2.value;
439+
throw type_error(std::string(it->name) + "(): got multiple values for argument '" +
440+
std::string(it->args[args_copied].name) + "'");
441+
}
442+
443+
handle item = PyTuple_GET_ITEM(args_in, args_copied);
444+
PyTuple_SET_ITEM(pass_args.ptr(), args_copied, item.inc_ref().ptr());
445+
}
446+
447+
// We'll need to copy this if we steal some kwargs for defaults
448+
dict kwargs = reinterpret_borrow<dict>(kwargs_in);
449+
450+
// 2. Check kwargs and, failing that, defaults that may help complete the list
451+
if (args_copied < pos_args) {
452+
bool copied_kwargs = false;
453+
454+
for (; args_copied < pos_args; ++args_copied) {
455+
const auto &arg = it->args[args_copied];
456+
457+
handle value;
458+
if (kwargs_in)
459+
value = PyDict_GetItemString(kwargs.ptr(), arg.name);
428460

429461
if (value) {
430-
PyTuple_SET_ITEM(args_.ptr(), index, value.inc_ref().ptr());
431-
} else {
432-
kwargs_consumed = (size_t) -1; /* definite failure */
462+
// Consume a kwargs value
463+
if (!copied_kwargs) {
464+
kwargs = reinterpret_steal<dict>(PyDict_Copy(kwargs.ptr()));
465+
copied_kwargs = true;
466+
}
467+
PyDict_DelItemString(kwargs.ptr(), arg.name);
468+
}
469+
else if (arg.value) {
470+
value = arg.value;
471+
}
472+
473+
if (value)
474+
PyTuple_SET_ITEM(pass_args.ptr(), args_copied, value.inc_ref().ptr());
475+
else
433476
break;
477+
}
478+
479+
if (args_copied < pos_args)
480+
continue; // Not enough arguments, defaults, or kwargs to fill the positional arguments
481+
}
482+
483+
// 3. Check everything was consumed (unless we have a kwargs arg)
484+
if (kwargs && kwargs.size() > 0 && !it->has_kwargs)
485+
continue; // Unconsumed kwargs, but no py::kwargs argument to accept them
486+
487+
// 4a. If we have a py::args argument, create a new tuple with leftovers
488+
if (it->has_args) {
489+
tuple extra_args;
490+
if (args_to_copy == 0) {
491+
// We didn't copy out any position arguments from the args_in tuple, so we
492+
// can use it directly:
493+
extra_args = reinterpret_borrow<tuple>(args_in);
494+
}
495+
else if (args_copied >= n_args_in) {
496+
extra_args = tuple(0);
497+
}
498+
else {
499+
size_t args_size = n_args_in - args_copied;
500+
extra_args = tuple(args_size);
501+
for (size_t i = 0; i < args_size; ++i) {
502+
handle item = PyTuple_GET_ITEM(args_in, args_copied + i);
503+
extra_args[i] = item.inc_ref().ptr();
434504
}
435505
}
506+
PyTuple_SET_ITEM(pass_args.ptr(), args_copied++, extra_args.release().ptr());
436507
}
437508

509+
// 4b. If we have a py::kwargs, pass on any remaining kwargs
510+
if (it->has_kwargs) {
511+
if (!kwargs.ptr())
512+
kwargs = dict(); // If we didn't get one, send an empty one
513+
PyTuple_SET_ITEM(pass_args.ptr(), args_copied++, kwargs.inc_ref().ptr());
514+
}
515+
516+
// 5. Put everything in a big tuple. Not technically step 5, we've been building it
517+
// in `pass_args` all along.
518+
519+
// 6. Call the function.
438520
try {
439-
if ((kwargs_consumed == nkwargs || it->has_kwargs) &&
440-
(nargs_ == it->nargs || it->has_args))
441-
result = it->impl(it, args_, kwargs, parent);
521+
result = it->impl(it, pass_args, parent);
442522
} catch (reference_cast_error &) {
443523
result = PYBIND11_TRY_NEXT_OVERLOAD;
444524
}
@@ -512,7 +592,7 @@ class cpp_function : public function {
512592
msg += "\n";
513593
}
514594
msg += "\nInvoked with: ";
515-
auto args_ = reinterpret_borrow<tuple>(args);
595+
auto args_ = reinterpret_borrow<tuple>(args_in);
516596
for (size_t ti = overloads->is_constructor ? 1 : 0; ti < args_.size(); ++ti) {
517597
msg += pybind11::repr(args_[ti]);
518598
if ((ti + 1) != args_.size() )
@@ -530,9 +610,8 @@ class cpp_function : public function {
530610
if (overloads->is_constructor) {
531611
/* When a constructor ran successfully, the corresponding
532612
holder type (e.g. std::unique_ptr) must still be initialized. */
533-
PyObject *inst = PyTuple_GET_ITEM(args, 0);
534-
auto tinfo = detail::get_type_info(Py_TYPE(inst));
535-
tinfo->init_holder(inst, nullptr);
613+
auto tinfo = detail::get_type_info(Py_TYPE(parent.ptr()));
614+
tinfo->init_holder(parent.ptr(), nullptr);
536615
}
537616
return result.ptr();
538617
}

include/pybind11/pytypes.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ class handle : public detail::object_api<handle> {
8585

8686
PyObject *ptr() const { return m_ptr; }
8787
PyObject *&ptr() { return m_ptr; }
88-
const handle& inc_ref() const { Py_XINCREF(m_ptr); return *this; }
89-
const handle& dec_ref() const { Py_XDECREF(m_ptr); return *this; }
88+
const handle& inc_ref() const & { Py_XINCREF(m_ptr); return *this; }
89+
const handle& dec_ref() const & { Py_XDECREF(m_ptr); return *this; }
9090

9191
template <typename T> T cast() const;
9292
explicit operator bool() const { return m_ptr != nullptr; }

0 commit comments

Comments
 (0)