diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a52bb31f..83daf1432 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ - Unreleased - Add dynamic borrow checking to safely construct references into the interior of NumPy arrays. ([#274](https://github.com/PyO3/rust-numpy/pull/274)) + - The deprecated iterator builders `NpySingleIterBuilder::{readonly,readwrite}` and `NpyMultiIterBuilder::add_{readonly,readwrite}` now take referencces to `PyReadonlyArray` and `PyReadwriteArray` instead of consuming them. + - The destructive `PyArray::resize` method is now unsafe if used without an instance of `PyReadwriteArray`. ([#302](https://github.com/PyO3/rust-numpy/pull/302)) - The `inner`, `dot` and `einsum` functions can also return a scalar instead of a zero-dimensional array to match NumPy's types ([#285](https://github.com/PyO3/rust-numpy/pull/285)) - Deprecate `PyArray::from_exact_iter` after optimizing `PyArray::from_iter`. ([#292](https://github.com/PyO3/rust-numpy/pull/292)) diff --git a/src/array.rs b/src/array.rs index 4ee5279ca..b5aed2b43 100644 --- a/src/array.rs +++ b/src/array.rs @@ -23,7 +23,7 @@ use crate::borrow::{PyReadonlyArray, PyReadwriteArray}; use crate::cold; use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; use crate::dtype::{Element, PyArrayDescr}; -use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError}; +use crate::error::{BorrowError, DimensionalityError, FromVecError, NotContiguousError, TypeError}; use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API}; use crate::slice_container::PySliceContainer; @@ -846,13 +846,33 @@ impl PyArray { } /// Get an immutable borrow of the NumPy array + pub fn try_readonly(&self) -> Result, BorrowError> { + PyReadonlyArray::try_new(self) + } + + /// Get an immutable borrow of the NumPy array + /// + /// # Panics + /// + /// Panics if the allocation backing the array is currently mutably borrowed. + /// For a non-panicking variant, use [`try_readonly`][Self::try_readonly]. pub fn readonly(&self) -> PyReadonlyArray<'_, T, D> { - PyReadonlyArray::try_new(self).unwrap() + self.try_readonly().unwrap() } /// Get a mutable borrow of the NumPy array + pub fn try_readwrite(&self) -> Result, BorrowError> { + PyReadwriteArray::try_new(self) + } + + /// Get a mutable borrow of the NumPy array + /// + /// # Panics + /// + /// Panics if the allocation backing the array is currently borrowed. + /// For a non-panicking variant, use [`try_readwrite`][Self::try_readwrite]. pub fn readwrite(&self) -> PyReadwriteArray<'_, T, D> { - PyReadwriteArray::try_new(self).unwrap() + self.try_readwrite().unwrap() } /// Returns the internal array as [`ArrayView`]. @@ -1057,19 +1077,30 @@ impl PyArray { data.into_pyarray(py) } - /// Extends or trancates the length of 1 dimension PyArray. + /// Extends or truncates the length of a one-dimensional array. + /// + /// # Safety + /// + /// There should be no outstanding references (shared or exclusive) into the array + /// as this method might re-allocate it and thereby invalidate all pointers into it. /// /// # Example + /// /// ``` /// use numpy::PyArray; - /// pyo3::Python::with_gil(|py| { + /// use pyo3::Python; + /// + /// Python::with_gil(|py| { /// let pyarray = PyArray::arange(py, 0, 10, 1); /// assert_eq!(pyarray.len(), 10); - /// pyarray.resize(100).unwrap(); + /// + /// unsafe { + /// pyarray.resize(100).unwrap(); + /// } /// assert_eq!(pyarray.len(), 100); /// }); /// ``` - pub fn resize(&self, new_elems: usize) -> PyResult<()> { + pub unsafe fn resize(&self, new_elems: usize) -> PyResult<()> { self.resize_(self.py(), [new_elems], 1, NPY_ORDER::NPY_ANYORDER) } diff --git a/src/borrow.rs b/src/borrow.rs index 76a13e364..1680c9d1b 100644 --- a/src/borrow.rs +++ b/src/borrow.rs @@ -170,6 +170,86 @@ impl BorrowFlags { unsafe fn get(&self) -> &mut HashMap { (*self.0.get()).get_or_insert_with(HashMap::new) } + + fn acquire(&self, array: &PyArray) -> Result<(), BorrowError> { + let address = base_address(array); + + // SAFETY: Access to `&PyArray` implies holding the GIL + // and we are not calling into user code which might re-enter this function. + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + + match borrow_flags.entry(address) { + Entry::Occupied(entry) => { + let readers = entry.into_mut(); + + let new_readers = readers.wrapping_add(1); + + if new_readers <= 0 { + cold(); + return Err(BorrowError::AlreadyBorrowed); + } + + *readers = new_readers; + } + Entry::Vacant(entry) => { + entry.insert(1); + } + } + + Ok(()) + } + + fn release(&self, array: &PyArray) { + let address = base_address(array); + + // SAFETY: Access to `&PyArray` implies holding the GIL + // and we are not calling into user code which might re-enter this function. + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + + let readers = borrow_flags.get_mut(&address).unwrap(); + + *readers -= 1; + + if *readers == 0 { + borrow_flags.remove(&address).unwrap(); + } + } + + fn acquire_mut(&self, array: &PyArray) -> Result<(), BorrowError> { + let address = base_address(array); + + // SAFETY: Access to `&PyArray` implies holding the GIL + // and we are not calling into user code which might re-enter this function. + let borrow_flags = unsafe { BORROW_FLAGS.get() }; + + match borrow_flags.entry(address) { + Entry::Occupied(entry) => { + let writers = entry.into_mut(); + + if *writers != 0 { + cold(); + return Err(BorrowError::AlreadyBorrowed); + } + + *writers = -1; + } + Entry::Vacant(entry) => { + entry.insert(-1); + } + } + + Ok(()) + } + + fn release_mut(&self, array: &PyArray) { + let address = base_address(array); + + // SAFETY: Access to `&PyArray` implies holding the GIL + // and we are not calling into user code which might re-enter this function. + let borrow_flags = unsafe { self.get() }; + + borrow_flags.remove(&address).unwrap(); + } } static BORROW_FLAGS: BorrowFlags = BorrowFlags::new(); @@ -224,29 +304,7 @@ where D: Dimension, { pub(crate) fn try_new(array: &'py PyArray) -> Result { - let address = base_address(array); - - // SAFETY: Access to a `&'py PyArray` implies holding the GIL - // and we are not calling into user code which might re-enter this function. - let borrow_flags = unsafe { BORROW_FLAGS.get() }; - - match borrow_flags.entry(address) { - Entry::Occupied(entry) => { - let readers = entry.into_mut(); - - let new_readers = readers.wrapping_add(1); - - if new_readers <= 0 { - cold(); - return Err(BorrowError::AlreadyBorrowed); - } - - *readers = new_readers; - } - Entry::Vacant(entry) => { - entry.insert(1); - } - } + BORROW_FLAGS.acquire(array)?; Ok(Self(array)) } @@ -275,21 +333,19 @@ where } } +impl<'a, T, D> Clone for PyReadonlyArray<'a, T, D> +where + T: Element, + D: Dimension, +{ + fn clone(&self) -> Self { + Self::try_new(self.0).unwrap() + } +} + impl<'a, T, D> Drop for PyReadonlyArray<'a, T, D> { fn drop(&mut self) { - let address = base_address(self.0); - - // SAFETY: Access to a `&'py PyArray` implies holding the GIL - // and we are not calling into user code which might re-enter this function. - let borrow_flags = unsafe { BORROW_FLAGS.get() }; - - let readers = borrow_flags.get_mut(&address).unwrap(); - - *readers -= 1; - - if *readers == 0 { - borrow_flags.remove(&address).unwrap(); - } + BORROW_FLAGS.release(self.0); } } @@ -348,27 +404,7 @@ where return Err(BorrowError::NotWriteable); } - let address = base_address(array); - - // SAFETY: Access to a `&'py PyArray` implies holding the GIL - // and we are not calling into user code which might re-enter this function. - let borrow_flags = unsafe { BORROW_FLAGS.get() }; - - match borrow_flags.entry(address) { - Entry::Occupied(entry) => { - let writers = entry.into_mut(); - - if *writers != 0 { - cold(); - return Err(BorrowError::AlreadyBorrowed); - } - - *writers = -1; - } - Entry::Vacant(entry) => { - entry.insert(-1); - } - } + BORROW_FLAGS.acquire_mut(array)?; Ok(Self(array)) } @@ -397,15 +433,44 @@ where } } -impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> { - fn drop(&mut self) { - let address = base_address(self.0); +impl<'py, T> PyReadwriteArray<'py, T, Ix1> +where + T: Element, +{ + /// Extends or truncates the length of a one-dimensional array. + /// + /// # Example + /// + /// ``` + /// use numpy::PyArray; + /// use pyo3::Python; + /// + /// Python::with_gil(|py| { + /// let pyarray = PyArray::arange(py, 0, 10, 1); + /// assert_eq!(pyarray.len(), 10); + /// + /// let pyarray = pyarray.readwrite(); + /// let pyarray = pyarray.resize(100).unwrap(); + /// assert_eq!(pyarray.len(), 100); + /// }); + /// ``` + pub fn resize(self, new_elems: usize) -> PyResult { + BORROW_FLAGS.release_mut(self.0); + + // SAFETY: Ownership of `self` proves exclusive access to the interior of the array. + unsafe { + self.0.resize(new_elems)?; + } - // SAFETY: Access to a `&'py PyArray` implies holding the GIL - // and we are not calling into user code which might re-enter this function. - let borrow_flags = unsafe { BORROW_FLAGS.get() }; + BORROW_FLAGS.acquire_mut(self.0)?; - borrow_flags.remove(&address).unwrap(); + Ok(self) + } +} + +impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> { + fn drop(&mut self) { + BORROW_FLAGS.release_mut(self.0); } } diff --git a/src/convert.rs b/src/convert.rs index cffc3f9ab..c107a1262 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -29,7 +29,9 @@ use crate::sealed::Sealed; /// assert_eq!(py_array.readonly().as_slice().unwrap(), &[1, 2, 3]); /// /// // Array cannot be resized when its data is owned by Rust. -/// assert!(py_array.resize(100).is_err()); +/// unsafe { +/// assert!(py_array.resize(100).is_err()); +/// } /// }); /// ``` pub trait IntoPyArray { diff --git a/tests/borrow.rs b/tests/borrow.rs index c090b5852..dac90b87d 100644 --- a/tests/borrow.rs +++ b/tests/borrow.rs @@ -115,6 +115,19 @@ fn borrows_span_threads() { }); } +#[test] +fn shared_borrows_can_be_cloned() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, (1, 2, 3), false); + + let shared1 = array.readonly(); + let shared2 = shared1.clone(); + + assert_eq!(shared2.shape(), [1, 2, 3]); + assert_eq!(shared1.shape(), [1, 2, 3]); + }); +} + #[test] #[should_panic(expected = "AlreadyBorrowed")] fn overlapping_views_conflict() { @@ -235,3 +248,17 @@ fn readwrite_as_array_slice() { assert_eq!(*array.get_mut([0, 1, 2]).unwrap(), 0.0); }); } + +#[test] +fn resize_using_exclusive_borrow() { + Python::with_gil(|py| { + let array = PyArray::::zeros(py, 3, false); + assert_eq!(array.shape(), [3]); + + let mut array = array.readwrite(); + assert_eq!(array.as_slice_mut().unwrap(), &[0.0; 3]); + + let mut array = array.resize(5).unwrap(); + assert_eq!(array.as_slice_mut().unwrap(), &[0.0; 5]); + }); +} diff --git a/tests/to_py.rs b/tests/to_py.rs index 4f31904a5..57392b4db 100644 --- a/tests/to_py.rs +++ b/tests/to_py.rs @@ -161,7 +161,9 @@ fn into_pyarray_cannot_resize() { Python::with_gil(|py| { let arr = vec![1, 2, 3].into_pyarray(py); - assert!(arr.resize(100).is_err()) + unsafe { + assert!(arr.resize(100).is_err()); + } }); }