Skip to content

Commit 7b70fc8

Browse files
committed
Add getindex functionality
1 parent 7363170 commit 7b70fc8

File tree

5 files changed

+234
-32
lines changed

5 files changed

+234
-32
lines changed

src/LinearMaps.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ include("kronecker.jl") # Kronecker product of linear maps
248248
include("fillmap.jl") # linear maps representing constantly filled matrices
249249
include("conversion.jl") # conversion of linear maps to matrices
250250
include("show.jl") # show methods for LinearMap objects
251+
include("getindex.jl") # getindex functionality
251252

252253
"""
253254
LinearMap(A::LinearMap; kwargs...)::WrappedMap

src/getindex.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# module GetIndex
2+
3+
# using ..LinearMaps: LinearMap, AdjointMap, TransposeMap, FillMap, LinearCombination,
4+
# ScaledMap, UniformScalingMap, WrappedMap
5+
6+
# required in Base.to_indices for [:]-indexing
7+
Base.eachindex(::IndexLinear, A::LinearMap) = (Base.@_inline_meta; Base.OneTo(length(A)))
8+
Base.lastindex(A::LinearMap) = (Base.@_inline_meta; last(eachindex(IndexLinear(), A)))
9+
Base.firstindex(A::LinearMap) = (Base.@_inline_meta; first(eachindex(IndexLinear(), A)))
10+
11+
function Base.checkbounds(A::LinearMap, i, j)
12+
Base.@_inline_meta
13+
Base.checkbounds_indices(Bool, axes(A), (i, j)) || throw(BoundsError(A, (i, j)))
14+
nothing
15+
end
16+
# Linear indexing is explicitly allowed when there is only one (non-cartesian) index
17+
function Base.checkbounds(A::LinearMap, i)
18+
Base.@_inline_meta
19+
Base.checkindex(Bool, Base.OneTo(length(A)), i) || throw(BoundsError(A, i))
20+
nothing
21+
end
22+
23+
# main entry point
24+
Base.@propagate_inbounds function Base.getindex(A::LinearMap, I...)
25+
# TODO: introduce some sort of switch?
26+
Base.@_inline_meta
27+
@boundscheck checkbounds(A, I...)
28+
@inbounds _getindex(A, Base.to_indices(A, I)...)
29+
end
30+
# quick pass forward
31+
Base.@propagate_inbounds Base.getindex(A::ScaledMap, I...) = A.λ .* getindex(A.lmap, I...)
32+
Base.@propagate_inbounds Base.getindex(A::WrappedMap, I...) = A.lmap[I...]
33+
Base.@propagate_inbounds Base.getindex(A::WrappedMap, i::Integer) = A.lmap[i]
34+
Base.@propagate_inbounds Base.getindex(A::WrappedMap, i::Integer, j::Integer) = A.lmap[i,j]
35+
36+
########################
37+
# linear indexing
38+
########################
39+
Base.@propagate_inbounds function _getindex(A::LinearMap, i::Integer)
40+
Base.@_inline_meta
41+
i1, i2 = Base._ind2sub(axes(A), i)
42+
return _getindex(A, i1, i2)
43+
end
44+
Base.@propagate_inbounds _getindex(A::LinearMap, I::AbstractVector{<:Integer}) =
45+
[_getindex(A, i) for i in I]
46+
_getindex(A::LinearMap, ::Base.Slice) = vec(Matrix(A))
47+
48+
########################
49+
# Cartesian indexing
50+
########################
51+
Base.@propagate_inbounds _getindex(A::LinearMap, i::Integer, j::Integer) =
52+
_getindex(A, Base.Slice(axes(A)[1]), j)[i]
53+
Base.@propagate_inbounds function _getindex(A::LinearMap, i::Integer, J::AbstractVector{<:Integer})
54+
try
55+
return (basevec(A, i)'A)[J]
56+
catch
57+
x = zeros(eltype(A), size(A, 2))
58+
y = similar(x, eltype(A), size(A, 1))
59+
r = similar(x, eltype(A), length(J))
60+
for (ind, j) in enumerate(J)
61+
x[j] = one(eltype(A))
62+
_unsafe_mul!(y, A, x)
63+
r[ind] = y[i]
64+
x[j] = zero(eltype(A))
65+
end
66+
return r
67+
end
68+
end
69+
function _getindex(A::LinearMap, i::Integer, J::Base.Slice)
70+
try
71+
return vec(basevec(A, i)'A)
72+
catch
73+
return vec(_getindex(A, i:i, J))
74+
end
75+
end
76+
Base.@propagate_inbounds _getindex(A::LinearMap, I::AbstractVector{<:Integer}, j::Integer) =
77+
_getindex(A, Base.Slice(axes(A)[1]), j)[I] # = A[:,j][I] w/o bounds check
78+
_getindex(A::LinearMap, ::Base.Slice, j::Integer) = A*basevec(A, j)
79+
Base.@propagate_inbounds function _getindex(A::LinearMap, Is::Vararg{AbstractVector{<:Integer},2})
80+
shape = Base.index_shape(Is...)
81+
dest = zeros(eltype(A), shape)
82+
I, J = Is
83+
for (ind, ij) in zip(eachindex(dest), Iterators.product(I, J))
84+
i, j = ij
85+
dest[ind] = _getindex(A, i, j)
86+
end
87+
return dest
88+
end
89+
Base.@propagate_inbounds function _getindex(A::LinearMap, I::AbstractVector{<:Integer}, ::Base.Slice)
90+
x = zeros(eltype(A), size(A, 2))
91+
y = similar(x, eltype(A), size(A, 1))
92+
r = similar(x, eltype(A), (length(I), size(A, 2)))
93+
@views for j in axes(A)[2]
94+
x[j] = one(eltype(A))
95+
_unsafe_mul!(y, A, x)
96+
r[:,j] .= y[I]
97+
x[j] = zero(eltype(A))
98+
end
99+
return r
100+
end
101+
Base.@propagate_inbounds function _getindex(A::LinearMap, ::Base.Slice, J::AbstractVector{<:Integer})
102+
x = zeros(eltype(A), size(A, 2))
103+
y = similar(x, eltype(A), (size(A, 1), length(J)))
104+
for (i, j) in enumerate(J)
105+
x[j] = one(eltype(A))
106+
_unsafe_mul!(selectdim(y, 2, i), A, x)
107+
x[j] = zero(eltype(A))
108+
end
109+
return y
110+
end
111+
_getindex(A::LinearMap, ::Base.Slice, ::Base.Slice) = Matrix(A)
112+
113+
# specialized methods
114+
_getindex(A::FillMap, ::Integer, ::Integer) = A.λ
115+
Base.@propagate_inbounds _getindex(A::LinearCombination, i::Integer, j::Integer) =
116+
sum(a -> A.maps[a][i, j], eachindex(A.maps))
117+
Base.@propagate_inbounds _getindex(A::AdjointMap, i::Integer, j::Integer) =
118+
adjoint(A.lmap[j, i])
119+
Base.@propagate_inbounds _getindex(A::TransposeMap, i::Integer, j::Integer) =
120+
transpose(A.lmap[j, i])
121+
_getindex(A::UniformScalingMap, i::Integer, j::Integer) = ifelse(i == j, A.λ, zero(eltype(A)))
122+
123+
# helpers
124+
function basevec(A, i::Integer)
125+
x = zeros(eltype(A), size(A, 2))
126+
@inbounds x[i] = one(eltype(A))
127+
return x
128+
end
129+
130+
nogetindex_error() = error("indexing not allowed for LinearMaps; consider setting `LinearMaps.allowgetindex = true`")
131+
132+
# end # module

test/getindex.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using BenchmarkTools, LinearAlgebra, LinearMaps, Test
2+
# using LinearMaps.GetIndex
3+
4+
function test_getindex(A::LinearMap, M::AbstractMatrix)
5+
@assert size(A) == size(M)
6+
@test all((A[i,j] == M[i,j] for i in axes(A, 1), j in axes(A, 2)))
7+
@test all((A[i] == M[i] for i in 1:length(A)))
8+
@test A[1,1] == M[1,1]
9+
@test A[:] == M[:]
10+
@test A[1,:] == M[1,:]
11+
@test A[:,1] == M[:,1]
12+
@test A[1:4,:] == M[1:4,:]
13+
@test A[:,1:4] == M[:,1:4]
14+
@test A[1,1:3] == M[1,1:3]
15+
@test A[1:3,1] == M[1:3,1]
16+
@test A[2:end,1] == M[2:end,1]
17+
@test A[1:2,1:3] == M[1:2,1:3]
18+
@test A[[2,1],1:3] == M[[2,1],1:3]
19+
@test A[:,:] == M
20+
@test A[7] == M[7]
21+
@test_throws BoundsError A[firstindex(A)-1]
22+
@test_throws BoundsError A[lastindex(A)+1]
23+
@test_throws BoundsError A[6,1]
24+
@test_throws BoundsError A[1,6]
25+
@test_throws BoundsError A[2,1:6]
26+
@test_throws BoundsError A[1:6,2]
27+
return true
28+
end
29+
30+
@testset "getindex" begin
31+
A = rand(5,5)
32+
L = LinearMap(A)
33+
# @btime getindex($A, i) setup=(i = rand(1:9));
34+
# @btime getindex($L, i) setup=(i = rand(1:9));
35+
# @btime (getindex($A, i, j)) setup=(i = rand(1:3); j = rand(1:3));
36+
# @btime (getindex($L, i, j)) setup=(i = rand(1:3); j = rand(1:3));
37+
38+
struct TwoMap <: LinearMaps.LinearMap{Float64} end
39+
Base.size(::TwoMap) = (5,5)
40+
LinearMaps._getindex(::TwoMap, i::Integer, j::Integer) = 2.0
41+
LinearMaps._unsafe_mul!(y::AbstractVector, ::TwoMap, x::AbstractVector) = fill!(y, 2.0*sum(x))
42+
43+
@test test_getindex(TwoMap(), fill(2.0, 5, 5))
44+
Base.adjoint(A::TwoMap) = A
45+
@test test_getindex(TwoMap(), fill(2.0, 5, 5))
46+
47+
MA = rand(ComplexF64, 5, 5)
48+
for FA in (
49+
LinearMap{ComplexF64}((y, x) -> mul!(y, MA, x), (y, x) -> mul!(y, MA', x), 5, 5),
50+
LinearMap{ComplexF64}((y, x) -> mul!(y, MA, x), 5, 5),
51+
)
52+
@test test_getindex(FA, MA)
53+
@test test_getindex(3FA, 3MA)
54+
@test test_getindex(FA + FA, 2MA)
55+
if !isnothing(FA.fc)
56+
@test test_getindex(transpose(FA), transpose(MA))
57+
@test test_getindex(transpose(3FA), transpose(3MA))
58+
@test test_getindex(3transpose(FA), transpose(3MA))
59+
@test test_getindex(adjoint(FA), adjoint(MA))
60+
@test test_getindex(adjoint(3FA), adjoint(3MA))
61+
@test test_getindex(3adjoint(FA), adjoint(3MA))
62+
end
63+
end
64+
65+
@test test_getindex(FillMap(0.5, (5, 5)), fill(0.5, (5, 5)))
66+
@test test_getindex(LinearMap(0.5I, 5), Matrix(0.5I, 5, 5))
67+
end

test/linearmaps.jl

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,20 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays, BenchmarkTools
3333
end
3434
end
3535

36-
# new type
37-
struct SimpleFunctionMap <: LinearMap{Float64}
38-
f::Function
39-
N::Int
40-
end
41-
struct SimpleComplexFunctionMap <: LinearMap{Complex{Float64}}
42-
f::Function
43-
N::Int
44-
end
45-
Base.size(A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}) = (A.N, A.N)
46-
Base.:(*)(A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}, v::AbstractVector) = A.f(v)
47-
LinearAlgebra.mul!(y::AbstractVector, A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}, x::AbstractVector) = copyto!(y, *(A, x))
48-
4936
@testset "new LinearMap type" begin
37+
# new type
38+
struct SimpleFunctionMap <: LinearMap{Float64}
39+
f::Function
40+
N::Int
41+
end
42+
struct SimpleComplexFunctionMap <: LinearMap{Complex{Float64}}
43+
f::Function
44+
N::Int
45+
end
46+
Base.size(A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}) = (A.N, A.N)
47+
Base.:(*)(A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}, v::AbstractVector) = A.f(v)
48+
LinearAlgebra.mul!(y::AbstractVector, A::Union{SimpleFunctionMap,SimpleComplexFunctionMap}, x::AbstractVector) = copyto!(y, *(A, x))
49+
5050
F = SimpleFunctionMap(cumsum, 10)
5151
FC = SimpleComplexFunctionMap(cumsum, 10)
5252
@test @inferred ndims(F) == 2
@@ -83,27 +83,27 @@ LinearAlgebra.mul!(y::AbstractVector, A::Union{SimpleFunctionMap,SimpleComplexFu
8383
@test Fs isa SparseMatrixCSC
8484
end
8585

86-
struct MyFillMap{T} <: LinearMaps.LinearMap{T}
87-
λ::T
88-
size::Dims{2}
89-
function MyFillMap::T, dims::Dims{2}) where {T}
90-
all(d -> d >= 0, dims) || throw(ArgumentError("dims of MyFillMap must be non-negative"))
91-
promote_type(T, typeof(λ)) == T || throw(InexactError())
92-
return new{T}(λ, dims)
86+
@testset "transpose of new LinearMap type" begin
87+
struct MyFillMap{T} <: LinearMaps.LinearMap{T}
88+
λ::T
89+
size::Dims{2}
90+
function MyFillMap::T, dims::Dims{2}) where {T}
91+
all(d -> d >= 0, dims) || throw(ArgumentError("dims of MyFillMap must be non-negative"))
92+
promote_type(T, typeof(λ)) == T || throw(InexactError())
93+
return new{T}(λ, dims)
94+
end
95+
end
96+
Base.size(A::MyFillMap) = A.size
97+
function LinearAlgebra.mul!(y::AbstractVecOrMat, A::MyFillMap, x::AbstractVector)
98+
LinearMaps.check_dim_mul(y, A, x)
99+
return fill!(y, iszero(A.λ) ? zero(eltype(y)) : A.λ*sum(x))
100+
end
101+
function LinearAlgebra.mul!(y::AbstractVecOrMat, transA::LinearMaps.TransposeMap{<:Any,<:MyFillMap}, x::AbstractVector)
102+
LinearMaps.check_dim_mul(y, transA, x)
103+
λ = transA.lmap.λ
104+
return fill!(y, iszero(λ) ? zero(eltype(y)) : transpose(λ)*sum(x))
93105
end
94-
end
95-
Base.size(A::MyFillMap) = A.size
96-
function LinearAlgebra.mul!(y::AbstractVecOrMat, A::MyFillMap, x::AbstractVector)
97-
LinearMaps.check_dim_mul(y, A, x)
98-
return fill!(y, iszero(A.λ) ? zero(eltype(y)) : A.λ*sum(x))
99-
end
100-
function LinearAlgebra.mul!(y::AbstractVecOrMat, transA::LinearMaps.TransposeMap{<:Any,<:MyFillMap}, x::AbstractVector)
101-
LinearMaps.check_dim_mul(y, transA, x)
102-
λ = transA.lmap.λ
103-
return fill!(y, iszero(λ) ? zero(eltype(y)) : transpose(λ)*sum(x))
104-
end
105106

106-
@testset "transpose of new LinearMap type" begin
107107
A = MyFillMap(5.0, (3, 3))
108108
x = ones(3)
109109
@test A * x == fill(15.0, 3)

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,5 @@ include("left.jl")
3434
include("fillmap.jl")
3535

3636
include("nontradaxes.jl")
37+
38+
include("getindex.jl")

0 commit comments

Comments
 (0)