Skip to content

Commit 1244a2c

Browse files
committed
Address andyferris' comments
1 parent 0772fb6 commit 1244a2c

File tree

2 files changed

+33
-35
lines changed

2 files changed

+33
-35
lines changed

src/SDiagonal.jl

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import Base: getindex,setindex!,==,-,+,*,/,\,transpose,ctranspose,convert, size, abs, real, imag, conj, eye, inv
55
import Base.LinAlg: ishermitian, issymmetric, isposdef, factorize, diag, trace, det, logdet, expm, logm, sqrtm
66

7-
@generated function scalem{T, M, N}(a::SMatrix{M,N, T}, b::SVector{N, T})
7+
@generated function scalem{M, N}(a::StaticMatrix{M,N}, b::StaticVector{N})
88
expr = vec([:(a[$j,$i]*b[$i]) for j=1:M, i=1:N])
9-
:(SMatrix{M,N,T}($(expr...)))
9+
:(let val1 = ($(expr[1])); similar_type(SMatrix{M,N},typeof(val1))(val1, $(expr[2:end]...)); end)
1010
end
11-
@generated function scalem{T, M, N}(a::SVector{M,T}, b::SMatrix{M, N, T})
11+
@generated function scalem{M, N}(a::StaticVector{M}, b::StaticMatrix{M, N})
1212
expr = vec([:(b[$j,$i]*a[$j]) for j=1:M, i=1:N])
13-
:(SMatrix{M,N,T}($(expr...)))
13+
:(let val1 = ($(expr[1])); similar_type(SMatrix{M,N},typeof(val1))(val1, $(expr[2:end]...)); end)
1414
end
1515

1616
struct SDiagonal{N,T} <: StaticMatrix{N, N, T}
@@ -22,35 +22,21 @@ diagtype{N}(::Type{SDiagonal{N}}) = SVector{N}
2222
diagtype(::Type{SDiagonal}) = SVector
2323

2424
# this is to deal with convert.jl
25-
@inline (::Type{SD})(a::AbstractVector) where {SD <: SDiagonal} = SDiagonal(diagtype(SD)(a))
26-
@inline (::Type{SD})(a::Tuple) where {SD <: SDiagonal} = SDiagonal(diagtype(SD)(a))
25+
@inline (::Type{SD})(a::AbstractVector) where {SD <: SDiagonal} = SDiagonal(convert(diagtype(SD), a))
26+
@inline (::Type{SD})(a::Tuple) where {SD <: SDiagonal} = SDiagonal(convert(diagtype(SD), a))
2727
@inline (::Type{SDiagonal}){N,T}(a::SVector{N,T}) = SDiagonal{N,T}(a)
2828

29-
@generated function SDiagonal{N,T}(a::SMatrix{N,N,T})
29+
@generated function SDiagonal{N,T}(a::StaticMatrix{N,N,T})
3030
expr = [:(a[$i,$i]) for i=1:N]
3131
:(SDiagonal{N,T}($(expr...)))
3232
end
3333

34-
3534
convert{N,T}(::Type{SDiagonal{N,T}}, D::SDiagonal{N,T}) = D
36-
convert{N,T}(::Type{SDiagonal{N,T}}, D::SDiagonal) = SDiagonal{N,T}(convert(SVector{N,T}, D.diag))
37-
38-
size{N}(D::SDiagonal{N}) = (N,N)
39-
40-
function size{N}(D::SDiagonal{N},d::Int)
41-
if d<1
42-
throw(ArgumentError("dimension must be ≥ 1, got $d"))
43-
end
44-
return d<=2 ? N : 1
45-
end
35+
convert{N,T}(::Type{SDiagonal{N,T}}, D::SDiagonal{N}) = SDiagonal{N,T}(convert(SVector{N,T}, D.diag))
4636

4737
Base.@propagate_inbounds function getindex{N,T}(D::SDiagonal{N,T}, i::Int, j::Int)
4838
@boundscheck checkbounds(D, i, j)
49-
if i == j
50-
@inbounds return D.diag[i]
51-
else
52-
zero(T)
53-
end
39+
@inbounds return ifelse(i == j, D.diag[i], zero(T))
5440
end
5541

5642
# avoid linear indexing?
@@ -76,11 +62,12 @@ factorize(D::SDiagonal) = D
7662
*{T<:Number}(D::SDiagonal, x::T) = SDiagonal(D.diag * x)
7763
/{T<:Number}(D::SDiagonal, x::T) = SDiagonal(D.diag / x)
7864
*(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag .* Db.diag)
79-
*(D::SDiagonal, V::SVector) = D.diag .* V
80-
*(V::SVector, D::SDiagonal) = D.diag .* V
81-
*(A::SMatrix, D::SDiagonal) = scalem(A,D.diag)
82-
*(D::SDiagonal, A::SMatrix) = scalem(D.diag,A)
83-
\(D::SDiagonal, b::SVector) = D.diag .\ b
65+
*(D::SDiagonal, V::AbstractVector) = D.diag .* V
66+
*(D::SDiagonal, V::StaticVector) = D.diag .* V
67+
*(A::StaticMatrix, D::SDiagonal) = scalem(A,D.diag)
68+
*(D::SDiagonal, A::StaticMatrix) = scalem(D.diag,A)
69+
\(D::SDiagonal, b::AbstractVector) = D.diag .\ b
70+
\(D::SDiagonal, b::StaticVector) = D.diag .\ b # catch ambiguity
8471

8572
conj(D::SDiagonal) = SDiagonal(conj(D.diag))
8673
transpose(D::SDiagonal) = D
@@ -101,17 +88,22 @@ expm(D::SDiagonal) = SDiagonal(exp.(D.diag))
10188
logm(D::SDiagonal) = SDiagonal(log.(D.diag))
10289
sqrtm(D::SDiagonal) = SDiagonal(sqrt.(D.diag))
10390

104-
\(D::SDiagonal, B::SMatrix) = scalem(1 ./ D.diag, B)
105-
/(B::SMatrix, D::SDiagonal) = scalem(1 ./ D.diag, B)
91+
\(D::SDiagonal, B::StaticMatrix) = scalem(1 ./ D.diag, B)
92+
/(B::StaticMatrix, D::SDiagonal) = scalem(1 ./ D.diag, B)
10693
\(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Db.diag ./ Da.diag)
10794
/(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag ./ Db.diag )
10895

109-
function inv{N,T}(D::SDiagonal{N,T})
110-
for i = 1:N
111-
if D.diag[i] == zero(T)
112-
throw(SingularException(i))
113-
end
96+
97+
@generated function check_singular{N,T}(D::SDiagonal{N,T})
98+
expr = Expr(:block)
99+
for i=1:N
100+
push!(expr.args, :(@inbounds iszero(D.diag[$i]) && throw(Base.LinAlg.SingularException($i))))
114101
end
102+
expr
103+
end
104+
105+
function inv{N,T}(D::SDiagonal{N,T})
106+
check_singular(D)
115107
SDiagonal(inv.(D.diag))
116108
end
117109

test/SDiagonal.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
@testset "Methods" begin
1818

19+
@test StaticArrays.scalem(@SMatrix([1 1 1;1 1 1; 1 1 1]), @SVector [1,2,3]) === @SArray [1 2 3; 1 2 3; 1 2 3]
20+
1921
m = SDiagonal(@SVector [11, 12, 13, 14])
2022
m2 = diagm([11, 12, 13, 14])
2123

@@ -54,6 +56,8 @@
5456
@test_throws Exception m[1] = 1
5557

5658
@test m*b == @SVector [22,-12,26,14]
59+
@test (b'*m)' == @SVector [22,-12,26,14]
60+
5761
@test m\b == m2\b
5862
@test m*m == m2*m
5963

@@ -68,5 +72,7 @@
6872
@test m\m == eye(SDiagonal{4,Float64})
6973

7074

75+
76+
7177
end
7278
end

0 commit comments

Comments
 (0)