Skip to content

Commit c0e7570

Browse files
fix oop handling
1 parent 5da8e11 commit c0e7570

File tree

1 file changed

+54
-38
lines changed

1 file changed

+54
-38
lines changed

src/norecompile.jl

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
struct OrdinaryDiffEqTag end
22

33
const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1}
4-
const NORECOMPILE_SUPPORTED_ARGS = (Tuple{Vector{Float64}, Vector{Float64},
5-
Vector{Float64}, Float64},
6-
Tuple{Vector{Float64}, Vector{Float64},
7-
SciMLBase.NullParameters, Float64})
8-
const arglists = (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64},
9-
Tuple{Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64
10-
},
11-
Tuple{Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT},
12-
Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64},
13-
Tuple{Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64},
14-
Tuple{Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT})
15-
const iip_returnlists = ntuple(x -> Nothing, length(arglists))
4+
const NORECOMPILE_IIP_SUPPORTED_ARGS = (Tuple{Vector{Float64}, Vector{Float64},
5+
Vector{Float64}, Float64},
6+
Tuple{Vector{Float64}, Vector{Float64},
7+
SciMLBase.NullParameters, Float64})
8+
const iip_arglists = (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64},
9+
Tuple{Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters,
10+
Float64
11+
},
12+
Tuple{Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT},
13+
Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64},
14+
Tuple{Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64
15+
},
16+
Tuple{Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT
17+
})
18+
const iip_returnlists = ntuple(x -> Nothing, length(iip_arglists))
1619
function void(@nospecialize(f::Function))
1720
function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}),
1821
@nospecialize(p::Vector{Float64}), @nospecialize(t::Float64))
@@ -65,52 +68,63 @@ function void(@nospecialize(f::Function))
6568
f2
6669
end
6770

68-
const oop_returnlists = (Vector{Float64},Vector{Float64},
69-
ntuple(x -> Vector{dualT}, length(arglists)-2)...)
71+
const oop_arglists = (Tuple{Vector{Float64}, Vector{Float64}, Float64},
72+
Tuple{Vector{Float64}, SciMLBase.NullParameters, Float64},
73+
Tuple{Vector{Float64}, Vector{Float64}, dualT},
74+
Tuple{Vector{dualT}, Vector{Float64}, Float64},
75+
Tuple{Vector{dualT}, SciMLBase.NullParameters, Float64},
76+
Tuple{Vector{Float64}, SciMLBase.NullParameters, dualT})
77+
78+
const NORECOMPILE_OOP_SUPPORTED_ARGS = (Tuple{Vector{Float64},
79+
Vector{Float64}, Float64},
80+
Tuple{Vector{Float64},
81+
SciMLBase.NullParameters, Float64})
82+
const oop_returnlists = (Vector{Float64}, Vector{Float64},
83+
ntuple(x -> Vector{dualT}, length(oop_arglists) - 2)...)
7084

7185
function typestablemapping(@nospecialize(f::Function))
72-
function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}),
86+
function f2(@nospecialize(u::Vector{Float64}),
7387
@nospecialize(p::Vector{Float64}), @nospecialize(t::Float64))
7488
f(u, p, t)::Vector{Float64}
7589
end
7690

77-
function f2(@nospecialize(du::Vector{Float64}), @nospecialize(u::Vector{Float64}),
91+
function f2(@nospecialize(u::Vector{Float64}),
7892
@nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64))
7993
f(u, p, t)::Vector{Float64}
8094
end
8195

82-
function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}),
96+
function f2(@nospecialize(u::Vector{dualT}),
8397
@nospecialize(p::Vector{Float64}), @nospecialize(t::Float64))
8498
f(u, p, t)::Vector{dualT}
8599
end
86100

87-
function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{dualT}),
101+
function f2(@nospecialize(u::Vector{dualT}),
88102
@nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::Float64))
89103
f(u, p, t)::Vector{dualT}
90104
end
91105

92-
function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}),
106+
function f2(@nospecialize(u::Vector{Float64}),
93107
@nospecialize(p::Vector{Float64}), @nospecialize(t::dualT))
94108
f(u, p, t)::Vector{dualT}
95109
end
96110

97-
function f2(@nospecialize(du::Vector{dualT}), @nospecialize(u::Vector{Float64}),
111+
function f2(@nospecialize(u::Vector{Float64}),
98112
@nospecialize(p::SciMLBase.NullParameters), @nospecialize(t::dualT))
99113
f(u, p, t)::Vector{dualT}
100114
end
101-
precompile(f, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64))
102-
precompile(f, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64))
103-
precompile(f, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64))
104-
precompile(f, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64))
105-
precompile(f, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT))
106-
precompile(f, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT))
107-
108-
precompile(f2, (Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64))
109-
precompile(f2, (Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64))
110-
precompile(f2, (Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64))
111-
precompile(f2, (Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, Float64))
112-
precompile(f2, (Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT))
113-
precompile(f2, (Vector{dualT}, Vector{Float64}, SciMLBase.NullParameters, dualT))
115+
precompile(f, (Vector{Float64}, Vector{Float64}, Float64))
116+
precompile(f, (Vector{Float64}, SciMLBase.NullParameters, Float64))
117+
precompile(f, (Vector{dualT}, Vector{Float64}, Float64))
118+
precompile(f, (Vector{dualT}, SciMLBase.NullParameters, Float64))
119+
precompile(f, (Vector{Float64}, Vector{Float64}, dualT))
120+
precompile(f, (Vector{Float64}, SciMLBase.NullParameters, dualT))
121+
122+
precompile(f2, (Vector{Float64}, Vector{Float64}, Float64))
123+
precompile(f2, (Vector{Float64}, SciMLBase.NullParameters, Float64))
124+
precompile(f2, (Vector{dualT}, Vector{Float64}, Float64))
125+
precompile(f2, (Vector{dualT}, SciMLBase.NullParameters, Float64))
126+
precompile(f2, (Vector{Float64}, Vector{Float64}, dualT))
127+
precompile(f2, (Vector{Float64}, SciMLBase.NullParameters, dualT))
114128
f2
115129
end
116130

@@ -122,7 +136,7 @@ const NORECOMPILE_ARGUMENT_MESSAGE = """
122136
"""
123137

124138
struct NoRecompileArgumentError <: Exception
125-
args
139+
args::Any
126140
end
127141

128142
function Base.showerror(io::IO, e::NoRecompileArgumentError)
@@ -133,18 +147,20 @@ end
133147

134148
function wrapfun_oop(ff, inputs::Tuple)
135149
IT = Tuple{map(typeof, inputs)...}
136-
if IT NORECOMPILE_SUPPORTED_ARGS
150+
if IT NORECOMPILE_OOP_SUPPORTED_ARGS
137151
throw(NoRecompileArgumentError(IT))
138152
end
139-
FunctionWrappersWrappers.FunctionWrappersWrapper(void(ff), arglists, oop_returnlists)
153+
FunctionWrappersWrappers.FunctionWrappersWrapper(typestablemapping(ff), oop_arglists,
154+
oop_returnlists)
140155
end
141156

142157
function wrapfun_iip(ff, inputs::Tuple)
143158
IT = Tuple{map(typeof, inputs)...}
144-
if IT NORECOMPILE_SUPPORTED_ARGS
159+
if IT NORECOMPILE_IIP_SUPPORTED_ARGS
145160
throw(NoRecompileArgumentError(IT))
146161
end
147-
FunctionWrappersWrappers.FunctionWrappersWrapper(void(ff), arglists, iip_returnlists)
162+
FunctionWrappersWrappers.FunctionWrappersWrapper(void(ff), iip_arglists,
163+
iip_returnlists)
148164
end
149165

150166
function unwrap_fw(fw::FunctionWrapper)

0 commit comments

Comments
 (0)