Skip to content

Commit 024068a

Browse files
committed
Merge pull request #1625 from JuliaLang/vs/expm
Performance improvements to expm
2 parents bb2eb1c + a6d2c61 commit 024068a

File tree

1 file changed

+49
-11
lines changed

1 file changed

+49
-11
lines changed

base/linalg_dense.jl

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -248,18 +248,34 @@ function expm!{T<:LapackType}(A::StridedMatrix{T})
248248
else
249249
C = [120.,60.,12.,1.]
250250
end
251+
C = convert(Array{T,1}, C)
251252
A2 = A * A
252253
P = copy(I)
253-
U = C[2] * P
254-
V = C[1] * P
254+
# U = C[2] * P
255+
# V = C[1] * P
256+
U = zeros(T, n, n)
257+
V = zeros(T, n, n)
258+
C2 = C[2]; C1 = C[1]
259+
for i=1:n
260+
U[i,i] = C2
261+
V[i,i] = C1
262+
end
255263
for k in 1:(div(size(C, 1), 2) - 1)
256264
k2 = 2 * k
257265
P *= A2
258-
U += C[k2 + 2] * P
259-
V += C[k2 + 1] * P
266+
#U += C[k2 + 2] * P
267+
#V += C[k2 + 1] * P
268+
Ck21 = C[k2 + 1]
269+
Ck22 = C[k2 + 2]
270+
for i=1:length(P)
271+
U[i] += Ck22 * P[i]
272+
V[i] += Ck21 * P[i]
273+
end
260274
end
261275
U = A * U
262-
X = (V - U)\(V + U)
276+
#X = (V - U)\(V + U)
277+
X = V + U
278+
LAPACK.gesv!(V-U, X)
263279
else
264280
s = log2(nA/5.4) # power of 2 later reversed by squaring
265281
if s > 0
@@ -271,15 +287,37 @@ function expm!{T<:LapackType}(A::StridedMatrix{T})
271287
670442572800., 33522128640., 1323241920.,
272288
40840800., 960960., 16380.,
273289
182., 1.]
290+
CC = convert(Array{T,1}, CC)
274291
A2 = A * A
275292
A4 = A2 * A2
276293
A6 = A2 * A4
277-
U = A * (A6 * (CC[14]*A6 + CC[12]*A4 + CC[10]*A2) +
278-
CC[8]*A6 + CC[6]*A4 + CC[4]*A2 + CC[2]*I)
279-
V = A6 * (CC[13]*A6 + CC[11]*A4 + CC[9]*A2) +
280-
CC[7]*A6 + CC[5]*A4 + CC[3]*A2 + CC[1]*I
281-
X = (V-U)\(V+U)
282-
294+
# U = A * (A6 * (CC[14]*A6 + CC[12]*A4 + CC[10]*A2) +
295+
# CC[8]*A6 + CC[6]*A4 + CC[4]*A2 + CC[2]*I)
296+
# V = A6 * (CC[13]*A6 + CC[11]*A4 + CC[9]*A2) +
297+
# CC[7]*A6 + CC[5]*A4 + CC[3]*A2 + CC[1]*I
298+
P1 = zeros(T, n, n)
299+
P2 = zeros(T, n, n)
300+
P3 = zeros(T, n, n)
301+
P4 = zeros(T, n, n)
302+
CC14 = CC[14]; CC12 = CC[12]; CC10 = CC[10]
303+
CC8 = CC[8]; CC6 = CC[6]; CC4 = CC[4]; CC2 = CC[2];
304+
CC13 = CC[13]; CC11 = CC[11]; CC9 = CC[9]
305+
CC7 = CC[7]; CC5 = CC[5]; CC3 = CC[3]; CC1 = CC[1]
306+
for i=1:length(I)
307+
P1[i] += CC14*A6[i] + CC12*A4[i] + CC10*A2[i]
308+
P2[i] += CC8*A6[i] + CC6*A4[i] + CC4*A2[i] + CC2*I[i]
309+
P3[i] += CC13*A6[i] + CC11*A4[i] + CC9*A2[i]
310+
P4[i] += CC7*A6[i] + CC5*A4[i] + CC3*A2[i] + CC1*I[i]
311+
end
312+
#U = A * (A6*P1 + P2)
313+
#V = A6*P3 + P4
314+
U = A * (BLAS.gemm!('N', 'N', one(T), A6, P1, one(T), P2))
315+
V = BLAS.gemm!('N', 'N', one(T), A6, P3, one(T), P4)
316+
317+
#X = (V-U)\(V+U)
318+
X = V + U
319+
LAPACK.gesv!(V-U, X)
320+
283321
if s > 0 # squaring to reverse dividing by power of 2
284322
for t in 1:si X *= X end
285323
end

0 commit comments

Comments
 (0)