Skip to content

avoid customMessage in case class logs #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 28, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Caller/build.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
libraryDependencies += "org.scala-lang" % "scala-reflect" % scalaVersion.value
libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.1" % Test
14 changes: 14 additions & 0 deletions Caller/src/main/scala/com/thoughtworks/deeplearning/Caller.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.thoughtworks.deeplearning

import scala.language.experimental.macros
import scala.reflect.macros.whitebox

final case class Caller[A](value: A)
object Caller {
implicit def generate: Caller[_] = macro impl

def impl(c: whitebox.Context): c.Tree = {
import c.universe._
q"new _root_.com.thoughtworks.deeplearning.Caller(this)"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.thoughtworks.deeplearning

import org.scalatest.{FreeSpec, Matchers}

object Foo {
def call(implicit caller: Caller[_]): String = {
caller.value.getClass.getName
}
}

class CallerSpec extends FreeSpec with Matchers {
"className" in {
val className: String = Foo.call
className should be(this.getClass.getName)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ object LogRecords {
private[LogRecords] abstract class LazyLogRecord(level: Level, customMessage: String = null)(
implicit fullName: sourcecode.FullName,
methodName: sourcecode.Name,
fileName: sourcecode.File)
className: Caller[_])
extends LogRecord(level, customMessage) {

setLoggerName(fullName.value)
setSourceClassName(fileName.value)
setSourceClassName(className.value.getClass.getName)
setSourceMethodName(methodName.value)

protected def makeDefaultMessage: Fastring
Expand All @@ -28,18 +28,27 @@ object LogRecords {

}

final case class UncaughtExceptionDuringBackward(thrown: Throwable, customMessage: String = null)
extends LazyLogRecord(Level.SEVERE, customMessage) {
final case class UncaughtExceptionDuringBackward(
thrown: Throwable)(implicit fullName: sourcecode.FullName, methodName: sourcecode.Name, className: Caller[_])
extends LazyLogRecord(Level.SEVERE) {
setThrown(thrown)
override protected def makeDefaultMessage = fast"An exception raised during backward"
}

final case class DeltaAccumulatorTracker(customMessage: String) extends LazyLogRecord(Level.FINER, customMessage) {
override protected def makeDefaultMessage: Fastring = fast"DeltaAccumulatorTracker default message"
final case class DeltaAccumulatorIsUpdating[Delta](
deltaAccumulator: Delta,
delta: Delta)(implicit fullName: sourcecode.FullName, methodName: sourcecode.Name, className: Caller[_])
extends LazyLogRecord(Level.FINER) {
override protected def makeDefaultMessage: Fastring =
fast"Before deltaAccumulator update, deltaAccumulator is : $deltaAccumulator, delta is : $delta"
}

final case class FloatWeightTracker(customMessage: String) extends LazyLogRecord(Level.FINER, customMessage) {
override protected def makeDefaultMessage: Fastring = fast"FloatWeightTracker default message"
final case class WeightIsUpdating[Delta](data: Delta, delta: Delta)(implicit fullName: sourcecode.FullName,
methodName: sourcecode.Name,
className: Caller[_])
extends LazyLogRecord(Level.FINER) {
override protected def makeDefaultMessage: Fastring =
fast"Before weight update, weight is : $data, delta is : $delta"
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.thoughtworks.deeplearning

import java.util.logging.{Level, Logger}

import com.thoughtworks.deeplearning.LogRecords.{DeltaAccumulatorTracker, UncaughtExceptionDuringBackward}
import com.thoughtworks.deeplearning.LogRecords.{DeltaAccumulatorIsUpdating, UncaughtExceptionDuringBackward}
import com.thoughtworks.deeplearning.Tape.Aux
import com.thoughtworks.raii._

Expand Down Expand Up @@ -31,20 +31,29 @@ object TapeTaskFactory {
operand1: Do[_ <: Tape.Aux[Data1, Delta1]])(
computeForward: (Data0, Data1) => Task[(OutputData, OutputDelta => (Do[_ <: Delta0], Do[_ <: Delta1]))])(
implicit binaryTapeTaskFactory: BinaryTapeTaskFactory[OutputData, OutputDelta],
logger: Logger): Do[Tape.Aux[OutputData, OutputDelta]] = {
logger: Logger,
fullName: sourcecode.FullName,
methodName: sourcecode.Name,
className: Caller[_]): Do[Tape.Aux[OutputData, OutputDelta]] = {
binaryTapeTaskFactory(operand0, operand1)(computeForward)
}

@inline
def unary[Data, Delta, OutputData, OutputDelta](operand: Do[_ <: Tape.Aux[Data, Delta]])(
computeForward: (Data) => Task[(OutputData, OutputDelta => Do[_ <: Delta])])(
implicit unaryTapeTaskFactory: UnaryTapeTaskFactory[OutputData, OutputDelta],
logger: Logger): Do[Tape.Aux[OutputData, OutputDelta]] = {
logger: Logger,
fullName: sourcecode.FullName,
methodName: sourcecode.Name,
className: Caller[_]): Do[Tape.Aux[OutputData, OutputDelta]] = {
unaryTapeTaskFactory(operand)(computeForward)
}

private abstract class Output[OutputData, OutputDelta: Monoid](override val data: OutputData)(
implicit logger: Logger)
implicit logger: Logger,
fullName: sourcecode.FullName,
methodName: sourcecode.Name,
className: Caller[_])
extends Tape
with ResourceT[Future, Try[Tape.Aux[OutputData, OutputDelta]]] {

Expand All @@ -67,7 +76,7 @@ object TapeTaskFactory {
tryDelta.map { delta =>
synchronized {
if (logger.isLoggable(Level.FINER)) {
logger.log(DeltaAccumulatorTracker(s"deltaAccumulator:$deltaAccumulator, delta: $delta"))
logger.log(DeltaAccumulatorIsUpdating(deltaAccumulator, delta))
}
deltaAccumulator |+|= delta
}
Expand All @@ -76,7 +85,7 @@ object TapeTaskFactory {

ResourceFactoryT.run(tryTRAIIFuture).flatMap {
case Failure(e) =>
logger.log(UncaughtExceptionDuringBackward(e, "An exception raised during backward"))
logger.log(UncaughtExceptionDuringBackward(e))
Future.now(())
case Success(()) =>
Future.now(())
Expand Down Expand Up @@ -126,7 +135,10 @@ object TapeTaskFactory {
}
}

final class MonoidBinaryTapeTaskFactory[OutputData, OutputDelta: Monoid](implicit logger: Logger)
final class MonoidBinaryTapeTaskFactory[OutputData, OutputDelta: Monoid](implicit logger: Logger,
fullName: sourcecode.FullName,
methodName: sourcecode.Name,
className: Caller[_])
extends BinaryTapeTaskFactory[OutputData, OutputDelta] {
@inline
override def apply[Data0, Delta0, Data1, Delta1](operand0: Do[_ <: Tape.Aux[Data0, Delta0]],
Expand Down Expand Up @@ -182,13 +194,19 @@ object TapeTaskFactory {

@inline
implicit def monoidBinaryTapeTaskFactory[OutputData, OutputDelta: Monoid](
implicit logger: Logger): BinaryTapeTaskFactory[OutputData, OutputDelta] = {
implicit logger: Logger,
fullName: sourcecode.FullName,
methodName: sourcecode.Name,
className: Caller[_]): BinaryTapeTaskFactory[OutputData, OutputDelta] = {
new MonoidBinaryTapeTaskFactory[OutputData, OutputDelta]
}
}

object UnaryTapeTaskFactory {
final class MonoidUnaryTapeTaskFactory[OutputData, OutputDelta: Monoid](implicit logger: Logger)
final class MonoidUnaryTapeTaskFactory[OutputData, OutputDelta: Monoid](implicit logger: Logger,
fullName: sourcecode.FullName,
methodName: sourcecode.Name,
className: Caller[_])
extends UnaryTapeTaskFactory[OutputData, OutputDelta] {
@inline
override def apply[Data, Delta](operand: Do[_ <: Tape.Aux[Data, Delta]])(
Expand Down Expand Up @@ -224,7 +242,10 @@ object TapeTaskFactory {

@inline
implicit def monoidUnaryTapeTaskFactory[OutputData, OutputDelta: Monoid](
implicit logger: Logger): UnaryTapeTaskFactory[OutputData, OutputDelta] = {
implicit logger: Logger,
fullName: sourcecode.FullName,
methodName: sourcecode.Name,
className: Caller[_]): UnaryTapeTaskFactory[OutputData, OutputDelta] = {
new MonoidUnaryTapeTaskFactory[OutputData, OutputDelta]
}
}
Expand Down
10 changes: 6 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ lazy val Memory = project
lazy val Tape =
project.dependsOn(ProjectRef(file("RAII.scala"), "packageJVM"), LogRecords, ProjectRef(file("RAII.scala"), "Do"))

lazy val TapeTaskFactory = project.dependsOn(Tape, ProjectRef(file("RAII.scala"), "Do"))
lazy val TapeTaskFactory = project.dependsOn(Tape, ProjectRef(file("RAII.scala"), "Do"), Caller)

lazy val Closeables = project

lazy val Caller = project

includeFilter in unmanagedSources := (includeFilter in unmanagedSources).value && new SimpleFileFilter(_.isFile)

lazy val OpenCL = project.dependsOn(Closeables, Memory, ProjectRef(file("RAII.scala"), "ResourceFactoryTJVM"))

//lazy val LayerFactory = project.dependsOn(DifferentiableKernel)

lazy val `differentiable-float` = project.dependsOn(TapeTask, TapeTaskFactory, PolyFunctions)
lazy val `differentiable-float` = project.dependsOn(TapeTask, TapeTaskFactory, PolyFunctions, Caller)

lazy val `differentiable-double` = project.dependsOn(TapeTask, TapeTaskFactory, PolyFunctions)
lazy val `differentiable-double` = project.dependsOn(TapeTask, TapeTaskFactory, PolyFunctions, Caller)

val FloatRegex = """(?i:float)""".r

Expand Down Expand Up @@ -62,7 +64,7 @@ lazy val PolyFunctions = project.dependsOn(ToTapeTask)

lazy val TapeTask = project.dependsOn(Tape, ProjectRef(file("RAII.scala"), "Do"))

lazy val LogRecords = project
lazy val LogRecords = project.dependsOn(Caller)

lazy val AsynchronousSemaphore = project

Expand Down
Loading