diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index dedf67cd..e0530430 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -34,6 +34,8 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + GROUP: ${{ matrix.group }} - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v1 with: diff --git a/Project.toml b/Project.toml index 502c64d8..8e8efca2 100644 --- a/Project.toml +++ b/Project.toml @@ -6,27 +6,27 @@ version = "2.30.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" -ArrayInterfaceStaticArrays = "b0d46f97-bff5-4637-a19a-dd75974142cd" +ArrayInterfaceStaticArraysCore = "dd5226c6-a4d4-4bc7-8575-46859f9c95b9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Adapt = "3" ArrayInterfaceCore = "0.1.1" -ArrayInterfaceStaticArrays = "0.1" +ArrayInterfaceStaticArraysCore = "0.1" ChainRulesCore = "0.10.7, 1" DocStringExtensions = "0.8, 0.9" FillArrays = "0.11, 0.12, 0.13" GPUArraysCore = "0.1" RecipesBase = "0.7, 0.8, 1.0" -StaticArrays = "0.12, 1.0" +StaticArraysCore = "1" ZygoteRules = "0.2" julia = "1.6" @@ -36,10 +36,11 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StructArrays", "Zygote"] +test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"] diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index eaa9f4b2..25529e1e 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -5,7 +5,7 @@ $(DocStringExtensions.README) module RecursiveArrayTools using DocStringExtensions -using RecipesBase, StaticArrays, Statistics, +using RecipesBase, StaticArraysCore, Statistics, ArrayInterfaceCore, LinearAlgebra import ChainRulesCore @@ -15,7 +15,7 @@ import ZygoteRules, Adapt # Required for the downstream_events.jl test # Since `ismutable` on an ArrayPartition needs # to know static arrays are not mutable -import ArrayInterfaceStaticArrays +import ArrayInterfaceStaticArraysCore using FillArrays diff --git a/src/array_partition.jl b/src/array_partition.jl index 6ff690a3..1bb9eff2 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -90,7 +90,7 @@ Base.zero(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = zero(A) ## Array -Base.Array(A::ArrayPartition) = ArrayPartition(Array.(A.x)) +Base.Array(A::ArrayPartition) = reduce(vcat,Array.(A.x)) Base.Array(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A <: AbstractVector{<:ArrayPartition}} = reduce(hcat,Array.(VA.u)) ## ones @@ -390,13 +390,13 @@ end # [U11 U12 U13] [ b1 ] # [ 0 U22 U23] \ [ b2 ] # [ 0 0 U33] [ b3 ] -function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitUpperTriangular,UpperTriangular} +function LinearAlgebra.ldiv!(A::UnitUpperTriangular, bb::ArrayPartition) A = A.data n = npartitions(bb) b = bb.x lens = map(length, b) @inbounds for j in n:-1:1 - Ajj = T(getblock(A, lens, j, j)) + Ajj = UnitUpperTriangular(getblock(A, lens, j, j)) xj = ldiv!(Ajj, vec(b[j])) for i in j-1:-1:1 Aij = getblock(A, lens, i, j) @@ -407,13 +407,30 @@ function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitUpperT return bb end -function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitLowerTriangular,LowerTriangular} +function LinearAlgebra.ldiv!(A::UpperTriangular, bb::ArrayPartition) + A = A.data + n = npartitions(bb) + b = bb.x + lens = map(length, b) + @inbounds for j in n:-1:1 + Ajj = UpperTriangular(getblock(A, lens, j, j)) + xj = ldiv!(Ajj, vec(b[j])) + for i in j-1:-1:1 + Aij = getblock(A, lens, i, j) + # bi = -Aij * xj + bi + mul!(vec(b[i]), Aij, xj, -1, true) + end + end + return bb +end + +function LinearAlgebra.ldiv!(A::UnitLowerTriangular, bb::ArrayPartition) A = A.data n = npartitions(bb) b = bb.x lens = map(length, b) @inbounds for j in 1:n - Ajj = T(getblock(A, lens, j, j)) + Ajj = UnitLowerTriangular(getblock(A, lens, j, j)) xj = ldiv!(Ajj, vec(b[j])) for i in j+1:n Aij = getblock(A, lens, i, j) @@ -423,6 +440,24 @@ function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitLowerT end return bb end + +function LinearAlgebra.ldiv!(A::LowerTriangular, bb::ArrayPartition) + A = A.data + n = npartitions(bb) + b = bb.x + lens = map(length, b) + @inbounds for j in 1:n + Ajj = LowerTriangular(getblock(A, lens, j, j)) + xj = ldiv!(Ajj, vec(b[j])) + for i in j+1:n + Aij = getblock(A, lens, i, j) + # bi = -Aij * xj + b[i] + mul!(vec(b[i]), Aij, xj, -1, true) + end + end + return bb +end + # TODO: optimize function LinearAlgebra._ipiv_rows!(A::LU, order::OrdinalRange, B::ArrayPartition) for i = order diff --git a/src/utils.jl b/src/utils.jl index d2efdeba..f9a08fce 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,7 +9,8 @@ like `copy` on arrays of scalars. function recursivecopy(a) deepcopy(a) end -recursivecopy(a::Union{SVector,SMatrix,SArray,Number}) = copy(a) +recursivecopy(a::Union{StaticArraysCore.SVector,StaticArraysCore.SMatrix, + StaticArraysCore.SArray,Number}) = copy(a) function recursivecopy(a::AbstractArray{T,N}) where {T<:Number,N} copy(a) end @@ -33,7 +34,7 @@ like `copy!` on arrays of scalars. """ function recursivecopy! end -function recursivecopy!(b::AbstractArray{T,N},a::AbstractArray{T2,N}) where {T<:StaticArray,T2<:StaticArray,N} +function recursivecopy!(b::AbstractArray{T,N},a::AbstractArray{T2,N}) where {T<:StaticArraysCore.StaticArray,T2<:StaticArraysCore.StaticArray,N} @inbounds for i in eachindex(a) # TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19 b[i] = copy(a[i]) @@ -68,13 +69,13 @@ A recursive `fill!` function. """ function recursivefill! end -function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:StaticArray,T2<:StaticArray,N} +function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:StaticArraysCore.StaticArray,T2<:StaticArraysCore.StaticArray,N} @inbounds for i in eachindex(b) b[i] = copy(a) end end -function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:SArray,T2<:Union{Number,Bool},N} +function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:StaticArraysCore.SArray,T2<:Union{Number,Bool},N} @inbounds for i in eachindex(b) b[i] = fill(a, typeof(b[i])) end @@ -88,7 +89,7 @@ function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:Union{Number,Bool fill!(b, a) end -function recursivefill!(b::AbstractArray{T,N},a) where {T<:MArray,N} +function recursivefill!(b::AbstractArray{T,N},a) where {T<:StaticArraysCore.MArray,N} @inbounds for i in eachindex(b) if isassigned(b,i) recursivefill!(b[i],a) @@ -151,7 +152,7 @@ If `i= i if !ArrayInterfaceCore.ismutable(T) || !perform_copy - # TODO: Check for `setindex!`` if T <: StaticArray and use `copy!(b[i],a[i])` + # TODO: Check for `setindex!`` if T <: StaticArraysCore.StaticArray and use `copy!(b[i],a[i])` # or `b[i] = a[i]`, see https://github.com/JuliaDiffEq/RecursiveArrayTools.jl/issues/19 a[i] = x else @@ -208,7 +209,15 @@ ones has a `Array{Array{Float64,N},N}`, this will return `Array{Float64,N}`. """ recursive_unitless_eltype(a) = recursive_unitless_eltype(eltype(a)) recursive_unitless_eltype(a::Type{Any}) = Any -recursive_unitless_eltype(a::Type{T}) where {T<:StaticArray} = similar_type(a,recursive_unitless_eltype(eltype(a))) + +# Should be: +# recursive_unitless_eltype(a::Type{T}) where {T<:StaticArray} = similar_type(a,recursive_unitless_eltype(eltype(a))) +# But missing from StaticArraysCore +recursive_unitless_eltype(a::Type{StaticArraysCore.SArray{S, T, N, L}}) where {S, T, N, L} = StaticArraysCore.SArray{S, typeof(one(T)), N, L} +recursive_unitless_eltype(a::Type{StaticArraysCore.MArray{S, T, N, L}}) where {S, T, N, L} = StaticArraysCore.MArray{S, typeof(one(T)), N, L} +recursive_unitless_eltype(a::Type{StaticArraysCore.SizedArray{S, T, N, M, TData}}) where { + S, T, N, M, TData} = StaticArraysCore.SizedArray{S, typeof(one(T)), N, M, TData} + recursive_unitless_eltype(a::Type{T}) where {T<:Array} = Array{recursive_unitless_eltype(eltype(a)),ndims(a)} recursive_unitless_eltype(a::Type{T}) where {T<:Number} = typeof(one(eltype(a))) recursive_unitless_eltype(::Type{<:Enum{T}}) where T = T diff --git a/test/linalg.jl b/test/linalg.jl index 1c7b137e..492dc336 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -4,6 +4,7 @@ using LinearAlgebra n, m = 5, 6 bb = rand(n), rand(m) b = ArrayPartition(bb) +@test Array(b) isa Array @test Array(b) == collect(b) == vcat(bb...) A = randn(MersenneTwister(123), n+m, n+m) diff --git a/test/runtests.jl b/test/runtests.jl index 54ae5aad..efa80e65 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,7 +19,7 @@ end @time begin -if !is_APPVEYOR && GROUP == "Core" +if GROUP == "Core" || GROUP == "All" @time @testset "Utils Tests" begin include("utils_test.jl") end @time @testset "Partitions Tests" begin include("partitions_test.jl") end @time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end diff --git a/test/upstream.jl b/test/upstream.jl index 021eceb0..21f73a40 100644 --- a/test/upstream.jl +++ b/test/upstream.jl @@ -36,7 +36,7 @@ dyn(u, p, t) = ArrayPartition( ArrayPartition(zeros(1), [0.0]) ) -solve( +@test solve( ODEProblem( dyn, ArrayPartition( @@ -45,15 +45,28 @@ solve( ), (0.0, 1.0) ),AutoTsit5(Rodas5()) -) - -@test_broken solve( - ODEProblem( - dyn, - ArrayPartition( - ArrayPartition(zeros(1), [-1.0]), - ArrayPartition(zeros(1), [0.75]) - ), - (0.0, 1.0) - ),Rodas5() ).retcode == :Success + +if VERSION < v"1.7" + @test solve( + ODEProblem( + dyn, + ArrayPartition( + ArrayPartition(zeros(1), [-1.0]), + ArrayPartition(zeros(1), [0.75]) + ), + (0.0, 1.0) + ),Rodas5() + ).retcode == :Success +else + @test_broken solve( + ODEProblem( + dyn, + ArrayPartition( + ArrayPartition(zeros(1), [-1.0]), + ArrayPartition(zeros(1), [0.75]) + ), + (0.0, 1.0) + ),Rodas5() + ).retcode == :Success +end \ No newline at end of file