@@ -299,6 +299,70 @@ common_number(a, b) =
299
299
# # Linear Algebra
300
300
301
301
ArrayInterface. zeromatrix (A:: ArrayPartition ) = ArrayInterface. zeromatrix (reduce (vcat,vec .(A. x)))
302
- LinearAlgebra. ldiv! (A:: LinearAlgebra.LU ,b:: ArrayPartition ) = ldiv! (A,Array (b))
303
- LinearAlgebra. ldiv! (A:: LinearAlgebra.QR ,b:: ArrayPartition ) = ldiv! (A,Array (b))
304
- LinearAlgebra. ldiv! (A:: LinearAlgebra.SVD ,b:: ArrayPartition ) = ldiv! (A,Array (b))
302
+
303
+ LinearAlgebra. ldiv! (A:: Factorization , b:: ArrayPartition ) = (x = ldiv! (A, Array (b)); copyto! (b, x))
304
+ function LinearAlgebra. ldiv! (A:: LU , b:: ArrayPartition )
305
+ LinearAlgebra. _ipiv_rows! (A, 1 : length (A. ipiv), b)
306
+ ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), b))
307
+ return b
308
+ end
309
+
310
+ # block matrix indexing
311
+ @inbounds function getblock (A, lens, i, j)
312
+ ii1 = i == 1 ? 0 : sum (ii-> lens[ii], 1 : i- 1 )
313
+ jj1 = j == 1 ? 0 : sum (ii-> lens[ii], 1 : j- 1 )
314
+ ij1 = CartesianIndex (ii1, jj1)
315
+ cc1 = CartesianIndex ((1 , 1 ))
316
+ inc = CartesianIndex (lens[i], lens[j])
317
+ return @view A[(ij1+ cc1): (ij1+ inc)]
318
+ end
319
+ # fast ldiv for UpperTriangular and UnitLowerTriangular
320
+ # [U11 U12 U13] [ b1 ]
321
+ # [ 0 U22 U23] \ [ b2 ]
322
+ # [ 0 0 U33] [ b3 ]
323
+ function LinearAlgebra. ldiv! (A:: T , bb:: ArrayPartition ) where T<: Union{UnitUpperTriangular,UpperTriangular}
324
+ A = A. data
325
+ n = npartitions (bb)
326
+ b = bb. x
327
+ lens = map (length, b)
328
+ @inbounds for j in n: - 1 : 1
329
+ Ajj = T (getblock (A, lens, j, j))
330
+ xj = ldiv! (Ajj, b[j])
331
+ for i in j- 1 : - 1 : 1
332
+ Aij = getblock (A, lens, i, j)
333
+ # bi = -Aij * xj + bi
334
+ mul! (b[i], Aij, xj, - 1 , true )
335
+ end
336
+ end
337
+ return bb
338
+ end
339
+
340
+ function LinearAlgebra. ldiv! (A:: T , bb:: ArrayPartition ) where T<: Union{UnitLowerTriangular,LowerTriangular}
341
+ A = A. data
342
+ n = npartitions (bb)
343
+ b = bb. x
344
+ lens = map (length, b)
345
+ @inbounds for j in 1 : n
346
+ Ajj = T (getblock (A, lens, j, j))
347
+ xj = ldiv! (Ajj, b[j])
348
+ for i in j+ 1 : n
349
+ Aij = getblock (A, lens, i, j)
350
+ # bi = -Aij * xj + b[i]
351
+ mul! (b[i], Aij, xj, - 1 , true )
352
+ end
353
+ end
354
+ return bb
355
+ end
356
+ # TODO : optimize
357
+ function LinearAlgebra. _ipiv_rows! (A:: LU , order:: OrdinalRange , B:: ArrayPartition )
358
+ for i = order
359
+ if i != A. ipiv[i]
360
+ LinearAlgebra. _swap_rows! (B, i, A. ipiv[i])
361
+ end
362
+ end
363
+ return B
364
+ end
365
+ function LinearAlgebra. _swap_rows! (B:: ArrayPartition , i:: Integer , j:: Integer )
366
+ B[i], B[j] = B[j], B[i]
367
+ return B
368
+ end
0 commit comments