Skip to content

Commit d63633f

Browse files
committed
Add support for non-converting arguments
This adds support for controlling the `convert` flag of arguments through the py::arg annotation. This then allows arguments to be flagged as non-converting, which the type_caster is able to use to request different behaviour. Currently, AFAICS `convert` is only used for type converters of regular pybind11-registered types; all of the other core type_casters ignore it. We can, however, repurpose it to control internal conversion of converters like Eigen and `array`: most usefully to give callers a way to disable the conversion that would otherwise occur when a `Eigen::Ref<const Eigen::Matrix>` argument is passed a numpy array that requires conversion (either because it has an incompatible stride or the wrong dtype). Specifying a noconvert looks like one of these: m.def("f1", &f, "a"_a.noconvert() = "default"); // Named, default, noconvert m.def("f2", &f, "a"_a.noconvert()); // Named, no default, no converting m.def("f3", &f, py::arg().noconvert()); // Unnamed, no default, no converting (The last part--being able to declare a py::arg without a name--is new: previous py::arg() only accepted named keyword arguments). Such an non-convert argument is then passed `convert = false` by the type caster when loading the argument. Whether this has an effect is up to the type caster itself, but as mentioned above, this would be extremely helpful for the Eigen support to give a nicer way to specify a "no-copy" mode than the custom wrapper in the current PR, and moreover isn't an Eigen-specific hack.
1 parent 0558a9a commit d63633f

File tree

5 files changed

+190
-52
lines changed

5 files changed

+190
-52
lines changed

include/pybind11/attr.h

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,17 @@ struct undefined_t;
6969
template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined_t> struct op_;
7070
template <typename... Args> struct init;
7171
template <typename... Args> struct init_alias;
72-
struct function_call;
7372
inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret);
7473

7574
/// Internal data structure which holds metadata about a keyword argument
7675
struct argument_record {
7776
const char *name; ///< Argument name
7877
const char *descr; ///< Human-readable version of the argument value
7978
handle value; ///< Associated Python object
79+
bool convert : 1; ///< True if the argument is allowed to convert when loading
8080

81-
argument_record(const char *name, const char *descr, handle value)
82-
: name(name), descr(descr), value(value) { }
81+
argument_record(const char *name, const char *descr, handle value, bool convert)
82+
: name(name), descr(descr), value(value), convert(convert) { }
8383
};
8484

8585
/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
@@ -131,7 +131,7 @@ struct function_record {
131131
bool is_method : 1;
132132

133133
/// Number of arguments (including py::args and/or py::kwargs, if present)
134-
uint16_t nargs;
134+
std::uint16_t nargs;
135135

136136
/// Python method object
137137
PyMethodDef *def = nullptr;
@@ -222,21 +222,11 @@ struct type_record {
222222
}
223223
};
224224

225-
/// Internal data associated with a single function call
226-
struct function_call {
227-
function_call(function_record &f, handle p) : func(f), parent(p) {
228-
args.reserve(f.nargs);
229-
}
230-
231-
/// The function data:
232-
const function_record &func;
233-
234-
/// Arguments passed to the function:
235-
std::vector<handle> args;
236-
237-
/// The parent, if any
238-
handle parent;
239-
};
225+
inline function_call::function_call(function_record &f, handle p) :
226+
func(f), parent(p) {
227+
args.reserve(f.nargs);
228+
args_convert.reserve(f.nargs);
229+
}
240230

241231
/**
242232
* Partial template specializations to process custom attributes provided to
@@ -300,16 +290,16 @@ template <> struct process_attribute<is_operator> : process_attribute_default<is
300290
template <> struct process_attribute<arg> : process_attribute_default<arg> {
301291
static void init(const arg &a, function_record *r) {
302292
if (r->is_method && r->args.empty())
303-
r->args.emplace_back("self", nullptr, handle());
304-
r->args.emplace_back(a.name, nullptr, handle());
293+
r->args.emplace_back("self", nullptr, handle(), true /*convert*/);
294+
r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert);
305295
}
306296
};
307297

308298
/// Process a keyword argument attribute (*with* a default value)
309299
template <> struct process_attribute<arg_v> : process_attribute_default<arg_v> {
310300
static void init(const arg_v &a, function_record *r) {
311301
if (r->is_method && r->args.empty())
312-
r->args.emplace_back("self", nullptr, handle());
302+
r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/);
313303

314304
if (!a.value) {
315305
#if !defined(NDEBUG)
@@ -330,7 +320,7 @@ template <> struct process_attribute<arg_v> : process_attribute_default<arg_v> {
330320
"Compile in debug mode for more information.");
331321
#endif
332322
}
333-
r->args.emplace_back(a.name, a.descr, a.value.inc_ref());
323+
r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert);
334324
}
335325
};
336326

include/pybind11/cast.h

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,22 +1198,26 @@ template <return_value_policy policy = return_value_policy::automatic_reference,
11981198
}
11991199

12001200
/// \ingroup annotations
1201-
/// Annotation for keyword arguments
1201+
/// Annotation for arguments
12021202
struct arg {
1203-
/// Set the name of the argument
1204-
constexpr explicit arg(const char *name) : name(name) { }
1203+
/// Constructs an argument with the name of the argument; if null or omitted, this is a positional argument.
1204+
constexpr explicit arg(const char *name = nullptr) : name(name), flag_noconvert(false) { }
12051205
/// Assign a value to this argument
12061206
template <typename T> arg_v operator=(T &&value) const;
1207+
/// Indicate that the type should not be converted in the type caster
1208+
arg &noconvert(bool flag = true) { flag_noconvert = flag; return *this; }
12071209

1208-
const char *name;
1210+
const char *name; ///< If non-null, this is a named kwargs argument
1211+
bool flag_noconvert : 1; ///< If set, do not allow conversion (requires a supporting type caster!)
12091212
};
12101213

12111214
/// \ingroup annotations
1212-
/// Annotation for keyword arguments with values
1215+
/// Annotation for arguments with values
12131216
struct arg_v : arg {
1217+
private:
12141218
template <typename T>
1215-
arg_v(const char *name, T &&x, const char *descr = nullptr)
1216-
: arg(name),
1219+
arg_v(arg &&base, T &&x, const char *descr = nullptr)
1220+
: arg(base),
12171221
value(reinterpret_steal<object>(
12181222
detail::make_caster<T>::cast(x, return_value_policy::automatic, {})
12191223
)),
@@ -1223,15 +1227,32 @@ struct arg_v : arg {
12231227
#endif
12241228
{ }
12251229

1230+
public:
1231+
/// Direct construction with name, default, and description
1232+
template <typename T>
1233+
arg_v(const char *name, T &&x, const char *descr = nullptr)
1234+
: arg_v(arg(name), std::forward<T>(x), descr) { }
1235+
1236+
/// Called internally when invoking `py::arg("a") = value`
1237+
template <typename T>
1238+
arg_v(const arg &base, T &&x, const char *descr = nullptr)
1239+
: arg_v(arg(base), std::forward<T>(x), descr) { }
1240+
1241+
/// Same as `arg::noconvert()`, but returns *this as arg_v&, not arg&
1242+
arg_v &noconvert(bool flag = true) { arg::noconvert(flag); return *this; }
1243+
1244+
/// The default value
12261245
object value;
1246+
/// The (optional) description of the default value
12271247
const char *descr;
12281248
#if !defined(NDEBUG)
1249+
/// The C++ type name of the default value (only available when compiled in debug mode)
12291250
std::string type;
12301251
#endif
12311252
};
12321253

12331254
template <typename T>
1234-
arg_v arg::operator=(T &&value) const { return {name, std::forward<T>(value)}; }
1255+
arg_v arg::operator=(T &&value) const { return {std::move(*this), std::forward<T>(value)}; }
12351256

12361257
/// Alias for backward compatibility -- to be removed in version 2.0
12371258
template <typename /*unused*/> using arg_t = arg_v;
@@ -1248,11 +1269,28 @@ NAMESPACE_BEGIN(detail)
12481269
// forward declaration
12491270
struct function_record;
12501271

1272+
/// Internal data associated with a single function call
1273+
struct function_call {
1274+
function_call(function_record &f, handle p); // Implementation in attr.h
1275+
1276+
/// The function data:
1277+
const function_record &func;
1278+
1279+
/// Arguments passed to the function:
1280+
std::vector<handle> args;
1281+
1282+
/// The `convert` value the arguments should be loaded with
1283+
std::vector<bool> args_convert;
1284+
1285+
/// The parent, if any
1286+
handle parent;
1287+
};
1288+
1289+
12511290
/// Helper class which loads arguments for C++ functions called from Python
12521291
template <typename... Args>
12531292
class argument_loader {
12541293
using indices = make_index_sequence<sizeof...(Args)>;
1255-
using function_arguments = const std::vector<handle> &;
12561294

12571295
template <typename Arg> using argument_is_args = std::is_same<intrinsic_t<Arg>, args>;
12581296
template <typename Arg> using argument_is_kwargs = std::is_same<intrinsic_t<Arg>, kwargs>;
@@ -1270,8 +1308,8 @@ class argument_loader {
12701308

12711309
static PYBIND11_DESCR arg_names() { return detail::concat(make_caster<Args>::name()...); }
12721310

1273-
bool load_args(function_arguments args) {
1274-
return load_impl_sequence(args, indices{});
1311+
bool load_args(function_call &call) {
1312+
return load_impl_sequence(call, indices{});
12751313
}
12761314

12771315
template <typename Return, typename Func>
@@ -1287,11 +1325,11 @@ class argument_loader {
12871325

12881326
private:
12891327

1290-
static bool load_impl_sequence(function_arguments, index_sequence<>) { return true; }
1328+
static bool load_impl_sequence(function_call &, index_sequence<>) { return true; }
12911329

12921330
template <size_t... Is>
1293-
bool load_impl_sequence(function_arguments args, index_sequence<Is...>) {
1294-
for (bool r : {std::get<Is>(value).load(args[Is], true)...})
1331+
bool load_impl_sequence(function_call &call, index_sequence<Is...>) {
1332+
for (bool r : {std::get<Is>(value).load(call.args[Is], call.args_convert[Is])...})
12951333
if (!r)
12961334
return false;
12971335
return true;
@@ -1380,6 +1418,13 @@ class unpacking_collector {
13801418
}
13811419

13821420
void process(list &/*args_list*/, arg_v a) {
1421+
if (!a.name)
1422+
#if defined(NDEBUG)
1423+
nameless_argument_error();
1424+
#else
1425+
nameless_argument_error(a.type);
1426+
#endif
1427+
13831428
if (m_kwargs.contains(a.name)) {
13841429
#if defined(NDEBUG)
13851430
multiple_values_error();
@@ -1412,6 +1457,15 @@ class unpacking_collector {
14121457
}
14131458
}
14141459

1460+
[[noreturn]] static void nameless_argument_error() {
1461+
throw type_error("Got kwargs without a name; only named arguments "
1462+
"may be passed via py::arg() to a python function call. "
1463+
"(compile in debug mode for details)");
1464+
}
1465+
[[noreturn]] static void nameless_argument_error(std::string type) {
1466+
throw type_error("Got kwargs without a name of type '" + type + "'; only named "
1467+
"arguments may be passed via py::arg() to a python function call. ");
1468+
}
14151469
[[noreturn]] static void multiple_values_error() {
14161470
throw type_error("Got multiple values for keyword argument "
14171471
"(compile in debug mode for details)");

include/pybind11/pybind11.h

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class cpp_function : public function {
122122
cast_in args_converter;
123123

124124
/* Try to cast the function arguments into the C++ domain */
125-
if (!args_converter.load_args(call.args))
125+
if (!args_converter.load_args(call))
126126
return PYBIND11_TRY_NEXT_OVERLOAD;
127127

128128
/* Invoke call policy pre-call hook */
@@ -198,7 +198,7 @@ class cpp_function : public function {
198198
if (c == '{') {
199199
// Write arg name for everything except *args, **kwargs and return type.
200200
if (type_depth == 0 && text[char_index] != '*' && arg_index < args) {
201-
if (!rec->args.empty()) {
201+
if (!rec->args.empty() && rec->args[arg_index].name) {
202202
signature += rec->args[arg_index].name;
203203
} else if (arg_index == 0 && rec->is_method) {
204204
signature += "self";
@@ -257,7 +257,7 @@ class cpp_function : public function {
257257
rec->signature = strdup(signature.c_str());
258258
rec->args.shrink_to_fit();
259259
rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__");
260-
rec->nargs = (uint16_t) args;
260+
rec->nargs = (std::uint16_t) args;
261261

262262
#if PY_MAJOR_VERSION < 3
263263
if (rec->sibling && PyMethod_Check(rec->sibling.ptr()))
@@ -392,8 +392,10 @@ class cpp_function : public function {
392392

393393
handle parent = n_args_in > 0 ? PyTuple_GET_ITEM(args_in, 0) : nullptr,
394394
result = PYBIND11_TRY_NEXT_OVERLOAD;
395+
395396
try {
396397
for (; it != nullptr; it = it->next) {
398+
397399
/* For each overload:
398400
1. Copy all positional arguments we were given, also checking to make sure that
399401
named positional arguments weren't *also* specified via kwarg.
@@ -435,14 +437,15 @@ class cpp_function : public function {
435437
// raise a TypeError like Python does. (We could also continue with the next
436438
// overload, but this seems highly likely to be a caller mistake rather than a
437439
// legitimate overload).
438-
if (kwargs_in && args_copied < it->args.size()) {
439-
handle value = PyDict_GetItemString(kwargs_in, it->args[args_copied].name);
440+
if (kwargs_in && args_copied < func.args.size() && func.args[args_copied].name) {
441+
handle value = PyDict_GetItemString(kwargs_in, func.args[args_copied].name);
440442
if (value)
441-
throw type_error(std::string(it->name) + "(): got multiple values for argument '" +
442-
std::string(it->args[args_copied].name) + "'");
443+
throw type_error(std::string(func.name) + "(): got multiple values for argument '" +
444+
std::string(func.args[args_copied].name) + "'");
443445
}
444446

445447
call.args.push_back(PyTuple_GET_ITEM(args_in, args_copied));
448+
call.args_convert.push_back(args_copied < func.args.size() ? func.args[args_copied].convert : true);
446449
}
447450

448451
// We'll need to copy this if we steal some kwargs for defaults
@@ -453,10 +456,10 @@ class cpp_function : public function {
453456
bool copied_kwargs = false;
454457

455458
for (; args_copied < pos_args; ++args_copied) {
456-
const auto &arg = it->args[args_copied];
459+
const auto &arg = func.args[args_copied];
457460

458461
handle value;
459-
if (kwargs_in)
462+
if (kwargs_in && arg.name)
460463
value = PyDict_GetItemString(kwargs.ptr(), arg.name);
461464

462465
if (value) {
@@ -470,8 +473,10 @@ class cpp_function : public function {
470473
value = arg.value;
471474
}
472475

473-
if (value)
476+
if (value) {
474477
call.args.push_back(value);
478+
call.args_convert.push_back(arg.convert);
479+
}
475480
else
476481
break;
477482
}
@@ -481,12 +486,12 @@ class cpp_function : public function {
481486
}
482487

483488
// 3. Check everything was consumed (unless we have a kwargs arg)
484-
if (kwargs && kwargs.size() > 0 && !it->has_kwargs)
489+
if (kwargs && kwargs.size() > 0 && !func.has_kwargs)
485490
continue; // Unconsumed kwargs, but no py::kwargs argument to accept them
486491

487492
// 4a. If we have a py::args argument, create a new tuple with leftovers
488493
tuple extra_args;
489-
if (it->has_args) {
494+
if (func.has_args) {
490495
if (args_to_copy == 0) {
491496
// We didn't copy out any position arguments from the args_in tuple, so we
492497
// can reuse it directly without copying:
@@ -502,31 +507,34 @@ class cpp_function : public function {
502507
}
503508
}
504509
call.args.push_back(extra_args);
510+
call.args_convert.push_back(false);
505511
}
506512

507513
// 4b. If we have a py::kwargs, pass on any remaining kwargs
508-
if (it->has_kwargs) {
514+
if (func.has_kwargs) {
509515
if (!kwargs.ptr())
510516
kwargs = dict(); // If we didn't get one, send an empty one
511517
call.args.push_back(kwargs);
518+
call.args_convert.push_back(false);
512519
}
513520

514521
// 5. Put everything in a vector. Not technically step 5, we've been building it
515522
// in `call.args` all along.
516523
#if !defined(NDEBUG)
517-
if (call.args.size() != call.func.nargs)
524+
if (call.args.size() != func.nargs || call.args_convert.size() != func.nargs)
518525
pybind11_fail("Internal error: function call dispatcher inserted wrong number of arguments!");
519526
#endif
520527

521528
// 6. Call the function.
522529
try {
523-
result = it->impl(call);
530+
result = func.impl(call);
524531
} catch (reference_cast_error &) {
525532
result = PYBIND11_TRY_NEXT_OVERLOAD;
526533
}
527534

528535
if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD)
529536
break;
537+
530538
}
531539
} catch (error_already_set &e) {
532540
e.restore();

0 commit comments

Comments
 (0)