4
4
import Base: getindex,setindex!,== ,- ,+ ,* ,/ ,\ ,transpose,ctranspose,convert, size, abs, real, imag, conj, eye, inv
5
5
import Base. LinAlg: ishermitian, issymmetric, isposdef, factorize, diag, trace, det, logdet, expm, logm, sqrtm
6
6
7
- @generated function scalem {T, M, N} (a:: SMatrix {M,N, T } , b:: SVector{N, T } )
7
+ @generated function scalem {M, N} (a:: StaticMatrix {M,N} , b:: StaticVector{N } )
8
8
expr = vec ([:(a[$ j,$ i]* b[$ i]) for j= 1 : M, i= 1 : N])
9
- :(SMatrix {M,N,T} ( $ (expr... )))
9
+ :(let val1 = ( $ (expr[ 1 ])); similar_type ( SMatrix{M,N}, typeof (val1))(val1, $ (expr[ 2 : end ] . .. )); end )
10
10
end
11
- @generated function scalem {T, M, N} (a:: SVector{M,T } , b:: SMatrix {M, N, T } )
11
+ @generated function scalem {M, N} (a:: StaticVector{M } , b:: StaticMatrix {M, N} )
12
12
expr = vec ([:(b[$ j,$ i]* a[$ j]) for j= 1 : M, i= 1 : N])
13
- :(SMatrix {M,N,T} ( $ (expr... )))
13
+ :(let val1 = ( $ (expr[ 1 ])); similar_type ( SMatrix{M,N}, typeof (val1))(val1, $ (expr[ 2 : end ] . .. )); end )
14
14
end
15
15
16
16
struct SDiagonal{N,T} <: StaticMatrix{N, N, T}
@@ -22,35 +22,21 @@ diagtype{N}(::Type{SDiagonal{N}}) = SVector{N}
22
22
diagtype (:: Type{SDiagonal} ) = SVector
23
23
24
24
# this is to deal with convert.jl
25
- @inline (:: Type{SD} )(a:: AbstractVector ) where {SD <: SDiagonal } = SDiagonal (diagtype (SD)( a))
26
- @inline (:: Type{SD} )(a:: Tuple ) where {SD <: SDiagonal } = SDiagonal (diagtype (SD)( a))
25
+ @inline (:: Type{SD} )(a:: AbstractVector ) where {SD <: SDiagonal } = SDiagonal (convert ( diagtype (SD), a))
26
+ @inline (:: Type{SD} )(a:: Tuple ) where {SD <: SDiagonal } = SDiagonal (convert ( diagtype (SD), a))
27
27
@inline (:: Type{SDiagonal} ){N,T}(a:: SVector{N,T} ) = SDiagonal {N,T} (a)
28
28
29
- @generated function SDiagonal {N,T} (a:: SMatrix {N,N,T} )
29
+ @generated function SDiagonal {N,T} (a:: StaticMatrix {N,N,T} )
30
30
expr = [:(a[$ i,$ i]) for i= 1 : N]
31
31
:(SDiagonal {N,T} ($ (expr... )))
32
32
end
33
33
34
-
35
34
convert {N,T} (:: Type{SDiagonal{N,T}} , D:: SDiagonal{N,T} ) = D
36
- convert {N,T} (:: Type{SDiagonal{N,T}} , D:: SDiagonal ) = SDiagonal {N,T} (convert (SVector{N,T}, D. diag))
37
-
38
- size {N} (D:: SDiagonal{N} ) = (N,N)
39
-
40
- function size {N} (D:: SDiagonal{N} ,d:: Int )
41
- if d< 1
42
- throw (ArgumentError (" dimension must be ≥ 1, got $d " ))
43
- end
44
- return d<= 2 ? N : 1
45
- end
35
+ convert {N,T} (:: Type{SDiagonal{N,T}} , D:: SDiagonal{N} ) = SDiagonal {N,T} (convert (SVector{N,T}, D. diag))
46
36
47
37
Base. @propagate_inbounds function getindex {N,T} (D:: SDiagonal{N,T} , i:: Int , j:: Int )
48
38
@boundscheck checkbounds (D, i, j)
49
- if i == j
50
- @inbounds return D. diag[i]
51
- else
52
- zero (T)
53
- end
39
+ @inbounds return ifelse (i == j, D. diag[i], zero (T))
54
40
end
55
41
56
42
# avoid linear indexing?
@@ -76,11 +62,12 @@ factorize(D::SDiagonal) = D
76
62
* {T<: Number }(D:: SDiagonal , x:: T ) = SDiagonal (D. diag * x)
77
63
/ {T<: Number }(D:: SDiagonal , x:: T ) = SDiagonal (D. diag / x)
78
64
* (Da:: SDiagonal , Db:: SDiagonal ) = SDiagonal (Da. diag .* Db. diag)
79
- * (D:: SDiagonal , V:: SVector ) = D. diag .* V
80
- * (V:: SVector , D:: SDiagonal ) = D. diag .* V
81
- * (A:: SMatrix , D:: SDiagonal ) = scalem (A,D. diag)
82
- * (D:: SDiagonal , A:: SMatrix ) = scalem (D. diag,A)
83
- \ (D:: SDiagonal , b:: SVector ) = D. diag .\ b
65
+ * (D:: SDiagonal , V:: AbstractVector ) = D. diag .* V
66
+ * (D:: SDiagonal , V:: StaticVector ) = D. diag .* V
67
+ * (A:: StaticMatrix , D:: SDiagonal ) = scalem (A,D. diag)
68
+ * (D:: SDiagonal , A:: StaticMatrix ) = scalem (D. diag,A)
69
+ \ (D:: SDiagonal , b:: AbstractVector ) = D. diag .\ b
70
+ \ (D:: SDiagonal , b:: StaticVector ) = D. diag .\ b # catch ambiguity
84
71
85
72
conj (D:: SDiagonal ) = SDiagonal (conj (D. diag))
86
73
transpose (D:: SDiagonal ) = D
@@ -101,17 +88,22 @@ expm(D::SDiagonal) = SDiagonal(exp.(D.diag))
101
88
logm (D:: SDiagonal ) = SDiagonal (log .(D. diag))
102
89
sqrtm (D:: SDiagonal ) = SDiagonal (sqrt .(D. diag))
103
90
104
- \ (D:: SDiagonal , B:: SMatrix ) = scalem (1 ./ D. diag, B)
105
- / (B:: SMatrix , D:: SDiagonal ) = scalem (1 ./ D. diag, B)
91
+ \ (D:: SDiagonal , B:: StaticMatrix ) = scalem (1 ./ D. diag, B)
92
+ / (B:: StaticMatrix , D:: SDiagonal ) = scalem (1 ./ D. diag, B)
106
93
\ (Da:: SDiagonal , Db:: SDiagonal ) = SDiagonal (Db. diag ./ Da. diag)
107
94
/ (Da:: SDiagonal , Db:: SDiagonal ) = SDiagonal (Da. diag ./ Db. diag )
108
95
109
- function inv {N,T} (D :: SDiagonal{N,T} )
110
- for i = 1 : N
111
- if D . diag[i] == zero (T )
112
- throw ( SingularException (i))
113
- end
96
+
97
+ @generated function check_singular {N,T} (D :: SDiagonal{N,T} )
98
+ expr = Expr ( :block )
99
+ for i = 1 : N
100
+ push! (expr . args, :( @inbounds iszero (D . diag[ $ i]) && throw (Base . LinAlg . SingularException ( $ i))))
114
101
end
102
+ expr
103
+ end
104
+
105
+ function inv {N,T} (D:: SDiagonal{N,T} )
106
+ check_singular (D)
115
107
SDiagonal (inv .(D. diag))
116
108
end
117
109
0 commit comments