Skip to content

Commit 79bac77

Browse files
refactor: remove potential for StackOverflowError in AbstractVectorOfArray indexing
1 parent 9790424 commit 79bac77

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

src/vector_of_array.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,15 @@ end
230230
@deprecate Base.getindex(A::AbstractDiffEqArray, i::Int) Base.getindex(A, :, i) false
231231

232232
__parameterless_type(T) = Base.typename(T).wrapper
233-
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N},
233+
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N},
234234
::NotSymbolic, I::Colon...) where {T, N}
235235
@assert length(I) == ndims(A.u[1]) + 1
236236
vecs = vec.(A.u)
237237
return Adapt.adapt(__parameterless_type(T),
238238
reshape(reduce(hcat, vecs), size(A.u[1])..., length(A.u)))
239239
end
240240

241-
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N},
241+
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N},
242242
::NotSymbolic, I::AbstractArray{Bool},
243243
J::Colon...) where {T, N}
244244
@assert length(J) == ndims(A.u[1]) + 1 - ndims(I)
@@ -247,34 +247,34 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray{T, N},
247247
end
248248

249249
# Need two of each methods to avoid ambiguities
250-
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int)
250+
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int)
251251
A.u[I]
252252
end
253253

254-
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, ::NotSymbolic, I::Union{Int,AbstractArray{Int},AbstractArray{Bool},Colon}...)
254+
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, I::Union{Int,AbstractArray{Int},AbstractArray{Bool},Colon}...)
255255
if last(I) isa Int
256256
A.u[last(I)][Base.front(I)...]
257257
else
258258
stack(getindex.(A.u[last(I)], tuple.(Base.front(I))...))
259259
end
260260
end
261-
Base.@propagate_inbounds function Base.getindex(VA::AbstractVectorOfArray, ::NotSymbolic, ii::CartesianIndex)
261+
Base.@propagate_inbounds function _getindex(VA::AbstractVectorOfArray, ::NotSymbolic, ii::CartesianIndex)
262262
ti = Tuple(ii)
263263
i = last(ti)
264264
jj = CartesianIndex(Base.front(ti))
265265
return VA.u[i][jj]
266266
end
267267

268-
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}})
268+
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}})
269269
VectorOfArray(A.u[I])
270270
end
271271

272-
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}})
272+
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymbolic, ::Colon, I::Union{AbstractArray{Int},AbstractArray{Bool}})
273273
DiffEqArray(A.u[I], A.t[I], parameter_values(A), symbolic_container(A))
274274
end
275275

276276
# Symbolic Indexing Methods
277-
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym)
277+
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym)
278278
if is_independent_variable(A, sym)
279279
return A.t
280280
elseif is_variable(A, sym)
@@ -296,7 +296,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::Scalar
296296
end
297297
end
298298

299-
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...)
299+
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym, args...)
300300
if is_independent_variable(A, sym)
301301
return A.t[args...]
302302
elseif is_variable(A, sym)
@@ -319,11 +319,11 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::Scalar
319319
end
320320

321321

322-
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...)
322+
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ArraySymbolic, sym, args...)
323323
return getindex(A, collect(sym), args...)
324324
end
325325

326-
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
326+
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
327327
if all(x -> is_parameter(A, x), sym)
328328
Base.depwarn("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.", :parameter_getindex)
329329
return getp(A, sym)(A)
@@ -332,7 +332,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::Scalar
332332
end
333333
end
334334

335-
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray}, args...)
335+
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray}, args...)
336336
return reduce(vcat, map(s -> A[s, args...]', sym))
337337
end
338338

@@ -341,9 +341,9 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg,
341341
elsymtype = symbolic_type(eltype(_arg))
342342

343343
if symtype != NotSymbolic()
344-
return Base.getindex(A, symtype, _arg, args...)
344+
return _getindex(A, symtype, _arg, args...)
345345
else
346-
return Base.getindex(A, elsymtype, _arg, args...)
346+
return _getindex(A, elsymtype, _arg, args...)
347347
end
348348
end
349349

0 commit comments

Comments
 (0)