-
-
Notifications
You must be signed in to change notification settings - Fork 24
Open
Labels
Description
Example:
using Optimisers, Functors, Zygote
struct TwoThirds a; b; c; end # from the tests
Functors.@functor TwoThirds (a, c)
Optimisers.trainable(x::TwoThirds) = (a = x.a,)
mtt = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])
v, re = destructure(mtt)
re(100v) # TwoThirds([100.0, 200.0], [3.0], [4.0, 5.0])
gradient(mtt) do x
w, _ = destructure(x)
1000 * prod(w)
end # ((a = [2000.0, 1000.0], b = nothing, c = [4.0, 5.0]),)
Here b
is correct (excluded from children), but c
(non-trainable child) has a value not a gradient.