Skip to content

Commit ef04e4b

Browse files
committed
parametrize SymTridiagonal on the wrapped vector type
1 parent 0f2a014 commit ef04e4b

File tree

5 files changed

+87
-26
lines changed

5 files changed

+87
-26
lines changed

NEWS.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ This section lists changes that do not have deprecation warnings.
7373
longer present. Use `first(R)` and `last(R)` to obtain
7474
start/stop. ([#20974])
7575

76-
* The `Diagonal` type definition has changed from `Diagonal{T}` to
77-
`Diagonal{T,V<:AbstractVector{T}}` ([#22718]).
76+
* The `Diagonal`, `Bidiagonal` and `SymTridiagonal` type definitions have changed from
77+
`Diagonal{T}`, `Bidiagonal{T}` and `SymTridiagonal{T}` to
78+
`Diagonal{T,V<:AbstractVector{T}}`, `Bidiagonal{T,V<:AbstractVector{T}}`
79+
and `SymTridiagonal{T,V<:AbstractVector{T}}` respectively ([#22718], [#22925], [#23035]).
7880

7981
* Spaces are no longer allowed between `@` and the name of a macro in a macro call ([#22868]).
8082

@@ -142,8 +144,9 @@ Library improvements
142144

143145
* `Char`s can now be concatenated with `String`s and/or other `Char`s using `*` ([#22532]).
144146

145-
* `Diagonal` is now parameterized on the type of the wrapped vector. This allows
146-
for `Diagonal` matrices with arbitrary `AbstractVector`s ([#22718]).
147+
* `Diagonal`, `Bidiagonal` and `SymTridiagonal` are now parameterized on the type
148+
of the wrapped vectors, allowing `Diagonal`, `Bidiagonal` and `SymTridiagonal`
149+
matrices with arbitrary `AbstractVector`s ([#22718], [#22925], [#23035]).
147150

148151
* Mutating versions of `randperm` and `randcycle` have been added:
149152
`randperm!` and `randcycle!` ([#22723]).

base/linalg/tridiag.jl

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,55 +3,80 @@
33
#### Specialized matrix types ####
44

55
## (complex) symmetric tridiagonal matrices
6-
struct SymTridiagonal{T} <: AbstractMatrix{T}
7-
dv::Vector{T} # diagonal
8-
ev::Vector{T} # subdiagonal
9-
function SymTridiagonal{T}(dv::Vector{T}, ev::Vector{T}) where T
6+
struct SymTridiagonal{T,V<:AbstractVector{T}} <: AbstractMatrix{T}
7+
dv::V # diagonal
8+
ev::V # subdiagonal
9+
function SymTridiagonal{T}(dv::V, ev::V) where {T,V<:AbstractVector{T}}
1010
if !(length(dv) - 1 <= length(ev) <= length(dv))
1111
throw(DimensionMismatch("subdiagonal has wrong length. Has length $(length(ev)), but should be either $(length(dv) - 1) or $(length(dv))."))
1212
end
13-
new(dv,ev)
13+
new{T,V}(dv,ev)
1414
end
1515
end
1616

1717
"""
1818
SymTridiagonal(dv, ev)
1919
20-
Construct a symmetric tridiagonal matrix from the diagonal and first sub/super-diagonal,
21-
respectively. The result is of type `SymTridiagonal` and provides efficient specialized
22-
eigensolvers, but may be converted into a regular matrix with
23-
[`convert(Array, _)`](@ref) (or `Array(_)` for short).
20+
Construct a symmetric tridiagonal matrix from the diagonal (`dv`) and first
21+
sub/super-diagonal (`ev`), respectively. The result is of type `SymTridiagonal`
22+
and provides efficient specialized eigensolvers, but may be converted into a
23+
regular matrix with [`convert(Array, _)`](@ref) (or `Array(_)` for short).
2424
2525
# Examples
2626
```jldoctest
27-
julia> dv = [1; 2; 3; 4]
27+
julia> dv = [1, 2, 3, 4]
2828
4-element Array{Int64,1}:
2929
1
3030
2
3131
3
3232
4
3333
34-
julia> ev = [7; 8; 9]
34+
julia> ev = [7, 8, 9]
3535
3-element Array{Int64,1}:
3636
7
3737
8
3838
9
3939
4040
julia> SymTridiagonal(dv, ev)
41-
4×4 SymTridiagonal{Int64}:
41+
4×4 SymTridiagonal{Int64,Array{Int64,1}}:
4242
1 7 ⋅ ⋅
4343
7 2 8 ⋅
4444
⋅ 8 3 9
4545
⋅ ⋅ 9 4
4646
```
4747
"""
48-
SymTridiagonal(dv::Vector{T}, ev::Vector{T}) where {T} = SymTridiagonal{T}(dv, ev)
48+
SymTridiagonal(dv::V, ev::V) where {T,V<:AbstractVector{T}} = SymTridiagonal{T}(dv, ev)
49+
50+
function SymTridiagonal(dv::AbstractVector{T}, ev::AbstractVector{S}) where {T,S}
51+
R = promote_type(T, S)
52+
SymTridiagonal(convert(AbstractVector{R}, dv), convert(AbstractVector{R}, ev))
53+
end
4954

50-
function SymTridiagonal(dv::AbstractVector{Td}, ev::AbstractVector{Te}) where {Td,Te}
51-
T = promote_type(Td,Te)
55+
function SymTridiagonal(dv::AbstractVector{T}, ev::AbstractVector{T}) where T
5256
SymTridiagonal(convert(Vector{T}, dv), convert(Vector{T}, ev))
5357
end
5458

59+
"""
60+
SymTridiagonal(A::AbstractMatrix)
61+
62+
Construct a symmetric tridiagonal matrix from the
63+
diagonal and first sub/super-diagonal, of `A`.
64+
65+
# Examples
66+
```jldoctest
67+
julia> A = [1 2 3; 2 4 5; 3 5 6]
68+
3×3 Array{Int64,2}:
69+
1 2 3
70+
2 4 5
71+
3 5 6
72+
73+
julia> SymTridiagonal(A)
74+
3×3 SymTridiagonal{Int64}:
75+
1 2 ⋅
76+
2 4 5
77+
⋅ 5 6
78+
```
79+
"""
5580
function SymTridiagonal(A::AbstractMatrix)
5681
if diag(A,1) == diag(A,-1)
5782
SymTridiagonal(diag(A), diag(A,1))
@@ -61,10 +86,10 @@ function SymTridiagonal(A::AbstractMatrix)
6186
end
6287

6388
convert(::Type{SymTridiagonal{T}}, S::SymTridiagonal) where {T} =
64-
SymTridiagonal(convert(Vector{T}, S.dv), convert(Vector{T}, S.ev))
89+
SymTridiagonal(convert(AbstractVector{T}, S.dv), convert(AbstractVector{T}, S.ev))
6590
convert(::Type{AbstractMatrix{T}}, S::SymTridiagonal) where {T} =
66-
SymTridiagonal(convert(Vector{T}, S.dv), convert(Vector{T}, S.ev))
67-
function convert(::Type{Matrix{T}}, M::SymTridiagonal{T}) where T
91+
SymTridiagonal(convert(AbstractVector{T}, S.dv), convert(AbstractVector{T}, S.ev))
92+
function convert(::Type{Matrix{T}}, M::SymTridiagonal) where T
6893
n = size(M, 1)
6994
Mf = zeros(T, n, n)
7095
@inbounds begin
@@ -311,7 +336,7 @@ end
311336
# R. Usmani, "Inversion of a tridiagonal Jacobi matrix",
312337
# Linear Algebra and its Applications 212-213 (1994), pp.413-414
313338
# doi:10.1016/0024-3795(94)90414-6
314-
function inv_usmani(a::Vector{T}, b::Vector{T}, c::Vector{T}) where T
339+
function inv_usmani(a::V, b::V, c::V) where {T,V<:AbstractVector{T}}
315340
n = length(b)
316341
θ = ZeroOffsetVector(zeros(T, n+1)) #principal minors of A
317342
θ[0] = 1
@@ -341,7 +366,7 @@ end
341366

342367
#Implements the determinant using principal minors
343368
#Inputs and reference are as above for inv_usmani()
344-
function det_usmani(a::Vector{T}, b::Vector{T}, c::Vector{T}) where T
369+
function det_usmani(a::V, b::V, c::V) where {T,V<:AbstractVector{T}}
345370
n = length(b)
346371
θa = one(T)
347372
if n == 0
@@ -635,7 +660,7 @@ convert(::Type{AbstractMatrix{T}},M::Tridiagonal) where {T} = convert(Tridiagona
635660
convert(::Type{Tridiagonal{T}}, M::SymTridiagonal{T}) where {T} = Tridiagonal(M)
636661
function convert(::Type{SymTridiagonal{T}}, M::Tridiagonal) where T
637662
if M.dl == M.du
638-
return SymTridiagonal(convert(Vector{T},M.d), convert(Vector{T},M.dl))
663+
return SymTridiagonal{T}(convert(AbstractVector{T},M.d), convert(AbstractVector{T},M.dl))
639664
else
640665
throw(ArgumentError("Tridiagonal is not symmetric, cannot convert to SymTridiagonal"))
641666
end

base/test.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,10 @@ end
13151315
GenericArray{T}(args...) where {T} = GenericArray(Array{T}(args...))
13161316
GenericArray{T,N}(args...) where {T,N} = GenericArray(Array{T,N}(args...))
13171317

1318+
function Base.convert(::Type{AbstractArray{T,N}}, a::GenericArray) where {T,N}
1319+
GenericArray{T,N}(convert(Array{T,N}, a.a))
1320+
end
1321+
13181322
Base.eachindex(a::GenericArray) = eachindex(a.a)
13191323
Base.indices(a::GenericArray) = indices(a.a)
13201324
Base.length(a::GenericArray) = length(a.a)

test/linalg/tridiag.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ for elty in (Float32, Float64, Complex64, Complex128, Int)
2828
v = convert(Vector{elty}, v)
2929
B = convert(Matrix{elty}, B)
3030
end
31+
32+
@testset "constructor" begin
33+
for (x, y) in ((d, dl), (GenericArray(d), GenericArray(dl)))
34+
ST = SymTridiagonal(x, y)
35+
@test ST::SymTridiagonal{elty, typeof(x)} == Matrix(ST)
36+
@test ST.dv === x
37+
@test ST.ev === y
38+
end
39+
end
3140
ε = eps(abs2(float(one(elty))))
3241
T = Tridiagonal(dl, d, du)
3342
Ts = SymTridiagonal(d, dl)
@@ -225,6 +234,26 @@ for elty in (Float32, Float64, Complex64, Complex128, Int)
225234
end
226235
end
227236

237+
@testset "arbitrary AbstractVector in SymTridiagonal" begin
238+
S, T = Float32, Float64
239+
R = promote_type(T, S)
240+
dvT = rand(T, n); dvS = rand(S, n)
241+
evT = rand(T, n-1); evS = rand(S, n-1)
242+
GdvT = GenericArray(dvT); GdvS = GenericArray(dvS)
243+
GevT = GenericArray(evT); GevS = GenericArray(evS)
244+
@test isa(SymTridiagonal(dvT, evT), SymTridiagonal{T,Vector{T}})
245+
@test isa(SymTridiagonal(dvT, evS), SymTridiagonal{R,Vector{R}})
246+
@test isa(SymTridiagonal(dvS, evT), SymTridiagonal{R,Vector{R}})
247+
248+
@test isa(SymTridiagonal(GdvT, GevT), SymTridiagonal{T,GenericArray{T,1}})
249+
@test isa(SymTridiagonal(GdvT, GevS), SymTridiagonal{R,GenericArray{R,1}})
250+
@test isa(SymTridiagonal(GdvS, GevT), SymTridiagonal{R,GenericArray{R,1}})
251+
252+
@test isa(SymTridiagonal(GdvT, evT), SymTridiagonal{T,Vector{T}})
253+
@test isa(SymTridiagonal(GdvT, evS), SymTridiagonal{R,Vector{R}})
254+
@test isa(SymTridiagonal(GdvS, evT), SymTridiagonal{R,Vector{R}})
255+
end
256+
228257
#Test equivalence of eigenvectors/singular vectors taking into account possible phase (sign) differences
229258
function test_approx_eq_vecs(a::StridedVecOrMat{S}, b::StridedVecOrMat{T}, error=nothing) where {S<:Real,T<:Real}
230259
n = size(a, 1)

test/show.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ A = reshape(1:16,4,4)
558558
@test replstr(Diagonal(A)) == "4×4 Diagonal{$(Int),Array{$(Int),1}}:\n 1 ⋅ ⋅ ⋅\n ⋅ 6 ⋅ ⋅\n ⋅ ⋅ 11 ⋅\n ⋅ ⋅ ⋅ 16"
559559
@test replstr(Bidiagonal(A,:U)) == "4×4 Bidiagonal{$Int}:\n 1 5 ⋅ ⋅\n ⋅ 6 10 ⋅\n ⋅ ⋅ 11 15\n ⋅ ⋅ ⋅ 16"
560560
@test replstr(Bidiagonal(A,:L)) == "4×4 Bidiagonal{$Int}:\n 1 ⋅ ⋅ ⋅\n 2 6 ⋅ ⋅\n ⋅ 7 11 ⋅\n ⋅ ⋅ 12 16"
561-
@test replstr(SymTridiagonal(A+A')) == "4×4 SymTridiagonal{$Int}:\n 2 7 ⋅ ⋅\n 7 12 17 ⋅\n ⋅ 17 22 27\n ⋅ ⋅ 27 32"
561+
@test replstr(SymTridiagonal(A+A')) == "4×4 SymTridiagonal{$(Int),Array{$(Int),1}}:\n 2 7 ⋅ ⋅\n 7 12 17 ⋅\n ⋅ 17 22 27\n ⋅ ⋅ 27 32"
562562
@test replstr(Tridiagonal(diag(A,-1),diag(A),diag(A,+1))) == "4×4 Tridiagonal{$Int}:\n 1 5 ⋅ ⋅\n 2 6 10 ⋅\n ⋅ 7 11 15\n ⋅ ⋅ 12 16"
563563
@test replstr(UpperTriangular(copy(A))) == "4×4 UpperTriangular{$Int,Array{$Int,2}}:\n 1 5 9 13\n ⋅ 6 10 14\n ⋅ ⋅ 11 15\n ⋅ ⋅ ⋅ 16"
564564
@test replstr(LowerTriangular(copy(A))) == "4×4 LowerTriangular{$Int,Array{$Int,2}}:\n 1 ⋅ ⋅ ⋅\n 2 6 ⋅ ⋅\n 3 7 11 ⋅\n 4 8 12 16"

0 commit comments

Comments
 (0)