Skip to content

Commit 8f71f30

Browse files
feat: reduce reliance on metadata in structural_simplify
1 parent ce959f7 commit 8f71f30

File tree

1 file changed

+114
-84
lines changed

1 file changed

+114
-84
lines changed

src/systems/systemstructure.jl

Lines changed: 114 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -254,44 +254,60 @@ function Base.push!(ev::EquationsView, eq)
254254
push!(ev.ts.extra_eqs, eq)
255255
end
256256

257-
function is_time_dependent_parameter(p, iv)
258-
return iv !== nothing && isparameter(p) && iscall(p) &&
259-
(operation(p) === getindex && is_time_dependent_parameter(arguments(p)[1], iv) ||
257+
function is_time_dependent_parameter(p, allps, iv)
258+
return iv !== nothing && p in allps && iscall(p) &&
259+
(operation(p) === getindex && is_time_dependent_parameter(arguments(p)[1], allps, iv) ||
260260
(args = arguments(p); length(args)) == 1 && isequal(only(args), iv))
261261
end
262262

263+
function symbolic_contains(var, set)
264+
var in set || symbolic_type(var) == ArraySymbolic() && Symbolics.shape(var) != Symbolics.Unknown() && all(i -> var[i] in set, eachindex(var))
265+
end
266+
263267
function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
268+
# flatten system
264269
sys = flatten(sys)
265270
ivs = independent_variables(sys)
266271
iv = length(ivs) == 1 ? ivs[1] : nothing
267-
# scalarize array equations, without scalarizing arguments to registered functions
268-
eqs = flatten_equations(copy(equations(sys)))
272+
# flatten array equations
273+
eqs = flatten_equations(equations(sys))
269274
neqs = length(eqs)
270-
dervaridxs = OrderedSet{Int}()
271-
var2idx = Dict{Any, Int}()
272-
symbolic_incidence = []
273-
fullvars = []
274275
param_derivative_map = Dict{BasicSymbolic, Any}()
275-
var_counter = Ref(0)
276-
var_types = VariableType[]
277-
addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
276+
# * Scalarize unknowns
277+
dvs = Set{BasicSymbolic}()
278+
fullvars = BasicSymbolic[]
279+
for x in unknowns(sys)
280+
push!(dvs, x)
281+
xx = Symbolics.scalarize(x)
282+
if xx isa AbstractArray
283+
union!(dvs, xx)
284+
append!(fullvars, xx)
285+
else
286+
push!(fullvars, xx)
287+
end
288+
end
289+
ps = Set{BasicSymbolic}()
290+
for x in parameters(sys)
291+
push!(ps, x)
292+
xx = Symbolics.scalarize(x)
293+
xx isa AbstractArray && union!(dvs, x)
294+
end
295+
var2idx = Dict{BasicSymbolic, Int}(v => k for (k, v) in enumerate(fullvars))
296+
addvar! = let fullvars = fullvars, dvs = dvs, var2idx = var2idx
278297
var -> get!(var2idx, var) do
298+
push!(dvs, var)
279299
push!(fullvars, var)
280-
push!(var_types, getvariabletype(var))
281-
var_counter[] += 1
300+
return length(fullvars)
282301
end
283302
end
284303

285-
vars = OrderedSet()
286-
varsvec = []
304+
# build symbolic incidence
305+
symbolic_incidence = Vector{BasicSymbolic}[]
306+
varsbuf = Set()
287307
eqs_to_retain = trues(length(eqs))
288-
for (i, eq′) in enumerate(eqs)
289-
if eq′.lhs isa Connection
290-
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
291-
return nothing
292-
end
308+
for (i, eq) in enumerate(eqs)
293309
if iscall(eq′.lhs) && (op = operation(eq′.lhs)) isa Differential &&
294-
isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq′.lhs)), iv)
310+
isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq′.lhs)), ps, iv)
295311
# parameter derivatives are opted out by specifying `D(p) ~ missing`, but
296312
# we want to store `nothing` in the map because that means `fast_substitute`
297313
# will ignore the rule. We will this identify the presence of `eq′.lhs` in
@@ -301,80 +317,71 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
301317
# change the equation if the RHS is `missing` so the rest of this loop works
302318
eq′ = eq′.lhs ~ coalesce(eq′.rhs, 0.0)
303319
end
304-
if _iszero(eq′.lhs)
305-
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
306-
eq = eq′
307-
else
308-
lhs = quick_cancel ? quick_cancel_expr(eq′.lhs) : eq′.lhs
309-
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
310-
eq = 0 ~ rhs - lhs
320+
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
321+
if !_iszero(eq.lhs)
322+
lhs = quick_cancel ? quick_cancel_expr(eq.lhs) : eq.lhs
323+
eq = eqs[i] = 0 ~ rhs - lhs
311324
end
312-
vars!(vars, eq.rhs, op = Symbolics.Operator)
313-
for v in vars
314-
_var, _ = var_from_nested_derivative(v)
315-
any(isequal(_var), ivs) && continue
316-
if isparameter(_var) ||
317-
(iscall(_var) && isparameter(operation(_var)) || isconstant(_var))
318-
if is_time_dependent_parameter(_var, iv) &&
319-
!haskey(param_derivative_map, Differential(iv)(_var))
325+
empty!(varsbuf)
326+
vars!(varsbuf, eq; op = Symbolics.Operator)
327+
incidence = Set{BasicSymbolic}()
328+
for v in varsbuf
329+
# FIXME: This check still needs to rely on metadata
330+
isconstant(v) && continue
331+
vtype = getvariabletype(v)
332+
# additionally track brownians in fullvars
333+
# TODO: When uniting system types, track brownians in their own field
334+
if vtype == BROWNIAN
335+
i = addvar!(v)
336+
push!(incidence, v)
337+
end
338+
339+
if symbolic_contains(v, ps)
340+
if is_time_dependent_parameter(v, ps, iv) && !haskey(param_derivative_map, Differential(iv)(_var))
320341
# Parameter derivatives default to zero - they stay constant
321342
# between callbacks
322343
param_derivative_map[Differential(iv)(_var)] = 0.0
323344
end
324345
continue
325346
end
326-
v = scalarize(v)
327-
if v isa AbstractArray
328-
append!(varsvec, v)
329-
else
330-
push!(varsvec, v)
331-
end
332-
end
333-
isalgeq = true
334-
unknownvars = []
335-
for var in varsvec
336-
ModelingToolkit.isdelay(var, iv) && continue
337-
set_incidence = true
338-
@label ANOTHER_VAR
339-
_var, _ = var_from_nested_derivative(var)
340-
any(isequal(_var), ivs) && continue
341-
if isparameter(_var) ||
342-
(iscall(_var) && isparameter(operation(_var)) || isconstant(_var))
343-
continue
344-
end
345-
varidx = addvar!(var)
346-
set_incidence && push!(unknownvars, var)
347-
348-
dvar = var
349-
idx = varidx
350-
while isdifferential(dvar)
351-
if !(idx in dervaridxs)
352-
push!(dervaridxs, idx)
347+
348+
if !symbolic_contains(v, dvs)
349+
isvalid = iscall(v) && operation(v) isa Union{Shift, Sample, Hold}
350+
v′ = v
351+
while !isvalid && iscall(v′) && operation(v′) isa Union{Differential, Shift}
352+
v′ = arguments(v)[1]
353+
if v′ in dvs || getmetadata(v′, SymScope, LocalScope()) isa GlobalScope
354+
isvalid = true
355+
break
356+
end
357+
end
358+
if !isvalid
359+
throw(ArgumentError("$v is present in the system but $v′ is not an unknown."))
353360
end
354-
isalgeq = false
355-
dvar = arguments(dvar)[1]
356-
idx = addvar!(dvar)
357-
end
358361

359-
dvar = var
360-
idx = varidx
362+
addvar!(v)
363+
if iscall(v) && operation(v) isa Symbolics.Operator && !isdifferential(v) && (it = input_timedomain(v)) !== nothing
364+
v′ = only(arguments(v))
365+
addvar!(setmetadata(v′, VariableTimeDomain, it))
366+
end
367+
end
361368

362-
if iscall(var) && operation(var) isa Symbolics.Operator &&
363-
!isdifferential(var) && (it = input_timedomain(var)) !== nothing
364-
set_incidence = false
365-
var = only(arguments(var))
366-
var = setmetadata(var, VariableTimeDomain, it)
367-
@goto ANOTHER_VAR
369+
if symbolic_type(v) == ArraySymbolic()
370+
union!(incidence, collect(v))
371+
else
372+
push!(incidence, v)
368373
end
369374
end
370-
push!(symbolic_incidence, copy(unknownvars))
371-
empty!(unknownvars)
372-
empty!(vars)
373-
empty!(varsvec)
374-
if isalgeq
375-
eqs[i] = eq
376-
else
377-
eqs[i] = eqs[i].lhs ~ rhs
375+
376+
push!(symbolic_incidence, collect(incidence))
377+
end
378+
379+
dervaridxs = Int[]
380+
for (i, v) in enumerate(fullvars)
381+
while isdifferential(v)
382+
push!(dervaridxs, i)
383+
v = arguments(v)[1]
384+
i = addvar!(v)
378385
end
379386
end
380387
eqs = eqs[eqs_to_retain]
@@ -389,6 +396,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
389396
symbolic_incidence = symbolic_incidence[sortidxs]
390397
end
391398

399+
# Handle shifts - find lowest shift and add intermediates with derivative edges
392400
### Handle discrete variables
393401
lowest_shift = Dict()
394402
for var in fullvars
@@ -428,6 +436,9 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
428436
end
429437
end
430438
end
439+
440+
var_types = Vector{VariableType}(getvariabletype.(fullvars))
441+
431442
# sort `fullvars` such that the mass matrix is as diagonal as possible.
432443
dervaridxs = collect(dervaridxs)
433444
sorted_fullvars = OrderedSet(fullvars[dervaridxs])
@@ -451,6 +462,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
451462
var2idx = Dict(fullvars .=> eachindex(fullvars))
452463
dervaridxs = 1:length(dervaridxs)
453464

465+
# build `var_to_diff`
454466
nvars = length(fullvars)
455467
diffvars = []
456468
var_to_diff = DiffGraph(nvars, true)
@@ -462,13 +474,15 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
462474
var_to_diff[diffvaridx] = dervaridx
463475
end
464476

477+
# build incidence graph
465478
graph = BipartiteGraph(neqs, nvars, Val(false))
466479
for (ie, vars) in enumerate(symbolic_incidence), v in vars
467480
jv = var2idx[v]
468481
add_edge!(graph, ie, jv)
469482
end
470483

471484
@set! sys.eqs = eqs
485+
@set! sys.unknowns = [v for (i, v) in enumerate(fullvars) if var_types[i] != BROWNIAN]
472486

473487
eq_to_diff = DiffGraph(nsrcs(graph))
474488

@@ -731,3 +745,19 @@ function _structural_simplify!(state::TearingState; simplify = false,
731745

732746
ModelingToolkit.invalidate_cache!(sys)
733747
end
748+
749+
struct DifferentiatedVariableNotUnknownError <: Exception
750+
differentiated
751+
undifferentiated
752+
end
753+
754+
function Base.showerror(io::IO, err::DifferentiatedVariableNotUnknownError)
755+
undiff = err.undifferentiated
756+
diff = err.differentiated
757+
print(io, "Variable $undiff occurs differentiated as $diff but is not an unknown of the system.")
758+
scope = getmetadata(undiff, SymScope, LocalScope())
759+
depth = expected_scope_depth(scope)
760+
if depth > 0
761+
print(io, "\nVariable $undiff expects $depth more levels in the hierarchy to be an unknown.")
762+
end
763+
end

0 commit comments

Comments
 (0)