@@ -176,14 +176,16 @@ Base.all(f, A::ArrayPartition) = all(f, (all(f, x) for x in A.x))
176
176
Base. all (f:: Function , A:: ArrayPartition ) = all ((all (f, x) for x in A. x))
177
177
Base. all (A:: ArrayPartition ) = all (identity, A)
178
178
179
- function Base. copyto! (dest:: AbstractArray , A:: ArrayPartition )
180
- @assert length (dest) == length (A)
181
- cur = 1
182
- @inbounds for i in 1 : length (A. x)
183
- dest[cur: (cur + length (A. x[i]) - 1 )] .= vec (A. x[i])
184
- cur += length (A. x[i])
179
+ for type in [AbstractArray, SparseArrays. AbstractCompressedVector, PermutedDimsArray]
180
+ @eval function Base. copyto! (dest:: $ (type), A:: ArrayPartition )
181
+ @assert length (dest) == length (A)
182
+ cur = 1
183
+ @inbounds for i in 1 : length (A. x)
184
+ dest[cur: (cur + length (A. x[i]) - 1 )] .= vec (A. x[i])
185
+ cur += length (A. x[i])
186
+ end
187
+ dest
185
188
end
186
- dest
187
189
end
188
190
189
191
function Base. copyto! (A:: ArrayPartition , src:: ArrayPartition )
@@ -419,30 +421,38 @@ end
419
421
420
422
ArrayInterface. zeromatrix (A:: ArrayPartition ) = ArrayInterface. zeromatrix (Vector (A))
421
423
422
- function LinearAlgebra. ldiv! (A:: Factorization , b:: ArrayPartition )
423
- (x = ldiv! (A, Array (b)); copyto! (b, x))
424
+ function __get_subtypes_in_module (mod, supertype; include_supertype = true , all= false , except= [])
425
+ return filter ([getproperty (mod, name) for name in names (mod; all) if ! in (name, except)]) do value
426
+ return value isa Type && (value <: supertype ) && (include_supertype || value != supertype) && ! in (value, except)
427
+ end
424
428
end
425
429
426
- @static if VERSION >= v " 1.9"
427
- function LinearAlgebra. ldiv! (A:: LinearAlgebra.SVD{T, Tr, M} ,
428
- b:: ArrayPartition ) where {Tr, T, M <: AbstractArray{T} }
430
+ for factorization in vcat (__get_subtypes_in_module (LinearAlgebra, Factorization; include_supertype = false , all= true , except= [:LU , :LAPACKFactorizations ]), LDLt{T,<: SymTridiagonal{T,V} where {V<: AbstractVector{T} }} where {T})
431
+ @eval function LinearAlgebra. ldiv! (A:: T , b:: ArrayPartition ) where {T<: $factorization }
429
432
(x = ldiv! (A, Array (b)); copyto! (b, x))
430
433
end
434
+ end
431
435
432
- function LinearAlgebra. ldiv! (A:: LinearAlgebra.QRCompactWY{T, M, C} ,
433
- b:: ArrayPartition ) where {
434
- T <: Union{Float32, Float64, ComplexF64, ComplexF32} ,
435
- M <: AbstractMatrix{T} ,
436
- C <: AbstractMatrix{T} ,
437
- }
438
- (x = ldiv! (A, Array (b)); copyto! (b, x))
439
- end
436
+ function LinearAlgebra. ldiv! (A:: LinearAlgebra.SVD{T, Tr, M} ,
437
+ b:: ArrayPartition ) where {Tr, T, M <: AbstractArray{T} }
438
+ (x = ldiv! (A, Array (b)); copyto! (b, x))
440
439
end
441
440
442
- function LinearAlgebra. ldiv! (A:: LU , b:: ArrayPartition )
443
- LinearAlgebra. _ipiv_rows! (A, 1 : length (A. ipiv), b)
444
- ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), b))
445
- return b
441
+ function LinearAlgebra. ldiv! (A:: LinearAlgebra.QRCompactWY{T, M, C} ,
442
+ b:: ArrayPartition ) where {
443
+ T <: Union{Float32, Float64, ComplexF64, ComplexF32} ,
444
+ M <: AbstractMatrix{T} ,
445
+ C <: AbstractMatrix{T} ,
446
+ }
447
+ (x = ldiv! (A, Array (b)); copyto! (b, x))
448
+ end
449
+
450
+ for type in [LU, LU{T,Tridiagonal{T,V}} where {T,V}]
451
+ @eval function LinearAlgebra. ldiv! (A:: $type , b:: ArrayPartition )
452
+ LinearAlgebra. _ipiv_rows! (A, 1 : length (A. ipiv), b)
453
+ ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), b))
454
+ return b
455
+ end
446
456
end
447
457
448
458
# block matrix indexing
@@ -458,78 +468,31 @@ end
458
468
# [U11 U12 U13] [ b1 ]
459
469
# [ 0 U22 U23] \ [ b2 ]
460
470
# [ 0 0 U33] [ b3 ]
461
- function LinearAlgebra. ldiv! (A:: UnitUpperTriangular , bb:: ArrayPartition )
462
- A = A. data
463
- n = npartitions (bb)
464
- b = bb. x
465
- lens = map (length, b)
466
- @inbounds for j in n: - 1 : 1
467
- Ajj = UnitUpperTriangular (getblock (A, lens, j, j))
468
- xj = ldiv! (Ajj, vec (b[j]))
469
- for i in (j - 1 ): - 1 : 1
470
- Aij = getblock (A, lens, i, j)
471
- # bi = -Aij * xj + bi
472
- mul! (vec (b[i]), Aij, xj, - 1 , true )
473
- end
474
- end
475
- return bb
476
- end
477
-
478
- function LinearAlgebra. ldiv! (A:: UpperTriangular , bb:: ArrayPartition )
479
- A = A. data
480
- n = npartitions (bb)
481
- b = bb. x
482
- lens = map (length, b)
483
- @inbounds for j in n: - 1 : 1
484
- Ajj = UpperTriangular (getblock (A, lens, j, j))
485
- xj = ldiv! (Ajj, vec (b[j]))
486
- for i in (j - 1 ): - 1 : 1
487
- Aij = getblock (A, lens, i, j)
488
- # bi = -Aij * xj + bi
489
- mul! (vec (b[i]), Aij, xj, - 1 , true )
490
- end
491
- end
492
- return bb
493
- end
494
-
495
- function LinearAlgebra. ldiv! (A:: UnitLowerTriangular , bb:: ArrayPartition )
496
- A = A. data
497
- n = npartitions (bb)
498
- b = bb. x
499
- lens = map (length, b)
500
- @inbounds for j in 1 : n
501
- Ajj = UnitLowerTriangular (getblock (A, lens, j, j))
502
- xj = ldiv! (Ajj, vec (b[j]))
503
- for i in (j + 1 ): n
504
- Aij = getblock (A, lens, i, j)
505
- # bi = -Aij * xj + b[i]
506
- mul! (vec (b[i]), Aij, xj, - 1 , true )
471
+ for basetype in [UnitUpperTriangular, UpperTriangular, UnitLowerTriangular, LowerTriangular]
472
+ for type in [basetype, basetype{T, <: Adjoint{T} } where {T}, basetype{T, <: Transpose{T} } where {T}]
473
+ j_iter, i_iter = if basetype <: UnitUpperTriangular || basetype <: UpperTriangular
474
+ (:(n: - 1 : 1 ), :(j- 1 : - 1 : 1 ))
475
+ else
476
+ (:(1 : n), :((j+ 1 ): n))
507
477
end
508
- end
509
- return bb
510
- end
511
- function _ldiv! (A :: LowerTriangular , bb :: ArrayPartition )
512
- A = A . data
513
- n = npartitions (bb)
514
- b = bb . x
515
- lens = map (length, b )
516
- @inbounds for j in 1 : n
517
- Ajj = LowerTriangular ( getblock (A, lens, j , j) )
518
- xj = ldiv! (Ajj, vec (b[j]))
519
- for i in (j + 1 ) : n
520
- Aij = getblock (A, lens, i, j)
521
- # bi = -Aij * xj + b[i]
522
- mul! ( vec (b[i]), Aij, xj, - 1 , true )
478
+ @eval function LinearAlgebra . ldiv! (A :: $type , bb :: ArrayPartition )
479
+ A = A . data
480
+ n = npartitions (bb)
481
+ b = bb . x
482
+ lens = map (length, b)
483
+ @inbounds for j in $ j_iter
484
+ Ajj = $ basetype ( getblock (A, lens, j, j))
485
+ xj = ldiv! (Ajj, vec (b[j]) )
486
+ for i in $ i_iter
487
+ Aij = getblock (A, lens, i , j)
488
+ # bi = -Aij * xj + bi
489
+ mul! ( vec (b[i]), Aij, xj, - 1 , true )
490
+ end
491
+ end
492
+ return bb
523
493
end
524
494
end
525
- return bb
526
- end
527
-
528
- function LinearAlgebra. ldiv! (A:: LowerTriangular{T, <:LinearAlgebra.Adjoint{T}} ,
529
- bb:: ArrayPartition ) where {T}
530
- _ldiv! (A, bb)
531
495
end
532
- LinearAlgebra. ldiv! (A:: LowerTriangular , bb:: ArrayPartition ) = _ldiv! (A, bb)
533
496
534
497
# TODO : optimize
535
498
function LinearAlgebra. _ipiv_rows! (A:: LU , order:: OrdinalRange , B:: ArrayPartition )
0 commit comments