From 14e3e113ccf83eb884027ab9c1fecb3cce84f299 Mon Sep 17 00:00:00 2001 From: Frank Schaefer Date: Thu, 4 Aug 2022 12:33:15 -0400 Subject: [PATCH 1/4] add a sensitive test --- test/adjoints.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/adjoints.jl b/test/adjoints.jl index 406267a9..d26547f6 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -1,4 +1,5 @@ using RecursiveArrayTools, Zygote, ForwardDiff, Test +using OrdinaryDiffEq function loss(x) sum(abs2,Array(VectorOfArray([x .* i for i in 1:5]))) @@ -30,6 +31,12 @@ function loss5(x) sum(abs2,Array(ArrayPartition([x .* i for i in 1:5]...))) end +function loss6(x) + _x = ArrayPartition([x .* i for i in 1:5]...) + _prob = ODEProblem((u,p,t)->u, _x, (0,1)) + sum(abs2, Array(_prob.u0)) +end + x = float.(6:10) loss(x) @test Zygote.gradient(loss,x)[1] == ForwardDiff.gradient(loss,x) @@ -37,3 +44,4 @@ loss(x) @test Zygote.gradient(loss3,x)[1] == ForwardDiff.gradient(loss3,x) @test Zygote.gradient(loss4,x)[1] == ForwardDiff.gradient(loss4,x) @test Zygote.gradient(loss5,x)[1] == ForwardDiff.gradient(loss5,x) +@test Zygote.gradient(loss6,x)[1] == ForwardDiff.gradient(loss6,x) From 58182c332704d02e95ec7b8715a94f33ab226dba Mon Sep 17 00:00:00 2001 From: Frank Schaefer Date: Thu, 4 Aug 2022 12:38:23 -0400 Subject: [PATCH 2/4] fix ntuple bracket --- src/zygote.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zygote.jl b/src/zygote.jl index 5f14889c..623eeed8 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -19,7 +19,7 @@ function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x} function ArrayPartition_adjoint(_y) y = Array(_y) starts = vcat(0,cumsum(reduce(vcat,length.(x)))) - NoTangent(), ArrayPartition(ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i]))), length(x)), NoTangent() + NoTangent(), ArrayPartition(ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x))), NoTangent() end ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint From c3644f540628027d5db189432c84f7112b534ba3 Mon Sep 17 00:00:00 2001 From: Frank Schaefer Date: Fri, 5 Aug 2022 17:11:53 -0400 Subject: [PATCH 3/4] remove ArrayPartition vom rrule --- src/zygote.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zygote.jl b/src/zygote.jl index 623eeed8..1617b2e5 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -19,7 +19,7 @@ function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x} function ArrayPartition_adjoint(_y) y = Array(_y) starts = vcat(0,cumsum(reduce(vcat,length.(x)))) - NoTangent(), ArrayPartition(ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x))), NoTangent() + NoTangent(), (ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x))), NoTangent() end ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint From ddff8901dfd7c813f6d495b67f841ee55253cf9a Mon Sep 17 00:00:00 2001 From: Frank Schaefer Date: Tue, 9 Aug 2022 11:02:29 -0400 Subject: [PATCH 4/4] remove unnecessary brackets --- src/zygote.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zygote.jl b/src/zygote.jl index 1617b2e5..8b795c40 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -19,7 +19,7 @@ function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x} function ArrayPartition_adjoint(_y) y = Array(_y) starts = vcat(0,cumsum(reduce(vcat,length.(x)))) - NoTangent(), (ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x))), NoTangent() + NoTangent(), ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), NoTangent() end ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint