diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 4102152447..8952cf0943 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1328,7 +1328,7 @@ struct handle_type_name { }; template <> struct handle_type_name { - static constexpr auto name = const_name("Buffer"); + static constexpr auto name = const_name("collections.abc.Buffer"); }; template <> struct handle_type_name { diff --git a/include/pybind11/stl.h b/include/pybind11/stl.h index 6a148e7402..4258f1584c 100644 --- a/include/pybind11/stl.h +++ b/include/pybind11/stl.h @@ -43,7 +43,7 @@ PYBIND11_NAMESPACE_BEGIN(detail) // Begin: Equivalent of // https://github.com/google/clif/blob/ae4eee1de07cdf115c0c9bf9fec9ff28efce6f6c/clif/python/runtime.cc#L388-L438 /* -The three `PyObjectTypeIsConvertibleTo*()` functions below are +The three `object_is_convertible_to_*()` functions below are the result of converging the behaviors of pybind11 and PyCLIF (http://github.com/google/clif). @@ -69,10 +69,13 @@ to prevent accidents and improve readability: are also fairly commonly used, therefore enforcing explicit conversions would have an unfavorable cost : benefit ratio; more sloppily speaking, such an enforcement would be more annoying than helpful. + +Additional checks have been added to allow types derived from `collections.abc.Set` and +`collections.abc.Mapping` (`collections.abc.Sequence` is already allowed by `PySequence_Check`). */ -inline bool PyObjectIsInstanceWithOneOfTpNames(PyObject *obj, - std::initializer_list tp_names) { +inline bool object_is_instance_with_one_of_tp_names(PyObject *obj, + std::initializer_list tp_names) { if (PyType_Check(obj)) { return false; } @@ -85,37 +88,48 @@ inline bool PyObjectIsInstanceWithOneOfTpNames(PyObject *obj, return false; } -inline bool PyObjectTypeIsConvertibleToStdVector(PyObject *obj) { - if (PySequence_Check(obj) != 0) { - return !PyUnicode_Check(obj) && !PyBytes_Check(obj); +inline bool object_is_convertible_to_std_vector(const handle &src) { + // Allow sequence-like objects, but not (byte-)string-like objects. + if (PySequence_Check(src.ptr()) != 0) { + return !PyUnicode_Check(src.ptr()) && !PyBytes_Check(src.ptr()); } - return (PyGen_Check(obj) != 0) || (PyAnySet_Check(obj) != 0) - || PyObjectIsInstanceWithOneOfTpNames( - obj, {"dict_keys", "dict_values", "dict_items", "map", "zip"}); + // Allow generators, set/frozenset and several common iterable types. + return (PyGen_Check(src.ptr()) != 0) || (PyAnySet_Check(src.ptr()) != 0) + || object_is_instance_with_one_of_tp_names( + src.ptr(), {"dict_keys", "dict_values", "dict_items", "map", "zip"}); } -inline bool PyObjectTypeIsConvertibleToStdSet(PyObject *obj) { - return (PyAnySet_Check(obj) != 0) || PyObjectIsInstanceWithOneOfTpNames(obj, {"dict_keys"}); +inline bool object_is_convertible_to_std_set(const handle &src, bool convert) { + // Allow set/frozenset and dict keys. + // In convert mode: also allow types derived from collections.abc.Set. + return ((PyAnySet_Check(src.ptr()) != 0) + || object_is_instance_with_one_of_tp_names(src.ptr(), {"dict_keys"})) + || (convert && isinstance(src, module_::import("collections.abc").attr("Set"))); } -inline bool PyObjectTypeIsConvertibleToStdMap(PyObject *obj) { - if (PyDict_Check(obj)) { +inline bool object_is_convertible_to_std_map(const handle &src, bool convert) { + // Allow dict. + if (PyDict_Check(src.ptr())) { return true; } - // Implicit requirement in the conditions below: - // A type with `.__getitem__()` & `.items()` methods must implement these - // to be compatible with https://docs.python.org/3/c-api/mapping.html - if (PyMapping_Check(obj) == 0) { - return false; - } - PyObject *items = PyObject_GetAttrString(obj, "items"); - if (items == nullptr) { - PyErr_Clear(); - return false; + // Allow types conforming to Mapping Protocol. + // According to https://docs.python.org/3/c-api/mapping.html, `PyMappingCheck()` checks for + // `__getitem__()` without checking the type of keys. In order to restrict the allowed types + // closer to actual Mapping-like types, we also check for the `items()` method. + if (PyMapping_Check(src.ptr()) != 0) { + PyObject *items = PyObject_GetAttrString(src.ptr(), "items"); + if (items != nullptr) { + bool is_convertible = (PyCallable_Check(items) != 0); + Py_DECREF(items); + if (is_convertible) { + return true; + } + } else { + PyErr_Clear(); + } } - bool is_convertible = (PyCallable_Check(items) != 0); - Py_DECREF(items); - return is_convertible; + // In convert mode: Allow types derived from collections.abc.Mapping + return convert && isinstance(src, module_::import("collections.abc").attr("Mapping")); } // @@ -172,7 +186,7 @@ struct set_caster { public: bool load(handle src, bool convert) { - if (!PyObjectTypeIsConvertibleToStdSet(src.ptr())) { + if (!object_is_convertible_to_std_set(src, convert)) { return false; } if (isinstance(src)) { @@ -203,7 +217,9 @@ struct set_caster { return s.release(); } - PYBIND11_TYPE_CASTER(type, const_name("set[") + key_conv::name + const_name("]")); + PYBIND11_TYPE_CASTER(type, + io_name("collections.abc.Set", "set") + const_name("[") + key_conv::name + + const_name("]")); }; template @@ -234,7 +250,7 @@ struct map_caster { public: bool load(handle src, bool convert) { - if (!PyObjectTypeIsConvertibleToStdMap(src.ptr())) { + if (!object_is_convertible_to_std_map(src, convert)) { return false; } if (isinstance(src)) { @@ -274,7 +290,8 @@ struct map_caster { } PYBIND11_TYPE_CASTER(Type, - const_name("dict[") + key_conv::name + const_name(", ") + value_conv::name + io_name("collections.abc.Mapping", "dict") + const_name("[") + + key_conv::name + const_name(", ") + value_conv::name + const_name("]")); }; @@ -283,7 +300,7 @@ struct list_caster { using value_conv = make_caster; bool load(handle src, bool convert) { - if (!PyObjectTypeIsConvertibleToStdVector(src.ptr())) { + if (!object_is_convertible_to_std_vector(src)) { return false; } if (isinstance(src)) { @@ -340,7 +357,9 @@ struct list_caster { return l.release(); } - PYBIND11_TYPE_CASTER(Type, const_name("list[") + value_conv::name + const_name("]")); + PYBIND11_TYPE_CASTER(Type, + io_name("collections.abc.Sequence", "list") + const_name("[") + + value_conv::name + const_name("]")); }; template @@ -416,7 +435,7 @@ struct array_caster { public: bool load(handle src, bool convert) { - if (!PyObjectTypeIsConvertibleToStdVector(src.ptr())) { + if (!object_is_convertible_to_std_vector(src)) { return false; } if (isinstance(src)) { @@ -474,10 +493,12 @@ struct array_caster { using cast_op_type = movable_cast_op_type; static constexpr auto name - = const_name(const_name(""), const_name("Annotated[")) + const_name("list[") - + value_conv::name + const_name("]") - + const_name( - const_name(""), const_name(", FixedSize(") + const_name() + const_name(")]")); + = const_name(const_name(""), const_name("typing.Annotated[")) + + io_name("collections.abc.Sequence", "list") + const_name("[") + value_conv::name + + const_name("]") + + const_name(const_name(""), + const_name(", \"FixedSize(") + const_name() + + const_name(")\"]")); }; template diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 2612edb270..d335b71e96 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -230,7 +230,7 @@ def test_ctypes_from_buffer(): def test_buffer_docstring(): assert ( m.get_buffer_info.__doc__.strip() - == "get_buffer_info(arg0: Buffer) -> pybind11_tests.buffers.buffer_info" + == "get_buffer_info(arg0: collections.abc.Buffer) -> pybind11_tests.buffers.buffer_info" ) diff --git a/tests/test_kwargs_and_defaults.py b/tests/test_kwargs_and_defaults.py index a8e19f15bb..b62e4b7412 100644 --- a/tests/test_kwargs_and_defaults.py +++ b/tests/test_kwargs_and_defaults.py @@ -22,7 +22,7 @@ def test_function_signatures(doc): assert doc(m.kw_func3) == "kw_func3(data: str = 'Hello world!') -> None" assert ( doc(m.kw_func4) - == "kw_func4(myList: list[typing.SupportsInt] = [13, 17]) -> str" + == "kw_func4(myList: collections.abc.Sequence[typing.SupportsInt] = [13, 17]) -> str" ) assert ( doc(m.kw_func_udl) diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 697e43965f..b20b098834 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -1254,7 +1254,7 @@ def test_arg_return_type_hints(doc): # std::vector assert ( doc(m.half_of_number_vector) - == "half_of_number_vector(arg0: list[Union[float, int]]) -> list[float]" + == "half_of_number_vector(arg0: collections.abc.Sequence[Union[float, int]]) -> list[float]" ) # Tuple assert ( diff --git a/tests/test_stl.cpp b/tests/test_stl.cpp index 9ddd951e0c..5eff4f5838 100644 --- a/tests/test_stl.cpp +++ b/tests/test_stl.cpp @@ -648,4 +648,19 @@ TEST_SUBMODULE(stl, m) { } return zum; }); + m.def("roundtrip_std_vector_int", [](const std::vector &v) { return v; }); + m.def("roundtrip_std_map_str_int", [](const std::map &m) { return m; }); + m.def("roundtrip_std_set_int", [](const std::set &s) { return s; }); + m.def( + "roundtrip_std_vector_int_noconvert", + [](const std::vector &v) { return v; }, + py::arg("v").noconvert()); + m.def( + "roundtrip_std_map_str_int_noconvert", + [](const std::map &m) { return m; }, + py::arg("m").noconvert()); + m.def( + "roundtrip_std_set_int_noconvert", + [](const std::set &s) { return s; }, + py::arg("s").noconvert()); } diff --git a/tests/test_stl.py b/tests/test_stl.py index 29a6bf119f..96c84374ce 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -20,7 +20,10 @@ def test_vector(doc): assert m.load_bool_vector((True, False)) assert doc(m.cast_vector) == "cast_vector() -> list[int]" - assert doc(m.load_vector) == "load_vector(arg0: list[typing.SupportsInt]) -> bool" + assert ( + doc(m.load_vector) + == "load_vector(arg0: collections.abc.Sequence[typing.SupportsInt]) -> bool" + ) # Test regression caused by 936: pointers to stl containers weren't castable assert m.cast_ptr_vector() == ["lvalue", "lvalue"] @@ -42,10 +45,13 @@ def test_array(doc): assert m.load_array(lst) assert m.load_array(tuple(lst)) - assert doc(m.cast_array) == "cast_array() -> Annotated[list[int], FixedSize(2)]" + assert ( + doc(m.cast_array) + == 'cast_array() -> typing.Annotated[list[int], "FixedSize(2)"]' + ) assert ( doc(m.load_array) - == "load_array(arg0: Annotated[list[typing.SupportsInt], FixedSize(2)]) -> bool" + == 'load_array(arg0: typing.Annotated[collections.abc.Sequence[typing.SupportsInt], "FixedSize(2)"]) -> bool' ) @@ -65,7 +71,8 @@ def test_valarray(doc): assert doc(m.cast_valarray) == "cast_valarray() -> list[int]" assert ( - doc(m.load_valarray) == "load_valarray(arg0: list[typing.SupportsInt]) -> bool" + doc(m.load_valarray) + == "load_valarray(arg0: collections.abc.Sequence[typing.SupportsInt]) -> bool" ) @@ -79,7 +86,9 @@ def test_map(doc): assert m.load_map(d) assert doc(m.cast_map) == "cast_map() -> dict[str, str]" - assert doc(m.load_map) == "load_map(arg0: dict[str, str]) -> bool" + assert ( + doc(m.load_map) == "load_map(arg0: collections.abc.Mapping[str, str]) -> bool" + ) def test_set(doc): @@ -91,7 +100,7 @@ def test_set(doc): assert m.load_set(frozenset(s)) assert doc(m.cast_set) == "cast_set() -> set[str]" - assert doc(m.load_set) == "load_set(arg0: set[str]) -> bool" + assert doc(m.load_set) == "load_set(arg0: collections.abc.Set[str]) -> bool" def test_recursive_casting(): @@ -273,7 +282,7 @@ def __fspath__(self): assert m.parent_paths(["foo/bar", "foo/baz"]) == [Path("foo"), Path("foo")] assert ( doc(m.parent_paths) - == "parent_paths(arg0: list[Union[os.PathLike, str, bytes]]) -> list[pathlib.Path]" + == "parent_paths(arg0: collections.abc.Sequence[Union[os.PathLike, str, bytes]]) -> list[pathlib.Path]" ) # py::typing::List assert m.parent_paths_list(["foo/bar", "foo/baz"]) == [Path("foo"), Path("foo")] @@ -364,7 +373,7 @@ def test_stl_pass_by_pointer(msg): msg(excinfo.value) == """ stl_pass_by_pointer(): incompatible function arguments. The following argument types are supported: - 1. (v: list[typing.SupportsInt] = None) -> list[int] + 1. (v: collections.abc.Sequence[typing.SupportsInt] = None) -> list[int] Invoked with: """ @@ -376,7 +385,7 @@ def test_stl_pass_by_pointer(msg): msg(excinfo.value) == """ stl_pass_by_pointer(): incompatible function arguments. The following argument types are supported: - 1. (v: list[typing.SupportsInt] = None) -> list[int] + 1. (v: collections.abc.Sequence[typing.SupportsInt] = None) -> list[int] Invoked with: None """ @@ -567,3 +576,145 @@ def gen_invalid(): with pytest.raises(expected_exception): m.pass_std_map_int(FakePyMappingGenObj(gen_obj)) assert not tuple(gen_obj) + + +def test_sequence_caster_protocol(doc): + from collections.abc import Sequence + + # Implements the Sequence protocol without explicitly inheriting from collections.abc.Sequence. + class BareSequenceLike: + def __init__(self, *args): + self.data = tuple(args) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.data[index] + + # Implements the Sequence protocol by reusing BareSequenceLike's implementation. + # Additionally, inherits from collections.abc.Sequence. + class FormalSequenceLike(BareSequenceLike, Sequence): + pass + + # convert mode + assert ( + doc(m.roundtrip_std_vector_int) + == "roundtrip_std_vector_int(arg0: collections.abc.Sequence[typing.SupportsInt]) -> list[int]" + ) + assert m.roundtrip_std_vector_int([1, 2, 3]) == [1, 2, 3] + assert m.roundtrip_std_vector_int((1, 2, 3)) == [1, 2, 3] + assert m.roundtrip_std_vector_int(FormalSequenceLike(1, 2, 3)) == [1, 2, 3] + assert m.roundtrip_std_vector_int(BareSequenceLike(1, 2, 3)) == [1, 2, 3] + assert m.roundtrip_std_vector_int([]) == [] + assert m.roundtrip_std_vector_int(()) == [] + assert m.roundtrip_std_vector_int(BareSequenceLike()) == [] + # noconvert mode + assert ( + doc(m.roundtrip_std_vector_int_noconvert) + == "roundtrip_std_vector_int_noconvert(v: list[int]) -> list[int]" + ) + assert m.roundtrip_std_vector_int_noconvert([1, 2, 3]) == [1, 2, 3] + assert m.roundtrip_std_vector_int_noconvert((1, 2, 3)) == [1, 2, 3] + assert m.roundtrip_std_vector_int_noconvert(FormalSequenceLike(1, 2, 3)) == [ + 1, + 2, + 3, + ] + assert m.roundtrip_std_vector_int_noconvert(BareSequenceLike(1, 2, 3)) == [1, 2, 3] + assert m.roundtrip_std_vector_int_noconvert([]) == [] + assert m.roundtrip_std_vector_int_noconvert(()) == [] + assert m.roundtrip_std_vector_int_noconvert(BareSequenceLike()) == [] + + +def test_mapping_caster_protocol(doc): + from collections.abc import Mapping + + # Implements the Mapping protocol without explicitly inheriting from collections.abc.Mapping. + class BareMappingLike: + def __init__(self, **kwargs): + self.data = dict(kwargs) + + def __len__(self): + return len(self.data) + + def __getitem__(self, key): + return self.data[key] + + def __iter__(self): + yield from self.data + + # Implements the Mapping protocol by reusing BareMappingLike's implementation. + # Additionally, inherits from collections.abc.Mapping. + class FormalMappingLike(BareMappingLike, Mapping): + pass + + a1b2c3 = {"a": 1, "b": 2, "c": 3} + # convert mode + assert ( + doc(m.roundtrip_std_map_str_int) + == "roundtrip_std_map_str_int(arg0: collections.abc.Mapping[str, typing.SupportsInt]) -> dict[str, int]" + ) + assert m.roundtrip_std_map_str_int(a1b2c3) == a1b2c3 + assert m.roundtrip_std_map_str_int(FormalMappingLike(**a1b2c3)) == a1b2c3 + assert m.roundtrip_std_map_str_int({}) == {} + assert m.roundtrip_std_map_str_int(FormalMappingLike()) == {} + with pytest.raises(TypeError): + m.roundtrip_std_map_str_int(BareMappingLike(**a1b2c3)) + # noconvert mode + assert ( + doc(m.roundtrip_std_map_str_int_noconvert) + == "roundtrip_std_map_str_int_noconvert(m: dict[str, int]) -> dict[str, int]" + ) + assert m.roundtrip_std_map_str_int_noconvert(a1b2c3) == a1b2c3 + assert m.roundtrip_std_map_str_int_noconvert({}) == {} + with pytest.raises(TypeError): + m.roundtrip_std_map_str_int_noconvert(FormalMappingLike(**a1b2c3)) + with pytest.raises(TypeError): + m.roundtrip_std_map_str_int_noconvert(BareMappingLike(**a1b2c3)) + + +def test_set_caster_protocol(doc): + from collections.abc import Set + + # Implements the Set protocol without explicitly inheriting from collections.abc.Set. + class BareSetLike: + def __init__(self, *args): + self.data = set(args) + + def __len__(self): + return len(self.data) + + def __contains__(self, item): + return item in self.data + + def __iter__(self): + yield from self.data + + # Implements the Set protocol by reusing BareSetLike's implementation. + # Additionally, inherits from collections.abc.Set. + class FormalSetLike(BareSetLike, Set): + pass + + # convert mode + assert ( + doc(m.roundtrip_std_set_int) + == "roundtrip_std_set_int(arg0: collections.abc.Set[typing.SupportsInt]) -> set[int]" + ) + assert m.roundtrip_std_set_int({1, 2, 3}) == {1, 2, 3} + assert m.roundtrip_std_set_int(FormalSetLike(1, 2, 3)) == {1, 2, 3} + assert m.roundtrip_std_set_int(set()) == set() + assert m.roundtrip_std_set_int(FormalSetLike()) == set() + with pytest.raises(TypeError): + m.roundtrip_std_set_int(BareSetLike(1, 2, 3)) + # noconvert mode + assert ( + doc(m.roundtrip_std_set_int_noconvert) + == "roundtrip_std_set_int_noconvert(s: set[int]) -> set[int]" + ) + assert m.roundtrip_std_set_int_noconvert({1, 2, 3}) == {1, 2, 3} + assert m.roundtrip_std_set_int_noconvert(set()) == set() + with pytest.raises(TypeError): + m.roundtrip_std_set_int_noconvert(FormalSetLike(1, 2, 3)) + with pytest.raises(TypeError): + m.roundtrip_std_set_int_noconvert(BareSetLike(1, 2, 3))