Skip to content

Commit f7e4750

Browse files
authored
Merge pull request #1417 from rust-ndarray/baseiter-covariant
Make iterators covariant in element type
2 parents 84fe611 + 00e1546 commit f7e4750

File tree

8 files changed

+85
-51
lines changed

8 files changed

+85
-51
lines changed

src/data_repr.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,6 @@ impl<A> OwnedRepr<A>
5959
self.ptr.as_ptr()
6060
}
6161

62-
pub(crate) fn as_ptr_mut(&self) -> *mut A
63-
{
64-
self.ptr.as_ptr()
65-
}
66-
6762
pub(crate) fn as_nonnull_mut(&mut self) -> NonNull<A>
6863
{
6964
self.ptr

src/impl_owned_array.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#[cfg(not(feature = "std"))]
22
use alloc::vec::Vec;
3+
use core::ptr::NonNull;
34
use std::mem;
45
use std::mem::MaybeUninit;
56

6-
#[allow(unused_imports)]
7+
#[allow(unused_imports)] // Needed for Rust 1.64
78
use rawpointer::PointerExt;
89

910
use crate::imp_prelude::*;
@@ -435,7 +436,7 @@ where D: Dimension
435436
// "deconstruct" self; the owned repr releases ownership of all elements and we
436437
// carry on with raw view methods
437438
let data_len = self.data.len();
438-
let data_ptr = self.data.as_nonnull_mut().as_ptr();
439+
let data_ptr = self.data.as_nonnull_mut();
439440

440441
unsafe {
441442
// Safety: self.data releases ownership of the elements. Any panics below this point
@@ -866,8 +867,9 @@ where D: Dimension
866867
///
867868
/// This is an internal function for use by move_into and IntoIter only, safety invariants may need
868869
/// to be upheld across the calls from those implementations.
869-
pub(crate) unsafe fn drop_unreachable_raw<A, D>(mut self_: RawArrayViewMut<A, D>, data_ptr: *mut A, data_len: usize)
870-
where D: Dimension
870+
pub(crate) unsafe fn drop_unreachable_raw<A, D>(
871+
mut self_: RawArrayViewMut<A, D>, data_ptr: NonNull<A>, data_len: usize,
872+
) where D: Dimension
871873
{
872874
let self_len = self_.len();
873875

@@ -878,7 +880,7 @@ where D: Dimension
878880
}
879881
sort_axes_in_default_order(&mut self_);
880882
// with uninverted axes this is now the element with lowest address
881-
let array_memory_head_ptr = self_.ptr.as_ptr();
883+
let array_memory_head_ptr = self_.ptr;
882884
let data_end_ptr = data_ptr.add(data_len);
883885
debug_assert!(data_ptr <= array_memory_head_ptr);
884886
debug_assert!(array_memory_head_ptr <= data_end_ptr);
@@ -907,7 +909,7 @@ where D: Dimension
907909

908910
// iter is a raw pointer iterator traversing the array in memory order now with the
909911
// sorted axes.
910-
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
912+
let mut iter = Baseiter::new(self_.ptr, self_.dim, self_.strides);
911913
let mut dropped_elements = 0;
912914

913915
let mut last_ptr = data_ptr;
@@ -917,7 +919,7 @@ where D: Dimension
917919
// should now be dropped. This interval may be empty, then we just skip this loop.
918920
while last_ptr != elem_ptr {
919921
debug_assert!(last_ptr < data_end_ptr);
920-
std::ptr::drop_in_place(last_ptr);
922+
std::ptr::drop_in_place(last_ptr.as_mut());
921923
last_ptr = last_ptr.add(1);
922924
dropped_elements += 1;
923925
}
@@ -926,7 +928,7 @@ where D: Dimension
926928
}
927929

928930
while last_ptr < data_end_ptr {
929-
std::ptr::drop_in_place(last_ptr);
931+
std::ptr::drop_in_place(last_ptr.as_mut());
930932
last_ptr = last_ptr.add(1);
931933
dropped_elements += 1;
932934
}

src/impl_views/conversions.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ where D: Dimension
199199
#[inline]
200200
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
201201
{
202-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
202+
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
203203
}
204204
}
205205

@@ -209,7 +209,7 @@ where D: Dimension
209209
#[inline]
210210
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
211211
{
212-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
212+
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
213213
}
214214
}
215215

@@ -220,7 +220,7 @@ where D: Dimension
220220
#[inline]
221221
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
222222
{
223-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
223+
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
224224
}
225225

226226
#[inline]
@@ -262,7 +262,7 @@ where D: Dimension
262262
#[inline]
263263
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
264264
{
265-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
265+
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
266266
}
267267

268268
#[inline]

src/iterators/chunks.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ impl_iterator! {
204204

205205
fn item(&mut self, ptr) {
206206
unsafe {
207-
ArrayView::new_(
207+
ArrayView::new(
208208
ptr,
209209
self.chunk.clone(),
210210
self.inner_strides.clone())
@@ -226,7 +226,7 @@ impl_iterator! {
226226

227227
fn item(&mut self, ptr) {
228228
unsafe {
229-
ArrayViewMut::new_(
229+
ArrayViewMut::new(
230230
ptr,
231231
self.chunk.clone(),
232232
self.inner_strides.clone())

src/iterators/into_iter.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,15 @@ impl<A, D> IntoIter<A, D>
3333
where D: Dimension
3434
{
3535
/// Create a new by-value iterator that consumes `array`
36-
pub(crate) fn new(mut array: Array<A, D>) -> Self
36+
pub(crate) fn new(array: Array<A, D>) -> Self
3737
{
3838
unsafe {
3939
let array_head_ptr = array.ptr;
40-
let ptr = array.as_mut_ptr();
4140
let mut array_data = array.data;
4241
let data_len = array_data.release_all_elements();
4342
debug_assert!(data_len >= array.dim.size());
4443
let has_unreachable_elements = array.dim.size() != data_len;
45-
let inner = Baseiter::new(ptr, array.dim, array.strides);
44+
let inner = Baseiter::new(array_head_ptr, array.dim, array.strides);
4645

4746
IntoIter {
4847
array_data,
@@ -62,7 +61,7 @@ impl<A, D: Dimension> Iterator for IntoIter<A, D>
6261
#[inline]
6362
fn next(&mut self) -> Option<A>
6463
{
65-
self.inner.next().map(|p| unsafe { p.read() })
64+
self.inner.next().map(|p| unsafe { p.as_ptr().read() })
6665
}
6766

6867
fn size_hint(&self) -> (usize, Option<usize>)
@@ -92,7 +91,7 @@ where D: Dimension
9291
while let Some(_) = self.next() {}
9392

9493
unsafe {
95-
let data_ptr = self.array_data.as_ptr_mut();
94+
let data_ptr = self.array_data.as_nonnull_mut();
9695
let view = RawArrayViewMut::new(self.array_head_ptr, self.inner.dim.clone(), self.inner.strides.clone());
9796
debug_assert!(self.inner.dim.size() < self.data_len, "data_len {} and dim size {}",
9897
self.data_len, self.inner.dim.size());

src/iterators/mod.rs

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ use alloc::vec::Vec;
1919
use std::iter::FromIterator;
2020
use std::marker::PhantomData;
2121
use std::ptr;
22+
use std::ptr::NonNull;
23+
24+
#[allow(unused_imports)] // Needed for Rust 1.64
25+
use rawpointer::PointerExt;
2226

2327
use crate::Ix1;
2428

@@ -34,11 +38,11 @@ use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
3438

3539
/// Base for iterators over all axes.
3640
///
37-
/// Iterator element type is `*mut A`.
41+
/// Iterator element type is `NonNull<A>`.
3842
#[derive(Debug)]
3943
pub struct Baseiter<A, D>
4044
{
41-
ptr: *mut A,
45+
ptr: NonNull<A>,
4246
dim: D,
4347
strides: D,
4448
index: Option<D>,
@@ -50,7 +54,7 @@ impl<A, D: Dimension> Baseiter<A, D>
5054
/// to be correct to avoid performing an unsafe pointer offset while
5155
/// iterating.
5256
#[inline]
53-
pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<A, D>
57+
pub unsafe fn new(ptr: NonNull<A>, len: D, stride: D) -> Baseiter<A, D>
5458
{
5559
Baseiter {
5660
ptr,
@@ -63,10 +67,10 @@ impl<A, D: Dimension> Baseiter<A, D>
6367

6468
impl<A, D: Dimension> Iterator for Baseiter<A, D>
6569
{
66-
type Item = *mut A;
70+
type Item = NonNull<A>;
6771

6872
#[inline]
69-
fn next(&mut self) -> Option<*mut A>
73+
fn next(&mut self) -> Option<Self::Item>
7074
{
7175
let index = match self.index {
7276
None => return None,
@@ -84,7 +88,7 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D>
8488
}
8589

8690
fn fold<Acc, G>(mut self, init: Acc, mut g: G) -> Acc
87-
where G: FnMut(Acc, *mut A) -> Acc
91+
where G: FnMut(Acc, Self::Item) -> Acc
8892
{
8993
let ndim = self.dim.ndim();
9094
debug_assert_ne!(ndim, 0);
@@ -133,28 +137,28 @@ impl<A, D: Dimension> ExactSizeIterator for Baseiter<A, D>
133137
impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
134138
{
135139
#[inline]
136-
fn next_back(&mut self) -> Option<*mut A>
140+
fn next_back(&mut self) -> Option<Self::Item>
137141
{
138142
let index = match self.index {
139143
None => return None,
140144
Some(ix) => ix,
141145
};
142146
self.dim[0] -= 1;
143-
let offset = <_>::stride_offset(&self.dim, &self.strides);
147+
let offset = Ix1::stride_offset(&self.dim, &self.strides);
144148
if index == self.dim {
145149
self.index = None;
146150
}
147151

148152
unsafe { Some(self.ptr.offset(offset)) }
149153
}
150154

151-
fn nth_back(&mut self, n: usize) -> Option<*mut A>
155+
fn nth_back(&mut self, n: usize) -> Option<Self::Item>
152156
{
153157
let index = self.index?;
154158
let len = self.dim[0] - index[0];
155159
if n < len {
156160
self.dim[0] -= n + 1;
157-
let offset = <_>::stride_offset(&self.dim, &self.strides);
161+
let offset = Ix1::stride_offset(&self.dim, &self.strides);
158162
if index == self.dim {
159163
self.index = None;
160164
}
@@ -166,7 +170,7 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
166170
}
167171

168172
fn rfold<Acc, G>(mut self, init: Acc, mut g: G) -> Acc
169-
where G: FnMut(Acc, *mut A) -> Acc
173+
where G: FnMut(Acc, Self::Item) -> Acc
170174
{
171175
let mut accum = init;
172176
if let Some(index) = self.index {
@@ -226,7 +230,7 @@ impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D>
226230
#[inline]
227231
fn next(&mut self) -> Option<&'a A>
228232
{
229-
self.inner.next().map(|p| unsafe { &*p })
233+
self.inner.next().map(|p| unsafe { p.as_ref() })
230234
}
231235

232236
fn size_hint(&self) -> (usize, Option<usize>)
@@ -237,7 +241,7 @@ impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D>
237241
fn fold<Acc, G>(self, init: Acc, mut g: G) -> Acc
238242
where G: FnMut(Acc, Self::Item) -> Acc
239243
{
240-
unsafe { self.inner.fold(init, move |acc, ptr| g(acc, &*ptr)) }
244+
unsafe { self.inner.fold(init, move |acc, ptr| g(acc, ptr.as_ref())) }
241245
}
242246
}
243247

@@ -246,13 +250,13 @@ impl<'a, A> DoubleEndedIterator for ElementsBase<'a, A, Ix1>
246250
#[inline]
247251
fn next_back(&mut self) -> Option<&'a A>
248252
{
249-
self.inner.next_back().map(|p| unsafe { &*p })
253+
self.inner.next_back().map(|p| unsafe { p.as_ref() })
250254
}
251255

252256
fn rfold<Acc, G>(self, init: Acc, mut g: G) -> Acc
253257
where G: FnMut(Acc, Self::Item) -> Acc
254258
{
255-
unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, &*ptr)) }
259+
unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, ptr.as_ref())) }
256260
}
257261
}
258262

@@ -646,7 +650,7 @@ impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D>
646650
#[inline]
647651
fn next(&mut self) -> Option<&'a mut A>
648652
{
649-
self.inner.next().map(|p| unsafe { &mut *p })
653+
self.inner.next().map(|mut p| unsafe { p.as_mut() })
650654
}
651655

652656
fn size_hint(&self) -> (usize, Option<usize>)
@@ -657,7 +661,10 @@ impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D>
657661
fn fold<Acc, G>(self, init: Acc, mut g: G) -> Acc
658662
where G: FnMut(Acc, Self::Item) -> Acc
659663
{
660-
unsafe { self.inner.fold(init, move |acc, ptr| g(acc, &mut *ptr)) }
664+
unsafe {
665+
self.inner
666+
.fold(init, move |acc, mut ptr| g(acc, ptr.as_mut()))
667+
}
661668
}
662669
}
663670

@@ -666,13 +673,16 @@ impl<'a, A> DoubleEndedIterator for ElementsBaseMut<'a, A, Ix1>
666673
#[inline]
667674
fn next_back(&mut self) -> Option<&'a mut A>
668675
{
669-
self.inner.next_back().map(|p| unsafe { &mut *p })
676+
self.inner.next_back().map(|mut p| unsafe { p.as_mut() })
670677
}
671678

672679
fn rfold<Acc, G>(self, init: Acc, mut g: G) -> Acc
673680
where G: FnMut(Acc, Self::Item) -> Acc
674681
{
675-
unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, &mut *ptr)) }
682+
unsafe {
683+
self.inner
684+
.rfold(init, move |acc, mut ptr| g(acc, ptr.as_mut()))
685+
}
676686
}
677687
}
678688

@@ -748,7 +758,7 @@ where D: Dimension
748758
{
749759
self.iter
750760
.next()
751-
.map(|ptr| unsafe { ArrayView::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
761+
.map(|ptr| unsafe { ArrayView::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
752762
}
753763

754764
fn size_hint(&self) -> (usize, Option<usize>)
@@ -772,7 +782,7 @@ impl<'a, A> DoubleEndedIterator for LanesIter<'a, A, Ix1>
772782
{
773783
self.iter
774784
.next_back()
775-
.map(|ptr| unsafe { ArrayView::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
785+
.map(|ptr| unsafe { ArrayView::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
776786
}
777787
}
778788

@@ -800,7 +810,7 @@ where D: Dimension
800810
{
801811
self.iter
802812
.next()
803-
.map(|ptr| unsafe { ArrayViewMut::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
813+
.map(|ptr| unsafe { ArrayViewMut::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
804814
}
805815

806816
fn size_hint(&self) -> (usize, Option<usize>)
@@ -824,7 +834,7 @@ impl<'a, A> DoubleEndedIterator for LanesIterMut<'a, A, Ix1>
824834
{
825835
self.iter
826836
.next_back()
827-
.map(|ptr| unsafe { ArrayViewMut::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
837+
.map(|ptr| unsafe { ArrayViewMut::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) })
828838
}
829839
}
830840

src/iterators/windows.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ impl_iterator! {
115115

116116
fn item(&mut self, ptr) {
117117
unsafe {
118-
ArrayView::new_(
118+
ArrayView::new(
119119
ptr,
120120
self.window.clone(),
121121
self.strides.clone())

0 commit comments

Comments
 (0)