Skip to content

Commit c07168a

Browse files
EricCousineau-TRIBetsyMcPhail
authored andcommitted
Add unittest checking drake#11424
1 parent 403d7c4 commit c07168a

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

tests/test_class.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,24 @@ TEST_SUBMODULE(class_, m) {
367367
.def(py::init<>())
368368
.def("ptr", &Aligned::ptr);
369369
#endif
370+
371+
// Test #1922 (drake#11424).
372+
class ExampleVirt2 {
373+
public:
374+
virtual ~ExampleVirt2() {}
375+
virtual std::string get_name() const { return "ExampleVirt2"; }
376+
};
377+
class PyExampleVirt2 : public ExampleVirt2 {
378+
public:
379+
std::string get_name() const override {
380+
PYBIND11_OVERLOAD(std::string, ExampleVirt2, get_name, );
381+
}
382+
};
383+
py::class_<ExampleVirt2, PyExampleVirt2>(m, "ExampleVirt2")
384+
.def(py::init())
385+
.def("get_name", &ExampleVirt2::get_name);
386+
m.def("example_virt2_get_name",
387+
[](const ExampleVirt2& obj) { return obj.get_name(); });
370388
}
371389

372390
template <int N> class BreaksBase { public: virtual ~BreaksBase() = default; };

tests/test_class.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import weakref
23

34
from pybind11_tests import class_ as m
45
from pybind11_tests import UserType, ConstructorStats
@@ -289,3 +290,38 @@ def test_aligned():
289290
if hasattr(m, "Aligned"):
290291
p = m.Aligned().ptr()
291292
assert p % 1024 == 0
293+
294+
295+
@pytest.mark.skip(reason="Generally reproducible in CPython, Python 3, non-debug, on Linux.\
296+
However, it is hard to pin this down for CI.")
297+
def test_1922():
298+
# Test #1922 (drake#11424).
299+
# Define a derived class which *does not* overload the method.
300+
# WARNING: The reproduction of this failure may be platform-specific, and
301+
# seems to depend on the order of definition and/or the name of the classes
302+
# defined. For example, trying to place this and the C++ code in
303+
# `test_virtual_functions` makes `assert id_1 == id_2` below fail.
304+
class Child1(m.ExampleVirt2):
305+
pass
306+
307+
id_1 = id(Child1)
308+
assert m.example_virt2_get_name(m.ExampleVirt2()) == "ExampleVirt2"
309+
assert m.example_virt2_get_name(Child1()) == "ExampleVirt2"
310+
311+
# Now delete everything (and ensure it's deleted).
312+
wref = weakref.ref(Child1)
313+
del Child1
314+
pytest.gc_collect()
315+
assert wref() is None
316+
317+
# Define a derived class which *does* define an overload.
318+
class Child2(m.ExampleVirt2):
319+
def get_name(self):
320+
return "Child2"
321+
322+
id_2 = id(Child2)
323+
assert id_1 == id_2 # This happens in CPython; not sure about PyPy.
324+
assert m.example_virt2_get_name(m.ExampleVirt2()) == "ExampleVirt2"
325+
# THIS WILL FAIL: This is using the cached `ExampleVirt2.get_name`, rather
326+
# than re-inspect the Python dictionary.
327+
assert m.example_virt2_get_name(Child2()) == "Child2"

0 commit comments

Comments
 (0)