Skip to content

Make iterators covariant in element type #1417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/data_repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ impl<A> OwnedRepr<A>
self.ptr.as_ptr()
}

pub(crate) fn as_ptr_mut(&self) -> *mut A
{
self.ptr.as_ptr()
}

pub(crate) fn as_nonnull_mut(&mut self) -> NonNull<A>
{
self.ptr
Expand Down
18 changes: 10 additions & 8 deletions src/impl_owned_array.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use core::ptr::NonNull;
use std::mem;
use std::mem::MaybeUninit;

#[allow(unused_imports)]
#[allow(unused_imports)] // Needed for Rust 1.64
use rawpointer::PointerExt;

use crate::imp_prelude::*;
Expand Down Expand Up @@ -435,7 +436,7 @@ where D: Dimension
// "deconstruct" self; the owned repr releases ownership of all elements and we
// carry on with raw view methods
let data_len = self.data.len();
let data_ptr = self.data.as_nonnull_mut().as_ptr();
let data_ptr = self.data.as_nonnull_mut();

unsafe {
// Safety: self.data releases ownership of the elements. Any panics below this point
Expand Down Expand Up @@ -866,8 +867,9 @@ where D: Dimension
///
/// This is an internal function for use by move_into and IntoIter only, safety invariants may need
/// to be upheld across the calls from those implementations.
pub(crate) unsafe fn drop_unreachable_raw<A, D>(mut self_: RawArrayViewMut<A, D>, data_ptr: *mut A, data_len: usize)
where D: Dimension
pub(crate) unsafe fn drop_unreachable_raw<A, D>(
mut self_: RawArrayViewMut<A, D>, data_ptr: NonNull<A>, data_len: usize,
) where D: Dimension
{
let self_len = self_.len();

Expand All @@ -878,7 +880,7 @@ where D: Dimension
}
sort_axes_in_default_order(&mut self_);
// with uninverted axes this is now the element with lowest address
let array_memory_head_ptr = self_.ptr.as_ptr();
let array_memory_head_ptr = self_.ptr;
let data_end_ptr = data_ptr.add(data_len);
debug_assert!(data_ptr <= array_memory_head_ptr);
debug_assert!(array_memory_head_ptr <= data_end_ptr);
Expand Down Expand Up @@ -907,7 +909,7 @@ where D: Dimension

// iter is a raw pointer iterator traversing the array in memory order now with the
// sorted axes.
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
let mut iter = Baseiter::new(self_.ptr, self_.dim, self_.strides);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review note: it's evident that Baseiter is constructed from a NonNull everywhere, which means that its non-null requirement is easily fulfilled.

let mut dropped_elements = 0;

let mut last_ptr = data_ptr;
Expand All @@ -917,7 +919,7 @@ where D: Dimension
// should now be dropped. This interval may be empty, then we just skip this loop.
while last_ptr != elem_ptr {
debug_assert!(last_ptr < data_end_ptr);
std::ptr::drop_in_place(last_ptr);
std::ptr::drop_in_place(last_ptr.as_mut());
last_ptr = last_ptr.add(1);
dropped_elements += 1;
}
Expand All @@ -926,7 +928,7 @@ where D: Dimension
}

while last_ptr < data_end_ptr {
std::ptr::drop_in_place(last_ptr);
std::ptr::drop_in_place(last_ptr.as_mut());
last_ptr = last_ptr.add(1);
dropped_elements += 1;
}
Expand Down
8 changes: 4 additions & 4 deletions src/impl_views/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}
}

Expand All @@ -209,7 +209,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}
}

Expand All @@ -220,7 +220,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}

#[inline]
Expand Down Expand Up @@ -262,7 +262,7 @@ where D: Dimension
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
{
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
}

#[inline]
Expand Down
4 changes: 2 additions & 2 deletions src/iterators/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ impl_iterator! {

fn item(&mut self, ptr) {
unsafe {
ArrayView::new_(
ArrayView::new(
ptr,
self.chunk.clone(),
self.inner_strides.clone())
Expand All @@ -226,7 +226,7 @@ impl_iterator! {

fn item(&mut self, ptr) {
unsafe {
ArrayViewMut::new_(
ArrayViewMut::new(
ptr,
self.chunk.clone(),
self.inner_strides.clone())
Expand Down
9 changes: 4 additions & 5 deletions src/iterators/into_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ impl<A, D> IntoIter<A, D>
where D: Dimension
{
/// Create a new by-value iterator that consumes `array`
pub(crate) fn new(mut array: Array<A, D>) -> Self
pub(crate) fn new(array: Array<A, D>) -> Self
{
unsafe {
let array_head_ptr = array.ptr;
let ptr = array.as_mut_ptr();
let mut array_data = array.data;
let data_len = array_data.release_all_elements();
debug_assert!(data_len >= array.dim.size());
let has_unreachable_elements = array.dim.size() != data_len;
let inner = Baseiter::new(ptr, array.dim, array.strides);
let inner = Baseiter::new(array_head_ptr, array.dim, array.strides);

IntoIter {
array_data,
Expand All @@ -62,7 +61,7 @@ impl<A, D: Dimension> Iterator for IntoIter<A, D>
#[inline]
fn next(&mut self) -> Option<A>
{
self.inner.next().map(|p| unsafe { p.read() })
self.inner.next().map(|p| unsafe { p.as_ptr().read() })
}

fn size_hint(&self) -> (usize, Option<usize>)
Expand Down Expand Up @@ -92,7 +91,7 @@ where D: Dimension
while let Some(_) = self.next() {}

unsafe {
let data_ptr = self.array_data.as_ptr_mut();
let data_ptr = self.array_data.as_nonnull_mut();
let view = RawArrayViewMut::new(self.array_head_ptr, self.inner.dim.clone(), self.inner.strides.clone());
debug_assert!(self.inner.dim.size() < self.data_len, "data_len {} and dim size {}",
self.data_len, self.inner.dim.size());
Expand Down
56 changes: 33 additions & 23 deletions src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ use alloc::vec::Vec;
use std::iter::FromIterator;
use std::marker::PhantomData;
use std::ptr;
use std::ptr::NonNull;

#[allow(unused_imports)] // Needed for Rust 1.64
use rawpointer::PointerExt;

use crate::Ix1;

Expand All @@ -34,11 +38,11 @@ use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};

/// Base for iterators over all axes.
///
/// Iterator element type is `*mut A`.
/// Iterator element type is `NonNull<A>`.
#[derive(Debug)]
pub struct Baseiter<A, D>
{
ptr: *mut A,
ptr: NonNull<A>,
dim: D,
strides: D,
index: Option<D>,
Expand All @@ -50,7 +54,7 @@ impl<A, D: Dimension> Baseiter<A, D>
/// to be correct to avoid performing an unsafe pointer offset while
/// iterating.
#[inline]
pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<A, D>
pub unsafe fn new(ptr: NonNull<A>, len: D, stride: D) -> Baseiter<A, D>
{
Baseiter {
ptr,
Expand All @@ -63,10 +67,10 @@ impl<A, D: Dimension> Baseiter<A, D>

impl<A, D: Dimension> Iterator for Baseiter<A, D>
{
type Item = *mut A;
type Item = NonNull<A>;

#[inline]
fn next(&mut self) -> Option<*mut A>
fn next(&mut self) -> Option<Self::Item>
{
let index = match self.index {
None => return None,
Expand All @@ -84,7 +88,7 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D>
}

fn fold<Acc, G>(mut self, init: Acc, mut g: G) -> Acc
where G: FnMut(Acc, *mut A) -> Acc
where G: FnMut(Acc, Self::Item) -> Acc
{
let ndim = self.dim.ndim();
debug_assert_ne!(ndim, 0);
Expand Down Expand Up @@ -133,28 +137,28 @@ impl<A, D: Dimension> ExactSizeIterator for Baseiter<A, D>
impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
{
#[inline]
fn next_back(&mut self) -> Option<*mut A>
fn next_back(&mut self) -> Option<Self::Item>
{
let index = match self.index {
None => return None,
Some(ix) => ix,
};
self.dim[0] -= 1;
let offset = <_>::stride_offset(&self.dim, &self.strides);
let offset = Ix1::stride_offset(&self.dim, &self.strides);
if index == self.dim {
self.index = None;
}

unsafe { Some(self.ptr.offset(offset)) }
}

fn nth_back(&mut self, n: usize) -> Option<*mut A>
fn nth_back(&mut self, n: usize) -> Option<Self::Item>
{
let index = self.index?;
let len = self.dim[0] - index[0];
if n < len {
self.dim[0] -= n + 1;
let offset = <_>::stride_offset(&self.dim, &self.strides);
let offset = Ix1::stride_offset(&self.dim, &self.strides);
if index == self.dim {
self.index = None;
}
Expand All @@ -166,7 +170,7 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
}

fn rfold<Acc, G>(mut self, init: Acc, mut g: G) -> Acc
where G: FnMut(Acc, *mut A) -> Acc
where G: FnMut(Acc, Self::Item) -> Acc
{
let mut accum = init;
if let Some(index) = self.index {
Expand Down Expand Up @@ -226,7 +230,7 @@ impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D>
#[inline]
fn next(&mut self) -> Option<&'a A>
{
self.inner.next().map(|p| unsafe { &*p })
self.inner.next().map(|p| unsafe { p.as_ref() })
}

fn size_hint(&self) -> (usize, Option<usize>)
Expand All @@ -237,7 +241,7 @@ impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D>
fn fold<Acc, G>(self, init: Acc, mut g: G) -> Acc
where G: FnMut(Acc, Self::Item) -> Acc
{
unsafe { self.inner.fold(init, move |acc, ptr| g(acc, &*ptr)) }
unsafe { self.inner.fold(init, move |acc, ptr| g(acc, ptr.as_ref())) }
}
}

Expand All @@ -246,13 +250,13 @@ impl<'a, A> DoubleEndedIterator for ElementsBase<'a, A, Ix1>
#[inline]
fn next_back(&mut self) -> Option<&'a A>
{
self.inner.next_back().map(|p| unsafe { &*p })
self.inner.next_back().map(|p| unsafe { p.as_ref() })
}

fn rfold<Acc, G>(self, init: Acc, mut g: G) -> Acc
where G: FnMut(Acc, Self::Item) -> Acc
{
unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, &*ptr)) }
unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, ptr.as_ref())) }
}
}

Expand Down Expand Up @@ -646,7 +650,7 @@ impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D>
#[inline]
fn next(&mut self) -> Option<&'a mut A>
{
self.inner.next().map(|p| unsafe { &mut *p })
self.inner.next().map(|mut p| unsafe { p.as_mut() })
}

fn size_hint(&self) -> (usize, Option<usize>)
Expand All @@ -657,7 +661,10 @@ impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D>
fn fold<Acc, G>(self, init: Acc, mut g: G) -> Acc
where G: FnMut(Acc, Self::Item) -> Acc
{
unsafe { self.inner.fold(init, move |acc, ptr| g(acc, &mut *ptr)) }
unsafe {
self.inner
.fold(init, move |acc, mut ptr| g(acc, ptr.as_mut()))
}
}
}

Expand All @@ -666,13 +673,16 @@ impl<'a, A> DoubleEndedIterator for ElementsBaseMut<'a, A, Ix1>
#[inline]
fn next_back(&mut self) -> Option<&'a mut A>
{
self.inner.next_back().map(|p| unsafe { &mut *p })
self.inner.next_back().map(|mut p| unsafe { p.as_mut() })
}

fn rfold<Acc, G>(self, init: Acc, mut g: G) -> Acc
where G: FnMut(Acc, Self::Item) -> Acc
{
unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, &mut *ptr)) }
unsafe {
self.inner
.rfold(init, move |acc, mut ptr| g(acc, ptr.as_mut()))
}
}
}

Expand Down Expand Up @@ -748,7 +758,7 @@ where D: Dimension
{
self.iter
.next()
.map(|ptr| unsafe { ArrayView::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
.map(|ptr| unsafe { ArrayView::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
}

fn size_hint(&self) -> (usize, Option<usize>)
Expand All @@ -772,7 +782,7 @@ impl<'a, A> DoubleEndedIterator for LanesIter<'a, A, Ix1>
{
self.iter
.next_back()
.map(|ptr| unsafe { ArrayView::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
.map(|ptr| unsafe { ArrayView::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
}
}

Expand Down Expand Up @@ -800,7 +810,7 @@ where D: Dimension
{
self.iter
.next()
.map(|ptr| unsafe { ArrayViewMut::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
.map(|ptr| unsafe { ArrayViewMut::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
}

fn size_hint(&self) -> (usize, Option<usize>)
Expand All @@ -824,7 +834,7 @@ impl<'a, A> DoubleEndedIterator for LanesIterMut<'a, A, Ix1>
{
self.iter
.next_back()
.map(|ptr| unsafe { ArrayViewMut::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
.map(|ptr| unsafe { ArrayViewMut::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/iterators/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl_iterator! {

fn item(&mut self, ptr) {
unsafe {
ArrayView::new_(
ArrayView::new(
ptr,
self.window.clone(),
self.strides.clone())
Expand Down
Loading