Skip to content

Commit 5345992

Browse files
committed
Backport "Fix (l/r)mul! with Diagonal/Bidiagonal #55052" to v1.11
1 parent eaa792f commit 5345992

File tree

6 files changed

+195
-4
lines changed

6 files changed

+195
-4
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,76 @@ const BiTri = Union{Bidiagonal,Tridiagonal}
435435
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
436436
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
437437

438-
lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
439-
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())
438+
# B .= A * B
439+
function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
440+
_muldiag_size_check(A, B)
441+
(; dv, ev) = A
442+
if A.uplo == 'U'
443+
for k in axes(B,2)
444+
for i in axes(ev,1)
445+
B[i,k] = dv[i] * B[i,k] + ev[i] * B[i+1,k]
446+
end
447+
B[end,k] = dv[end] * B[end,k]
448+
end
449+
else
450+
for k in axes(B,2)
451+
for i in reverse(axes(dv,1)[2:end])
452+
B[i,k] = dv[i] * B[i,k] + ev[i-1] * B[i-1,k]
453+
end
454+
B[1,k] = dv[1] * B[1,k]
455+
end
456+
end
457+
return B
458+
end
459+
# B .= D * B
460+
function lmul!(D::Diagonal, B::Bidiagonal)
461+
_muldiag_size_check(D, B)
462+
(; dv, ev) = B
463+
isL = B.uplo == 'L'
464+
dv[1] = D.diag[1] * dv[1]
465+
for i in axes(ev,1)
466+
ev[i] = D.diag[i + isL] * ev[i]
467+
dv[i+1] = D.diag[i+1] * dv[i+1]
468+
end
469+
return B
470+
end
471+
# B .= B * A
472+
function rmul!(B::AbstractMatrix, A::Bidiagonal)
473+
_muldiag_size_check(A, B)
474+
(; dv, ev) = A
475+
if A.uplo == 'U'
476+
for k in reverse(axes(dv,1)[2:end])
477+
for i in axes(B,1)
478+
B[i,k] = B[i,k] * dv[k] + B[i,k-1] * ev[k-1]
479+
end
480+
end
481+
for i in axes(B,1)
482+
B[i,1] *= dv[1]
483+
end
484+
else
485+
for k in axes(ev,1)
486+
for i in axes(B,1)
487+
B[i,k] = B[i,k] * dv[k] + B[i,k+1] * ev[k]
488+
end
489+
end
490+
for i in axes(B,1)
491+
B[i,end] *= dv[end]
492+
end
493+
end
494+
return B
495+
end
496+
# B .= B * D
497+
function rmul!(B::Bidiagonal, D::Diagonal)
498+
_muldiag_size_check(B, D)
499+
(; dv, ev) = B
500+
isU = B.uplo == 'U'
501+
dv[1] *= D.diag[1]
502+
for i in axes(ev,1)
503+
ev[i] *= D.diag[i + isU]
504+
dv[i+1] *= D.diag[i+1]
505+
end
506+
return B
507+
end
440508

441509
function check_A_mul_B!_sizes(C, A, B)
442510
mA, nA = size(A)

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,49 @@ function (*)(D::Diagonal, V::AbstractVector)
310310
return D.diag .* V
311311
end
312312

313-
rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
314-
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)
313+
function rmul!(A::AbstractMatrix, D::Diagonal)
314+
_muldiag_size_check(A, D)
315+
for I in CartesianIndices(A)
316+
row, col = Tuple(I)
317+
@inbounds A[row, col] *= D.diag[col]
318+
end
319+
return A
320+
end
321+
# T .= T * D
322+
function rmul!(T::Tridiagonal, D::Diagonal)
323+
_muldiag_size_check(T, D)
324+
(; dl, d, du) = T
325+
d[1] *= D.diag[1]
326+
for i in axes(dl,1)
327+
dl[i] *= D.diag[i]
328+
du[i] *= D.diag[i+1]
329+
d[i+1] *= D.diag[i+1]
330+
end
331+
return T
332+
end
333+
334+
function lmul!(D::Diagonal, B::AbstractVecOrMat)
335+
_muldiag_size_check(D, B)
336+
for I in CartesianIndices(B)
337+
row = I[1]
338+
@inbounds B[I] = D.diag[row] * B[I]
339+
end
340+
return B
341+
end
342+
343+
# in-place multiplication with a diagonal
344+
# T .= D * T
345+
function lmul!(D::Diagonal, T::Tridiagonal)
346+
_muldiag_size_check(D, T)
347+
(; dl, d, du) = T
348+
d[1] = D.diag[1] * d[1]
349+
for i in axes(dl,1)
350+
dl[i] = D.diag[i+1] * dl[i]
351+
du[i] = D.diag[i] * du[i]
352+
d[i+1] = D.diag[i+1] * d[i+1]
353+
end
354+
return T
355+
end
315356

316357
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
317358
require_one_based_indexing(out, B)

stdlib/LinearAlgebra/test/bidiag.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,4 +884,38 @@ end
884884
@test mul!(C1, B, sv, 1, 2) == mul!(C2, B, v, 1 ,2)
885885
end
886886

887+
@testset "rmul!/lmul! with banded matrices" begin
888+
dv, ev = rand(4), rand(3)
889+
for A in (Bidiagonal(dv, ev, :U), Bidiagonal(dv, ev, :L))
890+
@testset "$(nameof(typeof(B)))" for B in (
891+
Bidiagonal(dv, ev, :U),
892+
Bidiagonal(dv, ev, :L),
893+
Diagonal(dv)
894+
)
895+
@test_throws ArgumentError rmul!(B, A)
896+
@test_throws ArgumentError lmul!(A, B)
897+
end
898+
D = Diagonal(dv)
899+
@test rmul!(copy(A), D) A * D
900+
@test lmul!(D, copy(A)) D * A
901+
end
902+
@testset "non-commutative" begin
903+
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
904+
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
905+
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
906+
for uplo in (:L, :U)
907+
B = Bidiagonal(fill(S32, 4), fill(S32, 3), uplo)
908+
D = Diagonal(fill(S22, size(B,2)))
909+
@test rmul!(copy(B), D) B * D
910+
D = Diagonal(fill(S33, size(B,1)))
911+
@test lmul!(D, copy(B)) D * B
912+
end
913+
914+
B = Bidiagonal(fill(S33, 4), fill(S33, 3), :U)
915+
D = Diagonal(fill(S32, 4))
916+
@test lmul!(B, Array(D)) B * D
917+
B = Bidiagonal(fill(S22, 4), fill(S22, 3), :U)
918+
@test rmul!(Array(D), B) D * B
919+
end
920+
end
887921
end # module TestBidiagonal

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,4 +1288,16 @@ end
12881288
@test yadj == x'
12891289
end
12901290

1291+
@testset "rmul!/lmul! with banded matrices" begin
1292+
@testset "$(nameof(typeof(B)))" for B in (
1293+
Bidiagonal(rand(4), rand(3), :L),
1294+
Tridiagonal(rand(3), rand(4), rand(3))
1295+
)
1296+
BA = Array(B)
1297+
D = Diagonal(rand(size(B,1)))
1298+
DA = Array(D)
1299+
@test rmul!(copy(B), D) B * D BA * DA
1300+
@test lmul!(D, copy(B)) D * B DA * BA
1301+
end
1302+
end
12911303
end # module TestDiagonal

stdlib/LinearAlgebra/test/tridiag.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,4 +830,22 @@ end
830830
@test axes(B) === (ax, ax)
831831
end
832832

833+
@testset "rmul!/lmul! with banded matrices" begin
834+
dl, d, du = rand(3), rand(4), rand(3)
835+
A = Tridiagonal(dl, d, du)
836+
D = Diagonal(d)
837+
@test rmul!(copy(A), D) A * D
838+
@test lmul!(D, copy(A)) D * A
839+
840+
@testset "non-commutative" begin
841+
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
842+
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
843+
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
844+
T = Tridiagonal(fill(S32,3), fill(S32, 4), fill(S32, 3))
845+
D = Diagonal(fill(S22, size(T,2)))
846+
@test rmul!(copy(T), D) T * D
847+
D = Diagonal(fill(S33, size(T,1)))
848+
@test lmul!(D, copy(T)) D * T
849+
end
850+
end
833851
end # module TestTridiagonal

test/testhelpers/SizedArrays.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ Base.first(::SOneTo) = 1
2323
Base.last(r::SOneTo) = length(r)
2424
Base.show(io::IO, r::SOneTo) = print(io, "SOneTo(", length(r), ")")
2525

26+
Broadcast.axistype(a::Base.OneTo, s::SOneTo) = s
27+
Broadcast.axistype(s::SOneTo, a::Base.OneTo) = s
28+
2629
struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
2730
data::A
2831
function SizedArray{SZ}(data::AbstractArray{T,N}) where {SZ,T,N}
@@ -43,10 +46,25 @@ Base.size(a::SizedArray) = size(typeof(a))
4346
Base.size(::Type{<:SizedArray{SZ}}) where {SZ} = SZ
4447
Base.axes(a::SizedArray) = map(SOneTo, size(a))
4548
Base.getindex(A::SizedArray, i...) = getindex(A.data, i...)
49+
Base.setindex!(A::SizedArray, v, i...) = setindex!(A.data, v, i...)
4650
Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T)))
51+
Base.parent(S::SizedArray) = S.data
4752
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
4853
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data
4954

55+
homogenize_shape(t::Tuple) = (_homogenize_shape(first(t)), homogenize_shape(Base.tail(t))...)
56+
homogenize_shape(::Tuple{}) = ()
57+
_homogenize_shape(x::Integer) = x
58+
_homogenize_shape(x::AbstractUnitRange) = length(x)
59+
const Dims = Union{Integer, Base.OneTo, SOneTo}
60+
function Base.similar(::Type{A}, shape::Tuple{Dims, Vararg{Dims}}) where {A<:AbstractArray}
61+
similar(A, homogenize_shape(shape))
62+
end
63+
function Base.similar(::Type{A}, shape::Tuple{SOneTo, Vararg{SOneTo}}) where {A<:AbstractArray}
64+
R = similar(A, length.(shape))
65+
SizedArray{length.(shape)}(R)
66+
end
67+
5068
const SizedMatrixLike = Union{SizedMatrix, Transpose{<:Any, <:SizedMatrix}, Adjoint{<:Any, <:SizedMatrix}}
5169

5270
_data(S::SizedArray) = S.data

0 commit comments

Comments
 (0)