diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index fd415a634..6f205427d 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -615,6 +615,7 @@ end abstract type AbstractDevice end abstract type AbstractCPU <: AbstractDevice end struct CPUPointer <: AbstractCPU end +struct CPUTuple <: AbstractCPU end struct CheckParent end struct CPUIndex <: AbstractCPU end struct GPU <: AbstractDevice end @@ -630,7 +631,7 @@ Otherwise, returns `nothing`. """ device(A) = device(typeof(A)) device(::Type) = nothing -device(::Type{<:Tuple}) = CPUIndex() +device(::Type{<:Tuple}) = CPUTuple() device(::Type{T}) where {T<:Array} = CPUPointer() device(::Type{T}) where {T<:AbstractArray} = _device(has_parent(T), T) function _device(::True, ::Type{T}) where {T} @@ -880,6 +881,7 @@ function __init__() known_length(::Type{A}) where {A <: StaticArrays.StaticArray} = known_length(StaticArrays.Length(A)) device(::Type{<:StaticArrays.MArray}) = CPUPointer() + device(::Type{<:StaticArrays.SArray}) = CPUTuple() contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}() contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}() function stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}} diff --git a/test/runtests.jl b/test/runtests.jl index c0adfcbea..d13544b93 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -320,15 +320,15 @@ using OffsetArrays @test @inferred(device(A)) === ArrayInterface.CPUPointer() @test @inferred(device(B)) === ArrayInterface.CPUIndex() @test @inferred(device(-1:19)) === ArrayInterface.CPUIndex() - @test @inferred(device((1,2,3))) === ArrayInterface.CPUIndex() + @test @inferred(device((1,2,3))) === ArrayInterface.CPUTuple() @test @inferred(device(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.CPUPointer() @test @inferred(device(view(A, 1, :, 2:4))) === ArrayInterface.CPUPointer() @test @inferred(device(view(A, 1, :, 2:4)')) === ArrayInterface.CPUPointer() @test @inferred(device(OffsetArray(view(PermutedDimsArray(A, (3,1,2)), 1, :, 2:4)', 3, -173))) === ArrayInterface.CPUPointer() @test @inferred(device(view(OffsetArray(A,2,3,-12), 4, :, -11:-9))) === ArrayInterface.CPUPointer() @test @inferred(device(view(OffsetArray(A,2,3,-12), 3, :, [-11,-10,-9])')) === ArrayInterface.CPUIndex() - @test @inferred(device(OffsetArray(@SArray(zeros(2,2,2)),-123,29,3231))) === ArrayInterface.CPUIndex() - @test @inferred(device(OffsetArray(@view(@SArray(zeros(2,2,2))[1,1:2,:]),-3,4))) === ArrayInterface.CPUIndex() + @test @inferred(device(OffsetArray(@SArray(zeros(2,2,2)),-123,29,3231))) === ArrayInterface.CPUTuple() + @test @inferred(device(OffsetArray(@view(@SArray(zeros(2,2,2))[1,1:2,:]),-3,4))) === ArrayInterface.CPUTuple() @test @inferred(device(OffsetArray(@MArray(zeros(2,2,2)),8,-2,-5))) === ArrayInterface.CPUPointer() @test isnothing(device("Hello, world!")) @test @inferred(device(DenseWrapper{Int,2,Matrix{Int}})) === ArrayInterface.CPUPointer()