diff --git a/Project.toml b/Project.toml index 7c7cdc78d..f74d7ff68 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "3.1.30" +version = "3.1.31" [deps] IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" diff --git a/src/array_index.jl b/src/array_index.jl index 729be7395..3e1ce456f 100644 --- a/src/array_index.jl +++ b/src/array_index.jl @@ -207,6 +207,8 @@ end Base.firstindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = 1 Base.lastindex(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count Base.length(i::Union{TridiagonalIndex,BandedBlockBandedMatrixIndex,BandedMatrixIndex,BidiagonalIndex,BlockBandedMatrixIndex}) = i.count + +## getindex @propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)] @propagate_inbounds function Base.getindex(ind::BidiagonalIndex, i::Int) @boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i)) @@ -274,11 +276,11 @@ end ind.reflocalinds[p][_i] + ind.refcoords[p] - 1 end -@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex{N}) where {N} - return _strides2int(offsets(x), strides(x), Tuple(i)) + offset1(x) +@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex) where {N} + return _strides2int(offsets(x), strides(x), Tuple(i)) + static(1) end @generated function _strides2int(o::O, s::S, i::I) where {O,S,I} - N = known_length(I) + N = known_length(S) out = :() for i in 1:N tmp = :(((getfield(i, $i) - getfield(o, $i)) * getfield(s, $i))) diff --git a/src/indexing.jl b/src/indexing.jl index e4bc6c4ed..11d0ab979 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -141,22 +141,6 @@ to_index(::IndexLinear, axis, arg::CartesianIndices{1}) = axes(arg, 1) @propagate_inbounds function to_index(::IndexLinear, axis, arg::AbstractCartesianIndex{1}) return to_index(axis, first(Tuple(arg))) end -function to_index(::IndexLinear, x, arg::AbstractCartesianIndex{N}) where {N} - inds = Tuple(arg) - o = offsets(x) - s = size(x) - return first(inds) + (static(1) - first(o)) + _subs2int(first(s), tail(s), tail(o), tail(inds)) -end -@inline function _subs2int(stride, s::Tuple{Any,Vararg}, o::Tuple{Any,Vararg}, inds::Tuple{Any,Vararg}) - i = ((first(inds) - first(o)) * stride) - return i + _subs2int(stride * first(s), tail(s), tail(o), tail(inds)) -end -function _subs2int(stride, s::Tuple{Any}, o::Tuple{Any}, inds::Tuple{Any}) - return (first(inds) - first(o)) * stride -end -# trailing inbounds can only be 1 or 1:1 -_subs2int(stride, ::Tuple{}, ::Tuple{}, ::Tuple{Any}) = static(0) - @propagate_inbounds function to_index(::IndexLinear, x, arg::Union{Array{Bool}, BitArray}) @boundscheck checkbounds(x, arg) return LogicalIndex{Int}(arg) @@ -215,13 +199,18 @@ end @boundscheck checkbounds(x, arg) return LogicalIndex{Int}(arg) end -to_index(::IndexCartesian, x, i::Integer) = NDIndex(_int2subs(offsets(x), size(x), i - static(1))) -@inline function _int2subs(o::Tuple{Any,Vararg{Any}}, s::Tuple{Any,Vararg{Any}}, i) - len = first(s) - inext = div(i, len) - return (canonicalize(i - len * inext + first(o)), _int2subs(tail(o), tail(s), inext)...) + +# TODO delete this once the layout interface is working +_array_index(::IndexLinear, a, i::CanonicalInt) = i +@inline function _array_index(::IndexStyle, a, i::CanonicalInt) + CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))[i] end -_int2subs(o::Tuple{Any}, s::Tuple{Any}, i) = canonicalize(i + first(o)) +_array_index(::IndexLinear, a, i::AbstractCartesianIndex{1}) = getfield(Tuple(i), 1) +@inline function _array_index(::IndexLinear, a, i::AbstractCartesianIndex) + N = ndims(a) + StrideIndex{N,ntuple(+, Val(N)),nothing}(size_to_strides(size(a), static(1)), offsets(a))[i] +end +_array_index(::IndexStyle, a, i::AbstractCartesianIndex) = i """ unsafe_reconstruct(A, data; kwargs...) @@ -326,54 +315,50 @@ another instance of `ArrayInterface.getindex` should only be done by overloading Changing indexing based on a given argument from `args` should be done through, [`to_index`](@ref), or [`to_axis`](@ref). """ -@propagate_inbounds getindex(A, args...) = unsafe_get_index(A, to_indices(A, args)) +@propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args)...) @propagate_inbounds function getindex(A; kwargs...) - return unsafe_get_index(A, to_indices(A, order_named_inds(dimnames(A), values(kwargs)))) + return unsafe_getindex(A, to_indices(A, order_named_inds(dimnames(A), values(kwargs)))...) end @propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i) @propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i) -## unsafe_get_index ## -unsafe_get_index(A, i::Tuple{}) = unsafe_get_element(A, ()) -unsafe_get_index(A, i::Tuple{CanonicalInt}) = unsafe_get_element(A, getfield(i, 1)) -function unsafe_get_index(A, i::Tuple{CanonicalInt,Vararg{CanonicalInt}}) - unsafe_get_element(A, NDIndex(i)) +## unsafe_getindex ## +function unsafe_getindex(a::A) where {A} + parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A,))) + return unsafe_getindex(parent(a)) end -unsafe_get_index(A, i::Tuple) = unsafe_get_collection(A, i) - -#= - unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T - -Returns an element of `A` at the indices `inds`. This method assumes all `inds` -have been checked for being in bounds. Any new array type using `ArrayInterface.getindex` -must define `unsafe_get_element(::NewArrayType, inds)`. -=# -unsafe_get_element(a::A, inds) where {A} = _unsafe_get_element(has_parent(A), a, inds) -_unsafe_get_element(::True, a, inds) = unsafe_get_element(parent(a), inds) -_unsafe_get_element(::False, a, inds) = @inbounds(parent(a)[inds]) -_unsafe_get_element(::False, a::AbstractArray2, i) = unsafe_get_element_error(a, i) - -## Array ## -unsafe_get_element(A::Array, ::Tuple{}) = Base.arrayref(false, A, 1) -unsafe_get_element(A::Array, i::Integer) = Base.arrayref(false, A, Int(i)) -unsafe_get_element(A::Array, i::NDIndex) = unsafe_get_element(A, to_index(A, i)) +function unsafe_getindex(a::A, i::CanonicalInt) where {A} + idx = _array_index(IndexStyle(A), a, i) + if idx === i + parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A, i))) + return unsafe_getindex(parent(a), i) + else + return unsafe_getindex(a, idx) + end +end +function unsafe_getindex(a::A, i::AbstractCartesianIndex) where {A} + idx = _array_index(IndexStyle(A), a, i) + if idx === i + parent_type(A) <: A && throw(MethodError(unsafe_getindex, (A, i))) + return unsafe_getindex(parent(a), i) + else + return unsafe_getindex(a, idx) + end +end +function unsafe_getindex(a, i::CanonicalInt, ii::Vararg{CanonicalInt}) + unsafe_getindex(a, NDIndex(i, ii...)) +end +unsafe_getindex(a, i::Vararg{Any}) = unsafe_get_collection(a, i) -## LinearIndices ## -unsafe_get_element(A::LinearIndices, i::Integer) = Int(i) -unsafe_get_element(A::LinearIndices, i::NDIndex) = unsafe_get_element(A, to_index(A, i)) +unsafe_getindex(A::Array) = Base.arrayref(false, A, 1) +unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i)) -unsafe_get_element(A::CartesianIndices, i::NDIndex) = CartesianIndex(i) -unsafe_get_element(A::CartesianIndices, i::Integer) = unsafe_get_element(A, to_index(A, i)) +unsafe_getindex(A::LinearIndices, i::CanonicalInt) = Int(i) -unsafe_get_element(A::ReshapedArray, i::Integer) = unsafe_get_element(parent(A), i) -function unsafe_get_element(A::ReshapedArray, i::NDIndex) - return unsafe_get_element(parent(A), to_index(IndexLinear(), A, i)) -end +unsafe_getindex(A::CartesianIndices, i::AbstractCartesianIndex) = CartesianIndex(i) -unsafe_get_element(A::SubArray, i) = @inbounds(A[i]) -function unsafe_get_element_error(@nospecialize(A), @nospecialize(i)) - throw(MethodError(unsafe_get_element, (A, i))) -end +unsafe_getindex(A::SubArray, i::CanonicalInt) = @inbounds(A[i]) +unsafe_getindex(A::SubArray, i::AbstractCartesianIndex) = @inbounds(A[i]) # This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755. #= @@ -402,7 +387,7 @@ function _generate_unsafe_get_index!_body(N::Int) # the optimizer is not clever enough to split the union without it Dy === nothing && return dest (idx, state) = Dy - dest[idx] = unsafe_get_element(src, NDIndex(Base.Cartesian.@ntuple($N, j))) + dest[idx] = unsafe_getindex(src, NDIndex(Base.Cartesian.@ntuple($N, j))) Dy = iterate(D, state) end return dest @@ -441,45 +426,49 @@ Store the given values at the given key or index within a collection. """ @propagate_inbounds function setindex!(A, val, args...) if can_setindex(A) - return unsafe_set_index!(A, val, to_indices(A, args)) + return unsafe_setindex!(A, val, to_indices(A, args)...) else error("Instance of type $(typeof(A)) are not mutable and cannot change elements after construction.") end end @propagate_inbounds function setindex!(A, val; kwargs...) - return unsafe_set_index!(A, val, to_indices(A, order_named_inds(dimnames(A), values(kwargs)))) + return unsafe_setindex!(A, val, to_indices(A, order_named_inds(dimnames(A), values(kwargs)))...) end -## unsafe_get_index ## -unsafe_set_index!(A, v, i::Tuple{}) = unsafe_set_element!(A, v, ()) -unsafe_set_index!(A, v, i::Tuple{CanonicalInt}) = unsafe_set_element!(A, v, getfield(i, 1)) -function unsafe_set_index!(A, v, i::Tuple{CanonicalInt,Vararg{CanonicalInt}}) - unsafe_set_element!(A, v, NDIndex(i)) +## unsafe_setindex! ## +function unsafe_setindex!(a::A, v) where {A} + parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v))) + return unsafe_setindex!(parent(a), v) end -unsafe_set_index!(A, v, i::Tuple) = unsafe_set_collection!(A, v, i) - -#= - unsafe_set_element!(A, val, inds::Tuple) - -Sets an element of `A` to `val` at indices `inds`. This method assumes all `inds` -have been checked for being in bounds. Any new array type using `ArrayInterface.setindex!` -must define `unsafe_set_element!(::NewArrayType, val, inds)`. -=# -unsafe_set_element!(a, val, inds) = _unsafe_set_element!(has_parent(a), a, val, inds) -_unsafe_set_element!(::True, a, val, inds) = unsafe_set_element!(parent(a), val, inds) -_unsafe_set_element!(::False, a, val, inds) = @inbounds(parent(a)[inds] = val) - -function _unsafe_set_element!(::False, a::AbstractArray2, val, inds) - unsafe_set_element_error(a, val, inds) +function unsafe_setindex!(a::A, v, i::CanonicalInt) where {A} + idx = _array_index(IndexStyle(A), a, i) + if idx === i + parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v, i))) + return unsafe_setindex!(parent(a), v, i) + else + return unsafe_setindex!(a, v, idx) + end end -unsafe_set_element_error(A, v, i) = throw(MethodError(unsafe_set_element!, (A, v, i))) - -function unsafe_set_element!(A::Array{T}, val, ::Tuple{}) where {T} - Base.arrayset(false, A, convert(T, val)::T, 1) +function unsafe_setindex!(a::A, v, i::AbstractCartesianIndex) where {A} + idx = _array_index(IndexStyle(A), a, i) + if idx === i + parent_type(A) <: A && throw(MethodError(unsafe_setindex!, (A, v, i))) + return unsafe_setindex!(parent(a), v, i) + else + return unsafe_setindex!(a, v, idx) + end +end +function unsafe_setindex!(a, v, i::CanonicalInt, ii::Vararg{CanonicalInt}) + unsafe_setindex!(a, v, NDIndex(i, ii...)) end -function unsafe_set_element!(A::Array{T}, val, i::Integer) where {T} - return Base.arrayset(false, A, convert(T, val)::T, Int(i)) +function unsafe_setindex!(A::Array{T}, v) where {T} + Base.arrayset(false, A, convert(T, v)::T, 1) end +function unsafe_setindex!(A::Array{T}, v, i::CanonicalInt) where {T} + return Base.arrayset(false, A, convert(T, v)::T, Int(i)) +end + +unsafe_setindex!(a, v, i::Vararg{Any}) = unsafe_set_collection!(a, v, i) # This is based on Base._unsafe_setindex!. #= @@ -501,7 +490,7 @@ function _generate_unsafe_setindex!_body(N::Int) # the optimizer that it does not need to emit error paths Xy === nothing && break (val, state) = Xy - unsafe_set_element!(A, val, NDIndex(Base.Cartesian.@ntuple($N, i))) + unsafe_setindex!(A, val, NDIndex(Base.Cartesian.@ntuple($N, i))) Xy = iterate(x′, state) end A diff --git a/src/size.jl b/src/size.jl index c3af5c6bc..eb338e5ec 100644 --- a/src/size.jl +++ b/src/size.jl @@ -25,7 +25,6 @@ end size(x::SubArray) = eachop(_sub_size, to_parent_dims(x), x.indices) _sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = static_length(getfield(x, dim)) - @inline size(B::VecAdjTrans) = (One(), length(parent(B))) @inline size(B::MatAdjTrans) = permute(size(parent(B)), to_parent_dims(B)) @inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A} @@ -80,15 +79,15 @@ compile time. If a dimension does not have a known size along a dimension then ` returned in its position. """ known_size(x) = known_size(typeof(x)) -known_size(::Type{T}) where {T} = eachop(known_size, nstatic(Val(ndims(T))), T) - +known_size(::Type{T}) where {T} = eachop(_known_size, nstatic(Val(ndims(T))), axes_types(T)) +_known_size(::Type{T}, dim::StaticInt) where {T} = known_length(_get_tuple(T, dim)) @inline known_size(x, dim) = known_size(typeof(x), dim) @inline known_size(::Type{T}, dim) where {T} = known_size(T, to_dims(T, dim)) -@inline function known_size(::Type{T}, dim::Integer) where {T} +@inline function known_size(::Type{T}, dim::CanonicalInt) where {T} if ndims(T) < dim return 1 else - return known_length(axes_types(T, dim)) + return known_size(T)[dim] end end diff --git a/src/stridelayout.jl b/src/stridelayout.jl index 432ecbc6d..6a7ba5869 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -52,9 +52,9 @@ Returns offsets of indices with respect to 0. If values are known at compile tim it should return them as `Static` numbers. For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1))`. """ -offsets(x::StrideIndex) = getfield(x, :offsets) @inline offsets(x, i) = static_first(indices(x, i)) offsets(::Tuple) = (One(),) +offsets(x::StrideIndex) = getfield(x, :offsets) offsets(x) = eachop(_offsets, nstatic(Val(ndims(x))), x) function _offsets(x::X, dim::StaticInt{D}) where {X,D} start = known_first(axes_types(X, dim)) diff --git a/test/indexing.jl b/test/indexing.jl index ac324b637..93dae57fc 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -26,8 +26,9 @@ end @test @inferred(ArrayInterface.to_index(axis, CartesianIndices(()))) === CartesianIndices(()) x = LinearIndices((static(0):static(3),static(3):static(5),static(-2):static(0))); - @test @inferred(ArrayInterface.to_index(x, NDIndex((0, 3, -2)))) === 1 - @test @inferred(ArrayInterface.to_index(x, NDIndex(static(0), static(3), static(-2)))) === static(1) + + # @test @inferred(ArrayInterface.to_index(x, NDIndex((0, 3, -2)))) === 1 + # @test @inferred(ArrayInterface.to_index(x, NDIndex(static(0), static(3), static(-2)))) === static(1) @test_throws BoundsError ArrayInterface.to_index(axis, 4) @test_throws BoundsError ArrayInterface.to_index(axis, 1:4) @@ -125,8 +126,8 @@ end #@test_throws ArgumentError Base._sub2ind((1:3,), 2) #@test_throws ArgumentError Base._ind2sub((1:3,), 2) x = Array{Int,2}(undef, (2, 2)) - ArrayInterface.unsafe_set_index!(x, 1, (2, 2)) - @test ArrayInterface.unsafe_get_index(x, (2, 2)) === 1 + ArrayInterface.unsafe_setindex!(x, 1, 2, 2) + @test ArrayInterface.unsafe_getindex(x, 2, 2) === 1 # FIXME @test_throws MethodError ArrayInterface.unsafe_set_element!(x, 1, (:x, :x)) # FIXME @test_throws MethodError ArrayInterface.unsafe_get_element(x, (:x, :x))