-
-
Notifications
You must be signed in to change notification settings - Fork 78
Closed
Description
using RecursiveArrayTools
using OrdinaryDiffEq
using SciMLSensitivity
using Zygote, ForwardDiff
using Test
function fiip(du, u, p, t)
x = u.x[1]
dx = du.x[1]
dx[1] = p[1] * x[1] - p[2] * x[1] * x[2]
dx[2] = -p[3] * x[2] + p[4] * x[1] * x[2]
nothing
end
p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0; 1.0]
prob = ODEProblem(fiip, ArrayPartition(u0), (0.0, 10.0), p)
sol = solve(prob, Tsit5(), abstol=1e-14, reltol=1e-14, saveat=0.1)
# this works
du01, dp1 = Zygote.gradient((u0, p) -> sum(solve(prob, Tsit5(), u0=ArrayPartition(u0), p=p,
abstol=1e-14, reltol=1e-14,
saveat=0.1)), u0, p)
# remake and ODEProblem both fail for Zygote
function loss(u0, p)
#_prob = remake(prob, u0=ArrayPartition(u0), p=p)
_prob = ODEProblem(fiip, ArrayPartition(u0), (0, 10), p)
_sol = solve(_prob, Tsit5(), abstol=1e-14, reltol=1e-14, saveat=0.1)
sum(_sol)
end
Zdu0, Zdp = Zygote.gradient((u0, p) -> loss(u0, p), u0, p)
Fdu0 = ForwardDiff.gradient((u0) -> loss(u0, p), u0)
Fdp = ForwardDiff.gradient((p) -> loss(u0, p), p)
Fails with:
ERROR: MethodError: no method matching ntuple(::RecursiveArrayTools.var"#96#98"{Vector{Int64}, Vector{Float64}, Tuple{Vector{Float64}}})
Closest candidates are:
ntuple(::F, ::Integer) where F at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/ntuple.jl:17
ntuple(::F, ::Static.StaticInt{N}) where {F, N} at ~/.julia/packages/Static/SlJif/src/Static.jl:513
ntuple(::Any, ::Val{0}) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/ntuple.jl:47
...
Stacktrace:
[1] (::RecursiveArrayTools.var"#ArrayPartition_adjoint#97"{Tuple{Vector{Float64}}})(_y::ArrayPartition{Float64, Tuple{Vector{Float64}}})
@ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/WvCPk/src/zygote.jl:22
[2] ZBack
@ ~/.julia/packages/Zygote/IoW2g/src/compiler/chainrules.jl:205 [inlined]
Full stacktrace
ERROR: MethodError: no method matching ntuple(::RecursiveArrayTools.var"#96#98"{Vector{Int64}, Vector{Float64}, Tuple{Vector{Float64}}})
Closest candidates are:
ntuple(::F, ::Integer) where F at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/ntuple.jl:17
ntuple(::F, ::Static.StaticInt{N}) where {F, N} at ~/.julia/packages/Static/SlJif/src/Static.jl:513
ntuple(::Any, ::Val{0}) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/ntuple.jl:47
...
Stacktrace:
[1] (::RecursiveArrayTools.var"#ArrayPartition_adjoint#97"{Tuple{Vector{Float64}}})(_y::ArrayPartition{Float64, Tuple{Vector{Float64}}})
@ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/WvCPk/src/zygote.jl:22
[2] ZBack
@ ~/.julia/packages/Zygote/IoW2g/src/compiler/chainrules.jl:205 [inlined]
[3] Pullback
@ ~/.julia/packages/RecursiveArrayTools/WvCPk/src/array_partition.jl:502 [inlined]
[4] (::typeof(∂(convert)))(Δ::ArrayPartition{Float64, Tuple{Vector{Float64}}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[5] Pullback
@ ~/.julia/packages/SciMLBase/TqBga/src/problems/ode_problems.jl:93 [inlined]
[6] (::typeof(∂(_#241)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, ArrayPartition{Float64, Tuple{Vector{Float64}}}, Nothing, Vector{Float64}, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[7] Pullback
@ ~/.julia/packages/SciMLBase/TqBga/src/problems/ode_problems.jl:92 [inlined]
[8] (::typeof(∂(ODEProblem{true})))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, ArrayPartition{Float64, Tuple{Vector{Float64}}}, Nothing, Vector{Float64}, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[9] Pullback
@ ~/.julia/packages/SciMLBase/TqBga/src/problems/ode_problems.jl:92 [inlined]
[10] (::typeof(∂(ODEProblem{true})))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, ArrayPartition{Float64, Tuple{Vector{Float64}}}, Nothing, Vector{Float64}, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[11] (::Zygote.var"#216#217"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(ODEProblem{true}))})(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, ArrayPartition{Float64, Tuple{Vector{Float64}}}, Nothing, Vector{Float64}, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/lib/lib.jl:207
[12] (::Zygote.var"#1909#back#218"{Zygote.var"#216#217"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(ODEProblem{true}))}})(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, ArrayPartition{Float64, Tuple{Vector{Float64}}}, Nothing, Vector{Float64}, Nothing, Nothing}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[13] Pullback
@ ~/.julia/packages/SciMLBase/TqBga/src/problems/ode_problems.jl:131 [inlined]
[14] (::typeof(∂(#ODEProblem#251)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, ArrayPartition{Float64, Tuple{Vector{Float64}}}, Nothing, Vector{Float64}, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[15] (::Zygote.var"#216#217"{Tuple{NTuple{5, Nothing}, Tuple{Nothing}}, typeof(∂(#ODEProblem#251))})(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, ArrayPartition{Float64, Tuple{Vector{Float64}}}, Nothing, Vector{Float64}, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/lib/lib.jl:207
[16] #1909#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[17] Pullback
@ ~/.julia/packages/SciMLBase/TqBga/src/problems/ode_problems.jl:131 [inlined]
[18] (::typeof(∂(ODEProblem)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, ArrayPartition{Float64, Tuple{Vector{Float64}}}, Nothing, Vector{Float64}, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[19] Pullback
@ ~/.julia/packages/SciMLBase/TqBga/src/problems/ode_problems.jl:135 [inlined]
[20] (::typeof(∂(#ODEProblem#252)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, ArrayPartition{Float64, Tuple{Vector{Float64}}}, Nothing, Vector{Float64}, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[21] Pullback
@ ~/.julia/packages/SciMLBase/TqBga/src/problems/ode_problems.jl:135 [inlined]
[22] (::typeof(∂(ODEProblem)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, ArrayPartition{Float64, Tuple{Vector{Float64}}}, Nothing, Vector{Float64}, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
...
Metadata
Metadata
Assignees
Labels
No labels