From 79f8b3bf5e047d319768a244a25dbacd6ca3442c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 4 May 2025 12:34:25 +0530 Subject: [PATCH 1/4] Unwrap triangular matrices in broadcast --- src/structuredbroadcast.jl | 13 ++++++++++++- src/triangular.jl | 6 ++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/structuredbroadcast.jl b/src/structuredbroadcast.jl index 76e33bc2..658a363b 100644 --- a/src/structuredbroadcast.jl +++ b/src/structuredbroadcast.jl @@ -269,13 +269,24 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle}) return dest end +# Recursively replace wrapped matrices by their parents to improve broadcasting performance +# We may do this because the indexing within `copyto!` is restricted to the stored indices +preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A) +function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T} + args = map(x -> preprocess_broadcasted(T, x), bc.args) + Broadcast.Broadcasted(bc.f, args, bc.axes) +end +_preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A) +_preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A) + function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle}) isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc)) axs = axes(dest) axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + bc2 = preprocess_broadcasted(LowerTriangular, bc) for j in axs[2] for i in j:axs[1][end] - @inbounds dest.data[i,j] = bc[CartesianIndex(i, j)] + @inbounds dest.data[i,j] = bc2[CartesianIndex(i, j)] end end return dest diff --git a/src/triangular.jl b/src/triangular.jl index b0674ac5..5b476d24 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -702,10 +702,8 @@ end uppertridata(A) = A lowertridata(A) = A -# we restrict these specializations only to strided matrices to avoid cases where an UpperTriangular type -# doesn't share its indexing with the parent -uppertridata(A::UpperTriangular{<:Any, <:StridedMatrix}) = parent(A) -lowertridata(A::LowerTriangular{<:Any, <:StridedMatrix}) = parent(A) +uppertridata(A::UpperTriangular) = parent(A) +lowertridata(A::LowerTriangular) = parent(A) @inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) = @stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta)) From b0e12f9edc13c18643d30c7fd371554ab842a802 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 4 May 2025 12:43:00 +0530 Subject: [PATCH 2/4] unwrap in UpperTriangular --- src/structuredbroadcast.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/structuredbroadcast.jl b/src/structuredbroadcast.jl index 658a363b..79a78dba 100644 --- a/src/structuredbroadcast.jl +++ b/src/structuredbroadcast.jl @@ -283,10 +283,10 @@ function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle} isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc)) axs = axes(dest) axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) - bc2 = preprocess_broadcasted(LowerTriangular, bc) + bc_unwrapped = preprocess_broadcasted(LowerTriangular, bc) for j in axs[2] for i in j:axs[1][end] - @inbounds dest.data[i,j] = bc2[CartesianIndex(i, j)] + @inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)] end end return dest @@ -296,9 +296,10 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle} isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc)) axs = axes(dest) axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + bc_unwrapped = preprocess_broadcasted(UpperTriangular, bc) for j in axs[2] for i in 1:j - @inbounds dest.data[i,j] = bc[CartesianIndex(i, j)] + @inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)] end end return dest From 73d06ca7fbb8470fb7e8ffd64bc874f22adce4c5 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 4 May 2025 20:52:51 +0530 Subject: [PATCH 3/4] Test nested broadcast --- test/structuredbroadcast.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/structuredbroadcast.jl b/test/structuredbroadcast.jl index e4930a65..9da4355a 100644 --- a/test/structuredbroadcast.jl +++ b/test/structuredbroadcast.jl @@ -388,4 +388,10 @@ end @test ind == CartesianIndex(1,1) end +@testset "nested triangular broadcast" begin + L = LowerTriangular(rand(Int,4,4)) + M = Matrix(L) + @test L .+ L .+ 0 .+ L .+ 0 .- L == 2M +end + end From b50674eb70e3d4aebdfbc2a4eff662756107d288 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 4 May 2025 20:53:51 +0530 Subject: [PATCH 4/4] Test UpperTriangular --- test/structuredbroadcast.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/structuredbroadcast.jl b/test/structuredbroadcast.jl index 9da4355a..5e6a68d7 100644 --- a/test/structuredbroadcast.jl +++ b/test/structuredbroadcast.jl @@ -389,9 +389,11 @@ end end @testset "nested triangular broadcast" begin - L = LowerTriangular(rand(Int,4,4)) - M = Matrix(L) - @test L .+ L .+ 0 .+ L .+ 0 .- L == 2M + for T in (LowerTriangular, UpperTriangular) + L = T(rand(Int,4,4)) + M = Matrix(L) + @test L .+ L .+ 0 .+ L .+ 0 .- L == 2M + end end end