Skip to content

Commit f33e147

Browse files
authored
ShapedIndex (#199)
* Use `CartesianIndices` for conversion of `Int` -> CartesianIndex. * The `CartesianIndex` -> `Int` conversion is managed by composing a `StrideIndex`, where the strides are computed using `size_to_strides`, instead of the internal memory representation. * Improve known_size: Previously an array that needed a unique method for `known_size` also needed a unique one for `known_size(::Type{A}, dim)`. Now `known_size` is called and then indexed, requring only one new method.
1 parent 669658f commit f33e147

File tree

6 files changed

+95
-104
lines changed

6 files changed

+95
-104
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.1.30"
3+
version = "3.1.31"
44

55
[deps]
66
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"

src/array_index.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ end
207207
Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1
208208
Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count
209209
Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count
210+
211+
## getindex
210212
@propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)]
211213
@propagate_inbounds function Base.getindex(ind::BidiagonalIndex, i::Int)
212214
@boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i))
@@ -274,11 +276,11 @@ end
274276
ind.reflocalinds[p][_i] + ind.refcoords[p] - 1
275277
end
276278

277-
@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex{N}) where {N}
278-
return _strides2int(offsets(x), strides(x), Tuple(i)) + offset1(x)
279+
@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex) where {N}
280+
return _strides2int(offsets(x), strides(x), Tuple(i)) + static(1)
279281
end
280282
@generated function _strides2int(o::O, s::S, i::I) where {O,S,I}
281-
N = known_length(I)
283+
N = known_length(S)
282284
out = :()
283285
for i in 1:N
284286
tmp = :(((getfield(i, $i) - getfield(o, $i)) * getfield(s, $i)))

src/indexing.jl

Lines changed: 79 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -141,22 +141,6 @@ to_index(::IndexLinear, axis, arg::CartesianIndices{1}) = axes(arg, 1)
141141
@propagate_inbounds function to_index(::IndexLinear, axis, arg::AbstractCartesianIndex{1})
142142
return to_index(axis, first(Tuple(arg)))
143143
end
144-
function to_index(::IndexLinear, x, arg::AbstractCartesianIndex{N}) where {N}
145-
inds = Tuple(arg)
146-
o = offsets(x)
147-
s = size(x)
148-
return first(inds) + (static(1) - first(o)) + _subs2int(first(s), tail(s), tail(o), tail(inds))
149-
end
150-
@inline function _subs2int(stride, s::Tuple{Any,Vararg}, o::Tuple{Any,Vararg}, inds::Tuple{Any,Vararg})
151-
i = ((first(inds) - first(o)) * stride)
152-
return i + _subs2int(stride * first(s), tail(s), tail(o), tail(inds))
153-
end
154-
function _subs2int(stride, s::Tuple{Any}, o::Tuple{Any}, inds::Tuple{Any})
155-
return (first(inds) - first(o)) * stride
156-
end
157-
# trailing inbounds can only be 1 or 1:1
158-
_subs2int(stride, ::Tuple{}, ::Tuple{}, ::Tuple{Any}) = static(0)
159-
160144
@propagate_inbounds function to_index(::IndexLinear, x, arg::Union{Array{Bool}, BitArray})
161145
@boundscheck checkbounds(x, arg)
162146
return LogicalIndex{Int}(arg)
@@ -215,13 +199,18 @@ end
215199
@boundscheck checkbounds(x, arg)
216200
return LogicalIndex{Int}(arg)
217201
end
218-
to_index(::IndexCartesian, x, i::Integer) = NDIndex(_int2subs(offsets(x), size(x), i - static(1)))
219-
@inline function _int2subs(o::Tuple{Any,Vararg{Any}}, s::Tuple{Any,Vararg{Any}}, i)
220-
len = first(s)
221-
inext = div(i, len)
222-
return (canonicalize(i - len * inext + first(o)), _int2subs(tail(o), tail(s), inext)...)
202+
203+
# TODO delete this once the layout interface is working
204+
_array_index(::IndexLinear, a, i::CanonicalInt) = i
205+
@inline function _array_index(::IndexStyle, a, i::CanonicalInt)
206+
CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))[i]
223207
end
224-
_int2subs(o::Tuple{Any}, s::Tuple{Any}, i) = canonicalize(i + first(o))
208+
_array_index(::IndexLinear, a, i::AbstractCartesianIndex{1}) = getfield(Tuple(i), 1)
209+
@inline function _array_index(::IndexLinear, a, i::AbstractCartesianIndex)
210+
N = ndims(a)
211+
StrideIndex{N,ntuple(+, Val(N)),nothing}(size_to_strides(size(a), static(1)), offsets(a))[i]
212+
end
213+
_array_index(::IndexStyle, a, i::AbstractCartesianIndex) = i
225214

226215
"""
227216
unsafe_reconstruct(A, data; kwargs...)
@@ -326,54 +315,50 @@ another instance of `ArrayInterface.getindex` should only be done by overloading
326315
Changing indexing based on a given argument from `args` should be done through,
327316
[`to_index`](@ref), or [`to_axis`](@ref).
328317
"""
329-
@propagate_inbounds getindex(A, args...) = unsafe_get_index(A, to_indices(A, args))
318+
@propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args)...)
330319
@propagate_inbounds function getindex(A; kwargs...)
331-
return unsafe_get_index(A, to_indices(A, order_named_inds(dimnames(A), values(kwargs))))
320+
return unsafe_getindex(A, to_indices(A, order_named_inds(dimnames(A), values(kwargs)))...)
332321
end
333322
@propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i)
334323
@propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i)
335324

336-
## unsafe_get_index ##
337-
unsafe_get_index(A, i::Tuple{}) = unsafe_get_element(A, ())
338-
unsafe_get_index(A, i::Tuple{CanonicalInt}) = unsafe_get_element(A, getfield(i, 1))
339-
function unsafe_get_index(A, i::Tuple{CanonicalInt,Vararg{CanonicalInt}})
340-
unsafe_get_element(A, NDIndex(i))
325+
## unsafe_getindex ##
326+
function unsafe_getindex(a::A) where {A}
327+
parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A,)))
328+
return unsafe_getindex(parent(a))
341329
end
342-
unsafe_get_index(A, i::Tuple) = unsafe_get_collection(A, i)
343-
344-
#=
345-
unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T
346-
347-
Returns an element of `A` at the indices `inds`. This method assumes all `inds`
348-
have been checked for being in bounds. Any new array type using `ArrayInterface.getindex`
349-
must define `unsafe_get_element(::NewArrayType, inds)`.
350-
=#
351-
unsafe_get_element(a::A, inds) where {A} = _unsafe_get_element(has_parent(A), a, inds)
352-
_unsafe_get_element(::True, a, inds) = unsafe_get_element(parent(a), inds)
353-
_unsafe_get_element(::False, a, inds) = @inbounds(parent(a)[inds])
354-
_unsafe_get_element(::False, a::AbstractArray2, i) = unsafe_get_element_error(a, i)
355-
356-
## Array ##
357-
unsafe_get_element(A::Array, ::Tuple{}) = Base.arrayref(false, A, 1)
358-
unsafe_get_element(A::Array, i::Integer) = Base.arrayref(false, A, Int(i))
359-
unsafe_get_element(A::Array, i::NDIndex) = unsafe_get_element(A, to_index(A, i))
330+
function unsafe_getindex(a::A, i::CanonicalInt) where {A}
331+
idx = _array_index(IndexStyle(A), a, i)
332+
if idx === i
333+
parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A, i)))
334+
return unsafe_getindex(parent(a), i)
335+
else
336+
return unsafe_getindex(a, idx)
337+
end
338+
end
339+
function unsafe_getindex(a::A, i::AbstractCartesianIndex) where {A}
340+
idx = _array_index(IndexStyle(A), a, i)
341+
if idx === i
342+
parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A, i)))
343+
return unsafe_getindex(parent(a), i)
344+
else
345+
return unsafe_getindex(a, idx)
346+
end
347+
end
348+
function unsafe_getindex(a, i::CanonicalInt, ii::Vararg{CanonicalInt})
349+
unsafe_getindex(a, NDIndex(i, ii...))
350+
end
351+
unsafe_getindex(a, i::Vararg{Any}) = unsafe_get_collection(a, i)
360352

361-
## LinearIndices ##
362-
unsafe_get_element(A::LinearIndices, i::Integer) = Int(i)
363-
unsafe_get_element(A::LinearIndices, i::NDIndex) = unsafe_get_element(A, to_index(A, i))
353+
unsafe_getindex(A::Array) = Base.arrayref(false, A, 1)
354+
unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i))
364355

365-
unsafe_get_element(A::CartesianIndices, i::NDIndex) = CartesianIndex(i)
366-
unsafe_get_element(A::CartesianIndices, i::Integer) = unsafe_get_element(A, to_index(A, i))
356+
unsafe_getindex(A::LinearIndices, i::CanonicalInt) = Int(i)
367357

368-
unsafe_get_element(A::ReshapedArray, i::Integer) = unsafe_get_element(parent(A), i)
369-
function unsafe_get_element(A::ReshapedArray, i::NDIndex)
370-
return unsafe_get_element(parent(A), to_index(IndexLinear(), A, i))
371-
end
358+
unsafe_getindex(A::CartesianIndices, i::AbstractCartesianIndex) = CartesianIndex(i)
372359

373-
unsafe_get_element(A::SubArray, i) = @inbounds(A[i])
374-
function unsafe_get_element_error(@nospecialize(A), @nospecialize(i))
375-
throw(MethodError(unsafe_get_element, (A, i)))
376-
end
360+
unsafe_getindex(A::SubArray, i::CanonicalInt) = @inbounds(A[i])
361+
unsafe_getindex(A::SubArray, i::AbstractCartesianIndex) = @inbounds(A[i])
377362

378363
# This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755.
379364
#=
@@ -402,7 +387,7 @@ function _generate_unsafe_get_index!_body(N::Int)
402387
# the optimizer is not clever enough to split the union without it
403388
Dy === nothing && return dest
404389
(idx, state) = Dy
405-
dest[idx] = unsafe_get_element(src, NDIndex(Base.Cartesian.@ntuple($N, j)))
390+
dest[idx] = unsafe_getindex(src, NDIndex(Base.Cartesian.@ntuple($N, j)))
406391
Dy = iterate(D, state)
407392
end
408393
return dest
@@ -441,45 +426,49 @@ Store the given values at the given key or index within a collection.
441426
"""
442427
@propagate_inbounds function setindex!(A, val, args...)
443428
if can_setindex(A)
444-
return unsafe_set_index!(A, val, to_indices(A, args))
429+
return unsafe_setindex!(A, val, to_indices(A, args)...)
445430
else
446431
error("Instance of type $(typeof(A)) are not mutable and cannot change elements after construction.")
447432
end
448433
end
449434
@propagate_inbounds function setindex!(A, val; kwargs...)
450-
return unsafe_set_index!(A, val, to_indices(A, order_named_inds(dimnames(A), values(kwargs))))
435+
return unsafe_setindex!(A, val, to_indices(A, order_named_inds(dimnames(A), values(kwargs)))...)
451436
end
452437

453-
## unsafe_get_index ##
454-
unsafe_set_index!(A, v, i::Tuple{}) = unsafe_set_element!(A, v, ())
455-
unsafe_set_index!(A, v, i::Tuple{CanonicalInt}) = unsafe_set_element!(A, v, getfield(i, 1))
456-
function unsafe_set_index!(A, v, i::Tuple{CanonicalInt,Vararg{CanonicalInt}})
457-
unsafe_set_element!(A, v, NDIndex(i))
438+
## unsafe_setindex! ##
439+
function unsafe_setindex!(a::A, v) where {A}
440+
parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v)))
441+
return unsafe_setindex!(parent(a), v)
458442
end
459-
unsafe_set_index!(A, v, i::Tuple) = unsafe_set_collection!(A, v, i)
460-
461-
#=
462-
unsafe_set_element!(A, val, inds::Tuple)
463-
464-
Sets an element of `A` to `val` at indices `inds`. This method assumes all `inds`
465-
have been checked for being in bounds. Any new array type using `ArrayInterface.setindex!`
466-
must define `unsafe_set_element!(::NewArrayType, val, inds)`.
467-
=#
468-
unsafe_set_element!(a, val, inds) = _unsafe_set_element!(has_parent(a), a, val, inds)
469-
_unsafe_set_element!(::True, a, val, inds) = unsafe_set_element!(parent(a), val, inds)
470-
_unsafe_set_element!(::False, a, val, inds) = @inbounds(parent(a)[inds] = val)
471-
472-
function _unsafe_set_element!(::False, a::AbstractArray2, val, inds)
473-
unsafe_set_element_error(a, val, inds)
443+
function unsafe_setindex!(a::A, v, i::CanonicalInt) where {A}
444+
idx = _array_index(IndexStyle(A), a, i)
445+
if idx === i
446+
parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v, i)))
447+
return unsafe_setindex!(parent(a), v, i)
448+
else
449+
return unsafe_setindex!(a, v, idx)
450+
end
474451
end
475-
unsafe_set_element_error(A, v, i) = throw(MethodError(unsafe_set_element!, (A, v, i)))
476-
477-
function unsafe_set_element!(A::Array{T}, val, ::Tuple{}) where {T}
478-
Base.arrayset(false, A, convert(T, val)::T, 1)
452+
function unsafe_setindex!(a::A, v, i::AbstractCartesianIndex) where {A}
453+
idx = _array_index(IndexStyle(A), a, i)
454+
if idx === i
455+
parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v, i)))
456+
return unsafe_setindex!(parent(a), v, i)
457+
else
458+
return unsafe_setindex!(a, v, idx)
459+
end
460+
end
461+
function unsafe_setindex!(a, v, i::CanonicalInt, ii::Vararg{CanonicalInt})
462+
unsafe_setindex!(a, v, NDIndex(i, ii...))
479463
end
480-
function unsafe_set_element!(A::Array{T}, val, i::Integer) where {T}
481-
return Base.arrayset(false, A, convert(T, val)::T, Int(i))
464+
function unsafe_setindex!(A::Array{T}, v) where {T}
465+
Base.arrayset(false, A, convert(T, v)::T, 1)
482466
end
467+
function unsafe_setindex!(A::Array{T}, v, i::CanonicalInt) where {T}
468+
return Base.arrayset(false, A, convert(T, v)::T, Int(i))
469+
end
470+
471+
unsafe_setindex!(a, v, i::Vararg{Any}) = unsafe_set_collection!(a, v, i)
483472

484473
# This is based on Base._unsafe_setindex!.
485474
#=
@@ -501,7 +490,7 @@ function _generate_unsafe_setindex!_body(N::Int)
501490
# the optimizer that it does not need to emit error paths
502491
Xy === nothing && break
503492
(val, state) = Xy
504-
unsafe_set_element!(A, val, NDIndex(Base.Cartesian.@ntuple($N, i)))
493+
unsafe_setindex!(A, val, NDIndex(Base.Cartesian.@ntuple($N, i)))
505494
Xy = iterate(x′, state)
506495
end
507496
A

src/size.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ end
2525

2626
size(x::SubArray) = eachop(_sub_size, to_parent_dims(x), x.indices)
2727
_sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = static_length(getfield(x, dim))
28-
2928
@inline size(B::VecAdjTrans) = (One(), length(parent(B)))
3029
@inline size(B::MatAdjTrans) = permute(size(parent(B)), to_parent_dims(B))
3130
@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A}
@@ -80,15 +79,15 @@ compile time. If a dimension does not have a known size along a dimension then `
8079
returned in its position.
8180
"""
8281
known_size(x) = known_size(typeof(x))
83-
known_size(::Type{T}) where {T} = eachop(known_size, nstatic(Val(ndims(T))), T)
84-
82+
known_size(::Type{T}) where {T} = eachop(_known_size, nstatic(Val(ndims(T))), axes_types(T))
83+
_known_size(::Type{T}, dim::StaticInt) where {T} = known_length(_get_tuple(T, dim))
8584
@inline known_size(x, dim) = known_size(typeof(x), dim)
8685
@inline known_size(::Type{T}, dim) where {T} = known_size(T, to_dims(T, dim))
87-
@inline function known_size(::Type{T}, dim::Integer) where {T}
86+
@inline function known_size(::Type{T}, dim::CanonicalInt) where {T}
8887
if ndims(T) < dim
8988
return 1
9089
else
91-
return known_length(axes_types(T, dim))
90+
return known_size(T)[dim]
9291
end
9392
end
9493

src/stridelayout.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ Returns offsets of indices with respect to 0. If values are known at compile tim
5252
it should return them as `Static` numbers.
5353
For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1))`.
5454
"""
55-
offsets(x::StrideIndex) = getfield(x, :offsets)
5655
@inline offsets(x, i) = static_first(indices(x, i))
5756
offsets(::Tuple) = (One(),)
57+
offsets(x::StrideIndex) = getfield(x, :offsets)
5858
offsets(x) = eachop(_offsets, nstatic(Val(ndims(x))), x)
5959
function _offsets(x::X, dim::StaticInt{D}) where {X,D}
6060
start = known_first(axes_types(X, dim))

test/indexing.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ end
2626
@test @inferred(ArrayInterface.to_index(axis, CartesianIndices(()))) === CartesianIndices(())
2727

2828
x = LinearIndices((static(0):static(3),static(3):static(5),static(-2):static(0)));
29-
@test @inferred(ArrayInterface.to_index(x, NDIndex((0, 3, -2)))) === 1
30-
@test @inferred(ArrayInterface.to_index(x, NDIndex(static(0), static(3), static(-2)))) === static(1)
29+
30+
# @test @inferred(ArrayInterface.to_index(x, NDIndex((0, 3, -2)))) === 1
31+
# @test @inferred(ArrayInterface.to_index(x, NDIndex(static(0), static(3), static(-2)))) === static(1)
3132

3233
@test_throws BoundsError ArrayInterface.to_index(axis, 4)
3334
@test_throws BoundsError ArrayInterface.to_index(axis, 1:4)
@@ -125,8 +126,8 @@ end
125126
#@test_throws ArgumentError Base._sub2ind((1:3,), 2)
126127
#@test_throws ArgumentError Base._ind2sub((1:3,), 2)
127128
x = Array{Int,2}(undef, (2, 2))
128-
ArrayInterface.unsafe_set_index!(x, 1, (2, 2))
129-
@test ArrayInterface.unsafe_get_index(x, (2, 2)) === 1
129+
ArrayInterface.unsafe_setindex!(x, 1, 2, 2)
130+
@test ArrayInterface.unsafe_getindex(x, 2, 2) === 1
130131

131132
# FIXME @test_throws MethodError ArrayInterface.unsafe_set_element!(x, 1, (:x, :x))
132133
# FIXME @test_throws MethodError ArrayInterface.unsafe_get_element(x, (:x, :x))

0 commit comments

Comments
 (0)