1
+ use std:: any:: { type_name, TypeId } ;
1
2
use std:: mem:: size_of;
2
3
use std:: os:: raw:: { c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort} ;
3
4
5
+ use ndarray:: Dimension ;
4
6
use num_traits:: { Bounded , Zero } ;
5
- use pyo3:: { ffi, prelude:: * , pyobject_native_type_core, types:: PyType , AsPyPointer , PyNativeType } ;
7
+ use pyo3:: {
8
+ ffi, prelude:: * , pyobject_native_type_core, type_object:: PyTypeObject , types:: PyType ,
9
+ AsPyPointer , PyDowncastError , PyNativeType , PyTypeInfo ,
10
+ } ;
6
11
12
+ use crate :: array:: PyArray ;
7
13
use crate :: npyffi:: { NpyTypes , PyArray_Descr , NPY_TYPES , PY_ARRAY_API } ;
14
+ use crate :: NpySingleIterBuilder ;
8
15
9
16
pub use num_complex:: { Complex32 , Complex64 } ;
10
17
@@ -132,26 +139,12 @@ impl PyArrayDescr {
132
139
/// #[pyclass]
133
140
/// pub struct CustomElement;
134
141
///
135
- /// // The transparent wrapper is necessary as one cannot implement
136
- /// // a foreign trait (`Element`) on a foreign type (`Py`) directly.
137
- /// #[derive(Clone)]
138
- /// #[repr(transparent)]
139
- /// pub struct Wrapper(pub Py<CustomElement>);
140
- ///
141
- /// unsafe impl Element for Wrapper {
142
- /// const IS_COPY: bool = false;
143
- ///
144
- /// fn get_dtype(py: Python) -> &PyArrayDescr {
145
- /// PyArrayDescr::object(py)
146
- /// }
147
- /// }
148
- ///
149
142
/// Python::with_gil(|py| {
150
- /// let array = Array2::<Wrapper >::from_shape_fn((2, 3), |(_i, _j)| {
151
- /// Wrapper( Py::new(py, CustomElement).unwrap() )
143
+ /// let array = Array2::<Py<CustomElement> >::from_shape_fn((2, 3), |(_i, _j)| {
144
+ /// Py::new(py, CustomElement).unwrap()
152
145
/// });
153
146
///
154
- /// let _array: &PyArray<Wrapper , _> = array.to_pyarray(py);
147
+ /// let _array: &PyArray<Py<CustomElement> , _> = array.to_pyarray(py);
155
148
/// });
156
149
/// ```
157
150
pub unsafe trait Element : Clone + Send {
@@ -164,6 +157,9 @@ pub unsafe trait Element: Clone + Send {
164
157
/// that contain object-type fields.
165
158
const IS_COPY : bool ;
166
159
160
+ /// TODO
161
+ fn check_element_types < D : Dimension > ( py : Python , array : & PyArray < Self , D > ) -> PyResult < ( ) > ;
162
+
167
163
/// Returns the associated array descriptor ("dtype") for the given type.
168
164
fn get_dtype ( py : Python ) -> & PyArrayDescr ;
169
165
}
@@ -218,6 +214,12 @@ macro_rules! impl_element_scalar {
218
214
$( #[ $meta] ) *
219
215
unsafe impl Element for $ty {
220
216
const IS_COPY : bool = true ;
217
+
218
+ fn check_element_types<D : Dimension >( _py: Python , _array: & PyArray <Self , D >) -> PyResult <( ) > {
219
+ // For scalar types, checking the dtype is sufficient.
220
+ Ok ( ( ) )
221
+ }
222
+
221
223
fn get_dtype( py: Python ) -> & PyArrayDescr {
222
224
PyArrayDescr :: from_npy_type( py, $npy_type)
223
225
}
@@ -244,9 +246,30 @@ impl_element_scalar!(Complex64 => NPY_CDOUBLE,
244
246
#[ cfg( any( target_pointer_width = "32" , target_pointer_width = "64" ) ) ]
245
247
impl_element_scalar ! ( usize , isize ) ;
246
248
247
- unsafe impl Element for PyObject {
249
+ unsafe impl < T > Element for Py < T >
250
+ where
251
+ T : PyTypeInfo + ' static ,
252
+ {
248
253
const IS_COPY : bool = false ;
249
254
255
+ fn check_element_types < D : Dimension > ( py : Python , array : & PyArray < Self , D > ) -> PyResult < ( ) > {
256
+ // `PyAny` can represent any Python object.
257
+ if TypeId :: of :: < PyAny > ( ) == TypeId :: of :: < T > ( ) {
258
+ return Ok ( ( ) ) ;
259
+ }
260
+
261
+ let type_object = T :: type_object ( py) ;
262
+ let iterator = NpySingleIterBuilder :: readwrite ( array) . build ( ) ?;
263
+
264
+ for element in iterator {
265
+ if !type_object. is_instance ( element) ? {
266
+ return Err ( PyDowncastError :: new ( array, type_name :: < T > ( ) ) . into ( ) ) ;
267
+ }
268
+ }
269
+
270
+ Ok ( ( ) )
271
+ }
272
+
250
273
fn get_dtype ( py : Python ) -> & PyArrayDescr {
251
274
PyArrayDescr :: object ( py)
252
275
}
0 commit comments