Skip to content

Commit 2d7c421

Browse files
jishnubKristofferC
authored andcommitted
Fix (l/r)mul! with Diagonal/Bidiagonal (#55052)
Currently, `rmul!(A::AbstractMatirx, D::Diagonal)` calls `mul!(A, A, D)`, but this isn't a valid call, as `mul!` assumes no aliasing between the destination and the matrices to be multiplied. As a consequence, ```julia julia> B = Bidiagonal(rand(4), rand(3), :L) 4×4 Bidiagonal{Float64, Vector{Float64}}: 0.476892 ⋅ ⋅ ⋅ 0.353756 0.139188 ⋅ ⋅ ⋅ 0.685839 0.309336 ⋅ ⋅ ⋅ 0.369038 0.304273 julia> D = Diagonal(rand(size(B,2))); julia> rmul!(B, D) 4×4 Bidiagonal{Float64, Vector{Float64}}: 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ⋅ ⋅ ⋅ 0.0 0.0 julia> B 4×4 Bidiagonal{Float64, Vector{Float64}}: 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ``` This is clearly nonsense, and happens because the internal `_mul!` function assumes that it can safely overwrite the destination with zeros before carrying out the multiplication. This is fixed in this PR by using broadcasting instead. The current implementation is generally equally performant, albeit occasionally with a minor allocation arising from `reshape`ing an `Array`. A similar problem also exists in `l/rmul!` with `Bidiaognal`, but that's a little harder to fix while remaining equally performant. (cherry picked from commit 262b40a)
1 parent d4f9808 commit 2d7c421

File tree

6 files changed

+183
-4
lines changed

6 files changed

+183
-4
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,76 @@ const BiTri = Union{Bidiagonal,Tridiagonal}
435435
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
436436
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
437437

438-
lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
439-
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())
438+
# B .= A * B
439+
function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
440+
_muldiag_size_check(A, B)
441+
(; dv, ev) = A
442+
if A.uplo == 'U'
443+
for k in axes(B,2)
444+
for i in axes(ev,1)
445+
B[i,k] = dv[i] * B[i,k] + ev[i] * B[i+1,k]
446+
end
447+
B[end,k] = dv[end] * B[end,k]
448+
end
449+
else
450+
for k in axes(B,2)
451+
for i in reverse(axes(dv,1)[2:end])
452+
B[i,k] = dv[i] * B[i,k] + ev[i-1] * B[i-1,k]
453+
end
454+
B[1,k] = dv[1] * B[1,k]
455+
end
456+
end
457+
return B
458+
end
459+
# B .= D * B
460+
function lmul!(D::Diagonal, B::Bidiagonal)
461+
_muldiag_size_check(D, B)
462+
(; dv, ev) = B
463+
isL = B.uplo == 'L'
464+
dv[1] = D.diag[1] * dv[1]
465+
for i in axes(ev,1)
466+
ev[i] = D.diag[i + isL] * ev[i]
467+
dv[i+1] = D.diag[i+1] * dv[i+1]
468+
end
469+
return B
470+
end
471+
# B .= B * A
472+
function rmul!(B::AbstractMatrix, A::Bidiagonal)
473+
_muldiag_size_check(A, B)
474+
(; dv, ev) = A
475+
if A.uplo == 'U'
476+
for k in reverse(axes(dv,1)[2:end])
477+
for i in axes(B,1)
478+
B[i,k] = B[i,k] * dv[k] + B[i,k-1] * ev[k-1]
479+
end
480+
end
481+
for i in axes(B,1)
482+
B[i,1] *= dv[1]
483+
end
484+
else
485+
for k in axes(ev,1)
486+
for i in axes(B,1)
487+
B[i,k] = B[i,k] * dv[k] + B[i,k+1] * ev[k]
488+
end
489+
end
490+
for i in axes(B,1)
491+
B[i,end] *= dv[end]
492+
end
493+
end
494+
return B
495+
end
496+
# B .= B * D
497+
function rmul!(B::Bidiagonal, D::Diagonal)
498+
_muldiag_size_check(B, D)
499+
(; dv, ev) = B
500+
isU = B.uplo == 'U'
501+
dv[1] *= D.diag[1]
502+
for i in axes(ev,1)
503+
ev[i] *= D.diag[i + isU]
504+
dv[i+1] *= D.diag[i+1]
505+
end
506+
return B
507+
end
440508

441509
function check_A_mul_B!_sizes(C, A, B)
442510
mA, nA = size(A)

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,49 @@ function (*)(D::Diagonal, V::AbstractVector)
310310
return D.diag .* V
311311
end
312312

313-
rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
314-
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)
313+
function rmul!(A::AbstractMatrix, D::Diagonal)
314+
_muldiag_size_check(A, D)
315+
for I in CartesianIndices(A)
316+
row, col = Tuple(I)
317+
@inbounds A[row, col] *= D.diag[col]
318+
end
319+
return A
320+
end
321+
# T .= T * D
322+
function rmul!(T::Tridiagonal, D::Diagonal)
323+
_muldiag_size_check(T, D)
324+
(; dl, d, du) = T
325+
d[1] *= D.diag[1]
326+
for i in axes(dl,1)
327+
dl[i] *= D.diag[i]
328+
du[i] *= D.diag[i+1]
329+
d[i+1] *= D.diag[i+1]
330+
end
331+
return T
332+
end
333+
334+
function lmul!(D::Diagonal, B::AbstractVecOrMat)
335+
_muldiag_size_check(D, B)
336+
for I in CartesianIndices(B)
337+
row = I[1]
338+
@inbounds B[I] = D.diag[row] * B[I]
339+
end
340+
return B
341+
end
342+
343+
# in-place multiplication with a diagonal
344+
# T .= D * T
345+
function lmul!(D::Diagonal, T::Tridiagonal)
346+
_muldiag_size_check(D, T)
347+
(; dl, d, du) = T
348+
d[1] = D.diag[1] * d[1]
349+
for i in axes(dl,1)
350+
dl[i] = D.diag[i+1] * dl[i]
351+
du[i] = D.diag[i] * du[i]
352+
d[i+1] = D.diag[i+1] * d[i+1]
353+
end
354+
return T
355+
end
315356

316357
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
317358
require_one_based_indexing(out, B)

stdlib/LinearAlgebra/test/bidiag.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,4 +884,39 @@ end
884884
@test mul!(C1, B, sv, 1, 2) == mul!(C2, B, v, 1 ,2)
885885
end
886886

887+
@testset "rmul!/lmul! with banded matrices" begin
888+
dv, ev = rand(4), rand(3)
889+
for A in (Bidiagonal(dv, ev, :U), Bidiagonal(dv, ev, :L))
890+
@testset "$(nameof(typeof(B)))" for B in (
891+
Bidiagonal(dv, ev, :U),
892+
Bidiagonal(dv, ev, :L),
893+
Diagonal(dv)
894+
)
895+
@test_throws ArgumentError rmul!(B, A)
896+
@test_throws ArgumentError lmul!(A, B)
897+
end
898+
D = Diagonal(dv)
899+
@test rmul!(copy(A), D) A * D
900+
@test lmul!(D, copy(A)) D * A
901+
end
902+
@testset "non-commutative" begin
903+
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
904+
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
905+
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
906+
for uplo in (:L, :U)
907+
B = Bidiagonal(fill(S32, 4), fill(S32, 3), uplo)
908+
D = Diagonal(fill(S22, size(B,2)))
909+
@test rmul!(copy(B), D) B * D
910+
D = Diagonal(fill(S33, size(B,1)))
911+
@test lmul!(D, copy(B)) D * B
912+
end
913+
914+
B = Bidiagonal(fill(S33, 4), fill(S33, 3), :U)
915+
D = Diagonal(fill(S32, 4))
916+
@test lmul!(B, Array(D)) B * D
917+
B = Bidiagonal(fill(S22, 4), fill(S22, 3), :U)
918+
@test rmul!(Array(D), B) D * B
919+
end
920+
end
921+
887922
end # module TestBidiagonal

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,4 +1288,17 @@ end
12881288
@test yadj == x'
12891289
end
12901290

1291+
@testset "rmul!/lmul! with banded matrices" begin
1292+
@testset "$(nameof(typeof(B)))" for B in (
1293+
Bidiagonal(rand(4), rand(3), :L),
1294+
Tridiagonal(rand(3), rand(4), rand(3))
1295+
)
1296+
BA = Array(B)
1297+
D = Diagonal(rand(size(B,1)))
1298+
DA = Array(D)
1299+
@test rmul!(copy(B), D) B * D BA * DA
1300+
@test lmul!(D, copy(B)) D * B DA * BA
1301+
end
1302+
end
1303+
12911304
end # module TestDiagonal

stdlib/LinearAlgebra/test/tridiag.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,4 +830,23 @@ end
830830
@test axes(B) === (ax, ax)
831831
end
832832

833+
@testset "rmul!/lmul! with banded matrices" begin
834+
dl, d, du = rand(3), rand(4), rand(3)
835+
A = Tridiagonal(dl, d, du)
836+
D = Diagonal(d)
837+
@test rmul!(copy(A), D) A * D
838+
@test lmul!(D, copy(A)) D * A
839+
840+
@testset "non-commutative" begin
841+
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
842+
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
843+
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
844+
T = Tridiagonal(fill(S32,3), fill(S32, 4), fill(S32, 3))
845+
D = Diagonal(fill(S22, size(T,2)))
846+
@test rmul!(copy(T), D) T * D
847+
D = Diagonal(fill(S33, size(T,1)))
848+
@test lmul!(D, copy(T)) D * T
849+
end
850+
end
851+
833852
end # module TestTridiagonal

test/testhelpers/SizedArrays.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ Base.first(::SOneTo) = 1
2323
Base.last(r::SOneTo) = length(r)
2424
Base.show(io::IO, r::SOneTo) = print(io, "SOneTo(", length(r), ")")
2525

26+
Broadcast.axistype(a::Base.OneTo, s::SOneTo) = s
27+
Broadcast.axistype(s::SOneTo, a::Base.OneTo) = s
28+
2629
struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
2730
data::A
2831
function SizedArray{SZ}(data::AbstractArray{T,N}) where {SZ,T,N}

0 commit comments

Comments
 (0)