Skip to content

Drop StaticArrays by using StaticArraysCore #217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: ${{ matrix.group }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
with:
Expand Down
11 changes: 6 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@ version = "2.30.0"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
ArrayInterfaceStaticArrays = "b0d46f97-bff5-4637-a19a-dd75974142cd"
ArrayInterfaceStaticArraysCore = "dd5226c6-a4d4-4bc7-8575-46859f9c95b9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Adapt = "3"
ArrayInterfaceCore = "0.1.1"
ArrayInterfaceStaticArrays = "0.1"
ArrayInterfaceStaticArraysCore = "0.1"
ChainRulesCore = "0.10.7, 1"
DocStringExtensions = "0.8, 0.9"
FillArrays = "0.11, 0.12, 0.13"
GPUArraysCore = "0.1"
RecipesBase = "0.7, 0.8, 1.0"
StaticArrays = "0.12, 1.0"
StaticArraysCore = "1"
ZygoteRules = "0.2"
julia = "1.6"

Expand All @@ -36,10 +36,11 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StructArrays", "Zygote"]
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"]
4 changes: 2 additions & 2 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ $(DocStringExtensions.README)
module RecursiveArrayTools

using DocStringExtensions
using RecipesBase, StaticArrays, Statistics,
using RecipesBase, StaticArraysCore, Statistics,
ArrayInterfaceCore, LinearAlgebra

import ChainRulesCore
Expand All @@ -15,7 +15,7 @@ import ZygoteRules, Adapt
# Required for the downstream_events.jl test
# Since `ismutable` on an ArrayPartition needs
# to know static arrays are not mutable
import ArrayInterfaceStaticArrays
import ArrayInterfaceStaticArraysCore

using FillArrays

Expand Down
45 changes: 40 additions & 5 deletions src/array_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Base.zero(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = zero(A)

## Array

Base.Array(A::ArrayPartition) = ArrayPartition(Array.(A.x))
Base.Array(A::ArrayPartition) = reduce(vcat,Array.(A.x))
Base.Array(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A <: AbstractVector{<:ArrayPartition}} = reduce(hcat,Array.(VA.u))

## ones
Expand Down Expand Up @@ -390,13 +390,13 @@ end
# [U11 U12 U13] [ b1 ]
# [ 0 U22 U23] \ [ b2 ]
# [ 0 0 U33] [ b3 ]
function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitUpperTriangular,UpperTriangular}
function LinearAlgebra.ldiv!(A::UnitUpperTriangular, bb::ArrayPartition)
A = A.data
n = npartitions(bb)
b = bb.x
lens = map(length, b)
@inbounds for j in n:-1:1
Ajj = T(getblock(A, lens, j, j))
Ajj = UnitUpperTriangular(getblock(A, lens, j, j))
xj = ldiv!(Ajj, vec(b[j]))
for i in j-1:-1:1
Aij = getblock(A, lens, i, j)
Expand All @@ -407,13 +407,30 @@ function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitUpperT
return bb
end

function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitLowerTriangular,LowerTriangular}
function LinearAlgebra.ldiv!(A::UpperTriangular, bb::ArrayPartition)
A = A.data
n = npartitions(bb)
b = bb.x
lens = map(length, b)
@inbounds for j in n:-1:1
Ajj = UpperTriangular(getblock(A, lens, j, j))
xj = ldiv!(Ajj, vec(b[j]))
for i in j-1:-1:1
Aij = getblock(A, lens, i, j)
# bi = -Aij * xj + bi
mul!(vec(b[i]), Aij, xj, -1, true)
end
end
return bb
end

function LinearAlgebra.ldiv!(A::UnitLowerTriangular, bb::ArrayPartition)
A = A.data
n = npartitions(bb)
b = bb.x
lens = map(length, b)
@inbounds for j in 1:n
Ajj = T(getblock(A, lens, j, j))
Ajj = UnitLowerTriangular(getblock(A, lens, j, j))
xj = ldiv!(Ajj, vec(b[j]))
for i in j+1:n
Aij = getblock(A, lens, i, j)
Expand All @@ -423,6 +440,24 @@ function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitLowerT
end
return bb
end

function LinearAlgebra.ldiv!(A::LowerTriangular, bb::ArrayPartition)
A = A.data
n = npartitions(bb)
b = bb.x
lens = map(length, b)
@inbounds for j in 1:n
Ajj = LowerTriangular(getblock(A, lens, j, j))
xj = ldiv!(Ajj, vec(b[j]))
for i in j+1:n
Aij = getblock(A, lens, i, j)
# bi = -Aij * xj + b[i]
mul!(vec(b[i]), Aij, xj, -1, true)
end
end
return bb
end

# TODO: optimize
function LinearAlgebra._ipiv_rows!(A::LU, order::OrdinalRange, B::ArrayPartition)
for i = order
Expand Down
23 changes: 16 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ like `copy` on arrays of scalars.
function recursivecopy(a)
deepcopy(a)
end
recursivecopy(a::Union{SVector,SMatrix,SArray,Number}) = copy(a)
recursivecopy(a::Union{StaticArraysCore.SVector,StaticArraysCore.SMatrix,
StaticArraysCore.SArray,Number}) = copy(a)
function recursivecopy(a::AbstractArray{T,N}) where {T<:Number,N}
copy(a)
end
Expand All @@ -33,7 +34,7 @@ like `copy!` on arrays of scalars.
"""
function recursivecopy! end

function recursivecopy!(b::AbstractArray{T,N},a::AbstractArray{T2,N}) where {T<:StaticArray,T2<:StaticArray,N}
function recursivecopy!(b::AbstractArray{T,N},a::AbstractArray{T2,N}) where {T<:StaticArraysCore.StaticArray,T2<:StaticArraysCore.StaticArray,N}
@inbounds for i in eachindex(a)
# TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19
b[i] = copy(a[i])
Expand Down Expand Up @@ -68,13 +69,13 @@ A recursive `fill!` function.
"""
function recursivefill! end

function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:StaticArray,T2<:StaticArray,N}
function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:StaticArraysCore.StaticArray,T2<:StaticArraysCore.StaticArray,N}
@inbounds for i in eachindex(b)
b[i] = copy(a)
end
end

function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:SArray,T2<:Union{Number,Bool},N}
function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:StaticArraysCore.SArray,T2<:Union{Number,Bool},N}
@inbounds for i in eachindex(b)
b[i] = fill(a, typeof(b[i]))
end
Expand All @@ -88,7 +89,7 @@ function recursivefill!(b::AbstractArray{T,N},a::T2) where {T<:Union{Number,Bool
fill!(b, a)
end

function recursivefill!(b::AbstractArray{T,N},a) where {T<:MArray,N}
function recursivefill!(b::AbstractArray{T,N},a) where {T<:StaticArraysCore.MArray,N}
@inbounds for i in eachindex(b)
if isassigned(b,i)
recursivefill!(b[i],a)
Expand Down Expand Up @@ -151,7 +152,7 @@ If `i<length(x)`, it's simply a `recursivecopy!` to the `i`th element. Otherwise
function copyat_or_push!(a::AbstractVector{T},i::Int,x,nc::Type{Val{perform_copy}}=Val{true}) where {T,perform_copy}
@inbounds if length(a) >= i
if !ArrayInterfaceCore.ismutable(T) || !perform_copy
# TODO: Check for `setindex!`` if T <: StaticArray and use `copy!(b[i],a[i])`
# TODO: Check for `setindex!`` if T <: StaticArraysCore.StaticArray and use `copy!(b[i],a[i])`
# or `b[i] = a[i]`, see https://github.com/JuliaDiffEq/RecursiveArrayTools.jl/issues/19
a[i] = x
else
Expand Down Expand Up @@ -208,7 +209,15 @@ ones has a `Array{Array{Float64,N},N}`, this will return `Array{Float64,N}`.
"""
recursive_unitless_eltype(a) = recursive_unitless_eltype(eltype(a))
recursive_unitless_eltype(a::Type{Any}) = Any
recursive_unitless_eltype(a::Type{T}) where {T<:StaticArray} = similar_type(a,recursive_unitless_eltype(eltype(a)))

# Should be:
# recursive_unitless_eltype(a::Type{T}) where {T<:StaticArray} = similar_type(a,recursive_unitless_eltype(eltype(a)))
# But missing from StaticArraysCore
recursive_unitless_eltype(a::Type{StaticArraysCore.SArray{S, T, N, L}}) where {S, T, N, L} = StaticArraysCore.SArray{S, typeof(one(T)), N, L}
recursive_unitless_eltype(a::Type{StaticArraysCore.MArray{S, T, N, L}}) where {S, T, N, L} = StaticArraysCore.MArray{S, typeof(one(T)), N, L}
recursive_unitless_eltype(a::Type{StaticArraysCore.SizedArray{S, T, N, M, TData}}) where {
S, T, N, M, TData} = StaticArraysCore.SizedArray{S, typeof(one(T)), N, M, TData}

recursive_unitless_eltype(a::Type{T}) where {T<:Array} = Array{recursive_unitless_eltype(eltype(a)),ndims(a)}
recursive_unitless_eltype(a::Type{T}) where {T<:Number} = typeof(one(eltype(a)))
recursive_unitless_eltype(::Type{<:Enum{T}}) where T = T
Expand Down
1 change: 1 addition & 0 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using LinearAlgebra
n, m = 5, 6
bb = rand(n), rand(m)
b = ArrayPartition(bb)
@test Array(b) isa Array
@test Array(b) == collect(b) == vcat(bb...)
A = randn(MersenneTwister(123), n+m, n+m)

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ end

@time begin

if !is_APPVEYOR && GROUP == "Core"
if GROUP == "Core" || GROUP == "All"
@time @testset "Utils Tests" begin include("utils_test.jl") end
@time @testset "Partitions Tests" begin include("partitions_test.jl") end
@time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
Expand Down
37 changes: 25 additions & 12 deletions test/upstream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dyn(u, p, t) = ArrayPartition(
ArrayPartition(zeros(1), [0.0])
)

solve(
@test solve(
ODEProblem(
dyn,
ArrayPartition(
Expand All @@ -45,15 +45,28 @@ solve(
),
(0.0, 1.0)
),AutoTsit5(Rodas5())
)

@test_broken solve(
ODEProblem(
dyn,
ArrayPartition(
ArrayPartition(zeros(1), [-1.0]),
ArrayPartition(zeros(1), [0.75])
),
(0.0, 1.0)
),Rodas5()
).retcode == :Success

if VERSION < v"1.7"
@test solve(
ODEProblem(
dyn,
ArrayPartition(
ArrayPartition(zeros(1), [-1.0]),
ArrayPartition(zeros(1), [0.75])
),
(0.0, 1.0)
),Rodas5()
).retcode == :Success
else
@test_broken solve(
ODEProblem(
dyn,
ArrayPartition(
ArrayPartition(zeros(1), [-1.0]),
ArrayPartition(zeros(1), [0.75])
),
(0.0, 1.0)
),Rodas5()
).retcode == :Success
end