diff --git a/src/zygote.jl b/src/zygote.jl index 5f14889c..8b795c40 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -19,7 +19,7 @@ function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x} function ArrayPartition_adjoint(_y) y = Array(_y) starts = vcat(0,cumsum(reduce(vcat,length.(x)))) - NoTangent(), ArrayPartition(ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i]))), length(x)), NoTangent() + NoTangent(), ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), NoTangent() end ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint diff --git a/test/adjoints.jl b/test/adjoints.jl index 406267a9..d26547f6 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -1,4 +1,5 @@ using RecursiveArrayTools, Zygote, ForwardDiff, Test +using OrdinaryDiffEq function loss(x) sum(abs2,Array(VectorOfArray([x .* i for i in 1:5]))) @@ -30,6 +31,12 @@ function loss5(x) sum(abs2,Array(ArrayPartition([x .* i for i in 1:5]...))) end +function loss6(x) + _x = ArrayPartition([x .* i for i in 1:5]...) + _prob = ODEProblem((u,p,t)->u, _x, (0,1)) + sum(abs2, Array(_prob.u0)) +end + x = float.(6:10) loss(x) @test Zygote.gradient(loss,x)[1] == ForwardDiff.gradient(loss,x) @@ -37,3 +44,4 @@ loss(x) @test Zygote.gradient(loss3,x)[1] == ForwardDiff.gradient(loss3,x) @test Zygote.gradient(loss4,x)[1] == ForwardDiff.gradient(loss4,x) @test Zygote.gradient(loss5,x)[1] == ForwardDiff.gradient(loss5,x) +@test Zygote.gradient(loss6,x)[1] == ForwardDiff.gradient(loss6,x)