Skip to content

Commit cf0dcce

Browse files
committed
Allow binding factory functions as constructors
This allows you to use: cls.def_static("__init__", &factory_function); where `factory_function` is some pointer or holder-generating factory function of the type that `cls` binds. Internally, this still results in a method, but we handle the C++-static-factory <-> python method translation by calling the factory function, then stealing the resulting object internal pointer and holder.
1 parent 2d14c1c commit cf0dcce

File tree

7 files changed

+337
-34
lines changed

7 files changed

+337
-34
lines changed

docs/advanced/classes.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,30 @@ In other words, :func:`init` creates an anonymous function that invokes an
366366
in-place constructor. Memory allocation etc. is already take care of beforehand
367367
within pybind11.
368368

369+
It is also possible to bind C++ factory functions as Python constructors
370+
(instead of or in addition to standard constructors) by using ``def_static``:
371+
372+
.. code-block:: cpp
373+
374+
py::class_<Example>(m, "Example")
375+
// Bind an existing factory function which returns a new Example pointer:
376+
.def_static("__init__", &Example::create)
377+
// Similar, but returns using an existing holder:
378+
.def_static("__init__", []() {
379+
return std::unique_ptr<Example>(new Example(arg, "another arg"));
380+
})
381+
// Can mix/overload these with regular constructors, too:
382+
.def(py::init<double>())
383+
;
384+
385+
.. note::
386+
387+
Unlike other named functions declared using ``def_static``, these
388+
constructors are *not* static on the Python side, only on the C++ side.
389+
Pybind internally converts the (non-static) Python constructor to a static
390+
C++ call when one of these constructors is invoked, then sets up the Python
391+
instance appropriately.
392+
369393
.. _classes_with_non_public_destructors:
370394

371395
Non-public destructors

include/pybind11/attr.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ struct argument_record {
131131
/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
132132
struct function_record {
133133
function_record()
134-
: is_constructor(false), is_stateless(false), is_operator(false),
135-
has_args(false), has_kwargs(false), is_method(false) { }
134+
: is_constructor(false), is_factory_constructor(false), is_stateless(false),
135+
is_operator(false), has_args(false), has_kwargs(false), is_method(false) { }
136136

137137
/// Function name
138138
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
@@ -161,6 +161,9 @@ struct function_record {
161161
/// True if name == '__init__'
162162
bool is_constructor : 1;
163163

164+
/// True if name == '__init__' and this was a `def_static` (factory function exposed as a constructor)
165+
bool is_factory_constructor : 1;
166+
164167
/// True if this is a stateless function pointer
165168
bool is_stateless : 1;
166169

@@ -216,7 +219,7 @@ struct type_record {
216219
void *(*operator_new)(size_t) = ::operator new;
217220

218221
/// Function pointer to class_<..>::init_holder
219-
void (*init_holder)(PyObject *, const void *) = nullptr;
222+
void (*init_holder)(PyObject *, const void *, PyObject *) = nullptr;
220223

221224
/// Function pointer to class_<..>::dealloc
222225
void (*dealloc)(PyObject *) = nullptr;

include/pybind11/cast.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct type_info {
2626
PyTypeObject *type;
2727
size_t type_size;
2828
void *(*operator_new)(size_t);
29-
void (*init_holder)(PyObject *, const void *);
29+
void (*init_holder)(PyObject *, const void *, PyObject *);
3030
void (*dealloc)(PyObject *);
3131
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
3232
std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
@@ -352,7 +352,7 @@ class type_caster_generic {
352352
throw cast_error("unhandled return_value_policy: should not happen!");
353353
}
354354

355-
tinfo->init_holder(inst.ptr(), existing_holder);
355+
tinfo->init_holder(inst.ptr(), existing_holder, nullptr);
356356

357357
internals.registered_instances.emplace(wrapper->value, inst.ptr());
358358

include/pybind11/class_support.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,36 @@ extern "C" inline void pybind11_object_dealloc(PyObject *self) {
244244
Py_TYPE(self)->tp_free(self);
245245
}
246246

247+
/// Swaps the pybind internals of one instance into another. `a` and `b` must be allocated,
248+
/// registered instances of the same type. The caller should typically ensure that the instances are
249+
/// uniquely referenced (foreign references are cannot be updated).
250+
///
251+
/// The holder transfer must be done separately after the call.
252+
inline void instance_swap(instance_essentials<void> *from, instance_essentials<void> *to) {
253+
auto type = Py_TYPE(to);
254+
if (type != Py_TYPE(from))
255+
pybind11_fail("instance_swap(): Cannot swap instances of different types");
256+
if (!to->value || !from->value)
257+
pybind11_fail("instance_swap(): Cannot swap unallocated instances");
258+
259+
auto &registered_instances = get_internals().registered_instances;
260+
std::pair<const void *const, void *> *to_reg = nullptr, *from_reg = nullptr;
261+
auto range = registered_instances.equal_range(to->value);
262+
for (auto it = range.first; it != range.second; ++it)
263+
if (type == Py_TYPE(it->second)) { to_reg = &*it; break; }
264+
range = registered_instances.equal_range(from->value);
265+
for (auto it = range.first; it != range.second; ++it)
266+
if (type == Py_TYPE(it->second)) { from_reg = &*it; break; }
267+
if (!to_reg || !from_reg)
268+
pybind11_fail("instance_swap(): Cannot swap unregistered instances");
269+
270+
std::swap(to->value, from->value);
271+
std::swap(to->weakrefs, from->weakrefs);
272+
if (type->tp_dictoffset != 0)
273+
std::swap(*_PyObject_GetDictPtr((PyObject *) to), *_PyObject_GetDictPtr((PyObject *) from));
274+
std::swap(to_reg->second, from_reg->second);
275+
}
276+
247277
/** Create a type which can be used as a common base for all classes with the same
248278
instance size, i.e. all classes with the same `sizeof(holder_type)`. This is
249279
needed in order to satisfy Python's requirements for multiple inheritance.

include/pybind11/pybind11.h

Lines changed: 99 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,13 @@ class cpp_function : public function {
268268
rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__");
269269
rec->nargs = (std::uint16_t) args;
270270

271+
if (rec->is_constructor && !rec->is_method) {
272+
// Handle a def_static(__init__) constructor: we expose this to Python as an instance
273+
// method, then deal with the static call internally.
274+
rec->is_factory_constructor = true;
275+
rec->is_method = true;
276+
}
277+
271278
#if PY_MAJOR_VERSION < 3
272279
if (rec->sibling && PyMethod_Check(rec->sibling.ptr()))
273280
rec->sibling = PyMethod_GET_FUNCTION(rec->sibling.ptr());
@@ -399,8 +406,10 @@ class cpp_function : public function {
399406
using namespace detail;
400407

401408
/* Iterator over the list of potentially admissible overloads */
402-
function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr),
403-
*it = overloads;
409+
auto func_capsule = reinterpret_borrow<capsule>(self);
410+
function_record *overloads = func_capsule,
411+
*it = overloads,
412+
*winner; // Stores the one we actually use
404413

405414
/* Need to know how many arguments + keyword arguments there are to pick the right overload */
406415
const size_t n_args_in = (size_t) PyTuple_GET_SIZE(args_in);
@@ -444,15 +453,21 @@ class cpp_function : public function {
444453
if (func.has_args) --pos_args; // (but don't count py::args
445454
if (func.has_kwargs) --pos_args; // or py::kwargs)
446455

447-
if (!func.has_args && n_args_in > pos_args)
456+
// If this overload is a factory function masquerading as a constructor we need to
457+
// skip the initial (uninitialized) self argument.
458+
bool skip_first = func.is_factory_constructor;
459+
460+
const size_t n_args = n_args_in - skip_first;
461+
462+
if (!func.has_args && n_args > pos_args)
448463
continue; // Too many arguments for this overload
449464

450-
if (n_args_in < pos_args && func.args.size() < pos_args)
465+
if (n_args < pos_args && func.args.size() < pos_args)
451466
continue; // Not enough arguments given, and not enough defaults to fill in the blanks
452467

453468
function_call call(func, parent);
454469

455-
size_t args_to_copy = std::min(pos_args, n_args_in);
470+
size_t args_to_copy = std::min(pos_args, n_args);
456471
size_t args_copied = 0;
457472

458473
// 1. Copy any position arguments given.
@@ -464,7 +479,7 @@ class cpp_function : public function {
464479
break;
465480
}
466481

467-
call.args.push_back(PyTuple_GET_ITEM(args_in, args_copied));
482+
call.args.push_back(PyTuple_GET_ITEM(args_in, args_copied + skip_first));
468483
call.args_convert.push_back(args_copied < func.args.size() ? func.args[args_copied].convert : true);
469484
}
470485
if (bad_kwarg)
@@ -524,7 +539,7 @@ class cpp_function : public function {
524539
size_t args_size = n_args_in - args_copied;
525540
extra_args = tuple(args_size);
526541
for (size_t i = 0; i < args_size; ++i) {
527-
handle item = PyTuple_GET_ITEM(args_in, args_copied + i);
542+
handle item = PyTuple_GET_ITEM(args_in, args_copied + i + skip_first);
528543
extra_args[i] = item.inc_ref().ptr();
529544
}
530545
}
@@ -563,8 +578,10 @@ class cpp_function : public function {
563578
result = PYBIND11_TRY_NEXT_OVERLOAD;
564579
}
565580

566-
if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD)
581+
if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD) {
582+
winner = &func;
567583
break;
584+
}
568585

569586
if (overloaded) {
570587
// The (overloaded) call failed; if the call has at least one argument that
@@ -591,8 +608,10 @@ class cpp_function : public function {
591608
result = PYBIND11_TRY_NEXT_OVERLOAD;
592609
}
593610

594-
if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD)
611+
if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD) {
612+
winner = const_cast<function_record *>(&call.func);
595613
break;
614+
}
596615
}
597616
}
598617
} catch (error_already_set &e) {
@@ -638,20 +657,32 @@ class cpp_function : public function {
638657
msg += " "+ std::to_string(++ctr) + ". ";
639658

640659
bool wrote_sig = false;
641-
if (overloads->is_constructor) {
642-
// For a constructor, rewrite `(self: Object, arg0, ...) -> NoneType` as `Object(arg0, ...)`
660+
if (it2->is_constructor) {
643661
std::string sig = it2->signature;
644-
size_t start = sig.find('(') + 7; // skip "(self: "
645-
if (start < sig.size()) {
646-
// End at the , for the next argument
647-
size_t end = sig.find(", "), next = end + 2;
662+
if (!it2->is_factory_constructor) {
663+
// For a constructor, rewrite `(self: Object, arg0, ...) -> NoneType` as `Object(arg0, ...)`
664+
size_t start = sig.find('(') + 7; // skip "(self: "
665+
if (start < sig.size()) {
666+
// End at the , for the next argument
667+
size_t end = sig.find(", "), next = end + 2;
668+
size_t ret = sig.rfind(" -> ");
669+
// Or the ), if there is no comma:
670+
if (end >= sig.size()) next = end = sig.find(')');
671+
if (start < end && next < sig.size()) {
672+
msg.append(sig, start, end - start);
673+
msg += '(';
674+
msg.append(sig, next, ret - next);
675+
wrote_sig = true;
676+
}
677+
}
678+
}
679+
else {
680+
// A factory function masquerading as a constructor; rewrite
681+
// `(arg0: whatever...) -> ClassName` as `ClassName(arg0: whatever...)`
648682
size_t ret = sig.rfind(" -> ");
649-
// Or the ), if there is no comma:
650-
if (end >= sig.size()) next = end = sig.find(')');
651-
if (start < end && next < sig.size()) {
652-
msg.append(sig, start, end - start);
653-
msg += '(';
654-
msg.append(sig, next, ret - next);
683+
if (ret < sig.size()) {
684+
msg.append(sig, ret + 4, sig.npos);
685+
msg.append(sig, 0, ret);
655686
wrote_sig = true;
656687
}
657688
}
@@ -690,12 +721,46 @@ class cpp_function : public function {
690721
msg += it->signature;
691722
PyErr_SetString(PyExc_TypeError, msg.c_str());
692723
return nullptr;
693-
} else {
694-
if (overloads->is_constructor) {
695-
/* When a constructor ran successfully, the corresponding
696-
holder type (e.g. std::unique_ptr) must still be initialized. */
724+
} else { // Call succeeded
725+
if (winner->is_constructor) {
697726
auto tinfo = get_type_info(Py_TYPE(parent.ptr()));
698-
tinfo->init_holder(parent.ptr(), nullptr);
727+
if (!winner->is_factory_constructor) {
728+
/* When an ordinary constructor ran successfully, the corresponding
729+
holder type (e.g. std::unique_ptr) must still be initialized. */
730+
tinfo->init_holder(parent.ptr(), nullptr, nullptr);
731+
}
732+
else {
733+
/* For a factory function exposed as a constructor, the corresponding pointer
734+
and holder must be transferred from the returned object into the allocated
735+
instance */
736+
auto *result_inst = (detail::instance_essentials<void> *) result.ptr(),
737+
*parent_inst = (detail::instance_essentials<void> *) parent.ptr();
738+
std::string failure;
739+
// Make sure the factory function gave us exactly the right type:
740+
if (Py_TYPE(result.ptr()) != tinfo->type)
741+
failure = std::string("static __init__() should return '") + tinfo->type->tp_name +
742+
"', not '" + Py_TYPE(result.ptr())->tp_name + "'";
743+
// The factory function must give back a unique reference:
744+
else if (result.ref_count() != 1)
745+
failure = "static __init__() returned an object with multiple references";
746+
// Guard against accidentally specifying a reference r.v. policy or similar:
747+
else if (!result_inst->holder_constructed && !result_inst->owned)
748+
failure = "static __init__() failed: cannot construct from an unowned reference";
749+
750+
if (!failure.empty()) {
751+
result.dec_ref();
752+
PyErr_SetString(PyExc_TypeError, failure.c_str());
753+
return nullptr;
754+
}
755+
756+
// Swap the pointer and other internals, then transfer the holder:
757+
detail::instance_swap(result_inst, parent_inst);
758+
tinfo->init_holder(parent.ptr(), nullptr, result.ptr());
759+
// We transfered the value out of result, so let it be destroyed:
760+
result.dec_ref();
761+
762+
result = none().release();
763+
}
699764
}
700765
return result.ptr();
701766
}
@@ -1131,10 +1196,15 @@ class class_ : public detail::generic_type {
11311196
}
11321197
}
11331198

1134-
/// Initialize holder object of an instance, possibly given a pointer to an existing holder
1135-
static void init_holder(PyObject *inst_, const void *holder_ptr) {
1199+
/// Initialize holder object of an instance, possibly given a pointer to an existing holder or
1200+
/// an alternative instance to transfer the holder from
1201+
static void init_holder(PyObject *inst_, const void *holder_in, PyObject *holder_from) {
11361202
auto inst = (instance_type *) inst_;
1137-
init_holder_helper(inst, (const holder_type *) holder_ptr, inst->value);
1203+
const holder_type *holder_ptr = (const holder_type *) holder_in;
1204+
if (!holder_ptr && holder_from)
1205+
holder_ptr = &((instance_type *) holder_from)->holder;
1206+
1207+
init_holder_helper(inst, holder_ptr, inst->value);
11381208
}
11391209

11401210
static void dealloc(PyObject *inst_) {

0 commit comments

Comments
 (0)