Skip to content

Commit 75b5f93

Browse files
committed
Make iterators covariant in element type
The internal Baseiter type underlies most of the ndarray iterators, and it used `*mut A` for element type A. Update it to use `NonNull<A>` which behaves identically except it's guaranteed to be non-null and is covariant w.r.t the parameter A. Add compile test from the issue. Fixes #1290
1 parent 84fe611 commit 75b5f93

File tree

5 files changed

+47
-18
lines changed

5 files changed

+47
-18
lines changed

src/impl_owned_array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ where D: Dimension
907907

908908
// iter is a raw pointer iterator traversing the array in memory order now with the
909909
// sorted axes.
910-
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
910+
let mut iter = Baseiter::new(self_.ptr, self_.dim, self_.strides);
911911
let mut dropped_elements = 0;
912912

913913
let mut last_ptr = data_ptr;

src/impl_views/conversions.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ where D: Dimension
199199
#[inline]
200200
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
201201
{
202-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
202+
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
203203
}
204204
}
205205

@@ -209,7 +209,7 @@ where D: Dimension
209209
#[inline]
210210
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
211211
{
212-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
212+
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
213213
}
214214
}
215215

@@ -220,7 +220,7 @@ where D: Dimension
220220
#[inline]
221221
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
222222
{
223-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
223+
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
224224
}
225225

226226
#[inline]
@@ -262,7 +262,7 @@ where D: Dimension
262262
#[inline]
263263
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
264264
{
265-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
265+
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
266266
}
267267

268268
#[inline]

src/iterators/into_iter.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,15 @@ impl<A, D> IntoIter<A, D>
3333
where D: Dimension
3434
{
3535
/// Create a new by-value iterator that consumes `array`
36-
pub(crate) fn new(mut array: Array<A, D>) -> Self
36+
pub(crate) fn new(array: Array<A, D>) -> Self
3737
{
3838
unsafe {
3939
let array_head_ptr = array.ptr;
40-
let ptr = array.as_mut_ptr();
4140
let mut array_data = array.data;
4241
let data_len = array_data.release_all_elements();
4342
debug_assert!(data_len >= array.dim.size());
4443
let has_unreachable_elements = array.dim.size() != data_len;
45-
let inner = Baseiter::new(ptr, array.dim, array.strides);
44+
let inner = Baseiter::new(array_head_ptr, array.dim, array.strides);
4645

4746
IntoIter {
4847
array_data,

src/iterators/mod.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use alloc::vec::Vec;
1919
use std::iter::FromIterator;
2020
use std::marker::PhantomData;
2121
use std::ptr;
22+
use std::ptr::NonNull;
2223

2324
use crate::Ix1;
2425

@@ -38,7 +39,7 @@ use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
3839
#[derive(Debug)]
3940
pub struct Baseiter<A, D>
4041
{
41-
ptr: *mut A,
42+
ptr: NonNull<A>,
4243
dim: D,
4344
strides: D,
4445
index: Option<D>,
@@ -50,7 +51,7 @@ impl<A, D: Dimension> Baseiter<A, D>
5051
/// to be correct to avoid performing an unsafe pointer offset while
5152
/// iterating.
5253
#[inline]
53-
pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<A, D>
54+
pub unsafe fn new(ptr: NonNull<A>, len: D, stride: D) -> Baseiter<A, D>
5455
{
5556
Baseiter {
5657
ptr,
@@ -74,7 +75,7 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D>
7475
};
7576
let offset = D::stride_offset(&index, &self.strides);
7677
self.index = self.dim.next_for(index);
77-
unsafe { Some(self.ptr.offset(offset)) }
78+
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
7879
}
7980

8081
fn size_hint(&self) -> (usize, Option<usize>)
@@ -99,7 +100,7 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D>
99100
let mut i = 0;
100101
let i_end = len - elem_index;
101102
while i < i_end {
102-
accum = g(accum, row_ptr.offset(i as isize * stride));
103+
accum = g(accum, row_ptr.offset(i as isize * stride).as_ptr());
103104
i += 1;
104105
}
105106
}
@@ -145,7 +146,7 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
145146
self.index = None;
146147
}
147148

148-
unsafe { Some(self.ptr.offset(offset)) }
149+
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
149150
}
150151

151152
fn nth_back(&mut self, n: usize) -> Option<*mut A>
@@ -158,7 +159,7 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
158159
if index == self.dim {
159160
self.index = None;
160161
}
161-
unsafe { Some(self.ptr.offset(offset)) }
162+
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
162163
} else {
163164
self.index = None;
164165
None
@@ -178,7 +179,8 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
178179
accum = g(
179180
accum,
180181
self.ptr
181-
.offset(Ix1::stride_offset(&self.dim, &self.strides)),
182+
.offset(Ix1::stride_offset(&self.dim, &self.strides))
183+
.as_ptr(),
182184
);
183185
}
184186
}

tests/iterators.rs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
#![allow(
2-
clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names
3-
)]
1+
#![allow(clippy::deref_addrof, clippy::unreadable_literal)]
42

53
use ndarray::prelude::*;
64
use ndarray::{arr3, indices, s, Slice, Zip};
@@ -1055,3 +1053,33 @@ impl Drop for DropCount<'_>
10551053
self.drops.set(self.drops.get() + 1);
10561054
}
10571055
}
1056+
1057+
#[test]
1058+
fn test_impl_iter_compiles()
1059+
{
1060+
// Requires that the iterators are covariant in the element type
1061+
1062+
// base case: std
1063+
fn slice_iter_non_empty_indices<'s, 'a>(array: &'a Vec<&'s str>) -> impl Iterator<Item = usize> + 'a
1064+
{
1065+
array
1066+
.iter()
1067+
.enumerate()
1068+
.filter(|(_index, elem)| !elem.is_empty())
1069+
.map(|(index, _elem)| index)
1070+
}
1071+
1072+
let _ = slice_iter_non_empty_indices;
1073+
1074+
// ndarray case
1075+
fn array_iter_non_empty_indices<'s, 'a>(array: &'a Array<&'s str, Ix1>) -> impl Iterator<Item = usize> + 'a
1076+
{
1077+
array
1078+
.iter()
1079+
.enumerate()
1080+
.filter(|(_index, elem)| !elem.is_empty())
1081+
.map(|(index, _elem)| index)
1082+
}
1083+
1084+
let _ = array_iter_non_empty_indices;
1085+
}

0 commit comments

Comments
 (0)