Skip to content

Commit 53d2164

Browse files
test_numpy_dtypes: Add test for py::vectorize()
1 parent 8c0cd94 commit 53d2164

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

tests/test_numpy_dtypes.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,16 @@ TEST_SUBMODULE(numpy_dtypes, m) {
462462
m.def("buffer_to_dtype", [](py::buffer& buf) { return py::dtype(buf.request()); });
463463

464464
// test_scalar_conversion
465-
m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; });
465+
auto f_simple = [](SimpleStruct s) { return s.uint_ * 10; };
466+
m.def("f_simple", f_simple);
466467
m.def("f_packed", [](PackedStruct s) { return s.uint_ * 10; });
467468
m.def("f_nested", [](NestedStruct s) { return s.a.uint_ * 10; });
468469

470+
// test_vectorize
471+
m.def("f_simple_vectorized", py::vectorize(f_simple));
472+
auto f_simple_pass_thru = [](SimpleStruct s) { return s; };
473+
m.def("f_simple_pass_thru_vectorized", py::vectorize(f_simple_pass_thru));
474+
469475
// test_register_dtype
470476
m.def("register_dtype", []() { PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_); });
471477

tests/test_numpy_dtypes.py

+9
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,15 @@ def test_scalar_conversion():
287287
assert 'incompatible function arguments' in str(excinfo.value)
288288

289289

290+
def test_vectorize():
291+
n = 3
292+
array = m.create_rec_simple(n)
293+
values = m.f_simple_vectorized(array)
294+
np.testing.assert_array_equal(values, [0, 10, 20])
295+
array_2 = m.f_simple_pass_thru_vectorized(array)
296+
np.testing.assert_array_equal(array, array_2)
297+
298+
290299
def test_register_dtype():
291300
with pytest.raises(RuntimeError) as excinfo:
292301
m.register_dtype()

0 commit comments

Comments
 (0)