Skip to content

Commit 7908b7a

Browse files
committed
enum_: move most functionality to a non-template implementation
This commit addresses an inefficiency in how enums are created in pybind11. Most of the enum_<> implementation is completely generic -- however, being a template class, it ended up instantiating vast amounts of essentially identical code in larger projects with many enums. This commit introduces a generic non-templated helper class that is compatible with any kind of enumeration. enum_ then becomes a thin wrapper around this new class. The new enum_<> API is designed to be 100% compatible with the old one.
1 parent e07478f commit 7908b7a

File tree

1 file changed

+144
-77
lines changed

1 file changed

+144
-77
lines changed

include/pybind11/pybind11.h

Lines changed: 144 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,143 @@ detail::initimpl::pickle_factory<GetState, SetState> pickle(GetState &&g, SetSta
13601360
return {std::forward<GetState>(g), std::forward<SetState>(s)};
13611361
}
13621362

1363+
NAMESPACE_BEGIN(detail)
1364+
struct enum_base {
1365+
enum_base(handle base, handle parent) : m_base(base), m_parent(parent) { }
1366+
1367+
PYBIND11_NOINLINE void init(bool is_arithmetic, bool is_convertible) {
1368+
m_base.attr("__entries") = dict();
1369+
auto property = handle((PyObject *) &PyProperty_Type);
1370+
auto static_property = handle((PyObject *) get_internals().static_property_type);
1371+
1372+
m_base.attr("__repr__") = cpp_function(
1373+
[](handle arg) -> str {
1374+
handle type = arg.get_type();
1375+
object type_name = type.attr("__name__");
1376+
dict entries = type.attr("__entries");
1377+
for (const auto &kv : entries) {
1378+
object other = kv.second[int_(0)];
1379+
if (other.equal(arg))
1380+
return pybind11::str("{}.{}").format(type_name, kv.first);
1381+
}
1382+
return pybind11::str("{}.???").format(type_name);
1383+
}, is_method(m_base)
1384+
);
1385+
1386+
m_base.attr("name") = property(cpp_function(
1387+
[](handle arg) -> str {
1388+
dict entries = arg.get_type().attr("__entries");
1389+
for (const auto &kv : entries) {
1390+
if (handle(kv.second[int_(0)]).equal(arg))
1391+
return pybind11::str(kv.first);
1392+
}
1393+
return "???";
1394+
}, is_method(m_base)
1395+
));
1396+
1397+
m_base.attr("__doc__") = static_property(cpp_function(
1398+
[](handle arg) -> std::string {
1399+
std::string docstring;
1400+
dict entries = arg.attr("__entries");
1401+
if (((PyTypeObject *) arg.ptr())->tp_doc)
1402+
docstring += std::string(((PyTypeObject *) arg.ptr())->tp_doc) + "\n\n";
1403+
docstring += "Members:";
1404+
for (const auto &kv : entries) {
1405+
auto key = std::string(pybind11::str(kv.first));
1406+
auto comment = kv.second[int_(1)];
1407+
docstring += "\n\n " + key;
1408+
if (!comment.is_none())
1409+
docstring += " : " + (std::string) pybind11::str(comment);
1410+
}
1411+
return docstring;
1412+
}
1413+
), none(), none(), "");
1414+
1415+
m_base.attr("__members__") = static_property(cpp_function(
1416+
[](handle arg) -> dict {
1417+
dict entries = arg.attr("__entries"), m;
1418+
for (const auto &kv : entries)
1419+
m[kv.first] = kv.second[int_(0)];
1420+
return m;
1421+
}), none(), none(), ""
1422+
);
1423+
1424+
#define PYBIND11_ENUM_OP_STRICT(op, expr) \
1425+
m_base.attr(op) = cpp_function( \
1426+
[](object a, object b) { \
1427+
if (!a.get_type().is(b.get_type())) \
1428+
throw type_error("Expected an enumeration of matching type!"); \
1429+
return expr; \
1430+
}, \
1431+
is_method(m_base))
1432+
1433+
#define PYBIND11_ENUM_OP_CONV(op, expr) \
1434+
m_base.attr(op) = cpp_function( \
1435+
[](object a_, object b_) { \
1436+
int_ a(a_), b(b_); \
1437+
return expr; \
1438+
}, \
1439+
is_method(m_base))
1440+
1441+
if (is_convertible) {
1442+
PYBIND11_ENUM_OP_CONV("__eq__", !b.is_none() && a.equal(b));
1443+
PYBIND11_ENUM_OP_CONV("__ne__", b.is_none() || !a.equal(b));
1444+
1445+
if (is_arithmetic) {
1446+
PYBIND11_ENUM_OP_CONV("__lt__", a < b);
1447+
PYBIND11_ENUM_OP_CONV("__gt__", a > b);
1448+
PYBIND11_ENUM_OP_CONV("__le__", a <= b);
1449+
PYBIND11_ENUM_OP_CONV("__ge__", a >= b);
1450+
PYBIND11_ENUM_OP_CONV("__and__", a & b);
1451+
PYBIND11_ENUM_OP_CONV("__rand__", a & b);
1452+
PYBIND11_ENUM_OP_CONV("__or__", a | b);
1453+
PYBIND11_ENUM_OP_CONV("__ror__", a | b);
1454+
PYBIND11_ENUM_OP_CONV("__xor__", a ^ b);
1455+
PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b);
1456+
}
1457+
} else {
1458+
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)));
1459+
PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)));
1460+
1461+
if (is_arithmetic) {
1462+
PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b));
1463+
PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b));
1464+
PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b));
1465+
PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b));
1466+
}
1467+
}
1468+
1469+
#undef PYBIND11_ENUM_OP_CONV
1470+
#undef PYBIND11_ENUM_OP_STRICT
1471+
1472+
m_base.attr("__hash__") = cpp_function(
1473+
[](object arg) { return int_(arg); }, is_method(m_base));
1474+
}
1475+
1476+
PYBIND11_NOINLINE void value(char const* name_, object value, const char *doc = nullptr) {
1477+
dict entries = m_base.attr("__entries");
1478+
str name(name_);
1479+
if (entries.contains(name)) {
1480+
std::string type_name = (std::string) str(m_base.get_type());
1481+
throw value_error(type_name + ": element " + std::string(name_) + " already exists!");
1482+
}
1483+
1484+
entries[name] = std::make_pair(value, doc);
1485+
m_base.attr(name) = value;
1486+
}
1487+
1488+
PYBIND11_NOINLINE void export_values() {
1489+
dict entries = m_base.attr("__entries");
1490+
for (const auto &kv : entries)
1491+
m_parent.attr(kv.first) = kv.second[int_(0)];
1492+
}
1493+
1494+
handle m_base;
1495+
handle m_parent;
1496+
};
1497+
1498+
NAMESPACE_END(detail)
1499+
13631500
/// Binds C++ enumerations and enumeration classes to Python
13641501
template <typename Type> class enum_ : public class_<Type> {
13651502
public:
@@ -1370,106 +1507,36 @@ template <typename Type> class enum_ : public class_<Type> {
13701507

13711508
template <typename... Extra>
13721509
enum_(const handle &scope, const char *name, const Extra&... extra)
1373-
: class_<Type>(scope, name, extra...), m_entries(), m_parent(scope) {
1374-
1510+
: class_<Type>(scope, name, extra...), m_base(*this, scope) {
13751511
constexpr bool is_arithmetic = detail::any_of<std::is_same<arithmetic, Extra>...>::value;
1512+
constexpr bool is_convertible = std::is_convertible<Type, Scalar>::value;
1513+
m_base.init(is_arithmetic, is_convertible);
13761514

1377-
auto m_entries_ptr = m_entries.inc_ref().ptr();
1378-
def("__repr__", [name, m_entries_ptr](Type value) -> pybind11::str {
1379-
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr)) {
1380-
if (pybind11::cast<Type>(kv.second[int_(0)]) == value)
1381-
return pybind11::str("{}.{}").format(name, kv.first);
1382-
}
1383-
return pybind11::str("{}.???").format(name);
1384-
});
1385-
def_property_readonly("name", [m_entries_ptr](Type value) -> pybind11::str {
1386-
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr)) {
1387-
if (pybind11::cast<Type>(kv.second[int_(0)]) == value)
1388-
return pybind11::str(kv.first);
1389-
}
1390-
return pybind11::str("???");
1391-
});
1392-
def_property_readonly_static("__doc__", [m_entries_ptr](handle self_) {
1393-
std::string docstring;
1394-
const char *tp_doc = ((PyTypeObject *) self_.ptr())->tp_doc;
1395-
if (tp_doc)
1396-
docstring += std::string(tp_doc) + "\n\n";
1397-
docstring += "Members:";
1398-
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr)) {
1399-
auto key = std::string(pybind11::str(kv.first));
1400-
auto comment = kv.second[int_(1)];
1401-
docstring += "\n\n " + key;
1402-
if (!comment.is_none())
1403-
docstring += " : " + (std::string) pybind11::str(comment);
1404-
}
1405-
return docstring;
1406-
});
1407-
def_property_readonly_static("__members__", [m_entries_ptr](handle /* self_ */) {
1408-
dict m;
1409-
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr))
1410-
m[kv.first] = kv.second[int_(0)];
1411-
return m;
1412-
}, return_value_policy::copy);
14131515
def(init([](Scalar i) { return static_cast<Type>(i); }));
14141516
def("__int__", [](Type value) { return (Scalar) value; });
14151517
#if PY_MAJOR_VERSION < 3
14161518
def("__long__", [](Type value) { return (Scalar) value; });
14171519
#endif
1418-
def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; });
1419-
def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; });
1420-
if (is_arithmetic) {
1421-
def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; });
1422-
def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; });
1423-
def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; });
1424-
def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; });
1425-
}
1426-
if (std::is_convertible<Type, Scalar>::value) {
1427-
// Don't provide comparison with the underlying type if the enum isn't convertible,
1428-
// i.e. if Type is a scoped enum, mirroring the C++ behaviour. (NB: we explicitly
1429-
// convert Type to Scalar below anyway because this needs to compile).
1430-
def("__eq__", [](const Type &value, Scalar value2) { return (Scalar) value == value2; });
1431-
def("__ne__", [](const Type &value, Scalar value2) { return (Scalar) value != value2; });
1432-
if (is_arithmetic) {
1433-
def("__lt__", [](const Type &value, Scalar value2) { return (Scalar) value < value2; });
1434-
def("__gt__", [](const Type &value, Scalar value2) { return (Scalar) value > value2; });
1435-
def("__le__", [](const Type &value, Scalar value2) { return (Scalar) value <= value2; });
1436-
def("__ge__", [](const Type &value, Scalar value2) { return (Scalar) value >= value2; });
1437-
def("__invert__", [](const Type &value) { return ~((Scalar) value); });
1438-
def("__and__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; });
1439-
def("__or__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; });
1440-
def("__xor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; });
1441-
def("__rand__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; });
1442-
def("__ror__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; });
1443-
def("__rxor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; });
1444-
def("__and__", [](const Type &value, const Type &value2) { return (Scalar) value & (Scalar) value2; });
1445-
def("__or__", [](const Type &value, const Type &value2) { return (Scalar) value | (Scalar) value2; });
1446-
def("__xor__", [](const Type &value, const Type &value2) { return (Scalar) value ^ (Scalar) value2; });
1447-
}
1448-
}
1449-
def("__hash__", [](const Type &value) { return (Scalar) value; });
1520+
14501521
// Pickling and unpickling -- needed for use with the 'multiprocessing' module
14511522
def(pickle([](const Type &value) { return pybind11::make_tuple((Scalar) value); },
14521523
[](tuple t) { return static_cast<Type>(t[0].cast<Scalar>()); }));
14531524
}
14541525

14551526
/// Export enumeration entries into the parent scope
14561527
enum_& export_values() {
1457-
for (const auto &kv : m_entries)
1458-
m_parent.attr(kv.first) = kv.second[int_(0)];
1528+
m_base.export_values();
14591529
return *this;
14601530
}
14611531

14621532
/// Add an enumeration entry
14631533
enum_& value(char const* name, Type value, const char *doc = nullptr) {
1464-
auto v = pybind11::cast(value, return_value_policy::copy);
1465-
this->attr(name) = v;
1466-
m_entries[pybind11::str(name)] = std::make_pair(v, doc);
1534+
m_base.value(name, pybind11::cast(value, return_value_policy::copy), doc);
14671535
return *this;
14681536
}
14691537

14701538
private:
1471-
dict m_entries;
1472-
handle m_parent;
1539+
detail::enum_base m_base;
14731540
};
14741541

14751542
NAMESPACE_BEGIN(detail)

0 commit comments

Comments
 (0)