Skip to content

Commit 2894c92

Browse files
nicholasjngwjakob
authored andcommitted
Port nb::bind_map from pybind11 (#114)
This commit adds a port of nanobind's `nb::bind_map<T>` feature to create bindings of STL map types (`map`, `unordered_map`). The implementation contains the following simplifications: 1. The C++17 constexpr feature was used to considerably reduce the size of the implementation. 2. The key/value/item views are simple wrappers without the need for polymorphism or STL unique pointers. They are created once per map type. The commit also includes a port of the associated pybind11 test suite parts. Co-authored by: Nicholas Junge <[email protected]> Co-authored by: Wenzel Jakob <[email protected]>
1 parent 83ec1ff commit 2894c92

File tree

6 files changed

+377
-2
lines changed

6 files changed

+377
-2
lines changed

include/nanobind/nb_class.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,9 @@ template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined
239239
template <typename T, typename SFINAE = int>
240240
struct is_copy_constructible : std::is_copy_constructible<T> { };
241241

242+
template <typename T>
243+
constexpr bool is_copy_constructible_v = is_copy_constructible<T>::value;
244+
242245
NAMESPACE_END(detail)
243246

244247
template <typename T, typename... Ts>
@@ -275,7 +278,7 @@ class class_ : public object {
275278
if constexpr (!std::is_same_v<Alias, T>)
276279
d.flags |= (uint32_t) detail::type_flags::is_trampoline;
277280

278-
if constexpr (detail::is_copy_constructible<T>::value) {
281+
if constexpr (detail::is_copy_constructible_v<T>) {
279282
d.flags |= (uint32_t) detail::type_flags::is_copy_constructible;
280283

281284
if constexpr (!std::is_trivially_copy_constructible_v<T>) {

include/nanobind/stl/bind_map.h

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
nanobind/stl/bind_map.h: Automatic creation of bindings for map-style containers
3+
4+
All rights reserved. Use of this source code is governed by a
5+
BSD-style license that can be found in the LICENSE file.
6+
*/
7+
8+
#pragma once
9+
10+
#include <nanobind/nanobind.h>
11+
#include <nanobind/make_iterator.h>
12+
#include <nanobind/stl/detail/traits.h>
13+
14+
NAMESPACE_BEGIN(NB_NAMESPACE)
15+
16+
template <typename Map, typename... Args>
17+
class_<Map> bind_map(handle scope, const char *name, Args &&...args) {
18+
using Key = typename Map::key_type;
19+
using Value = typename Map::mapped_type;
20+
21+
auto cl = class_<Map>(scope, name, std::forward<Args>(args)...)
22+
.def(init<>())
23+
24+
.def("__len__", &Map::size)
25+
26+
.def("__bool__",
27+
[](const Map &m) { return !m.empty(); },
28+
"Check whether the map is nonempty")
29+
30+
.def("__contains__",
31+
[](const Map &m, const Key &k) { return m.find(k) != m.end(); })
32+
33+
.def("__contains__", // fallback for incompatible types
34+
[](const Map &, handle) { return false; })
35+
36+
.def("__iter__",
37+
[](Map &m) {
38+
return make_key_iterator(type<Map>(), "KeyIterator",
39+
m.begin(), m.end());
40+
},
41+
keep_alive<0, 1>())
42+
43+
.def("__getitem__",
44+
[](Map &m, const Key &k) -> Value & {
45+
auto it = m.find(k);
46+
if (it == m.end())
47+
throw key_error();
48+
return it->second;
49+
},
50+
rv_policy::reference_internal
51+
)
52+
53+
.def("__delitem__",
54+
[](Map &m, const Key &k) {
55+
auto it = m.find(k);
56+
if (it == m.end())
57+
throw key_error();
58+
m.erase(it);
59+
}
60+
);
61+
62+
// Assignment operator for copy-assignable/copy-constructible types
63+
if constexpr (detail::is_copy_assignable_v<Value> ||
64+
detail::is_copy_constructible_v<Value>) {
65+
cl.def("__setitem__", [](Map &m, const Key &k, const Value &v) {
66+
if constexpr (detail::is_copy_assignable_v<Value>) {
67+
m[k] = v;
68+
} else {
69+
auto r = m.emplace(k, v);
70+
if (!r.second) {
71+
// Value is not copy-assignable. Erase and retry
72+
m.erase(r.first);
73+
m.emplace(k, v);
74+
}
75+
}
76+
});
77+
}
78+
79+
// Item, value, and key views
80+
struct KeyView { Map &map; };
81+
struct ValueView { Map &map; };
82+
struct ItemView { Map &map; };
83+
84+
class_<ItemView>(cl, "ItemView")
85+
.def("__len__", [](ItemView &v) { return v.map.size(); })
86+
.def("__iter__",
87+
[](ItemView &v) {
88+
return make_iterator(type<Map>(), "ItemIterator",
89+
v.map.begin(), v.map.end());
90+
},
91+
keep_alive<0, 1>());
92+
93+
class_<KeyView>(cl, "KeyView")
94+
.def("__contains__", [](KeyView &v, const Key &k) { return v.map.find(k) != v.map.end(); })
95+
.def("__contains__", [](KeyView &, handle) { return false; })
96+
.def("__len__", [](KeyView &v) { return v.map.size(); })
97+
.def("__iter__",
98+
[](KeyView &v) {
99+
return make_key_iterator(type<Map>(), "KeyIterator",
100+
v.map.begin(), v.map.end());
101+
},
102+
keep_alive<0, 1>());
103+
104+
class_<ValueView>(cl, "ValueView")
105+
.def("__len__", [](ValueView &v) { return v.map.size(); })
106+
.def("__iter__",
107+
[](ValueView &v) {
108+
return make_value_iterator(type<Map>(), "ValueIterator",
109+
v.map.begin(), v.map.end());
110+
},
111+
keep_alive<0, 1>());
112+
113+
cl.def("keys", [](Map &m) { return new KeyView{m}; }, keep_alive<0, 1>());
114+
cl.def("values", [](Map &m) { return new ValueView{m}; }, keep_alive<0, 1>());
115+
cl.def("items", [](Map &m) { return new ItemView{m}; }, keep_alive<0, 1>());
116+
117+
return cl;
118+
}
119+
120+
NAMESPACE_END(NB_NAMESPACE)

include/nanobind/stl/detail/traits.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,36 @@ struct is_copy_constructible<
3131
is_copy_constructible<typename T::value_type>::value;
3232
};
3333

34+
// std::pair is copy-constructible <=> both constituents are copy-constructible
3435
template <typename T1, typename T2>
3536
struct is_copy_constructible<std::pair<T1, T2>> {
3637
static constexpr bool value =
37-
is_copy_constructible<T1>::value ||
38+
is_copy_constructible<T1>::value &&
3839
is_copy_constructible<T2>::value;
3940
};
4041

42+
// Analogous template for checking copy-assignability
43+
template <typename T, typename SFINAE = int>
44+
struct is_copy_assignable : std::is_copy_assignable<T> { };
45+
46+
template <typename T>
47+
struct is_copy_assignable<T,
48+
enable_if_t<std::is_copy_assignable_v<T> &&
49+
std::is_same_v<typename T::value_type &,
50+
typename T::reference>>> {
51+
static constexpr bool value = is_copy_assignable<typename T::value_type>::value;
52+
};
53+
54+
template <typename T1, typename T2>
55+
struct is_copy_assignable<std::pair<T1, T2>> {
56+
static constexpr bool value =
57+
is_copy_assignable<T1>::value &&
58+
is_copy_assignable<T2>::value;
59+
};
60+
61+
template <typename T>
62+
constexpr bool is_copy_assignable_v = is_copy_assignable<T>::value;
63+
4164
NAMESPACE_END(detail)
4265
NAMESPACE_END(NB_NAMESPACE)
4366

tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ nanobind_add_module(test_functions_ext test_functions.cpp ${NB_EXTRA_ARGS})
1313
nanobind_add_module(test_classes_ext test_classes.cpp ${NB_EXTRA_ARGS})
1414
nanobind_add_module(test_holders_ext test_holders.cpp ${NB_EXTRA_ARGS})
1515
nanobind_add_module(test_stl_ext test_stl.cpp ${NB_EXTRA_ARGS})
16+
nanobind_add_module(test_bind_map_ext test_stl_bind_map.cpp ${NB_EXTRA_ARGS})
1617
nanobind_add_module(test_enum_ext test_enum.cpp ${NB_EXTRA_ARGS})
1718
nanobind_add_module(test_tensor_ext test_tensor.cpp ${NB_EXTRA_ARGS})
1819
nanobind_add_module(test_intrusive_ext test_intrusive.cpp object.cpp object.h ${NB_EXTRA_ARGS})
@@ -40,6 +41,7 @@ set(TEST_FILES
4041
test_classes.py
4142
test_holders.py
4243
test_stl.py
44+
test_stl_bind_map.py
4345
test_enum.py
4446
test_tensor.py
4547
test_intrusive.py

tests/test_stl_bind_map.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#include <map>
2+
#include <string>
3+
#include <unordered_map>
4+
#include <vector>
5+
6+
#include <nanobind/stl/bind_map.h>
7+
#include <nanobind/stl/string.h>
8+
#include <nanobind/stl/vector.h>
9+
10+
namespace nb = nanobind;
11+
12+
// testing for insertion of non-copyable class
13+
class E_nc {
14+
public:
15+
explicit E_nc(int i) : value{i} {}
16+
E_nc(const E_nc &) = delete;
17+
E_nc &operator=(const E_nc &) = delete;
18+
E_nc(E_nc &&) = default;
19+
E_nc &operator=(E_nc &&) = default;
20+
21+
int value;
22+
};
23+
24+
template <class Map>
25+
Map *times_ten(int n) {
26+
auto *m = new Map();
27+
for (int i = 1; i <= n; i++) {
28+
m->emplace(int(i), E_nc(10 * i));
29+
}
30+
return m;
31+
}
32+
33+
template <class NestMap>
34+
NestMap *times_hundred(int n) {
35+
auto *m = new NestMap();
36+
for (int i = 1; i <= n; i++) {
37+
for (int j = 1; j <= n; j++) {
38+
(*m)[i].emplace(int(j * 10), E_nc(100 * j));
39+
}
40+
}
41+
return m;
42+
}
43+
44+
NB_MODULE(test_bind_map_ext, m) {
45+
// test_map_string_double
46+
nb::bind_map<std::map<std::string, double>>(m, "MapStringDouble");
47+
nb::bind_map<std::unordered_map<std::string, double>>(m, "UnorderedMapStringDouble");
48+
// test_map_string_double_const
49+
nb::bind_map<std::map<std::string, double const>>(m, "MapStringDoubleConst");
50+
nb::bind_map<std::unordered_map<std::string, double const>>(m,
51+
"UnorderedMapStringDoubleConst");
52+
53+
nb::class_<E_nc>(m, "ENC").def(nb::init<int>()).def_readwrite("value", &E_nc::value);
54+
55+
nb::bind_map<std::map<int, E_nc>>(m, "MapENC");
56+
m.def("get_mnc", &times_ten<std::map<int, E_nc>>);
57+
nb::bind_map<std::unordered_map<int, E_nc>>(m, "UmapENC");
58+
m.def("get_umnc", &times_ten<std::unordered_map<int, E_nc>>);
59+
// Issue #1885: binding nested std::map<X, Container<E>> with E non-copyable
60+
nb::bind_map<std::map<int, std::vector<E_nc>>>(m, "MapVecENC");
61+
m.def("get_nvnc", [](int n) {
62+
auto *m = new std::map<int, std::vector<E_nc>>();
63+
for (int i = 1; i <= n; i++) {
64+
for (int j = 1; j <= n; j++) {
65+
(*m)[i].emplace_back(j);
66+
}
67+
}
68+
return m;
69+
});
70+
71+
nb::bind_map<std::map<int, std::map<int, E_nc>>>(m, "MapMapENC");
72+
m.def("get_nmnc", &times_hundred<std::map<int, std::map<int, E_nc>>>);
73+
nb::bind_map<std::unordered_map<int, std::unordered_map<int, E_nc>>>(m, "UmapUmapENC");
74+
m.def("get_numnc", &times_hundred<std::unordered_map<int, std::unordered_map<int, E_nc>>>);
75+
76+
}

0 commit comments

Comments
 (0)