Skip to content

Commit 2cbf77a

Browse files
committed
Addressing reviewer requests.
1 parent 47e35e9 commit 2cbf77a

5 files changed

+85
-88
lines changed

tests/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ set(PYBIND11_TEST_FILES
123123
test_opaque_types.cpp
124124
test_operator_overloading.cpp
125125
test_pickling.cpp
126-
test_pickling_trampoline.cpp
127126
test_pytypes.cpp
128127
test_sequences_and_iterators.cpp
129128
test_smart_ptr.cpp

tests/test_pickling.cpp

+53
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,65 @@
1+
// clang-format off
12
/*
23
tests/test_pickling.cpp -- pickle support
34
45
Copyright (c) 2016 Wenzel Jakob <[email protected]>
6+
Copyright (c) 2021 The Pybind Development Team.
57
68
All rights reserved. Use of this source code is governed by a
79
BSD-style license that can be found in the LICENSE file.
810
*/
911

1012
#include "pybind11_tests.h"
1113

14+
// clang-format on
15+
16+
#include <memory>
17+
#include <stdexcept>
18+
#include <utility>
19+
20+
namespace exercise_trampoline {
21+
22+
struct SimpleBase {
23+
int num = 0;
24+
virtual ~SimpleBase() = default;
25+
26+
// For compatibility with old clang versions:
27+
SimpleBase() = default;
28+
SimpleBase(const SimpleBase &) = default;
29+
};
30+
31+
struct SimpleBaseTrampoline : SimpleBase {};
32+
33+
struct SimpleCppDerived : SimpleBase {};
34+
35+
void wrap(py::module m) {
36+
py::class_<SimpleBase, SimpleBaseTrampoline>(m, "SimpleBase")
37+
.def(py::init<>())
38+
.def_readwrite("num", &SimpleBase::num)
39+
.def(py::pickle(
40+
[](py::object self) {
41+
py::dict d;
42+
if (py::hasattr(self, "__dict__"))
43+
d = self.attr("__dict__");
44+
return py::make_tuple(self.attr("num"), d);
45+
},
46+
[](py::tuple t) {
47+
if (t.size() != 2)
48+
throw std::runtime_error("Invalid state!");
49+
auto cpp_state = std::unique_ptr<SimpleBase>(new SimpleBaseTrampoline);
50+
cpp_state->num = t[0].cast<int>();
51+
auto py_state = t[1].cast<py::dict>();
52+
return std::make_pair(std::move(cpp_state), py_state);
53+
}));
54+
55+
m.def("make_SimpleCppDerivedAsBase",
56+
[]() { return std::unique_ptr<SimpleBase>(new SimpleCppDerived); });
57+
}
58+
59+
} // namespace exercise_trampoline
60+
61+
// clang-format off
62+
1263
TEST_SUBMODULE(pickling, m) {
1364
// test_roundtrip
1465
class Pickleable {
@@ -130,4 +181,6 @@ TEST_SUBMODULE(pickling, m) {
130181
return std::make_pair(cpp_state, py_state);
131182
}));
132183
#endif
184+
185+
exercise_trampoline::wrap(m);
133186
}

tests/test_pickling.py

+32
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,35 @@ def test_enum_pickle():
4545

4646
data = pickle.dumps(e.EOne, 2)
4747
assert e.EOne == pickle.loads(data)
48+
49+
50+
#
51+
# exercise_trampoline
52+
#
53+
class SimplePyDerived(m.SimpleBase):
54+
pass
55+
56+
57+
def test_roundtrip_simple_py_derived():
58+
p = SimplePyDerived()
59+
p.num = 202
60+
p.stored_in_dict = 303
61+
data = pickle.dumps(p, pickle.HIGHEST_PROTOCOL)
62+
p2 = pickle.loads(data)
63+
assert isinstance(p2, SimplePyDerived)
64+
assert p2.num == 202
65+
assert p2.stored_in_dict == 303
66+
67+
68+
def test_roundtrip_simple_cpp_derived():
69+
p = m.make_SimpleCppDerivedAsBase()
70+
p.num = 404
71+
if not env.PYPY:
72+
# To ensure that this unit test is not accidentally invalidated.
73+
with pytest.raises(AttributeError):
74+
# Mimics the `setstate` C++ implementation.
75+
setattr(p, "__dict__", {}) # noqa: B010
76+
data = pickle.dumps(p, pickle.HIGHEST_PROTOCOL)
77+
p2 = pickle.loads(data)
78+
assert isinstance(p2, m.SimpleBase)
79+
assert p2.num == 404

tests/test_pickling_trampoline.cpp

-50
This file was deleted.

tests/test_pickling_trampoline.py

-37
This file was deleted.

0 commit comments

Comments
 (0)