Skip to content

Commit d6eb1ad

Browse files
Merge pull request #371 from AayushSabharwal/as/broadcast-index
fix: fix indexing using array symbolics, `Colon`
2 parents 39ae861 + 0879183 commit d6eb1ad

File tree

3 files changed

+60
-24
lines changed

3 files changed

+60
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ StaticArrays = "1.6"
5959
StaticArraysCore = "1.4"
6060
Statistics = "1.10"
6161
StructArrays = "0.6.11"
62-
SymbolicIndexingInterface = "0.3.19"
62+
SymbolicIndexingInterface = "0.3.20"
6363
Tables = "1.11"
6464
Test = "1"
6565
Tracker = "0.2.15"

src/vector_of_array.jl

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -351,49 +351,73 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymboli
351351
DiffEqArray(A.u[I], A.t[I], parameter_values(A), symbolic_container(A))
352352
end
353353

354+
struct ParameterIndexingError <: Exception
355+
sym
356+
end
357+
358+
function Base.showerror(io::IO, pie::ParameterIndexingError)
359+
print(io, "Indexing with parameters is deprecated. Use `getp(A, $(pie.sym))` for parameter indexing.")
360+
end
361+
354362
# Symbolic Indexing Methods
355-
for symtype in [ScalarSymbolic, ArraySymbolic]
356-
paramcheck = quote
357-
if is_parameter(A, sym) || (sym isa AbstractArray && symbolic_type(eltype(sym)) !== NotSymbolic() || sym isa Tuple) && all(x -> is_parameter(A, x), sym)
358-
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
359-
end
360-
end
361-
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym)
362-
$paramcheck
363-
getu(A, sym)(A)
363+
for (symtype, elsymtype, valtype, errcheck) in [
364+
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))),
365+
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))),
366+
(NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray},
367+
:(all(x -> is_parameter(A, x), sym))),
368+
]
369+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
370+
::$elsymtype, sym::$valtype)
371+
if $errcheck
372+
throw(ParameterIndexingError(sym))
364373
end
365-
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg)
366-
$paramcheck
367-
getu(A, sym)(A, arg)
374+
getu(A, sym)(A)
375+
end
376+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
377+
::$elsymtype, sym::$valtype, arg)
378+
if $errcheck
379+
throw(ParameterIndexingError(sym))
368380
end
369-
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Union{AbstractArray{Int}, AbstractArray{Bool}})
370-
$paramcheck
371-
getu(A, sym).((A,), arg)
381+
getu(A, sym)(A, arg)
382+
end
383+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
384+
::$elsymtype, sym::$valtype, arg::Union{AbstractArray{Int}, AbstractArray{Bool}})
385+
if $errcheck
386+
throw(ParameterIndexingError(sym))
372387
end
373-
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype, sym, arg::Colon)
374-
$paramcheck
375-
getu(A, sym)(A)
388+
getu(A, sym).((A,), arg)
389+
end
390+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
391+
::$elsymtype, sym::$valtype, ::Colon)
392+
if $errcheck
393+
throw(ParameterIndexingError(sym))
376394
end
395+
getu(A, sym)(A)
396+
end
377397
end
378398

379399
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,
380-
::SymbolicIndexingInterface.SolvedVariables, args...)
400+
::NotSymbolic, ::SymbolicIndexingInterface.SolvedVariables, args...)
381401
return getindex(A, variable_symbols(A), args...)
382402
end
383403

384404
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,
385-
::SymbolicIndexingInterface.AllVariables, args...)
405+
::NotSymbolic, ::SymbolicIndexingInterface.AllVariables, args...)
386406
return getindex(A, all_variable_symbols(A), args...)
387407
end
388408

389409
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, args...)
390410
symtype = symbolic_type(_arg)
391411
elsymtype = symbolic_type(eltype(_arg))
392412

393-
if symtype != NotSymbolic()
394-
return _getindex(A, symtype, _arg, args...)
413+
if symtype == NotSymbolic() && elsymtype == NotSymbolic()
414+
if _arg isa Union{Tuple, AbstractArray} && any(x -> symbolic_type(x) != NotSymbolic(), _arg)
415+
_getindex(A, symtype, elsymtype, _arg, args...)
416+
else
417+
_getindex(A, symtype, _arg, args...)
418+
end
395419
else
396-
return _getindex(A, elsymtype, _arg, args...)
420+
_getindex(A, symtype, elsymtype, _arg, args...)
397421
end
398422
end
399423

test/downstream/symbol_indexing.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,15 @@ sol_ts = sol(ts)
6161
@assert sol_ts isa DiffEqArray
6262
test_tables_interface(sol_ts, [:timestamp, Symbol("x(t)"), Symbol("y(t)")],
6363
hcat(ts, Array(sol_ts)'))
64+
65+
# Array variables
66+
using LinearAlgebra
67+
sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0
68+
ps = @parameters p[1:3] = [1, 2, 3]
69+
eqs = [collect(D.(x) .~ x)
70+
D(y) ~ norm(collect(x)) * y - x[1]]
71+
@mtkbuild sys = ODESystem(eqs, t, sts, ps)
72+
prob = ODEProblem(sys, [], (0, 1.0))
73+
sol = solve(prob, Tsit5())
74+
@test sol[x .+ [y, 2y, 3y]] vcat.(getindex.((sol,), [x[1] + y, x[2] + 2y, x[3] + 3y])...)
75+
@test sol[x, :] sol[x]

0 commit comments

Comments
 (0)