Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6c8b821

Browse files
authoredDec 6, 2021
Merge pull request #980 from YuhanLiin/const-generics
Const generics Improvements
2 parents d1bb045 + 75a27e5 commit 6c8b821

File tree

6 files changed

+162
-169
lines changed

6 files changed

+162
-169
lines changed
 

‎.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
- stable
2222
- beta
2323
- nightly
24-
- 1.49.0 # MSRV
24+
- 1.51.0 # MSRV
2525

2626
steps:
2727
- uses: actions/checkout@v2

‎src/arraytraits.rs

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,20 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9-
use std::hash;
10-
use std::iter::FromIterator;
9+
use alloc::boxed::Box;
10+
use alloc::vec::Vec;
1111
use std::iter::IntoIterator;
1212
use std::mem;
1313
use std::ops::{Index, IndexMut};
14-
use alloc::boxed::Box;
15-
use alloc::vec::Vec;
14+
use std::{hash, mem::size_of};
15+
use std::{iter::FromIterator, slice};
1616

1717
use crate::imp_prelude::*;
18-
use crate::iter::{Iter, IterMut};
19-
use crate::NdIndex;
20-
21-
use crate::numeric_util;
22-
use crate::{FoldWhile, Zip};
18+
use crate::{
19+
dimension,
20+
iter::{Iter, IterMut},
21+
numeric_util, FoldWhile, NdIndex, Zip,
22+
};
2323

2424
#[cold]
2525
#[inline(never)]
@@ -323,6 +323,30 @@ where
323323
}
324324
}
325325

326+
/// Implementation of ArrayView2::from(&S) where S is a slice to a 2D array
327+
///
328+
/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A
329+
/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes).
330+
impl<'a, A, const N: usize> From<&'a [[A; N]]> for ArrayView<'a, A, Ix2> {
331+
/// Create a two-dimensional read-only array view of the data in `slice`
332+
fn from(xs: &'a [[A; N]]) -> Self {
333+
let cols = N;
334+
let rows = xs.len();
335+
let dim = Ix2(rows, cols);
336+
if size_of::<A>() == 0 {
337+
dimension::size_of_shape_checked(&dim)
338+
.expect("Product of non-zero axis lengths must not overflow isize.");
339+
}
340+
341+
// `cols * rows` is guaranteed to fit in `isize` because we checked that it fits in
342+
// `isize::MAX`
343+
unsafe {
344+
let data = slice::from_raw_parts(xs.as_ptr() as *const A, cols * rows);
345+
ArrayView::from_shape_ptr(dim, data.as_ptr())
346+
}
347+
}
348+
}
349+
326350
/// Implementation of `ArrayView::from(&A)` where `A` is an array.
327351
impl<'a, A, S, D> From<&'a ArrayBase<S, D>> for ArrayView<'a, A, D>
328352
where
@@ -355,6 +379,30 @@ where
355379
}
356380
}
357381

382+
/// Implementation of ArrayViewMut2::from(&S) where S is a slice to a 2D array
383+
///
384+
/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A
385+
/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes).
386+
impl<'a, A, const N: usize> From<&'a mut [[A; N]]> for ArrayViewMut<'a, A, Ix2> {
387+
/// Create a two-dimensional read-write array view of the data in `slice`
388+
fn from(xs: &'a mut [[A; N]]) -> Self {
389+
let cols = N;
390+
let rows = xs.len();
391+
let dim = Ix2(rows, cols);
392+
if size_of::<A>() == 0 {
393+
dimension::size_of_shape_checked(&dim)
394+
.expect("Product of non-zero axis lengths must not overflow isize.");
395+
}
396+
397+
// `cols * rows` is guaranteed to fit in `isize` because we checked that it fits in
398+
// `isize::MAX`
399+
unsafe {
400+
let data = slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut A, cols * rows);
401+
ArrayViewMut::from_shape_ptr(dim, data.as_mut_ptr())
402+
}
403+
}
404+
}
405+
358406
/// Implementation of `ArrayViewMut::from(&mut A)` where `A` is an array.
359407
impl<'a, A, S, D> From<&'a mut ArrayBase<S, D>> for ArrayViewMut<'a, A, D>
360408
where

‎src/dimension/ndindex.rs

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -140,50 +140,6 @@ macro_rules! ndindex_with_array {
140140
0
141141
}
142142
}
143-
144-
// implement NdIndex<IxDyn> for Dim<[Ix; 2]> and so on
145-
unsafe impl NdIndex<IxDyn> for Dim<[Ix; $n]> {
146-
#[inline]
147-
fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option<isize> {
148-
debug_assert_eq!(strides.ndim(), $n,
149-
"Attempted to index with {:?} in array with {} axes",
150-
self, strides.ndim());
151-
stride_offset_checked(dim.ix(), strides.ix(), self.ix())
152-
}
153-
154-
#[inline]
155-
fn index_unchecked(&self, strides: &IxDyn) -> isize {
156-
debug_assert_eq!(strides.ndim(), $n,
157-
"Attempted to index with {:?} in array with {} axes",
158-
self, strides.ndim());
159-
$(
160-
stride_offset(get!(self, $index), get!(strides, $index)) +
161-
)*
162-
0
163-
}
164-
}
165-
166-
// implement NdIndex<IxDyn> for [Ix; 2] and so on
167-
unsafe impl NdIndex<IxDyn> for [Ix; $n] {
168-
#[inline]
169-
fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option<isize> {
170-
debug_assert_eq!(strides.ndim(), $n,
171-
"Attempted to index with {:?} in array with {} axes",
172-
self, strides.ndim());
173-
stride_offset_checked(dim.ix(), strides.ix(), self)
174-
}
175-
176-
#[inline]
177-
fn index_unchecked(&self, strides: &IxDyn) -> isize {
178-
debug_assert_eq!(strides.ndim(), $n,
179-
"Attempted to index with {:?} in array with {} axes",
180-
self, strides.ndim());
181-
$(
182-
stride_offset(self[$index], get!(strides, $index)) +
183-
)*
184-
0
185-
}
186-
}
187143
)+
188144
};
189145
}
@@ -198,6 +154,64 @@ ndindex_with_array! {
198154
[6, Ix6 0 1 2 3 4 5]
199155
}
200156

157+
// implement NdIndex<IxDyn> for Dim<[Ix; 2]> and so on
158+
unsafe impl<const N: usize> NdIndex<IxDyn> for Dim<[Ix; N]> {
159+
#[inline]
160+
fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option<isize> {
161+
debug_assert_eq!(
162+
strides.ndim(),
163+
N,
164+
"Attempted to index with {:?} in array with {} axes",
165+
self,
166+
strides.ndim()
167+
);
168+
stride_offset_checked(dim.ix(), strides.ix(), self.ix())
169+
}
170+
171+
#[inline]
172+
fn index_unchecked(&self, strides: &IxDyn) -> isize {
173+
debug_assert_eq!(
174+
strides.ndim(),
175+
N,
176+
"Attempted to index with {:?} in array with {} axes",
177+
self,
178+
strides.ndim()
179+
);
180+
(0..N)
181+
.map(|i| stride_offset(get!(self, i), get!(strides, i)))
182+
.sum()
183+
}
184+
}
185+
186+
// implement NdIndex<IxDyn> for [Ix; 2] and so on
187+
unsafe impl<const N: usize> NdIndex<IxDyn> for [Ix; N] {
188+
#[inline]
189+
fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option<isize> {
190+
debug_assert_eq!(
191+
strides.ndim(),
192+
N,
193+
"Attempted to index with {:?} in array with {} axes",
194+
self,
195+
strides.ndim()
196+
);
197+
stride_offset_checked(dim.ix(), strides.ix(), self)
198+
}
199+
200+
#[inline]
201+
fn index_unchecked(&self, strides: &IxDyn) -> isize {
202+
debug_assert_eq!(
203+
strides.ndim(),
204+
N,
205+
"Attempted to index with {:?} in array with {} axes",
206+
self,
207+
strides.ndim()
208+
);
209+
(0..N)
210+
.map(|i| stride_offset(self[i], get!(strides, i)))
211+
.sum()
212+
}
213+
}
214+
201215
impl<'a> IntoDimension for &'a [Ix] {
202216
type Dim = IxDyn;
203217
fn into_dimension(self) -> Self::Dim {

‎src/free_functions.rs

Lines changed: 24 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9-
use std::mem::{forget, size_of};
10-
use alloc::slice;
119
use alloc::vec;
1210
use alloc::vec::Vec;
11+
use std::mem::{forget, size_of};
1312

1413
use crate::imp_prelude::*;
1514
use crate::{dimension, ArcArray1, ArcArray2};
@@ -87,26 +86,10 @@ pub fn aview1<A>(xs: &[A]) -> ArrayView1<'_, A> {
8786

8887
/// Create a two-dimensional array view with elements borrowing `xs`.
8988
///
90-
/// **Panics** if the product of non-zero axis lengths overflows `isize`. (This
91-
/// can only occur when `V` is zero-sized.)
92-
pub fn aview2<A, V: FixedInitializer<Elem = A>>(xs: &[V]) -> ArrayView2<'_, A> {
93-
let cols = V::len();
94-
let rows = xs.len();
95-
let dim = Ix2(rows, cols);
96-
if size_of::<V>() == 0 {
97-
dimension::size_of_shape_checked(&dim)
98-
.expect("Product of non-zero axis lengths must not overflow isize.");
99-
}
100-
// `rows` is guaranteed to fit in `isize` because we've checked the ZST
101-
// case and slices never contain > `isize::MAX` bytes. `cols` is guaranteed
102-
// to fit in `isize` because `FixedInitializer` is not implemented for any
103-
// array lengths > `isize::MAX`. `cols * rows` is guaranteed to fit in
104-
// `isize` because we've checked the ZST case and slices never contain >
105-
// `isize::MAX` bytes.
106-
unsafe {
107-
let data = slice::from_raw_parts(xs.as_ptr() as *const A, cols * rows);
108-
ArrayView::from_shape_ptr(dim, data.as_ptr())
109-
}
89+
/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A
90+
/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes).
91+
pub fn aview2<A, const N: usize>(xs: &[[A; N]]) -> ArrayView2<'_, A> {
92+
ArrayView2::from(xs)
11093
}
11194

11295
/// Create a one-dimensional read-write array view with elements borrowing `xs`.
@@ -127,16 +110,15 @@ pub fn aview_mut1<A>(xs: &mut [A]) -> ArrayViewMut1<'_, A> {
127110

128111
/// Create a two-dimensional read-write array view with elements borrowing `xs`.
129112
///
130-
/// **Panics** if the product of non-zero axis lengths overflows `isize`. (This
131-
/// can only occur when `V` is zero-sized.)
113+
/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A
114+
/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes).
132115
///
133116
/// # Example
134117
///
135118
/// ```
136119
/// use ndarray::aview_mut2;
137120
///
138-
/// // The inner (nested) array must be of length 1 to 16, but the outer
139-
/// // can be of any length.
121+
/// // The inner (nested) and outer arrays can be of any length.
140122
/// let mut data = [[0.; 2]; 128];
141123
/// {
142124
/// // Make a 128 x 2 mut array view then turn it into 2 x 128
@@ -148,57 +130,10 @@ pub fn aview_mut1<A>(xs: &mut [A]) -> ArrayViewMut1<'_, A> {
148130
/// // look at the start of the result
149131
/// assert_eq!(&data[..3], [[1., -1.], [1., -1.], [1., -1.]]);
150132
/// ```
151-
pub fn aview_mut2<A, V: FixedInitializer<Elem = A>>(xs: &mut [V]) -> ArrayViewMut2<'_, A> {
152-
let cols = V::len();
153-
let rows = xs.len();
154-
let dim = Ix2(rows, cols);
155-
if size_of::<V>() == 0 {
156-
dimension::size_of_shape_checked(&dim)
157-
.expect("Product of non-zero axis lengths must not overflow isize.");
158-
}
159-
// `rows` is guaranteed to fit in `isize` because we've checked the ZST
160-
// case and slices never contain > `isize::MAX` bytes. `cols` is guaranteed
161-
// to fit in `isize` because `FixedInitializer` is not implemented for any
162-
// array lengths > `isize::MAX`. `cols * rows` is guaranteed to fit in
163-
// `isize` because we've checked the ZST case and slices never contain >
164-
// `isize::MAX` bytes.
165-
unsafe {
166-
let data = slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut A, cols * rows);
167-
ArrayViewMut::from_shape_ptr(dim, data.as_mut_ptr())
168-
}
133+
pub fn aview_mut2<A, const N: usize>(xs: &mut [[A; N]]) -> ArrayViewMut2<'_, A> {
134+
ArrayViewMut2::from(xs)
169135
}
170136

171-
/// Fixed-size array used for array initialization
172-
#[allow(clippy::missing_safety_doc)] // Should not be implemented downstream and to be deprecated.
173-
pub unsafe trait FixedInitializer {
174-
type Elem;
175-
fn as_init_slice(&self) -> &[Self::Elem];
176-
fn len() -> usize;
177-
}
178-
179-
macro_rules! impl_arr_init {
180-
(__impl $n: expr) => (
181-
unsafe impl<T> FixedInitializer for [T; $n] {
182-
type Elem = T;
183-
fn as_init_slice(&self) -> &[T] { self }
184-
fn len() -> usize { $n }
185-
}
186-
);
187-
() => ();
188-
($n: expr, $($m:expr,)*) => (
189-
impl_arr_init!(__impl $n);
190-
impl_arr_init!($($m,)*);
191-
)
192-
193-
}
194-
195-
// For implementors: If you ever implement `FixedInitializer` for array lengths
196-
// > `isize::MAX` (e.g. once Rust adds const generics), you must update
197-
// `aview2` and `aview_mut2` to perform the necessary checks. In particular,
198-
// the assumption that `cols` can never exceed `isize::MAX` would be incorrect.
199-
// (Consider e.g. `let xs: &[[i32; ::std::usize::MAX]] = &[]`.)
200-
impl_arr_init!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,);
201-
202137
/// Create a two-dimensional array with elements from `xs`.
203138
///
204139
/// ```
@@ -210,22 +145,16 @@ impl_arr_init!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,);
210145
/// a.shape() == [2, 3]
211146
/// );
212147
/// ```
213-
pub fn arr2<A: Clone, V: FixedInitializer<Elem = A>>(xs: &[V]) -> Array2<A>
214-
where
215-
V: Clone,
216-
{
148+
pub fn arr2<A: Clone, const N: usize>(xs: &[[A; N]]) -> Array2<A> {
217149
Array2::from(xs.to_vec())
218150
}
219151

220-
impl<A, V> From<Vec<V>> for Array2<A>
221-
where
222-
V: FixedInitializer<Elem = A>,
223-
{
152+
impl<A, const N: usize> From<Vec<[A; N]>> for Array2<A> {
224153
/// Converts the `Vec` of arrays to an owned 2-D array.
225154
///
226155
/// **Panics** if the product of non-zero axis lengths overflows `isize`.
227-
fn from(mut xs: Vec<V>) -> Self {
228-
let dim = Ix2(xs.len(), V::len());
156+
fn from(mut xs: Vec<[A; N]>) -> Self {
157+
let dim = Ix2(xs.len(), N);
229158
let ptr = xs.as_mut_ptr();
230159
let cap = xs.capacity();
231160
let expand_len = dimension::size_of_shape_checked(&dim)
@@ -234,29 +163,25 @@ where
234163
unsafe {
235164
let v = if size_of::<A>() == 0 {
236165
Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len)
237-
} else if V::len() == 0 {
166+
} else if N == 0 {
238167
Vec::new()
239168
} else {
240169
// Guaranteed not to overflow in this case since A is non-ZST
241170
// and Vec never allocates more than isize bytes.
242-
let expand_cap = cap * V::len();
171+
let expand_cap = cap * N;
243172
Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap)
244173
};
245174
ArrayBase::from_shape_vec_unchecked(dim, v)
246175
}
247176
}
248177
}
249178

250-
impl<A, V, U> From<Vec<V>> for Array3<A>
251-
where
252-
V: FixedInitializer<Elem = U>,
253-
U: FixedInitializer<Elem = A>,
254-
{
179+
impl<A, const N: usize, const M: usize> From<Vec<[[A; M]; N]>> for Array3<A> {
255180
/// Converts the `Vec` of arrays to an owned 3-D array.
256181
///
257182
/// **Panics** if the product of non-zero axis lengths overflows `isize`.
258-
fn from(mut xs: Vec<V>) -> Self {
259-
let dim = Ix3(xs.len(), V::len(), U::len());
183+
fn from(mut xs: Vec<[[A; M]; N]>) -> Self {
184+
let dim = Ix3(xs.len(), N, M);
260185
let ptr = xs.as_mut_ptr();
261186
let cap = xs.capacity();
262187
let expand_len = dimension::size_of_shape_checked(&dim)
@@ -265,12 +190,12 @@ where
265190
unsafe {
266191
let v = if size_of::<A>() == 0 {
267192
Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len)
268-
} else if V::len() == 0 || U::len() == 0 {
193+
} else if N == 0 || M == 0 {
269194
Vec::new()
270195
} else {
271196
// Guaranteed not to overflow in this case since A is non-ZST
272197
// and Vec never allocates more than isize bytes.
273-
let expand_cap = cap * V::len() * U::len();
198+
let expand_cap = cap * N * M;
274199
Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap)
275200
};
276201
ArrayBase::from_shape_vec_unchecked(dim, v)
@@ -280,7 +205,7 @@ where
280205

281206
/// Create a two-dimensional array with elements from `xs`.
282207
///
283-
pub fn rcarr2<A: Clone, V: Clone + FixedInitializer<Elem = A>>(xs: &[V]) -> ArcArray2<A> {
208+
pub fn rcarr2<A: Clone, const N: usize>(xs: &[[A; N]]) -> ArcArray2<A> {
284209
arr2(xs).into_shared()
285210
}
286211

@@ -301,23 +226,11 @@ pub fn rcarr2<A: Clone, V: Clone + FixedInitializer<Elem = A>>(xs: &[V]) -> ArcA
301226
/// a.shape() == [3, 2, 2]
302227
/// );
303228
/// ```
304-
pub fn arr3<A: Clone, V: FixedInitializer<Elem = U>, U: FixedInitializer<Elem = A>>(
305-
xs: &[V],
306-
) -> Array3<A>
307-
where
308-
V: Clone,
309-
U: Clone,
310-
{
229+
pub fn arr3<A: Clone, const N: usize, const M: usize>(xs: &[[[A; M]; N]]) -> Array3<A> {
311230
Array3::from(xs.to_vec())
312231
}
313232

314233
/// Create a three-dimensional array with elements from `xs`.
315-
pub fn rcarr3<A: Clone, V: FixedInitializer<Elem = U>, U: FixedInitializer<Elem = A>>(
316-
xs: &[V],
317-
) -> ArcArray<A, Ix3>
318-
where
319-
V: Clone,
320-
U: Clone,
321-
{
234+
pub fn rcarr3<A: Clone, const N: usize, const M: usize>(xs: &[[[A; M]; N]]) -> ArcArray<A, Ix3> {
322235
arr3(xs).into_shared()
323236
}

‎src/zip/ndproducer.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
use crate::imp_prelude::*;
32
use crate::Layout;
43
use crate::NdIndex;
@@ -168,6 +167,26 @@ impl<'a, A: 'a> IntoNdProducer for &'a mut [A] {
168167
}
169168
}
170169

170+
/// A one-dimensional array is a one-dimensional producer
171+
impl<'a, A: 'a, const N: usize> IntoNdProducer for &'a [A; N] {
172+
type Item = <Self::Output as NdProducer>::Item;
173+
type Dim = Ix1;
174+
type Output = ArrayView1<'a, A>;
175+
fn into_producer(self) -> Self::Output {
176+
<_>::from(self)
177+
}
178+
}
179+
180+
/// A mutable one-dimensional array is a mutable one-dimensional producer
181+
impl<'a, A: 'a, const N: usize> IntoNdProducer for &'a mut [A; N] {
182+
type Item = <Self::Output as NdProducer>::Item;
183+
type Dim = Ix1;
184+
type Output = ArrayViewMut1<'a, A>;
185+
fn into_producer(self) -> Self::Output {
186+
<_>::from(self)
187+
}
188+
}
189+
171190
/// A Vec is a one-dimensional producer
172191
impl<'a, A: 'a> IntoNdProducer for &'a Vec<A> {
173192
type Item = <Self::Output as NdProducer>::Item;
@@ -399,4 +418,3 @@ impl<A, D: Dimension> NdProducer for RawArrayViewMut<A, D> {
399418
self.split_at(axis, index)
400419
}
401420
}
402-

‎tests/array.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ fn diag() {
731731
let a = arr2(&[[1., 2., 3.0f32], [0., 0., 0.]]);
732732
let d = a.view().into_diag();
733733
assert_eq!(d.dim(), 2);
734-
let d = arr2::<f32, _>(&[[]]).into_diag();
734+
let d = arr2::<f32, 0>(&[[]]).into_diag();
735735
assert_eq!(d.dim(), 0);
736736
let d = ArcArray::<f32, _>::zeros(()).into_diag();
737737
assert_eq!(d.dim(), 1);
@@ -960,7 +960,7 @@ fn zero_axes() {
960960
a.map_inplace(|_| panic!());
961961
a.for_each(|_| panic!());
962962
println!("{:?}", a);
963-
let b = arr2::<f32, _>(&[[], [], [], []]);
963+
let b = arr2::<f32, 0>(&[[], [], [], []]);
964964
println!("{:?}\n{:?}", b.shape(), b);
965965

966966
// we can even get a subarray of b

0 commit comments

Comments
 (0)
Please sign in to comment.