Skip to content

Commit d5514e3

Browse files
dkarraschmaleadt
authored andcommitted
Partial 1.10 enablement (JuliaGPU#330)
Co-authored-by: Tim Besard <[email protected]>
1 parent b07fd5f commit d5514e3

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

.buildkite/pipeline.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@ steps:
3131
- "1.6"
3232
- "1.7"
3333
- "1.8"
34-
- "1.9-nightly"
34+
- "1.9"
35+
- "1.10-nightly"
3536
- "nightly"
3637
adjustments:
3738
- with:
3839
julia: "nightly"
3940
soft_fail: true
41+
- with:
42+
julia: "1.10-nightly"
43+
soft_fail: true
4044

4145
# Special tests
4246
- group: ":eyes: Special"

lib/mkl/linalg.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ if VERSION < v"1.10.0-DEV.1365"
137137
end
138138

139139
# triangular
140+
if isdefined(LinearAlgebra, :generic_trimatmul!) # VERSION >= v"1.10-DEVXYZ"
141+
# multiplication
142+
LinearAlgebra.generic_trimatmul!(c::oneStridedVector{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, b::AbstractVector{T}) where {T<:onemklFloat} =
143+
trmv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, c === b ? c : copyto!(c, b))
144+
# division
145+
LinearAlgebra.generic_trimatdiv!(C::oneStridedVector{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractVector{T}) where {T<:onemklFloat} =
146+
trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
147+
else
140148
## direct multiplication/division
141149
for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'),
142150
(:UnitLowerTriangular, 'L', 'U'),
@@ -183,6 +191,7 @@ for (t, uploc, isunitc) in ((:LowerTriangular, 'U', 'N'),
183191
trsv!($uploc, 'C', $isunitc, parent(parent(A)), B)
184192
end
185193
end
194+
end # VERSION
186195

187196

188197
#
@@ -254,23 +263,34 @@ end
254263
end # VERSION
255264

256265
# triangular
266+
if isdefined(LinearAlgebra, :generic_trimatmul!) # VERSION >= v"1.10-DEVXYZ"
267+
LinearAlgebra.generic_trimatmul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
268+
trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
269+
LinearAlgebra.generic_mattrimul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
270+
trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
271+
LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
272+
trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
273+
LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
274+
trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
275+
else
257276
## direct multiplication/division
258277
for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'),
259278
(:UnitLowerTriangular, 'L', 'U'),
260279
(:UpperTriangular, 'U', 'N'),
261280
(:UnitUpperTriangular, 'U', 'U'))
262281
@eval begin
263282
# Multiplication
264-
LinearAlgebra.lmul!(A::$t{T,<:oneStridedVecOrMat},
265-
B::oneStridedVecOrMat{T}) where {T<:onemklFloat} =
266-
trmm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B, B)
267-
LinearAlgebra.rmul!(A::oneStridedVecOrMat{T},
268-
B::$t{T,<:oneStridedVecOrMat}) where {T<:onemklFloat} =
269-
trmm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A, A)
283+
LinearAlgebra.lmul!(A::$t{T,<:oneStridedMatrix},
284+
B::oneStridedMatrix{T}) where {T<:onemklFloat} =
285+
trmm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B)
286+
LinearAlgebra.rmul!(A::oneStridedMatrix{T},
287+
B::$t{T,<:oneStridedMatrix}) where {T<:onemklFloat} =
288+
trmm!('R', $uploc, 'N', $isunitc, one(T), parent(B), A)
270289

271290
# Left division
272-
LinearAlgebra.ldiv!(A::$t{T,<:oneStridedVecOrMat},
273-
B::oneStridedVecOrMat{T}) where {T<:onemklFloat} =
291+
LinearAlgebra.ldiv!(A::$t{T,<:oneStridedMatrix},
292+
B::oneStridedMatrix{T}) where {T<:onemklFloat} =
274293
trsm!('L', $uploc, 'N', $isunitc, one(T), parent(A), B)
275294
end
276295
end
296+
end # VERSION

lib/mkl/wrappers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1246,7 +1246,7 @@ function trmm(side::Char,
12461246
alpha::Number,
12471247
A::oneStridedMatrix{T},
12481248
B::oneStridedMatrix{T}) where T
1249-
trmm!(side, uplo, transa, diag, alpha, A, B)
1249+
trmm!(side, uplo, transa, diag, alpha, A, copy(B))
12501250
end
12511251
function trsm(side::Char,
12521252
uplo::Char,

test/onemkl.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,9 +655,9 @@ end
655655
dA = oneArray(A)
656656
dB = oneArray(B)
657657
C = alpha*A*B
658-
oneMKL.trmm('L','U','N','N',alpha,dA,dB)
658+
dC = oneMKL.trmm('L','U','N','N',alpha,dA,dB)
659659
# move to host and compare
660-
h_C = Array(dB)
660+
h_C = Array(dC)
661661
@test C ≈ h_C
662662
end
663663

0 commit comments

Comments
 (0)