From 57a3fa26cbf3a0478142e93ec16ebb25a62b5847 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 25 Mar 2022 21:18:26 -0400
Subject: [PATCH 1/7] attempt 62

---
 src/destructure.jl  |  6 +++---
 test/destructure.jl | 47 ++++++++++++++++++++++++++++++++++++++++++---
 test/runtests.jl    | 17 ++++++++++++++++
 3 files changed, 64 insertions(+), 6 deletions(-)

diff --git a/src/destructure.jl b/src/destructure.jl
index 2b91983d..15a4bb64 100644
--- a/src/destructure.jl
+++ b/src/destructure.jl
@@ -91,7 +91,7 @@ _getat(y::AbstractArray, o::Int, flat::AbstractVector) =
 
 function _trainable_biwalk(f, x, aux)
   ch, re = functor(typeof(x), x)
-  au, _ = functor(typeof(x), aux)
+  au, _ = functor(aux) 
   _trainmap(f, ch, _trainable(x), au) |> re
 end
 
@@ -103,7 +103,7 @@ end
 
 function _Tangent_biwalk(f, x, aux)  # use with prune = NoT
   ch, re = functor(typeof(x), x)
-  au, _ = functor(typeof(x), aux)
+  au, _ = functor(aux)
   y = _trainmap(f, ch, _trainable(x), au)
   y isa Tuple{} && return NoT
   p = ProjectTo(x)
@@ -126,7 +126,7 @@ ChainRulesCore.@non_differentiable _zero(x)
 function _grad!(x, dx, off, flat::AbstractVector)
   x′, _ = functor(typeof(x), x)
   dx′, _ = functor(typeof(x), base(dx))
-  off′, _ = functor(typeof(x), off)
+  off′, _ = functor(off)
   foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
   flat
 end
diff --git a/test/destructure.jl b/test/destructure.jl
index 043315b3..fe5699cb 100644
--- a/test/destructure.jl
+++ b/test/destructure.jl
@@ -49,7 +49,7 @@ m9 = (a = m1, b = mat, c = [mat, m1])
   m8′ = destructure(m8)[2](1:5)
   @test m8′[1].x === m8′[1].y
   @test m8′[2].b.y === false
-  @test m8′[3][1] == [5.0]
+  @test m8′[3][1] == [5.0] # broken
 
   m9′ = destructure(m9)[2](10:10:70)
   @test m9′.b === m9′.c[1]
@@ -79,7 +79,7 @@ end
   g8 = gradient(m -> sum(abs2, destructure(m)[1]), m8)[1]
   @test g8[1].x == [2,4,6]
   @test g8[2].b.x == [8]
-  @test g8[3] == [[10.0]]
+  @test g8[3] == [[10.0]]  # fails
 
   g9 = gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1]
   @test g9.c === nothing
@@ -130,7 +130,7 @@ end
 
   v8, re8 = destructure(m8)
   @test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0]
-  @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]
+  @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]  # fails
 
   re9 = destructure(m9)[2]
   @test gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14]
@@ -180,3 +180,44 @@ end
     4(sum(m.x) + sum(m.y)) + 13*sum(m.z)  # again two gradients are ===, so it eliminates one
   end == ([17,17,4,4],)  # Flux gave ([4.0, 4.0, 13.0, 13.0],)
 end
+
+@testset "issue 62" begin
+  # Flux.Chain used to have children which aren't its own fields, which Skip immitates.
+
+  sk = Skip([1.0, 2.0], (x=3, y=[4.0, 5.0]))
+  @test fmap(identity, sk) == sk
+
+  gk = gradient(x -> sum(x[2].y), sk)[1]
+  @test fmap(Zygote.accum, sk, gk) isa Skip  # this relies on functor(typeof(x), dx)
+
+  st = fmapstructure(identity, sk)
+  @test st isa Tuple{Vector, NamedTuple}
+  @test_throws Exception fmap(+, sk, st)  # this fails because of functor(typeof(x), dx)
+
+  v, re = destructure(sk)
+  @test v == [1,2,4,5]
+  @test re(10v) isa Skip
+  @test re(10v)[1] == [10, 20]
+
+  @test gradient(zero(v)) do w
+    re(w)[2].y[1]
+  end == ([0,0,1,0],)
+
+  # gradient(sk) do x
+  #   w, _ = destructure(x)
+  #   w[1]
+  # end
+#=
+
+ERROR: ArgumentError: Tangent for the primal Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}} should be backed by a NamedTuple type, not by Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}.
+Stacktrace:
+  [1] _backing_error(P::Type, G::Type, E::Type)
+    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/tangent.jl:62
+  [2] ChainRulesCore.Tangent{Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}}, Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}}(backing::Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}})
+    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/tangent.jl:36
+  [3] _Tangent_biwalk(f::Function, x::Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}}, aux::Tuple{Int64, NamedTuple{(:x, :y), Tuple{Tuple{}, Int64}}})
+    @ Optimisers ~/.julia/dev/Optimisers/src/destructure.jl:116
+
+=#
+
+end
diff --git a/test/runtests.jl b/test/runtests.jl
index d47bce08..1a54c5e4 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -13,6 +13,13 @@ struct TwoThirds a; b; c; end
 Functors.@functor TwoThirds (a, c)
 Optimisers.trainable(x::TwoThirds) = (a = x.a,)
 
+struct Skip{T}  # like Flux 0.12's Chain
+  layers::T
+  Skip(ls...) = new{typeof(ls)}(ls)
+end
+Base.getindex(x::Skip, i::Integer) = x.layers[i]
+Functors.functor(::Type{<:Skip}, x) = x.layers, ls -> Skip(ls...)
+
 @testset verbose=true "Optimisers.jl" begin
   @testset verbose=true "Features" begin
 
@@ -165,6 +172,16 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
       @test_throws ArgumentError Optimisers.setup(ADAMW(), m2)
     end
 
+    @testset "issue 62" begin
+      m62 = (s = Skip([1.0, 2.0], Foo([3.0], false)), t = [4.0, 5.0])
+      s62 = Optimisers.setup(Descent(), m62)
+      g62 = gradient(m -> m.s[2].x[1] + 3 * m.t[2], m62)
+      s, m = Optimisers.update(s62, m62, g62...)
+      @test m.s isa Skip
+      @test m.s[2].x ≈ [2.9]
+      @test m.t ≈ [4, 4.7]
+    end
+
   end
   @testset verbose=true "Destructure" begin
     include("destructure.jl")

From 81e41fe69eb550e618cf4e5b9639962d5d048d9f Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 25 Mar 2022 22:29:14 -0400
Subject: [PATCH 2/7] next idea

---
 src/destructure.jl  |  9 ++++++---
 test/destructure.jl | 12 ++++++------
 2 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/src/destructure.jl b/src/destructure.jl
index 15a4bb64..36928134 100644
--- a/src/destructure.jl
+++ b/src/destructure.jl
@@ -91,10 +91,13 @@ _getat(y::AbstractArray, o::Int, flat::AbstractVector) =
 
 function _trainable_biwalk(f, x, aux)
   ch, re = functor(typeof(x), x)
-  au, _ = functor(aux) 
+  au = _aux_children(aux) 
   _trainmap(f, ch, _trainable(x), au) |> re
 end
 
+_aux_children(off) = functor(off)[1]
+_aux_children(off::AbstractArray) = off  # leaflike according to Functors, but we need to see each offset
+
 function _trainmap(f, ch, tr, aux)
   map(ch, tr, aux) do c, t, a  # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
     isnothing(t) ? c : f(t, a)
@@ -103,7 +106,7 @@ end
 
 function _Tangent_biwalk(f, x, aux)  # use with prune = NoT
   ch, re = functor(typeof(x), x)
-  au, _ = functor(aux)
+  au = _aux_children(aux)
   y = _trainmap(f, ch, _trainable(x), au)
   y isa Tuple{} && return NoT
   p = ProjectTo(x)
@@ -126,7 +129,7 @@ ChainRulesCore.@non_differentiable _zero(x)
 function _grad!(x, dx, off, flat::AbstractVector)
   x′, _ = functor(typeof(x), x)
   dx′, _ = functor(typeof(x), base(dx))
-  off′, _ = functor(off)
+  off′ = _aux_children(off)
   foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
   flat
 end
diff --git a/test/destructure.jl b/test/destructure.jl
index fe5699cb..df1ecffb 100644
--- a/test/destructure.jl
+++ b/test/destructure.jl
@@ -49,7 +49,7 @@ m9 = (a = m1, b = mat, c = [mat, m1])
   m8′ = destructure(m8)[2](1:5)
   @test m8′[1].x === m8′[1].y
   @test m8′[2].b.y === false
-  @test m8′[3][1] == [5.0] # broken
+  @test m8′[3][1] == [5.0]
 
   m9′ = destructure(m9)[2](10:10:70)
   @test m9′.b === m9′.c[1]
@@ -130,7 +130,7 @@ end
 
   v8, re8 = destructure(m8)
   @test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0]
-  @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]  # fails
+  @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]
 
   re9 = destructure(m9)[2]
   @test gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14]
@@ -203,10 +203,10 @@ end
     re(w)[2].y[1]
   end == ([0,0,1,0],)
 
-  # gradient(sk) do x
-  #   w, _ = destructure(x)
-  #   w[1]
-  # end
+  gradient(sk) do x
+    w, _ = destructure(x)
+    w[1]
+  end
 #=
 
 ERROR: ArgumentError: Tangent for the primal Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}} should be backed by a NamedTuple type, not by Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}.

From 6a0ba9f259778a22728be529c0f6da6bb24164f4 Mon Sep 17 00:00:00 2001
From: Jonathan Doucette <jdoucette@physics.ubc.ca>
Date: Wed, 20 Apr 2022 23:47:19 -0700
Subject: [PATCH 3/7] `Offset` wrapper to avoid confusing `isleaf` with
 `AbstractArray{Int}`s of offsets (also simplifying `_aux_children`); fix
 broken test for issue #62

---
 src/destructure.jl  | 22 +++++++++++++---------
 test/destructure.jl |  6 +++---
 2 files changed, 16 insertions(+), 12 deletions(-)

diff --git a/src/destructure.jl b/src/destructure.jl
index e114ff1e..b7f82c2f 100644
--- a/src/destructure.jl
+++ b/src/destructure.jl
@@ -53,16 +53,20 @@ end
 Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
 Base.length(re::Restructure) = re.length
 
+struct Offset
+  i::Int
+end
+
 # This flattens a model, and returns a web of offsets for later use:
 function _flatten(x)
-  isnumeric(x) && return vcat(_vec(x)), 0, length(x)  # trivial case
+  isnumeric(x) && return vcat(_vec(x)), Offset(0), length(x)  # trivial case
   arrays = AbstractVector[]
   len = Ref(0)
   off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
     push!(arrays, _vec(y))
     o = len[]
     len[] = o + length(y)
-    o
+    Offset(o)
   end
   reduce(vcat, arrays), off, len[]
 end
@@ -85,9 +89,9 @@ function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trai
   end
 end
 
-_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1])
-_getat(y::AbstractArray, o::Int, flat::AbstractVector) =
-  ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y)))  # ProjectTo is just correcting eltypes
+_getat(y::Number, off::Offset, flat::AbstractVector) = ProjectTo(y)(flat[off.i + 1])
+_getat(y::AbstractArray, off::Offset, flat::AbstractVector) =
+  ProjectTo(y)(reshape(flat[off.i .+ (1:length(y))], axes(y)))  # ProjectTo is just correcting eltypes
 
 function _trainable_biwalk(f, x, aux)
   ch, re = functor(typeof(x), x)
@@ -96,7 +100,6 @@ function _trainable_biwalk(f, x, aux)
 end
 
 _aux_children(off) = functor(off)[1]
-_aux_children(off::AbstractArray) = off  # leaflike according to Functors, but we need to see each offset
 
 function _trainmap(f, ch, tr, aux)
   map(ch, tr, aux) do c, t, a  # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
@@ -113,6 +116,7 @@ function _Tangent_biwalk(f, x, aux)  # use with prune = NoT
   if p isa ProjectTo  # e.g. Array, NamedTuple
     p(y)
   else  # p === identity for unknown structs
+    y = backing(re(y)) # extract NamedTuple backing from re(y); required if x has children which aren't its own fields
     Tangent{typeof(x), typeof(y)}(y)
   end
 end
@@ -135,17 +139,17 @@ function _grad!(x, dx, off, flat::AbstractVector)
   end
   flat
 end
-function _grad!(x, dx, off::Integer, flat::AbstractVector{T}) where T
+function _grad!(x, dx, off::Offset, flat::AbstractVector{T}) where T
   dx_un = unthunk(dx)
   T2 = promote_type(T, eltype(dx_un))
   if T != T2  # then we must widen the type
     flat = copyto!(similar(flat, T2), flat)
   end
-  @views flat[off .+ (1:length(x))] .+= vec(dx_un)  # must visit all tied nodes
+  @views flat[off.i .+ (1:length(x))] .+= vec(dx_un)  # must visit all tied nodes
   flat
 end
 _grad!(x, dx::Zero, off, flat::AbstractVector) = flat
-_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = flat  # ambiguity
+_grad!(x, dx::Zero, off::Offset, flat::AbstractVector) = flat  # ambiguity
 
 # These are only needed for 2nd derivatives:
 function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
diff --git a/test/destructure.jl b/test/destructure.jl
index 76a71011..5afbc746 100644
--- a/test/destructure.jl
+++ b/test/destructure.jl
@@ -203,10 +203,10 @@ end
     re(w)[2].y[1]
   end == ([0,0,1,0],)
 
-  gradient(sk) do x
+  @test gradient(sk) do x
     w, _ = destructure(x)
-    w[1]
-  end
+    w[1] + w[4]
+  end == ((layers = ([1.0, 0.0], (x = nothing, y = [0.0, 1.0])),),)
 #=
 
 ERROR: ArgumentError: Tangent for the primal Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}} should be backed by a NamedTuple type, not by Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}.

From 927e095ae1f4742a27cc0710132ea1000ad1e20e Mon Sep 17 00:00:00 2001
From: Jonathan Doucette <jdoucette@physics.ubc.ca>
Date: Fri, 22 Apr 2022 11:01:18 -0700
Subject: [PATCH 4/7] test no longer fails; offset structure is not leaflike
 using `Offset` wrapper

---
 test/destructure.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/destructure.jl b/test/destructure.jl
index 5afbc746..a303a485 100644
--- a/test/destructure.jl
+++ b/test/destructure.jl
@@ -79,7 +79,7 @@ end
   g8 = gradient(m -> sum(abs2, destructure(m)[1]), m8)[1]
   @test g8[1].x == [2,4,6]
   @test g8[2].b.x == [8]
-  @test g8[3] == [[10.0]]  # fails
+  @test g8[3] == [[10.0]]
 
   g9 = gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1]
   @test g9.c === nothing

From ba909d238cb020070b28aa1f4c6b775c4b3c7785 Mon Sep 17 00:00:00 2001
From: Jonathan Doucette <jdoucette@physics.ubc.ca>
Date: Fri, 22 Apr 2022 14:05:56 -0700
Subject: [PATCH 5/7] delete error message for gradient of `destructure`, which
 is working now

---
 test/destructure.jl | 12 ------------
 1 file changed, 12 deletions(-)

diff --git a/test/destructure.jl b/test/destructure.jl
index a303a485..9740b495 100644
--- a/test/destructure.jl
+++ b/test/destructure.jl
@@ -207,18 +207,6 @@ end
     w, _ = destructure(x)
     w[1] + w[4]
   end == ((layers = ([1.0, 0.0], (x = nothing, y = [0.0, 1.0])),),)
-#=
-
-ERROR: ArgumentError: Tangent for the primal Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}} should be backed by a NamedTuple type, not by Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}.
-Stacktrace:
-  [1] _backing_error(P::Type, G::Type, E::Type)
-    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/tangent.jl:62
-  [2] ChainRulesCore.Tangent{Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}}, Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}}(backing::Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}})
-    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/tangent.jl:36
-  [3] _Tangent_biwalk(f::Function, x::Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}}, aux::Tuple{Int64, NamedTuple{(:x, :y), Tuple{Tuple{}, Int64}}})
-    @ Optimisers ~/.julia/dev/Optimisers/src/destructure.jl:116
-
-=#
 end
 
 @testset "DiffEqFlux issue 699" begin

From abf87388b694463c5d299a2f9ce1dbf44a89bcab Mon Sep 17 00:00:00 2001
From: Jonathan Doucette <jdoucette@physics.ubc.ca>
Date: Fri, 22 Apr 2022 14:12:09 -0700
Subject: [PATCH 6/7] remove `_aux_children`

---
 src/destructure.jl | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/src/destructure.jl b/src/destructure.jl
index b7f82c2f..5fdb9e8e 100644
--- a/src/destructure.jl
+++ b/src/destructure.jl
@@ -95,12 +95,10 @@ _getat(y::AbstractArray, off::Offset, flat::AbstractVector) =
 
 function _trainable_biwalk(f, x, aux)
   ch, re = functor(typeof(x), x)
-  au = _aux_children(aux)
+  au, _ = functor(aux)
   _trainmap(f, ch, _trainable(x), au) |> re
 end
 
-_aux_children(off) = functor(off)[1]
-
 function _trainmap(f, ch, tr, aux)
   map(ch, tr, aux) do c, t, a  # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
     isnothing(t) ? c : f(t, a)
@@ -109,7 +107,7 @@ end
 
 function _Tangent_biwalk(f, x, aux)  # use with prune = NoT
   ch, re = functor(typeof(x), x)
-  au = _aux_children(aux)
+  au, _ = functor(aux)
   y = _trainmap(f, ch, _trainable(x), au)
   y isa Tuple{} && return NoT
   p = ProjectTo(x)
@@ -133,7 +131,7 @@ ChainRulesCore.@non_differentiable _zero(x)
 function _grad!(x, dx, off, flat::AbstractVector)
   x′, _ = functor(typeof(x), x)
   dx′, _ = functor(typeof(x), base(dx))
-  off′ = _aux_children(off)
+  off′, _ = functor(off)
   for (xᵢ, dxᵢ, oᵢ) in zip(x′, dx′, off′)
     flat = _grad!(xᵢ, dxᵢ, oᵢ, flat)
   end

From 415b5971ad90ff3ac2eecfb2edb40703fee5e36f Mon Sep 17 00:00:00 2001
From: Jonathan Doucette <jdoucette@physics.ubc.ca>
Date: Sat, 30 Apr 2022 21:31:24 -0700
Subject: [PATCH 7/7] modified `_trainmap` which returns `NoT` for `functor`-ed
 values which are not `trainable`; filter primal values from `backing(re(y))`

---
 src/destructure.jl  | 10 ++++++++--
 test/destructure.jl | 10 ++++++++--
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/src/destructure.jl b/src/destructure.jl
index 5fdb9e8e..49f73ed1 100644
--- a/src/destructure.jl
+++ b/src/destructure.jl
@@ -108,13 +108,19 @@ end
 function _Tangent_biwalk(f, x, aux)  # use with prune = NoT
   ch, re = functor(typeof(x), x)
   au, _ = functor(aux)
-  y = _trainmap(f, ch, _trainable(x), au)
+  y = map(ch, _trainable(x), au) do c, t, a  # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
+    isnothing(t) ? NoT : f(t, a)
+  end
   y isa Tuple{} && return NoT
   p = ProjectTo(x)
   if p isa ProjectTo  # e.g. Array, NamedTuple
     p(y)
   else  # p === identity for unknown structs
-    y = backing(re(y)) # extract NamedTuple backing from re(y); required if x has children which aren't its own fields
+    y = map(backing(x), backing(re(y))) do c, t
+      # backing(re(y)) extracts NamedTuple backing from re(y); required if x has children which aren't its own fields
+      # however, re(y) will repopulate primal field values from x which weren't functor-ed; these gradients should be NoT
+      c === t ? NoT : t
+    end
     Tangent{typeof(x), typeof(y)}(y)
   end
 end
diff --git a/test/destructure.jl b/test/destructure.jl
index 9740b495..fce4ceeb 100644
--- a/test/destructure.jl
+++ b/test/destructure.jl
@@ -205,8 +205,14 @@ end
 
   @test gradient(sk) do x
     w, _ = destructure(x)
-    w[1] + w[4]
-  end == ((layers = ([1.0, 0.0], (x = nothing, y = [0.0, 1.0])),),)
+    w[1]^2 + w[4]^2
+  end == ((layers = ([2.0, 0.0], (x = nothing, y = [0.0, 10.0])),),)
+
+  ac = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])  # a,c are functor-ed, and only a is trainable
+  @test gradient(ac) do x
+    w2, _ = destructure(x)
+    w2[2]^2
+  end == ((a = [0.0, 4.0], b = nothing, c = nothing),)
 end
 
 @testset "DiffEqFlux issue 699" begin