Skip to content
Merged
120 changes: 83 additions & 37 deletions src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
Base.isstored(A::UpperOrLowerTriangular, i::Int, j::Int) =
_shouldforwardindex(A, i, j) ? Base.isstored(A.data, i, j) : false

@propagate_inbounds getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} =
_shouldforwardindex(A, i, j) ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T))
@propagate_inbounds getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) =
_shouldforwardindex(A, i, j) ? A.data[i,j] : diagzero(A,i,j)
@propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T}
if _shouldforwardindex(A, i, j)
A.data[i,j]
else
@boundscheck checkbounds(A, i, j)
ifelse(i == j, oneunit(T), zero(T))
end
end
@propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int)
if _shouldforwardindex(A, i, j)
A.data[i,j]
else
@boundscheck checkbounds(A, i, j)
@inbounds diagzero(A,i,j)
end
end

_shouldforwardindex(U::UpperTriangular, b::BandIndex) = b.band >= 0
_shouldforwardindex(U::LowerTriangular, b::BandIndex) = b.band <= 0
Expand All @@ -250,63 +262,97 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0

# these specialized getindex methods enable constant-propagation of the band
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, b::BandIndex) where {T}
_shouldforwardindex(A, b) ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
if _shouldforwardindex(A, b)
A.data[b]
else
@boundscheck checkbounds(A, b)
ifelse(b.band == 0, oneunit(T), zero(T))
end
end
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, b::BandIndex)
_shouldforwardindex(A, b) ? A.data[b] : diagzero(A.data, b)
if _shouldforwardindex(A, b)
A.data[b]
else
@boundscheck checkbounds(A, b)
@inbounds diagzero(A, b)
end
end

_zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower"
_zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"

@noinline function throw_nonzeroerror(T, @nospecialize(x), i, j)
Ts = _zero_triangular_half_str(T)
Tn = nameof(T)
@noinline function throw_nonzeroerror(Tn::Symbol, @nospecialize(x), i, j)
zero_half = Tn in (:UpperTriangular, :UnitUpperTriangular) ? "lower" : "upper"
nstr = Tn === :UpperTriangular ? "n" : ""
throw(ArgumentError(
lazy"cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)"))
LazyString(
lazy"cannot set index ($i, $j) in the $zero_half triangular part ",
lazy"of a$nstr $Tn matrix to a nonzero value ($x)")
)
)
end
@noinline function throw_nononeerror(T, @nospecialize(x), i, j)
Tn = nameof(T)
@noinline function throw_nonuniterror(Tn::Symbol, @nospecialize(x), i, j)
throw(ArgumentError(
lazy"cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)"))
lazy"cannot set index ($i, $j) on the diagonal of a $Tn matrix to a non-unit value ($x)"))
end

@propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer)
if i > j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
return A
end

@propagate_inbounds function setindex!(A::UnitUpperTriangular, x, i::Integer, j::Integer)
if i > j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
elseif i == j
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
if i == j # diagonal
x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j)
else
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
end
return A
end

@propagate_inbounds function setindex!(A::LowerTriangular, x, i::Integer, j::Integer)
if i < j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
return A
end

@propagate_inbounds function setindex!(A::UnitLowerTriangular, x, i::Integer, j::Integer)
if i < j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
elseif i == j
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
if i == j # diagonal
x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j)
else
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
end
return A
end
Expand Down Expand Up @@ -560,7 +606,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
@eval @inline function _copy!(A::$UT, B::$T)
for dind in diagind(A, IndexStyle(A))
if A[dind] != B[dind]
throw_nononeerror(typeof(A), B[dind], Tuple(dind)...)
throw_nonuniterror(nameof(typeof(A)), B[dind], Tuple(dind)...)
end
end
_copy!($T(parent(A)), B)
Expand Down Expand Up @@ -741,7 +787,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
checksize1(A, B)
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
for i in firstindex(B.data,1):(j - 1)
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
end
Expand All @@ -752,7 +798,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
checksize1(A, B)
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
for i in firstindex(B.data,1):(j - 1)
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
end
Expand Down Expand Up @@ -783,7 +829,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
checksize1(A, B)
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
for i in (j + 1):lastindex(B.data,1)
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
end
Expand All @@ -794,7 +840,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
checksize1(A, B)
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
for i in (j + 1):lastindex(B.data,1)
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
end
Expand Down
79 changes: 77 additions & 2 deletions test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -641,11 +641,11 @@ end
@testset "error message" begin
A = UpperTriangular(Ap)
B = UpperTriangular(Bp)
@test_throws "cannot set index in the lower triangular part" copyto!(A, B)
@test_throws "cannot set index (3, 1) in the lower triangular part" copyto!(A, B)

A = LowerTriangular(Ap)
B = LowerTriangular(Bp)
@test_throws "cannot set index in the upper triangular part" copyto!(A, B)
@test_throws "cannot set index (1, 2) in the upper triangular part" copyto!(A, B)
end
end

Expand Down Expand Up @@ -950,6 +950,10 @@ end
@test 2\U == 2\M
@test U*2 == M*2
@test 2*U == 2*M

U2 = copy(U)
@test rmul!(U, 1) == U2
@test lmul!(1, U) == U2
end

@testset "scaling partly initialized unit triangular" begin
Expand All @@ -966,4 +970,75 @@ end
end
end

@testset "indexing checks" begin
P = [1 2; 3 4]
@testset "getindex" begin
U = UnitUpperTriangular(P)
@test_throws BoundsError U[0,0]
@test_throws BoundsError U[1,0]
@test_throws BoundsError U[BandIndex(0,0)]
@test_throws BoundsError U[BandIndex(-1,0)]

U = UpperTriangular(P)
@test_throws BoundsError U[1,0]
@test_throws BoundsError U[BandIndex(-1,0)]

L = UnitLowerTriangular(P)
@test_throws BoundsError L[0,0]
@test_throws BoundsError L[0,1]
@test_throws BoundsError U[BandIndex(0,0)]
@test_throws BoundsError U[BandIndex(1,0)]

L = LowerTriangular(P)
@test_throws BoundsError L[0,1]
@test_throws BoundsError L[BandIndex(1,0)]
end
@testset "setindex!" begin
A = SizedArrays.SizedArray{(2,2)}(P)
M = fill(A, 2, 2)
U = UnitUpperTriangular(M)
@test_throws "Cannot `convert` an object of type $Int" U[1,1] = 1
non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitUpperTriangular matrix to a non-unit value"
@test_throws non_unit_msg U[1,1] = A
L = UnitLowerTriangular(M)
@test_throws "Cannot `convert` an object of type $Int" L[1,1] = 1
non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitLowerTriangular matrix to a non-unit value"
@test_throws non_unit_msg L[1,1] = A

for UT in (UnitUpperTriangular, UpperTriangular)
U = UT(M)
@test_throws "Cannot `convert` an object of type $Int" U[2,1] = 0
end
for LT in (UnitLowerTriangular, LowerTriangular)
L = LT(M)
@test_throws "Cannot `convert` an object of type $Int" L[1,2] = 0
end

U = UnitUpperTriangular(P)
@test_throws BoundsError U[0,0] = 1
@test_throws BoundsError U[1,0] = 0

U = UpperTriangular(P)
@test_throws BoundsError U[1,0] = 0

L = UnitLowerTriangular(P)
@test_throws BoundsError L[0,0] = 1
@test_throws BoundsError L[0,1] = 0

L = LowerTriangular(P)
@test_throws BoundsError L[0,1] = 0
end
end

@testset "unit triangular l/rdiv!" begin
A = rand(3,3)
@testset for (UT,T) in ((UnitUpperTriangular, UpperTriangular),
(UnitLowerTriangular, LowerTriangular))
UnitTri = UT(A)
Tri = T(LinearAlgebra.full(UnitTri))
@test 2 \ UnitTri ≈ 2 \ Tri
@test UnitTri / 2 ≈ Tri / 2
end
end

end # module TestTriangular