Skip to content

Commit febacd6

Browse files
committed
WIP: Check dynamic type of individual elements to enable arrays holding Py<T> with custom types T.
1 parent 4021797 commit febacd6

File tree

2 files changed

+53
-27
lines changed

2 files changed

+53
-27
lines changed

src/array.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,9 @@ impl<T, D> IntoPy<PyObject> for PyArray<T, D> {
127127
impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
128128
// here we do type-check three times
129129
// 1. Checks if the object is PyArray
130-
// 2. Checks if the data type of the array is T
131-
// 3. Checks if the dimension is same as D
130+
// 2. Checks if the dimension is same as D
131+
// 3. Checks if the data type of the array is T
132+
// 4. Optionally checks if the elements of the array match T
132133
fn extract(ob: &'a PyAny) -> PyResult<Self> {
133134
let array = unsafe {
134135
if npyffi::PyArray_Check(ob.as_ptr()) == 0 {
@@ -137,19 +138,21 @@ impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
137138
&*(ob as *const PyAny as *const PyArray<T, D>)
138139
};
139140

140-
let src_dtype = array.dtype();
141-
let dst_dtype = T::get_dtype(ob.py());
142-
if !src_dtype.is_equiv_to(dst_dtype) {
143-
return Err(TypeError::new(src_dtype, dst_dtype).into());
144-
}
145-
146141
let src_ndim = array.shape().len();
147142
if let Some(dst_ndim) = D::NDIM {
148143
if src_ndim != dst_ndim {
149144
return Err(DimensionalityError::new(src_ndim, dst_ndim).into());
150145
}
151146
}
152147

148+
let src_dtype = array.dtype();
149+
let dst_dtype = T::get_dtype(ob.py());
150+
if !src_dtype.is_equiv_to(dst_dtype) {
151+
return Err(TypeError::new(src_dtype, dst_dtype).into());
152+
}
153+
154+
T::check_element_types(ob.py(), array)?;
155+
153156
Ok(array)
154157
}
155158
}

src/dtype.rs

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1+
use std::any::{type_name, TypeId};
12
use std::mem::size_of;
23
use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort};
34

5+
use ndarray::Dimension;
46
use num_traits::{Bounded, Zero};
5-
use pyo3::{ffi, prelude::*, pyobject_native_type_core, types::PyType, AsPyPointer, PyNativeType};
7+
use pyo3::{
8+
ffi, prelude::*, pyobject_native_type_core, type_object::PyTypeObject, types::PyType,
9+
AsPyPointer, PyDowncastError, PyNativeType, PyTypeInfo,
10+
};
611

12+
use crate::array::PyArray;
713
use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API};
14+
use crate::NpySingleIterBuilder;
815

916
pub use num_complex::{Complex32, Complex64};
1017

@@ -132,26 +139,12 @@ impl PyArrayDescr {
132139
/// #[pyclass]
133140
/// pub struct CustomElement;
134141
///
135-
/// // The transparent wrapper is necessary as one cannot implement
136-
/// // a foreign trait (`Element`) on a foreign type (`Py`) directly.
137-
/// #[derive(Clone)]
138-
/// #[repr(transparent)]
139-
/// pub struct Wrapper(pub Py<CustomElement>);
140-
///
141-
/// unsafe impl Element for Wrapper {
142-
/// const IS_COPY: bool = false;
143-
///
144-
/// fn get_dtype(py: Python) -> &PyArrayDescr {
145-
/// PyArrayDescr::object(py)
146-
/// }
147-
/// }
148-
///
149142
/// Python::with_gil(|py| {
150-
/// let array = Array2::<Wrapper>::from_shape_fn((2, 3), |(_i, _j)| {
151-
/// Wrapper(Py::new(py, CustomElement).unwrap())
143+
/// let array = Array2::<Py<CustomElement>>::from_shape_fn((2, 3), |(_i, _j)| {
144+
/// Py::new(py, CustomElement).unwrap()
152145
/// });
153146
///
154-
/// let _array: &PyArray<Wrapper, _> = array.to_pyarray(py);
147+
/// let _array: &PyArray<Py<CustomElement>, _> = array.to_pyarray(py);
155148
/// });
156149
/// ```
157150
pub unsafe trait Element: Clone + Send {
@@ -164,6 +157,9 @@ pub unsafe trait Element: Clone + Send {
164157
/// that contain object-type fields.
165158
const IS_COPY: bool;
166159

160+
/// TODO
161+
fn check_element_types<D: Dimension>(py: Python, array: &PyArray<Self, D>) -> PyResult<()>;
162+
167163
/// Returns the associated array descriptor ("dtype") for the given type.
168164
fn get_dtype(py: Python) -> &PyArrayDescr;
169165
}
@@ -218,6 +214,12 @@ macro_rules! impl_element_scalar {
218214
$(#[$meta])*
219215
unsafe impl Element for $ty {
220216
const IS_COPY: bool = true;
217+
218+
fn check_element_types<D: Dimension>(_py: Python, _array: &PyArray<Self, D>) -> PyResult<()> {
219+
// For scalar types, checking the dtype is sufficient.
220+
Ok(())
221+
}
222+
221223
fn get_dtype(py: Python) -> &PyArrayDescr {
222224
PyArrayDescr::from_npy_type(py, $npy_type)
223225
}
@@ -244,9 +246,30 @@ impl_element_scalar!(Complex64 => NPY_CDOUBLE,
244246
#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
245247
impl_element_scalar!(usize, isize);
246248

247-
unsafe impl Element for PyObject {
249+
unsafe impl<T> Element for Py<T>
250+
where
251+
T: PyTypeInfo + 'static,
252+
{
248253
const IS_COPY: bool = false;
249254

255+
fn check_element_types<D: Dimension>(py: Python, array: &PyArray<Self, D>) -> PyResult<()> {
256+
// `PyAny` can represent any Python object.
257+
if TypeId::of::<PyAny>() == TypeId::of::<T>() {
258+
return Ok(());
259+
}
260+
261+
let type_object = T::type_object(py);
262+
let iterator = NpySingleIterBuilder::readwrite(array).build()?;
263+
264+
for element in iterator {
265+
if !type_object.is_instance(element)? {
266+
return Err(PyDowncastError::new(array, type_name::<T>()).into());
267+
}
268+
}
269+
270+
Ok(())
271+
}
272+
250273
fn get_dtype(py: Python) -> &PyArrayDescr {
251274
PyArrayDescr::object(py)
252275
}

0 commit comments

Comments
 (0)