Skip to content

Commit 17a3c77

Browse files
authored
Add StepRange support for CartesianIndices (#37829)
1 parent 9405bf5 commit 17a3c77

File tree

5 files changed

+494
-78
lines changed

5 files changed

+494
-78
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ Standard library changes
111111
* `first` and `last` functions now accept an integer as second argument to get that many
112112
leading or trailing elements of any iterable ([#34868]).
113113
* `intersect` on `CartesianIndices` now returns `CartesianIndices` instead of `Vector{<:CartesianIndex}` ([#36643]).
114+
* `CartesianIndices` now supports step different from `1`. It can also be constructed from three
115+
`CartesianIndex`es `I`, `S`, `J` using `I:S:J`. `step` for `CartesianIndices` now returns a
116+
`CartesianIndex`. ([#37829])
114117
* `push!(c::Channel, v)` now returns channel `c`. Previously, it returned the pushed value `v` ([#34202]).
115118
* `RegexMatch` objects can now be probed for whether a named capture group exists within it through `haskey()` ([#36717]).
116119
* For consistency `haskey(r::RegexMatch, i::Integer)` has also been added and returns if the capture group for `i` exists ([#37300]).

base/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,7 @@ broadcasted(::typeof(+), j::CartesianIndex{N}, I::CartesianIndices{N}) where N =
11121112
broadcasted(::typeof(-), I::CartesianIndices{N}, j::CartesianIndex{N}) where N =
11131113
CartesianIndices(map((rng, offset)->rng .- offset, I.indices, Tuple(j)))
11141114
function broadcasted(::typeof(-), j::CartesianIndex{N}, I::CartesianIndices{N}) where N
1115-
diffrange(offset, rng) = range(offset-last(rng), length=length(rng))
1115+
diffrange(offset, rng) = range(offset-last(rng), length=length(rng), step=step(rng))
11161116
Iterators.reverse(CartesianIndices(map(diffrange, Tuple(j), I.indices)))
11171117
end
11181118

base/multidimensional.jl

Lines changed: 138 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ module IteratorsMD
1111
using .Base: IndexLinear, IndexCartesian, AbstractCartesianIndex, fill_to_length, tail,
1212
ReshapedArray, ReshapedArrayLF, OneTo
1313
using .Base.Iterators: Reverse, PartitionIterator
14+
using .Base: @propagate_inbounds
1415

1516
export CartesianIndex, CartesianIndices
1617

@@ -149,13 +150,13 @@ module IteratorsMD
149150
function Base.nextind(a::AbstractArray{<:Any,N}, i::CartesianIndex{N}) where {N}
150151
iter = CartesianIndices(axes(a))
151152
# might overflow
152-
I = inc(i.I, first(iter).I, last(iter).I)
153+
I = inc(i.I, iter.indices)
153154
return I
154155
end
155156
function Base.prevind(a::AbstractArray{<:Any,N}, i::CartesianIndex{N}) where {N}
156157
iter = CartesianIndices(axes(a))
157158
# might underflow
158-
I = dec(i.I, last(iter).I, first(iter).I)
159+
I = dec(i.I, iter.indices)
159160
return I
160161
end
161162

@@ -169,15 +170,15 @@ module IteratorsMD
169170
# Iteration
170171
"""
171172
CartesianIndices(sz::Dims) -> R
172-
CartesianIndices((istart:istop, jstart:jstop, ...)) -> R
173+
CartesianIndices((istart:[istep:]istop, jstart:[jstep:]jstop, ...)) -> R
173174
174175
Define a region `R` spanning a multidimensional rectangular range
175176
of integer indices. These are most commonly encountered in the
176177
context of iteration, where `for I in R ... end` will return
177178
[`CartesianIndex`](@ref) indices `I` equivalent to the nested loops
178179
179-
for j = jstart:jstop
180-
for i = istart:istop
180+
for j = jstart:jstep:jstop
181+
for i = istart:istep:istop
181182
...
182183
end
183184
end
@@ -190,6 +191,10 @@ module IteratorsMD
190191
As a convenience, constructing a `CartesianIndices` from an array makes a
191192
range of its indices.
192193
194+
!!! compat "Julia 1.6"
195+
The step range method `CartesianIndices((istart:istep:istop, jstart:[jstep:]jstop, ...))`
196+
requires at least Julia 1.6.
197+
193198
# Examples
194199
```jldoctest
195200
julia> foreach(println, CartesianIndices((2, 2, 2)))
@@ -222,6 +227,15 @@ module IteratorsMD
222227
223228
julia> cartesian[4]
224229
CartesianIndex(1, 2)
230+
231+
julia> cartesian = CartesianIndices((1:2:5, 1:2))
232+
3×2 CartesianIndices{2, Tuple{StepRange{Int64, Int64}, UnitRange{Int64}}}:
233+
CartesianIndex(1, 1) CartesianIndex(1, 2)
234+
CartesianIndex(3, 1) CartesianIndex(3, 2)
235+
CartesianIndex(5, 1) CartesianIndex(5, 2)
236+
237+
julia> cartesian[2, 2]
238+
CartesianIndex(3, 2)
225239
```
226240
227241
## Broadcasting
@@ -248,29 +262,37 @@ module IteratorsMD
248262
249263
For cartesian to linear index conversion, see [`LinearIndices`](@ref).
250264
"""
251-
struct CartesianIndices{N,R<:NTuple{N,AbstractUnitRange{Int}}} <: AbstractArray{CartesianIndex{N},N}
265+
struct CartesianIndices{N,R<:NTuple{N,OrdinalRange{Int, Int}}} <: AbstractArray{CartesianIndex{N},N}
252266
indices::R
253267
end
254268

255269
CartesianIndices(::Tuple{}) = CartesianIndices{0,typeof(())}(())
256-
CartesianIndices(inds::NTuple{N,AbstractUnitRange{<:Integer}}) where {N} =
257-
CartesianIndices(map(r->convert(AbstractUnitRange{Int}, r), inds))
270+
function CartesianIndices(inds::NTuple{N,OrdinalRange{<:Integer, <:Integer}}) where {N}
271+
indices = map(r->convert(OrdinalRange{Int, Int}, r), inds)
272+
CartesianIndices{N, typeof(indices)}(indices)
273+
end
258274

259275
CartesianIndices(index::CartesianIndex) = CartesianIndices(index.I)
260-
CartesianIndices(sz::NTuple{N,<:Integer}) where {N} = CartesianIndices(map(Base.OneTo, sz))
261-
CartesianIndices(inds::NTuple{N,Union{<:Integer,AbstractUnitRange{<:Integer}}}) where {N} =
262-
CartesianIndices(map(i->first(i):last(i), inds))
276+
CartesianIndices(inds::NTuple{N,Union{<:Integer,OrdinalRange{<:Integer}}}) where {N} =
277+
CartesianIndices(map(_convert2ind, inds))
263278

264279
CartesianIndices(A::AbstractArray) = CartesianIndices(axes(A))
265280

281+
_convert2ind(sz::Integer) = Base.OneTo(sz)
282+
_convert2ind(sz::AbstractUnitRange) = first(sz):last(sz)
283+
_convert2ind(sz::OrdinalRange) = first(sz):step(sz):last(sz)
284+
266285
"""
267-
(:)(I::CartesianIndex, J::CartesianIndex)
286+
(:)(start::CartesianIndex, [step::CartesianIndex], stop::CartesianIndex)
268287
269-
Construct [`CartesianIndices`](@ref) from two `CartesianIndex`.
288+
Construct [`CartesianIndices`](@ref) from two `CartesianIndex` and an optional step.
270289
271290
!!! compat "Julia 1.1"
272291
This method requires at least Julia 1.1.
273292
293+
!!! compat "Julia 1.6"
294+
The step range method start:step:stop requires at least Julia 1.6.
295+
274296
# Examples
275297
```jldoctest
276298
julia> I = CartesianIndex(2,1);
@@ -281,17 +303,26 @@ module IteratorsMD
281303
2×3 CartesianIndices{2, Tuple{UnitRange{Int64}, UnitRange{Int64}}}:
282304
CartesianIndex(2, 1) CartesianIndex(2, 2) CartesianIndex(2, 3)
283305
CartesianIndex(3, 1) CartesianIndex(3, 2) CartesianIndex(3, 3)
306+
307+
julia> I:CartesianIndex(1, 2):J
308+
2×2 CartesianIndices{2, Tuple{StepRange{Int64, Int64}, StepRange{Int64, Int64}}}:
309+
CartesianIndex(2, 1) CartesianIndex(2, 3)
310+
CartesianIndex(3, 1) CartesianIndex(3, 3)
284311
```
285312
"""
286313
(:)(I::CartesianIndex{N}, J::CartesianIndex{N}) where N =
287314
CartesianIndices(map((i,j) -> i:j, Tuple(I), Tuple(J)))
315+
(:)(I::CartesianIndex{N}, S::CartesianIndex{N}, J::CartesianIndex{N}) where N =
316+
CartesianIndices(map((i,s,j) -> i:s:j, Tuple(I), Tuple(S), Tuple(J)))
288317

289318
promote_rule(::Type{CartesianIndices{N,R1}}, ::Type{CartesianIndices{N,R2}}) where {N,R1,R2} =
290319
CartesianIndices{N,Base.indices_promote_type(R1,R2)}
291320

292321
convert(::Type{Tuple{}}, R::CartesianIndices{0}) = ()
293-
convert(::Type{NTuple{N,AbstractUnitRange{Int}}}, R::CartesianIndices{N}) where {N} =
294-
R.indices
322+
for RT in (OrdinalRange{Int, Int}, StepRange{Int, Int}, AbstractUnitRange{Int})
323+
@eval convert(::Type{NTuple{N,$RT}}, R::CartesianIndices{N}) where {N} =
324+
map(x->convert($RT, x), R.indices)
325+
end
295326
convert(::Type{NTuple{N,AbstractUnitRange}}, R::CartesianIndices{N}) where {N} =
296327
convert(NTuple{N,AbstractUnitRange{Int}}, R)
297328
convert(::Type{NTuple{N,UnitRange{Int}}}, R::CartesianIndices{N}) where {N} =
@@ -318,13 +349,8 @@ module IteratorsMD
318349
# AbstractArray implementation
319350
Base.axes(iter::CartesianIndices{N,R}) where {N,R} = map(Base.axes1, iter.indices)
320351
Base.IndexStyle(::Type{CartesianIndices{N,R}}) where {N,R} = IndexCartesian()
321-
@inline function Base.getindex(iter::CartesianIndices{N,<:NTuple{N,Base.OneTo}}, I::Vararg{Int, N}) where {N}
322-
@boundscheck checkbounds(iter, I...)
323-
CartesianIndex(I)
324-
end
325-
@inline function Base.getindex(iter::CartesianIndices{N,R}, I::Vararg{Int, N}) where {N,R}
326-
@boundscheck checkbounds(iter, I...)
327-
CartesianIndex(I .- first.(Base.axes1.(iter.indices)) .+ first.(iter.indices))
352+
@propagate_inbounds function Base.getindex(iter::CartesianIndices{N,R}, I::Vararg{Int, N}) where {N,R}
353+
CartesianIndex(getindex.(iter.indices, I))
328354
end
329355

330356
ndims(R::CartesianIndices) = ndims(typeof(R))
@@ -344,47 +370,65 @@ module IteratorsMD
344370
IteratorSize(::Type{<:CartesianIndices{N}}) where {N} = Base.HasShape{N}()
345371

346372
@inline function iterate(iter::CartesianIndices)
347-
iterfirst, iterlast = first(iter), last(iter)
348-
if any(map(>, iterfirst.I, iterlast.I))
373+
iterfirst = first(iter)
374+
if !all(map(in, iterfirst.I, iter.indices))
349375
return nothing
350376
end
351377
iterfirst, iterfirst
352378
end
353379
@inline function iterate(iter::CartesianIndices, state)
354-
valid, I = __inc(state.I, first(iter).I, last(iter).I)
380+
valid, I = __inc(state.I, iter.indices)
355381
valid || return nothing
356382
return CartesianIndex(I...), CartesianIndex(I...)
357383
end
358384

359385
# increment & carry
360-
@inline function inc(state, start, stop)
361-
_, I = __inc(state, start, stop)
386+
@inline function inc(state, indices)
387+
_, I = __inc(state, indices)
362388
return CartesianIndex(I...)
363389
end
364390

365-
# increment post check to avoid integer overflow
366-
@inline __inc(::Tuple{}, ::Tuple{}, ::Tuple{}) = false, ()
367-
@inline function __inc(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int})
368-
valid = state[1] < stop[1]
369-
return valid, (state[1]+1,)
391+
# Unlike ordinary ranges, CartesianIndices continues the iteration in the next column when the
392+
# current column is consumed. The implementation is written recursively to achieve this.
393+
# `iterate` returns `Union{Nothing, Tuple}`, we explicitly pass a `valid` flag to eliminate
394+
# the type instability inside the core `__inc` logic, and this gives better runtime performance.
395+
__inc(::Tuple{}, ::Tuple{}) = false, ()
396+
@inline function __inc(state::Tuple{Int}, indices::Tuple{<:OrdinalRange})
397+
rng = indices[1]
398+
I = state[1] + step(rng)
399+
valid = __is_valid_range(I, rng) && state[1] != last(rng)
400+
return valid, (I, )
401+
end
402+
@inline function __inc(state, indices)
403+
rng = indices[1]
404+
I = state[1] + step(rng)
405+
if __is_valid_range(I, rng) && state[1] != last(rng)
406+
return true, (I, tail(state)...)
407+
end
408+
valid, I = __inc(tail(state), tail(indices))
409+
return valid, (first(rng), I...)
370410
end
371411

372-
@inline function __inc(state, start, stop)
373-
if state[1] < stop[1]
374-
return true, (state[1]+1, tail(state)...)
412+
@inline __is_valid_range(I, rng::AbstractUnitRange) = I in rng
413+
@inline function __is_valid_range(I, rng::OrdinalRange)
414+
if step(rng) > 0
415+
lo, hi = first(rng), last(rng)
416+
else
417+
lo, hi = last(rng), first(rng)
375418
end
376-
valid, I = __inc(tail(state), tail(start), tail(stop))
377-
return valid, (start[1], I...)
419+
lo <= I <= hi
378420
end
379421

380422
# 0-d cartesian ranges are special-cased to iterate once and only once
381423
iterate(iter::CartesianIndices{0}, done=false) = done ? nothing : (CartesianIndex(), true)
382424

383-
size(iter::CartesianIndices) = map(dimlength, first(iter).I, last(iter).I)
384-
dimlength(start, stop) = stop-start+1
425+
size(iter::CartesianIndices) = map(length, iter.indices)
385426

386427
length(iter::CartesianIndices) = prod(size(iter))
387428

429+
# make CartesianIndices a multidimensional range
430+
Base.step(iter::CartesianIndices) = CartesianIndex(map(step, iter.indices))
431+
388432
first(iter::CartesianIndices) = CartesianIndex(map(first, iter.indices))
389433
last(iter::CartesianIndices) = CartesianIndex(map(last, iter.indices))
390434

@@ -395,11 +439,8 @@ module IteratorsMD
395439
@inline to_indices(A, inds, I::Tuple{CartesianIndices{0},Vararg{Any}}) =
396440
(first(I), to_indices(A, inds, tail(I))...)
397441

398-
@inline function in(i::CartesianIndex{N}, r::CartesianIndices{N}) where {N}
399-
_in(true, i.I, first(r).I, last(r).I)
400-
end
401-
_in(b, ::Tuple{}, ::Tuple{}, ::Tuple{}) = b
402-
@inline _in(b, i, start, stop) = _in(b & (start[1] <= i[1] <= stop[1]), tail(i), tail(start), tail(stop))
442+
@inline in(i::CartesianIndex, r::CartesianIndices) = false
443+
@inline in(i::CartesianIndex{N}, r::CartesianIndices{N}) where {N} = all(map(in, i.I, r.indices))
403444

404445
simd_outer_range(iter::CartesianIndices{0}) = iter
405446
function simd_outer_range(iter::CartesianIndices)
@@ -410,8 +451,8 @@ module IteratorsMD
410451
simd_inner_length(iter::CartesianIndices, I::CartesianIndex) = Base.length(iter.indices[1])
411452

412453
simd_index(iter::CartesianIndices{0}, ::CartesianIndex, I1::Int) = first(iter)
413-
@inline function simd_index(iter::CartesianIndices, Ilast::CartesianIndex, I1::Int)
414-
CartesianIndex((I1+first(iter.indices[1]), Ilast.I...))
454+
@propagate_inbounds function simd_index(iter::CartesianIndices, Ilast::CartesianIndex, I1::Int)
455+
CartesianIndex(getindex(iter.indices[1], I1+first(Base.axes1(iter.indices[1]))), Ilast.I...)
415456
end
416457

417458
# Split out the first N elements of a tuple
@@ -440,44 +481,79 @@ module IteratorsMD
440481

441482
# reversed CartesianIndices iteration
442483

484+
Base.reverse(iter::CartesianIndices) = CartesianIndices(reverse.(iter.indices))
485+
443486
@inline function iterate(r::Reverse{<:CartesianIndices})
444-
iterfirst, iterlast = last(r.itr), first(r.itr)
445-
if any(map(<, iterfirst.I, iterlast.I))
487+
iterfirst = last(r.itr)
488+
if !all(map(in, iterfirst.I, r.itr.indices))
446489
return nothing
447490
end
448491
iterfirst, iterfirst
449492
end
450493
@inline function iterate(r::Reverse{<:CartesianIndices}, state)
451-
valid, I = __dec(state.I, last(r.itr).I, first(r.itr).I)
494+
valid, I = __dec(state.I, r.itr.indices)
452495
valid || return nothing
453496
return CartesianIndex(I...), CartesianIndex(I...)
454497
end
455498

456499
# decrement & carry
457-
@inline function dec(state, start, stop)
458-
_, I = __dec(state, start, stop)
500+
@inline function dec(state, indices)
501+
_, I = __dec(state, indices)
459502
return CartesianIndex(I...)
460503
end
461504

462505
# decrement post check to avoid integer overflow
463-
@inline __dec(::Tuple{}, ::Tuple{}, ::Tuple{}) = false, ()
464-
@inline function __dec(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int})
465-
valid = state[1] > stop[1]
466-
return valid, (state[1]-1,)
506+
@inline __dec(::Tuple{}, ::Tuple{}) = false, ()
507+
@inline function __dec(state::Tuple{Int}, indices::Tuple{<:OrdinalRange})
508+
rng = indices[1]
509+
I = state[1] - step(rng)
510+
valid = __is_valid_range(I, rng) && state[1] != first(rng)
511+
return valid, (I,)
467512
end
468513

469-
@inline function __dec(state, start, stop)
470-
if state[1] > stop[1]
471-
return true, (state[1]-1, tail(state)...)
514+
@inline function __dec(state, indices)
515+
rng = indices[1]
516+
I = state[1] - step(rng)
517+
if __is_valid_range(I, rng) && state[1] != first(rng)
518+
return true, (I, tail(state)...)
472519
end
473-
valid, I = __dec(tail(state), tail(start), tail(stop))
474-
return valid, (start[1], I...)
520+
valid, I = __dec(tail(state), tail(indices))
521+
return valid, (last(rng), I...)
475522
end
476523

477524
# 0-d cartesian ranges are special-cased to iterate once and only once
478525
iterate(iter::Reverse{<:CartesianIndices{0}}, state=false) = state ? nothing : (CartesianIndex(), true)
479526

480-
Base.LinearIndices(inds::CartesianIndices{N,R}) where {N,R} = LinearIndices{N,R}(inds.indices)
527+
function Base.LinearIndices(inds::CartesianIndices{N,R}) where {N,R<:NTuple{N, AbstractUnitRange}}
528+
LinearIndices{N,R}(inds.indices)
529+
end
530+
function Base.LinearIndices(inds::CartesianIndices)
531+
indices = inds.indices
532+
if all(x->step(x)==1, indices)
533+
indices = map(rng->first(rng):last(rng), indices)
534+
LinearIndices{length(indices), typeof(indices)}(indices)
535+
else
536+
# Given the fact that StepRange 1:2:4 === 1:2:3, we lost the original size information
537+
# and thus cannot calculate the correct linear indices when the steps are not 1.
538+
throw(ArgumentError("LinearIndices for $(typeof(inds)) with non-1 step size is not yet supported."))
539+
end
540+
end
541+
542+
# This is currently needed because converting to LinearIndices is only available when steps are
543+
# all 1
544+
# NOTE: this is only a temporary patch and could be possibly removed when StepRange support to
545+
# LinearIndices is done
546+
function Base.collect(inds::CartesianIndices{N, R}) where {N,R<:NTuple{N, AbstractUnitRange}}
547+
Base._collect_indices(axes(inds), inds)
548+
end
549+
function Base.collect(inds::CartesianIndices)
550+
dest = Array{eltype(inds), ndims(inds)}(undef, size(inds))
551+
i = 0
552+
@inbounds for a in inds
553+
dest[i+=1] = a
554+
end
555+
dest
556+
end
481557

482558
# array operations
483559
Base.intersect(a::CartesianIndices{N}, b::CartesianIndices{N}) where N =
@@ -501,7 +577,7 @@ module IteratorsMD
501577
end
502578
@inline function iterate(iter::CartesianPartition, (state, n))
503579
n >= length(iter) && return nothing
504-
I = IteratorsMD.inc(state.I, first(iter.parent.parent).I, last(iter.parent.parent).I)
580+
I = IteratorsMD.inc(state.I, iter.parent.parent.indices)
505581
return I, (I, n+1)
506582
end
507583

0 commit comments

Comments
 (0)