From a88c6a7911a28a39128ec508c01863c1059e7db0 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 22 Mar 2022 06:08:34 -0400 Subject: [PATCH 1/9] Setup FunctionWrappersWrappers norecompile mode Needs: - https://github.com/SciML/SciMLBase.jl/pull/143 - https://github.com/SciML/OrdinaryDiffEq.jl/pull/1627 ```julia using OrdinaryDiffEq function f(du, u, p, t) du[1] = 0.2u[1] du[2] = 0.4u[2] end u0 = ones(2) tspan = (0.0, 1.0) prob = ODEProblem{true,false}(f, u0, tspan, Float64[]) function lorenz(du, u, p, t) du[1] = 10.0(u[2] - u[1]) du[2] = u[1] * (28.0 - u[3]) - u[2] du[3] = u[1] * u[2] - (8 / 3) * u[3] end lorenzprob = ODEProblem{true,false}(lorenz, [1.0; 0.0; 0.0], (0.0, 1.0), Float64[]) typeof(prob) === typeof(lorenzprob) # true @time sol = solve(prob, Rosenbrock23(autodiff=false)) @time sol = solve(prob, Rosenbrock23(chunk_size=1)) ``` ``` 2.763588 seconds (10.32 M allocations: 648.718 MiB, 4.92% gc time, 99.89% compilation time) 10.577789 seconds (45.44 M allocations: 2.760 GiB, 4.87% gc time, 99.97% compilation time) ``` While the types of `prob` are exactly the same, there is still a significant amount of compile time, even with that exact same time being called in `using` at OrdinaryDiffEq. Maybe this needs to be run on master? --- Project.toml | 1 + src/DiffEqBase.jl | 3 +++ src/norecompile.jl | 56 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 src/norecompile.jl diff --git a/Project.toml b/Project.toml index 932ab47fd..9a9360b21 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index eb5259c63..db58f911a 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -34,6 +34,7 @@ using Setfield using ForwardDiff +import FunctionWrappersWrappers @reexport using SciMLBase using SciMLBase: @def, DEIntegrator, DEProblem, AbstractDiffEqOperator, @@ -119,6 +120,8 @@ include("init.jl") include("forwarddiff.jl") include("chainrules.jl") +include("precompile.jl") +include("norecompile.jl") # This is only used for oop stiff solvers default_factorize(A) = lu(A; check = false) diff --git a/src/norecompile.jl b/src/norecompile.jl new file mode 100644 index 000000000..4e7fcf861 --- /dev/null +++ b/src/norecompile.jl @@ -0,0 +1,56 @@ +struct OrdinaryDiffEqTag end + +const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag,Float64},Float64,1} +const arglists = (Tuple{Vector{Float64},Vector{Float64},Vector{Float64},Float64}, + Tuple{Vector{Float64},Vector{Float64},SciMLBase.NullParameters,Float64}, + Tuple{Vector{dualT},Vector{Float64},Vector{Float64},dualT}, + Tuple{Vector{dualT},Vector{dualT},Vector{Float64},Float64}, + Tuple{Vector{dualT},Vector{dualT},SciMLBase.NullParameters,Float64}, + Tuple{Vector{dualT},Vector{Float64},SciMLBase.NullParameters,dualT}) +const returnlists = ntuple(x -> Nothing, length(arglists)) +void(f) = function (du, u, p, t) + f(du, u, p, t) + nothing +end +const NORECOMPILE_FUNCTION = typeof(FunctionWrappersWrappers.FunctionWrappersWrapper(void(() -> nothing), arglists, returnlists)) +wrap_norecompile(f) = FunctionWrappersWrappers.FunctionWrappersWrapper(void(f), arglists, returnlists) + +function ODEFunction{iip,false}(f; + mass_matrix=I, + analytic=nothing, + tgrad=nothing, + jac=nothing, + jvp=nothing, + vjp=nothing, + jac_prototype=nothing, + sparsity=jac_prototype, + Wfact=nothing, + Wfact_t=nothing, + paramjac=nothing, + syms=nothing, + indepsym=nothing, + observed=SciMLBase.DEFAULT_OBSERVED, + colorvec=nothing) where {iip} + + if jac === nothing && isa(jac_prototype, AbstractDiffEqLinearOperator) + if iip + jac = update_coefficients! #(J,u,p,t) + else + jac = (u, p, t) -> update_coefficients!(deepcopy(jac_prototype), u, p, t) + end + end + + if jac_prototype !== nothing && colorvec === nothing && ArrayInterface.fast_matrix_colors(jac_prototype) + _colorvec = ArrayInterface.matrix_colors(jac_prototype) + else + _colorvec = colorvec + end + + ODEFunction{iip, + NORECOMPILE_FUNCTION,Any,Any,Any,Any, + Any,Any,Any,Any,Any, + Any,Any,typeof(syms),typeof(indepsym),Any,typeof(_colorvec)}( + wrap_norecompile(f), mass_matrix, analytic, tgrad, jac, + jvp, vjp, jac_prototype, sparsity, Wfact, + Wfact_t, paramjac, syms, indepsym, observed, _colorvec) +end \ No newline at end of file From 188db39b4f12608d596af9c69f18cd8838ca0369 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 26 Mar 2022 11:04:52 -0400 Subject: [PATCH 2/9] with extra specialization! --- src/norecompile.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/norecompile.jl b/src/norecompile.jl index 4e7fcf861..9f4f8caff 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -47,9 +47,10 @@ function ODEFunction{iip,false}(f; end ODEFunction{iip, - NORECOMPILE_FUNCTION,Any,Any,Any,Any, - Any,Any,Any,Any,Any, - Any,Any,typeof(syms),typeof(indepsym),Any,typeof(_colorvec)}( + NORECOMPILE_FUNCTION,typeof(mass_matrix),typeof(analytic),typeof(tgrad),typeof(jac), + typeof(jvp),typeof(vjp),typeof(jac_prototype),typeof(sparsity),typeof(Wfact), + typeof(Wfact_t),typeof(paramjac),typeof(syms),typeof(indepsym), + typeof(observed),typeof(_colorvec)}( wrap_norecompile(f), mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, observed, _colorvec) From b604656409715b9f92b6ff66f55abed4edba4c98 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 27 Mar 2022 16:50:59 -0400 Subject: [PATCH 3/9] try `@nospecialize` on the arguments --- src/norecompile.jl | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/src/norecompile.jl b/src/norecompile.jl index 9f4f8caff..781be65f7 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -8,9 +8,37 @@ const arglists = (Tuple{Vector{Float64},Vector{Float64},Vector{Float64},Float64} Tuple{Vector{dualT},Vector{dualT},SciMLBase.NullParameters,Float64}, Tuple{Vector{dualT},Vector{Float64},SciMLBase.NullParameters,dualT}) const returnlists = ntuple(x -> Nothing, length(arglists)) -void(f) = function (du, u, p, t) - f(du, u, p, t) - nothing +function void(@nospecialize(f::Function)) + function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), @nospecialize(p::Vector{Float64}), @nospecialize(t::dualT)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::dualT)) + f(du, u, p, t) + nothing + end + f2 end const NORECOMPILE_FUNCTION = typeof(FunctionWrappersWrappers.FunctionWrappersWrapper(void(() -> nothing), arglists, returnlists)) wrap_norecompile(f) = FunctionWrappersWrappers.FunctionWrappersWrapper(void(f), arglists, returnlists) From 3f8370bd194f5fa742dc577650c2e51b16623952 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 28 Mar 2022 05:18:36 -0400 Subject: [PATCH 4/9] force precompilation? --- src/norecompile.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/norecompile.jl b/src/norecompile.jl index 781be65f7..b15a35c6f 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -38,8 +38,15 @@ function void(@nospecialize(f::Function)) f(du, u, p, t) nothing end + precompile(f2, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) + precompile(f2, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) + precompile(f2, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) + precompile(f2, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) + precompile(f2, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) + precompile(f2, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) f2 end + const NORECOMPILE_FUNCTION = typeof(FunctionWrappersWrappers.FunctionWrappersWrapper(void(() -> nothing), arglists, returnlists)) wrap_norecompile(f) = FunctionWrappersWrappers.FunctionWrappersWrapper(void(f), arglists, returnlists) From 954ee433000a33df40f54ecdb0e7c3d304f8e282 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 28 Mar 2022 05:24:48 -0400 Subject: [PATCH 5/9] try more --- src/norecompile.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/norecompile.jl b/src/norecompile.jl index b15a35c6f..1451d3939 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -38,6 +38,13 @@ function void(@nospecialize(f::Function)) f(du, u, p, t) nothing end + precompile(f, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) + precompile(f, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) + precompile(f, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) + precompile(f, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) + precompile(f, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) + precompile(f, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) + precompile(f2, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) precompile(f2, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) precompile(f2, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) From 2a186eaddac4912f3efda5763995fe9b1ef6c190 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 19 Aug 2022 23:56:57 -0400 Subject: [PATCH 6/9] no precompile file --- src/DiffEqBase.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index db58f911a..9c39e8cbb 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -120,7 +120,6 @@ include("init.jl") include("forwarddiff.jl") include("chainrules.jl") -include("precompile.jl") include("norecompile.jl") # This is only used for oop stiff solvers default_factorize(A) = lu(A; check = false) From f05cf28b8a755495695adef68119c44514697e43 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 20 Aug 2022 00:04:18 -0400 Subject: [PATCH 7/9] add sys --- src/norecompile.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/norecompile.jl b/src/norecompile.jl index 1451d3939..4ffed7290 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -72,7 +72,7 @@ function ODEFunction{iip,false}(f; syms=nothing, indepsym=nothing, observed=SciMLBase.DEFAULT_OBSERVED, - colorvec=nothing) where {iip} + colorvec=nothing, sys = nothing) where {iip} if jac === nothing && isa(jac_prototype, AbstractDiffEqLinearOperator) if iip @@ -92,8 +92,8 @@ function ODEFunction{iip,false}(f; NORECOMPILE_FUNCTION,typeof(mass_matrix),typeof(analytic),typeof(tgrad),typeof(jac), typeof(jvp),typeof(vjp),typeof(jac_prototype),typeof(sparsity),typeof(Wfact), typeof(Wfact_t),typeof(paramjac),typeof(syms),typeof(indepsym), - typeof(observed),typeof(_colorvec)}( + typeof(observed),typeof(_colorvec), typeof(sys)}( wrap_norecompile(f), mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, - Wfact_t, paramjac, syms, indepsym, observed, _colorvec) -end \ No newline at end of file + Wfact_t, paramjac, syms, indepsym, observed, _colorvec, sys) +end From 5da8e11da72b2b7c22d607be3592e9117e2f450a Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 20 Aug 2022 07:40:38 -0400 Subject: [PATCH 8/9] do not require SciMLBase changes and support oop --- src/norecompile.jl | 239 +++++++++++++++++++++++++++------------------ src/utils.jl | 14 --- 2 files changed, 146 insertions(+), 107 deletions(-) diff --git a/src/norecompile.jl b/src/norecompile.jl index 4ffed7290..026aa01c7 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -1,99 +1,152 @@ struct OrdinaryDiffEqTag end -const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag,Float64},Float64,1} -const arglists = (Tuple{Vector{Float64},Vector{Float64},Vector{Float64},Float64}, - Tuple{Vector{Float64},Vector{Float64},SciMLBase.NullParameters,Float64}, - Tuple{Vector{dualT},Vector{Float64},Vector{Float64},dualT}, - Tuple{Vector{dualT},Vector{dualT},Vector{Float64},Float64}, - Tuple{Vector{dualT},Vector{dualT},SciMLBase.NullParameters,Float64}, - Tuple{Vector{dualT},Vector{Float64},SciMLBase.NullParameters,dualT}) -const returnlists = ntuple(x -> Nothing, length(arglists)) -function void(@nospecialize(f::Function)) - function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) - f(du, u, p, t) - nothing - end - - function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) - f(du, u, p, t) - nothing - end - - function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) - f(du, u, p, t) - nothing - end - - function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) - f(du, u, p, t) - nothing - end - - function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), @nospecialize(p::Vector{Float64}), @nospecialize(t::dualT)) - f(du, u, p, t) - nothing - end - - function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::dualT)) - f(du, u, p, t) - nothing - end - precompile(f, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) - precompile(f, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) - precompile(f, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) - precompile(f, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) - precompile(f, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) - precompile(f, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) - - precompile(f2, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) - precompile(f2, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) - precompile(f2, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) - precompile(f2, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) - precompile(f2, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) - precompile(f2, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) - f2 +const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1} +const NORECOMPILE_SUPPORTED_ARGS = (Tuple{Vector{Float64}, Vector{Float64}, + Vector{Float64}, Float64}, + Tuple{Vector{Float64}, Vector{Float64}, + SciMLBase.NullParameters, Float64}) +const arglists = (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}, + Tuple{Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64 + }, + Tuple{Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT}, + Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64}, + Tuple{Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64}, + Tuple{Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT}) +const iip_returnlists = ntuple(x -> Nothing, length(arglists)) +function void(@nospecialize(f::Function)) + function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), + @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), + @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), + @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), + @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), + @nospecialize(p::Vector{Float64}), @nospecialize(t::dualT)) + f(du, u, p, t) + nothing + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), + @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::dualT)) + f(du, u, p, t) + nothing + end + precompile(f, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) + precompile(f, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) + precompile(f, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) + precompile(f, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) + precompile(f, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) + precompile(f, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) + + precompile(f2, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) + precompile(f2, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) + precompile(f2, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) + precompile(f2, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) + precompile(f2, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) + precompile(f2, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) + f2 end -const NORECOMPILE_FUNCTION = typeof(FunctionWrappersWrappers.FunctionWrappersWrapper(void(() -> nothing), arglists, returnlists)) -wrap_norecompile(f) = FunctionWrappersWrappers.FunctionWrappersWrapper(void(f), arglists, returnlists) - -function ODEFunction{iip,false}(f; - mass_matrix=I, - analytic=nothing, - tgrad=nothing, - jac=nothing, - jvp=nothing, - vjp=nothing, - jac_prototype=nothing, - sparsity=jac_prototype, - Wfact=nothing, - Wfact_t=nothing, - paramjac=nothing, - syms=nothing, - indepsym=nothing, - observed=SciMLBase.DEFAULT_OBSERVED, - colorvec=nothing, sys = nothing) where {iip} - - if jac === nothing && isa(jac_prototype, AbstractDiffEqLinearOperator) - if iip - jac = update_coefficients! #(J,u,p,t) - else - jac = (u, p, t) -> update_coefficients!(deepcopy(jac_prototype), u, p, t) +const oop_returnlists = (Vector{Float64},Vector{Float64}, + ntuple(x -> Vector{dualT}, length(arglists)-2)...) + +function typestablemapping(@nospecialize(f::Function)) + function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), + @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) + f(u, p, t)::Vector{Float64} + end + + function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), + @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) + f(u, p, t)::Vector{Float64} + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), + @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) + f(u, p, t)::Vector{dualT} + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), + @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) + f(u, p, t)::Vector{dualT} end - end - - if jac_prototype !== nothing && colorvec === nothing && ArrayInterface.fast_matrix_colors(jac_prototype) - _colorvec = ArrayInterface.matrix_colors(jac_prototype) - else - _colorvec = colorvec - end - - ODEFunction{iip, - NORECOMPILE_FUNCTION,typeof(mass_matrix),typeof(analytic),typeof(tgrad),typeof(jac), - typeof(jvp),typeof(vjp),typeof(jac_prototype),typeof(sparsity),typeof(Wfact), - typeof(Wfact_t),typeof(paramjac),typeof(syms),typeof(indepsym), - typeof(observed),typeof(_colorvec), typeof(sys)}( - wrap_norecompile(f), mass_matrix, analytic, tgrad, jac, - jvp, vjp, jac_prototype, sparsity, Wfact, - Wfact_t, paramjac, syms, indepsym, observed, _colorvec, sys) + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), + @nospecialize(p::Vector{Float64}), @nospecialize(t::dualT)) + f(u, p, t)::Vector{dualT} + end + + function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), + @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::dualT)) + f(u, p, t)::Vector{dualT} + end + precompile(f, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) + precompile(f, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) + precompile(f, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) + precompile(f, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) + precompile(f, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) + precompile(f, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) + + precompile(f2, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) + precompile(f2, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) + precompile(f2, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) + precompile(f2, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) + precompile(f2, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) + precompile(f2, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) + f2 +end + +const NORECOMPILE_ARGUMENT_MESSAGE = """ + No-recompile mode is only supported for state arguments + of type `Vector{Float64}`, time arguments of `Float64` + and parameter arguments of type `Vector{Float64}` or + `SciMLBase.NullParameters`. + """ + +struct NoRecompileArgumentError <: Exception + args +end + +function Base.showerror(io::IO, e::NoRecompileArgumentError) + println(io, NORECOMPILE_ARGUMENT_MESSAGE) + print(io, "Attempted arguments: ") + print(io, e.args) +end + +function wrapfun_oop(ff, inputs::Tuple) + IT = Tuple{map(typeof, inputs)...} + if IT ∉ NORECOMPILE_SUPPORTED_ARGS + throw(NoRecompileArgumentError(IT)) + end + FunctionWrappersWrappers.FunctionWrappersWrapper(void(ff), arglists, oop_returnlists) +end + +function wrapfun_iip(ff, inputs::Tuple) + IT = Tuple{map(typeof, inputs)...} + if IT ∉ NORECOMPILE_SUPPORTED_ARGS + throw(NoRecompileArgumentError(IT)) + end + FunctionWrappersWrappers.FunctionWrappersWrapper(void(ff), arglists, iip_returnlists) +end + +function unwrap_fw(fw::FunctionWrapper) + fw.obj[] end diff --git a/src/utils.jl b/src/utils.jl index bfda082f8..7fcfd3369 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,17 +1,3 @@ -function wrapfun_oop(ff, inputs::Tuple) - IT = map(typeof, inputs) - FunctionWrapper{IT[1], Tuple{IT...}}((args...) -> (ff(args...))) -end - -function wrapfun_iip(ff, inputs::Tuple) - IT = map(typeof, inputs) - FunctionWrapper{Nothing, Tuple{IT...}}((args...) -> (ff(args...); nothing)) -end - -function unwrap_fw(fw::FunctionWrapper) - fw.obj[] -end - _vec(v) = vec(v) _vec(v::Number) = v _vec(v::AbstractVector) = v From 96b6526e3026b5a9b36b0c26071db37e82a749d9 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 20 Aug 2022 07:53:30 -0400 Subject: [PATCH 9/9] setup opaque closure form so that no other changes are necessary --- src/norecompile.jl | 106 ++------------------------------------------- 1 file changed, 4 insertions(+), 102 deletions(-) diff --git a/src/norecompile.jl b/src/norecompile.jl index 026aa01c7..3ac09128a 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -14,104 +14,14 @@ const arglists = (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float Tuple{Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT}) const iip_returnlists = ntuple(x -> Nothing, length(arglists)) function void(@nospecialize(f::Function)) - function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), - @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) - f(du, u, p, t) - nothing - end - - function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), - @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) - f(du, u, p, t) - nothing - end - - function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), - @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) - f(du, u, p, t) - nothing - end - - function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), - @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) - f(du, u, p, t) - nothing - end - - function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), - @nospecialize(p::Vector{Float64}), @nospecialize(t::dualT)) - f(du, u, p, t) - nothing - end - - function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), - @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::dualT)) - f(du, u, p, t) - nothing - end - precompile(f, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) - precompile(f, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) - precompile(f, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) - precompile(f, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) - precompile(f, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) - precompile(f, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) - - precompile(f2, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) - precompile(f2, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) - precompile(f2, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) - precompile(f2, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) - precompile(f2, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) - precompile(f2, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) - f2 + Base.Experimental.@opaque (args...)->f(args...) end const oop_returnlists = (Vector{Float64},Vector{Float64}, ntuple(x -> Vector{dualT}, length(arglists)-2)...) function typestablemapping(@nospecialize(f::Function)) - function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), - @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) - f(u, p, t)::Vector{Float64} - end - - function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}), - @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) - f(u, p, t)::Vector{Float64} - end - - function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), - @nospecialize(p::Vector{Float64}), @nospecialize(t::Float64)) - f(u, p, t)::Vector{dualT} - end - - function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}), - @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64)) - f(u, p, t)::Vector{dualT} - end - - function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), - @nospecialize(p::Vector{Float64}), @nospecialize(t::dualT)) - f(u, p, t)::Vector{dualT} - end - - function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}), - @nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::dualT)) - f(u, p, t)::Vector{dualT} - end - precompile(f, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) - precompile(f, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) - precompile(f, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) - precompile(f, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) - precompile(f, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) - precompile(f, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) - - precompile(f2, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64)) - precompile(f2, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64)) - precompile(f2, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64)) - precompile(f2, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64)) - precompile(f2, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT)) - precompile(f2, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT)) - f2 + Base.Experimental.@opaque (args...)->f(args...) end const NORECOMPILE_ARGUMENT_MESSAGE = """ @@ -132,19 +42,11 @@ function Base.showerror(io::IO, e::NoRecompileArgumentError) end function wrapfun_oop(ff, inputs::Tuple) - IT = Tuple{map(typeof, inputs)...} - if IT ∉ NORECOMPILE_SUPPORTED_ARGS - throw(NoRecompileArgumentError(IT)) - end - FunctionWrappersWrappers.FunctionWrappersWrapper(void(ff), arglists, oop_returnlists) + void(ff) end function wrapfun_iip(ff, inputs::Tuple) - IT = Tuple{map(typeof, inputs)...} - if IT ∉ NORECOMPILE_SUPPORTED_ARGS - throw(NoRecompileArgumentError(IT)) - end - FunctionWrappersWrappers.FunctionWrappersWrapper(void(ff), arglists, iip_returnlists) + void(ff) end function unwrap_fw(fw::FunctionWrapper)