Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3099,5 +3099,16 @@ function _keepat!(a::AbstractVector, m::AbstractVector{Bool})
end
end
deleteat!(a, j:lastindex(a))
end

## 1-d circshift ##
function circshift!(a::AbstractVector, shift::Integer)
n = length(a)
n == 0 && return
shift = mod(shift, n)
shift == 0 && return
reverse!(a, 1, shift)
reverse!(a, shift+1, length(a))
reverse!(a)
return a
end
12 changes: 12 additions & 0 deletions base/combinatorics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ function swapcols!(a::AbstractMatrix, i, j)
@inbounds a[k,i],a[k,j] = a[k,j],a[k,i]
end
end

# swap rows i and j of a, in-place
function swaprows!(a::AbstractMatrix, i, j)
i == j && return
rows = axes(a,1)
@boundscheck i in rows || throw(BoundsError(a, (:,i)))
@boundscheck j in rows || throw(BoundsError(a, (:,j)))
for k in axes(a,2)
@inbounds a[i,k],a[j,k] = a[j,k],a[i,k]
end
end

# like permute!! applied to each row of a, in-place in a (overwriting p).
function permutecols!!(a::AbstractMatrix, p::AbstractVector{<:Integer})
require_one_based_indexing(a, p)
Expand Down
90 changes: 89 additions & 1 deletion stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2189,7 +2189,7 @@ end
getindex(A::AbstractSparseMatrixCSC, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2])

function getindex(A::AbstractSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T
if !(1 <= i0 <= size(A, 1) && 1 <= i1 <= size(A, 2)); throw(BoundsError()); end
@boundscheck checkbounds(A, i0, i1)
r1 = Int(getcolptr(A)[i1])
r2 = Int(getcolptr(A)[i1+1]-1)
(r1 > r2) && return zero(T)
Expand Down Expand Up @@ -3840,3 +3840,91 @@ end

circshift!(O::AbstractSparseMatrixCSC, X::AbstractSparseMatrixCSC, (r,)::Base.DimsInteger{1}) = circshift!(O, X, (r,0))
circshift!(O::AbstractSparseMatrixCSC, X::AbstractSparseMatrixCSC, r::Real) = circshift!(O, X, (Integer(r),0))

## swaprows! / swapcols!
macro swap(a, b)
esc(:(($a, $b) = ($b, $a)))
end

function Base.swapcols!(A::AbstractSparseMatrixCSC, i, j)
i == j && return

# For simplicitly, let i denote the smaller of the two columns
j < i && @swap(i, j)

colptr = getcolptr(A)
irow = colptr[i]:(colptr[i+1]-1)
jrow = colptr[j]:(colptr[j+1]-1)

function rangeexchange!(arr, irow, jrow)
if length(irow) == length(jrow)
for (a, b) in zip(irow, jrow)
@inbounds @swap(arr[i], arr[j])
end
return
end
# This is similar to the triple-reverse tricks for
# circshift!, except that we have three ranges here,
# so it ends up being 4 reverse calls (but still
# 2 overall reversals for the memory range). Like
# circshift!, there's also a cycle chasing algorithm
# with optimal memory complexity, but the performance
# tradeoffs against this implementation are non-trivial,
# so let's just do this simple thing for now.
# See https://github.com/JuliaLang/julia/pull/42676 for
# discussion of circshift!-like algorithms.
reverse!(@view arr[irow])
reverse!(@view arr[jrow])
reverse!(@view arr[(last(irow)+1):(first(jrow)-1)])
reverse!(@view arr[first(irow):last(jrow)])
end
rangeexchange!(rowvals(A), irow, jrow)
rangeexchange!(nonzeros(A), irow, jrow)

if length(irow) != length(jrow)
@inbounds colptr[i+1:j] .+= length(jrow) - length(irow)
end
return nothing
end

function Base.swaprows!(A::AbstractSparseMatrixCSC, i, j)
# For simplicitly, let i denote the smaller of the two rows
j < i && @swap(i, j)

rows = rowvals(A)
vals = nonzeros(A)
for col = 1:size(A, 2)
rr = nzrange(A, col)
iidx = searchsortedfirst(@view(rows[rr]), i)
has_i = iidx <= length(rr) && rows[rr[iidx]] == i

jrange = has_i ? (iidx:last(rr)) : rr
jidx = searchsortedlast(@view(rows[jrange]), j)
has_j = jidx != 0 && rows[jrange[jidx]] == j

if !has_j && !has_i
# Has neither row - nothing to do
continue
elseif has_i && has_j
# This column had both i and j rows - swap them
@swap(vals[rr[iidx]], vals[jrange[jidx]])
elseif has_i
# Update the rowval and then rotate both nonzeros
# and the remaining rowvals into the correct place
rows[rr[iidx]] = j
jidx == 0 && continue
rotate_range = rr[iidx]:jrange[jidx]
circshift!(@view(vals[rotate_range]), -1)
circshift!(@view(rows[rotate_range]), -1)
else
# Same as i, but in the opposite direction
@assert has_j
rows[jrange[jidx]] = i
iidx > length(rr) && continue
rotate_range = rr[iidx]:jrange[jidx]
circshift!(@view(vals[rotate_range]), 1)
circshift!(@view(rows[rotate_range]), 1)
end
end
return nothing
end
18 changes: 2 additions & 16 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2085,18 +2085,6 @@ function fill!(A::Union{SparseVector, AbstractSparseMatrixCSC}, x)
return A
end



# in-place swaps (dense) blocks start:split and split+1:fin in col
function _swap!(col::AbstractVector, start::Integer, fin::Integer, split::Integer)
split == fin && return
reverse!(col, start, split)
reverse!(col, split + 1, fin)
reverse!(col, start, fin)
return
end


# in-place shifts a sparse subvector by r. Used also by sparsematrix.jl
function subvector_shifter!(R::AbstractVector, V::AbstractVector, start::Integer, fin::Integer, m::Integer, r::Integer)
split = fin
Expand All @@ -2110,16 +2098,14 @@ function subvector_shifter!(R::AbstractVector, V::AbstractVector, start::Integer
end
end
# ...but rowval should be sorted within columns
_swap!(R, start, fin, split)
_swap!(V, start, fin, split)
circshift!(@view(R[start:fin]), split-start+1)
circshift!(@view(V[start:fin]), split-start+1)
end


function circshift!(O::SparseVector, X::SparseVector, (r,)::Base.DimsInteger{1})
copy!(O, X)
subvector_shifter!(nonzeroinds(O), nonzeros(O), 1, length(nonzeroinds(O)), length(O), mod(r, length(X)))
return O
end


circshift!(O::SparseVector, X::SparseVector, r::Real,) = circshift!(O, X, (Integer(r),))
22 changes: 22 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3294,4 +3294,26 @@ end
@test eval(Meta.parse(repr(m))) == m
end

using Base: swaprows!, swapcols!
@testset "swaprows!, swapcols!" begin
S = sparse(
[ 0 0 0 0 0 0
0 -1 1 1 0 0
0 0 0 1 1 0
0 0 1 1 1 -1])

for (f!, i, j) in
((swaprows!, 1, 2), # Test swapping rows where one row is fully sparse
(swaprows!, 2, 3), # Test swapping rows of unequal length
(swaprows!, 2, 4), # Test swapping non-adjacent rows
(swapcols!, 1, 2), # Test swapping columns where one column is fully sparse
(swapcols!, 2, 3), # Test swapping coulms of unequal length
(swapcols!, 2, 4)) # Test swapping non-adjacent columns
Scopy = copy(S)
Sdense = Array(S)
f!(Scopy, i, j); f!(Sdense, i, j)
@test Scopy == Sdense
end
end

end # module