Skip to content

Commit 0fe6005

Browse files
committed
Added template constructors to buffer_info that can deduce the item size, format string, and number of dimensions from the pointer type and the shape container
1 parent 731a9f6 commit 0fe6005

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

include/pybind11/buffer_info.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,18 @@ struct buffer_info {
3434
for (size_t i = 0; i < (size_t) ndim; ++i)
3535
size *= shape[i];
3636
}
37+
38+
template <typename T>
39+
buffer_info(T *ptr, detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
40+
// Brace-initialization of the base class ensures left-to-right evaluation order of parameters (i.e. getting ndim before
41+
: buffer_info{ptr, sizeof(T), format_descriptor<T>::format(), static_cast<ssize_t>(shape_in->size()), std::move(shape_in), std::move(strides_in)} { } moving the shape container)
3742

3843
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size)
3944
: buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { }
45+
46+
template <typename T>
47+
buffer_info(T *ptr, ssize_t size)
48+
: buffer_info(ptr, sizeof(T), format_descriptor<T>::format(), size) { }
4049

4150
explicit buffer_info(Py_buffer *view, bool ownview = true)
4251
: buffer_info(view->buf, view->itemsize, view->format, view->ndim,

tests/test_buffers.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,6 @@ test_initializer buffers([](py::module &m) {
105105
.def_buffer([](Matrix &m) -> py::buffer_info {
106106
return py::buffer_info(
107107
m.data(), /* Pointer to buffer */
108-
sizeof(float), /* Size of one scalar */
109-
py::format_descriptor<float>::format(), /* Python struct-style format descriptor */
110-
2, /* Number of dimensions */
111108
{ m.rows(), m.cols() }, /* Buffer dimensions */
112109
{ sizeof(float) * size_t(m.rows()), /* Strides (in bytes) for each index */
113110
sizeof(float) }

0 commit comments

Comments
 (0)