From 3eb5e07ded38af28e74946d778a5531fbc1ecd07 Mon Sep 17 00:00:00 2001 From: Mike J Innes <mike.j.innes@gmail.com> Date: Wed, 31 Oct 2018 20:07:26 +0000 Subject: [PATCH] basic TrackedComplex --- src/tracker/Tracker.jl | 1 + src/tracker/back.jl | 14 ++++++------ src/tracker/lib/complex.jl | 44 ++++++++++++++++++++++++++++++++++++++ src/tracker/lib/real.jl | 7 +++--- test/tracker.jl | 2 ++ 5 files changed, 57 insertions(+), 11 deletions(-) create mode 100644 src/tracker/lib/complex.jl diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index e99bc1cd..ca007ee7 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -70,6 +70,7 @@ include("idset.jl") include("back.jl") include("numeric.jl") include("lib/real.jl") +include("lib/complex.jl") include("lib/array.jl") """ diff --git a/src/tracker/back.jl b/src/tracker/back.jl index af130dd3..d74810d5 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -14,10 +14,7 @@ function scan(x::Tracked) return end -function scan(x) - istracked(x) && scan(tracker(x)) - return -end +scan(::Nothing) = return function back_(c::Call, Δ, once) Δs = c.func(Δ) @@ -61,7 +58,7 @@ back(::Nothing, Δ, once) = return function back!(x, Δ; once = true) istracked(x) || return - scan(x) + scan(tracker(x)) back(tracker(x), Δ, once) return end @@ -143,16 +140,19 @@ function forward(f, ps::Params) y, function (Δ) g = Grads(ps) if istracked(y) - scan(y) + scan(tracker(y)) back(g, tracker(y), Δ) end return g end end +# Essentially a hack for complex numbers +unwrap(x) = x + function forward(f, args...) args = param.(args) - y, back = forward(() -> f(args...), Params(args)) + y, back = forward(() -> f(unwrap.(args)...), Params(args)) y, Δ -> getindex.(Ref(back(Δ)), args) end diff --git a/src/tracker/lib/complex.jl b/src/tracker/lib/complex.jl new file mode 100644 index 00000000..f9e82e50 --- /dev/null +++ b/src/tracker/lib/complex.jl @@ -0,0 +1,44 @@ +# Internal interface + +struct _TrackedComplex{T<:Real} + data::Complex{T} + tracker::Tracked{Complex{T}} +end + +_TrackedComplex(x::Complex) = _TrackedComplex(x, Tracked{typeof(x)}(Call(), zero(x))) + +data(x::_TrackedComplex) = x.data +tracker(x::_TrackedComplex) = x.tracker + +Base.real(x::_TrackedComplex) = track(real, x) +Base.imag(x::_TrackedComplex) = track(imag, x) + +@grad real(x::_TrackedComplex) = real(data(x)), r̄ -> (r̄ + zero(r̄)*im,) +@grad imag(x::_TrackedComplex) = imag(data(x)), ī -> (zero(ī) + ī*im,) + +unwrap(x::_TrackedComplex) = real(x) + imag(x)*im + +track(f::Call, x::Complex) = + unwrap(_TrackedComplex(x, Tracked{typeof(x)}(f, zero(x)))) + +param(x::Complex) = _TrackedComplex(float(x)) + +# External interface + +TrackedComplex{T<:Real} = Complex{TrackedReal{T}} + +data(x::TrackedComplex) = data(real(x)) + data(imag(x))*im + +tracker(x::TrackedComplex) = + Tracked{typeof(data(x))}(Call(c -> (real(c), imag(c)), + (tracker(real(x)),tracker(imag(x)))), + zero(data(x))) + +function Base.show(io::IO, x::TrackedComplex) + show(io, data(x)) + print(io, " (tracked)") +end + +Base.log(x::TrackedComplex) = track(log, x) + +@grad log(x::TrackedComplex) = log(data(x)), ȳ -> (ȳ/x,) diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index 3546beba..cba881ca 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -11,9 +11,8 @@ tracker(x::TrackedReal) = x.tracker track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x))) function back!(x::TrackedReal; once = true) - isinf(x) && error("Loss is Inf") - isnan(x) && error("Loss is NaN") - return back!(x, 1, once = once) + losscheck(data(x)) + return back!(x, 1, once = once) end function Base.show(io::IO, x::TrackedReal) @@ -32,7 +31,7 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = error("Not implemented: convert tracked $S to tracked $T") -for op in [:(==), :≈, :<] +for op in [:(==), :≈, :<, :<=] @eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y) @eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y)) @eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y)) diff --git a/test/tracker.jl b/test/tracker.jl index b12e3bb6..b6638712 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -291,4 +291,6 @@ end @test count == 3 end +@test Tracker.gradient(x -> abs2(log(x)), 1+2im)[1] isa Complex + end #testset