diff --git a/Project.toml b/Project.toml index 5b7e510379..ee1ce3b60c 100644 --- a/Project.toml +++ b/Project.toml @@ -78,6 +78,7 @@ julia = "1.2" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GalacticOptim = "a75be94c-b780-496d-a8a9-0878b188d577" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" @@ -91,4 +92,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["BenchmarkTools", "ForwardDiff", "GalacticOptim", "NonlinearSolve", "OrdinaryDiffEq", "Optim", "Random", "ReferenceTests", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"] +test = ["BenchmarkTools", "DiffEqNoiseProcess", "ForwardDiff", "GalacticOptim", "NonlinearSolve", "OrdinaryDiffEq", "Optim", "Random", "ReferenceTests", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"] diff --git a/docs/src/basics/ContextualVariables.md b/docs/src/basics/ContextualVariables.md index 52035389c5..ddc5c6d5ec 100644 --- a/docs/src/basics/ContextualVariables.md +++ b/docs/src/basics/ContextualVariables.md @@ -35,7 +35,7 @@ with the unit m^3/s: @variables x[1:2,1:2] [connect = Flow; unit = u"m^3/s"] ``` -ModelingToolkit defines `connect`, `unit`, `noise`, and `description` keys for +ModelingToolkit defines `connect`, `unit`, `noise`, `default`, and `description` keys for the metadata. One can get and set metadata by ```julia diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 4b896ebdbc..04744cc707 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -492,6 +492,29 @@ function push_defaults!(stmt, defs, var2name) return defs_name end +function push_ctxvars!(stmt, name, idx, v, prop) + # name = nameof(Symbolics.operation(Symbolics.unwrap(x))) + # or + vname = Symbolics.tosymbol(v; escape=false) + ctxvar_type = Symbolics.option_to_metadata_type(Val(prop)) + if hasmetadata(v, ctxvar_type) + val = getmetadata(v, ctxvar_type) + if idx !== nothing + push!(stmt, :($name[$idx] = setmetadata($name[$idx], $ctxvar_type, $val))) + else + push!(stmt, :($name = setmetadata($v, $ctxvar_type, $val))) + end + end +end + +# function push_metadata!(stmt, vars) +function push_metadata!(stmt, name, idx, v) + for ctxvar in [:default, :connect, :unit, :noise, :description] + # for ctxvar in [:connect, :unit, :noise, :description] + push_ctxvars!(stmt, name, idx, v, ctxvar) + end +end + ### ### System I/O ### @@ -512,7 +535,17 @@ function toexpr(sys::AbstractSystem) psname = gensym(:ps) ps = parameters(sys) push_vars!(stmt, psname, Symbol("@parameters"), ps) - + + if iv !== nothing + push_metadata!(stmt, ivname, nothing, iv) + end + for (idx, st) in enumerate(sts) + push_metadata!(stmt, stsname, idx, st) + end + for (idx, p) in enumerate(ps) + push_metadata!(stmt, psname, idx, p) + end + var2name = Dict{Any,Symbol}() for v in Iterators.flatten((sts, ps)) var2name[v] = getname(v) diff --git a/src/utils.jl b/src/utils.jl index 14fa98c127..8c4a75913b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -160,6 +160,18 @@ hasdefault(v) = hasmetadata(v, Symbolics.VariableDefaultValue) getdefault(v) = value(getmetadata(v, Symbolics.VariableDefaultValue)) setdefault(v, val) = val === nothing ? v : setmetadata(v, Symbolics.VariableDefaultValue, value(val)) +# should it be get_unit or getunit (not necesary, just utility) +for prop in [:default, :connect, :unit, :noise, :description] + fname1 = Symbol(:get, prop) + fname2 = Symbol(:has, prop) + fname3 = Symbol(:set, prop) + @eval begin + $fname1(x::Num) = getmetadata(x, Symbolics.option_to_metadata_type(Val($(QuoteNode(prop))))) + $fname2(x::Num) = hasmetadata(x, Symbolics.option_to_metadata_type(Val($(QuoteNode(prop))))) + $fname3(x::Num, val) = setmetadata(x, Symbolics.option_to_metadata_type(Val($(QuoteNode(prop)))), val) + end +end + function collect_defaults!(defs, vars) for v in vars; (haskey(defs, v) || !hasdefault(v)) && continue defs[v] = getdefault(v) diff --git a/src/variables.jl b/src/variables.jl index a63c226e06..e8c844a8cb 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -1,11 +1,12 @@ struct VariableUnit end struct VariableConnectType end -struct VariabelNoiseType end -struct VariabelDescriptionType end +struct VariableNoiseType end +struct VariableDescriptionType end Symbolics.option_to_metadata_type(::Val{:unit}) = VariableUnit Symbolics.option_to_metadata_type(::Val{:connect}) = VariableConnectType -Symbolics.option_to_metadata_type(::Val{:noise}) = VariabelNoiseType -Symbolics.option_to_metadata_type(::Val{:description}) = VariabelDescriptionType +Symbolics.option_to_metadata_type(::Val{:noise}) = VariableNoiseType +Symbolics.option_to_metadata_type(::Val{:description}) = VariableDescriptionType +Symbolics.option_to_metadata_type(::Val{:default}) = Symbolics.VariableDefaultValue """ $(SIGNATURES) diff --git a/test/serialization.jl b/test/serialization.jl index ac8968cb8c..268ee57b04 100644 --- a/test/serialization.jl +++ b/test/serialization.jl @@ -1,4 +1,4 @@ -using ModelingToolkit, SciMLBase, Serialization +using ModelingToolkit, SciMLBase, Serialization, Unitful, DiffEqNoiseProcess @parameters t @variables x(t) @@ -25,3 +25,58 @@ io = IOBuffer() write(io, rc_model) sys = include_string(@__MODULE__, String(take!(io))) @test sys == flatten(rc_model) + +# test metadata is preserved in `toexpr` +W = WienerProcess(0, 0, 0) +@parameters begin + # (t=0), [unit=u"s"] + t, [unit = u"s"] + (σ = 28.), [description = "sigma"] + (ρ = 10) + (β = 8 / 3) +end + +@variables begin + (x(t) = 0), + [unit = u"m/s" + description = "rate of convection" + noise = W] + (y(t) = 0), [unit = u"m/s"; connect = Flow] + (z(t) = 0), [unit = u"m/s"] +end + +D = Differential(t) + +eqs = [D(x) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z] + +sys = ODESystem(eqs) +ps = parameters(sys) +sigma_metadata = ps[1].metadata +expr = toexpr(sys) +sys2 = eval(expr) +ps2 = parameters(sys2) +sigma_metadata2 = ps2[1].metadata +@test sigma_metadata2 == sigma_metadata # fails if :default is not in push_metadata! + +sts2 = states(sys2) +for (i, st) in enumerate(states(sys)) + @test st.metadata == sts2[i].metadata +end + +ps2 = parameters(sys2) +for (i, p) in enumerate(parameters(sys)) + @test p.metadata == ps2[i].metadata +end + +iv = independent_variable(sys) +iv2 = independent_variable(sys2) +@test iv.metadata == iv2.metadata + +# [tests nothing] +# do we want to recurse into ps, sts, etc to ensure they have equivalent metadata? +io = IOBuffer() +write(io, sys) +sys2 = include_string(@__MODULE__, String(take!(io))) +@test sys2 == flatten(sys)