@@ -254,44 +254,60 @@ function Base.push!(ev::EquationsView, eq)
254
254
push! (ev. ts. extra_eqs, eq)
255
255
end
256
256
257
- function is_time_dependent_parameter (p, iv)
258
- return iv != = nothing && isparameter (p) && iscall (p) &&
259
- (operation (p) === getindex && is_time_dependent_parameter (arguments (p)[1 ], iv) ||
257
+ function is_time_dependent_parameter (p, allps, iv)
258
+ return iv != = nothing && p in allps && iscall (p) &&
259
+ (operation (p) === getindex && is_time_dependent_parameter (arguments (p)[1 ], allps, iv) ||
260
260
(args = arguments (p); length (args)) == 1 && isequal (only (args), iv))
261
261
end
262
262
263
+ function symbolic_contains (var, set)
264
+ var in set || symbolic_type (var) == ArraySymbolic () && Symbolics. shape (var) != Symbolics. Unknown () && all (i -> var[i] in set, eachindex (var))
265
+ end
266
+
263
267
function TearingState (sys; quick_cancel = false , check = true , sort_eqs = true )
268
+ # flatten system
264
269
sys = flatten (sys)
265
270
ivs = independent_variables (sys)
266
271
iv = length (ivs) == 1 ? ivs[1 ] : nothing
267
- # scalarize array equations, without scalarizing arguments to registered functions
268
- eqs = flatten_equations (copy ( equations (sys) ))
272
+ # flatten array equations
273
+ eqs = flatten_equations (equations (sys))
269
274
neqs = length (eqs)
270
- dervaridxs = OrderedSet {Int} ()
271
- var2idx = Dict {Any, Int} ()
272
- symbolic_incidence = []
273
- fullvars = []
274
275
param_derivative_map = Dict {BasicSymbolic, Any} ()
275
- var_counter = Ref (0 )
276
- var_types = VariableType[]
277
- addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
276
+ # * Scalarize unknowns
277
+ dvs = Set {BasicSymbolic} ()
278
+ fullvars = BasicSymbolic[]
279
+ for x in unknowns (sys)
280
+ push! (dvs, x)
281
+ xx = Symbolics. scalarize (x)
282
+ if xx isa AbstractArray
283
+ union! (dvs, xx)
284
+ append! (fullvars, xx)
285
+ else
286
+ push! (fullvars, xx)
287
+ end
288
+ end
289
+ ps = Set {BasicSymbolic} ()
290
+ for x in parameters (sys)
291
+ push! (ps, x)
292
+ xx = Symbolics. scalarize (x)
293
+ xx isa AbstractArray && union! (dvs, x)
294
+ end
295
+ var2idx = Dict {BasicSymbolic, Int} (v => k for (k, v) in enumerate (fullvars))
296
+ addvar! = let fullvars = fullvars, dvs = dvs, var2idx = var2idx
278
297
var -> get! (var2idx, var) do
298
+ push! (dvs, var)
279
299
push! (fullvars, var)
280
- push! (var_types, getvariabletype (var))
281
- var_counter[] += 1
300
+ return length (fullvars)
282
301
end
283
302
end
284
303
285
- vars = OrderedSet ()
286
- varsvec = []
304
+ # build symbolic incidence
305
+ symbolic_incidence = Vector{BasicSymbolic}[]
306
+ varsbuf = Set ()
287
307
eqs_to_retain = trues (length (eqs))
288
- for (i, eq′) in enumerate (eqs)
289
- if eq′. lhs isa Connection
290
- check ? error (" $(nameof (sys)) has unexpanded `connect` statements" ) :
291
- return nothing
292
- end
308
+ for (i, eq) in enumerate (eqs)
293
309
if iscall (eq′. lhs) && (op = operation (eq′. lhs)) isa Differential &&
294
- isequal (op. x, iv) && is_time_dependent_parameter (only (arguments (eq′. lhs)), iv)
310
+ isequal (op. x, iv) && is_time_dependent_parameter (only (arguments (eq′. lhs)), ps, iv)
295
311
# parameter derivatives are opted out by specifying `D(p) ~ missing`, but
296
312
# we want to store `nothing` in the map because that means `fast_substitute`
297
313
# will ignore the rule. We will this identify the presence of `eq′.lhs` in
@@ -301,80 +317,71 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
301
317
# change the equation if the RHS is `missing` so the rest of this loop works
302
318
eq′ = eq′. lhs ~ coalesce (eq′. rhs, 0.0 )
303
319
end
304
- if _iszero (eq′. lhs)
305
- rhs = quick_cancel ? quick_cancel_expr (eq′. rhs) : eq′. rhs
306
- eq = eq′
307
- else
308
- lhs = quick_cancel ? quick_cancel_expr (eq′. lhs) : eq′. lhs
309
- rhs = quick_cancel ? quick_cancel_expr (eq′. rhs) : eq′. rhs
310
- eq = 0 ~ rhs - lhs
320
+ rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
321
+ if ! _iszero (eq. lhs)
322
+ lhs = quick_cancel ? quick_cancel_expr (eq. lhs) : eq. lhs
323
+ eq = eqs[i] = 0 ~ rhs - lhs
311
324
end
312
- vars! (vars, eq. rhs, op = Symbolics. Operator)
313
- for v in vars
314
- _var, _ = var_from_nested_derivative (v)
315
- any (isequal (_var), ivs) && continue
316
- if isparameter (_var) ||
317
- (iscall (_var) && isparameter (operation (_var)) || isconstant (_var))
318
- if is_time_dependent_parameter (_var, iv) &&
319
- ! haskey (param_derivative_map, Differential (iv)(_var))
325
+ empty! (varsbuf)
326
+ vars! (varsbuf, eq; op = Symbolics. Operator)
327
+ incidence = Set {BasicSymbolic} ()
328
+ for v in varsbuf
329
+ # FIXME : This check still needs to rely on metadata
330
+ isconstant (v) && continue
331
+ vtype = getvariabletype (v)
332
+ # additionally track brownians in fullvars
333
+ # TODO : When uniting system types, track brownians in their own field
334
+ if vtype == BROWNIAN
335
+ i = addvar! (v)
336
+ push! (incidence, v)
337
+ end
338
+
339
+ if symbolic_contains (v, ps)
340
+ if is_time_dependent_parameter (v, ps, iv) && ! haskey (param_derivative_map, Differential (iv)(_var))
320
341
# Parameter derivatives default to zero - they stay constant
321
342
# between callbacks
322
343
param_derivative_map[Differential (iv)(_var)] = 0.0
323
344
end
324
345
continue
325
346
end
326
- v = scalarize (v)
327
- if v isa AbstractArray
328
- append! (varsvec, v)
329
- else
330
- push! (varsvec, v)
331
- end
332
- end
333
- isalgeq = true
334
- unknownvars = []
335
- for var in varsvec
336
- ModelingToolkit. isdelay (var, iv) && continue
337
- set_incidence = true
338
- @label ANOTHER_VAR
339
- _var, _ = var_from_nested_derivative (var)
340
- any (isequal (_var), ivs) && continue
341
- if isparameter (_var) ||
342
- (iscall (_var) && isparameter (operation (_var)) || isconstant (_var))
343
- continue
344
- end
345
- varidx = addvar! (var)
346
- set_incidence && push! (unknownvars, var)
347
-
348
- dvar = var
349
- idx = varidx
350
- while isdifferential (dvar)
351
- if ! (idx in dervaridxs)
352
- push! (dervaridxs, idx)
347
+
348
+ if ! symbolic_contains (v, dvs)
349
+ isvalid = iscall (v) && operation (v) isa Union{Shift, Sample, Hold}
350
+ v′ = v
351
+ while ! isvalid && iscall (v′) && operation (v′) isa Union{Differential, Shift}
352
+ v′ = arguments (v)[1 ]
353
+ if v′ in dvs || getmetadata (v′, SymScope, LocalScope ()) isa GlobalScope
354
+ isvalid = true
355
+ break
356
+ end
357
+ end
358
+ if ! isvalid
359
+ throw (ArgumentError (" $v is present in the system but $v′ is not an unknown." ))
353
360
end
354
- isalgeq = false
355
- dvar = arguments (dvar)[1 ]
356
- idx = addvar! (dvar)
357
- end
358
361
359
- dvar = var
360
- idx = varidx
362
+ addvar! (v)
363
+ if iscall (v) && operation (v) isa Symbolics. Operator && ! isdifferential (v) && (it = input_timedomain (v)) != = nothing
364
+ v′ = only (arguments (v))
365
+ addvar! (setmetadata (v′, VariableTimeDomain, it))
366
+ end
367
+ end
361
368
362
- if iscall (var) && operation (var) isa Symbolics. Operator &&
363
- ! isdifferential (var) && (it = input_timedomain (var)) != = nothing
364
- set_incidence = false
365
- var = only (arguments (var))
366
- var = setmetadata (var, VariableTimeDomain, it)
367
- @goto ANOTHER_VAR
369
+ if symbolic_type (v) == ArraySymbolic ()
370
+ union! (incidence, collect (v))
371
+ else
372
+ push! (incidence, v)
368
373
end
369
374
end
370
- push! (symbolic_incidence, copy (unknownvars))
371
- empty! (unknownvars)
372
- empty! (vars)
373
- empty! (varsvec)
374
- if isalgeq
375
- eqs[i] = eq
376
- else
377
- eqs[i] = eqs[i]. lhs ~ rhs
375
+
376
+ push! (symbolic_incidence, collect (incidence))
377
+ end
378
+
379
+ dervaridxs = Int[]
380
+ for (i, v) in enumerate (fullvars)
381
+ while isdifferential (v)
382
+ push! (dervaridxs, i)
383
+ v = arguments (v)[1 ]
384
+ i = addvar! (v)
378
385
end
379
386
end
380
387
eqs = eqs[eqs_to_retain]
@@ -389,6 +396,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
389
396
symbolic_incidence = symbolic_incidence[sortidxs]
390
397
end
391
398
399
+ # Handle shifts - find lowest shift and add intermediates with derivative edges
392
400
# ## Handle discrete variables
393
401
lowest_shift = Dict ()
394
402
for var in fullvars
@@ -428,6 +436,9 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
428
436
end
429
437
end
430
438
end
439
+
440
+ var_types = Vector {VariableType} (getvariabletype .(fullvars))
441
+
431
442
# sort `fullvars` such that the mass matrix is as diagonal as possible.
432
443
dervaridxs = collect (dervaridxs)
433
444
sorted_fullvars = OrderedSet (fullvars[dervaridxs])
@@ -451,6 +462,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
451
462
var2idx = Dict (fullvars .=> eachindex (fullvars))
452
463
dervaridxs = 1 : length (dervaridxs)
453
464
465
+ # build `var_to_diff`
454
466
nvars = length (fullvars)
455
467
diffvars = []
456
468
var_to_diff = DiffGraph (nvars, true )
@@ -462,13 +474,15 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
462
474
var_to_diff[diffvaridx] = dervaridx
463
475
end
464
476
477
+ # build incidence graph
465
478
graph = BipartiteGraph (neqs, nvars, Val (false ))
466
479
for (ie, vars) in enumerate (symbolic_incidence), v in vars
467
480
jv = var2idx[v]
468
481
add_edge! (graph, ie, jv)
469
482
end
470
483
471
484
@set! sys. eqs = eqs
485
+ @set! sys. unknowns = [v for (i, v) in enumerate (fullvars) if var_types[i] != BROWNIAN]
472
486
473
487
eq_to_diff = DiffGraph (nsrcs (graph))
474
488
@@ -731,3 +745,19 @@ function _structural_simplify!(state::TearingState; simplify = false,
731
745
732
746
ModelingToolkit. invalidate_cache! (sys)
733
747
end
748
+
749
+ struct DifferentiatedVariableNotUnknownError <: Exception
750
+ differentiated
751
+ undifferentiated
752
+ end
753
+
754
+ function Base. showerror (io:: IO , err:: DifferentiatedVariableNotUnknownError )
755
+ undiff = err. undifferentiated
756
+ diff = err. differentiated
757
+ print (io, " Variable $undiff occurs differentiated as $diff but is not an unknown of the system." )
758
+ scope = getmetadata (undiff, SymScope, LocalScope ())
759
+ depth = expected_scope_depth (scope)
760
+ if depth > 0
761
+ print (io, " \n Variable $undiff expects $depth more levels in the hierarchy to be an unknown." )
762
+ end
763
+ end
0 commit comments