Skip to content

Commit d738142

Browse files
committed
Specialize LinearAlgebra.BLAS.dot for strided vectors of floats.
Fixes #37767.
1 parent 67b9d4d commit d738142

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

stdlib/LinearAlgebra/src/blas.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,53 @@ for (fname, elty) in ((:cblas_zdotu_sub,:ComplexF64),
434434
end
435435
end
436436

437+
@inline function _dot_length_check(x,y)
438+
n = length(x)
439+
if n != length(y)
440+
throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))"))
441+
end
442+
n
443+
end
444+
445+
for (elty, f) in ((Float32, :dot), (Float64, :dot),
446+
(ComplexF32, :dotc), (ComplexF64, :dotc),
447+
(ComplexF32, :dotu), (ComplexF64, :dotu))
448+
@eval begin
449+
function $f(x::DenseArray{$elty}, y::DenseArray{$elty})
450+
n = _dot_length_check(x,y)
451+
$f(n, x, 1, y, 1)
452+
end
453+
454+
function $f(x::StridedVector{$elty}, y::DenseArray{$elty})
455+
n = _dot_length_check(x,y)
456+
xstride = stride(x,1)
457+
ystride = stride(y,1)
458+
# see subarray.jl / pointer(x)
459+
# (for x a SubArray of a Vector, pointer(x) points past the last element
460+
# in the case of negative stride)
461+
x_delta = xstride < 0 ? n : 1
462+
GC.@preserve x $f(n,pointer(x,x_delta),xstride,y,ystride)
463+
end
464+
465+
function $f(x::DenseArray{$elty}, y::StridedVector{$elty})
466+
n = _dot_length_check(x,y)
467+
xstride = stride(x,1)
468+
ystride = stride(y,1)
469+
y_delta = ystride < 0 ? n : 1
470+
GC.@preserve y $f(n,x,xstride,pointer(y,y_delta),ystride)
471+
end
472+
473+
function $f(x::StridedVector{$elty}, y::StridedVector{$elty})
474+
n = _dot_length_check(x,y)
475+
xstride = stride(x,1)
476+
ystride = stride(y,1)
477+
x_delta = xstride < 0 ? n : 1
478+
y_delta = ystride < 0 ? n : 1
479+
GC.@preserve x y $f(n,pointer(x,x_delta),xstride,pointer(y,y_delta),ystride)
480+
end
481+
end
482+
end
483+
437484
function dot(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasReal
438485
require_one_based_indexing(DX, DY)
439486
n = length(DX)

stdlib/LinearAlgebra/test/matmul.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,19 @@ end
205205
@test *(Asub, adjoint(Asub)) == *(Aref, adjoint(Aref))
206206
end
207207

208+
@testset "dot product of subarrays of vectors (floats, negative stride, issue #37767)" begin
209+
for T in (Float32, Float64, ComplexF32, ComplexF64)
210+
a = Vector{T}(3:2:7)
211+
b = Vector{T}(1:10)
212+
v = view(b,7:-2:3)
213+
@test dot(a,Vector(v)) 67.0
214+
@test dot(a,v) 67.0
215+
@test dot(v,a) 67.0
216+
@test dot(Vector(v),Vector(v)) 83.0
217+
@test dot(v,v) 83.0
218+
end
219+
end
220+
208221
@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T1 in (Float32,Float64)
209222
for T2 in (Float32,Float64)
210223
for arg1_real in (true,false)

0 commit comments

Comments
 (0)