-
-
Notifications
You must be signed in to change notification settings - Fork 68
Differentiating ArrayPartitions in DEProblems #221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
using Zygote, ForwardDiff
using ChainRules, ChainRulesCore
using RecursiveArrayTools
u0 = [1.0; 1.0]
function f(x)
_x = ArrayPartition(x, x)
sum(abs, _x.x[1])
end
f(u0)
Fdx = ForwardDiff.gradient(f, u0)
Zdx = Zygote.gradient(f, u0)[1] # works
struct foo{uType}
u0::uType
end
# function ChainRulesCore.rrule(::Type{foo}, u0)
# foo_pullback(Δfoo) = NoTangent(), Δfoo.u0
# return foo(u0), foo_pullback
# end
function f2(x)
_x = ArrayPartition(x, x)
_prob = foo(_x)
sum(abs, _prob.u0[1])
end
f2(u0)
Fdx = ForwardDiff.gradient(f2, u0)
Zdx = Zygote.gradient(f2, u0)[1]
using DiffEqBase
function f3(x)
_x = ArrayPartition(x, x)
_prob = ODEProblem((u, p, t) -> u, _x, (0.0, 1.0))
sum(abs, _prob.u0[1])
end
f3(u0)
Fdx = ForwardDiff.gradient(f3, u0)
Zdx = Zygote.gradient(f3, u0)[1] If I use the rrule for the constructor of Zdx = Zygote.gradient(f2, u0)[1]
Zdx = Zygote.gradient(f3, u0)[1] throw the same error: ERROR: Can't differentiate foreigncall expression.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/dev/limitations.html
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] Pullback
@ ./essentials.jl:599 [inlined]
[3] (::typeof(∂(getindex)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[4] Pullback
@ ~/.julia/dev/RecursiveArrayTools/src/array_partition.jl:503 [inlined]
[5] Pullback
@ ./ntuple.jl:49 [inlined]
[6] Pullback
@ ~/.julia/dev/RecursiveArrayTools/src/array_partition.jl:502 [inlined]
[7] (::typeof(∂(convert)))(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[8] Pullback
@ ~/SciMLSensitivity/689/MWE.jl:21 [inlined]
[9] (::typeof(∂(foo{ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}})))(Δ::NamedTuple{(:u0,), Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[10] Pullback
@ ~/SciMLSensitivity/689/MWE.jl:21 [inlined]
[11] (::typeof(∂(foo)))(Δ::NamedTuple{(:u0,), Tuple{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[12] Pullback
@ ~/SciMLSensitivity/689/MWE.jl:31 [inlined]
[13] (::typeof(∂(f2)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[14] (::Zygote.var"#60#61"{typeof(∂(f2))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:41
[15] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:76
[16] top-level scope
@ ~/SciMLSensitivity/689/MWE.jl:37 ERROR: Can't differentiate foreigncall expression.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/dev/limitations.html
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] Pullback
@ ./essentials.jl:599 [inlined]
[3] (::typeof(∂(getindex)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[4] Pullback
@ ~/.julia/dev/RecursiveArrayTools/src/array_partition.jl:503 [inlined]
[5] Pullback
@ ./ntuple.jl:49 [inlined]
[6] Pullback
@ ~/.julia/dev/RecursiveArrayTools/src/array_partition.jl:502 [inlined]
[7] (::typeof(∂(convert)))(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:93 [inlined]
[9] (::typeof(∂(_#241)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[10] Pullback
@ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:92 [inlined]
[11] (::typeof(∂(ODEProblem{false})))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[12] Pullback
@ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:92 [inlined]
[13] (::typeof(∂(ODEProblem{false})))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[14] (::Zygote.var"#216#217"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(ODEProblem{false}))})(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/lib/lib.jl:207
[15] (::Zygote.var"#1909#back#218"{Zygote.var"#216#217"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(ODEProblem{false}))}})(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[16] Pullback
@ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:131 [inlined]
[17] (::typeof(∂(#ODEProblem#251)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[18] (::Zygote.var"#216#217"{Tuple{NTuple{5, Nothing}, Tuple{Nothing}}, typeof(∂(#ODEProblem#251))})(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/lib/lib.jl:207
[19] #1909#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[20] Pullback
@ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:131 [inlined]
[21] (::typeof(∂(ODEProblem)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[22] Pullback
@ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:135 [inlined]
[23] (::typeof(∂(#ODEProblem#252)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[24] Pullback
@ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:135 [inlined]
[25] (::typeof(∂(ODEProblem)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[26] Pullback
@ ~/.julia/packages/SciMLBase/QzHjf/src/problems/ode_problems.jl:135 [inlined]
[27] (::typeof(∂(ODEProblem)))(Δ::NamedTuple{(:f, :u0, :tspan, :p, :kwargs, :problem_type), Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing, Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[28] Pullback
@ ~/SciMLSensitivity/689/MWE.jl:43 [inlined]
[29] (::typeof(∂(f3)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
[30] (::Zygote.var"#60#61"{typeof(∂(f3))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:41
[31] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:76
[32] top-level scope
@ ~/SciMLSensitivity/689/MWE.jl:49 @ChrisRackauckas where is the rrule for using Zygote, ForwardDiff
using ChainRules, ChainRulesCore
using RecursiveArrayTools
u0 = [1.0; 1.0]
struct foo{uType}
u0::uType
end
# function ChainRulesCore.rrule(::Type{foo}, u0)
# foo_pullback(Δfoo) = NoTangent(), Δfoo.u0
# return foo(u0), foo_pullback
# end
function f2(x)
_x = ArrayPartition(x, x)
_prob = foo(_x)
sum(abs, _prob.u0[1])
end
f2(u0)
Fdx = ForwardDiff.gradient(f2, u0)
Zdx = Zygote.gradient(f2, u0)[1]
|
https://github.com/SciML/DiffEqBase.jl/blob/master/src/chainrules.jl |
Related issue: SciML/SciMLSensitivity.jl#689
The reason why some ArrayPartition tests are working is that the
@code_typed
from Zygote uses the ZygoteRulesThe failing ones instead use the ChainRules definition