Skip to content

Commit 342cc67

Browse files
Add unittest checking drake#11424
1 parent 403d7c4 commit 342cc67

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

tests/test_class.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,24 @@ TEST_SUBMODULE(class_, m) {
367367
.def(py::init<>())
368368
.def("ptr", &Aligned::ptr);
369369
#endif
370+
371+
// Test #1922 (drake#11424).
372+
class ExampleVirt2 {
373+
public:
374+
virtual ~ExampleVirt2() {}
375+
virtual std::string get_name() const { return "ExampleVirt2"; }
376+
};
377+
class PyExampleVirt2 : public ExampleVirt2 {
378+
public:
379+
std::string get_name() const override {
380+
PYBIND11_OVERLOAD(std::string, ExampleVirt2, get_name, );
381+
}
382+
};
383+
py::class_<ExampleVirt2, PyExampleVirt2>(m, "ExampleVirt2")
384+
.def(py::init())
385+
.def("get_name", &ExampleVirt2::get_name);
386+
m.def("example_virt2_get_name",
387+
[](const ExampleVirt2& obj) { return obj.get_name(); });
370388
}
371389

372390
template <int N> class BreaksBase { public: virtual ~BreaksBase() = default; };

tests/test_class.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import weakref
23

34
from pybind11_tests import class_ as m
45
from pybind11_tests import UserType, ConstructorStats
@@ -289,3 +290,34 @@ def test_aligned():
289290
if hasattr(m, "Aligned"):
290291
p = m.Aligned().ptr()
291292
assert p % 1024 == 0
293+
294+
295+
def test_1922():
296+
# Test #1922 (drake#11424).
297+
# Define a derived class which *does not* overload the method.
298+
# WARNING: The reproduction of this failure may be platform-specific, and
299+
# seems to depend on the order of definition and/or the name of the classes
300+
# defined. For example, trying to place this and the C++ code in
301+
# `test_virtual_functions` makes `assert id_1 == id_2` below fail.
302+
class Child1(m.ExampleVirt2): pass
303+
304+
id_1 = id(Child1)
305+
assert m.example_virt2_get_name(m.ExampleVirt2()) == "ExampleVirt2"
306+
assert m.example_virt2_get_name(Child1()) == "ExampleVirt2"
307+
308+
# Now delete everything (and ensure it's deleted).
309+
wref = weakref.ref(Child1)
310+
del Child1
311+
pytest.gc_collect()
312+
assert wref() == None
313+
314+
# Define a derived class which *does* define an overload.
315+
class Child2(m.ExampleVirt2):
316+
def get_name(self): return "Child2"
317+
318+
id_2 = id(Child2)
319+
assert id_1 == id_2 # This happens in CPython; not sure about PyPy.
320+
assert m.example_virt2_get_name(m.ExampleVirt2()) == "ExampleVirt2"
321+
# THIS WILL FAIL: This is using the cached `ExampleVirt2.get_name`, rather
322+
# than re-inspect the Python dictionary.
323+
assert m.example_virt2_get_name(Child2()) == "Child2"

0 commit comments

Comments
 (0)