Skip to content

Commit 09f0829

Browse files
sizmailovwjakob
authored andcommitted
Avoid conversion to int_ rhs argument of enum eq/ne (#1912)
* fix: Avoid conversion to `int_` rhs argument of enum eq/ne * test: compare unscoped enum with strings * suppress comparison to None warning * test unscoped enum arithmetic and comparision with unsupported type
1 parent f6c4c10 commit 09f0829

File tree

3 files changed

+62
-13
lines changed

3 files changed

+62
-13
lines changed

include/pybind11/pybind11.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,9 +1469,17 @@ struct enum_base {
14691469
}, \
14701470
is_method(m_base))
14711471

1472+
#define PYBIND11_ENUM_OP_CONV_LHS(op, expr) \
1473+
m_base.attr(op) = cpp_function( \
1474+
[](object a_, object b) { \
1475+
int_ a(a_); \
1476+
return expr; \
1477+
}, \
1478+
is_method(m_base))
1479+
14721480
if (is_convertible) {
1473-
PYBIND11_ENUM_OP_CONV("__eq__", !b.is_none() && a.equal(b));
1474-
PYBIND11_ENUM_OP_CONV("__ne__", b.is_none() || !a.equal(b));
1481+
PYBIND11_ENUM_OP_CONV_LHS("__eq__", !b.is_none() && a.equal(b));
1482+
PYBIND11_ENUM_OP_CONV_LHS("__ne__", b.is_none() || !a.equal(b));
14751483

14761484
if (is_arithmetic) {
14771485
PYBIND11_ENUM_OP_CONV("__lt__", a < b);
@@ -1501,6 +1509,7 @@ struct enum_base {
15011509
}
15021510
}
15031511

1512+
#undef PYBIND11_ENUM_OP_CONV_LHS
15041513
#undef PYBIND11_ENUM_OP_CONV
15051514
#undef PYBIND11_ENUM_OP_STRICT
15061515

tests/test_enum.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ TEST_SUBMODULE(enums, m) {
1313
// test_unscoped_enum
1414
enum UnscopedEnum {
1515
EOne = 1,
16-
ETwo
16+
ETwo,
17+
EThree
1718
};
1819
py::enum_<UnscopedEnum>(m, "UnscopedEnum", py::arithmetic(), "An unscoped enumeration")
1920
.value("EOne", EOne, "Docstring for EOne")
2021
.value("ETwo", ETwo, "Docstring for ETwo")
22+
.value("EThree", EThree, "Docstring for EThree")
2123
.export_values();
2224

2325
// test_scoped_enum

tests/test_enum.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,38 +21,65 @@ def test_unscoped_enum():
2121

2222
# __members__ property
2323
assert m.UnscopedEnum.__members__ == \
24-
{"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo}
24+
{"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo, "EThree": m.UnscopedEnum.EThree}
2525
# __members__ readonly
2626
with pytest.raises(AttributeError):
2727
m.UnscopedEnum.__members__ = {}
2828
# __members__ returns a copy
2929
foo = m.UnscopedEnum.__members__
3030
foo["bar"] = "baz"
3131
assert m.UnscopedEnum.__members__ == \
32-
{"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo}
32+
{"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo, "EThree": m.UnscopedEnum.EThree}
3333

34-
assert m.UnscopedEnum.__doc__ == \
35-
'''An unscoped enumeration
34+
for docstring_line in '''An unscoped enumeration
3635
3736
Members:
3837
3938
EOne : Docstring for EOne
4039
41-
ETwo : Docstring for ETwo''' or m.UnscopedEnum.__doc__ == \
42-
'''An unscoped enumeration
43-
44-
Members:
45-
4640
ETwo : Docstring for ETwo
4741
48-
EOne : Docstring for EOne'''
42+
EThree : Docstring for EThree'''.split('\n'):
43+
assert docstring_line in m.UnscopedEnum.__doc__
4944

5045
# Unscoped enums will accept ==/!= int comparisons
5146
y = m.UnscopedEnum.ETwo
5247
assert y == 2
5348
assert 2 == y
5449
assert y != 3
5550
assert 3 != y
51+
# Compare with None
52+
assert (y != None) # noqa: E711
53+
assert not (y == None) # noqa: E711
54+
# Compare with an object
55+
assert (y != object())
56+
assert not (y == object())
57+
# Compare with string
58+
assert y != "2"
59+
assert "2" != y
60+
assert not ("2" == y)
61+
assert not (y == "2")
62+
63+
with pytest.raises(TypeError):
64+
y < object()
65+
66+
with pytest.raises(TypeError):
67+
y <= object()
68+
69+
with pytest.raises(TypeError):
70+
y > object()
71+
72+
with pytest.raises(TypeError):
73+
y >= object()
74+
75+
with pytest.raises(TypeError):
76+
y | object()
77+
78+
with pytest.raises(TypeError):
79+
y & object()
80+
81+
with pytest.raises(TypeError):
82+
y ^ object()
5683

5784
assert int(m.UnscopedEnum.ETwo) == 2
5885
assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo"
@@ -71,6 +98,11 @@ def test_unscoped_enum():
7198
assert not (m.UnscopedEnum.ETwo < m.UnscopedEnum.EOne)
7299
assert not (2 < m.UnscopedEnum.EOne)
73100

101+
# arithmetic
102+
assert m.UnscopedEnum.EOne & m.UnscopedEnum.EThree == m.UnscopedEnum.EOne
103+
assert m.UnscopedEnum.EOne | m.UnscopedEnum.ETwo == m.UnscopedEnum.EThree
104+
assert m.UnscopedEnum.EOne ^ m.UnscopedEnum.EThree == m.UnscopedEnum.ETwo
105+
74106

75107
def test_scoped_enum():
76108
assert m.test_scoped_enum(m.ScopedEnum.Three) == "ScopedEnum::Three"
@@ -82,6 +114,12 @@ def test_scoped_enum():
82114
assert not 3 == z
83115
assert z != 3
84116
assert 3 != z
117+
# Compare with None
118+
assert (z != None) # noqa: E711
119+
assert not (z == None) # noqa: E711
120+
# Compare with an object
121+
assert (z != object())
122+
assert not (z == object())
85123
# Scoped enums will *NOT* accept >, <, >= and <= int comparisons (Will throw exceptions)
86124
with pytest.raises(TypeError):
87125
z > 3

0 commit comments

Comments
 (0)