Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c23d5bc

Browse files
committedMar 22, 2022
Setup FunctionWrappersWrappers norecompile mode
Needs: - SciML/SciMLBase.jl#143 - SciML/OrdinaryDiffEq.jl#1627 ```julia using OrdinaryDiffEq function f(du, u, p, t) du[1] = 0.2u[1] du[2] = 0.4u[2] end u0 = ones(2) tspan = (0.0, 1.0) prob = ODEProblem{true,false}(f, u0, tspan, Float64[]) function lorenz(du, u, p, t) du[1] = 10.0(u[2] - u[1]) du[2] = u[1] * (28.0 - u[3]) - u[2] du[3] = u[1] * u[2] - (8 / 3) * u[3] end lorenzprob = ODEProblem{true,false}(lorenz, [1.0; 0.0; 0.0], (0.0, 1.0), Float64[]) typeof(prob) === typeof(lorenzprob) # true @time sol = solve(prob, Rosenbrock23(autodiff=false)) @time sol = solve(prob, Rosenbrock23(chunk_size=1)) ``` ``` 2.763588 seconds (10.32 M allocations: 648.718 MiB, 4.92% gc time, 99.89% compilation time) 10.577789 seconds (45.44 M allocations: 2.760 GiB, 4.87% gc time, 99.97% compilation time) ``` While the types of `prob` are exactly the same, there is still a significant amount of compile time, even with that exact same time being called in `using` at OrdinaryDiffEq. Maybe this needs to be run on master?
1 parent 9aa4d87 commit c23d5bc

File tree

3 files changed

+59
-0
lines changed

3 files changed

+59
-0
lines changed
 

‎Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1313
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1414
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1515
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
16+
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
1617
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1718
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1819
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

‎src/DiffEqBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ using Setfield
3838

3939
using ForwardDiff
4040

41+
import FunctionWrappersWrappers
4142
@reexport using SciMLBase
4243

4344
using SciMLBase: @def, DEIntegrator, DEProblem, AbstractDiffEqOperator,
@@ -127,6 +128,7 @@ include("forwarddiff.jl")
127128
include("chainrules.jl")
128129

129130
include("precompile.jl")
131+
include("norecompile.jl")
130132

131133
"""
132134
$(TYPEDEF)

‎src/norecompile.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
struct OrdinaryDiffEqTag end
2+
3+
const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag,Float64},Float64,1}
4+
const arglists = (Tuple{Vector{Float64},Vector{Float64},Vector{Float64},Float64},
5+
Tuple{Vector{Float64},Vector{Float64},SciMLBase.NullParameters,Float64},
6+
Tuple{Vector{dualT},Vector{Float64},Vector{Float64},dualT},
7+
Tuple{Vector{dualT},Vector{dualT},Vector{Float64},Float64},
8+
Tuple{Vector{dualT},Vector{dualT},SciMLBase.NullParameters,Float64},
9+
Tuple{Vector{dualT},Vector{Float64},SciMLBase.NullParameters,dualT})
10+
const returnlists = ntuple(x -> Nothing, length(arglists))
11+
void(f) = function (du, u, p, t)
12+
f(du, u, p, t)
13+
nothing
14+
end
15+
const NORECOMPILE_FUNCTION = typeof(FunctionWrappersWrappers.FunctionWrappersWrapper(void(() -> nothing), arglists, returnlists))
16+
wrap_norecompile(f) = FunctionWrappersWrappers.FunctionWrappersWrapper(void(f), arglists, returnlists)
17+
18+
function ODEFunction{iip,false}(f;
19+
mass_matrix=I,
20+
analytic=nothing,
21+
tgrad=nothing,
22+
jac=nothing,
23+
jvp=nothing,
24+
vjp=nothing,
25+
jac_prototype=nothing,
26+
sparsity=jac_prototype,
27+
Wfact=nothing,
28+
Wfact_t=nothing,
29+
paramjac=nothing,
30+
syms=nothing,
31+
indepsym=nothing,
32+
observed=SciMLBase.DEFAULT_OBSERVED,
33+
colorvec=nothing) where {iip}
34+
35+
if jac === nothing && isa(jac_prototype, AbstractDiffEqLinearOperator)
36+
if iip
37+
jac = update_coefficients! #(J,u,p,t)
38+
else
39+
jac = (u, p, t) -> update_coefficients!(deepcopy(jac_prototype), u, p, t)
40+
end
41+
end
42+
43+
if jac_prototype !== nothing && colorvec === nothing && ArrayInterface.fast_matrix_colors(jac_prototype)
44+
_colorvec = ArrayInterface.matrix_colors(jac_prototype)
45+
else
46+
_colorvec = colorvec
47+
end
48+
49+
ODEFunction{iip,
50+
NORECOMPILE_FUNCTION,Any,Any,Any,Any,
51+
Any,Any,Any,Any,Any,
52+
Any,Any,typeof(syms),typeof(indepsym),Any,typeof(_colorvec)}(
53+
wrap_norecompile(f), mass_matrix, analytic, tgrad, jac,
54+
jvp, vjp, jac_prototype, sparsity, Wfact,
55+
Wfact_t, paramjac, syms, indepsym, observed, _colorvec)
56+
end

0 commit comments

Comments
 (0)
Please sign in to comment.