Skip to content

Commit 94d0a9f

Browse files
committed
Improve constructor resolution in variant_caster
Currently, `py::int_(1).cast<variant<double, int>>()` fills the `double` slot of the variant. This commit switches the loader to a 2-pass scheme in order to correctly fill the `int` slot.
1 parent 93e3eac commit 94d0a9f

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

include/pybind11/stl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,12 @@ struct variant_caster<V<Ts...>> {
315315
bool load_alternative(handle, bool, type_list<>) { return false; }
316316

317317
bool load(handle src, bool convert) {
318+
// Do a first pass without conversions to improve constructor resolution.
319+
// E.g. `py::int_(1).cast<variant<double, int>>()` needs to fill the `int`
320+
// slot of the variant. Without two-pass loading `double` would be filled
321+
// because it appears first and a conversion is possible.
322+
if (convert && load_alternative(src, false, type_list<Ts...>{}))
323+
return true;
318324
return load_alternative(src, convert, type_list<Ts...>{});
319325
}
320326

tests/test_python_types.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,10 @@ test_initializer python_types([](py::module &m) {
366366
return std::visit(visitor(), v);
367367
});
368368

369+
m.def("load_variant_2pass", [](std::variant<double, int> v) {
370+
return std::visit(visitor(), v);
371+
});
372+
369373
m.def("cast_variant", []() {
370374
using V = std::variant<int, std::string>;
371375
return py::make_tuple(V(5), V("Hello"));

tests/test_python_types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,16 @@ def test_exp_optional():
373373

374374
@pytest.mark.skipif(not hasattr(pybind11_tests, "load_variant"), reason='no <variant>')
375375
def test_variant(doc):
376-
from pybind11_tests import load_variant, cast_variant
376+
from pybind11_tests import load_variant, load_variant_2pass, cast_variant
377377

378378
assert load_variant(1) == "int"
379379
assert load_variant("1") == "std::string"
380380
assert load_variant(1.0) == "double"
381381
assert load_variant(None) == "std::nullptr_t"
382+
383+
assert load_variant_2pass(1) == "int"
384+
assert load_variant_2pass(1.0) == "double"
385+
382386
assert cast_variant() == (5, "Hello")
383387

384388
assert doc(load_variant) == "load_variant(arg0: Union[int, str, float, None]) -> str"

0 commit comments

Comments
 (0)