Skip to content

Commit 2b06d09

Browse files
committed
Specialize LinearAlgebra.BLAS.dot for strided vectors of floats.
Fixes #37767.
1 parent c79309b commit 2b06d09

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

stdlib/LinearAlgebra/src/blas.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,50 @@ for (fname, elty) in ((:cblas_zdotu_sub,:ComplexF64),
344344
end
345345
end
346346

347+
@inline function _dot_length_check(x,y)
348+
n = length(x)
349+
if n != length(y)
350+
throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))"))
351+
end
352+
n
353+
end
354+
355+
for (elty, f) in ((Float32, :dot), (Float64, :dot),
356+
(ComplexF32, :dotc), (ComplexF64, :dotc),
357+
(ComplexF32, :dotu), (ComplexF64, :dotu))
358+
@eval begin
359+
function $f(x::DenseArray{$elty}, y::DenseArray{$elty})
360+
n = _dot_length_check(x,y)
361+
$f(n, x, 1, y, 1)
362+
end
363+
364+
function $f(x::StridedVector{$elty}, y::DenseArray{$elty})
365+
n = _dot_length_check(x,y)
366+
xstride = stride(x,1)
367+
ystride = stride(y,1)
368+
x_delta = xstride < 0 ? n : 1
369+
GC.@preserve x $f(n,pointer(x,x_delta),xstride,y,ystride)
370+
end
371+
372+
function $f(x::DenseArray{$elty}, y::StridedVector{$elty})
373+
n = _dot_length_check(x,y)
374+
xstride = stride(x,1)
375+
ystride = stride(y,1)
376+
y_delta = ystride < 0 ? n : 1
377+
GC.@preserve y $f(n,x,xstride,pointer(y,y_delta),ystride)
378+
end
379+
380+
function $f(x::StridedVector{$elty}, y::StridedVector{$elty})
381+
n = _dot_length_check(x,y)
382+
xstride = stride(x,1)
383+
ystride = stride(y,1)
384+
x_delta = xstride < 0 ? n : 1
385+
y_delta = ystride < 0 ? n : 1
386+
GC.@preserve x y $f(n,pointer(x,x_delta),xstride,pointer(y,y_delta),ystride)
387+
end
388+
end
389+
end
390+
347391
function dot(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasReal
348392
require_one_based_indexing(DX, DY)
349393
n = length(DX)

stdlib/LinearAlgebra/test/matmul.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,19 @@ end
205205
@test *(Asub, adjoint(Asub)) == *(Aref, adjoint(Aref))
206206
end
207207

208+
@testset "dot product of subarrays of vectors (floats, negative stride, issue #37767)" begin
209+
for T in (Float32, Float64, ComplexF32, ComplexF64)
210+
a = Vector{T}(3:2:7)
211+
b = Vector{T}(1:10)
212+
v = view(b,7:-2:3)
213+
@test dot(a,Vector(v)) 67.0
214+
@test dot(a,v) 67.0
215+
@test dot(v,a) 67.0
216+
@test dot(Vector(v),Vector(v)) 83.0
217+
@test dot(v,v) 83.0
218+
end
219+
end
220+
208221
@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T1 in (Float32,Float64)
209222
for T2 in (Float32,Float64)
210223
for arg1_real in (true,false)

0 commit comments

Comments
 (0)