Skip to content

Commit b4f767b

Browse files
committed
Also use PyObjectTypeIsConvertibleToStdVector() in array_caster.
1 parent 3631886 commit b4f767b

File tree

3 files changed

+29
-10
lines changed

3 files changed

+29
-10
lines changed

include/pybind11/stl.h

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,8 @@ struct array_caster {
324324
return size == Size;
325325
}
326326

327-
public:
328-
bool load(handle src, bool convert) {
329-
if (!isinstance<sequence>(src)) {
330-
return false;
331-
}
332-
auto l = reinterpret_borrow<sequence>(src);
327+
bool convert_elements(handle seq, bool convert) {
328+
auto l = reinterpret_borrow<sequence>(seq);
333329
if (!require_size(l.size())) {
334330
return false;
335331
}
@@ -344,6 +340,25 @@ struct array_caster {
344340
return true;
345341
}
346342

343+
public:
344+
bool load(handle src, bool convert) {
345+
if (!PyObjectTypeIsConvertibleToStdVector(src.ptr())) {
346+
return false;
347+
}
348+
if (isinstance<sequence>(src)) {
349+
return convert_elements(src, convert);
350+
}
351+
if (!convert) {
352+
return false;
353+
}
354+
// Designed to be behavior-equivalent to passing tuple(src) from Python:
355+
// The conversion to a tuple will first exhaust the generator object, to ensure that
356+
// the generator is not left in an unpredictable (to the caller) partially-consumed
357+
// state.
358+
assert(isinstance<iterable>(src));
359+
return convert_elements(tuple(reinterpret_borrow<iterable>(src)), convert);
360+
}
361+
347362
template <typename T>
348363
static handle cast(T &&src, return_value_policy policy, handle parent) {
349364
list l(src.size());

tests/test_stl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,8 @@ TEST_SUBMODULE(stl, m) {
553553
m.def("pass_std_vector_pair_int", [](const std::vector<std::pair<int, int>> &vec_pair_int) {
554554
return vec_pair_int.size();
555555
});
556+
m.def("pass_std_array_int_2",
557+
[](const std::array<int, 2> &arr_int) { return arr_int.size(); });
556558
m.def("pass_std_set_int", [](const std::set<int> &set_int) { return set_int.size(); });
557559
m.def("pass_std_map_int", [](const std::map<int, int> &map_int) { return map_int.size(); });
558560
}

tests/test_stl.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,15 +381,17 @@ def test_return_vector_bool_raw_ptr():
381381
assert len(v) == 4513
382382

383383

384-
def test_pass_std_vector_int():
385-
fn = m.pass_std_vector_int
384+
@pytest.mark.parametrize("fn", [m.pass_std_vector_int, m.pass_std_array_int_2])
385+
def test_pass_std_vector_int(fn):
386386
assert fn([1, 2]) == 2
387387
assert fn((1, 2)) == 2
388388
assert fn({1, 2}) == 2
389389
assert fn({"x": 1, "y": 2}.values()) == 2
390390
assert fn({1: None, 2: None}.keys()) == 2
391-
assert fn(i for i in range(3)) == 3
392-
assert fn(map(lambda i: i, range(4))) == 4 # noqa: C417
391+
assert fn(i for i in range(2)) == 2
392+
assert fn(map(lambda i: i, range(2))) == 2 # noqa: C417
393+
with pytest.raises(TypeError):
394+
fn({1: 2, 3: 4})
393395
with pytest.raises(TypeError):
394396
fn({})
395397

0 commit comments

Comments
 (0)