diff --git a/Project.toml b/Project.toml index 932ab47fd..9a9360b21 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index eb5259c63..9c39e8cbb 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -34,6 +34,7 @@ using Setfield using ForwardDiff +import FunctionWrappersWrappers @reexport using SciMLBase using SciMLBase: @def, DEIntegrator, DEProblem, AbstractDiffEqOperator, @@ -119,6 +120,7 @@ include("init.jl") include("forwarddiff.jl") include("chainrules.jl") +include("norecompile.jl") # This is only used for oop stiff solvers default_factorize(A) = lu(A; check = false) diff --git a/src/norecompile.jl b/src/norecompile.jl new file mode 100644 index 000000000..3ac09128a --- /dev/null +++ b/src/norecompile.jl @@ -0,0 +1,54 @@ +struct OrdinaryDiffEqTag end + +const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1} +const NORECOMPILE_SUPPORTED_ARGS = (Tuple{Vector{Float64}, Vector{Float64}, + Vector{Float64}, Float64}, + Tuple{Vector{Float64}, Vector{Float64}, + SciMLBase.NullParameters, Float64}) +const arglists = (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}, + Tuple{Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64 + }, + Tuple{Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT}, + Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64}, + Tuple{Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64}, + Tuple{Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT}) +const iip_returnlists = ntuple(x -> Nothing, length(arglists)) +function void(@nospecialize(f::Function)) + Base.Experimental.@opaque (args...)->f(args...) +end + +const oop_returnlists = (Vector{Float64},Vector{Float64}, + ntuple(x -> Vector{dualT}, length(arglists)-2)...) + +function typestablemapping(@nospecialize(f::Function)) + Base.Experimental.@opaque (args...)->f(args...) +end + +const NORECOMPILE_ARGUMENT_MESSAGE = """ + No-recompile mode is only supported for state arguments + of type `Vector{Float64}`, time arguments of `Float64` + and parameter arguments of type `Vector{Float64}` or + `SciMLBase.NullParameters`. + """ + +struct NoRecompileArgumentError <: Exception + args +end + +function Base.showerror(io::IO, e::NoRecompileArgumentError) + println(io, NORECOMPILE_ARGUMENT_MESSAGE) + print(io, "Attempted arguments: ") + print(io, e.args) +end + +function wrapfun_oop(ff, inputs::Tuple) + void(ff) +end + +function wrapfun_iip(ff, inputs::Tuple) + void(ff) +end + +function unwrap_fw(fw::FunctionWrapper) + fw.obj[] +end diff --git a/src/utils.jl b/src/utils.jl index bfda082f8..7fcfd3369 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,17 +1,3 @@ -function wrapfun_oop(ff, inputs::Tuple) - IT = map(typeof, inputs) - FunctionWrapper{IT[1], Tuple{IT...}}((args...) -> (ff(args...))) -end - -function wrapfun_iip(ff, inputs::Tuple) - IT = map(typeof, inputs) - FunctionWrapper{Nothing, Tuple{IT...}}((args...) -> (ff(args...); nothing)) -end - -function unwrap_fw(fw::FunctionWrapper) - fw.obj[] -end - _vec(v) = vec(v) _vec(v::Number) = v _vec(v::AbstractVector) = v