Skip to content

Use mul! in Diagonal*Matrix #42321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Oct 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
asin, asinh, atan, atanh, axes, big, broadcast, ceil, cis, conj, convert, copy, copyto!, cos,
cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat,
getproperty, imag, inv, isapprox, isequal, isone, iszero, IndexStyle, kron, kron!, length, log, map, ndims,
oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
one, oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
setindex!, show, similar, sin, sincos, sinh, size, sqrt,
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec, zero
using Base: IndexLinear, promote_eltype, promote_op, promote_typeof,
@propagate_inbounds, @pure, reduce, typed_hvcat, typed_vcat, require_one_based_indexing,
splat
Expand Down
140 changes: 85 additions & 55 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,37 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) =
Diagonal(Base.literal_pow.(^, D.diag, valp)) # for speed
Base.literal_pow(::typeof(^), D::Diagonal, ::Val{-1}) = inv(D) # for disambiguation

function _muldiag_size_check(A, B)
nA = size(A, 2)
mB = size(B, 1)
@noinline throw_dimerr(::AbstractMatrix, nA, mB) = throw(DimensionMismatch("second dimension of A, $nA, does not match first dimension of B, $mB"))
@noinline throw_dimerr(::AbstractVector, nA, mB) = throw(DimensionMismatch("second dimension of D, $nA, does not match length of V, $mB"))
nA == mB || throw_dimerr(B, nA, mB)
return nothing
end
# the output matrix should have the same size as the non-diagonal input matrix or vector
@noinline throw_dimerr(szC, szA) = throw(DimensionMismatch("output matrix has size: $szC, but should have size $szA"))
_size_check_out(C, ::Diagonal, A) = _size_check_out(C, A)
_size_check_out(C, A, ::Diagonal) = _size_check_out(C, A)
_size_check_out(C, A::Diagonal, ::Diagonal) = _size_check_out(C, A)
function _size_check_out(C, A)
szA = size(A)
szC = size(C)
szA == szC || throw_dimerr(szC, szA)
return nothing
end
function _muldiag_size_check(C, A, B)
_muldiag_size_check(A, B)
_size_check_out(C, A, B)
end

function (*)(Da::Diagonal, Db::Diagonal)
nDa, mDb = size(Da, 2), size(Db, 1)
if nDa != mDb
throw(DimensionMismatch("second dimension of Da, $nDa, does not match first dimension of Db, $mDb"))
end
_muldiag_size_check(Da, Db)
return Diagonal(Da.diag .* Db.diag)
end

function (*)(D::Diagonal, V::AbstractVector)
nD = size(D, 2)
if nD != length(V)
throw(DimensionMismatch("second dimension of D, $nD, does not match length of V, $(length(V))"))
end
_muldiag_size_check(D, V)
return D.diag .* V
end

Expand All @@ -220,29 +238,12 @@ end
lmul!(D, copy_oftype(B, promote_op(*, eltype(B), eltype(D.diag))))

(*)(A::AbstractMatrix, D::Diagonal) =
rmul!(copy_similar(A, promote_op(*, eltype(A), eltype(D.diag))), D)
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D)
(*)(D::Diagonal, A::AbstractMatrix) =
lmul!(D, copy_similar(A, promote_op(*, eltype(A), eltype(D.diag))))
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)

function rmul!(A::AbstractMatrix, D::Diagonal)
require_one_based_indexing(A)
nA, nD = size(A, 2), length(D.diag)
if nA != nD
throw(DimensionMismatch("second dimension of A, $nA, does not match the first of D, $nD"))
end
A .= A .* permutedims(D.diag)
return A
end

function lmul!(D::Diagonal, B::AbstractVecOrMat)
require_one_based_indexing(B)
nB, nD = size(B, 1), length(D.diag)
if nB != nD
throw(DimensionMismatch("second dimension of D, $nD, does not match the first of B, $nB"))
end
B .= D.diag .* B
return B
end
rmul!(A::AbstractMatrix, D::Diagonal) = mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = mul!(B, D, B)
Comment on lines +245 to +246
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rmul! and lmul! won't be vectorlized on master as it is a self-inplace broadcast. (4x slower)
Add @inline is useless here. Without fix like #43185, the only solution is recoding with for-loop.

But this should not be caught by CI as the test size is usually small.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a look, @N5N3! I vaguely remember that you proposed in a similar situation the fix by inlining the rhs. Why doesn't that help here?

Copy link
Member

@N5N3 N5N3 Mar 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, our broadcast unaliases the src from the dest to avoid memory contention, which prevents LLVM to prove dest === src.
Our for-loop based implement doesn't include such unaliasing stage, thus @inline is enough to enable simd.


rmul!(A::Union{LowerTriangular,UpperTriangular}, D::Diagonal) = typeof(A)(rmul!(A.data, D))
function rmul!(A::UnitLowerTriangular, D::Diagonal)
Expand Down Expand Up @@ -306,37 +307,66 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix})
lmul!(D, At)
end

rmul!(A::Diagonal, B::Diagonal) = Diagonal(A.diag .*= B.diag)
lmul!(A::Diagonal, B::Diagonal) = Diagonal(B.diag .= A.diag .* B.diag)
@inline function __muldiag!(out, D::Diagonal, B, alpha, beta)
if iszero(beta)
out .= (D.diag .* B) .*ₛ alpha
else
out .= (D.diag .* B) .*ₛ alpha .+ out .* beta
end
return out
end

@inline function __muldiag!(out, A, D::Diagonal, alpha, beta)
if iszero(beta)
out .= (A .* permutedims(D.diag)) .*ₛ alpha
else
out .= (A .* permutedims(D.diag)) .*ₛ alpha .+ out .* beta
end
return out
end

@inline function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta)
if iszero(beta)
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha
else
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .* beta
end
return out
end

# only needed for ambiguity resolution, as mul! is explicitly defined for these arguments
@inline __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) =
mul!(out, D1, D2, alpha, beta)

@inline function _muldiag!(out, A, B, alpha, beta)
_muldiag_size_check(out, A, B)
__muldiag!(out, A, B, alpha, beta)
return out
end

# Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat
@inline mul!(out::AbstractVector, A::Diagonal, in::AbstractVector, alpha::Number, beta::Number) =
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, A::Diagonal, in::AbstractMatrix, alpha::Number, beta::Number) =
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, A::Diagonal, in::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, A::Diagonal, in::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta

@inline mul!(out::AbstractMatrix, in::AbstractMatrix, A::Diagonal, alpha::Number, beta::Number) =
out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, in::Adjoint{<:Any,<:AbstractVecOrMat}, A::Diagonal,
alpha::Number, beta::Number) =
out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, in::Transpose{<:Any,<:AbstractVecOrMat}, A::Diagonal,
alpha::Number, beta::Number) =
out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) =
_muldiag!(out, D, V, alpha, beta)
@inline mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, alpha::Number, beta::Number) =
_muldiag!(out, D, B, alpha, beta)
@inline mul!(out::AbstractMatrix, D::Diagonal, B::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) = _muldiag!(out, D, B, alpha, beta)
@inline mul!(out::AbstractMatrix, D::Diagonal, B::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) = _muldiag!(out, D, B, alpha, beta)

@inline mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, alpha::Number, beta::Number) =
_muldiag!(out, A, D, alpha, beta)
@inline mul!(out::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, D::Diagonal,
alpha::Number, beta::Number) = _muldiag!(out, A, D, alpha, beta)
@inline mul!(out::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, D::Diagonal,
alpha::Number, beta::Number) = _muldiag!(out, A, D, alpha, beta)
@inline mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
_muldiag!(C, Da, Db, alpha, beta)

function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number)
mA = size(Da, 1)
mB = size(Db, 1)
mA == mB || throw(DimensionMismatch("A has dimensions ($mA,$mA) but B has dimensions ($mB,$mB)"))
mC, nC = size(C)
mC == nC == mA || throw(DimensionMismatch("output matrix has size: ($mC,$nC), but should have size ($mA,$mA)"))
_muldiag_size_check(C, Da, Db)
require_one_based_indexing(C)
mA = size(Da, 1)
da = Da.diag
db = Db.diag
_rmul_or_fill!(C, beta)
Expand Down
3 changes: 2 additions & 1 deletion stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
# inside this function.
function *ₛ end
Broadcast.broadcasted(::typeof(*ₛ), out, beta) =
iszero(beta::Number) ? false : broadcasted(*, out, beta)
iszero(beta::Number) ? false :
isone(beta::Number) ? broadcasted(identity, out) : broadcasted(*, out, beta)
Comment on lines 10 to +12
Copy link
Member

@N5N3 N5N3 Feb 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change might introduce more instability?
Usage in this PR should be union-splitable, and the branch should be eliminated in 3-args mul! and l/rmul!
But it might mess else where?


"""
MulAddMul(alpha, beta)
Expand Down
5 changes: 4 additions & 1 deletion stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,16 @@ function fill!(A::Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}, x)
not be filled with $x, since some of its entries are constrained."))
end

one(A::Diagonal{T}) where T = Diagonal(fill!(similar(A.diag, typeof(one(T))), one(T)))
one(D::Diagonal) = Diagonal(one.(D.diag))
one(A::Bidiagonal{T}) where T = Bidiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))), A.uplo)
one(A::Tridiagonal{T}) where T = Tridiagonal(fill!(similar(A.du, typeof(one(T))), zero(one(T))), fill!(similar(A.d, typeof(one(T))), one(T)), fill!(similar(A.dl, typeof(one(T))), zero(one(T))))
one(A::SymTridiagonal{T}) where T = SymTridiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))))
# equals and approx equals methods for structured matrices
# SymTridiagonal == Tridiagonal is already defined in tridiag.jl

zero(D::Diagonal) = Diagonal(zero.(D.diag))
oneunit(D::Diagonal) = Diagonal(oneunit.(D.diag))

# SymTridiagonal and Bidiagonal have the same field names
==(A::Diagonal, B::Union{SymTridiagonal, Bidiagonal}) = iszero(B.ev) && A.diag == B.dv
==(B::Bidiagonal, A::Diagonal) = A == B
Expand Down
80 changes: 78 additions & 2 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,41 @@ let D1 = Diagonal(rand(5)), D2 = Diagonal(rand(5))
@test LinearAlgebra.lmul!(adjoint(D1),copy(D2)) == adjoint(D1)*D2
end

@testset "multiplication of a Diagonal with a Matrix" begin
A = collect(reshape(1:8, 4, 2));
B = BigFloat.(A);
DL = Diagonal(collect(axes(A, 1)));
DR = Diagonal(Float16.(collect(axes(A, 2))));

@test DL * A == collect(DL) * A
@test A * DR == A * collect(DR)
@test DL * B == collect(DL) * B
@test B * DR == B * collect(DR)

A = reshape([ones(2,2), ones(2,2)*2, ones(2,2)*3, ones(2,2)*4], 2, 2)
Ac = collect(A)
D = Diagonal([collect(reshape(1:4, 2, 2)), collect(reshape(5:8, 2, 2))])
Dc = collect(D)
@test A * D == Ac * Dc
@test D * A == Dc * Ac
@test D * D == Dc * Dc

AS = similar(A)
mul!(AS, A, D, true, false)
@test AS == A * D

D2 = similar(D)
mul!(D2, D, D)
@test D2 == D * D

D2[diagind(D2)] .= D[diagind(D)]
lmul!(D, D2)
@test D2 == D * D
D2[diagind(D2)] .= D[diagind(D)]
rmul!(D2, D)
@test D2 == D * D
end

@testset "multiplication of QR Q-factor and Diagonal (#16615 spot test)" begin
D = Diagonal(randn(5))
Q = qr(randn(5, 5)).Q
Expand Down Expand Up @@ -686,12 +721,35 @@ end
xt = transpose(x)
A = reshape([[1 2; 3 4], zeros(Int,2,2), zeros(Int, 2, 2), [5 6; 7 8]], 2, 2)
D = Diagonal(A)
@test x'*D == x'*A == copy(x')*D == copy(x')*A
@test xt*D == xt*A == copy(xt)*D == copy(xt)*A
@test x'*D == x'*A == collect(x')*D == collect(x')*A
@test xt*D == xt*A == collect(xt)*D == collect(xt)*A
outadjxD = similar(x'*D); outtrxD = similar(xt*D);
mul!(outadjxD, x', D)
@test outadjxD == x'*D
mul!(outtrxD, xt, D)
@test outtrxD == xt*D

D1 = Diagonal([[1 2; 3 4]])
@test D1 * x' == D1 * collect(x') == collect(D1) * collect(x')
@test D1 * xt == D1 * collect(xt) == collect(D1) * collect(xt)
outD1adjx = similar(D1 * x'); outD1trx = similar(D1 * xt);
mul!(outadjxD, D1, x')
@test outadjxD == D1*x'
mul!(outtrxD, D1, xt)
@test outtrxD == D1*xt

y = [x, x]
yt = transpose(y)
@test y'*D*y == (y'*D)*y == (y'*A)*y
@test yt*D*y == (yt*D)*y == (yt*A)*y
outadjyD = similar(y'*D); outtryD = similar(yt*D);
outadjyD2 = similar(collect(y'*D)); outtryD2 = similar(collect(yt*D));
mul!(outadjyD, y', D)
mul!(outadjyD2, y', D)
@test outadjyD == outadjyD2 == y'*D
mul!(outtryD, yt, D)
mul!(outtryD2, yt, D)
@test outtryD == outtryD2 == yt*D
end

@testset "Multiplication of single element Diagonal (#36746, #40726)" begin
Expand Down Expand Up @@ -826,4 +884,22 @@ end
@test \(x, B) == /(B, x)
end

@testset "zero and one" begin
D1 = Diagonal(rand(3))
@test D1 + zero(D1) == D1
@test D1 * one(D1) == D1
@test D1 * oneunit(D1) == D1
@test oneunit(D1) isa typeof(D1)
D2 = Diagonal([collect(reshape(1:4, 2, 2)), collect(reshape(5:8, 2, 2))])
@test D2 + zero(D2) == D2
@test D2 * one(D2) == D2
@test D2 * oneunit(D2) == D2
@test oneunit(D2) isa typeof(D2)
D3 = Diagonal([D2, D2]);
@test D3 + zero(D3) == D3
@test D3 * one(D3) == D3
@test D3 * oneunit(D3) == D3
@test oneunit(D3) isa typeof(D3)
end

end # module TestDiagonal