Skip to content

Commit 2a2b620

Browse files
feat: reduce reliance on metadata in structural_simplify
1 parent 2579558 commit 2a2b620

File tree

1 file changed

+104
-79
lines changed

1 file changed

+104
-79
lines changed

src/systems/systemstructure.jl

Lines changed: 104 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -253,105 +253,106 @@ function Base.push!(ev::EquationsView, eq)
253253
push!(ev.ts.extra_eqs, eq)
254254
end
255255

256+
function symbolic_contains(var, set)
257+
var in set || symbolic_type(var) == ArraySymbolic() && Symbolics.shape(var) != Symbolics.Unknown() && all(i -> var[i] in set, eachindex(var))
258+
end
259+
256260
function TearingState(sys; quick_cancel = false, check = true)
261+
# flatten system
257262
sys = flatten(sys)
258263
ivs = independent_variables(sys)
259264
iv = length(ivs) == 1 ? ivs[1] : nothing
260-
# scalarize array equations, without scalarizing arguments to registered functions
261-
eqs = flatten_equations(copy(equations(sys)))
265+
# flatten array equations
266+
eqs = flatten_equations(equations(sys))
262267
neqs = length(eqs)
263-
dervaridxs = OrderedSet{Int}()
264-
var2idx = Dict{Any, Int}()
265-
symbolic_incidence = []
266-
fullvars = []
267-
var_counter = Ref(0)
268-
var_types = VariableType[]
269-
addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
268+
# * Scalarize unknowns
269+
dvs = Set{BasicSymbolic}()
270+
fullvars = BasicSymbolic[]
271+
for x in unknowns(sys)
272+
push!(dvs, x)
273+
xx = Symbolics.scalarize(x)
274+
if xx isa AbstractArray
275+
union!(dvs, xx)
276+
append!(fullvars, xx)
277+
else
278+
push!(fullvars, xx)
279+
end
280+
end
281+
var2idx = Dict{BasicSymbolic, Int}(v => k for (k, v) in enumerate(fullvars))
282+
addvar! = let fullvars = fullvars, dvs = dvs, var2idx = var2idx
270283
var -> get!(var2idx, var) do
284+
push!(dvs, var)
271285
push!(fullvars, var)
272-
push!(var_types, getvariabletype(var))
273-
var_counter[] += 1
286+
return length(fullvars)
274287
end
275288
end
276289

277-
vars = OrderedSet()
278-
varsvec = []
279-
for (i, eq′) in enumerate(eqs)
280-
if eq′.lhs isa Connection
281-
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
282-
return nothing
283-
end
284-
if _iszero(eq′.lhs)
285-
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
286-
eq = eq′
287-
else
288-
lhs = quick_cancel ? quick_cancel_expr(eq′.lhs) : eq′.lhs
289-
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
290-
eq = 0 ~ rhs - lhs
290+
# build symbolic incidence
291+
symbolic_incidence = Vector{BasicSymbolic}[]
292+
varsbuf = Set()
293+
for (i, eq) in enumerate(eqs)
294+
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
295+
if !_iszero(eq.lhs)
296+
lhs = quick_cancel ? quick_cancel_expr(eq.lhs) : eq.lhs
297+
eq = eqs[i] = 0 ~ rhs - lhs
291298
end
292-
vars!(vars, eq.rhs, op = Symbolics.Operator)
293-
for v in vars
294-
_var, _ = var_from_nested_derivative(v)
295-
any(isequal(_var), ivs) && continue
296-
if isparameter(_var) ||
297-
(iscall(_var) && isparameter(operation(_var)) || isconstant(_var))
298-
continue
299+
empty!(varsbuf)
300+
vars!(varsbuf, eq; op = Symbolics.Operator)
301+
incidence = Set{BasicSymbolic}()
302+
for v in varsbuf
303+
# FIXME: This check still needs to rely on metadata
304+
isconstant(v) && continue
305+
vtype = getvariabletype(v)
306+
# additionally track brownians in fullvars
307+
# TODO: When uniting system types, track brownians in their own field
308+
if vtype == BROWNIAN
309+
i = addvar!(v)
310+
push!(incidence, v)
299311
end
300-
v = scalarize(v)
301-
if v isa AbstractArray
302-
append!(varsvec, v)
303-
else
304-
push!(varsvec, v)
305-
end
306-
end
307-
isalgeq = true
308-
unknownvars = []
309-
for var in varsvec
310-
ModelingToolkit.isdelay(var, iv) && continue
311-
set_incidence = true
312-
@label ANOTHER_VAR
313-
_var, _ = var_from_nested_derivative(var)
314-
any(isequal(_var), ivs) && continue
315-
if isparameter(_var) ||
316-
(iscall(_var) && isparameter(operation(_var)) || isconstant(_var))
317-
continue
318-
end
319-
varidx = addvar!(var)
320-
set_incidence && push!(unknownvars, var)
321-
322-
dvar = var
323-
idx = varidx
324-
while isdifferential(dvar)
325-
if !(idx in dervaridxs)
326-
push!(dervaridxs, idx)
312+
313+
vtype == VARIABLE || continue
314+
315+
if !symbolic_contains(v, dvs)
316+
isvalid = iscall(v) && operation(v) isa Union{Shift, Sample, Hold}
317+
v′ = v
318+
while !isvalid && iscall(v′) && operation(v′) isa Union{Differential, Shift}
319+
v′ = arguments(v)[1]
320+
if v′ in dvs || getmetadata(v′, SymScope, LocalScope()) isa GlobalScope
321+
isvalid = true
322+
break
323+
end
324+
end
325+
if !isvalid
326+
throw(ArgumentError("$v is present in the system but $v′ is not an unknown."))
327327
end
328-
isalgeq = false
329-
dvar = arguments(dvar)[1]
330-
idx = addvar!(dvar)
331-
end
332328

333-
dvar = var
334-
idx = varidx
329+
addvar!(v)
330+
if iscall(v) && operation(v) isa Symbolics.Operator && !isdifferential(v) && (it = input_timedomain(v)) !== nothing
331+
v′ = only(arguments(v))
332+
addvar!(setmetadata(v′, VariableTimeDomain, it))
333+
end
334+
end
335335

336-
if iscall(var) && operation(var) isa Symbolics.Operator &&
337-
!isdifferential(var) && (it = input_timedomain(var)) !== nothing
338-
set_incidence = false
339-
var = only(arguments(var))
340-
var = setmetadata(var, VariableTimeDomain, it)
341-
@goto ANOTHER_VAR
336+
if symbolic_type(v) == ArraySymbolic()
337+
union!(incidence, collect(v))
338+
else
339+
push!(incidence, v)
342340
end
343341
end
344-
push!(symbolic_incidence, copy(unknownvars))
345-
empty!(unknownvars)
346-
empty!(vars)
347-
empty!(varsvec)
348-
if isalgeq
349-
eqs[i] = eq
350-
else
351-
eqs[i] = eqs[i].lhs ~ rhs
342+
343+
push!(symbolic_incidence, collect(incidence))
344+
end
345+
346+
dervaridxs = Int[]
347+
for (i, v) in enumerate(fullvars)
348+
while isdifferential(v)
349+
push!(dervaridxs, i)
350+
v = arguments(v)[1]
351+
i = addvar!(v)
352352
end
353353
end
354354

355+
# Handle shifts - find lowest shift and add intermediates with derivative edges
355356
### Handle discrete variables
356357
lowest_shift = Dict()
357358
for var in fullvars
@@ -391,6 +392,9 @@ function TearingState(sys; quick_cancel = false, check = true)
391392
end
392393
end
393394
end
395+
396+
var_types = Vector{VariableType}(getvariabletype.(fullvars))
397+
394398
# sort `fullvars` such that the mass matrix is as diagonal as possible.
395399
dervaridxs = collect(dervaridxs)
396400
sorted_fullvars = OrderedSet(fullvars[dervaridxs])
@@ -414,6 +418,7 @@ function TearingState(sys; quick_cancel = false, check = true)
414418
var2idx = Dict(fullvars .=> eachindex(fullvars))
415419
dervaridxs = 1:length(dervaridxs)
416420

421+
# build `var_to_diff`
417422
nvars = length(fullvars)
418423
diffvars = []
419424
var_to_diff = DiffGraph(nvars, true)
@@ -425,20 +430,24 @@ function TearingState(sys; quick_cancel = false, check = true)
425430
var_to_diff[diffvaridx] = dervaridx
426431
end
427432

433+
# build incidence graph
428434
graph = BipartiteGraph(neqs, nvars, Val(false))
429435
for (ie, vars) in enumerate(symbolic_incidence), v in vars
430436
jv = var2idx[v]
431437
add_edge!(graph, ie, jv)
432438
end
433439

434440
@set! sys.eqs = eqs
441+
@set! sys.unknowns = [v for (i, v) in enumerate(fullvars) if var_types[i] != BROWNIAN]
435442

436443
eq_to_diff = DiffGraph(nsrcs(graph))
437444

438445
ts = TearingState(sys, fullvars,
439446
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
440447
complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem),
441448
Any[])
449+
450+
# `shift_discrete_system`
442451
if sys isa DiscreteSystem
443452
ts = shift_discrete_system(ts)
444453
end
@@ -726,3 +735,19 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
726735

727736
ModelingToolkit.invalidate_cache!(sys), input_idxs
728737
end
738+
739+
struct DifferentiatedVariableNotUnknownError <: Exception
740+
differentiated
741+
undifferentiated
742+
end
743+
744+
function Base.showerror(io::IO, err::DifferentiatedVariableNotUnknownError)
745+
undiff = err.undifferentiated
746+
diff = err.differentiated
747+
print(io, "Variable $undiff occurs differentiated as $diff but is not an unknown of the system.")
748+
scope = getmetadata(undiff, SymScope, LocalScope())
749+
depth = expected_scope_depth(scope)
750+
if depth > 0
751+
print(io, "\nVariable $undiff expects $depth more levels in the hierarchy to be an unknown.")
752+
end
753+
end

0 commit comments

Comments
 (0)