diff --git a/Project.toml b/Project.toml index 22d07c4f..458634f6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveArrayTools" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" authors = ["Chris Rackauckas "] -version = "2.3.0" +version = "2.3.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -25,6 +25,7 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [targets] -test = ["NLsolve", "OrdinaryDiffEq", "Test", "Unitful"] +test = ["NLsolve", "OrdinaryDiffEq", "Test", "Unitful", "Random"] diff --git a/src/array_partition.jl b/src/array_partition.jl index 65743356..c7d8c3ec 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -299,6 +299,70 @@ common_number(a, b) = ## Linear Algebra ArrayInterface.zeromatrix(A::ArrayPartition) = ArrayInterface.zeromatrix(reduce(vcat,vec.(A.x))) -LinearAlgebra.ldiv!(A::LinearAlgebra.LU,b::ArrayPartition) = ldiv!(A,Array(b)) -LinearAlgebra.ldiv!(A::LinearAlgebra.QR,b::ArrayPartition) = ldiv!(A,Array(b)) -LinearAlgebra.ldiv!(A::LinearAlgebra.SVD,b::ArrayPartition) = ldiv!(A,Array(b)) + +LinearAlgebra.ldiv!(A::Factorization, b::ArrayPartition) = (x = ldiv!(A, Array(b)); copyto!(b, x)) +function LinearAlgebra.ldiv!(A::LU, b::ArrayPartition) + LinearAlgebra._ipiv_rows!(A, 1 : length(A.ipiv), b) + ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), b)) + return b +end + +# block matrix indexing +@inbounds function getblock(A, lens, i, j) + ii1 = i == 1 ? 0 : sum(ii->lens[ii], 1:i-1) + jj1 = j == 1 ? 0 : sum(ii->lens[ii], 1:j-1) + ij1 = CartesianIndex(ii1, jj1) + cc1 = CartesianIndex((1, 1)) + inc = CartesianIndex(lens[i], lens[j]) + return @view A[(ij1+cc1):(ij1+inc)] +end +# fast ldiv for UpperTriangular and UnitLowerTriangular +# [U11 U12 U13] [ b1 ] +# [ 0 U22 U23] \ [ b2 ] +# [ 0 0 U33] [ b3 ] +function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitUpperTriangular,UpperTriangular} + A = A.data + n = npartitions(bb) + b = bb.x + lens = map(length, b) + @inbounds for j in n:-1:1 + Ajj = T(getblock(A, lens, j, j)) + xj = ldiv!(Ajj, b[j]) + for i in j-1:-1:1 + Aij = getblock(A, lens, i, j) + # bi = -Aij * xj + bi + mul!(b[i], Aij, xj, -1, true) + end + end + return bb +end + +function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitLowerTriangular,LowerTriangular} + A = A.data + n = npartitions(bb) + b = bb.x + lens = map(length, b) + @inbounds for j in 1:n + Ajj = T(getblock(A, lens, j, j)) + xj = ldiv!(Ajj, b[j]) + for i in j+1:n + Aij = getblock(A, lens, i, j) + # bi = -Aij * xj + b[i] + mul!(b[i], Aij, xj, -1, true) + end + end + return bb +end +# TODO: optimize +function LinearAlgebra._ipiv_rows!(A::LU, order::OrdinalRange, B::ArrayPartition) + for i = order + if i != A.ipiv[i] + LinearAlgebra._swap_rows!(B, i, A.ipiv[i]) + end + end + return B +end +function LinearAlgebra._swap_rows!(B::ArrayPartition, i::Integer, j::Integer) + B[i], B[j] = B[j], B[i] + return B +end diff --git a/test/linalg.jl b/test/linalg.jl new file mode 100644 index 00000000..ac35452c --- /dev/null +++ b/test/linalg.jl @@ -0,0 +1,28 @@ +using RecursiveArrayTools, Test, Random +using LinearAlgebra + +n, m = 5, 6 +bb = rand(n), rand(m) +b = ArrayPartition(bb) +@test Array(b) == collect(b) == vcat(bb...) +A = randn(MersenneTwister(123), n+m, n+m) + +for T in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular) + B = T(A) + @test B*Array(B \ b) ≈ b + bbb = copy(b) + @test ldiv!(bbb, B, b) === bbb + copyto!(bbb, b) + @test ldiv!(B, bbb) === bbb + @test B*Array(bbb) ≈ b +end + +for ff in (lu, svd, qr) + FF = ff(A) + @test A*(FF \ b) ≈ b + bbb = copy(b) + @test ldiv!(bbb, FF, b) === bbb + copyto!(bbb, b) + @test ldiv!(FF, bbb) === bbb + @test A*bbb ≈ b +end diff --git a/test/runtests.jl b/test/runtests.jl index 46ceadaf..300a84f4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,5 +7,6 @@ using Test @time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end @time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end @time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end + @time @testset "Linear Algebra Tests" begin include("linalg.jl") end @time @testset "Upstream Tests" begin include("upstream.jl") end end