Skip to content

Commit 81753cc

Browse files
committed
Use sparse triangular solvers for sparse triangular solves. Fixes #13792.
Make fwd/bwdTriSolve! work for triagular views Add check for triangular matrices in sparse factorize
1 parent ee584e8 commit 81753cc

File tree

3 files changed

+74
-36
lines changed

3 files changed

+74
-36
lines changed

base/sparse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Base: Func, AddFun, OrFun, ConjFun, IdFun
66
using Base.Sort: Forward
77
using Base.LinAlg: AbstractTriangular, PosDefException
88

9-
import Base: +, -, *, &, |, $, .+, .-, .*, ./, .\, .^, .<, .!=, ==
9+
import Base: +, -, *, \, &, |, $, .+, .-, .*, ./, .\, .^, .<, .!=, ==
1010
import Base: A_mul_B!, Ac_mul_B, Ac_mul_B!, At_mul_B, At_mul_B!, A_ldiv_B!
1111

1212
import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,

base/sparse/linalg.jl

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,10 @@ end
167167
## solvers
168168
function fwdTriSolve!(A::SparseMatrixCSC, B::AbstractVecOrMat)
169169
# forward substitution for CSC matrices
170-
n = length(B)
171-
if isa(B, Vector)
172-
nrowB = n
173-
ncolB = 1
174-
else
175-
nrowB, ncolB = size(B)
176-
end
177-
ncol = chksquare(A)
170+
nrowB, ncolB = size(B, 1), size(B, 2)
171+
ncol = LinAlg.chksquare(A)
178172
if nrowB != ncol
179-
throw(DimensionMismatch("A is $(ncol)X$(ncol) and B has length $(n)"))
173+
throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows"))
180174
end
181175

182176
aa = A.nzval
@@ -185,56 +179,88 @@ function fwdTriSolve!(A::SparseMatrixCSC, B::AbstractVecOrMat)
185179

186180
joff = 0
187181
for k = 1:ncolB
188-
for j = 1:(nrowB-1)
189-
jb = joff + j
182+
for j = 1:nrowB
190183
i1 = ia[j]
191-
i2 = ia[j+1]-1
192-
B[jb] /= aa[i1]
193-
bj = B[jb]
194-
for i = i1+1:i2
195-
B[joff+ja[i]] -= bj*aa[i]
184+
i2 = ia[j + 1] - 1
185+
186+
# loop through the structural zeros
187+
ii = i1
188+
jai = ja[ii]
189+
while ii <= i2 && jai < j
190+
ii += 1
191+
jai = ja[ii]
192+
end
193+
194+
# check for zero pivot and divide with pivot
195+
if jai == j
196+
bj = B[joff + jai]/aa[ii]
197+
B[joff + jai] = bj
198+
ii += 1
199+
else
200+
throw(LinAlg.SingularException(j))
201+
end
202+
203+
# update remaining part
204+
for i = ii:i2
205+
B[joff + ja[i]] -= bj*aa[i]
196206
end
197207
end
198208
joff += nrowB
199-
B[joff] /= aa[end]
200209
end
201-
return B
210+
B
202211
end
203212

204213
function bwdTriSolve!(A::SparseMatrixCSC, B::AbstractVecOrMat)
205214
# backward substitution for CSC matrices
206-
n = length(B)
207-
if isa(B, Vector)
208-
nrowB = n
209-
ncolB = 1
210-
else
211-
nrowB, ncolB = size(B)
215+
nrowB, ncolB = size(B, 1), size(B, 2)
216+
ncol = LinAlg.chksquare(A)
217+
if nrowB != ncol
218+
throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows"))
212219
end
213-
ncol = chksquare(A)
214-
if nrowB != ncol throw(DimensionMismatch("A is $(ncol)X$(ncol) and B has length $(n)")) end
215220

216221
aa = A.nzval
217222
ja = A.rowval
218223
ia = A.colptr
219224

220225
joff = 0
221226
for k = 1:ncolB
222-
for j = nrowB:-1:2
223-
jb = joff + j
227+
for j = nrowB:-1:1
224228
i1 = ia[j]
225-
i2 = ia[j+1]-1
226-
B[jb] /= aa[i2]
227-
bj = B[jb]
228-
for i = i2-1:-1:i1
229-
B[joff+ja[i]] -= bj*aa[i]
229+
i2 = ia[j + 1] - 1
230+
231+
# loop through the structural zeros
232+
ii = i2
233+
jai = ja[ii]
234+
while ii >= i1 && jai > j
235+
ii -= 1
236+
jai = ja[ii]
237+
end
238+
239+
# check for zero pivot and divide with pivot
240+
if jai == j
241+
bj = B[joff + jai]/aa[ii]
242+
B[joff + jai] = bj
243+
ii -= 1
244+
else
245+
throw(LinAlg.SingularException(j))
246+
end
247+
248+
# update remaining part
249+
for i = ii:-1:i1
250+
B[joff + ja[i]] -= bj*aa[i]
230251
end
231252
end
232-
B[joff+1] /= aa[1]
233253
joff += nrowB
234254
end
235-
return B
255+
B
236256
end
237257

258+
A_ldiv_B!{T,Ti}(L::LowerTriangular{T,SparseMatrixCSC{T,Ti}}, B::StridedVecOrMat) = fwdTriSolve!(L.data, B)
259+
A_ldiv_B!{T,Ti}(U::UpperTriangular{T,SparseMatrixCSC{T,Ti}}, B::StridedVecOrMat) = bwdTriSolve!(U.data, B)
260+
261+
(\){T,Ti}(L::LowerTriangular{T,SparseMatrixCSC{T,Ti}}, B::SparseMatrixCSC) = A_ldiv_B!(L, full(B))
262+
(\){T,Ti}(U::UpperTriangular{T,SparseMatrixCSC{T,Ti}}, B::SparseMatrixCSC) = A_ldiv_B!(U, full(B))
263+
238264
## triu, tril
239265

240266
function triu{Tv,Ti}(S::SparseMatrixCSC{Tv,Ti}, k::Integer=0)

test/sparsedir/sparse.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,3 +1167,15 @@ let
11671167
@test_throws ErrorException eig(A)
11681168
@test_throws ErrorException inv(A)
11691169
end
1170+
1171+
let
1172+
n = 100
1173+
A = sprandn(n, n, 0.5) + sqrt(n)*I
1174+
x = LowerTriangular(A)*ones(n)
1175+
@test LowerTriangular(A)\x ones(n)
1176+
x = UpperTriangular(A)*ones(n)
1177+
@test UpperTriangular(A)\x ones(n)
1178+
A[2,2] = 0
1179+
@test_throws LinAlg.SingularException LowerTriangular(A)\ones(n)
1180+
@test_throws LinAlg.SingularException UpperTriangular(A)\ones(n)
1181+
end

0 commit comments

Comments
 (0)