diff --git a/src/back.jl b/src/back.jl index e0cc739..23c8383 100644 --- a/src/back.jl +++ b/src/back.jl @@ -1,11 +1,14 @@ # The AD generates fairly large backtraces that are unhelpful if you interrupt # while training; this just cleans that up. macro interrupts(ex) - :(try $(esc(ex)) + :( + try + $(esc(ex)) catch e e isa InterruptException || rethrow() throw(e) - end) + end + ) end # In-place gradients @@ -14,54 +17,113 @@ init_grad(x) = zero(x) zero_grad!(x) = zero(x) zero_grad!(x::AbstractArray) = (x .= 0) -scan(c::Call) = foreach(scan, c.args) - -function scan(x::Tracked) - x.isleaf && return - ref = x.ref += 1 - if ref == 1 - scan(x.f) - isdefined(x, :grad) && (x.grad = zero_grad!(x.grad)) - end - return -end - -function scan(x) - istracked(x) && scan(tracker(x)) - return -end - -function back_(c::Call, Δ, once) +# scan(c::Call) = foreach(scan, c.args) + +# function scan(x::Tracked) +# x.isleaf && return +# ref = x.ref += 1 +# if ref == 1 +# scan(x.f) +# isdefined(x, :grad) && (x.grad = zero_grad!(x.grad)) +# end +# return +# end + +# function scan(x) +# istracked(x) && scan(tracker(x)) +# return +# end + +# function back_(c::Call, Δ, once) +# Δs = c.func(Δ) +# (Δs isa Tuple && length(Δs) >= length(c.args)) || +# error("Gradient is not a tuple of length $(length(c.args))") +# foreach((x, d) -> back(x, d, once), c.args, data.(Δs)) +# end + +function back_(c::Call, Δ) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") - foreach((x, d) -> back(x, d, once), c.args, data.(Δs)) + foreach((x, d) -> back(x, d), c.args, data.(Δs)) end -back_(::Call{Nothing}, Δ, once) = nothing -back_(::Call{Missing}, Δ, once) = error("`back!` was already used") +# back_(::Call{Nothing}, Δ, once) = nothing +# back_(::Call{Missing}, Δ, once) = error("`back!` was already used") + +back_(::Call{Nothing}, Δ) = nothing +back_(::Call{Missing}, Δ) = error("`back!` was already used") accum!(x, Δ) = x .+ Δ accum!(x::AbstractArray, Δ) = (x .+= Δ) -function back(x::Tracked, Δ, once) - x.isleaf && (x.grad = accum!(x.grad, Δ); return) - ref = x.ref -= 1 - grad = if isdefined(x, :grad) +# function back(x::Tracked, Δ, once) +# x.isleaf && (x.grad = accum!(x.grad, Δ); return) +# ref = x.ref -= 1 +# grad = if isdefined(x, :grad) +# x.grad = accum!(x.grad, Δ) +# elseif ref > 0 +# x.grad = Δ +# else +# Δ +# end +# if ref == 0 +# back_(x.f, grad, once) +# once && !x.isleaf && (x.f = Call(missing, ())) +# end +# return +# end + + +# function back(x::Tracked, Δ) +# # Increment the reference count +# x.ref += 1 + +# # Handle gradient accumulation and backpropagation based on the reference count +# if x.ref == 1 +# # Node has no more references, perform backpropagation and reset gradient +# x.grad = Δ +# back_(x.f, Δ) +# else +# # Node already has additional references, accumulate gradient into the gradient buffer +# x.grad = accum!(x.grad, Δ) +# end + +# # Decrement the reference count +# x.ref -= 1 + +# return +# end + + + +function back(x::Tracked, Δ) + if x.isleaf + x.grad = accum!(x.grad, Δ) + return + end + + x.ref -= 1 + if isdefined(x, :grad) x.grad = accum!(x.grad, Δ) - elseif ref > 0 + elseif x.ref > 0 x.grad = Δ else - Δ + x.grad = Δ end - if ref == 0 - back_(x.f, grad, once) - once && !x.isleaf && (x.f = Call(missing, ())) + + if x.ref == 0 + Δs = x.f(Δ) + for (arg, d) in zip(x.args, Δs) + back(arg, d) + end end - return end -back(::Nothing, Δ, once) = return + + +# back(::Nothing, Δ, once) = return +back(::Nothing, Δ) = return # Interface methods @@ -71,10 +133,24 @@ back(::Nothing, Δ, once) = return # Refcounts are also probably not safe in some situations (e.g. back called # from within a backpropagator) -function back!(x, Δ; once = true) + +function back!(x, Δ) + istracked(x) || return + back(tracker(x), Δ) +end + + +# function back!(x, Δ; once=true) +# # back(tracker(x), Δ, once) # Call the back function starting from the tracker of x +# back(tracker(x), Δ) # Call the back function starting from the tracker of x +# return +# end + +function back!(x, Δ; once=true) istracked(x) || return - scan(x) - back(tracker(x), Δ, once) + # scan(x) + # back(tracker(x), Δ, once) + back(tracker(x), Δ) return end @@ -161,7 +237,7 @@ function gradient_nested(f, args...) return back(1) end -gradient(f, xs...; nest = false) = +gradient(f, xs...; nest=false) = nest ? gradient_nested(f, xs...) : gradient_(f, xs...) # Jacobians and Hessians @@ -219,14 +295,14 @@ julia> withgradient(model, rand(Float32, 2)) do m, x ``` """ function withgradient(f, xs...) - pxs = fmap(param, xs; exclude = isnumeric, walk = _trainable_walk) - l = f(pxs...) - l1 = l isa Union{Tuple, NamedTuple} ? first(l) : l - val = l isa Union{Tuple, NamedTuple} ? fmap(data, l) : data(l) - losscheck(l1) - l1 isa TrackedReal || return (; val, grad = map(_ -> nothing, xs)) - @interrupts back!(l1) - (; val, grad = rec_grad(pxs)) + pxs = fmap(param, xs; exclude=isnumeric, walk=_trainable_walk) + l = f(pxs...) + l1 = l isa Union{Tuple,NamedTuple} ? first(l) : l + val = l isa Union{Tuple,NamedTuple} ? fmap(data, l) : data(l) + losscheck(l1) + l1 isa TrackedReal || return (; val, grad=map(_ -> nothing, xs)) + @interrupts back!(l1) + (; val, grad=rec_grad(pxs)) end function _trainable_walk(f, x) @@ -234,7 +310,7 @@ function _trainable_walk(f, x) isempty(func) && return x done = map(f, _trainable(x)) # recurse only into trainable fields, this contains `nothing` elsewhere map(func, merge(func, done)) do n, t - isnothing(t) ? n : t + isnothing(t) ? n : t end |> re # reconstruct the whole thing end _trainable_walk(f, x::Tuple) = map(f, x) @@ -247,9 +323,9 @@ rec_grad(x::Number) = nothing rec_grad(x::Union{Tuple,NamedTuple,AbstractArray}) = map(rec_grad, x) rec_grad(::Tuple{}) = nothing -rec_grad(::NamedTuple{(), Tuple{}}) = nothing +rec_grad(::NamedTuple{(),Tuple{}}) = nothing function rec_grad(x::T) where {T} - F = fieldnames(T) - isempty(F) && return nothing - map(f -> rec_grad(getfield(x, f)), NamedTuple{F}(F)) + F = fieldnames(T) + isempty(F) && return nothing + map(f -> rec_grad(getfield(x, f)), NamedTuple{F}(F)) end