Skip to content

Commit d7aafe2

Browse files
Merge pull request #94 from SciML/myb/linalg
Add fast blocked substitution for triangular matrices
2 parents c9958b8 + f72d37c commit d7aafe2

File tree

4 files changed

+99
-5
lines changed

4 files changed

+99
-5
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "2.3.0"
4+
version = "2.3.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -25,6 +25,7 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
2525
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2626
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2727
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
28+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2829

2930
[targets]
30-
test = ["NLsolve", "OrdinaryDiffEq", "Test", "Unitful"]
31+
test = ["NLsolve", "OrdinaryDiffEq", "Test", "Unitful", "Random"]

src/array_partition.jl

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,70 @@ common_number(a, b) =
299299
## Linear Algebra
300300

301301
ArrayInterface.zeromatrix(A::ArrayPartition) = ArrayInterface.zeromatrix(reduce(vcat,vec.(A.x)))
302-
LinearAlgebra.ldiv!(A::LinearAlgebra.LU,b::ArrayPartition) = ldiv!(A,Array(b))
303-
LinearAlgebra.ldiv!(A::LinearAlgebra.QR,b::ArrayPartition) = ldiv!(A,Array(b))
304-
LinearAlgebra.ldiv!(A::LinearAlgebra.SVD,b::ArrayPartition) = ldiv!(A,Array(b))
302+
303+
LinearAlgebra.ldiv!(A::Factorization, b::ArrayPartition) = (x = ldiv!(A, Array(b)); copyto!(b, x))
304+
function LinearAlgebra.ldiv!(A::LU, b::ArrayPartition)
305+
LinearAlgebra._ipiv_rows!(A, 1 : length(A.ipiv), b)
306+
ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), b))
307+
return b
308+
end
309+
310+
# block matrix indexing
311+
@inbounds function getblock(A, lens, i, j)
312+
ii1 = i == 1 ? 0 : sum(ii->lens[ii], 1:i-1)
313+
jj1 = j == 1 ? 0 : sum(ii->lens[ii], 1:j-1)
314+
ij1 = CartesianIndex(ii1, jj1)
315+
cc1 = CartesianIndex((1, 1))
316+
inc = CartesianIndex(lens[i], lens[j])
317+
return @view A[(ij1+cc1):(ij1+inc)]
318+
end
319+
# fast ldiv for UpperTriangular and UnitLowerTriangular
320+
# [U11 U12 U13] [ b1 ]
321+
# [ 0 U22 U23] \ [ b2 ]
322+
# [ 0 0 U33] [ b3 ]
323+
function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitUpperTriangular,UpperTriangular}
324+
A = A.data
325+
n = npartitions(bb)
326+
b = bb.x
327+
lens = map(length, b)
328+
@inbounds for j in n:-1:1
329+
Ajj = T(getblock(A, lens, j, j))
330+
xj = ldiv!(Ajj, b[j])
331+
for i in j-1:-1:1
332+
Aij = getblock(A, lens, i, j)
333+
# bi = -Aij * xj + bi
334+
mul!(b[i], Aij, xj, -1, true)
335+
end
336+
end
337+
return bb
338+
end
339+
340+
function LinearAlgebra.ldiv!(A::T, bb::ArrayPartition) where T<:Union{UnitLowerTriangular,LowerTriangular}
341+
A = A.data
342+
n = npartitions(bb)
343+
b = bb.x
344+
lens = map(length, b)
345+
@inbounds for j in 1:n
346+
Ajj = T(getblock(A, lens, j, j))
347+
xj = ldiv!(Ajj, b[j])
348+
for i in j+1:n
349+
Aij = getblock(A, lens, i, j)
350+
# bi = -Aij * xj + b[i]
351+
mul!(b[i], Aij, xj, -1, true)
352+
end
353+
end
354+
return bb
355+
end
356+
# TODO: optimize
357+
function LinearAlgebra._ipiv_rows!(A::LU, order::OrdinalRange, B::ArrayPartition)
358+
for i = order
359+
if i != A.ipiv[i]
360+
LinearAlgebra._swap_rows!(B, i, A.ipiv[i])
361+
end
362+
end
363+
return B
364+
end
365+
function LinearAlgebra._swap_rows!(B::ArrayPartition, i::Integer, j::Integer)
366+
B[i], B[j] = B[j], B[i]
367+
return B
368+
end

test/linalg.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using RecursiveArrayTools, Test, Random
2+
using LinearAlgebra
3+
4+
n, m = 5, 6
5+
bb = rand(n), rand(m)
6+
b = ArrayPartition(bb)
7+
@test Array(b) == collect(b) == vcat(bb...)
8+
A = randn(MersenneTwister(123), n+m, n+m)
9+
10+
for T in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular)
11+
B = T(A)
12+
@test B*Array(B \ b) b
13+
bbb = copy(b)
14+
@test ldiv!(bbb, B, b) === bbb
15+
copyto!(bbb, b)
16+
@test ldiv!(B, bbb) === bbb
17+
@test B*Array(bbb) b
18+
end
19+
20+
for ff in (lu, svd, qr)
21+
FF = ff(A)
22+
@test A*(FF \ b) b
23+
bbb = copy(b)
24+
@test ldiv!(bbb, FF, b) === bbb
25+
copyto!(bbb, b)
26+
@test ldiv!(FF, bbb) === bbb
27+
@test A*bbb b
28+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ using Test
77
@time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
88
@time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end
99
@time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
10+
@time @testset "Linear Algebra Tests" begin include("linalg.jl") end
1011
@time @testset "Upstream Tests" begin include("upstream.jl") end
1112
end

0 commit comments

Comments
 (0)