diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 3bbcab9432190..21312ca356b08 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -317,8 +317,9 @@ function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta return C end -(/)(A::AbstractVecOrMat, D::Diagonal) = - _rdiv!(similar(A, promote_op(/, eltype(A), eltype(D))), A, D) +_promote_dotop(f, args...) = promote_op(f, eltype.(args)...) + +/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(similar(A, _promote_dotop(/, A, D), size(A)), A, D) rdiv!(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(A, A, D) # avoid copy when possible via internal 3-arg backend @@ -338,21 +339,10 @@ function _rdiv!(B::AbstractVecOrMat, A::AbstractVecOrMat, D::Diagonal) end B end -# Optimization for Diagonal / Diagonal -function _rdiv!(Dc::Diagonal, Db::Diagonal, Da::Diagonal) - n, k = length(Db.diag), length(Db.diag) - n == k || throw(DimensionMismatch("left hand side has $n columns but D is $k by $k")) - j = findfirst(iszero, Da.diag) - isnothing(j) || throw(SingularException(j)) - Dc.diag .= Db.diag ./ Da.diag - Dc -end -(\)(D::Diagonal, B::AbstractVecOrMat) = - ldiv!(similar(B, promote_op(\, eltype(D), eltype(B))), D, B) +\(D::Diagonal, B::AbstractVecOrMat) = ldiv!(similar(B, _promote_dotop(\, D, B), size(B)), D, B) ldiv!(D::Diagonal, B::AbstractVecOrMat) = ldiv!(B, D, B) -ldiv!(Dc::Diagonal, Da::Diagonal, Db::Diagonal) = Diagonal(ldiv!(Dc.diag, Da, Db.diag)) function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat) require_one_based_indexing(A, B) d = length(D.diag) @@ -365,6 +355,19 @@ function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat) B .= D.diag .\ A end +# Optimizations for \, / between Diagonals +\(D::Diagonal, B::Diagonal) = ldiv!(similar(B, _promote_dotop(\, D, B)), D, B) +/(A::Diagonal, D::Diagonal) = _rdiv!(similar(A, _promote_dotop(/, A, D)), A, D) +function _rdiv!(Dc::Diagonal, Db::Diagonal, Da::Diagonal) + n, k = length(Db.diag), length(Db.diag) + n == k || throw(DimensionMismatch("left hand side has $n columns but D is $k by $k")) + j = findfirst(iszero, Da.diag) + isnothing(j) || throw(SingularException(j)) + Dc.diag .= Db.diag ./ Da.diag + Dc +end +ldiv!(Dc::Diagonal, Da::Diagonal, Db::Diagonal) = Diagonal(ldiv!(Dc.diag, Da, Db.diag)) + # (l/r)mul!, l/rdiv!, *, / and \ Optimization for AbstractTriangular. # These functions are generally more efficient if we calculate the whole data field. # The following code implements them in a unified patten to avoid missing. diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 3f21ad32ec0a6..bd2ddfdbff7f4 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -204,8 +204,8 @@ Random.seed!(1) @test D2*D' ≈ Array(D2)*Array(D)' #division of two Diagonals - @test D/D2 ≈ Diagonal(D.diag./D2.diag) - @test D\D2 ≈ Diagonal(D2.diag./D.diag) + @test (D/D2)::Diagonal ≈ Diagonal(D.diag./D2.diag) + @test (D\D2)::Diagonal ≈ Diagonal(D2.diag./D.diag) # QR \ Diagonal A = rand(elty, n, n) diff --git a/stdlib/LinearAlgebra/test/hessenberg.jl b/stdlib/LinearAlgebra/test/hessenberg.jl index 48f813595c92a..9b623273666c2 100644 --- a/stdlib/LinearAlgebra/test/hessenberg.jl +++ b/stdlib/LinearAlgebra/test/hessenberg.jl @@ -90,12 +90,10 @@ let n = 10 @testset "Multiplication/division" begin for x = (5, 5I, Diagonal(d), Bidiagonal(d,dl,:U), UpperTriangular(A), UnitUpperTriangular(A)) - @test H*x == Array(H)*x broken = eltype(H) <: Furlong && x isa Bidiagonal - @test x*H == x*Array(H) broken = eltype(H) <: Furlong && x isa Bidiagonal + @test (H*x)::UpperHessenberg == Array(H)*x broken = eltype(H) <: Furlong && x isa Bidiagonal + @test (x*H)::UpperHessenberg == x*Array(H) broken = eltype(H) <: Furlong && x isa Bidiagonal @test H/x == Array(H)/x broken = eltype(H) <: Furlong && x isa Union{Bidiagonal, UpperTriangular} @test x\H == x\Array(H) broken = eltype(H) <: Furlong && x isa Union{Bidiagonal, UpperTriangular} - @test H*x isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Bidiagonal - @test x*H isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Bidiagonal @test H/x isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Bidiagonal @test x\H isa UpperHessenberg broken = eltype(H) <: Furlong && x isa Bidiagonal end