@@ -253,105 +253,106 @@ function Base.push!(ev::EquationsView, eq)
253
253
push! (ev. ts. extra_eqs, eq)
254
254
end
255
255
256
+ function symbolic_contains (var, set)
257
+ var in set || symbolic_type (var) == ArraySymbolic () && Symbolics. shape (var) != Symbolics. Unknown () && all (i -> var[i] in set, eachindex (var))
258
+ end
259
+
256
260
function TearingState (sys; quick_cancel = false , check = true )
261
+ # flatten system
257
262
sys = flatten (sys)
258
263
ivs = independent_variables (sys)
259
264
iv = length (ivs) == 1 ? ivs[1 ] : nothing
260
- # scalarize array equations, without scalarizing arguments to registered functions
261
- eqs = flatten_equations (copy ( equations (sys) ))
265
+ # flatten array equations
266
+ eqs = flatten_equations (equations (sys))
262
267
neqs = length (eqs)
263
- dervaridxs = OrderedSet {Int} ()
264
- var2idx = Dict {Any, Int} ()
265
- symbolic_incidence = []
266
- fullvars = []
267
- var_counter = Ref (0 )
268
- var_types = VariableType[]
269
- addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
268
+ # * Scalarize unknowns
269
+ dvs = Set {BasicSymbolic} ()
270
+ fullvars = BasicSymbolic[]
271
+ for x in unknowns (sys)
272
+ push! (dvs, x)
273
+ xx = Symbolics. scalarize (x)
274
+ if xx isa AbstractArray
275
+ union! (dvs, xx)
276
+ append! (fullvars, xx)
277
+ else
278
+ push! (fullvars, xx)
279
+ end
280
+ end
281
+ var2idx = Dict {BasicSymbolic, Int} (v => k for (k, v) in enumerate (fullvars))
282
+ addvar! = let fullvars = fullvars, dvs = dvs, var2idx = var2idx
270
283
var -> get! (var2idx, var) do
284
+ push! (dvs, var)
271
285
push! (fullvars, var)
272
- push! (var_types, getvariabletype (var))
273
- var_counter[] += 1
286
+ return length (fullvars)
274
287
end
275
288
end
276
289
277
- vars = OrderedSet ()
278
- varsvec = []
279
- for (i, eq′) in enumerate (eqs)
280
- if eq′. lhs isa Connection
281
- check ? error (" $(nameof (sys)) has unexpanded `connect` statements" ) :
282
- return nothing
283
- end
284
- if _iszero (eq′. lhs)
285
- rhs = quick_cancel ? quick_cancel_expr (eq′. rhs) : eq′. rhs
286
- eq = eq′
287
- else
288
- lhs = quick_cancel ? quick_cancel_expr (eq′. lhs) : eq′. lhs
289
- rhs = quick_cancel ? quick_cancel_expr (eq′. rhs) : eq′. rhs
290
- eq = 0 ~ rhs - lhs
290
+ # build symbolic incidence
291
+ symbolic_incidence = Vector{BasicSymbolic}[]
292
+ varsbuf = Set ()
293
+ for (i, eq) in enumerate (eqs)
294
+ rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
295
+ if ! _iszero (eq. lhs)
296
+ lhs = quick_cancel ? quick_cancel_expr (eq. lhs) : eq. lhs
297
+ eq = eqs[i] = 0 ~ rhs - lhs
291
298
end
292
- vars! (vars, eq. rhs, op = Symbolics. Operator)
293
- for v in vars
294
- _var, _ = var_from_nested_derivative (v)
295
- any (isequal (_var), ivs) && continue
296
- if isparameter (_var) ||
297
- (iscall (_var) && isparameter (operation (_var)) || isconstant (_var))
298
- continue
299
+ empty! (varsbuf)
300
+ vars! (varsbuf, eq; op = Symbolics. Operator)
301
+ incidence = Set {BasicSymbolic} ()
302
+ for v in varsbuf
303
+ # FIXME : This check still needs to rely on metadata
304
+ isconstant (v) && continue
305
+ vtype = getvariabletype (v)
306
+ # additionally track brownians in fullvars
307
+ # TODO : When uniting system types, track brownians in their own field
308
+ if vtype == BROWNIAN
309
+ i = addvar! (v)
310
+ push! (incidence, v)
299
311
end
300
- v = scalarize (v)
301
- if v isa AbstractArray
302
- append! (varsvec, v)
303
- else
304
- push! (varsvec, v)
305
- end
306
- end
307
- isalgeq = true
308
- unknownvars = []
309
- for var in varsvec
310
- ModelingToolkit. isdelay (var, iv) && continue
311
- set_incidence = true
312
- @label ANOTHER_VAR
313
- _var, _ = var_from_nested_derivative (var)
314
- any (isequal (_var), ivs) && continue
315
- if isparameter (_var) ||
316
- (iscall (_var) && isparameter (operation (_var)) || isconstant (_var))
317
- continue
318
- end
319
- varidx = addvar! (var)
320
- set_incidence && push! (unknownvars, var)
321
-
322
- dvar = var
323
- idx = varidx
324
- while isdifferential (dvar)
325
- if ! (idx in dervaridxs)
326
- push! (dervaridxs, idx)
312
+
313
+ vtype == VARIABLE || continue
314
+
315
+ if ! symbolic_contains (v, dvs)
316
+ isvalid = iscall (v) && operation (v) isa Union{Shift, Sample, Hold}
317
+ v′ = v
318
+ while ! isvalid && iscall (v′) && operation (v′) isa Union{Differential, Shift}
319
+ v′ = arguments (v)[1 ]
320
+ if v′ in dvs || getmetadata (v′, SymScope, LocalScope ()) isa GlobalScope
321
+ isvalid = true
322
+ break
323
+ end
324
+ end
325
+ if ! isvalid
326
+ throw (ArgumentError (" $v is present in the system but $v′ is not an unknown." ))
327
327
end
328
- isalgeq = false
329
- dvar = arguments (dvar)[1 ]
330
- idx = addvar! (dvar)
331
- end
332
328
333
- dvar = var
334
- idx = varidx
329
+ addvar! (v)
330
+ if iscall (v) && operation (v) isa Symbolics. Operator && ! isdifferential (v) && (it = input_timedomain (v)) != = nothing
331
+ v′ = only (arguments (v))
332
+ addvar! (setmetadata (v′, VariableTimeDomain, it))
333
+ end
334
+ end
335
335
336
- if iscall (var) && operation (var) isa Symbolics. Operator &&
337
- ! isdifferential (var) && (it = input_timedomain (var)) != = nothing
338
- set_incidence = false
339
- var = only (arguments (var))
340
- var = setmetadata (var, VariableTimeDomain, it)
341
- @goto ANOTHER_VAR
336
+ if symbolic_type (v) == ArraySymbolic ()
337
+ union! (incidence, collect (v))
338
+ else
339
+ push! (incidence, v)
342
340
end
343
341
end
344
- push! (symbolic_incidence, copy (unknownvars))
345
- empty! (unknownvars)
346
- empty! (vars)
347
- empty! (varsvec)
348
- if isalgeq
349
- eqs[i] = eq
350
- else
351
- eqs[i] = eqs[i]. lhs ~ rhs
342
+
343
+ push! (symbolic_incidence, collect (incidence))
344
+ end
345
+
346
+ dervaridxs = Int[]
347
+ for (i, v) in enumerate (fullvars)
348
+ while isdifferential (v)
349
+ push! (dervaridxs, i)
350
+ v = arguments (v)[1 ]
351
+ i = addvar! (v)
352
352
end
353
353
end
354
354
355
+ # Handle shifts - find lowest shift and add intermediates with derivative edges
355
356
# ## Handle discrete variables
356
357
lowest_shift = Dict ()
357
358
for var in fullvars
@@ -391,6 +392,9 @@ function TearingState(sys; quick_cancel = false, check = true)
391
392
end
392
393
end
393
394
end
395
+
396
+ var_types = Vector {VariableType} (getvariabletype .(fullvars))
397
+
394
398
# sort `fullvars` such that the mass matrix is as diagonal as possible.
395
399
dervaridxs = collect (dervaridxs)
396
400
sorted_fullvars = OrderedSet (fullvars[dervaridxs])
@@ -414,6 +418,7 @@ function TearingState(sys; quick_cancel = false, check = true)
414
418
var2idx = Dict (fullvars .=> eachindex (fullvars))
415
419
dervaridxs = 1 : length (dervaridxs)
416
420
421
+ # build `var_to_diff`
417
422
nvars = length (fullvars)
418
423
diffvars = []
419
424
var_to_diff = DiffGraph (nvars, true )
@@ -425,20 +430,24 @@ function TearingState(sys; quick_cancel = false, check = true)
425
430
var_to_diff[diffvaridx] = dervaridx
426
431
end
427
432
433
+ # build incidence graph
428
434
graph = BipartiteGraph (neqs, nvars, Val (false ))
429
435
for (ie, vars) in enumerate (symbolic_incidence), v in vars
430
436
jv = var2idx[v]
431
437
add_edge! (graph, ie, jv)
432
438
end
433
439
434
440
@set! sys. eqs = eqs
441
+ @set! sys. unknowns = [v for (i, v) in enumerate (fullvars) if var_types[i] != BROWNIAN]
435
442
436
443
eq_to_diff = DiffGraph (nsrcs (graph))
437
444
438
445
ts = TearingState (sys, fullvars,
439
446
SystemStructure (complete (var_to_diff), complete (eq_to_diff),
440
447
complete (graph), nothing , var_types, sys isa AbstractDiscreteSystem),
441
448
Any[])
449
+
450
+ # `shift_discrete_system`
442
451
if sys isa DiscreteSystem
443
452
ts = shift_discrete_system (ts)
444
453
end
@@ -726,3 +735,19 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
726
735
727
736
ModelingToolkit. invalidate_cache! (sys), input_idxs
728
737
end
738
+
739
+ struct DifferentiatedVariableNotUnknownError <: Exception
740
+ differentiated
741
+ undifferentiated
742
+ end
743
+
744
+ function Base. showerror (io:: IO , err:: DifferentiatedVariableNotUnknownError )
745
+ undiff = err. undifferentiated
746
+ diff = err. differentiated
747
+ print (io, " Variable $undiff occurs differentiated as $diff but is not an unknown of the system." )
748
+ scope = getmetadata (undiff, SymScope, LocalScope ())
749
+ depth = expected_scope_depth (scope)
750
+ if depth > 0
751
+ print (io, " \n Variable $undiff expects $depth more levels in the hierarchy to be an unknown." )
752
+ end
753
+ end
0 commit comments