Skip to content
Closed
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ New library functions
* `isunordered(x)` returns true if `x` is value that is normally unordered, such as `NaN` or `missing`.
* New macro `Base.@invokelatest f(args...; kwargs...)` provides a convenient way to call `Base.invokelatest(f, args...; kwargs...)` ([#37971])
* New macro `Base.@invoke f(arg1::T1, arg2::T2; kwargs...)` provides an easier syntax to call `invoke(f, Tuple{T1,T2}, arg1, arg2; kwargs...)` ([#38438])
* New function `Base.eachstoredindex` returns an index iterator over the structural non-zero indices of a container ([#40103])

New library features
--------------------
Expand Down
10 changes: 10 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,16 @@ function _all_match_first(f::F, inds, A, B...) where F<:Function
end
_all_match_first(f::F, inds) where F<:Function = true

"""
eachstoredindex(A)

Returns an iterable over the indices of `A` where the values are structurally non-zero.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to define this as being tied to structural nonzeros? Or simply stored values? We could simultaneously introduce a unstoredvalue to give room for other sorts of structures.

In any case, I like this and think it's very needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. What would be expected for a FillArrays.Ones? In the case of stored values, it would return an empty iterator? I am not sure we can make this useful in general, except for constructors?

It falls back to `eachindex(A)` and can be redefined by array types.
`eachstoredindex(A)` is not guaranteed to return the same shape as `eachindex(A)`.
"""
eachstoredindex(A) = eachindex(A)


# keys with an IndexStyle
keys(s::IndexStyle, A::AbstractArray, B::AbstractArray...) = eachindex(s, A, B...)

Expand Down
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ end

parent(D::Diagonal) = D.diag

Base.eachstoredindex(D::Diagonal) = diagind(D)

ishermitian(D::Diagonal{<:Real}) = true
ishermitian(D::Diagonal{<:Number}) = isreal(D.diag)
ishermitian(D::Diagonal) = all(ishermitian, D.diag)
Expand Down
29 changes: 28 additions & 1 deletion stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1515,7 +1515,7 @@ function findnz(S::AbstractSparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
V = Vector{Tv}(undef, numnz)

count = 1
@inbounds for col = 1 : size(S, 2), k = getcolptr(S)[col] : (getcolptr(S)[col+1]-1)
@inbounds for col in 1:size(S, 2), k in nzrange(S, col)
I[count] = rowvals(S)[k]
J[count] = col
V[count] = nonzeros(S)[k]
Expand All @@ -1525,6 +1525,33 @@ function findnz(S::AbstractSparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
return (I, J, V)
end

function Base.eachstoredindex(S::AbstractSparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
numnz = nnz(S)
indices = Vector{CartesianIndex{2}}(undef, numnz)
count = 1
@inbounds for col in 1:size(S, 2), k in nzrange(S, col)
indices[count] = CartesianIndex(rowvals(S)[k], col)
count += 1
end
return indices
end

function Base.eachstoredindex(S::Union{Adjoint{T, ST}, Transpose{T, ST}}) where {T, ST <: AbstractSparseMatrixCSC{T}}
P = parent(S)
numnz = nnz(P)
indices = Vector{CartesianIndex{2}}(undef, numnz)
count = 1
@inbounds for col = 1 : size(P, 2), k = getcolptr(P)[col] : (getcolptr(P)[col+1]-1)
indices[count] = CartesianIndex(col, rowvals(P)[k])
count += 1
end
return indices
end

function Base.eachstoredindex(S::Union{Symmetric{T, ST}, LinearAlgebra.AbstractTriangular{T, ST}, UpperHessenberg{T, ST}, Hermitian{T, ST}}) where {T, ST <: AbstractSparseMatrixCSC{T}}
return Base.eachstoredindex(parent(S))
end

function _sparse_findnextnz(m::AbstractSparseMatrixCSC, ij::CartesianIndex{2})
row, col = Tuple(ij)
col > size(m, 2) && return nothing
Expand Down
3 changes: 3 additions & 0 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ function nonzeroinds(x::SparseColumnView)
@inbounds y = view(rowvals(A), nzrange(A, colidx))
return y
end

Base.eachstoredindex(x::SparseVector) = getfield(x, :nzind)

nonzeroinds(x::SparseVectorView) = nonzeroinds(parent(x))

rowvals(x::SparseVectorUnion) = nonzeroinds(x)
Expand Down
10 changes: 10 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1833,7 +1833,9 @@ end
# test that stored zeros are still stored zeros in the diagonal
S = sparse([1,3],[1,3],[0.0,0.0]); V = diag(S)
@test nonzeroinds(V) == [1,3]
@test Base.eachstoredindex(V) == [1,3]
@test nonzeros(V) == [0.0,0.0]
@test V[Base.eachstoredindex(V)] == nonzeros(V)
end

@testset "expandptr" begin
Expand Down Expand Up @@ -2875,6 +2877,14 @@ end
@test SparseMatrixCSC(at(wr(A))) == Matrix(at(wr(B)))
end

@testset "eachstoredindex($(wr))" for wr in (UpperTriangular, LowerTriangular,
UnitUpperTriangular, UnitLowerTriangular,
Hermitian, (Hermitian, :L), Symmetric, (Symmetric, :L), Transpose, Adjoint)
S = dowrap(wr, A)
sum(S[Base.eachstoredindex(S)]) == sum(S)
sum(S[Base.eachstoredindex(S)]) == sum(Matrix(S))
end

@test sparse([1,2,3,4,5]') == SparseMatrixCSC([1 2 3 4 5])
@test sparse(UpperTriangular(A')) == UpperTriangular(B')
@test sparse(Adjoint(UpperTriangular(A'))) == Adjoint(UpperTriangular(B'))
Expand Down
2 changes: 2 additions & 0 deletions stdlib/SparseArrays/test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ x1_full[SparseArrays.nonzeroinds(spv_x1)] = nonzeros(spv_x1)
@test count(!iszero, x) == 3
@test nnz(x) == 3
@test SparseArrays.nonzeroinds(x) == [2, 5, 6]
@test Base.eachstoredindex(x) == [2, 5, 6]
@test nonzeros(x) == [1.25, -0.75, 3.5]
@test nonzeros(x) == x[Base.eachstoredindex(x)]
@test count(SparseVector(8, [2, 5, 6], [true,false,true])) == 2
end

Expand Down