diff --git a/.gitignore b/.gitignore index 49541d368..a9edec329 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *.jl.mem deps/deps.jl Manifest.toml +*~ diff --git a/Project.toml b/Project.toml index d70b9896f..32fb4c1a2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "2.12.1" +version = "2.13.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -16,10 +16,11 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "Aqua"] +test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "OffsetArrays", "Aqua"] diff --git a/README.md b/README.md index ca1e862f9..e632a8b04 100644 --- a/README.md +++ b/README.md @@ -134,12 +134,80 @@ Otherwise, returns `nothing`. For example, `known_step(UnitRange{Int})` returns If `length` of an instance of type `T` is known at compile time, return it. Otherwise, return `nothing`. -## Static(N::Int) +## device(::Type{T}) + +Indicates the most efficient way to access elements from the collection in low level code. +For `GPUArrays`, will return `ArrayInterface.GPU()`. +For `AbstractArray` supporting a `pointer` method, returns `ArrayInterface.CPUPointer()`. +For other `AbstractArray`s and `Tuple`s, returns `ArrayInterface.CPUIndex()`. +Otherwise, returns `nothing`. + +## contiguous_axis(::Type{T}) + +Returns the axis of an array of type `T` containing contiguous data. +If no axis is contiguous, it returns `Contiguous{-1}`. +If unknown, it returns `nothing`. + +## contiguous_axis_indicator(::Type{T}) + +Returns a tuple of boolean `Val`s indicating whether that axis is contiguous. + +## contiguous_batch_size(::Type{T}) + +Returns the size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`. +If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`. +If `contiguous_axis(T) == -1`, it will return `ContiguousBatch{-1}()`. +If unknown, it will return `nothing`. + +## stride_rank(::Type{T}) + +Returns the rank of each stride. + +## dense_dims(::Type{T}) +Returns a tuple of indicators for whether each axis is dense. +An axis `i` of array `A` is dense if `stride(A, i) * size(A, i) == stride(A, j)` where `j` is the axis (if it exists) such that `stride_rank(A)[i] + 1 == stride_rank(A)[j]`. + +## ArrayInterface.size(A) + +Returns the size of `A`. If the size of any axes are known at compile time, +these should be returned as `StaticInt`s. For example: +```julia +julia> using StaticArrays, ArrayInterface + +julia> A = @SMatrix rand(3,4); + +julia> ArrayInterface.size(A) +(StaticInt{3}(), StaticInt{4}()) +``` + +## ArrayInterface.strides(A) + +Returns the strides of array `A`. If any strides are known at compile time, +these should be returned as `StaticInt`s. For example: +```julia +julia> using ArrayInterface + +julia> A = rand(3,4); + +julia> ArrayInterface.strides(A) +(StaticInt{1}(), 3) +``` +## offsets(A) + +Returns offsets of indices with respect to 0. If values are known at compile time, +it should return them as `StaticInt`s. +For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1))`. + +## can_avx(f) + +Is the function `f` whitelisted for `LoopVectorization.@avx`? + +## StaticInt(N::Int) Creates a static integer with value known at compile time. It is a number, -supporting basic arithmetic. Many operations with two `Static` integers -will produce another `Static` integer. If one of the arguments to a -function call isn't static (e.g., `Static(4) + 3`) then the `Static` +supporting basic arithmetic. Many operations with two `StaticInt` integers +will produce another `StaticInt` integer. If one of the arguments to a +function call isn't static (e.g., `StaticInt(4) + 3`) then the `StaticInt` number will promote to a dynamic value. # List of things to add diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 170da63bb..bc20d024e 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -33,8 +33,9 @@ If `length` of an instance of type `T` is known at compile time, return it. Otherwise, return `nothing`. """ known_length(x) = known_length(typeof(x)) -known_length(::Type{<:NTuple{N,<:Any}}) where {N} = N known_length(::Type{<:NamedTuple{L}}) where {L} = length(L) +known_length(::Type{T}) where {T<:Base.Slice} = known_length(parent_type(T)) +known_length(::Type{<:Tuple{Vararg{Any,N}}}) where {N} = N known_length(::Type{<:Number}) = 1 function known_length(::Type{T}) where {T} if parent_type(T) <: T @@ -52,7 +53,7 @@ _known_length(x::Tuple{Vararg{Int}}) = prod(x) """ can_change_size(::Type{T}) -> Bool -Returns `true` if the size of `T` can change, in which case operations +Returns `true` if the Base.size of `T` can change, in which case operations such as `pop!` and `popfirst!` are available for collections of type `T`. """ can_change_size(x) = can_change_size(typeof(x)) @@ -102,7 +103,7 @@ function Base.setindex(x::AbstractVector,v,i::Int) end function Base.setindex(x::AbstractMatrix,v,i::Int,j::Int) - n,m = size(x) + n,m = Base.size(x) x .* (i .!== 1:n) .* (j .!== i:m)' .+ v .* (i .== 1:n) .* (j .== i:m)' end @@ -202,7 +203,7 @@ Return: (I,J) #indexable objects Find sparsity pattern of special matrices, the same as the first two elements of findnz(::SparseMatrixCSC) """ function findstructralnz(x::Diagonal) - n=size(x,1) + n = Base.size(x,1) (1:n,1:n) end @@ -412,7 +413,7 @@ function Base.getindex(ind::BandedBlockBandedMatrixIndex,i::Int) end function findstructralnz(x::Bidiagonal) - n=size(x,1) + n= Base.size(x,1) isup= x.uplo=='U' ? true : false rowind=BidiagonalIndex(n+n-1,isup) colind=BidiagonalIndex(n+n-1,!isup) @@ -420,7 +421,7 @@ function findstructralnz(x::Bidiagonal) end function findstructralnz(x::Union{Tridiagonal,SymTridiagonal}) - n=size(x,1) + n= Base.size(x,1) rowind=TridiagonalIndex(n+n-1+n-1,n,true) colind=TridiagonalIndex(n+n-1+n-1,n,false) (rowind,colind) @@ -447,10 +448,10 @@ fast_matrix_colors(A::Type{<:Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagona matrix_colors(A::Union{Array,UpperTriangular,LowerTriangular}) The color vector for dense matrix and triangular matrix is simply -`[1,2,3,...,size(A,2)]` +`[1,2,3,..., Base.size(A,2)]` """ function matrix_colors(A::Union{Array,UpperTriangular,LowerTriangular}) - eachindex(1:size(A,2)) # Vector size matches number of rows + eachindex(1:Base.size(A,2)) # Vector Base.size matches number of rows end function _cycle(repetend,len) @@ -458,15 +459,15 @@ function _cycle(repetend,len) end function matrix_colors(A::Diagonal) - fill(1,size(A,2)) + fill(1, Base.size(A,2)) end function matrix_colors(A::Bidiagonal) - _cycle(1:2,size(A,2)) + _cycle(1:2, Base.size(A,2)) end function matrix_colors(A::Union{Tridiagonal,SymTridiagonal}) - _cycle(1:3,size(A,2)) + _cycle(1:3, Base.size(A,2)) end """ @@ -540,9 +541,64 @@ function restructure(x,y) end function restructure(x::Array,y) - reshape(convert(Array,y),size(x)...) + reshape(convert(Array,y), Base.size(x)...) end +abstract type AbstractDevice end +abstract type AbstractCPU <: AbstractDevice end +struct CPUPointer <: AbstractCPU end +struct CheckParent end +struct CPUIndex <: AbstractCPU end +struct GPU <: AbstractDevice end +""" +device(::Type{T}) + +Indicates the most efficient way to access elements from the collection in low level code. +For `GPUArrays`, will return `ArrayInterface.GPU()`. +For `AbstractArray` supporting a `pointer` method, returns `ArrayInterface.CPUPointer()`. +For other `AbstractArray`s and `Tuple`s, returns `ArrayInterface.CPUIndex()`. +Otherwise, returns `nothing`. +""" +device(A) = device(typeof(A)) +device(::Type) = nothing +device(::Type{<:Tuple}) = CPUIndex() +# Relies on overloading for GPUArrays that have subtyped `StridedArray`. +device(::Type{<:StridedArray}) = CPUPointer() +function device(::Type{T}) where {T <: AbstractArray} + P = parent_type(T) + T === P ? CPUIndex() : device(P) +end + + +""" +defines_strides(::Type{T}) -> Bool + +Is strides(::T) defined? +""" +defines_strides(::Type) = false +defines_strides(x) = defines_strides(typeof(x)) +defines_strides(::Type{<:StridedArray}) = true +defines_strides(::Type{A}) where {A <: Union{<:Transpose,<:Adjoint,<:SubArray,<:PermutedDimsArray}} = defines_strides(parent_type(A)) + +""" +can_avx(f) + +Returns `true` if the function `f` is guaranteed to be compatible with `LoopVectorization.@avx` for supported element and array types. +While a return value of `false` does not indicate the function isn't supported, this allows a library to conservatively apply `@avx` +only when it is known to be safe to do so. + +```julia +function mymap!(f, y, args...) + if can_avx(f) + @avx @. y = f(args...) + else + @. y = f(args...) + end +end +``` +""" +can_avx(::Any) = false + """ insert(collection, index, item) @@ -654,6 +710,7 @@ function __init__() ismutable(::Type{<:StaticArrays.StaticArray}) = false can_setindex(::Type{<:StaticArrays.StaticArray}) = false ismutable(::Type{<:StaticArrays.MArray}) = true + ismutable(::Type{<:StaticArrays.SizedArray}) = true function lu_instance(_A::StaticArrays.StaticMatrix{N,N}) where {N} A = StaticArrays.SArray(_A) @@ -675,6 +732,26 @@ function __init__() known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N + device(::Type{<:StaticArrays.MArray}) = CPUPointer() + contiguous_axis(::Type{<:StaticArrays.StaticArray}) = Contiguous{1}() + contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = ContiguousBatch{0}() + stride_rank(::Type{T}) where {N, T <: StaticArrays.StaticArray{<:Any,<:Any,N}} = StrideRank{ntuple(identity, Val{N}())}() + dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} = DenseDims{ntuple(_ -> true, Val(N))}() + defines_strides(::Type{<:StaticArrays.MArray}) = true + @generated function size(A::StaticArrays.StaticArray{S}) where {S} + t = Expr(:tuple); Sp = S.parameters + for n in 1:length(Sp) + push!(t.args, Expr(:call, Expr(:curly, :StaticInt, Sp[n]))) + end + t + end + @generated function strides(A::StaticArrays.StaticArray{S}) where {S} + t = Expr(:tuple, Expr(:call, Expr(:curly, :StaticInt, 1))); Sp = S.parameters; x = 1 + for n in 1:length(Sp)-1 + push!(t.args, Expr(:call, Expr(:curly, :StaticInt, (x *= Sp[n])))) + end + t + end @require Adapt="79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin function Adapt.adapt_storage(::Type{<:StaticArrays.SArray{S}},xs::Array) where S StaticArrays.SArray{S}(xs) @@ -694,7 +771,7 @@ function __init__() aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where N = Tracker.collect(x) end - @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin + @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin @require Adapt="79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin include("cuarrays.jl") end @@ -717,7 +794,7 @@ function __init__() @require BandedMatrices="aae01518-5342-5314-be14-df237901396f" begin function findstructralnz(x::BandedMatrices.BandedMatrix) l,u=BandedMatrices.bandwidths(x) - rowsize,colsize=size(x) + rowsize,colsize= Base.size(x) rowind=BandedMatrixIndex(rowsize,colsize,l,u,true) colind=BandedMatrixIndex(rowsize,colsize,l,u,false) (rowind,colind) @@ -730,7 +807,7 @@ function __init__() function matrix_colors(A::BandedMatrices.BandedMatrix) l,u=BandedMatrices.bandwidths(A) width=u+l+1 - _cycle(1:width,size(A,2)) + _cycle(1:width, Base.size(A,2)) end end @@ -794,9 +871,19 @@ function __init__() end end end + @require OffsetArrays="6fe1bfb0-de20-5000-8ca7-80f57d26f881" begin + size(A::OffsetArrays.OffsetArray) = size(parent(A)) + strides(A::OffsetArrays.OffsetArray) = strides(parent(A)) + # offsets(A::OffsetArrays.OffsetArray) = map(+, A.offsets, offsets(parent(A))) + device(::OffsetArrays.OffsetArray) = CheckParent() + contiguous_axis(A::OffsetArrays.OffsetArray) = contiguous_axis(parent(A)) + contiguous_batch_size(A::OffsetArrays.OffsetArray) = contiguous_batch_size(parent(A)) + stride_rank(A::OffsetArrays.OffsetArray) = stride_rank(parent(A)) + end end include("static.jl") include("ranges.jl") +include("stridelayout.jl") end diff --git a/src/cuarrays.jl b/src/cuarrays.jl index d412f77d1..6a0c9e706 100644 --- a/src/cuarrays.jl +++ b/src/cuarrays.jl @@ -9,5 +9,8 @@ function Base.setindex(x::CuArrays.CuArray,v,i::Int) end function restructure(x::CuArrays.CuArray,y) - reshape(Adapt.adapt(parameterless_type(x),y),size(x)...) + reshape(Adapt.adapt(parameterless_type(x),y), Base.size(x)...) end + +Device(::Type{<:CuArrays.CuArray}) = GPU() + diff --git a/src/cuarrays2.jl b/src/cuarrays2.jl index b8ec51c7d..962c86686 100644 --- a/src/cuarrays2.jl +++ b/src/cuarrays2.jl @@ -9,5 +9,8 @@ function Base.setindex(x::CUDA.CuArray,v,i::Int) end function restructure(x::CUDA.CuArray,y) - reshape(Adapt.adapt(parameterless_type(x),y),size(x)...) + reshape(Adapt.adapt(parameterless_type(x),y), Base.size(x)...) end + +Device(::Type{<:CUDA.CuArray}) = GPU() + diff --git a/src/ranges.jl b/src/ranges.jl index 84ce7167c..988458d2f 100644 --- a/src/ranges.jl +++ b/src/ranges.jl @@ -78,17 +78,17 @@ struct OptionallyStaticUnitRange{F <: Integer, L <: Integer} <: AbstractUnitRang end end -Base.:(:)(L::Integer, ::Static{U}) where {U} = OptionallyStaticUnitRange(L, Static(U)) -Base.:(:)(::Static{L}, U::Integer) where {L} = OptionallyStaticUnitRange(Static(L), U) -Base.:(:)(::Static{L}, ::Static{U}) where {L,U} = OptionallyStaticUnitRange(Static(L), Static(U)) +Base.:(:)(L::Integer, ::StaticInt{U}) where {U} = OptionallyStaticUnitRange(L, StaticInt(U)) +Base.:(:)(::StaticInt{L}, U::Integer) where {L} = OptionallyStaticUnitRange(StaticInt(L), U) +Base.:(:)(::StaticInt{L}, ::StaticInt{U}) where {L,U} = OptionallyStaticUnitRange(StaticInt(L), StaticInt(U)) Base.first(r::OptionallyStaticUnitRange) = r.start -Base.step(::OptionallyStaticUnitRange) = Static(1) +Base.step(::OptionallyStaticUnitRange) = StaticInt(1) Base.last(r::OptionallyStaticUnitRange) = r.stop -known_first(::Type{<:OptionallyStaticUnitRange{Static{F}}}) where {F} = F +known_first(::Type{<:OptionallyStaticUnitRange{StaticInt{F}}}) where {F} = F known_step(::Type{<:OptionallyStaticUnitRange}) = 1 -known_last(::Type{<:OptionallyStaticUnitRange{<:Any,Static{L}}}) where {L} = L +known_last(::Type{<:OptionallyStaticUnitRange{<:Any,StaticInt{L}}}) where {L} = L function Base.isempty(r::OptionallyStaticUnitRange) if known_first(r) === oneunit(eltype(r)) @@ -102,7 +102,7 @@ unsafe_isempty_one_to(lst) = lst <= zero(lst) unsafe_isempty_unit_range(fst, lst) = fst > lst unsafe_length_one_to(lst::Int) = lst -unsafe_length_one_to(::Static{L}) where {L} = lst +unsafe_length_one_to(::StaticInt{L}) where {L} = lst Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer) if known_first(r) === oneunit(r) @@ -127,18 +127,24 @@ end return convert(eltype(r), val) end -@inline _try_static(::Static{N}, ::Static{N}) where {N} = Static{N}() -@inline _try_static(::Static{M}, ::Static{N}) where {M, N} = @assert false "Unequal Indices: Static{$M}() != Static{$N}()" -function _try_static(::Static{N}, x) where {N} - @assert N == x "Unequal Indices: Static{$N}() != x == $x" - return Static{N}() +@inline _try_static(::StaticInt{N}, ::StaticInt{N}) where {N} = StaticInt{N}() +@inline _try_static(::StaticInt{M}, ::StaticInt{N}) where {M, N} = @assert false "Unequal Indices: StaticInt{$M}() != StaticInt{$N}()" +@propagate_inbounds function _try_static(::StaticInt{N}, x) where {N} + @boundscheck begin + @assert N == x "Unequal Indices: StaticInt{$N}() != x == $x" + end + return StaticInt{N}() end -function _try_static(x, ::Static{N}) where {N} - @assert N == x "Unequal Indices: x == $x != Static{$N}()" - return Static{N}() +@propagate_inbounds function _try_static(x, ::StaticInt{N}) where {N} + @boundscheck begin + @assert N == x "Unequal Indices: x == $x != StaticInt{$N}()" + end + return StaticInt{N}() end -function _try_static(x, y) - @assert x == y "Unequal Indicess: x == $x != $y == y" +@propagate_inbounds function _try_static(x, y) + @boundscheck begin + @assert x == y "Unequal Indicess: x == $x != $y == y" + end return x end diff --git a/src/static.jl b/src/static.jl index 013dbcdb9..8ac44a5f4 100644 --- a/src/static.jl +++ b/src/static.jl @@ -1,94 +1,95 @@ """ A statically sized `Int`. -Use `Static(N)` instead of `Val(N)` when you want it to behave like a number. +Use `StaticInt(N)` instead of `Val(N)` when you want it to behave like a number. """ -struct Static{N} <: Integer - Static{N}() where {N} = new{N::Int}() +struct StaticInt{N} <: Integer + StaticInt{N}() where {N} = new{N::Int}() end -const Zero = Static{0} -const One = Static{1} +const Zero = StaticInt{0} +const One = StaticInt{1} -Base.@pure Static(N::Int) = Static{N}() -Static(N::Integer) = Static(convert(Int, N)) -Static(::Static{N}) where {N} = Static{N}() -Static(::Val{N}) where {N} = Static{N}() -Base.Val(::Static{N}) where {N} = Val{N}() -Base.convert(::Type{T}, ::Static{N}) where {T<:Number,N} = convert(T, N) -Base.convert(::Type{Static{N}}, ::Static{N}) where {N} = Static{N}() +Base.@pure StaticInt(N::Int) = StaticInt{N}() +StaticInt(N::Integer) = StaticInt(convert(Int, N)) +StaticInt(::StaticInt{N}) where {N} = StaticInt{N}() +StaticInt(::Val{N}) where {N} = StaticInt{N}() +# Base.Val(::StaticInt{N}) where {N} = Val{N}() +Base.convert(::Type{T}, ::StaticInt{N}) where {T<:Number,N} = convert(T, N) +# (::Type{T})(::ArrayInterface.StaticInt{N}) where {T,N} = T(N) +Base.convert(::Type{StaticInt{N}}, ::StaticInt{N}) where {N} = StaticInt{N}() -Base.promote_rule(::Type{<:Static}, ::Type{T}) where {T <: AbstractIrrational} = promote_rule(Int, T) -Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T <: AbstractIrrational} = promote_rule(T, Int) +Base.promote_rule(::Type{<:StaticInt}, ::Type{T}) where {T <: AbstractIrrational} = promote_rule(Int, T) +Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T <: AbstractIrrational} = promote_rule(T, Int) for (S,T) ∈ [(:Complex,:Real), (:Rational, :Integer), (:(Base.TwicePrecision),:Any)] - @eval Base.promote_rule(::Type{$S{T}}, ::Type{<:Static}) where {T <: $T} = promote_rule($S{T}, Int) + @eval Base.promote_rule(::Type{$S{T}}, ::Type{<:StaticInt}) where {T <: $T} = promote_rule($S{T}, Int) end -Base.promote_rule(::Type{Union{Nothing,Missing}}, ::Type{<:Static}) = Union{Nothing, Missing, Int} -Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Union{Missing,Nothing}} = promote_rule(T, Int) -Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Nothing} = promote_rule(T, Int) -Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Missing} = promote_rule(T, Int) +Base.promote_rule(::Type{Union{Nothing,Missing}}, ::Type{<:StaticInt}) = Union{Nothing, Missing, Int} +Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Union{Missing,Nothing}} = promote_rule(T, Int) +Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Nothing} = promote_rule(T, Int) +Base.promote_rule(::Type{T}, ::Type{<:StaticInt}) where {T >: Missing} = promote_rule(T, Int) for T ∈ [:Bool, :Missing, :BigFloat, :BigInt, :Nothing, :Any] # let S = :Any @eval begin - Base.promote_rule(::Type{S}, ::Type{$T}) where {S <: Static} = promote_rule(Int, $T) - Base.promote_rule(::Type{$T}, ::Type{S}) where {S <: Static} = promote_rule($T, Int) + Base.promote_rule(::Type{S}, ::Type{$T}) where {S <: StaticInt} = promote_rule(Int, $T) + Base.promote_rule(::Type{$T}, ::Type{S}) where {S <: StaticInt} = promote_rule($T, Int) end end -Base.promote_rule(::Type{<:Static}, ::Type{<:Static}) = Int -Base.:(%)(::Static{N}, ::Type{Integer}) where {N} = N +Base.promote_rule(::Type{<:StaticInt}, ::Type{<:StaticInt}) = Int +Base.:(%)(::StaticInt{N}, ::Type{Integer}) where {N} = N -Base.eltype(::Type{T}) where {T<:Static} = Int +Base.eltype(::Type{T}) where {T<:StaticInt} = Int Base.iszero(::Zero) = true -Base.iszero(::Static) = false +Base.iszero(::StaticInt) = false Base.isone(::One) = true -Base.isone(::Static) = false -Base.zero(::Type{T}) where {T<:Static} = Zero() -Base.one(::Type{T}) where {T<:Static} = One() +Base.isone(::StaticInt) = false +Base.zero(::Type{T}) where {T<:StaticInt} = Zero() +Base.one(::Type{T}) where {T<:StaticInt} = One() for T = [:Real, :Rational, :Integer] @eval begin @inline Base.:(+)(i::$T, ::Zero) = i - @inline Base.:(+)(i::$T, ::Static{M}) where {M} = i + M + @inline Base.:(+)(i::$T, ::StaticInt{M}) where {M} = i + M @inline Base.:(+)(::Zero, i::$T) = i - @inline Base.:(+)(::Static{M}, i::$T) where {M} = M + i + @inline Base.:(+)(::StaticInt{M}, i::$T) where {M} = M + i @inline Base.:(-)(i::$T, ::Zero) = i - @inline Base.:(-)(i::$T, ::Static{M}) where {M} = i - M + @inline Base.:(-)(i::$T, ::StaticInt{M}) where {M} = i - M @inline Base.:(*)(i::$T, ::Zero) = Zero() @inline Base.:(*)(i::$T, ::One) = i - @inline Base.:(*)(i::$T, ::Static{M}) where {M} = i * M + @inline Base.:(*)(i::$T, ::StaticInt{M}) where {M} = i * M @inline Base.:(*)(::Zero, i::$T) = Zero() @inline Base.:(*)(::One, i::$T) = i - @inline Base.:(*)(::Static{M}, i::$T) where {M} = M * i + @inline Base.:(*)(::StaticInt{M}, i::$T) where {M} = M * i end end @inline Base.:(+)(::Zero, ::Zero) = Zero() -@inline Base.:(+)(::Zero, ::Static{M}) where {M} = Static{M}() -@inline Base.:(+)(::Static{M}, ::Zero) where {M} = Static{M}() +@inline Base.:(+)(::Zero, ::StaticInt{M}) where {M} = StaticInt{M}() +@inline Base.:(+)(::StaticInt{M}, ::Zero) where {M} = StaticInt{M}() -@inline Base.:(-)(::Static{M}, ::Zero) where {M} = Static{M}() +@inline Base.:(-)(::StaticInt{M}, ::Zero) where {M} = StaticInt{M}() @inline Base.:(*)(::Zero, ::Zero) = Zero() @inline Base.:(*)(::One, ::Zero) = Zero() @inline Base.:(*)(::Zero, ::One) = Zero() @inline Base.:(*)(::One, ::One) = One() -@inline Base.:(*)(::Static{M}, ::Zero) where {M} = Zero() -@inline Base.:(*)(::Zero, ::Static{M}) where {M} = Zero() -@inline Base.:(*)(::Static{M}, ::One) where {M} = Static{M}() -@inline Base.:(*)(::One, ::Static{M}) where {M} = Static{M}() +@inline Base.:(*)(::StaticInt{M}, ::Zero) where {M} = Zero() +@inline Base.:(*)(::Zero, ::StaticInt{M}) where {M} = Zero() +@inline Base.:(*)(::StaticInt{M}, ::One) where {M} = StaticInt{M}() +@inline Base.:(*)(::One, ::StaticInt{M}) where {M} = StaticInt{M}() for f ∈ [:(+), :(-), :(*), :(/), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :(⊻)] - @eval @generated Base.$f(::Static{M}, ::Static{N}) where {M,N} = Expr(:call, Expr(:curly, :Static, $f(M, N))) + @eval @generated Base.$f(::StaticInt{M}, ::StaticInt{N}) where {M,N} = Expr(:call, Expr(:curly, :StaticInt, $f(M, N))) end for f ∈ [:(==), :(!=), :(<), :(≤), :(>), :(≥)] @eval begin - @inline Base.$f(::Static{M}, ::Static{N}) where {M,N} = $f(M, N) - @inline Base.$f(::Static{M}, x::Int) where {M} = $f(M, x) - @inline Base.$f(x::Int, ::Static{M}) where {M} = $f(x, M) + @inline Base.$f(::StaticInt{M}, ::StaticInt{N}) where {M,N} = $f(M, N) + @inline Base.$f(::StaticInt{M}, x::Int) where {M} = $f(M, x) + @inline Base.$f(x::Int, ::StaticInt{M}) where {M} = $f(x, M) end end @inline function maybe_static(f::F, g::G, x) where {F, G} L = f(x) - isnothing(L) ? g(x) : Static(L) + isnothing(L) ? g(x) : StaticInt(L) end @inline static_length(x) = maybe_static(known_length, length, x) @inline static_first(x) = maybe_static(known_first, first, x) diff --git a/src/stridelayout.jl b/src/stridelayout.jl new file mode 100644 index 000000000..943ee129b --- /dev/null +++ b/src/stridelayout.jl @@ -0,0 +1,302 @@ +struct Contiguous{N} end +Base.@pure Contiguous(N::Int) = Contiguous{N}() +_get(::Contiguous{N}) where {N} = N +""" +contiguous_axis(::Type{T}) -> Contiguous{N} + +Returns the axis of an array of type `T` containing contiguous data. +If no axis is contiguous, it returns `Contiguous{-1}`. +If unknown, it returns `nothing`. +""" +contiguous_axis(x) = contiguous_axis(typeof(x)) +contiguous_axis(::Type) = nothing +contiguous_axis(::Type{<:Array}) = Contiguous{1}() +contiguous_axis(::Type{<:Tuple}) = Contiguous{1}() +function contiguous_axis(::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}) where {T,A<:AbstractVector{T}} + c = contiguous_axis(A) + isnothing(c) && return nothing + c === Contiguous{1}() ? Contiguous{2}() : Contiguous{-1}() +end +function contiguous_axis(::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}) where {T,A<:AbstractMatrix{T}} + c = contiguous_axis(A) + isnothing(c) && return nothing + contig = _get(c) + new_contig = contig == -1 ? -1 : 3 - contig + Contiguous{new_contig}() +end +function contiguous_axis(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A<:AbstractArray{T,N}} + c = contiguous_axis(A) + isnothing(c) && return nothing + new_contig = I2[_get(c)] + Contiguous{new_contig}() +end +function contiguous_axis(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} + _contiguous_axis(S, contiguous_axis(A)) +end +_contiguous_axis(::Any, ::Nothing) = nothing +@generated function _contiguous_axis(::Type{S}, ::Contiguous{C}) where {C,N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} + n = 0 + new_contig = contig = C + for np in 1:NP + if I.parameters[np] <: AbstractUnitRange + n += 1 + if np == contig + new_contig = n + end + else + if np == contig + new_contig = -1 + end + end + end + # If n != N, then an axis was indeced by something other than an integer or `AbstractUnitRange`, so we return `nothing` + n == N || return nothing + Expr(:call, Expr(:curly, :Contiguous, new_contig)) +end + +""" +contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}} + +Returns a tuple boolean `Val`s indicating whether that axis is contiguous. +""" +contiguous_axis_indicator(::Type{A}) where {D, A <: AbstractArray{<:Any,D}} = contiguous_axis_indicator(contiguous_axis(A), Val(D)) +contiguous_axis_indicator(::A) where {A <: AbstractArray} = contiguous_axis_indicator(A) +Base.@pure contiguous_axis_indicator(::Contiguous{N}, ::Val{D}) where {N,D} = ntuple(d -> Val{d == N}(), Val{D}()) + +""" +If the contiguous dimension is not the dimension with `Stride_rank{1}` +""" +struct ContiguousBatch{N} end +Base.@pure ContiguousBatch(N::Int) = ContiguousBatch{N}() +_get(::ContiguousBatch{N}) where {N} = N + +""" +contiguous_batch_size(::Type{T}) -> ContiguousBatch{N} + +Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`. +If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`. +If `contiguous_axis(T) == -1`, it will return `ContiguousBatch{-1}()`. +If unknown, it will return `nothing`. +""" +contiguous_batch_size(x) = contiguous_batch_size(typeof(x)) +contiguous_batch_size(::Type) = nothing +contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = ContiguousBatch{0}() +contiguous_batch_size(::Type{<:Tuple}) = ContiguousBatch{0}() +contiguous_batch_size(::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}) where {T,A<:AbstractVecOrMat{T}} = contiguous_batch_size(A) +contiguous_batch_size(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A<:AbstractArray{T,N}} = contiguous_batch_size(A) +function contiguous_batch_size(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} + _contiguous_batch_size(S, contiguous_batch_size(A), contiguous_axis(A)) +end +_contiguous_batch_size(::Any, ::Any, ::Any) = nothing +@generated function _contiguous_batch_size(::Type{S}, ::ContiguousBatch{B}, ::Contiguous{C}) where {B,C,N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} + if I.parameters[C] <: AbstractUnitRange + Expr(:call, Expr(:curly, :ContiguousBatch, B)) + else + Expr(:call, Expr(:curly, :ContiguousBatch, -1)) + end +end + +struct StrideRank{R} end +Base.@pure StrideRank(R::NTuple{<:Any,Int}) = StrideRank{R}() +_get(::StrideRank{R}) where {R} = R +Base.collect(::StrideRank{R}) where {R} = collect(R) +@inline Base.getindex(::StrideRank{R}, i::Integer) where {R} = R[i] +@inline Base.getindex(::StrideRank{R}, ::Val{I}) where {R,I} = StrideRank{permute(R, I)}() + +""" +rank_to_sortperm(::StrideRank) -> NTuple{N,Int} + +Returns the `sortperm` of the stride ranks. +""" +function rank_to_sortperm(R::NTuple{N,Int}) where {N} + sp = ntuple(zero, Val{N}()) + r = ntuple(n -> sum(R[n] .≥ R), Val{N}()) + @inbounds for n in 1:N + sp = Base.setindex(sp, n, r[n]) + end + sp +end +@generated Base.sortperm(::StrideRank{R}) where {R} = rank_to_sortperm(R) + +stride_rank(x) = stride_rank(typeof(x)) +stride_rank(::Type) = nothing +stride_rank(::Type{Array{T,N}}) where {T,N} = StrideRank{ntuple(identity, Val{N}())}() +stride_rank(::Type{<:Tuple}) = StrideRank{(1,)}() + +stride_rank(::Type{B}) where {T, A <: AbstractVector{T}, B <: Union{Transpose{T,A},Adjoint{T,A}}} = StrideRank{(2, 1)}() +stride_rank(::Type{B}) where {T, A <: AbstractMatrix{T}, B <: Union{Transpose{T,A},Adjoint{T,A}}} = _stride_rank(B, stride_rank(A)) +_stride_rank(::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}, ::Nothing) where {T,A<:AbstractMatrix{T}} = nothing +_stride_rank(::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}, rank) where {T,A<:AbstractMatrix{T}} = rank[Val{(2,1)}()] + +stride_rank(::Type{B}) where {T,N,I1,I2,A<:AbstractArray{T,N},B<:PermutedDimsArray{T,N,I1,I2,A}} = _stride_rank(B, stride_rank(A)) +_stride_rank(::Type{B}, ::Nothing) where {T,N,I1,I2,A<:AbstractArray{T,N},B<:PermutedDimsArray{T,N,I1,I2,A}} = nothing +_stride_rank(::Type{B}, rank) where {T,N,I1,I2,A<:AbstractArray{T,N},B<:PermutedDimsArray{T,N,I1,I2,A}} = rank[Val{I1}()] +function stride_rank(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} + _stride_rank(S, stride_rank(A)) +end +_stride_rank(::Any, ::Any) = nothing +@generated function _stride_rank(::Type{S}, ::StrideRank{R}) where {R,N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} + rankv = collect(R) + rank_new = Int[] + n = 0 + for np in 1:NP + r = rankv[np] + if I.parameters[np] <: AbstractUnitRange + n += 1 + push!(rank_new, r) + end + end + # If n != N, then an axis was indeced by something other than an integer or `AbstractUnitRange`, so we return `nothing` + n == N || return nothing + ranktup = Expr(:tuple); append!(ranktup.args, rank_new) # dynamic splats bad + Expr(:call, Expr(:curly, :StrideRank, ranktup)) +end +stride_rank(x, i) = stride_rank(x)[i] + +struct DenseDims{D} end +Base.@pure DenseDims(D::NTuple{<:Any,Bool}) = DenseDims{D}() +@inline Base.getindex(::DenseDims{D}, i::Integer) where {D} = D[i] +@inline Base.getindex(::DenseDims{D}, ::Val{I}) where {D,I} = DenseDims{permute(D, I)}() +""" +dense_dims(::Type{T}) -> NTuple{N,Bool} + +Returns a tuple of indicators for whether each axis is dense. +An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`. +""" +dense_dims(x) = dense_dims(typeof(x)) +dense_dims(::Type) = nothing +dense_dims(::Type{Array{T,N}}) where {T,N} = DenseDims{ntuple(_ -> true, Val{N}())}() +dense_dims(::Type{<:Tuple}) = DenseDims{(true,)}() +function dense_dims(::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}) where {T,A<:AbstractMatrix{T}} + dense = dense_dims(A) + isnothing(dense) ? nothing : dense[Val{(2,1)}()] +end +function dense_dims(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A<:AbstractArray{T,N}} + dense = dense_dims(A) + isnothing(dense) ? nothing : dense[Val{I1}()] +end +function dense_dims(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} + _dense_dims(S, dense_dims(A), stride_rank(A)) +end +_dense_dims(::Any, ::Any) = nothing +@generated function _dense_dims(::Type{S}, ::DenseDims{D}, ::StrideRank{R}) where {D,R,N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} + still_dense = true + sp = rank_to_sortperm(R) + densev = Vector{Bool}(undef, NP) + for np in 1:NP + spₙ = sp[np] + still_dense &= D[spₙ] + densev[spₙ] = still_dense + # a dim not being complete makes later dims not dense + still_dense &= (I.parameters[spₙ] <: Base.Slice)::Bool + end + dense_tup = Expr(:tuple) + for np in 1:NP + spₙ = sp[np] + if I.parameters[np] <: Base.Slice + push!(dense_tup.args, densev[np]) + elseif I.parameters[np] <: AbstractUnitRange + push!(dense_tup.args, densev[np]) + end + end + # If n != N, then an axis was indexed by something other than an integer or `AbstractUnitRange`, so we return `nothing` + length(dense_tup.args) == N ? Expr(:call, Expr(:curly, :DenseDims, dense_tup)) : nothing +end + +permute(t::NTuple{N}, I::NTuple{N,Int}) where {N} = ntuple(n -> t[I[n]], Val{N}()) +@generated function permute(t::Tuple{Vararg{Any,N}}, ::Val{I}) where {N,I} + t = Expr(:tuple) + foreach(i -> push!(t.args, Expr(:ref, :t, i)), I) + Expr(:block, Expr(:meta, :inline), t) +end + +""" + size(A) + +Returns the size of `A`. If the size of any axes are known at compile time, +these should be returned as `Static` numbers. For example: +```julia +julia> using StaticArrays, ArrayInterface + +julia> A = @SMatrix rand(3,4); + +julia> ArrayInterface.size(A) +(StaticInt{3}(), StaticInt{4}()) +``` +""" +size(A) = Base.size(A) +""" + strides(A) + +Returns the strides of array `A`. If any strides are known at compile time, +these should be returned as `Static` numbers. For example: +```julia +julia> A = rand(3,4); + +julia> ArrayInterface.strides(A) +(StaticInt{1}(), 3) +``` +""" +strides(A) = Base.strides(A) +""" + offsets(A) + +Returns offsets of indices with respect to 0. If values are known at compile time, +it should return them as `Static` numbers. +For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1))`. +""" +offsets(::Any) = (StaticInt{1}(),) # Assume arbitrary Julia data structures use 1-based indexing by default. +@inline strides(A::Vector{<:Any}) = (StaticInt(1),) +@inline strides(A::Array{<:Any,N}) where {N} = (StaticInt(1), Base.tail(Base.strides(A))...) +@inline strides(A::AbstractArray{<:Any,N}) where {N} = Base.strides(A) + +@inline function offsets(x, i) + inds = indices(x, i) + start = known_first(inds) + isnothing(start) ? first(inds) : StaticInt(start) +end +# @inline offsets(A::AbstractArray{<:Any,N}) where {N} = ntuple(n -> offsets(A, n), Val{N}()) +# Explicit tuple needed for inference. +@generated function offsets(A::AbstractArray{<:Any,N}) where {N} + quote + $(Expr(:meta, :inline)) + Base.Cartesian.@ntuple $N n -> offsets(A, n) + end +end + + +@inline size(B::Union{Transpose{T,A},Adjoint{T,A}}) where {T,A<:AbstractMatrix{T}} = permute(size(parent(B)), Val{(2,1)}()) +@inline size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A<:AbstractArray{T,N}} = permute(size(parent(B)), Val{I1}()) +@inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N] +@inline size(A::AbstractArray, ::Val{N}) where {N} = size(A)[N] +@inline strides(B::Union{Transpose{T,A},Adjoint{T,A}}) where {T,A<:AbstractMatrix{T}} = permute(strides(parent(B)), Val{(2,1)}()) +@inline strides(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A<:AbstractArray{T,N}} = permute(strides(parent(B)), Val{I1}()) +@inline stride(A::AbstractArray, ::StaticInt{N}) where {N} = strides(A)[N] +@inline stride(A::AbstractArray, ::Val{N}) where {N} = strides(A)[N] +stride(A, i) = Base.stride(A, i) + +size(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} = _size(size(parent(B)), B.indices, map(static_length, B.indices)) +strides(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S <: SubArray{T,N,A,I}} = _strides(strides(parent(B)), B.indices) +@generated function _size(A::Tuple{Vararg{Any,N}}, inds::I, l::L) where {N, I<:Tuple, L} + t = Expr(:tuple) + for n in 1:N + if (I.parameters[n] <: Base.Slice) + push!(t.args, :(@inbounds(_try_static(A[$n], l[$n])))) + elseif I.parameters[n] <: AbstractUnitRange + push!(t.args, Expr(:ref, :l, n)) + end + end + Expr(:block, Expr(:meta, :inline), t) +end +@generated function _strides(A::Tuple{Vararg{Any,N}}, inds::I) where {N, I<:Tuple} + t = Expr(:tuple) + for n in 1:N + if I.parameters[n] <: AbstractRange + push!(t.args, Expr(:ref, :A, n)) + elseif !(I.parameters[n] <: Integer) + return nothing + end + end + Expr(:block, Expr(:meta, :inline), t) +end + diff --git a/test/runtests.jl b/test/runtests.jl index d61ba30d2..c12168e74 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using ArrayInterface, Test using Base: setindex -import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance, Static +import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance, device, contiguous_axis, contiguous_batch_size, stride_rank, dense_dims, StaticInt @test ArrayInterface.ismutable(rand(3)) using Aqua @@ -190,18 +190,157 @@ using ArrayInterface: parent_type end @testset "Range Interface" begin - @test isnothing(ArrayInterface.known_first(typeof(1:4))) - @test isone(ArrayInterface.known_first(Base.OneTo(4))) - @test isone(ArrayInterface.known_first(typeof(Base.OneTo(4)))) + @test isnothing(@inferred(ArrayInterface.known_first(typeof(1:4)))) + @test isone(@inferred(ArrayInterface.known_first(Base.OneTo(4)))) + @test isone(@inferred(ArrayInterface.known_first(typeof(Base.OneTo(4))))) - @test isnothing(ArrayInterface.known_last(1:4)) - @test isnothing(ArrayInterface.known_last(typeof(1:4))) + @test isnothing(@inferred(ArrayInterface.known_last(1:4))) + @test isnothing(@inferred(ArrayInterface.known_last(typeof(1:4)))) - @test isnothing(ArrayInterface.known_step(typeof(1:0.2:4))) - @test isone(ArrayInterface.known_step(1:4)) - @test isone(ArrayInterface.known_step(typeof(1:4))) + @test isnothing(@inferred(ArrayInterface.known_step(typeof(1:0.2:4)))) + @test isone(@inferred(ArrayInterface.known_step(1:4))) + @test isone(@inferred(ArrayInterface.known_step(typeof(1:4)))) end +@testset "Memory Layout" begin + A = zeros(3,4,5); + @test device(A) === ArrayInterface.CPUPointer() + @test device((1,2,3)) === ArrayInterface.CPUIndex() + @test device(PermutedDimsArray(A,(3,1,2))) === ArrayInterface.CPUPointer() + @test device(view(A, 1, :, 2:4)) === ArrayInterface.CPUPointer() + @test device(view(A, 1, :, 2:4)') === ArrayInterface.CPUPointer() + @test device(@SArray(zeros(2,2,2))) === ArrayInterface.CPUIndex() + @test device(@view(@SArray(zeros(2,2,2))[1,1:2,:])) === ArrayInterface.CPUIndex() + @test device(@MArray(zeros(2,2,2))) === ArrayInterface.CPUPointer() + @test isnothing(device("Hello, world!")) + + @test @inferred(contiguous_axis(@SArray(zeros(2,2,2)))) === ArrayInterface.Contiguous(1) + @test @inferred(contiguous_axis(A)) === ArrayInterface.Contiguous(1) + @test @inferred(contiguous_axis(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.Contiguous(2) + @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === ArrayInterface.Contiguous(1) + @test @inferred(contiguous_axis(transpose(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])))) === ArrayInterface.Contiguous(2) + @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === ArrayInterface.Contiguous(2) + @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.Contiguous(-1) + @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.Contiguous(-1) + @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.Contiguous(1) + + @test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) === (Val(true),Val(false),Val(false)) + @test @inferred(ArrayInterface.contiguous_axis_indicator(A)) === (Val(true),Val(false),Val(false)) + @test @inferred(ArrayInterface.contiguous_axis_indicator(PermutedDimsArray(A,(3,1,2)))) === (Val(false),Val(true),Val(false)) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === (Val(true),Val(false)) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === (Val(false),Val(true)) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === (Val(false),Val(true),Val(false)) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === (Val(false),Val(false)) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === (Val(false),Val(false)) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === (Val(true),Val(false)) + + @test @inferred(contiguous_batch_size(@SArray(zeros(2,2,2)))) === ArrayInterface.ContiguousBatch(0) + @test @inferred(contiguous_batch_size(A)) === ArrayInterface.ContiguousBatch(0) + @test @inferred(contiguous_batch_size(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.ContiguousBatch(0) + @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === ArrayInterface.ContiguousBatch(0) + @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === ArrayInterface.ContiguousBatch(0) + @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === ArrayInterface.ContiguousBatch(0) + @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.ContiguousBatch(-1) + @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.ContiguousBatch(-1) + @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.ContiguousBatch(0) + + @test @inferred(stride_rank(@SArray(zeros(2,2,2)))) === ArrayInterface.StrideRank((1, 2, 3)) + @test @inferred(stride_rank(A)) === ArrayInterface.StrideRank((1,2,3)) + @test @inferred(stride_rank(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.StrideRank((3, 1, 2)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === ArrayInterface.StrideRank((1, 2)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === ArrayInterface.StrideRank((2, 1)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === ArrayInterface.StrideRank((3, 1, 2)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.StrideRank((3, 2)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.StrideRank((2, 3)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StrideRank((1, 3)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,2,1])')) === ArrayInterface.StrideRank((2, 1)) + + @test @inferred(dense_dims(@SArray(zeros(2,2,2)))) === ArrayInterface.DenseDims((true,true,true)) + @test @inferred(dense_dims(A)) === ArrayInterface.DenseDims((true,true,true)) + @test @inferred(dense_dims(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.DenseDims((true,true,true)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === ArrayInterface.DenseDims((true,false)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === ArrayInterface.DenseDims((false,true)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === ArrayInterface.DenseDims((false,true,false)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,1:2]))) === ArrayInterface.DenseDims((false,true,true)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.DenseDims((false,false)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.DenseDims((false,false)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.DenseDims((true,false)) + + B = Array{Int8}(undef, 2,2,2,2); + doubleperm = PermutedDimsArray(PermutedDimsArray(B,(4,2,3,1)), (4,2,1,3)); + @test collect(strides(B))[collect(stride_rank(doubleperm))] == collect(strides(doubleperm)) +end + +using OffsetArrays +@testset "Static-Dynamic Size, Strides, and Offsets" begin + A = zeros(3,4,5); Ap = @view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])'; + S = @SArray zeros(2,3,4); Sp = @view(PermutedDimsArray(S,(3,1,2))[2:3,1:2,:]); + M = @MArray zeros(2,3,4); Mp = @view(PermutedDimsArray(M,(3,1,2))[:,2,:])'; + Sp2 = @view(PermutedDimsArray(S,(3,2,1))[2:3,:,:]); + Mp2 = @view(PermutedDimsArray(M,(3,1,2))[2:3,:,2])'; + + @test @inferred(ArrayInterface.size(A)) === (3,4,5) + @test @inferred(ArrayInterface.size(Ap)) === (2,5) + @test @inferred(ArrayInterface.size(A)) === size(A) + @test @inferred(ArrayInterface.size(Ap)) === size(Ap) + + @test @inferred(ArrayInterface.size(S)) === (StaticInt(2), StaticInt(3), StaticInt(4)) + @test @inferred(ArrayInterface.size(Sp)) === (2, 2, StaticInt(3)) + @test @inferred(ArrayInterface.size(Sp2)) === (2, StaticInt(3), StaticInt(2)) + @test @inferred(ArrayInterface.size(S)) == size(S) + @test @inferred(ArrayInterface.size(Sp)) == size(Sp) + @test @inferred(ArrayInterface.size(Sp2)) == size(Sp2) + @test @inferred(ArrayInterface.size(Sp2, StaticInt(1))) === 2 + @test @inferred(ArrayInterface.size(Sp2, StaticInt(2))) === StaticInt(3) + @test @inferred(ArrayInterface.size(Sp2, StaticInt(3))) === StaticInt(2) + + @test @inferred(ArrayInterface.size(M)) === (StaticInt(2), StaticInt(3), StaticInt(4)) + @test @inferred(ArrayInterface.size(Mp)) === (StaticInt(3), StaticInt(4)) + @test @inferred(ArrayInterface.size(Mp2)) === (StaticInt(2), 2) + @test @inferred(ArrayInterface.size(M)) == size(M) + @test @inferred(ArrayInterface.size(Mp)) == size(Mp) + @test @inferred(ArrayInterface.size(Mp2)) == size(Mp2) + + @test @inferred(ArrayInterface.strides(A)) === (StaticInt(1), 3, 12) + @test @inferred(ArrayInterface.strides(Ap)) === (StaticInt(1), 12) + @test @inferred(ArrayInterface.strides(A)) == strides(A) + @test @inferred(ArrayInterface.strides(Ap)) == strides(Ap) + + @test @inferred(ArrayInterface.strides(S)) === (StaticInt(1), StaticInt(2), StaticInt(6)) + @test @inferred(ArrayInterface.strides(Sp)) === (StaticInt(6), StaticInt(1), StaticInt(2)) + @test @inferred(ArrayInterface.strides(Sp2)) === (StaticInt(6), StaticInt(2), StaticInt(1)) + @test @inferred(ArrayInterface.stride(Sp2, StaticInt(1))) === StaticInt(6) + @test @inferred(ArrayInterface.stride(Sp2, StaticInt(2))) === StaticInt(2) + @test @inferred(ArrayInterface.stride(Sp2, StaticInt(3))) === StaticInt(1) + + @test @inferred(ArrayInterface.strides(M)) === (StaticInt(1), StaticInt(2), StaticInt(6)) + @test @inferred(ArrayInterface.strides(Mp)) === (StaticInt(2), StaticInt(6)) + @test @inferred(ArrayInterface.strides(Mp2)) === (StaticInt(1), StaticInt(6)) + @test @inferred(ArrayInterface.strides(M)) == strides(M) + @test @inferred(ArrayInterface.strides(Mp)) == strides(Mp) + @test @inferred(ArrayInterface.strides(Mp2)) == strides(Mp2) + + @test @inferred(ArrayInterface.offsets(A)) === (StaticInt(1), StaticInt(1), StaticInt(1)) + @test @inferred(ArrayInterface.offsets(Ap)) === (StaticInt(1), StaticInt(1)) + + @test @inferred(ArrayInterface.offsets(S)) === (StaticInt(1), StaticInt(1), StaticInt(1)) + @test @inferred(ArrayInterface.offsets(Sp)) === (StaticInt(1), StaticInt(1), StaticInt(1)) + @test @inferred(ArrayInterface.offsets(Sp2)) === (StaticInt(1), StaticInt(1), StaticInt(1)) + + @test @inferred(ArrayInterface.offsets(M)) === (StaticInt(1), StaticInt(1), StaticInt(1)) + @test @inferred(ArrayInterface.offsets(Mp)) === (StaticInt(1), StaticInt(1)) + @test @inferred(ArrayInterface.offsets(Mp2)) === (StaticInt(1), StaticInt(1)) + + O = OffsetArray(A, 3, 7, 10); + Op = PermutedDimsArray(O,(3,1,2)); + @test @inferred(ArrayInterface.offsets(O)) === (4, 8, 11) + @test @inferred(ArrayInterface.offsets(Op)) === (11, 4, 8) + + @test @inferred(ArrayInterface.offsets((1,2,3))) === (StaticInt(1),) +end + +@test ArrayInterface.can_avx(ArrayInterface.can_avx) == false + @testset "can_change_size" begin @test ArrayInterface.can_change_size([1]) @test ArrayInterface.can_change_size(Vector{Int}) @@ -211,7 +350,7 @@ end end @testset "known_length" begin - @test ArrayInterface.known_length(ArrayInterface.indices(SOneTo(7))) == 7 + @test ArrayInterface.known_length(@inferred(ArrayInterface.indices(SOneTo(7)))) == 7 @test ArrayInterface.known_length(1:2) == nothing @test ArrayInterface.known_length((1,)) == 1 @test ArrayInterface.known_length((a=1,b=2)) == 2 @@ -232,32 +371,32 @@ end @test @inferred(ArrayInterface.indices(A23)) == 1:6 @test @inferred(ArrayInterface.indices(SA23)) == 1:6 @test @inferred(ArrayInterface.indices(A23, 1)) == 1:2 - @test @inferred(ArrayInterface.indices(SA23, Static(1))) === Base.Slice(Static(1):Static(2)) + @test @inferred(ArrayInterface.indices(SA23, StaticInt(1))) === Base.Slice(StaticInt(1):StaticInt(2)) @test @inferred(ArrayInterface.indices((A23, A32), (1, 2))) == 1:2 - @test @inferred(ArrayInterface.indices((SA23, A32), (Static(1), 2))) === Base.Slice(Static(1):Static(2)) - @test @inferred(ArrayInterface.indices((A23, SA32), (1, Static(2)))) === Base.Slice(Static(1):Static(2)) - @test @inferred(ArrayInterface.indices((SA23, SA32), (Static(1), Static(2)))) === Base.Slice(Static(1):Static(2)) + @test @inferred(ArrayInterface.indices((SA23, A32), (StaticInt(1), 2))) === Base.Slice(StaticInt(1):StaticInt(2)) + @test @inferred(ArrayInterface.indices((A23, SA32), (1, StaticInt(2)))) === Base.Slice(StaticInt(1):StaticInt(2)) + @test @inferred(ArrayInterface.indices((SA23, SA32), (StaticInt(1), StaticInt(2)))) === Base.Slice(StaticInt(1):StaticInt(2)) @test @inferred(ArrayInterface.indices((A23, A23), 1)) == 1:2 - @test @inferred(ArrayInterface.indices((SA23, SA23), Static(1))) === Base.Slice(Static(1):Static(2)) - @test @inferred(ArrayInterface.indices((SA23, A23), Static(1))) === Base.Slice(Static(1):Static(2)) - @test @inferred(ArrayInterface.indices((A23, SA23), Static(1))) === Base.Slice(Static(1):Static(2)) - @test @inferred(ArrayInterface.indices((SA23, SA23), Static(1))) === Base.Slice(Static(1):Static(2)) + @test @inferred(ArrayInterface.indices((SA23, SA23), StaticInt(1))) === Base.Slice(StaticInt(1):StaticInt(2)) + @test @inferred(ArrayInterface.indices((SA23, A23), StaticInt(1))) === Base.Slice(StaticInt(1):StaticInt(2)) + @test @inferred(ArrayInterface.indices((A23, SA23), StaticInt(1))) === Base.Slice(StaticInt(1):StaticInt(2)) + @test @inferred(ArrayInterface.indices((SA23, SA23), StaticInt(1))) === Base.Slice(StaticInt(1):StaticInt(2)) @test_throws AssertionError ArrayInterface.indices((A23, ones(3, 3)), 1) @test_throws AssertionError ArrayInterface.indices((A23, ones(3, 3)), (1, 2)) - @test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), Static(1)) - @test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), (Static(1), 2)) - @test_throws AssertionError ArrayInterface.indices((SA23, SA23), (Static(1), Static(2))) + @test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), StaticInt(1)) + @test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), (StaticInt(1), 2)) + @test_throws AssertionError ArrayInterface.indices((SA23, SA23), (StaticInt(1), StaticInt(2))) end @testset "Static" begin - @test iszero(Static(0)) - @test !iszero(Static(1)) - @test @inferred(one(Static)) === Static(1) - @test @inferred(zero(Static)) === Static(0) - @test eltype(one(Static)) <: Int + @test iszero(StaticInt(0)) + @test !iszero(StaticInt(1)) + @test @inferred(one(StaticInt)) === StaticInt(1) + @test @inferred(zero(StaticInt)) === StaticInt(0) + @test eltype(one(StaticInt)) <: Int # test for ambiguities and correctness - for i ∈ [Static(0), Static(1), Static(2), 3] - for j ∈ [Static(0), Static(1), Static(2), 3] + for i ∈ [StaticInt(0), StaticInt(1), StaticInt(2), 3] + for j ∈ [StaticInt(0), StaticInt(1), StaticInt(2), 3] i === j === 3 && continue for f ∈ [+, -, *, ÷, %, <<, >>, >>>, &, |, ⊻, ==, ≤, ≥] (iszero(j) && ((f === ÷) || (f === %))) && continue # integer division error @@ -269,7 +408,7 @@ end x = f(convert(Int, i), 1.4) y = f(1.4, convert(Int, i)) @test convert(typeof(x), @inferred(f(i, 1.4))) === x - @test convert(typeof(y), @inferred(f(1.4, i))) === y # if f is division and i === Static(0), returns `NaN`; hence use of ==== in check. + @test convert(typeof(y), @inferred(f(1.4, i))) === y # if f is division and i === StaticInt(0), returns `NaN`; hence use of ==== in check. end end end