Skip to content

Commit 0a42082

Browse files
Add slicing functionality (#165)
Co-authored-by: Jeff Fessler <[email protected]>
1 parent 12c13fb commit 0a42082

File tree

5 files changed

+241
-1
lines changed

5 files changed

+241
-1
lines changed

docs/src/history.md

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Version history
22

33
## What's new in v3.7
4+
45
* `mul!(M::AbstractMatrix, A::LinearMap, s::Number, a, b)` methods are provided, mimicking
56
similar methods in `Base.LinearAlgebra`. This version allows for the memory efficient
67
implementation of in-place addition and conversion of a `LinearMap` to `Matrix`.
@@ -9,7 +10,6 @@
910
conversion, construction, and inplace addition will benefit from this supplied efficient
1011
implementation. If no specialisation is supplied, a generic fallback is used that is based
1112
on feeding the canonical basis of unit vectors to the linear map.
12-
1313
* A new map type called `EmbeddedMap` is introduced. It is a wrapper of a "small" `LinearMap`
1414
(or a suitably converted `AbstractVecOrMat`) embedded into a "larger" zero map. Hence,
1515
the "small" map acts only on a subset of the coordinates and maps to another subset of
@@ -24,6 +24,31 @@
2424
* `LinearMap(A::MapOrVecOrMat, dims::Dims{2}; offset::Dims{2})`, where the keyword
2525
argument `offset` determines the dimension of a virtual upper-left zero block, to which
2626
`A` gets (virtually) diagonally appended.
27+
* An often requested new feature has been added: slicing (i.e., non-scalar indexing) any
28+
`LinearMap` object via `Base.getindex` overloads. Note, however, that only rather
29+
efficient complete slicing operations are implemented: `A[:,j]`, `A[:,J]`, and `A[:,:]`,
30+
where `j::Integer` and `J` is either of type `AbstractVector{<:Integer>}` or an
31+
`AbstractVector{Bool}` of appropriate length ("logical slicing"). Partial slicing
32+
operations such as `A[I,j]` and `A[I,J]` where `I` is as `J` above are disallowed.
33+
34+
Scalar indexing `A[i::Integer,j::Integer]` as well as other indexing operations that fall
35+
back on scalar indexing such as logical indexing by some `AbstractMatrix{Bool}`, or
36+
indexing by vectors of (linear or Cartesian) indices are not supported; as an exception,
37+
`getindex` calls on wrapped `AbstractVecOrMat`s is forwarded to corresponding `getindex`
38+
methods from `Base` and therefore allow any type of usual indexing/slicing.
39+
If scalar indexing is really required, consider using `A[:,j][i]` which is as efficient
40+
as a reasonable generic implementation for `LinearMap`s can be.
41+
42+
Furthermore, (predominantly) horizontal slicing operations require the adjoint operation
43+
of the `LinearMap` type to be defined, or will fail otherwise. Important note:
44+
`LinearMap` objects are meant to model objects that act on vectors efficiently, and are
45+
in general *not* backed up by storage-like types like `Array`s. Therefore, slicing of
46+
`LinearMap`s is potentially slow, and it may require the (repeated) allocation of
47+
standard unit vectors. As a consequence, generic algorithms relying heavily on indexing
48+
and/or slicing are likely to run much slower than expected for `AbstractArray`s. To avoid
49+
repeated indexing operations which may involve redundant computations, it is strongly
50+
recommended to consider `convert`ing `LinearMap`-typed objects to `Matrix` or
51+
`SparseMatrixCSC` first, if memory permits.
2752

2853
## What's new in v3.6
2954

src/LinearMaps.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ include("fillmap.jl") # linear maps representing constantly filled matrices
345345
include("embeddedmap.jl") # embedded linear maps
346346
include("conversion.jl") # conversion of linear maps to matrices
347347
include("show.jl") # show methods for LinearMap objects
348+
include("getindex.jl") # getindex functionality
348349

349350
"""
350351
LinearMap(A::LinearMap; kwargs...)::WrappedMap

src/getindex.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
const Indexer = AbstractVector{<:Integer}
2+
3+
Base.IndexStyle(::LinearMap) = IndexCartesian()
4+
# required in Base.to_indices for [:]-indexing (only size check)
5+
Base.eachindex(::IndexLinear, A::LinearMap) = Base.OneTo(length(A))
6+
# Base.lastindex(A::LinearMap) = last(eachindex(IndexLinear(), A))
7+
# Base.firstindex(A::LinearMap) = first(eachindex(IndexLinear(), A))
8+
9+
function Base.checkbounds(A::LinearMap, i, j)
10+
Base.checkbounds_indices(Bool, axes(A), (i, j)) || throw(BoundsError(A, (i, j)))
11+
nothing
12+
end
13+
function Base.checkbounds(A::LinearMap, i)
14+
Base.checkindex(Bool, Base.OneTo(length(A)), i) || throw(BoundsError(A, i))
15+
nothing
16+
end
17+
# checkbounds in indexing via CartesianIndex
18+
Base.checkbounds(A::LinearMap, i::Union{CartesianIndex{2}, AbstractArray{CartesianIndex{2}}}) =
19+
Base.checkbounds_indices(Bool, axes(A), (i,)) || throw(BoundsError(A, i))
20+
Base.checkbounds(A::LinearMap, I::AbstractMatrix{Bool}) =
21+
axes(A) == axes(I) || throw(BoundsError(A, I))
22+
23+
# main entry point
24+
function Base.getindex(A::LinearMap, I...)
25+
@boundscheck checkbounds(A, I...)
26+
_getindex(A, Base.to_indices(A, I)...)
27+
end
28+
# quick pass forward
29+
Base.@propagate_inbounds Base.getindex(A::ScaledMap, I...) = A.λ * A.lmap[I...]
30+
Base.@propagate_inbounds Base.getindex(A::WrappedMap, I...) = A.lmap[I...]
31+
32+
########################
33+
# linear indexing
34+
########################
35+
_getindex(A::LinearMap, _) = error("linear indexing of LinearMaps is not supported")
36+
37+
########################
38+
# Cartesian indexing (partial slicing is not supported)
39+
########################
40+
_getindex(A::LinearMap, i::Integer, j::Integer) =
41+
error("scalar indexing of LinearMaps is not supported, consider using A[:,j][i] instead")
42+
_getindex(A::LinearMap, I::Indexer, j::Integer) =
43+
error("partial vertical slicing of LinearMaps is not supported, consider using A[:,j][I] instead")
44+
_getindex(A::LinearMap, i::Integer, J::Indexer) =
45+
error("partial horizontal slicing of LinearMaps is not supported, consider using A[i,:][J] instead")
46+
_getindex(A::LinearMap, I::Indexer, J::Indexer) =
47+
error("partial two-dimensional slicing of LinearMaps is not supported, consider using A[:,J][I] or A[I,:][J] instead")
48+
49+
_getindex(A::LinearMap, ::Base.Slice, j::Integer) = A*unitvec(A, 2, j)
50+
function _getindex(A::LinearMap, i::Integer, J::Base.Slice)
51+
try
52+
# requires adjoint action to be defined
53+
return vec(unitvec(A, 1, i)'A)
54+
catch
55+
error("horizontal slicing A[$i,:] requires the adjoint of $(typeof(A)) to be defined")
56+
end
57+
end
58+
_getindex(A::LinearMap, ::Base.Slice, ::Base.Slice) = convert(AbstractMatrix, A)
59+
_getindex(A::LinearMap, I::Base.Slice, J::Indexer) = __getindex(A, I, J)
60+
_getindex(A::LinearMap, I::Indexer, J::Base.Slice) = __getindex(A, I, J)
61+
function __getindex(A, I, J)
62+
dest = zeros(eltype(A), Base.index_shape(I, J))
63+
# choose whichever requires less map applications
64+
if length(I) <= length(J)
65+
try
66+
# requires adjoint action to be defined
67+
_fillbyrows!(dest, A, I, J)
68+
catch
69+
error("wide slicing A[I,J] with length(I) <= length(J) requires the adjoint of $(typeof(A)) to be defined")
70+
end
71+
else
72+
_fillbycols!(dest, A, I, J)
73+
end
74+
return dest
75+
end
76+
77+
# helpers
78+
function unitvec(A, dim, i)
79+
x = zeros(eltype(A), size(A, dim))
80+
@inbounds x[i] = one(eltype(A))
81+
return x
82+
end
83+
84+
function _fillbyrows!(dest, A, I, J)
85+
x = zeros(eltype(A), size(A, 1))
86+
temp = similar(x, eltype(A), size(A, 2))
87+
@views @inbounds for (di, i) in zip(eachrow(dest), I)
88+
x[i] = one(eltype(A))
89+
_unsafe_mul!(temp, A', x)
90+
di .= adjoint.(temp[J])
91+
x[i] = zero(eltype(A))
92+
end
93+
return dest
94+
end
95+
function _fillbycols!(dest, A, I::Indexer, J)
96+
x = zeros(eltype(A), size(A, 2))
97+
temp = similar(x, eltype(A), size(A, 1))
98+
@inbounds for (ind, j) in enumerate(J)
99+
x[j] = one(eltype(A))
100+
_unsafe_mul!(temp, A, x)
101+
dest[:,ind] .= temp[I]
102+
x[j] = zero(eltype(A))
103+
end
104+
return dest
105+
end
106+
function _fillbycols!(dest, A, ::Base.Slice, J)
107+
x = zeros(eltype(A), size(A, 2))
108+
@inbounds for (ind, j) in enumerate(J)
109+
x[j] = one(eltype(A))
110+
_unsafe_mul!(selectdim(dest, 2, ind), A, x)
111+
x[j] = zero(eltype(A))
112+
end
113+
return dest
114+
end

test/getindex.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
using LinearAlgebra, LinearMaps, Test
2+
using LinearMaps: VecOrMatMap, ScaledMap
3+
# using BenchmarkTools
4+
5+
function test_getindex(A::LinearMap, M::AbstractMatrix)
6+
@assert size(A) == size(M)
7+
mask = rand(Bool, size(A))
8+
imask = rand(Bool, size(A, 1))
9+
jmask = rand(Bool, size(A, 2))
10+
@test A[1,:] == M[1,:]
11+
@test A[:,1] == M[:,1]
12+
@test A[1:lastindex(A,1)-2,:] == M[1:lastindex(A,1)-2,:]
13+
@test A[:,1:4] == M[:,1:4]
14+
@test A[[2,1],:] == M[[2,1],:]
15+
@test A[:,[2,1]] == M[:,[2,1]]
16+
@test A[:,:] == M
17+
@test (lastindex(A, 1), lastindex(A, 2)) == size(A)
18+
if A isa VecOrMatMap || A isa ScaledMap{<:Any,<:Any,<:VecOrMatMap}
19+
@test A[:] == M[:]
20+
@test A[1,1] == M[1,1]
21+
@test A[1,1:3] == M[1,1:3]
22+
@test A[1:3,1] == M[1:3,1]
23+
@test A[2:end,1] == M[2:end,1]
24+
@test A[1:2,1:3] == M[1:2,1:3]
25+
@test A[[2,1],1:3] == M[[2,1],1:3]
26+
@test A[7] == M[7]
27+
@test A[3:7] == M[3:7]
28+
@test A[mask] == M[mask]
29+
@test A[findall(mask)] == M[findall(mask)]
30+
@test A[CartesianIndex(1,1)] == M[CartesianIndex(1,1)]
31+
@test A[imask, 1] == M[imask, 1]
32+
@test A[1, jmask] == M[1, jmask]
33+
@test A[imask, jmask] == M[imask, jmask]
34+
else
35+
@test_throws ErrorException A[:]
36+
@test_throws ErrorException A[1,1]
37+
@test_throws ErrorException A[1,1:3]
38+
@test_throws ErrorException A[1:3,1]
39+
@test_throws ErrorException A[2:end,1]
40+
@test_throws ErrorException A[1:2,1:3]
41+
@test_throws ErrorException A[[2,1],1:3]
42+
@test_throws ErrorException A[7]
43+
@test_throws ErrorException A[3:7]
44+
@test_throws ErrorException A[mask]
45+
@test_throws ErrorException A[findall(mask)]
46+
@test_throws ErrorException A[CartesianIndex(1,1)]
47+
@test_throws ErrorException A[imask, 1]
48+
@test_throws ErrorException A[1, jmask]
49+
@test_throws ErrorException A[imask, jmask]
50+
end
51+
@test_throws BoundsError A[lastindex(A,1)+1,1]
52+
@test_throws BoundsError A[1,lastindex(A,2)+1]
53+
@test_throws BoundsError A[2,1:lastindex(A,2)+1]
54+
@test_throws BoundsError A[1:lastindex(A,1)+1,2]
55+
@test_throws BoundsError A[ones(Bool, 2, 2)]
56+
@test_throws BoundsError A[[true, true], 1]
57+
@test_throws BoundsError A[1, [true, true]]
58+
return nothing
59+
end
60+
61+
@testset "getindex" begin
62+
M = rand(4,6)
63+
A = LinearMap(M)
64+
test_getindex(A, M)
65+
test_getindex(2A, 2M)
66+
# @btime getindex($M, i) setup=(i = rand(1:24));
67+
# @btime getindex($A, i) setup=(i = rand(1:24));
68+
# @btime (getindex($M, i, j)) setup=(i = rand(1:4); j = rand(1:6));
69+
# @btime (getindex($A, i, j)) setup=(i = rand(1:4); j = rand(1:6));
70+
71+
struct TwoMap <: LinearMaps.LinearMap{Float64} end
72+
Base.size(::TwoMap) = (5,5)
73+
LinearMaps._unsafe_mul!(y::AbstractVector, ::TwoMap, x::AbstractVector) = fill!(y, 2.0*sum(x))
74+
T = TwoMap()
75+
@test_throws ErrorException T[1,:]
76+
77+
Base.transpose(A::TwoMap) = A
78+
test_getindex(TwoMap(), fill(2.0, size(T)))
79+
80+
MA = rand(ComplexF64, 5, 5)
81+
FA = LinearMap{ComplexF64}((y, x) -> mul!(y, MA, x), (y, x) -> mul!(y, MA', x), 5, 5)
82+
F = LinearMap{ComplexF64}(x -> MA*x, y -> MA'y, 5, 5)
83+
test_getindex(FA, MA)
84+
test_getindex([FA FA], [MA MA])
85+
test_getindex([FA; FA], [MA; MA])
86+
test_getindex(F, MA)
87+
test_getindex(3FA, 3MA)
88+
test_getindex(FA + FA, 2MA)
89+
test_getindex(transpose(FA), transpose(MA))
90+
test_getindex(transpose(3FA), transpose(3MA))
91+
test_getindex(3transpose(FA), transpose(3MA))
92+
test_getindex(adjoint(FA), adjoint(MA))
93+
test_getindex(adjoint(3FA), adjoint(3MA))
94+
test_getindex(3adjoint(FA), adjoint(3MA))
95+
96+
test_getindex(FillMap(0.5, (5, 5)), fill(0.5, (5, 5)))
97+
test_getindex(LinearMap(0.5I, 5), Matrix(0.5I, 5, 5))
98+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,5 @@ include("fillmap.jl")
3535
include("nontradaxes.jl")
3636

3737
include("embeddedmap.jl")
38+
39+
include("getindex.jl")

0 commit comments

Comments
 (0)