Skip to content

Commit 427e4af

Browse files
committed
Fix buffer protocol inheritance
Fixes #878.
1 parent 6d2411f commit 427e4af

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

include/pybind11/class_support.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,11 +447,17 @@ inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) {
447447

448448
/// buffer_protocol: Fill in the view as specified by flags.
449449
extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
450-
auto tinfo = get_type_info(Py_TYPE(obj));
450+
// Look for a `get_buffer` implementation in this type's info or any bases (following MRO).
451+
type_info *tinfo = nullptr;
452+
for (auto type : reinterpret_borrow<tuple>(Py_TYPE(obj)->tp_mro)) {
453+
tinfo = get_type_info((PyTypeObject *) type.ptr());
454+
if (tinfo && tinfo->get_buffer)
455+
break;
456+
}
451457
if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) {
452458
if (view)
453459
view->obj = nullptr;
454-
PyErr_SetString(PyExc_BufferError, "generic_type::getbuffer(): Internal error");
460+
PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error");
455461
return -1;
456462
}
457463
memset(view, 0, sizeof(Py_buffer));

tests/test_buffers.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ class Matrix {
7474
float *m_data;
7575
};
7676

77+
class SquareMatrix : public Matrix {
78+
public:
79+
SquareMatrix(ssize_t n) : Matrix(n, n) { }
80+
};
81+
7782
struct PTMFBuffer {
7883
int32_t value = 0;
7984

@@ -141,6 +146,10 @@ test_initializer buffers([](py::module &m) {
141146
})
142147
;
143148

149+
// Derived classes inherit the buffer protocol and the buffer access function
150+
py::class_<SquareMatrix, Matrix>(m, "SquareMatrix")
151+
.def(py::init<ssize_t>());
152+
144153
py::class_<PTMFBuffer>(m, "PTMFBuffer", py::buffer_protocol())
145154
.def(py::init<>())
146155
.def_readwrite("value", &PTMFBuffer::value)

tests/test_buffers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_from_python():
3636
@pytest.unsupported_on_pypy
3737
def test_to_python():
3838
m = Matrix(5, 5)
39+
assert memoryview(m).shape == (5, 5)
3940

4041
assert m[2, 3] == 0
4142
m[2, 3] = 4
@@ -63,6 +64,16 @@ def test_to_python():
6364
assert cstats.move_assignments == 0
6465

6566

67+
@pytest.unsupported_on_pypy
68+
def test_inherited_protocol():
69+
"""SquareMatrix is derived from Matrix and inherits the buffer protocol"""
70+
from pybind11_tests import SquareMatrix
71+
72+
matrix = SquareMatrix(5)
73+
assert memoryview(matrix).shape == (5, 5)
74+
assert np.asarray(matrix).shape == (5, 5)
75+
76+
6677
@pytest.unsupported_on_pypy
6778
def test_ptmf():
6879
for cls in [PTMFBuffer, ConstPTMFBuffer, DerivedPTMFBuffer]:

0 commit comments

Comments
 (0)