From 3eb5e07ded38af28e74946d778a5531fbc1ecd07 Mon Sep 17 00:00:00 2001
From: Mike J Innes <mike.j.innes@gmail.com>
Date: Wed, 31 Oct 2018 20:07:26 +0000
Subject: [PATCH] basic TrackedComplex

---
 src/tracker/Tracker.jl     |  1 +
 src/tracker/back.jl        | 14 ++++++------
 src/tracker/lib/complex.jl | 44 ++++++++++++++++++++++++++++++++++++++
 src/tracker/lib/real.jl    |  7 +++---
 test/tracker.jl            |  2 ++
 5 files changed, 57 insertions(+), 11 deletions(-)
 create mode 100644 src/tracker/lib/complex.jl

diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl
index e99bc1cd..ca007ee7 100644
--- a/src/tracker/Tracker.jl
+++ b/src/tracker/Tracker.jl
@@ -70,6 +70,7 @@ include("idset.jl")
 include("back.jl")
 include("numeric.jl")
 include("lib/real.jl")
+include("lib/complex.jl")
 include("lib/array.jl")
 
 """
diff --git a/src/tracker/back.jl b/src/tracker/back.jl
index af130dd3..d74810d5 100644
--- a/src/tracker/back.jl
+++ b/src/tracker/back.jl
@@ -14,10 +14,7 @@ function scan(x::Tracked)
   return
 end
 
-function scan(x)
-  istracked(x) && scan(tracker(x))
-  return
-end
+scan(::Nothing) = return
 
 function back_(c::Call, Δ, once)
   Δs = c.func(Δ)
@@ -61,7 +58,7 @@ back(::Nothing, Δ, once) = return
 
 function back!(x, Δ; once = true)
   istracked(x) || return
-  scan(x)
+  scan(tracker(x))
   back(tracker(x), Δ, once)
   return
 end
@@ -143,16 +140,19 @@ function forward(f, ps::Params)
   y, function (Δ)
     g = Grads(ps)
     if istracked(y)
-      scan(y)
+      scan(tracker(y))
       back(g, tracker(y), Δ)
     end
     return g
   end
 end
 
+# Essentially a hack for complex numbers
+unwrap(x) = x
+
 function forward(f, args...)
   args = param.(args)
-  y, back = forward(() -> f(args...), Params(args))
+  y, back = forward(() -> f(unwrap.(args)...), Params(args))
   y, Δ -> getindex.(Ref(back(Δ)), args)
 end
 
diff --git a/src/tracker/lib/complex.jl b/src/tracker/lib/complex.jl
new file mode 100644
index 00000000..f9e82e50
--- /dev/null
+++ b/src/tracker/lib/complex.jl
@@ -0,0 +1,44 @@
+# Internal interface
+
+struct _TrackedComplex{T<:Real}
+  data::Complex{T}
+  tracker::Tracked{Complex{T}}
+end
+
+_TrackedComplex(x::Complex) = _TrackedComplex(x, Tracked{typeof(x)}(Call(), zero(x)))
+
+data(x::_TrackedComplex) = x.data
+tracker(x::_TrackedComplex) = x.tracker
+
+Base.real(x::_TrackedComplex) = track(real, x)
+Base.imag(x::_TrackedComplex) = track(imag, x)
+
+@grad real(x::_TrackedComplex) = real(data(x)), r̄ -> (r̄ + zero(r̄)*im,)
+@grad imag(x::_TrackedComplex) = imag(data(x)), ī -> (zero(ī) + ī*im,)
+
+unwrap(x::_TrackedComplex) = real(x) + imag(x)*im
+
+track(f::Call, x::Complex) =
+  unwrap(_TrackedComplex(x, Tracked{typeof(x)}(f, zero(x))))
+
+param(x::Complex) = _TrackedComplex(float(x))
+
+# External interface
+
+TrackedComplex{T<:Real} = Complex{TrackedReal{T}}
+
+data(x::TrackedComplex) = data(real(x)) + data(imag(x))*im
+
+tracker(x::TrackedComplex) =
+  Tracked{typeof(data(x))}(Call(c -> (real(c), imag(c)),
+                                (tracker(real(x)),tracker(imag(x)))),
+                           zero(data(x)))
+
+function Base.show(io::IO, x::TrackedComplex)
+  show(io, data(x))
+  print(io, " (tracked)")
+end
+
+Base.log(x::TrackedComplex) = track(log, x)
+
+@grad log(x::TrackedComplex) = log(data(x)), ȳ -> (ȳ/x,)
diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl
index 3546beba..cba881ca 100644
--- a/src/tracker/lib/real.jl
+++ b/src/tracker/lib/real.jl
@@ -11,9 +11,8 @@ tracker(x::TrackedReal) = x.tracker
 track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
 
 function back!(x::TrackedReal; once = true)
-    isinf(x) && error("Loss is Inf")
-    isnan(x) && error("Loss is NaN")
-    return back!(x, 1, once = once)
+  losscheck(data(x))
+  return back!(x, 1, once = once)
 end
 
 function Base.show(io::IO, x::TrackedReal)
@@ -32,7 +31,7 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
 Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
   error("Not implemented: convert tracked $S to tracked $T")
 
-for op in [:(==), :≈, :<]
+for op in [:(==), :≈, :<, :<=]
   @eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
   @eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
   @eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
diff --git a/test/tracker.jl b/test/tracker.jl
index b12e3bb6..b6638712 100644
--- a/test/tracker.jl
+++ b/test/tracker.jl
@@ -291,4 +291,6 @@ end
   @test count == 3
 end
 
+@test Tracker.gradient(x -> abs2(log(x)), 1+2im)[1] isa Complex
+
 end #testset