Skip to content

Numpy improvements #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ set(PYBIND11_PYTHON_VERSION "" CACHE STRING "Python version to use for compiling
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/tools")
set(Python_ADDITIONAL_VERSIONS 3.4 3.5 3.6 3.7)
find_package(PythonLibsNew ${PYBIND11_PYTHON_VERSION} REQUIRED)
find_package(Numpy)
if(PYTHON_NUMPY_FOUND)
list(APPEND PYTHON_INCLUDE_DIRS ${PYTHON_NUMPY_INCLUDE_DIR})
endif()

include(CheckCXXCompilerFlag)

Expand Down
291 changes: 260 additions & 31 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
#include <string>
#include <initializer_list>

#ifndef PYBIND_DONT_INCLUDE_NUMPY
#include <Python.h>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <numpy/ndarrayobject.h>
#include <numpy/ndarraytypes.h>
#undef NPY_NO_DEPRECATED_API
#endif

#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
Expand Down Expand Up @@ -153,14 +161,82 @@ class dtype : public object {
return attr("itemsize").cast<size_t>();
}

bool has_fields() const {
return attr("fields").cast<object>().ptr() != Py_None;
}

std::string kind() const {
return (std::string) attr("kind").cast<pybind11::str>();
}

bool check_scalar() const{
return PyArray_CheckScalar(m_ptr);
}

bool is_any_scalar() const{
return PyArray_CheckScalar(m_ptr);
}

bool is_unsigned() const{
return PyDataType_ISUNSIGNED(m_ptr);
}

bool is_signed() const{
return PyDataType_ISSIGNED(m_ptr);
}

bool is_integer() const{
return PyDataType_ISINTEGER(m_ptr);
}

bool is_float() const{
return PyDataType_ISFLOAT(m_ptr);
}

bool is_complex() const{
return PyDataType_ISCOMPLEX(m_ptr);
}

bool is_number() const{
return PyDataType_ISNUMBER(m_ptr);
}

bool is_string() const{
return PyDataType_ISSTRING(m_ptr);
}

bool is_python() const{
return PyDataType_ISPYTHON(m_ptr);
}

bool is_flexible() const{
return PyDataType_ISFLEXIBLE(m_ptr);
}

bool is_userdef() const{
return PyDataType_ISUSERDEF(m_ptr);
}

bool is_extended() const{
return PyDataType_ISEXTENDED(m_ptr);
}

bool is_object() const{
return PyDataType_ISOBJECT(m_ptr);
}

bool is_bool() const{
return PyTypeNum_ISBOOL(((PyArray_Descr*)(m_ptr))->type_num);//bool has bitrotted, see https://mail.scipy.org/pipermail/numpy-discussion/2013-August/067549.html
}

bool has_fields() const{
return PyDataType_HASFIELDS(m_ptr);
}

bool is_swapped() const{
return PyDataType_ISBYTESWAPPED(m_ptr);
}

bool operator==(const dtype &rhs) const{
return PyArray_EquivTypes((PyArray_Descr *) m_ptr, (PyArray_Descr *) rhs.m_ptr);
}

private:
static object _dtype_from_pep3118() {
static PyObject *obj = module::import("numpy.core._internal")
Expand Down Expand Up @@ -204,67 +280,175 @@ class dtype : public object {
}
};

class dtype_record_builder {
list names;
list offsets;
list types;
size_t itemsize;

dtype_record_builder(size_t itemsize):itemsize(itemsize){}

dtype_record_builder& add(const std::string &name, size_t offset, dtype &type){
names.append(str(name));
offsets.append(int_(offset));
types.append(type);
return *this;
}

dtype build(){
return dtype(names, types, offsets, itemsize);
}
};

class array : public buffer {
public:
PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_)

enum {
enum : int{
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
};

array(const pybind11::dtype& dt, const std::vector<size_t>& shape,
const std::vector<size_t>& strides, void *ptr = nullptr) {
const std::vector<size_t>& strides, void *ptr = nullptr, bool copy = true, int flags = 0) {
auto& api = detail::npy_api::get();
auto ndim = shape.size();
if (shape.size() != strides.size())
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
auto descr = dt;
object tmp(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
(Py_intptr_t *) strides.data(), ptr, 0, nullptr), false);
(Py_intptr_t *) strides.data(), ptr, flags, nullptr), false);
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
if (ptr)
if (ptr && copy)
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
m_ptr = tmp.release().ptr();
}

array(const pybind11::dtype& dt, const std::vector<size_t>& shape, void *ptr = nullptr)
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr) { }
array(const pybind11::dtype& dt, const std::vector<size_t>& shape, void *ptr = nullptr, bool copy = true, int flags = 0)
: array(dt, shape, default_strides(shape, dt.itemsize(), flags), ptr, copy, flags) { }

array(const pybind11::dtype& dt, size_t count, void *ptr = nullptr)
: array(dt, std::vector<size_t> { count }, ptr) { }
array(const pybind11::dtype& dt, size_t count, void *ptr = nullptr, bool copy = true, int flags = 0)
: array(dt, std::vector<size_t> { count }, ptr, copy, flags) { }

template<typename T> array(const std::vector<size_t>& shape,
const std::vector<size_t>& strides, T* ptr)
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr) { }
const std::vector<size_t>& strides, T* ptr, bool copy = true, int flags = 0)
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, copy, flags) { }

template<typename T> array(const std::vector<size_t>& shape, T* ptr)
: array(shape, default_strides(shape, sizeof(T)), ptr) { }
template<typename T> array(const std::vector<size_t>& shape, T* ptr, bool copy = true, int flags = 0)
: array(shape, default_strides(shape, sizeof(T), flags), ptr, copy, flags) { }

template<typename T> array(size_t size, T* ptr)
: array(std::vector<size_t> { size }, ptr) { }
template<typename T> array(size_t size, T* ptr, bool copy = true, int flags = 0)
: array(std::vector<size_t> { size }, ptr, copy, flags) { }

array(const buffer_info &info)
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }

pybind11::dtype dtype() {
return attr("dtype").cast<pybind11::dtype>();
return object((PyObject *) PyArray_DTYPE((PyArrayObject *) m_ptr), true).cast<pybind11::dtype>();
}
int ndims() const{
return PyArray_NDIM((PyArrayObject *) m_ptr);
}
npy_intp dim(int d) const{
return PyArray_DIM((PyArrayObject *) m_ptr, d);
}
const npy_intp* shape() const{
return PyArray_SHAPE((PyArrayObject *) m_ptr);
}
const npy_intp* strides() const{
return PyArray_STRIDES((PyArrayObject *) m_ptr);
}
npy_intp stride(int d) const{
return PyArray_STRIDE((PyArrayObject *) m_ptr, d);
}
uint8_t *data() {
return reinterpret_cast<uint8_t *>(PyArray_DATA((PyArrayObject *) m_ptr));
}
uint8_t *data() const{
return reinterpret_cast<uint8_t *>(PyArray_DATA((PyArrayObject *) m_ptr));
}
object base() const{
return object(PyArray_BASE((PyArrayObject *) m_ptr), true);
}
int type() const{
return PyArray_TYPE((PyArrayObject *) m_ptr);
}
int flags() const{
return PyArray_FLAGS((PyArrayObject *) m_ptr);
}
size_t itemsize() const{
return PyArray_ITEMSIZE((PyArrayObject *) m_ptr);
}
size_t size() const{
return PyArray_SIZE((PyArrayObject *) m_ptr);
}
size_t nbytes() const{
return PyArray_NBYTES((PyArrayObject *) m_ptr);
}

template<typename... Args>
uint8_t* data_at(Args&&... indices){
const int nterms = sizeof...(indices);
const uint32_t _indices[] = { uint32_t(indices)... };
auto *strides = this->strides();
assert(nterms < this->ndims());
uint8_t *p = this->data();
for(int i = 0; i < nterms; ++i){
p += _indices[i] * strides[i];
}
return p;
}
template<typename... Args>
const uint8_t* data_at(Args&&... indices) const{
return const_cast<array *>(this)->data_at(indices...);
}

template<int N>
static array empty(const std::array<int, N> &shape, const pybind11::dtype &type, int order = 'C'){
npy_intp _shape[N];
for(int i = 0; i < N; ++i){
_shape[i] = shape[i];
}
object result(PyArray_Empty(N, _shape, type, order == 'f' || order == 'F' ), false);
if (!result)
pybind11_fail("NumPy: unable to create array!");
return result;
}

template<int N>
static array zeros(const std::array<int, N> &shape, const pybind11::dtype &type, int order = 'C'){
npy_intp _shape[N];
for(int i = 0; i < N; ++i){
_shape[i] = shape[i];
}
object result(PyArray_Zeros(N, _shape, type, order == 'f' || order == 'F' ), false);
if (!result)
pybind11_fail("NumPy: unable to create array!");
return result;
}

protected:
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;

static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
auto ndim = shape.size();
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize, int flags) {
const int ndim = (int) shape.size();
std::vector<size_t> strides(ndim);
if (ndim) {
std::fill(strides.begin(), strides.end(), itemsize);
for (size_t i = 0; i < ndim - 1; i++)
for (size_t j = 0; j < ndim - 1 - i; j++)
strides[j] *= shape[ndim - 1 - i];
size_t cumprod = itemsize;
if(flags & c_style){
for(int i = ndim - 1; i >= 0; --i){
strides[i] = cumprod;
cumprod *= shape[i];
}
}else{
for(int i = 0; i < ndim; ++i){
strides[i] = cumprod;
cumprod *= shape[i];
}
}
}
return strides;
}
Expand All @@ -278,14 +462,59 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public

array_t(const buffer_info& info) : array(info) { }

array_t(const std::vector<size_t>& shape, const std::vector<size_t>& strides, T* ptr = nullptr)
: array(shape, strides, ptr) { }
array_t(const std::vector<size_t>& shape, const std::vector<size_t>& strides, T* ptr = nullptr, bool copy = true, int flags = 0)
: array(shape, strides, ptr, copy, flags) { }

array_t(const std::vector<size_t>& shape, T* ptr = nullptr)
: array(shape, ptr) { }
array_t(const std::vector<size_t>& shape, T* ptr = nullptr, bool copy = true, int flags = 0)
: array(shape, ptr, copy, flags) { }

array_t(size_t size, T* ptr = nullptr)
: array(size, ptr) { }
array_t(size_t size, T* ptr = nullptr, bool copy = true, int flags = 0)
: array(size, ptr, copy, flags) { }

T *data() {
return reinterpret_cast<T *>(PyArray_DATA((PyArrayObject *) m_ptr));
}
const T *data() const{
return reinterpret_cast<T *>(PyArray_DATA((PyArrayObject *) m_ptr));
}

template<typename... Args>
T* data_at(Args&&... indices){
const int nterms = sizeof...(indices);
const uint32_t _indices[] = { uint32_t(indices)... };
auto *strides = this->strides();
assert(nterms < this->ndims());
uint8_t *p = this->data();
for(int i = 0; i < nterms; ++i){
p += _indices[i] * strides[i];
}
return p;
}
template<typename... Args>
T* data_at(Args&&... indices) const{
return const_cast<array *>(this)->data_at(indices...);
}

template<typename... Args>
T* at(Args&&... indices){
assert(sizeof...(indices) == this->ndims());
return *this->data_at(indices...);
}
template<typename... Args>
const T* at(Args&&... indices) const{
assert(sizeof...(indices) == this->ndims());
return const_cast<array_t<T, ExtraFlags> *>(this)->at(indices...);
}

template<int N>
static array_t<T> empty(const std::array<int, N> &shape, int order = 'C'){
return empty(shape, dtype::of<T>(), order);
}

template<int N>
static array_t<T> zeros(const std::array<int, N> &shape, int order = 'C'){
return zeros(shape, dtype::of<T>(), order);
}

static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }

Expand Down
5 changes: 5 additions & 0 deletions tests/pybind11_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include "pybind11_tests.h"
#include "constructor_stats.h"

#include <Python.h>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <numpy/ndarrayobject.h>

void init_ex_methods_and_attributes(py::module &);
void init_ex_python_types(py::module &);
void init_ex_operator_overloading(py::module &);
Expand Down Expand Up @@ -50,6 +54,7 @@ void bind_ConstructorStats(py::module &m) {
}

PYBIND11_PLUGIN(pybind11_tests) {
import_array(); //import numpy
py::module m("pybind11_tests", "pybind example plugin");

bind_ConstructorStats(m);
Expand Down
Loading