From 24003c294eb8af89bad656048996fa962b7e5c2e Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Wed, 22 Jun 2022 17:03:16 -0400 Subject: [PATCH 1/2] trait for bounds checking --- lib/ArrayInterfaceCore/Project.toml | 2 +- .../src/ArrayInterfaceCore.jl | 115 +++++++++++++++++- 2 files changed, 114 insertions(+), 3 deletions(-) diff --git a/lib/ArrayInterfaceCore/Project.toml b/lib/ArrayInterfaceCore/Project.toml index f01f0f9d9..534eeabb0 100644 --- a/lib/ArrayInterfaceCore/Project.toml +++ b/lib/ArrayInterfaceCore/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterfaceCore" uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" -version = "0.1.12" +version = "0.1.13" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl index 65e44a2ef..1047d0c14 100644 --- a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl +++ b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl @@ -520,6 +520,115 @@ known_step(x) = known_step(typeof(x)) known_step(T::Type) = is_forwarding_wrapper(T) ? known_step(parent_type(T)) : nothing known_step(@nospecialize T::Type{<:AbstractUnitRange}) = 1 +abstract type CheckIndexStyle end + +""" + ArrayInterfaceCore.CheckIndexNone + +A [`ArrayInterfaceCore.CheckIndexStyle`](@ref) trait-value indicating that no bounds-checking +is needed. This is only appropriate for select types such as slice indexing +(e.g., `:`, `Base.Slice`). + +# Examples + +```jldoctest +julia> ArrayInterfaceCore.CheckIndexStyle(:) +ArrayInterfaceCore.CheckIndexNone() + +julia> ArrayInterfaceCore.CheckIndexStyle(Base.Slice(1:10)) +ArrayInterfaceCore.CheckIndexNone() + +``` + +""" +struct CheckIndexNone <: CheckIndexStyle end + +""" + ArrayInterfaceCore.CheckIndexFirstLast() + +A [`ArrayInterfaceCore.CheckIndexStyle`](@ref) trait-value indicating that bounds-checking +only need test the first and last elements in an index vector. + +# Examples + +```jldoctest +julia> r = 3:7; ArrayInterfaceCore.CheckIndexStyle(r) +ArrayInterfaceCore.CheckIndexFirstLast() + +``` +Ranges are declared `CheckIndexFirstLast` because `x[r]` can be tested +for out-of-bounds indexing using just the first and last elements of `r`. +See also [`ArrayInterfaceCore.CheckIndexAll`](@ref) and [`ArrayInterfaceCore.CheckIndexAxes`](@ref). +""" +struct CheckIndexFirstLast <: CheckIndexStyle end + +""" + ArrayInterfaceCore.CheckIndexAll() + +A [`ArrayInterfaceCore.CheckIndexStyle`](@ref) trait-value indicating that bounds-checking +must test all elements in an index vector. + +# Examples + +```jldoctest +julia> idx = [3,4,5,6,7]; ArrayInterfaceCore.CheckIndexStyle(idx) +ArrayInterfaceCore.CheckIndexAll() + +``` + +Since the entries in `idx` could be arbitrary, we have to check each +entry for bounds violations. +See also [`ArrayInterfaceCore..CheckIndexFirstLast`](@ref) and [`ArrayInterfaceCore.CheckIndexAxes`](@ref). +""" +struct CheckIndexAll <: CheckIndexStyle end + +""" + ArrayInterfaceCore.CheckIndexAxes() + +A [`CheckIndexStyle`](@ref) trait-value indicating that bounds-checking +should consider the axes of the index rather than the values of the +index. This is used in cases where the index acts as a filter to +select elements. + +# Examples + +```jldoctest +julia> idx = [true, false, true]; ArrayInterfaceCore.CheckIndexStyle(idx) +ArrayInterfaceCore.CheckIndexAxes() + +``` +When `idx` is used in `x[idx]`, it returns the entries in `x` +corresponding to `true` entries in `idx`. Consequently, indexing +should insist on `idx` and `x` having identical axes. +See also [`ArrayInterfaceCore.CheckIndexFirstLast`](@ref) and [`ArrayInterfaceCore.CheckIndexAll`](@ref). +""" +struct CheckIndexAxes <: CheckIndexStyle end + +""" + ArrayInterfaceCore.CheckIndexStyle(typeof(idx)) + ArrayInterfaceCore.CheckIndexStyle(::Type{T}) + +`CheckIndexStyle` specifies how bounds-checking of `x[idx]` should be performed. Certain +types of `idx`, such as ranges, may have particularly efficient ways to perform the +bounds-checking. When you define a new [`AbstractArray`](@ref) type, you can choose to +define a specific value for this trait: + +ArrayInterfaceCore.CheckIndexStyle(::Type{<:MyRange}) = CheckIndexFirstLast() +The default is [`CheckIndexAll()`](@ref), except for `AbstractVector`s with `Bool` +`eltype` (which default to [`ArrayInterfaceCore.CheckIndexAxes()`](@ref)) and subtypes of `AbstractRange` +(which default to [`ArrayInterfaceCore.CheckIndexFirstLast()`](@ref).) +""" +CheckIndexStyle(x) = CheckIndexStyle(typeof(x)) +CheckIndexStyle(::Type) = CheckIndexAll() +CheckIndexStyle(@nospecialize T::Type{<:AbstractArray}) = eltype(T) === Bool ? CheckIndexAxes() : CheckIndexAll() +CheckIndexStyle(@nospecialize T::Type{<:AbstractRange}) = eltype(T) === Bool ? CheckIndexAxes() : CheckIndexFirstLast() +CheckIndexStyle(@nospecialize T::Type{<:Base.FastContiguousSubArray}) = CheckIndexStyle(parent_type(T)) +CheckIndexStyle(::Type{Colon}) = CheckIndexNone() +CheckIndexStyle(@nospecialize T::Type{<:Base.Slice}) = CheckIndexNone() +CheckIndexStyle(@nospecialize T::Type{<:Base.Fix2{<:Union{typeof(<),typeof(isless),typeof(>=),typeof(>),typeof(isless)},<:Number}}) = CheckIndexNone() +CheckIndexStyle(@nospecialize T::Type{<:Number}) = CheckIndexFirstLast() + + """ is_splat_index(::Type{T}) -> Bool Returns `static(true)` if `T` is a type that splats across multiple dimensions. @@ -564,19 +673,21 @@ Provides basic trait information for each index type in in the tuple `T`. `NI`, `IS` are tuples of [`ndims_index`](@ref), [`ndims_shape`](@ref), and [`is_splat_index`](@ref) (respectively) for each field of `T`. """ -struct IndicesInfo{NI,NS,IS} end +struct IndicesInfo{NI,NS,IS,CI} end IndicesInfo(@nospecialize x::Tuple) = IndicesInfo(typeof(x)) @generated function IndicesInfo(::Type{T}) where {T<:Tuple} NI = Expr(:tuple) NS = Expr(:tuple) IS = Expr(:tuple) + CI = Expr(:tuple) for i in 1:fieldcount(T) T_i = fieldtype(T, i) push!(NI.args, :(ndims_index($(T_i)))) push!(NS.args, :(ndims_shape($(T_i)))) push!(IS.args, :(is_splat_index($(T_i)))) + push!(CI.args, :(CheckIndexStyle($(T_i)))) end - Expr(:block, Expr(:meta, :inline), :(IndicesInfo{$(NI),$(NS),$(IS)}())) + Expr(:block, Expr(:meta, :inline), :(IndicesInfo{$(NI),$(NS),$(IS),$(CI)}())) end """ From 655f1589eb3f87f87b6b20070f11f00133ec0c60 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Thu, 30 Jun 2022 04:06:23 -0400 Subject: [PATCH 2/2] Fix formatting --- lib/ArrayInterfaceCore/Project.toml | 1 - lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/ArrayInterfaceCore/Project.toml b/lib/ArrayInterfaceCore/Project.toml index 00db610ac..708575a42 100644 --- a/lib/ArrayInterfaceCore/Project.toml +++ b/lib/ArrayInterfaceCore/Project.toml @@ -2,7 +2,6 @@ name = "ArrayInterfaceCore" uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" version = "0.1.14" - [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl index 5fd8ec82e..1b1d25d5b 100644 --- a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl +++ b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl @@ -729,14 +729,13 @@ CheckIndexStyle(::Type) = CheckIndexAll() CheckIndexStyle(@nospecialize T::Type{<:AbstractArray}) = eltype(T) === Bool ? CheckIndexAxes() : CheckIndexAll() CheckIndexStyle(@nospecialize T::Type{<:AbstractRange}) = eltype(T) === Bool ? CheckIndexAxes() : CheckIndexFirstLast() CheckIndexStyle(@nospecialize T::Type{<:Base.FastContiguousSubArray}) = CheckIndexStyle(parent_type(T)) -CheckIndexStyle(::Type{Colon}) = CheckIndexNone() -CheckIndexStyle(@nospecialize T::Type{<:Base.Slice}) = CheckIndexNone() +CheckIndexStyle(@nospecialize T::Type{<:Union{Base.Slice,Colon}}) = CheckIndexNone() CheckIndexStyle(@nospecialize T::Type{<:Base.Fix2{<:Union{typeof(<),typeof(isless),typeof(>=),typeof(>),typeof(isless)},<:Number}}) = CheckIndexNone() CheckIndexStyle(@nospecialize T::Type{<:Number}) = CheckIndexFirstLast() - """ is_splat_index(::Type{T}) -> Bool + Returns `static(true)` if `T` is a type that splats across multiple dimensions. """ is_splat_index(T::Type) = false