Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,24 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
return dest
end

# Recursively replace wrapped matrices by their parents to improve broadcasting performance
# We may do this because the indexing within `copyto!` is restricted to the stored indices
preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A)
function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T}
args = map(x -> preprocess_broadcasted(T, x), bc.args)
Broadcast.Broadcasted(bc.f, args, bc.axes)
end
_preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A)
_preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A)

function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
bc_unwrapped = preprocess_broadcasted(LowerTriangular, bc)
for j in axs[2]
for i in j:axs[1][end]
@inbounds dest.data[i,j] = bc[CartesianIndex(i, j)]
@inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)]
end
end
return dest
Expand All @@ -285,9 +296,10 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
bc_unwrapped = preprocess_broadcasted(UpperTriangular, bc)
for j in axs[2]
for i in 1:j
@inbounds dest.data[i,j] = bc[CartesianIndex(i, j)]
@inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)]
end
end
return dest
Expand Down
6 changes: 2 additions & 4 deletions src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,10 +702,8 @@ end

uppertridata(A) = A
lowertridata(A) = A
# we restrict these specializations only to strided matrices to avoid cases where an UpperTriangular type
# doesn't share its indexing with the parent
uppertridata(A::UpperTriangular{<:Any, <:StridedMatrix}) = parent(A)
lowertridata(A::LowerTriangular{<:Any, <:StridedMatrix}) = parent(A)
uppertridata(A::UpperTriangular) = parent(A)
lowertridata(A::LowerTriangular) = parent(A)

@inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) =
@stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta))
Expand Down
8 changes: 8 additions & 0 deletions test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,4 +388,12 @@ end
@test ind == CartesianIndex(1,1)
end

@testset "nested triangular broadcast" begin
for T in (LowerTriangular, UpperTriangular)
L = T(rand(Int,4,4))
M = Matrix(L)
@test L .+ L .+ 0 .+ L .+ 0 .- L == 2M
end
end

end