diff --git a/src/structuredbroadcast.jl b/src/structuredbroadcast.jl index 76e33bc2..79a78dba 100644 --- a/src/structuredbroadcast.jl +++ b/src/structuredbroadcast.jl @@ -269,13 +269,24 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle}) return dest end +# Recursively replace wrapped matrices by their parents to improve broadcasting performance +# We may do this because the indexing within `copyto!` is restricted to the stored indices +preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A) +function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T} + args = map(x -> preprocess_broadcasted(T, x), bc.args) + Broadcast.Broadcasted(bc.f, args, bc.axes) +end +_preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A) +_preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A) + function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle}) isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc)) axs = axes(dest) axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + bc_unwrapped = preprocess_broadcasted(LowerTriangular, bc) for j in axs[2] for i in j:axs[1][end] - @inbounds dest.data[i,j] = bc[CartesianIndex(i, j)] + @inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)] end end return dest @@ -285,9 +296,10 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle} isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc)) axs = axes(dest) axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + bc_unwrapped = preprocess_broadcasted(UpperTriangular, bc) for j in axs[2] for i in 1:j - @inbounds dest.data[i,j] = bc[CartesianIndex(i, j)] + @inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)] end end return dest diff --git a/src/triangular.jl b/src/triangular.jl index b0674ac5..5b476d24 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -702,10 +702,8 @@ end uppertridata(A) = A lowertridata(A) = A -# we restrict these specializations only to strided matrices to avoid cases where an UpperTriangular type -# doesn't share its indexing with the parent -uppertridata(A::UpperTriangular{<:Any, <:StridedMatrix}) = parent(A) -lowertridata(A::LowerTriangular{<:Any, <:StridedMatrix}) = parent(A) +uppertridata(A::UpperTriangular) = parent(A) +lowertridata(A::LowerTriangular) = parent(A) @inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) = @stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta)) diff --git a/test/structuredbroadcast.jl b/test/structuredbroadcast.jl index e4930a65..5e6a68d7 100644 --- a/test/structuredbroadcast.jl +++ b/test/structuredbroadcast.jl @@ -388,4 +388,12 @@ end @test ind == CartesianIndex(1,1) end +@testset "nested triangular broadcast" begin + for T in (LowerTriangular, UpperTriangular) + L = T(rand(Int,4,4)) + M = Matrix(L) + @test L .+ L .+ 0 .+ L .+ 0 .- L == 2M + end +end + end