Skip to content

Commit b258ebd

Browse files
replace main Zygote.adjoints with ChainRules rrules
Part of SciML/SciMLBase.jl#69 And needs to be done with SciML/SciMLSensitivity.jl#428 But currently getting: ```julia using DiffEqSensitivity, OrdinaryDiffEq, Zygote function fiip(du,u,p,t) du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] end function foop(u,p,t) dx = p[1]*u[1] - p[2]*u[1]*u[2] dy = -p[3]*u[2] + p[4]*u[1]*u[2] [dx,dy] end p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0] prob = ODEProblem(fiip,u0,(0.0,10.0),p) du01,dp1 = Zygote.gradient((u0,p)->sum(solve(prob,Tsit5(),u0=u0,p=p,abstol=1e-14, reltol=1e-14,saveat=0.1,sensealg=QuadratureAdjoint())),u0,p) ``` ```julia ArgumentError: tuple must be non-empty first(#unused#::Tuple{}) at tuple.jl:134 _unapply(t::Nothing, xs::Tuple{}) at lib.jl:163 _unapply(t::Tuple{Nothing}, xs::Tuple{}) at lib.jl:167 _unapply(t::Tuple{Tuple{Nothing}}, xs::Tuple{}) at lib.jl:167 _unapply(t::Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, xs::Tuple{Nothing, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Nothing}) at lib.jl:168 unapply(t::Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, xs::Tuple{Nothing, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Nothing}) at lib.jl:177 #193 at lib.jl:195 [inlined] (::Zygote.var"#1713#back#195"{Zygote.var"#193#194"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.var"#kw_zpullback#40"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#179"{Base.Iterators.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:abstol, :reltol), Tuple{Float64, Float64}}}, Tsit5, QuadratureAdjoint{0, true, Val{:central}, Bool}, Vector{Float64}, Vector{Float64}, Tuple{}, Colon, NamedTuple{(:abstol, :reltol), Tuple{Float64, Float64}}}}}})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}) at adjoint.jl:59 Pullback at solve.jl:70 [inlined] (::typeof(∂(#solve#59)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}) at interface2.jl:0 (::Zygote.var"#193#194"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#59))})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}) at lib.jl:194 (::Zygote.var"#1713#back#195"{Zygote.var"#193#194"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#59))}})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}) at adjoint.jl:59 Pullback at solve.jl:68 [inlined] (::typeof(∂(solve##kw)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}) at interface2.jl:0 Pullback at test.jl:14 [inlined] (::typeof(∂(#7)))(Δ::Float64) at interface2.jl:0 (::Zygote.var"#46#47"{typeof(∂(#7))})(Δ::Float64) at interface.jl:41 gradient(::Function, ::Vector{Float64}, ::Vararg{Vector{Float64}, N} where N) at interface.jl:59 top-level scope at test.jl:14 eval at boot.jl:360 [inlined] ```
1 parent 8f9b220 commit b258ebd

File tree

1 file changed

+1
-15
lines changed

1 file changed

+1
-15
lines changed

src/solve.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -287,20 +287,14 @@ end
287287

288288
struct SensitivityADPassThrough <: SciMLBase.DEAlgorithm end
289289

290-
ZygoteRules.@adjoint function solve_up(prob,sensealg::Union{Nothing,AbstractSensitivityAlgorithm},
291-
u0,p,args...;
292-
kwargs...)
293-
_solve_adjoint(prob,sensealg,u0,p,args...;kwargs...)
294-
end
295-
296290
function ChainRulesCore.frule(::typeof(solve_up),prob,
297291
sensealg::Union{Nothing,AbstractSensitivityAlgorithm},
298292
u0,p,args...;
299293
kwargs...)
300294
_solve_forward(prob,sensealg,u0,p,args...;kwargs...)
301295
end
302296

303-
function ChainRulesCore.rrule(::typeof(solve_up),prob,
297+
function ChainRulesCore.rrule(::typeof(solve_up),prob::SciMLBase.DEProblem,
304298
sensealg::Union{Nothing,AbstractSensitivityAlgorithm},
305299
u0,p,args...;
306300
kwargs...)
@@ -314,14 +308,6 @@ end
314308
@deprecate concrete_solve(prob::SciMLBase.DEProblem,alg::Union{SciMLBase.DEAlgorithm,Nothing},
315309
u0=prob.u0,p=prob.p,args...;kwargs...) solve(prob,alg,args...;u0=u0,p=p,kwargs...)
316310

317-
ZygoteRules.@adjoint function concrete_solve(prob::SciMLBase.DEProblem,
318-
alg::Union{SciMLBase.DEAlgorithm,Nothing},
319-
u0=prob.u0,p=prob.p,args...;
320-
sensealg=nothing,
321-
kwargs...)
322-
_concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...)
323-
end
324-
325311
function _solve_adjoint(prob,sensealg,u0,p,args...;kwargs...)
326312
if isempty(args)
327313
_concrete_solve_adjoint(prob,nothing,sensealg,u0,p;kwargs...)

0 commit comments

Comments
 (0)