@@ -305,14 +305,16 @@ class dtype : public object {
305
305
306
306
class array : public buffer {
307
307
public:
308
- PYBIND11_OBJECT_DEFAULT (array, buffer, detail::npy_api::get().PyArray_Check_)
308
+ PYBIND11_OBJECT_CVT (array, buffer, detail::npy_api::get().PyArray_Check_, raw_array )
309
309
310
310
enum {
311
311
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
312
312
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
313
313
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
314
314
};
315
315
316
+ array () : array(0 , static_cast <const double *>(nullptr )) {}
317
+
316
318
array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
317
319
const std::vector<size_t > &strides, const void *ptr = nullptr ,
318
320
handle base = handle()) {
@@ -478,10 +480,12 @@ class array : public buffer {
478
480
}
479
481
480
482
// / Ensure that the argument is a NumPy array
481
- static array ensure (object input, int ExtraFlags = 0 ) {
482
- auto & api = detail::npy_api::get ();
483
- return reinterpret_steal<array>(api.PyArray_FromAny_ (
484
- input.release ().ptr (), nullptr , 0 , 0 , detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr ));
483
+ // / In case of an error, nullptr is returned and the Python error is cleared.
484
+ static array ensure (handle h, int ExtraFlags = 0 ) {
485
+ auto result = reinterpret_steal<array>(raw_array (h.ptr (), ExtraFlags));
486
+ if (!result)
487
+ PyErr_Clear ();
488
+ return result;
485
489
}
486
490
487
491
protected:
@@ -520,8 +524,6 @@ class array : public buffer {
520
524
return strides;
521
525
}
522
526
523
- protected:
524
-
525
527
template <typename ... Ix> void check_dimensions (Ix... index) const {
526
528
check_dimensions_impl (size_t (0 ), shape (), size_t (index )...);
527
529
}
@@ -536,15 +538,31 @@ class array : public buffer {
536
538
}
537
539
check_dimensions_impl (axis + 1 , shape + 1 , index ...);
538
540
}
541
+
542
+ // / Create array from any object -- always returns a new reference
543
+ static PyObject *raw_array (PyObject *ptr, int ExtraFlags = 0 ) {
544
+ if (ptr == nullptr )
545
+ return nullptr ;
546
+ return detail::npy_api::get ().PyArray_FromAny_ (
547
+ ptr, nullptr , 0 , 0 , detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr );
548
+ }
539
549
};
540
550
541
551
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
542
552
public:
543
- array_t () : array() { }
553
+ array_t () : array(0 , static_cast <const T *>(nullptr )) {}
554
+ array_t (handle h, borrowed_t ) : array(h, borrowed) { }
555
+ array_t (handle h, stolen_t ) : array(h, stolen) { }
544
556
545
- array_t (handle h, bool is_borrowed) : array(h, is_borrowed) { m_ptr = ensure_ (m_ptr); }
557
+ PYBIND11_DEPRECATED (" Use array_t<T>::ensure() instead" )
558
+ array_t (handle h, bool is_borrowed) : array(raw_array_t (h.ptr()), stolen) {
559
+ if (!m_ptr) PyErr_Clear ();
560
+ if (!is_borrowed) Py_XDECREF (h.ptr ());
561
+ }
546
562
547
- array_t (const object &o) : array(o) { m_ptr = ensure_ (m_ptr); }
563
+ array_t (const object &o) : array(raw_array_t (o.ptr()), stolen) {
564
+ if (!m_ptr) throw error_already_set ();
565
+ }
548
566
549
567
explicit array_t (const buffer_info& info) : array(info) { }
550
568
@@ -590,17 +608,30 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
590
608
return *(static_cast <T*>(array::mutable_data ()) + byte_offset (size_t (index )...) / itemsize ());
591
609
}
592
610
593
- static PyObject *ensure_ (PyObject *ptr) {
594
- if (ptr == nullptr )
595
- return nullptr ;
596
- auto & api = detail::npy_api::get ();
597
- PyObject *result = api.PyArray_FromAny_ (ptr, pybind11::dtype::of<T>().release ().ptr (), 0 , 0 ,
598
- detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr );
611
+ // / Ensure that the argument is a NumPy array of the correct dtype.
612
+ // / In case of an error, nullptr is returned and the Python error is cleared.
613
+ static array_t ensure (handle h) {
614
+ auto result = reinterpret_steal<array_t >(raw_array_t (h.ptr ()));
599
615
if (!result)
600
616
PyErr_Clear ();
601
- Py_DECREF (ptr);
602
617
return result;
603
618
}
619
+
620
+ static bool _check (handle h) {
621
+ const auto &api = detail::npy_api::get ();
622
+ return api.PyArray_Check_ (h.ptr ())
623
+ && api.PyArray_EquivTypes_ (PyArray_GET_ (h.ptr (), descr), dtype::of<T>().ptr ());
624
+ }
625
+
626
+ protected:
627
+ // / Create array from any object -- always returns a new reference
628
+ static PyObject *raw_array_t (PyObject *ptr) {
629
+ if (ptr == nullptr )
630
+ return nullptr ;
631
+ return detail::npy_api::get ().PyArray_FromAny_ (
632
+ ptr, dtype::of<T>().release ().ptr (), 0 , 0 ,
633
+ detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr );
634
+ }
604
635
};
605
636
606
637
template <typename T>
@@ -631,7 +662,7 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
631
662
using type = array_t <T, ExtraFlags>;
632
663
633
664
bool load (handle src, bool /* convert */ ) {
634
- value = type (src, true );
665
+ value = type::ensure (src);
635
666
return static_cast <bool >(value);
636
667
}
637
668
0 commit comments