From 014ff44faf4588ac41746f8381323cfe4777db19 Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Sun, 30 Aug 2020 18:42:48 -0400 Subject: [PATCH 01/14] Add ability to change learning rate between steps by adding a Placeholder into each Optimizer. Also, added to each Optimizer a corresponding Tensor that holds the value of the learning rate, and added a feed dictionary that maps the placeholder to the Tensor, so that it can be fed into the runner when running or evaluating. When setLearning rate is called the learning rate tensor and the feed dictionary are updated. --- .../tensorflow/keras/optimizers/AdaDelta.java | 25 +- .../tensorflow/keras/optimizers/AdaGrad.java | 25 +- .../keras/optimizers/AdaGradDA.java | 38 +- .../org/tensorflow/keras/optimizers/Adam.java | 21 +- .../tensorflow/keras/optimizers/Adamax.java | 81 +- .../org/tensorflow/keras/optimizers/Ftrl.java | 56 +- .../tensorflow/keras/optimizers/Nadam.java | 60 +- .../keras/optimizers/OptimizerInterface.java | 22 +- .../keras/optimizers/Optimizers.java | 6 +- .../tensorflow/keras/optimizers/RMSProp.java | 19 +- .../org/tensorflow/keras/optimizers/SGD.java | 22 +- .../keras/optimizers/AdaDeltaTest.java | 345 +++--- .../keras/optimizers/AdaGradDATest.java | 26 +- .../keras/optimizers/AdaGradTest.java | 28 +- .../tensorflow/keras/optimizers/AdamTest.java | 35 +- .../keras/optimizers/AdamaxTest.java | 33 +- .../tensorflow/keras/optimizers/FtrlTest.java | 106 +- .../keras/optimizers/NadamTest.java | 34 +- .../keras/optimizers/RMSPropTest.java | 535 ++++----- .../tensorflow/keras/optimizers/SGDTest.java | 34 +- .../keras/utils/EagerTestSession.java | 92 +- .../keras/utils/GraphTestSession.java | 275 +++-- .../tensorflow/keras/utils/TestSession.java | 1023 ++++++++++++----- 23 files changed, 1765 insertions(+), 1176 deletions(-) diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaDelta.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaDelta.java index b0a9dcf7d68..119a2311f3a 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaDelta.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaDelta.java @@ -31,7 +31,6 @@ * * <p>Two accumulation steps are required: 1) the accumulation of gradients squared, 2) the * accumulation of updates squared. - * */ public class AdaDelta extends org.tensorflow.framework.optimizers.AdaDelta implements OptimizerInterface { @@ -45,11 +44,9 @@ public class AdaDelta extends org.tensorflow.framework.optimizers.AdaDelta public static final float EPSILON_DEFAULT = 1e-7F; private Map<String, Object> config = new HashMap<>(); - private float learningRate; private List<Op> initializers = new ArrayList<>(); - /** * Create an Adadelta optimizer with default name="Adadelta", learning_rate=0.001F, rho=0.95F, and * epsilon=1e-7F @@ -127,7 +124,7 @@ protected Optional<Op> prepare(String name) { case 1: return Optional.of(initializers.get(0)); default: - return Optional.of( tf.withSubScope(name).withControlDependencies(initializers).noOp()); + return Optional.of(tf.withSubScope(name).withControlDependencies(initializers).noOp()); } } @@ -146,9 +143,8 @@ public static AdaDelta fromConfig(Ops tf, Map<String, Object> config) { * Create an Adadelta optimizer * * @param tf the tensorflow Ops - * @param config a config object to initialize, the config - * object has keys for "name", "learning_rate", "rho" and "epsilon". If a key is missing the - * default value is used. + * @param config a config object to initialize, the config object has keys for "name", + * "learning_rate", "rho" and "epsilon". If a key is missing the default value is used. */ public static AdaDelta create(Ops tf, Map<String, Object> config) { String name = (String) config.get(NAME_KEY); @@ -171,7 +167,6 @@ public static AdaDelta create(Ops tf, Map<String, Object> config) { * @param epsilon A constant epsilon used to better conditioning the grad update. */ private void initConfig(float learningRate, float rho, float epsilon) { - this.learningRate = learningRate; config.put(NAME_KEY, this.getOptimizerName()); config.put(LEARNING_RATE_KEY, learningRate); config.put(RHO_RATE_KEY, rho); @@ -183,18 +178,4 @@ private void initConfig(float learningRate, float rho, float epsilon) { public Map<String, Object> getConfig() { return config; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public void setLearningRate(float learningRate) { - this.learningRate = learningRate; - } - - } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGrad.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGrad.java index 039cf4a0d82..98476fbdce5 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGrad.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGrad.java @@ -17,7 +17,13 @@ import java.util.HashMap; import java.util.Map; import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; + +import org.tensorflow.Operand; +import org.tensorflow.Tensor; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TType; /** * AdaGrad Optimizer that implements the AdaGrad algorithm. Adagrad is an optimizer with @@ -34,7 +40,6 @@ public class AdaGrad extends org.tensorflow.framework.optimizers.AdaGrad public static final float INITIAL_ACCUM__DEFAULT = 0.1f; private Map<String, Object> config = new HashMap<>(); - private float learningRate; /** * Create an AdaGrad Optimizer with name="Adagrad", learningRate=0.001F, and initial @@ -99,8 +104,9 @@ public AdaGrad(Ops tf, float learningRate, float initialAccumulatorValue) { */ public AdaGrad(Ops tf, String name, float learningRate, float initialAccumulatorValue) { super(assertGraph(tf), name, learningRate, initialAccumulatorValue); - if(initialAccumulatorValue < 0.0F) - throw new IllegalArgumentException( "initial_accumulator_value must be non-negative: " + initialAccumulatorValue); + if (initialAccumulatorValue < 0.0F) + throw new IllegalArgumentException( + "initial_accumulator_value must be non-negative: " + initialAccumulatorValue); initConfig(learningRate, initialAccumulatorValue); } @@ -141,7 +147,6 @@ public static AdaGrad create(Ops tf, Map<String, Object> config) { * @param initialAccumulatorValue the initial Accumulator value */ private void initConfig(float learningRate, float initialAccumulatorValue) { - this.learningRate = learningRate; config.put(NAME_KEY, this.getOptimizerName()); config.put(LEARNING_RATE_KEY, learningRate); config.put(INITIAL_ACCUM_KEY, initialAccumulatorValue); @@ -152,16 +157,4 @@ private void initConfig(float learningRate, float initialAccumulatorValue) { public Map<String, Object> getConfig() { return config; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public void setLearningRate(float learningRate) { - this.learningRate = learningRate; - } } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGradDA.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGradDA.java index 2f15024bf56..f7d11697623 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGradDA.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGradDA.java @@ -14,10 +14,12 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; +import org.tensorflow.op.Ops; + import java.util.HashMap; import java.util.Map; + import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; -import org.tensorflow.op.Ops; /** Optimizer that implements the Adagrad Dual-Averaging algorithm. */ public class AdaGradDA extends org.tensorflow.framework.optimizers.AdaGradDA @@ -33,8 +35,7 @@ public class AdaGradDA extends org.tensorflow.framework.optimizers.AdaGradDA public static final float L1STRENGTH_DEFAULT = 0.0F; public static final float L2STRENGTH_DEFAULT = 0.0F; - private Map<String, Object> config = new HashMap<>(); - private float learningRate; + private final Map<String, Object> config = new HashMap<>(); /** * Create an AdagradDA Optimizer with default values name="adagrad-da". learning_rate=.001, @@ -85,11 +86,12 @@ public AdaGradDA( float l1Strength, float l2Strength) { super(assertGraph(tf), learningRate, initialAccumulatorValue, l1Strength, l2Strength); - if( initialAccumulatorValue < 0.0F) - throw new IllegalArgumentException("initial_accumulator_value must be non-negative: " + initialAccumulatorValue); - if(l1Strength < 0) + if (initialAccumulatorValue < 0.0F) + throw new IllegalArgumentException( + "initial_accumulator_value must be non-negative: " + initialAccumulatorValue); + if (l1Strength < 0) throw new IllegalArgumentException("l1Strength must be non-negative: " + l1Strength); - if(l2Strength < 0) + if (l2Strength < 0) throw new IllegalArgumentException("l2Strength must be non-negative: " + l2Strength); initConfig(learningRate, initialAccumulatorValue, l1Strength, l2Strength); } @@ -112,11 +114,12 @@ public AdaGradDA( float l1Strength, float l2Strength) { super(assertGraph(tf), name, learningRate, initialAccumulatorValue, l1Strength, l2Strength); - if( initialAccumulatorValue < 0.0F) - throw new IllegalArgumentException("initial_accumulator_value must be non-negative: " + initialAccumulatorValue); - if(l1Strength < 0) + if (initialAccumulatorValue < 0.0F) + throw new IllegalArgumentException( + "initial_accumulator_value must be non-negative: " + initialAccumulatorValue); + if (l1Strength < 0) throw new IllegalArgumentException("l1Strength must be non-negative: " + l1Strength); - if(l2Strength < 0) + if (l2Strength < 0) throw new IllegalArgumentException("l2Strength must be non-negative: " + l2Strength); initConfig(learningRate, initialAccumulatorValue, l1Strength, l2Strength); initConfig(learningRate, initialAccumulatorValue, l1Strength, l2Strength); @@ -168,7 +171,6 @@ public static AdaGradDA create(Ops tf, Map<String, Object> config) { */ private void initConfig( float learningRate, float initialAccumulatorValue, float l1Strength, float l2Strength) { - this.learningRate = learningRate; config.put(NAME_KEY, this.getOptimizerName()); config.put(LEARNING_RATE_KEY, learningRate); config.put(INITIAL_ACCUM_KEY, initialAccumulatorValue); @@ -181,16 +183,4 @@ private void initConfig( public Map<String, Object> getConfig() { return config; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public void setLearningRate(float learningRate) { - this.learningRate = learningRate; - } } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adam.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adam.java index 5d74c7e27f4..593ddcd88f3 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adam.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adam.java @@ -14,11 +14,12 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; +import org.tensorflow.op.Ops; + import java.util.HashMap; import java.util.Map; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; -import org.tensorflow.op.Ops; /** Adam Optimizer that implements the Adam algorithm. */ public class Adam extends org.tensorflow.framework.optimizers.Adam implements OptimizerInterface { @@ -33,8 +34,7 @@ public class Adam extends org.tensorflow.framework.optimizers.Adam implements Op public static final float BETA_ONE_DEFAULT = 0.9F; public static final float BETA_TWO_DEFAULT = 0.999F; - private float learningRate; - private Map<String, Object> config = new HashMap<>(); + private final Map<String, Object> config = new HashMap<>(); /** * Create an Adam Optimizer @@ -154,7 +154,6 @@ public static Adam create(Ops tf, Map<String, Object> config) { * 1 of the paper. Defaults to 1e-7. */ protected void initConfig(float learningRate, float betaOne, float betaTwo, float epsilon) { - this.learningRate = learningRate; config.put(NAME_KEY, this.getOptimizerName()); config.put(LEARNING_RATE_KEY, learningRate); config.put(EPSILON_KEY, epsilon); @@ -167,16 +166,4 @@ protected void initConfig(float learningRate, float betaOne, float betaTwo, floa public Map<String, Object> getConfig() { return config; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public void setLearningRate(float learningRate) { - this.learningRate = learningRate; - } } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adamax.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adamax.java index a976a6e51dd..4158043e95b 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adamax.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adamax.java @@ -14,29 +14,28 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.Scope; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyAdaMax; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; +import java.util.*; + +import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; + /** Adamax Optimizer that implements the Adamax algorithm. */ public class Adamax extends org.tensorflow.framework.optimizers.Optimizer - implements OptimizerInterface { + implements OptimizerInterface, AutoCloseable { public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; @@ -51,15 +50,17 @@ public class Adamax extends org.tensorflow.framework.optimizers.Optimizer public static final float BETA_ONE_DEFAULT = 0.9F; public static final float BETA_TWO_DEFAULT = 0.999F; - private Scope scope; - private Map<String, Object> config = new HashMap<>(); + private final Map<String, Object> config = new HashMap<>(); private float learningRate; + private Tensor<TFloat32> learningRateTensor; + private final Placeholder<TFloat32> learningRatePlaceholder; + private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; + private final float betaOne; private final float betaTwo; private final float epsilon; - private Constant<TFloat32> learningRateConst; private Constant<TFloat32> epsilonConst; private Constant<TFloat32> betaOneConst; private Constant<TFloat32> betaTwoConst; @@ -117,10 +118,14 @@ public Adamax(Ops tf, String name, float learningRate) { public Adamax(Ops tf, float learningRate, float betaOne, float betaTwo, float epsilon) { super(assertGraph(tf)); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; - this.scope = tf.scope(); initConfig(learningRate, betaOne, betaTwo, epsilon); } @@ -138,10 +143,14 @@ public Adamax( Ops tf, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { super(assertGraph(tf), name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; - this.scope = tf.scope(); initConfig(learningRate, betaOne, betaTwo, epsilon); } @@ -191,8 +200,31 @@ public float getLearningRate() { /** {@inheritDoc} */ @Override - public void setLearningRate(float learningRate) { + public final void setLearningRate(float learningRate) { this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** + * Get the Feed Dictionary for the run methods to set the Placeholder values(s) + * + * @return the current Feed Dictionary for the run methods + */ + public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } } /** {@inheritDoc} */ @@ -200,7 +232,6 @@ public void setLearningRate(float learningRate) { protected Optional<Op> prepare(String scopeName) { betaOneConst = tf.constant(betaOne); betaTwoConst = tf.constant(betaTwo); - learningRateConst = tf.constant(learningRate); epsilonConst = tf.constant(epsilon); return Optional.empty(); @@ -238,16 +269,16 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable Variable<T> firstMomentSlot = getSlot(variable, FIRST_MOMENT).get(); Variable<T> secondMomentSlot = getSlot(variable, SECOND_MOMENT).get(); return ApplyAdaMax.create( - scope, - (Operand) variable, - (Operand) firstMomentSlot, - (Operand) secondMomentSlot, - (Operand) tf.dtypes.cast(betaOnePower, gradient.dataType()), - (Operand) tf.dtypes.cast(learningRateConst, gradient.dataType()), - (Operand) tf.dtypes.cast(betaOneConst, gradient.dataType()), - (Operand) tf.dtypes.cast(betaTwoConst, gradient.dataType()), - (Operand) tf.dtypes.cast(epsilonConst, gradient.dataType()), - (Operand) gradient); + tf.scope(), + (Operand<T>) variable, + (Operand<T>) firstMomentSlot, + (Operand<T>) secondMomentSlot, + (Operand<T>) tf.dtypes.cast(betaOnePower, gradient.dataType()), + (Operand<T>) tf.dtypes.cast(this.learningRatePlaceholder, gradient.dataType()), + (Operand<T>) tf.dtypes.cast(betaOneConst, gradient.dataType()), + (Operand<T>) tf.dtypes.cast(betaTwoConst, gradient.dataType()), + (Operand<T>) tf.dtypes.cast(epsilonConst, gradient.dataType()), + (Operand<T>) gradient); } /** {@inheritDoc} */ diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Ftrl.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Ftrl.java index db73f60c77e..22dad158a4a 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Ftrl.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Ftrl.java @@ -14,26 +14,28 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.Session; -import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyFtrl; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; + /** Ftrl (Follow the Regularized Leader) Optimizer that implements the FTRL algorithm. */ public class Ftrl extends org.tensorflow.framework.optimizers.Optimizer - implements OptimizerInterface { + implements OptimizerInterface, AutoCloseable { public static final String LEARNING_RATE_KEY = "learning_rate"; public static final String LEARNING_RATE_POWER_KEY = "learning_rate_power"; @@ -55,15 +57,18 @@ public class Ftrl extends org.tensorflow.framework.optimizers.Optimizer private final String name; private float learningRate; + private Tensor<TFloat32> learningRateTensor; + private final Placeholder<TFloat32> learningRatePlaceholder; + private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; private final float learningRatePower; private final float initialAccumulatorValue; private final float l1RegularizationStrength; private final float l2RegularizationStrength; private final float l2ShrinkageRegularizationStrength; - private Map<String, Object> config = new HashMap<>(); + private final Map<String, Object> config = new HashMap<>(); - private boolean useLocking = true; + private final boolean useLocking = true; /** * Create a Ftrl Optimizer @@ -161,6 +166,11 @@ public Ftrl( super(assertGraph(tf)); this.name = getOptimizerName(); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.learningRatePower = learningRatePower; this.initialAccumulatorValue = initialAccumulatorValue; this.l1RegularizationStrength = l1Strength; @@ -198,6 +208,11 @@ public Ftrl( super(assertGraph(tf), name); this.name = name; this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.learningRatePower = learningRatePower; this.initialAccumulatorValue = initialAccumulatorValue; this.l1RegularizationStrength = l1Strength; @@ -331,7 +346,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable accumSlot, // accum linearSlot, // linear gradient, // gradient - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), // lr + tf.dtypes.cast(this.learningRatePlaceholder, gradient.dataType()), // lr tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.dataType()), // l1 tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.dataType()), // l2 tf.dtypes.cast( @@ -360,7 +375,26 @@ public float getLearningRate() { /** {@inheritDoc} */ @Override - public void setLearningRate(float learningRate) { + public final void setLearningRate(float learningRate) { this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } } } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java index a2eba4ecb49..f9f796d7738 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java @@ -14,29 +14,25 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import org.tensorflow.DataType; -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; +import org.tensorflow.*; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; +import java.util.*; + +import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; + /** Nadam Optimizer that implements the NAdam algorithm. */ public class Nadam extends org.tensorflow.framework.optimizers.Optimizer - implements OptimizerInterface { + implements OptimizerInterface, AutoCloseable { public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; @@ -55,6 +51,9 @@ public class Nadam extends org.tensorflow.framework.optimizers.Optimizer private final Map<String, Object> config = new HashMap<>(); private float learningRate; + private Tensor<TFloat32> learningRateTensor; + private final Placeholder<TFloat32> learningRatePlaceholder; + private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; private final float betaOne; private final float betaTwo; private final float epsilon; @@ -63,7 +62,6 @@ public class Nadam extends org.tensorflow.framework.optimizers.Optimizer private long iterations = 0; - private Constant<TFloat32> learningRateConst; private Constant<TFloat32> betaOneConst; private Constant<TFloat32> betaTwoConst; private Constant<TInt64> localStepConst; @@ -140,6 +138,11 @@ public Nadam(Ops tf, String name, float learningRate) { public Nadam(Ops tf, float learningRate, float betaOne, float betaTwo, float epsilon) { super(assertGraph(tf)); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; @@ -160,6 +163,11 @@ public Nadam( Ops tf, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { super(assertGraph(tf), name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; @@ -200,8 +208,31 @@ public float getLearningRate() { /** {@inheritDoc} */ @Override - public void setLearningRate(float learningRate) { + public final void setLearningRate(float learningRate) { this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** + * Get the Feed Dictionary for the run methods to set the Placeholder values(s) + * + * @return the current Feed Dictionary for the run methods + */ + public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } } /** {@inheritDoc} */ @@ -248,7 +279,6 @@ protected Optional<Op> prepare(String scopeName) { Constant<TFloat32> one = tf.constant(1.0F); Constant<TFloat32> point5 = tf.constant(0.5F); - learningRateConst = tf.constant(learningRate); betaOneConst = tf.constant(betaOne); betaTwoConst = tf.constant(betaTwo); localStepConst = tf.constant(this.iterations + 1); @@ -350,7 +380,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable tf.math.sub( variable, tf.math.div( - tf.math.mul(tf.dtypes.cast(learningRateConst, dType), m_t_bar), + tf.math.mul(tf.dtypes.cast(this.learningRatePlaceholder, dType), m_t_bar), tf.math.add(tf.math.sqrt(v_t_prime), tf.dtypes.cast(epsilonConst, dType)))); // assign(var, var_t, use_locking=self._use_locking) return tf.assign(variable, var_t, Assign.useLocking(true)); diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java index 0074ecb0f0a..183c71dd976 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java @@ -14,10 +14,11 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.Map; import org.tensorflow.Graph; import org.tensorflow.op.Ops; +import java.util.Map; + /** The main Interface for Keras Optimizers */ public interface OptimizerInterface { @@ -32,8 +33,9 @@ public interface OptimizerInterface { * @throws java.lang.IllegalArgumentException if the TensorFlow Ops does not represent Graph mode */ static Graph assertGraph(Ops tf) { - if(!tf.scope().env().isGraph()) { - throw new IllegalArgumentException("Invalid environment, Optimizers can only be used in Graph Mode"); + if (!tf.scope().env().isGraph()) { + throw new IllegalArgumentException( + "Invalid environment, Optimizers can only be used in Graph Mode"); } return (Graph) tf.scope().env(); } @@ -44,18 +46,4 @@ static Graph assertGraph(Ops tf) { * @return the config object used to initialize the Optimizer */ Map<String, Object> getConfig(); - - /** - * Return the current learning rate - * - * @return the current learning rate - */ - float getLearningRate(); - - /** - * Set the learning rate - * - * @param learningRate the learning rate; - */ - void setLearningRate(float learningRate); } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java index 1facb307b38..aecd8dcf537 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java @@ -22,9 +22,9 @@ import java.util.HashMap; import java.util.Map; import java.util.function.Function; +import java.util.function.Supplier; import java.util.logging.Level; import java.util.logging.Logger; -import java.util.function.Supplier; /** * Functions to get an Optimizer based on String name, an Optimizer class, or lambda function. @@ -79,8 +79,8 @@ public static Optimizer get(Ops tf, Function<Ops, Optimizer> func) { * * @param optimizerFunction either a String that identifies the Optimizer, an Optimizer class, or * * an Optimizer object. - * @param custom_functions a map of Optimizer lambdas that will be queried if the Optimizer is - * not found in the standard keys + * @param custom_functions a map of Optimizer lambdas that will be queried if the Optimizer is not + * found in the standard keys * @return the Optimizer object */ public static Optimizer get( diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java index c66c6bdd388..03fc4c01f71 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java @@ -14,11 +14,12 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; +import org.tensorflow.op.Ops; + import java.util.HashMap; import java.util.Map; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; -import org.tensorflow.op.Ops; /** RMSProp Optimizer that implements the RMSProp algorithm. */ public class RMSProp extends org.tensorflow.framework.optimizers.RMSProp @@ -37,7 +38,6 @@ public class RMSProp extends org.tensorflow.framework.optimizers.RMSProp public static final boolean CENTERED_DEFAULT = false; private Map<String, Object> config = new HashMap<>(); - private float learningRate; /** * Create an RMSProp Optimizer with the following defaults, name="RMSProp", learning_rate=0.001, @@ -172,7 +172,6 @@ public static RMSProp create(Ops tf, Map<String, Object> config) { */ private void initConfig( float learningRate, float decay, float momentum, float epsilon, boolean centered) { - this.learningRate = learningRate; config.put(NAME_KEY, this.getOptimizerName()); config.put(LEARNING_RATE_KEY, learningRate); config.put(DECAY_KEY, decay); @@ -186,16 +185,4 @@ private void initConfig( public Map<String, Object> getConfig() { return config; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public void setLearningRate(float learningRate) { - this.learningRate = learningRate; - } } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java index f89682f6820..5e7155c2ab5 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java @@ -14,10 +14,12 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; +import org.tensorflow.op.Ops; + import java.util.HashMap; import java.util.Map; + import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; -import org.tensorflow.op.Ops; /** Stochastic Gradient Descent and momentum optimizer. */ public class SGD extends org.tensorflow.framework.optimizers.Momentum @@ -32,7 +34,6 @@ public class SGD extends org.tensorflow.framework.optimizers.Momentum public static final boolean NESTEROV_DEFAULT = false; private Map<String, Object> config = new HashMap<>(); - private float learningRate; /** * Create a Stochastic Gradient Descent optimizer using defaults: name="SGD", learning_rate=0.01, @@ -102,7 +103,7 @@ public SGD(Ops tf, String name, float learningRate, float momentum) { */ public SGD(Ops tf, float learningRate, float momentum, boolean useNesterov) { super(assertGraph(tf), learningRate, momentum, useNesterov); - if(momentum < 0 || momentum > 1) + if (momentum < 0 || momentum > 1) throw new IllegalArgumentException("\"momentum\" must be between [0, 1]."); initConfig(learningRate, momentum, useNesterov); } @@ -119,7 +120,7 @@ public SGD(Ops tf, float learningRate, float momentum, boolean useNesterov) { */ public SGD(Ops tf, String name, float learningRate, float momentum, boolean useNesterov) { super(assertGraph(tf), name, learningRate, momentum, useNesterov); - if(momentum < 0 || momentum > 1) + if (momentum < 0 || momentum > 1) throw new IllegalArgumentException("\"momentum\" must be between [0, 1]."); initConfig(learningRate, momentum, useNesterov); } @@ -166,7 +167,6 @@ public static SGD create(Ops tf, Map<String, Object> config) { * @param useNesterov Whether to apply Nesterov momentum. Defaults to `false`. */ private void initConfig(float learningRate, float momentum, boolean useNesterov) { - this.learningRate = learningRate; config.put(NAME_KEY, this.getOptimizerName()); config.put(LEARNING_RATE_KEY, learningRate); config.put(MOMENTUM_KEY, momentum); @@ -179,18 +179,6 @@ public Map<String, Object> getConfig() { return config; } - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public void setLearningRate(float learningRate) { - this.learningRate = learningRate; - } - // overide the momentum name to return "SGD" /** {@inheritDoc} */ @Override diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java index 8a7c8af9fae..e8a3bc14d9b 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java @@ -14,26 +14,8 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; -import static org.tensorflow.framework.optimizers.AdaDelta.ACCUMULATOR; -import static org.tensorflow.framework.optimizers.AdaDelta.ACCUMULATOR_UPDATE; +import org.junit.jupiter.api.*; import org.tensorflow.framework.optimizers.Optimizer.GradAndVar; -import static org.tensorflow.keras.optimizers.AdaDelta.EPSILON_DEFAULT; -import static org.tensorflow.keras.optimizers.AdaDelta.EPSILON_KEY; -import static org.tensorflow.keras.optimizers.AdaDelta.LEARNING_RATE_DEFAULT; -import static org.tensorflow.keras.optimizers.AdaDelta.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.AdaDelta.RHO_DEFAULT; -import static org.tensorflow.keras.optimizers.AdaDelta.RHO_RATE_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -43,184 +25,181 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.AdaDelta.ACCUMULATOR; +import static org.tensorflow.framework.optimizers.AdaDelta.ACCUMULATOR_UPDATE; +import static org.tensorflow.keras.optimizers.AdaDelta.*; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + /** Test cases for AdaDelta Optimizer */ public class AdaDeltaTest { - private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; - private int index; + private int index; - public AdaDeltaTest() { - } + public AdaDeltaTest() {} - @BeforeAll - public static void setUpClass() { - } + @BeforeAll + public static void setUpClass() {} - @AfterAll - public static void tearDownClass() { - } + @AfterAll + public static void tearDownClass() {} - @BeforeEach - public void setUp() { - } + @BeforeEach + public void setUp() {} - @AfterEach - public void tearDown() { - } + @AfterEach + public void tearDown() {} - /** - * Test of create method, of class AdaDelta. - */ - @Test - public void testCreate() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - Map<String, Object> config = new HashMap<>(); - config.put(NAME_KEY, "AdaDelta"); - config.put(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - config.put(RHO_RATE_KEY, RHO_DEFAULT); - config.put(EPSILON_KEY, EPSILON_DEFAULT); - AdaDelta expResult = new AdaDelta(tf); - AdaDelta result = AdaDelta.create(tf, config); - assertEquals(expResult.getConfig(), result.getConfig()); - } + /** Test of create method, of class AdaDelta. */ + @Test + public void testCreate() { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + Map<String, Object> config = new HashMap<>(); + config.put(NAME_KEY, "AdaDelta"); + config.put(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); + config.put(RHO_RATE_KEY, RHO_DEFAULT); + config.put(EPSILON_KEY, EPSILON_DEFAULT); + AdaDelta expResult = new AdaDelta(tf); + AdaDelta result = AdaDelta.create(tf, config); + assertEquals(expResult.getConfig(), result.getConfig()); } - - @Test - public void testConstructAdadeltaWithLR() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - AdaDelta opt = new AdaDelta(tf, 1.0F, 0.9F, 1.F); - AdaDelta opt2 = new AdaDelta(tf, 0.1F, 0.9F, 1.F); - AdaDelta opt3 = new AdaDelta(tf, 0.1F, 0.9F, 1e-8F); - String format = "AdaDelta{learningRate=%s, rho=%s, epsilon=%s}"; - String optExpected = String.format(format, 1.0F, 0.9F, 1.F); - String opt2Expected = String.format(format, 0.1F, 0.9F, 1.F); - String opt3Expected = String.format(format, 0.1F, 0.9F, 1e-8F); - - String optString = opt.toString(); - String opt2String = opt2.toString(); - String opt3String = opt3.toString(); - - assertEquals(optExpected, optString); - assertEquals(opt2Expected, opt2String); - assertEquals(opt3Expected, opt3String); - } - + } + + @Test + public void testConstructAdadeltaWithLR() { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + AdaDelta opt = new AdaDelta(tf, 1.0F, 0.9F, 1.F); + AdaDelta opt1 = new AdaDelta(tf, "AdaDelta_1", 0.1F, 0.9F, 1.F); + AdaDelta opt2 = new AdaDelta(tf, "AdaDelta_2", 0.1F, 0.9F, 1e-8F); + String format = "AdaDelta{learningRate=%s, rho=%s, epsilon=%s}"; + String optExpected = String.format(format, 1.0F, 0.9F, 1.F); + String opt1Expected = String.format(format, 0.1F, 0.9F, 1.F); + String opt2Expected = String.format(format, 0.1F, 0.9F, 1e-8F); + + String optString = opt.toString(); + String opt1String = opt1.toString(); + String opt2String = opt2.toString(); + + assertEquals(optExpected, optString); + assertEquals(opt1Expected, opt1String); + assertEquals(opt2Expected, opt2String); } - - @Test - public void testConstructAdadeltaWithEpsilonValues() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - AdaDelta opt = new AdaDelta(tf); - Map<String, Object> config = opt.getConfig(); - assertEquals(EPSILON_DEFAULT, (float) config.get(EPSILON_KEY)); - - opt = new AdaDelta(tf, LEARNING_RATE_DEFAULT, RHO_DEFAULT, 1e-8F); - config = opt.getConfig(); - assertEquals(1e-8F, (float) config.get(EPSILON_KEY)); - } + } + + @Test + public void testConstructAdadeltaWithEpsilonValues() { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + AdaDelta opt = new AdaDelta(tf); + Map<String, Object> config = opt.getConfig(); + assertEquals(EPSILON_DEFAULT, (float) config.get(EPSILON_KEY)); + + opt = new AdaDelta(tf, "AdaDelta_1", LEARNING_RATE_DEFAULT, RHO_DEFAULT, 1e-8F); + config = opt.getConfig(); + assertEquals(1e-8F, (float) config.get(EPSILON_KEY)); } - - @Test - public void testBasic() { - int num_updates = 4; // # number of ADADELTA steps to perform - float[] grads = {0.2F, 0.1F, 0.01F}; - float[] lrs = {1.0F, 0.5F, 0.1F}; - for (float grad : grads) { - for (float lr : lrs) { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] fgrads = {grad, grad}; - Shape shape = Shape.of(var0_init.length); - Variable<TFloat32> var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE); - Variable<TFloat32> var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE); - - Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); - - Constant<TFloat32> cgrads = tf.constant(fgrads); - - float accum = 0.0F; - float accum_update = 0.0F; - float rho = 0.95F; - float epsilon = 1e-8F; - float epsilon1 = 1e-5F; - - /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); - gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput())); - gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput())); - - /* get the Optimizer */ - AdaDelta adaDelta = new AdaDelta(tf, lr, rho, epsilon); - - /** - * apply gradients - */ - Op adadelta_update = adaDelta.applyGradients(gradsAndVars, "AdaDeltaTest"); - - /* Create and validae the shapes of the slota */ - Variable<TFloat32>[] slots = new Variable[2]; - Variable<TFloat32>[] slotUpdates = new Variable[2]; - - slots[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR).get(); - assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape()); - - slotUpdates[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); - assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape()); - - slots[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR).get(); - assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape()); - - slotUpdates[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); - assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape()); - - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - - /** - * initialize the accumulators - */ - session.run(tf.init()); - - /** - * make sure the variables were initialized properly - */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); - - float[] updates = new float[num_updates]; - float tot_update = 0; - for (int step = 0; step < num_updates; step++) { - session.run(adadelta_update); - accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); - updates[step] = ((float) Math.sqrt(accum_update + epsilon) - * (float) (1 / Math.sqrt(accum + epsilon)) * grad); - accum_update = (accum_update * rho + ((float) Math.pow(updates[step], 2) * (1.0F - rho))); - tot_update += updates[step] * lr; - - for (int i = 0; i < 2; i++) { - session.evaluate(accum, slots[i]); - session.evaluate(accum_update, slotUpdates[i]); - } - - Float[] var0_initUpdate = {var0_init[0] - tot_update, var0_init[1] - tot_update}; - Float[] var1_initUpdate = {var1_init[0] - tot_update, var1_init[1] - tot_update}; - - session.evaluate(var0_initUpdate, var0); - session.evaluate(var1_initUpdate, var1); - - } - - } + } + + @Test + public void testBasic() { + int num_updates = 4; // # number of ADADELTA steps to perform + float[] grads = {0.2F, 0.1F, 0.01F}; + float[] lrs = {1.0F, 0.5F, 0.1F}; + for (float grad : grads) { + for (float lr : lrs) { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] fgrads = {grad, grad}; + Shape shape = Shape.of(var0_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> cgrads = tf.constant(fgrads); + + float accum = 0.0F; + float accum_update = 0.0F; + float rho = 0.95F; + float epsilon = 1e-8F; + float epsilon1 = 1e-5F; + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput())); + + /* get the Optimizer */ + AdaDelta adaDelta = new AdaDelta(tf, lr, rho, epsilon); + + /** apply gradients */ + Op adadelta_update = adaDelta.applyGradients(gradsAndVars, "AdaDeltaTest"); + + /* Create and validae the shapes of the slota */ + Variable<TFloat32>[] slots = new Variable[2]; + Variable<TFloat32>[] slotUpdates = new Variable[2]; + + slots[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape()); + + slotUpdates[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape()); + + slots[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape()); + + slotUpdates[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + /** make sure the variables were initialized properly */ + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + float[] updates = new float[num_updates]; + float tot_update = 0; + for (int step = 0; step < num_updates; step++) { + session.run(adadelta_update, adaDelta.getFeedDict()); + accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); + updates[step] = + ((float) Math.sqrt(accum_update + epsilon) + * (float) (1 / Math.sqrt(accum + epsilon)) + * grad); + accum_update = + (accum_update * rho + ((float) Math.pow(updates[step], 2) * (1.0F - rho))); + tot_update += updates[step] * lr; + + for (int i = 0; i < 2; i++) { + session.evaluate(accum, slots[i]); + session.evaluate(accum_update, slotUpdates[i]); } + + Float[] var0_initUpdate = {var0_init[0] - tot_update, var0_init[1] - tot_update}; + Float[] var1_initUpdate = {var1_init[0] - tot_update, var1_init[1] - tot_update}; + + session.evaluate(var0_initUpdate, var0); + session.evaluate(var1_initUpdate, var1); + } } + } } - + } } diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java index 3931db4da97..85f4220c4c7 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java @@ -14,20 +14,8 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.*; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.AdaGradDA.INITIAL_ACCUM_KEY; -import static org.tensorflow.keras.optimizers.AdaGradDA.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -37,6 +25,16 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.keras.optimizers.AdaGradDA.INITIAL_ACCUM_KEY; +import static org.tensorflow.keras.optimizers.AdaGradDA.LEARNING_RATE_KEY; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + /** Test cases for AdaGradDA Optimizer */ public class AdaGradDATest { @@ -116,7 +114,7 @@ public void testBasic() { session.evaluate(var0_init, var0); session.evaluate(var1_init, var1); - session.run(ada_update); + session.run(ada_update, instance.getFeedDict()); float[] expected0 = {-0.904534F, -1.603567F}; session.evaluate(expected0, var0); float[] expected1 = {-0.094821f, -0.189358f}; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java index 28c45c4c8c3..b6f1d7c88fc 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java @@ -14,21 +14,8 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; -import static org.tensorflow.framework.optimizers.AdaGrad.ACCUMULATOR; +import org.junit.jupiter.api.*; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.AdaGrad.INITIAL_ACCUM_KEY; -import static org.tensorflow.keras.optimizers.AdaGrad.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.ND; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -41,6 +28,17 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.AdaGrad.ACCUMULATOR; +import static org.tensorflow.keras.optimizers.AdaGrad.INITIAL_ACCUM_KEY; +import static org.tensorflow.keras.optimizers.AdaGrad.LEARNING_RATE_KEY; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + /** Test cases for AdaGrad Optimizer */ public class AdaGradTest { private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; @@ -138,7 +136,7 @@ public void testBasic() { session.evaluate(var1_init, var1); for (int step = 0; step < numSteps; step++) { - session.run(ada_update); + session.run(ada_update, instance.getFeedDict()); accum0_np = caclulateAccum(accum0_np, grads0_np); var0_np = calculate(var0_np, accum0_np, grads0_np, learningRate); diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java index 6f1d13d83d6..6a8f0f5078c 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java @@ -14,29 +14,9 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.*; import org.tensorflow.Tensor; -import static org.tensorflow.framework.optimizers.Adam.FIRST_MOMENT; -import static org.tensorflow.framework.optimizers.Adam.SECOND_MOMENT; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.Adam.BETA_ONE_DEFAULT; -import static org.tensorflow.keras.optimizers.Adam.BETA_ONE_KEY; -import static org.tensorflow.keras.optimizers.Adam.BETA_TWO_DEFAULT; -import static org.tensorflow.keras.optimizers.Adam.BETA_TWO_KEY; -import static org.tensorflow.keras.optimizers.Adam.EPSILON_DEFAULT; -import static org.tensorflow.keras.optimizers.Adam.EPSILON_KEY; -import static org.tensorflow.keras.optimizers.Adam.LEARNING_RATE_DEFAULT; -import static org.tensorflow.keras.optimizers.Adam.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.ND; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -49,6 +29,17 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.Adam.FIRST_MOMENT; +import static org.tensorflow.framework.optimizers.Adam.SECOND_MOMENT; +import static org.tensorflow.keras.optimizers.Adam.*; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + /** Test cases for Adam Optimizer */ public class AdamTest { private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; @@ -203,7 +194,7 @@ public void testBasic() { assertEquals(powers[1], f.getFloat(), epsilon1); }); } - session.run(update); + session.run(update, instance.getFeedDict()); float lr_t = learningRate diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java index 1d3dc9e76bf..3f6b232c179 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java @@ -14,29 +14,9 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.*; import org.tensorflow.Tensor; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.Adamax.BETA_ONE_DEFAULT; -import static org.tensorflow.keras.optimizers.Adamax.BETA_ONE_KEY; -import static org.tensorflow.keras.optimizers.Adamax.BETA_TWO_DEFAULT; -import static org.tensorflow.keras.optimizers.Adamax.BETA_TWO_KEY; -import static org.tensorflow.keras.optimizers.Adamax.EPSILON_DEFAULT; -import static org.tensorflow.keras.optimizers.Adamax.EPSILON_KEY; -import static org.tensorflow.keras.optimizers.Adamax.FIRST_MOMENT; -import static org.tensorflow.keras.optimizers.Adamax.LEARNING_RATE_DEFAULT; -import static org.tensorflow.keras.optimizers.Adamax.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.Adamax.SECOND_MOMENT; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.ND; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -49,6 +29,15 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.keras.optimizers.Adamax.*; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + /** Test cases for Adamax Optimizer */ public class AdamaxTest { private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; @@ -195,7 +184,7 @@ public void testBasic() { assertEquals(beta1_power, f.getFloat(), epsilon1); }); } - session.run(update); + session.run(update, instance.getFeedDict()); FloatNdArray[] resultNP = calculate(var0_np, grads0_np, step, m0, v0); var0_np = resultNP[VAR]; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/FtrlTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/FtrlTest.java index d61197348af..ba5d7ccb7a2 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/FtrlTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/FtrlTest.java @@ -14,24 +14,8 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.*; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.Ftrl.INITIAL_ACCUM_VALUE_KEY; -import static org.tensorflow.keras.optimizers.Ftrl.L1STRENGTH_KEY; -import static org.tensorflow.keras.optimizers.Ftrl.L2STRENGTH_KEY; -import static org.tensorflow.keras.optimizers.Ftrl.L2_SHRINKAGE_REGULARIZATION_STRENGTH_KEY; -import static org.tensorflow.keras.optimizers.Ftrl.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.Ftrl.LEARNING_RATE_POWER_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -41,8 +25,18 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; -/** Test cases for Ftrl Optimizer */ +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.keras.optimizers.Ftrl.*; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + +/** Test the Ftrl Optimizer */ public class FtrlTest { + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; int index; @@ -147,7 +141,7 @@ public void testFtrlWithL1_L2_L2Shrinkage() { session.evaluate(var1_init, var1); for (int i = 0; i < numSteps; i++) { - session.run(ftrl_update); + session.run(ftrl_update, instance.getFeedDict()); } float[] expectedVar0 = {-0.22578995F, -0.44345796F}; @@ -214,7 +208,7 @@ public void testFtrlWithL1() { session.evaluate(var1_init, var1); for (int i = 0; i < numSteps; i++) { - session.run(ftrl_update); + session.run(ftrl_update, instance.getFeedDict()); } float[] expectedVar0 = {-7.66718769F, -10.91273689F}; @@ -282,7 +276,7 @@ public void testFtrlWithL1_L2() { session.evaluate(var1_init, var1); for (int i = 0; i < numSteps; i++) { - session.run(ftrl_update); + session.run(ftrl_update, instance.getFeedDict()); } float[] expectedVar0 = {-0.24059935F, -0.46829352F}; @@ -293,6 +287,74 @@ public void testFtrlWithL1_L2() { } } + @Test + public void testChangingLearningRate() { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + int numSteps = 10; + float learningRate = 3.0F; + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {4.0F, 3.0F}; + float[] grads0_init = {0.1F, 0.2F}; + float[] grads1_init = {0.01F, 0.02F}; + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> grads0 = tf.constant(grads0_init); + Constant<TFloat32> grads1 = tf.constant(grads1_init); + + Ftrl instance = + new Ftrl( + tf, + learningRate, + Ftrl.LEARNING_RATE_POWER_DEFAULT, // learningRatePower + 0.1F, // initial_accumulator_value + 0.001F, // l1_regularization_strength + 2.0F, // l2_regularization_strength + Ftrl + .L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT // l2_shrinkage_regularization_strength + ); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op ftrl_update = instance.applyGradients(gradsAndVars, "FtrlTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + float expected[][][] = { + {{-0.022833f, -0.038881f}, {-0.002141f, -0.004474f}}, + {{-0.037825f, -0.067760f}, {-0.003717f, -0.007587f}}, + {{-0.019528f, -0.034022f}, {-0.001979f, -0.004008f}}, + {{-0.003895f, -0.007653f}, {-0.000355f, -0.000720f}}, + {{-0.000596f, -0.001364f}, {-0.000046f, -0.000094f}}, + {{-0.000084f, -0.000221f}, {-0.000006f, -0.000012f}}, + {{-0.000011f, -0.000034f}, {-0.000001f, -0.000001f}}, + {{-0.000002f, -0.000005f}, {-0.000000f, -0.000000f}}, + {{-0.000000f, -0.000001f}, {-0.000000f, -0.000000f}}, + {{-0.000000f, -0.000000f}, {-0.000000f, -0.000000f}} + }; + for (int i = 0; i < numSteps; i++) { + session.run(ftrl_update, instance.getFeedDict()); + session.evaluate(expected[i][0], var0); + session.evaluate(expected[i][1], var1); + learningRate *= 0.1f; + instance.setLearningRate(learningRate); + } + } + } + @Test public void doTestFtrlwithoutRegularization() { float[] var0_init = {0.0F, 0.0F}; @@ -339,7 +401,7 @@ public void doTestFtrlwithoutRegularization() { session.evaluate(var1_init, var1); for (int i = 0; i < numSteps; i++) { - session.run(ftrl_update); + session.run(ftrl_update, instance.getFeedDict()); } float[] expectedVar0 = {-2.60260963F, -4.29698515F}; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java index 2b8bce40471..6314b4b8b4c 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java @@ -14,29 +14,9 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.*; import org.tensorflow.Tensor; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.Adamax.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.Nadam.BETA_ONE_DEFAULT; -import static org.tensorflow.keras.optimizers.Nadam.BETA_ONE_KEY; -import static org.tensorflow.keras.optimizers.Nadam.BETA_TWO_DEFAULT; -import static org.tensorflow.keras.optimizers.Nadam.BETA_TWO_KEY; -import static org.tensorflow.keras.optimizers.Nadam.EPSILON_DEFAULT; -import static org.tensorflow.keras.optimizers.Nadam.EPSILON_KEY; -import static org.tensorflow.keras.optimizers.Nadam.FIRST_MOMENT; -import static org.tensorflow.keras.optimizers.Nadam.LEARNING_RATE_DEFAULT; -import static org.tensorflow.keras.optimizers.Nadam.SECOND_MOMENT; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.ND; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -49,6 +29,16 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.keras.optimizers.Adamax.LEARNING_RATE_KEY; +import static org.tensorflow.keras.optimizers.Nadam.*; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + /** Test cases for Nadam Optimizer */ public class NadamTest { private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; @@ -199,7 +189,7 @@ public void testBasic() { for (int step = 0; step < numSteps; step++) { - session.run(update); + session.run(update, instance.getFeedDict()); float mut = Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java index b8fb4f40ee9..2a43bdb3df2 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java @@ -14,30 +14,8 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.*; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.framework.optimizers.RMSProp.MG; -import static org.tensorflow.framework.optimizers.RMSProp.MOMENTUM; -import static org.tensorflow.framework.optimizers.RMSProp.RMS; -import static org.tensorflow.keras.optimizers.Ftrl.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; -import static org.tensorflow.keras.optimizers.RMSProp.CENTERED_DEFAULT; -import static org.tensorflow.keras.optimizers.RMSProp.CENTERED_KEY; -import static org.tensorflow.keras.optimizers.RMSProp.DECAY_DEFAULT; -import static org.tensorflow.keras.optimizers.RMSProp.DECAY_KEY; -import static org.tensorflow.keras.optimizers.RMSProp.EPSILON_DEFAULT; -import static org.tensorflow.keras.optimizers.RMSProp.EPSILON_KEY; -import static org.tensorflow.keras.optimizers.RMSProp.MOMENTUM_DEFAULT; -import static org.tensorflow.keras.optimizers.RMSProp.MOMENTUM_KEY; import org.tensorflow.keras.utils.ND; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -50,258 +28,281 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; -/** Test cases for RMSProp Optimizer */ -public class RMSPropTest { - private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; - - final int VAR_T = 0; - final int MG_T = 1; - final int RMS_T = 2; - final int MOM_T = 3; - - int index; - - public RMSPropTest() { - } - - @BeforeAll - public static void setUpClass() { - } - - @AfterAll - public static void tearDownClass() { - } - - @BeforeEach - public void setUp() { - } +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; - @AfterEach - public void tearDown() { - } +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.RMSProp.*; +import static org.tensorflow.keras.optimizers.Ftrl.LEARNING_RATE_KEY; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; +import static org.tensorflow.keras.optimizers.RMSProp.*; - /** - * Test of create method, of class RMSProp. - */ - @Test - public void testCreate() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - Map<String, Object> config = new HashMap<>(); - config.put(NAME_KEY, "Ftrl"); - config.put(LEARNING_RATE_KEY, 2.0F); - config.put(DECAY_KEY, DECAY_DEFAULT); - config.put(MOMENTUM_KEY, MOMENTUM_DEFAULT); - config.put(EPSILON_KEY, EPSILON_DEFAULT); - config.put(CENTERED_KEY, CENTERED_DEFAULT); - Ftrl expResult = new Ftrl(tf, 2.0F); - Ftrl result = Ftrl.create(tf, config); - assertEquals(expResult.getConfig(), result.getConfig()); - } +/** Test cases for RMSProp Optimizer */ +public class RMSPropTest { + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; + + final int VAR_T = 0; + final int MG_T = 1; + final int RMS_T = 2; + final int MOM_T = 3; + + int index; + + public RMSPropTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + /** Test of create method, of class RMSProp. */ + @Test + public void testCreate() { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + Map<String, Object> config = new HashMap<>(); + config.put(NAME_KEY, "Ftrl"); + config.put(LEARNING_RATE_KEY, 2.0F); + config.put(DECAY_KEY, DECAY_DEFAULT); + config.put(MOMENTUM_KEY, MOMENTUM_DEFAULT); + config.put(EPSILON_KEY, EPSILON_DEFAULT); + config.put(CENTERED_KEY, CENTERED_DEFAULT); + Ftrl expResult = new Ftrl(tf, 2.0F); + Ftrl result = Ftrl.create(tf, config); + assertEquals(expResult.getConfig(), result.getConfig()); } + } + + Object[][] _test_param_values = { + // learning_rate, rho (decay), momentum, epsilon, centered + {0.05F, 0.9F, 0.0F, 1e-3F, true}, + {0.05F, 0.9F, 0.0F, 1e-3F, false}, + {0.1F, 0.9F, 0.0F, 1e-3F, true}, + {0.01F, 0.9F, 0.0F, 1e-5F, true}, + {0.01F, 0.9F, 0.9F, 1e-5F, true} + }; + + @Test + public void testDense() { + + int numSteps = 3; + + for (int run = 0; run < _test_param_values.length; run++) { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + session.setEpsilon(1e-2f); + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.2F}; + float[] grads1_init = {0.01F, 0.2F}; + final float epsilon1 = 1e-2F; + + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> grads0 = tf.constant(grads0_init); + Constant<TFloat32> grads1 = tf.constant(grads1_init); - Object[][] _test_param_values = { // learning_rate, rho (decay), momentum, epsilon, centered - {0.05F, 0.9F, 0.0F, 1e-3F, true}, - {0.05F, 0.9F, 0.0F, 1e-3F, false}, - {0.1F, 0.9F, 0.0F, 1e-3F, true}, - {0.01F, 0.9F, 0.0F, 1e-5F, true}, - {0.01F, 0.9F, 0.9F, 1e-5F, true} - }; - - @Test - public void testDense() { - - int numSteps = 3; - - for (int run = 0; run < _test_param_values.length; run++) { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - session.setEpsilon(1e-2f); - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.2F}; - float[] grads1_init = {0.01F, 0.2F}; - final float epsilon1 = 1e-2F; - - FloatNdArray var0_np = NdArrays.vectorOf(var0_init); - FloatNdArray var1_np = NdArrays.vectorOf(var1_init); - FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); - FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); - - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); - Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - - Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); - - Constant<TFloat32> grads0 = tf.constant(grads0_init); - Constant<TFloat32> grads1 = tf.constant(grads1_init); - - // learning_rate, rho (decay), momentum, epsilon, centered - float learningRate = (float) (float) _test_param_values[run][0]; - float decay = (float) _test_param_values[run][1]; - float momentum = (float) _test_param_values[run][2]; - float epsilon = (float) _test_param_values[run][3]; - boolean centered = (boolean) _test_param_values[run][4]; - - RMSProp instance = new RMSProp(tf, - learningRate, - decay, - momentum, - epsilon, - centered); - - /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); - gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); - gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - - Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); - - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - - /** - * initialize the accumulators - */ - session.run(tf.init()); - - /** - * make sure the variables were initialized properly - */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); - - Variable<TFloat32> mg0 = centered ? instance.getSlot(var0.asOutput(), MG).get() : null; - Variable<TFloat32> mg1 = centered ? instance.getSlot(var1.asOutput(), MG).get() : null; - Variable<TFloat32> mom0 = momentum > 0.F ? instance.getSlot(var0.asOutput(), MOMENTUM).get() : null; - Variable<TFloat32> mom1 = momentum > 0.F ? instance.getSlot(var1.asOutput(), MOMENTUM).get() : null; - Variable<TFloat32> rms0 = instance.getSlot(var0.asOutput(), RMS).get(); - Variable<TFloat32> rms1 = instance.getSlot(var1.asOutput(), RMS).get(); - - float[] zeros = {0.0F, 0.0F}; - float[] ones = {1.0F, 1.0F}; // temp to match RMSProp - FloatNdArray mg0_np = NdArrays.vectorOf(zeros); - FloatNdArray mg1_np = NdArrays.vectorOf(zeros); - FloatNdArray rms0_np = NdArrays.vectorOf(ones); - FloatNdArray rms1_np = NdArrays.vectorOf(ones); - FloatNdArray mom0_np = NdArrays.vectorOf(zeros); - FloatNdArray mom1_np = NdArrays.vectorOf(zeros); - - - - for (int i = 0; i < numSteps; i++) { - session.run(update); - FloatNdArray[] result0 = calc(var0_np, grads0_np, mg0_np, rms0_np, - mom0_np, learningRate, decay, momentum, epsilon, centered); - var0_np = result0[VAR_T]; - mg0_np = result0[MG_T]; - rms0_np = result0[RMS_T]; - mom0_np = result0[MOM_T]; - - FloatNdArray[] result1 = calc(var1_np, grads1_np, mg1_np, rms1_np, - mom1_np, learningRate, decay, momentum, epsilon, centered); - - var1_np = result1[VAR_T]; - mg1_np = result1[MG_T]; - rms1_np = result1[RMS_T]; - mom1_np = result1[MOM_T]; - - if (centered) { - session.evaluate(mg0_np, mg0); - session.evaluate(mg0_np, mg0); - } - if (momentum > 0.F) { - session.evaluate(mom0_np, mom0); - session.evaluate(mom1_np, mom1); - } - - /* TODO the values returned from rms slot, do not match what I see in the python test */ - session.evaluate(rms0_np, rms0); - session.evaluate(rms1_np, rms1); - - session.evaluate(var0_np, var0); - session.evaluate(var1_np, var1); - } - } - } - } - - FloatNdArray[] calc(FloatNdArray var_np, FloatNdArray grad_np, FloatNdArray mg_np, - FloatNdArray rms_np, FloatNdArray mom, float lr, float decay, float momentum, - float epsilon, boolean centered) { - - FloatNdArray[] result = new FloatNdArray[4]; // var_t, mg_t, rms_t, mom_t - result[RMS_T] = calcRMS(rms_np, grad_np, decay); // RMS - - FloatNdArray denom_t; - if (centered) { - result[MG_T] = calcMG(mg_np, grad_np, decay); - //rms_t - mg_t * mg_t - denom_t = ND.sub(result[RMS_T], ND.square(result[MG_T])); - } else { - result[MG_T] = mg_np; - denom_t = rms_np; - } - if (momentum > 0.F) { - //momentum * mom + lr * g / (np.sqrt(denom_t + epsilon)) - result[MOM_T] = calcMom(momentum, mom, lr, grad_np, denom_t, epsilon); - //var_t = var - mom_t - result[VAR_T] = ND.sub(var_np, result[MOM_T]); - } else { - result[MOM_T] = mom; - result[VAR_T] = calcVar(var_np, grad_np, lr, denom_t, epsilon); + float learningRate = (float) (float) _test_param_values[run][0]; + float decay = (float) _test_param_values[run][1]; + float momentum = (float) _test_param_values[run][2]; + float epsilon = (float) _test_param_values[run][3]; + boolean centered = (boolean) _test_param_values[run][4]; + + RMSProp instance = new RMSProp(tf, learningRate, decay, momentum, epsilon, centered); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + /** make sure the variables were initialized properly */ + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + Variable<TFloat32> mg0 = centered ? instance.getSlot(var0.asOutput(), MG).get() : null; + Variable<TFloat32> mg1 = centered ? instance.getSlot(var1.asOutput(), MG).get() : null; + Variable<TFloat32> mom0 = + momentum > 0.F ? instance.getSlot(var0.asOutput(), MOMENTUM).get() : null; + Variable<TFloat32> mom1 = + momentum > 0.F ? instance.getSlot(var1.asOutput(), MOMENTUM).get() : null; + Variable<TFloat32> rms0 = instance.getSlot(var0.asOutput(), RMS).get(); + Variable<TFloat32> rms1 = instance.getSlot(var1.asOutput(), RMS).get(); + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; // temp to match RMSProp + FloatNdArray mg0_np = NdArrays.vectorOf(zeros); + FloatNdArray mg1_np = NdArrays.vectorOf(zeros); + FloatNdArray rms0_np = NdArrays.vectorOf(ones); + FloatNdArray rms1_np = NdArrays.vectorOf(ones); + FloatNdArray mom0_np = NdArrays.vectorOf(zeros); + FloatNdArray mom1_np = NdArrays.vectorOf(zeros); + + for (int i = 0; i < numSteps; i++) { + session.run(update, instance.getFeedDict()); + FloatNdArray[] result0 = + calc( + var0_np, + grads0_np, + mg0_np, + rms0_np, + mom0_np, + learningRate, + decay, + momentum, + epsilon, + centered); + var0_np = result0[VAR_T]; + mg0_np = result0[MG_T]; + rms0_np = result0[RMS_T]; + mom0_np = result0[MOM_T]; + + FloatNdArray[] result1 = + calc( + var1_np, + grads1_np, + mg1_np, + rms1_np, + mom1_np, + learningRate, + decay, + momentum, + epsilon, + centered); + + var1_np = result1[VAR_T]; + mg1_np = result1[MG_T]; + rms1_np = result1[RMS_T]; + mom1_np = result1[MOM_T]; + + if (centered) { + session.evaluate(mg0_np, mg0); + session.evaluate(mg0_np, mg0); + } + if (momentum > 0.F) { + session.evaluate(mom0_np, mom0); + session.evaluate(mom1_np, mom1); + } + + /* TODO the values returned from rms slot, do not match what I see in the python test */ + session.evaluate(rms0_np, rms0); + session.evaluate(rms1_np, rms1); + + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); } - - - return result; - + } } - - private FloatNdArray calcRMS(FloatNdArray rms_np, FloatNdArray grad_np, float decay) { - //rms * rho + (1 - rho) * g * g - FloatNdArray rms_rho = ND.mul(rms_np, decay); - FloatNdArray squareG = ND.square(grad_np); - float oneRHO = 1.0F - decay; - FloatNdArray decayG2 = ND.mul(oneRHO, squareG); - FloatNdArray result = ND.add(rms_rho, decayG2); - return result; - } - - private FloatNdArray calcMG(FloatNdArray mg_np, FloatNdArray grad_np, float decay) { - //mg_t = mg * rho + (1 - rho) * g - FloatNdArray mg_rho = ND.mul(mg_np, decay); - float oneRHO = 1.0F - decay; - FloatNdArray decayG = ND.mul(oneRHO, grad_np); - FloatNdArray result = ND.add(mg_rho, decayG); - return result; - + } + + FloatNdArray[] calc( + FloatNdArray var_np, + FloatNdArray grad_np, + FloatNdArray mg_np, + FloatNdArray rms_np, + FloatNdArray mom, + float lr, + float decay, + float momentum, + float epsilon, + boolean centered) { + + FloatNdArray[] result = new FloatNdArray[4]; // var_t, mg_t, rms_t, mom_t + result[RMS_T] = calcRMS(rms_np, grad_np, decay); // RMS + + FloatNdArray denom_t; + if (centered) { + result[MG_T] = calcMG(mg_np, grad_np, decay); + // rms_t - mg_t * mg_t + denom_t = ND.sub(result[RMS_T], ND.square(result[MG_T])); + } else { + result[MG_T] = mg_np; + denom_t = rms_np; } - - private FloatNdArray calcMom(float momentum, FloatNdArray mom, float lr, - FloatNdArray grad_np, FloatNdArray denom_t, float epsilon) { - // momentum * mom + lr * g / (np.sqrt(denom_t + epsilon)) - FloatNdArray moMo = ND.mul(momentum, mom); - FloatNdArray dividend = ND.mul(lr, grad_np); - FloatNdArray divisor = ND.sqrt(ND.add(denom_t, epsilon)); - FloatNdArray quotient = ND.div(dividend, divisor); - FloatNdArray result = ND.add(moMo, quotient); - return result; - + if (momentum > 0.F) { + // momentum * mom + lr * g / (np.sqrt(denom_t + epsilon)) + result[MOM_T] = calcMom(momentum, mom, lr, grad_np, denom_t, epsilon); + // var_t = var - mom_t + result[VAR_T] = ND.sub(var_np, result[MOM_T]); + } else { + result[MOM_T] = mom; + result[VAR_T] = calcVar(var_np, grad_np, lr, denom_t, epsilon); } - private FloatNdArray calcVar(FloatNdArray var_np, FloatNdArray grad_np, float lr, - FloatNdArray denom_t, float epsilon) { - // var - lr * g / (np.sqrt(denom_t) + epsilon) - FloatNdArray dividend = ND.mul(lr, grad_np); - FloatNdArray divisor = ND.add(ND.sqrt(denom_t), epsilon); - FloatNdArray quotient = ND.div(dividend, divisor); - FloatNdArray result = ND.sub(var_np, quotient); - return result; - - } + return result; + } + + private FloatNdArray calcRMS(FloatNdArray rms_np, FloatNdArray grad_np, float decay) { + // rms * rho + (1 - rho) * g * g + FloatNdArray rms_rho = ND.mul(rms_np, decay); + FloatNdArray squareG = ND.square(grad_np); + float oneRHO = 1.0F - decay; + FloatNdArray decayG2 = ND.mul(oneRHO, squareG); + FloatNdArray result = ND.add(rms_rho, decayG2); + return result; + } + + private FloatNdArray calcMG(FloatNdArray mg_np, FloatNdArray grad_np, float decay) { + // mg_t = mg * rho + (1 - rho) * g + FloatNdArray mg_rho = ND.mul(mg_np, decay); + float oneRHO = 1.0F - decay; + FloatNdArray decayG = ND.mul(oneRHO, grad_np); + FloatNdArray result = ND.add(mg_rho, decayG); + return result; + } + + private FloatNdArray calcMom( + float momentum, + FloatNdArray mom, + float lr, + FloatNdArray grad_np, + FloatNdArray denom_t, + float epsilon) { + // momentum * mom + lr * g / (np.sqrt(denom_t + epsilon)) + FloatNdArray moMo = ND.mul(momentum, mom); + FloatNdArray dividend = ND.mul(lr, grad_np); + FloatNdArray divisor = ND.sqrt(ND.add(denom_t, epsilon)); + FloatNdArray quotient = ND.div(dividend, divisor); + FloatNdArray result = ND.add(moMo, quotient); + return result; + } + + private FloatNdArray calcVar( + FloatNdArray var_np, FloatNdArray grad_np, float lr, FloatNdArray denom_t, float epsilon) { + // var - lr * g / (np.sqrt(denom_t) + epsilon) + FloatNdArray dividend = ND.mul(lr, grad_np); + FloatNdArray divisor = ND.add(ND.sqrt(denom_t), epsilon); + FloatNdArray quotient = ND.div(dividend, divisor); + FloatNdArray result = ND.sub(var_np, quotient); + return result; + } } diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java index 7e12b957f84..3d24b85239a 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java @@ -14,24 +14,8 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; -import static org.tensorflow.framework.optimizers.Momentum.MOMENTUM; +import org.junit.jupiter.api.*; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; -import static org.tensorflow.keras.optimizers.SGD.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.SGD.MOMENTUM_DEFAULT; -import static org.tensorflow.keras.optimizers.SGD.MOMENTUM_KEY; -import static org.tensorflow.keras.optimizers.SGD.NESTEROV_DEFAULT; -import static org.tensorflow.keras.optimizers.SGD.NESTEROV_KEY; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -41,6 +25,16 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.Momentum.MOMENTUM; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; +import static org.tensorflow.keras.optimizers.SGD.*; + /** Test cases for SGD Optimizer */ public class SGDTest { @@ -134,7 +128,7 @@ public void testBasic() { session.evaluate(var0_init, var0); session.evaluate(var1_init, var1); - session.run(update); // 1 step + session.run(update, instance.getFeedDict()); // 1 step float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F}; float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F}; @@ -194,7 +188,7 @@ public void testMomentum() { session.evaluate(var0_init, var0); session.evaluate(var1_init, var1); - session.run(update); // 1 step + session.run(update, instance.getFeedDict()); // 1 step float[] expectedMomentum0 = {0.1F, 0.1F}; float[] expectedMomentum1 = {0.01F, 0.01F}; @@ -206,7 +200,7 @@ public void testMomentum() { session.evaluate(expectedVar0, var0); session.evaluate(expectedVar1, var1); - session.run(update); // step 2 + session.run(update, instance.getFeedDict()); // step 2 float[] expectedMomentum0_2 = {(0.9f * 0.1f + 0.1f), (0.9f * 0.1f + 0.1f)}; float[] expectedMomentum1_2 = {(0.9f * 0.01f + 0.01f), (0.9f * 0.01f + 0.01f)}; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java index 6b7ebf9e2f2..6d286c311ff 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java @@ -14,37 +14,30 @@ =======================================================================*/ package org.tensorflow.keras.utils; -import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Predicate; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; -import org.tensorflow.DataType; -import org.tensorflow.EagerSession; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.Session; +import org.tensorflow.*; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.TString; +import org.tensorflow.types.*; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -/** Eaager Mode Test Session */ +import java.io.PrintWriter; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Predicate; + +import static org.junit.jupiter.api.Assertions.*; + +/** @author Jim Clarke */ public class EagerTestSession extends TestSession { private final EagerSession session; private final Ops tf; - /** Create an Eager mode test session. */ + /** Create a EagerTestSession */ public EagerTestSession() { this.session = EagerSession.create(); this.tf = Ops.create(session).withName("test"); @@ -57,8 +50,9 @@ public Ops getTF() { } /** - * Get the TensorFlow EagerSession instance - * @return the TensorFlow EagerSession instance + * Returns the EagerSession for this Test session + * + * @return the EagerSession for this Test session */ public EagerSession getSession() { return session; @@ -90,7 +84,22 @@ public EagerSession getEagerSession() { /** {@inheritDoc} */ @Override - public <T extends TNumber> void evaluate(double expected, Operand<T> input) { + public void run(Op op) { + /* Empty */ + } + + /** {@inheritDoc} */ + @Override + public void run(Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + /* Empty */ + } + + /** {@inheritDoc} */ + @Override + public <U extends TNumber> void evaluate( + double expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { DataType dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { Operand<TFloat32> o = (Operand<TFloat32>) input; @@ -169,7 +178,10 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { /** {@inheritDoc} */ @Override - public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { + public <U extends TNumber> void evaluate( + Number[] expected, + Output<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -254,7 +266,10 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { /** {@inheritDoc} */ @Override - public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { + public <U extends TNumber> void evaluate( + FloatNdArray expected, + Output<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { DataType dtype = input.dataType(); if (dtype == TFloat32.DTYPE) { Output<TFloat32> o = (Output<TFloat32>) input; @@ -334,7 +349,10 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { /** {@inheritDoc} */ @Override - public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predicate) { + public <U extends TNumber> void evaluate( + Output<U> input, + Predicate<Number> predicate, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { AtomicInteger index = new AtomicInteger(); DataType dtype = input.asOutput().dataType(); boolean isScalar = input.shape().equals(Shape.scalar()); @@ -457,7 +475,10 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic /** {@inheritDoc} */ @Override - public void evaluate(String[] expected, Output<TString> input) { + public void evaluate( + String[] expected, + Output<TString> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -485,7 +506,10 @@ public void evaluate(String[] expected, Output<TString> input) { /** {@inheritDoc} */ @Override - public void evaluate(Boolean[] expected, Output<TBool> input) { + public void evaluate( + Boolean[] expected, + Output<TBool> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -513,10 +537,13 @@ public void evaluate(Boolean[] expected, Output<TBool> input) { /** {@inheritDoc} */ @Override - public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { + public <T extends TType> void evaluate( + Output<T> expected, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { assert input.shape().equals(expected.shape()) : String.format( - "expected shape (%s) != to input shape (%ds)", + "expected shape (%s) != to input shape (%s)", expected.shape().toString(), input.shape().toString()); DataType dtype = input.asOutput().dataType(); boolean isScalar = input.shape().equals(Shape.scalar()); @@ -683,7 +710,10 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { /** {@inheritDoc} */ @Override - public <T extends TType> void print(PrintWriter writer, Output<T> input) { + public <T extends TType> void print( + PrintWriter writer, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { DataType dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { Output<TFloat32> o = (Output<TFloat32>) input; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java index 1a22289f4bf..ff18b338ce2 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java @@ -14,41 +14,32 @@ =======================================================================*/ package org.tensorflow.keras.utils; -import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Predicate; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; -import org.tensorflow.DataType; -import org.tensorflow.EagerSession; -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.Session; -import org.tensorflow.Tensor; +import org.tensorflow.*; +import org.tensorflow.Session.Runner; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.TString; +import org.tensorflow.types.*; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -/** Graph Mode Test Session */ +import java.io.PrintWriter; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Predicate; + +import static org.junit.jupiter.api.Assertions.*; + +/** @author Jim Clarke */ public class GraphTestSession extends TestSession { private final Graph graph; private final Session session; private final Ops tf; - /** Create a Graph mode test session. */ + /** Create a Graph Test Session */ public GraphTestSession() { graph = new Graph(); session = new Session(graph); @@ -61,15 +52,19 @@ public Ops getTF() { return tf; } - /** Get the Graph object that is represented by this Test Session */ + /** + * Get the Graph instance for this test Session + * + * @return + */ public Graph getGraph() { return graph; } - /** - * Get the TensorFlow Session instance - * @return the TensorFlow Session instance + * Get the Graph session instance for this test Session + * + * @return */ public Session getSession() { return session; @@ -119,13 +114,50 @@ public void run(Op op) { /** {@inheritDoc} */ @Override - public <T extends TNumber> void evaluate(double expected, Operand<T> input) { + public void run(Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + createRunner(op, feedDict).run(); + } + + /** + * Create a runner for the Operation + * + * @param op the operation + * @return the runner + */ + public Runner createRunner(Op op) { + return createRunner(op, null); + } + + /** + * Create a runner for the Operation + * + * @param op the operation + * @param feedDict the dictionary of values to use for the runner's feed operations. Required when + * placeholders are used. + * @return the runner + */ + public Runner createRunner( + Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + Runner runner = session.runner(); + runner.addTarget(op.op()); + if (feedDict != null && !feedDict.isEmpty()) { + feedDict.forEach((name, tensor) -> runner.feed(name, tensor)); + } + return runner; + } + + /** {@inheritDoc} */ + @Override + public <U extends TNumber> void evaluate( + double expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { DataType dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -137,7 +169,7 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { } index.set(0); try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -150,7 +182,7 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -162,7 +194,7 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { } index.set(0); try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -175,7 +207,7 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -187,7 +219,7 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { } index.set(0); try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -201,7 +233,7 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -213,7 +245,7 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { } index.set(0); try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -229,18 +261,24 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { /** {@inheritDoc} */ @Override - public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + public <U extends TNumber> void evaluate( + Number[] expected, + Output<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + long size = input.shape().size() == 0 ? 1 : input.shape().size(); + if (size != Shape.UNKNOWN_SIZE) { + assertEquals( + expected.length, + size, + () -> + String.format("expected length (%d) != to input length (%d)", expected.length, size)); + } DataType dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -252,7 +290,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { } index.set(0); try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -266,7 +304,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -278,7 +316,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { } index.set(0); try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -292,7 +330,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -304,7 +342,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { } index.set(0); try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -318,7 +356,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -330,7 +368,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { } index.set(0); try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -346,13 +384,16 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { /** {@inheritDoc} */ @Override - public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { + public <U extends TNumber> void evaluate( + FloatNdArray expected, + Output<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { DataType dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { AtomicLong index = new AtomicLong(); if (debug) { try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -364,7 +405,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { } index.set(0); try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -377,7 +418,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -389,7 +430,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { } index.set(0); try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -403,7 +444,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -415,7 +456,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { } index.set(0); try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -429,7 +470,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -441,7 +482,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { } index.set(0); try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -457,7 +498,10 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { /** {@inheritDoc} */ @Override - public void evaluate(String[] expected, Output<TString> input) { + public void evaluate( + String[] expected, + Output<TString> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -466,7 +510,7 @@ public void evaluate(String[] expected, Output<TString> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TString> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { result .data() .scalars() @@ -478,7 +522,7 @@ public void evaluate(String[] expected, Output<TString> input) { } index.set(0); try (Tensor<TString> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { result .data() .scalars() @@ -491,7 +535,10 @@ public void evaluate(String[] expected, Output<TString> input) { /** {@inheritDoc} */ @Override - public void evaluate(Boolean[] expected, Output<TBool> input) { + public void evaluate( + Boolean[] expected, + Output<TBool> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -500,7 +547,7 @@ public void evaluate(Boolean[] expected, Output<TBool> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TBool> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { result .data() .scalars() @@ -512,7 +559,7 @@ public void evaluate(Boolean[] expected, Output<TBool> input) { } index.set(0); try (Tensor<TBool> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { result .data() .scalars() @@ -525,10 +572,13 @@ public void evaluate(Boolean[] expected, Output<TBool> input) { /** {@inheritDoc} */ @Override - public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { + public <T extends TType> void evaluate( + Output<T> expected, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { assert input.shape().equals(expected.shape()) : String.format( - "expected shape (%s) != to input shape (%ds)", + "expected shape (%s) != to input shape (%s)", expected.shape().toString(), input.shape().toString()); AtomicInteger index = new AtomicInteger(); DataType dtype = input.asOutput().dataType(); @@ -537,9 +587,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { final Output<TFloat32> finalExpected = (Output<TFloat32>) expected; if (debug) { try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE); Tensor<TFloat32> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { System.out.printf( "0). %f <==> %f\n", expectedResult.data().getFloat(), result.data().getFloat()); @@ -560,9 +610,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } index.set(0); try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE); Tensor<TFloat32> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getFloat(), result.data().getFloat(), epsilon); } else { @@ -579,9 +629,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { final Output<TFloat64> finalExpected = (Output<TFloat64>) expected; if (debug) { try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE); Tensor<TFloat64> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { System.out.printf( "0). %f <==> %f\n", expectedResult.data().getDouble(), result.data().getDouble()); @@ -602,9 +652,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } index.set(0); try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE); Tensor<TFloat64> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getDouble(), result.data().getDouble(), epsilon); } else { @@ -621,9 +671,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { final Output<TInt32> finalExpected = (Output<TInt32>) expected; if (debug) { try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE); Tensor<TInt32> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { System.out.printf( "0). %d <==> %d\n", expectedResult.data().getInt(), result.data().getInt()); @@ -642,9 +692,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } index.set(0); try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE); Tensor<TInt32> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getInt(), result.data().getInt(), epsilon); } else { @@ -661,9 +711,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { final Output<TInt64> finalExpected = (Output<TInt64>) expected; if (debug) { try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE); Tensor<TInt64> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { System.out.printf( "0). %d <==> %d\n", expectedResult.data().getLong(), result.data().getLong()); @@ -682,9 +732,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } index.set(0); try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE); Tensor<TInt64> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getLong(), result.data().getLong(), epsilon); } else { @@ -701,9 +751,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { final Output<TBool> finalExpected = (Output<TBool>) expected; if (debug) { try (Tensor<TBool> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE); Tensor<TBool> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %b\n", expectedResult.data().getBoolean(), result.data().getBoolean()); @@ -724,9 +774,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } index.set(0); try (Tensor<TBool> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE); Tensor<TBool> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getBoolean(), result.data().getBoolean()); } else { @@ -743,9 +793,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { final Output<TString> finalExpected = (Output<TString>) expected; if (debug) { try (Tensor<TString> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE); Tensor<TString> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { if (isScalar) { System.out.printf( "0). %s <==> %s\n", expectedResult.data().getObject(), result.data().getObject()); @@ -766,9 +816,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } index.set(0); try (Tensor<TString> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE); Tensor<TString> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getObject(), result.data().getObject()); } else { @@ -787,15 +837,17 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } /** {@inheritDoc} */ - @Override - public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predicate) { + public <U extends TNumber> void evaluate( + Output<U> input, + Predicate<Number> predicate, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { AtomicInteger index = new AtomicInteger(); DataType dtype = input.asOutput().dataType(); boolean isScalar = input.shape().equals(Shape.scalar()); if (dtype == TFloat32.DTYPE) { if (debug) { try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %f\n", @@ -815,7 +867,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } index.set(0); try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { assertTrue(predicate.test(result.data().getFloat())); } else { @@ -831,7 +883,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } else if (dtype == TFloat64.DTYPE) { if (debug) { try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %f\n", @@ -851,9 +903,9 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } index.set(0); try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE); Tensor<TFloat64> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { assertTrue(predicate.test(result.data().getDouble())); } else { @@ -869,7 +921,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } else if (dtype == TInt32.DTYPE) { if (debug) { try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %d\n", predicate.test(result.data().getInt()), result.data().getInt()); @@ -888,9 +940,9 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } index.set(0); try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE); Tensor<TInt32> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { assertTrue(predicate.test(result.data().getInt())); } else { @@ -906,7 +958,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } else if (dtype == TInt64.DTYPE) { if (debug) { try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %d\n", @@ -926,9 +978,9 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } index.set(0); try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE); Tensor<TInt64> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { assertTrue(predicate.test(result.data().getLong())); } else { @@ -948,14 +1000,17 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic /** {@inheritDoc} */ @Override - public <T extends TType> void print(PrintWriter writer, Output<T> input) { + public <T extends TType> void print( + PrintWriter writer, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { boolean isScalar = input.asOutput().shape().size() == 1; DataType dtype = input.dataType(); if (dtype == TFloat32.DTYPE) { AtomicInteger index = new AtomicInteger(); try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { writer.printf("%d). %f\n", index.getAndIncrement(), result.data().getFloat()); } else { @@ -972,7 +1027,7 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) { AtomicInteger index = new AtomicInteger(); try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { writer.printf( "%d). %f\n", index.getAndIncrement(), ((Output<TFloat64>) input).data().getDouble()); @@ -990,7 +1045,7 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) { AtomicInteger index = new AtomicInteger(); try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { writer.printf( "%d). %f\n", index.getAndIncrement(), ((Output<TInt32>) input).data().getInt()); @@ -1008,7 +1063,7 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) { AtomicInteger index = new AtomicInteger(); try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { writer.printf( "%d). %f\n", index.getAndIncrement(), ((Output<TInt64>) input).data().getLong()); @@ -1026,7 +1081,7 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) { AtomicInteger index = new AtomicInteger(); try (Tensor<TBool> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { if (isScalar) { writer.printf( "%d). %b\n", index.getAndIncrement(), ((Output<TBool>) input).data().getBoolean()); @@ -1044,7 +1099,7 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) { AtomicInteger index = new AtomicInteger(); try (Tensor<TString> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { if (isScalar) { writer.printf( "%d). %s\n", index.getAndIncrement(), ((Output<TString>) input).data().getObject()); diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java index 1e5393aa2af..34348ccc1f4 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java @@ -14,16 +14,7 @@ =======================================================================*/ package org.tensorflow.keras.utils; -import java.io.OutputStream; -import java.io.OutputStreamWriter; -import java.io.PrintWriter; -import java.io.Writer; -import java.util.function.Predicate; -import static org.junit.jupiter.api.Assertions.assertTrue; -import org.tensorflow.EagerSession; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.Session; +import org.tensorflow.*; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -32,569 +23,1065 @@ import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -/** Base class for Test Session */ +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.io.Writer; +import java.util.Map; +import java.util.function.Predicate; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** @author Jim Clarke */ public abstract class TestSession implements AutoCloseable { protected float epsilon = 1e-5F; protected boolean debug; - /** The Test Session mode, either Eager or Graph */ + /** Enumerate between Eager and Graph Mode */ public enum Mode { EAGER, GRAPH; } - /** - * Create an Eager Test Session - * - * @return the Eager Test Session - */ public static TestSession createEagerSession() { return new EagerTestSession(); } - /** - * Create a Graph Test Session - * - * @return the Graph Test Session - */ public static TestSession createGraphSession() { return new GraphTestSession(); } - /** - * Create a Test Session - * - * @param mode - * @return - */ public static TestSession createTestSession(Mode mode) { return mode == Mode.EAGER ? createEagerSession() : createGraphSession(); } - /** Initialize the Test Session, default implementation is do nothing. */ public void initialize() { // empty } /** - * Run the Operation + * Perform session.run() + * + * <p>If in eager mode, this does nothing. * - * @param op the Operation to run + * @param op The Operation to run */ - public void run(Op op) { - // empty + public abstract void run(Op op); + + /** + * Perform session.run() + * + * <p>If in eager mode, this does nothing. + * + * @param op The Operation to run + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public abstract void run(Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(Number expected, Operand<U> input) { + evaluate(new Number[] {expected}, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(Number expected, Operand<T> input) { - evaluate(new Number[] {expected}, input); + public <U extends TNumber> void evaluate( + Number expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + evaluate(new Number[] {expected}, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value */ - public <T extends TNumber> void evaluate(Number expected, Op input) { - evaluate(new Number[] {expected}, input); + public void evaluate(Number expected, Op input) { + evaluate(new Number[] {expected}, input, null); } /** - * Evaluate the input against the expected values + * Evaluate the expected results versus the actual results * - * @param expected the expected values - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <T> the data type for the feedDict entries */ - public <T extends TNumber> void evaluate(Number[] expected, Op input) { - Output output = input.op().output(0); - evaluate(expected, output); + public <T extends TType> void evaluate( + Number expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + evaluate(new Number[] {expected}, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param <U> the data type for the input + */ + public <U extends TNumber> void evaluate(Number[] expected, Op input) { + Output<U> output = input.op().output(0); + evaluate(expected, output, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type for the input + */ + public <U extends TNumber> void evaluate( + Number[] expected, + Op input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + Output<U> output = input.op().output(0); + evaluate(expected, output, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(Number[] expected, Operand<T> input) { + public <U extends TNumber> void evaluate(Number[] expected, Operand<U> input) { + Output<U> output = input.asOutput(); + evaluate(expected, output, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + Number[] expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { Output output = input.asOutput(); - evaluate(expected, output); + evaluate(expected, output, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(byte expected, Operand<T> input) { - evaluate((double) expected, input); + public <U extends TNumber> void evaluate(byte expected, Operand<U> input) { + evaluate((double) expected, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(int expected, Operand<T> input) { - evaluate((double) expected, input); + public <U extends TNumber> void evaluate( + byte expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + evaluate((double) expected, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(long expected, Operand<T> input) { - evaluate((double) expected, input); + public <U extends TNumber> void evaluate(int expected, Operand<U> input) { + evaluate((double) expected, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(float expected, Operand<T> input) { - evaluate((double) expected, input); + public <U extends TNumber> void evaluate( + int expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + evaluate((double) expected, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param <U> the data type of the input */ - public abstract <T extends TNumber> void evaluate(double expected, Operand<T> input); + public <U extends TNumber> void evaluate(long expected, Operand<U> input) { + evaluate((double) expected, input, null); + } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(byte[] expected, Operand<T> input) { + public <U extends TNumber> void evaluate( + long expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + evaluate((double) expected, input, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(float expected, Operand<U> input) { + evaluate((double) expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + float expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + evaluate((double) expected, input, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(double expected, Operand<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input + */ + public abstract <U extends TNumber> void evaluate( + double expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(byte[] expected, Operand<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + byte[] expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { Byte[] iArray = new Byte[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(int[] expected, Operand<T> input) { + public <U extends TNumber> void evaluate(int[] expected, Operand<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + int[] expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { Integer[] iArray = new Integer[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(long[] expected, Operand<T> input) { + public <U extends TNumber> void evaluate(long[] expected, Operand<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + long[] expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { Long[] iArray = new Long[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedDict); } /** - * Evaluate the input against the expected values + * Evaluate the expected results versus the actual results * - * @param expected the expected values - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(float[] expected, Operand<T> input) { + public <U extends TNumber> void evaluate(float[] expected, Operand<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + float[] expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { Float[] iArray = new Float[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(double[] expected, Operand<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(double[] expected, Operand<T> input) { + public <U extends TNumber> void evaluate( + double[] expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { Double[] iArray = new Double[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param <U> the data type of the input */ - public abstract <T extends TNumber> void evaluate(Number[] expected, Output<T> input); + public <U extends TNumber> void evaluate(Number[] expected, Output<U> input) { + evaluate(expected, input, null); + } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input + */ + public abstract <U extends TNumber> void evaluate( + Number[] expected, + Output<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(String expected, Operand<TString> input) { - evaluate(new String[] {expected}, input); + evaluate(expected, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void evaluate( + String expected, + Operand<TString> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + evaluate(new String[] {expected}, input, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(String expected, Op input) { - evaluate(new String[] {expected}, input); + evaluate(new String[] {expected}, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void evaluate( + String expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + evaluate(new String[] {expected}, input, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(String[] expected, Op input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void evaluate( + String[] expected, + Op input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { Output output = input.op().output(0); - evaluate(expected, output); + evaluate(expected, output, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value */ public void evaluate(String[] expected, Operand<TString> input) { Output output = input.asOutput(); - evaluate(expected, output); + evaluate(expected, output, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. */ - public abstract void evaluate(String[] expected, Output<TString> input); + public abstract void evaluate( + String[] expected, + Output<TString> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value */ public void evaluate(Boolean expected, Operand<TBool> input) { - evaluate(new Boolean[] {expected}, input); + evaluate(new Boolean[] {expected}, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void evaluate( + Boolean expected, + Operand<TBool> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + evaluate(new Boolean[] {expected}, input, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(Boolean expected, Op input) { - evaluate(new Boolean[] {expected}, input); + evaluate(new Boolean[] {expected}, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void evaluate( + Boolean expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + evaluate(new Boolean[] {expected}, input, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(Boolean[] expected, Op input) { Output output = input.op().output(0); - evaluate(expected, output); + evaluate(expected, output, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void evaluate( + Boolean[] expected, + Op input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + Output output = input.op().output(0); + evaluate(expected, output, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(Boolean[] expected, Operand<TBool> input) { Output output = input.asOutput(); - evaluate(expected, output); + evaluate(expected, output, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. */ - public abstract void evaluate(Boolean[] expected, Output<TBool> input); + public void evaluate( + Boolean[] expected, + Operand<TBool> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + Output output = input.asOutput(); + evaluate(expected, output, feedDict); + } - public <T extends TType> void evaluate(Operand<T> expected, Op input) { - Output output = input.op().output(0); - evaluate(expected, output); + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + */ + public void evaluate(Boolean[] expected, Output<TBool> input) { + evaluate(expected, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public abstract void evaluate( + Boolean[] expected, + Output<TBool> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + */ + public <T extends TType> void evaluate(Operand<T> expected, Output<T> input) { + evaluate(expected.asOutput(), input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param <T> the data type for the feedDict entries */ public <T extends TType> void evaluate(Operand<T> expected, Operand<T> input) { - evaluate(expected.asOutput(), input.asOutput()); + evaluate(expected.asOutput(), input.asOutput(), null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <T> the data type for the feedDict entries */ - public abstract <T extends TType> void evaluate(Output<T> expected, Output<T> input); + public abstract <T extends TType> void evaluate( + Output<T> expected, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TType> void evaluate(FloatNdArray expected, Operand<T> input) { - evaluate(expected, input.asOutput()); + public <U extends TNumber> void evaluate(FloatNdArray expected, Operand<U> input) { + evaluate(expected, input.asOutput(), null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input */ - public abstract <T extends TType> void evaluate(FloatNdArray expected, Output<T> input); - - public <T extends TType> void evaluate(Operand<T> input, Predicate<Number> predicate) { - evaluate(input.asOutput(), predicate); + public <U extends TNumber> void evaluate( + FloatNdArray expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + evaluate(expected, input.asOutput(), feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param <U> the data type of the input */ - public abstract <T extends TType> void evaluate(Output<T> input, Predicate<Number> predicate); + public <U extends TNumber> void evaluate(FloatNdArray expected, Output<U> input) { + evaluate(expected, input, null); + } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input + */ + public abstract <U extends TNumber> void evaluate( + FloatNdArray expected, + Output<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + + /** + * Evaluate the actual results using a predicate + * + * @param input the actual value + * @param predicate a predicate that accepts a Number as an argument, if the result of the + * predicate is false, then the test will fail + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(Operand<U> input, Predicate<Number> predicate) { + evaluate(input.asOutput(), predicate, null); + } + + /** + * Evaluate the actual results using a predicate + * + * @param input the actual value + * @param predicate a predicate that accepts a Number as an argument, if the result of the + * predicate is false, then the test will fail + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <U> the data type of the input + */ + public abstract <U extends TNumber> void evaluate( + Output<U> input, + Predicate<Number> predicate, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + + /** + * Evaluate the actual results using a predicate + * + * @param input the actual value + * @param predicate a predicate that accepts a Number as an argument, if the result of the + * predicate is false, then the test will fail */ - public <T extends TType> void evaluate(FloatNdArray input, Predicate<Number> predicate) { + public void evaluate(FloatNdArray input, Predicate<Number> predicate) { input.scalars().forEach(f -> assertTrue(predicate.test(f.getFloat()))); } /** - * Print the input + * Print the results to output stream * * @param out the output stream - * @param input the operand to print - * @param <T> the data type of the input + * @param input the actual value + * @param <T> the data type for the input */ public <T extends TType> void print(OutputStream out, Operand<T> input) { - print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput()); + print(out, input, null); } /** - * Print the input + * Print the results to output stream * * @param out the output stream - * @param input the op to print - * @param <T> the data type of the input + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <T> the data type for the feedDict entries */ - public <T extends TType> void print(OutputStream out, Op input) { - print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0)); + public <T extends TType> void print( + OutputStream out, + Operand<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput(), feedDict); } /** - * Print the input + * Print the results to output stream * * @param out the output stream - * @param input the op to print - * @param <T> the data type of the input + * @param input the actual value + */ + public void print(OutputStream out, Op input) { + print(out, input, null); + } + + /** + * Print the results to output stream + * + * @param out the output stream + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void print( + OutputStream out, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0), feedDict); + } + + /** + * Print the results to output stream + * + * @param out the output stream + * @param input the actual value + * @param <T> the data type for the input */ public <T extends TType> void print(OutputStream out, Output<T> input) { - print(new PrintWriter(new OutputStreamWriter(out)), input); + print(out, input, null); } /** - * Print the input + * Print the results to output stream * - * @param witer the output writer - * @param input the operand to print - * @param <T> the data type of the input + * @param out the output stream + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <T> the data type for the input + */ + public <T extends TType> void print( + OutputStream out, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + print(new PrintWriter(new OutputStreamWriter(out)), input, feedDict); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param <T> the data type for the input */ public <T extends TType> void print(Writer writer, Operand<T> input) { - print(new PrintWriter(writer), input.asOutput()); + print(writer, input, null); } /** - * Print the input + * Print the results to the character stream * - * @param witer the output writer - * @param input the op to print - * @param <T> the data type of the input + * @param writer the character stream + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <T> the data type for the input */ - public <T extends TType> void print(Writer writer, Op input) { - print(new PrintWriter(writer), input.op().output(0)); + public <T extends TType> void print( + Writer writer, + Operand<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + print(new PrintWriter(writer), input.asOutput(), feedDict); } /** - * Print the input + * Print the results to the character stream * - * @param witer the output writer - * @param input the op to print - * @param <T> the data type of the input + * @param writer the character stream + * @param input the actual value + */ + public void print(Writer writer, Op input) { + print(writer, input, null); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void print( + Writer writer, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + print(new PrintWriter(writer), input.op().output(0), feedDict); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param <T> the data type for the input */ public <T extends TType> void print(Writer writer, Output<T> input) { - print(new PrintWriter(writer), input); + print(writer, input, null); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <T> the data type for the input + */ + public <T extends TType> void print( + Writer writer, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + print(new PrintWriter(writer), input, feedDict); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param <T> the data type for the input + */ + public <T extends TType> void print(PrintWriter writer, Output<T> input) { + print(writer, input, null); } /** - * Print the input + * Print the results to the character stream * - * @param witer the output writer - * @param input the op to print + * @param writer the character stream + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <T> the data type for the input */ - public abstract <T extends TType> void print(PrintWriter writer, Output<T> input); + public abstract <T extends TType> void print( + PrintWriter writer, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); /** - * Get the TensorFlow Ops + * Get the TensorFlow Ops for this test session * - * @return the TensorFlow Ops + * @return the TensorFlow Ops for this test session */ public abstract Ops getTF(); /** - * Determine if this Test Session represents an Eager Session + * Determine whether this session is in Eager mode * - * @return true, if this Test Session represents an Eager Session + * @return true if the this session is in Eager mode */ public abstract boolean isEager(); /** - * Determine if this Test Session represents a Graph Session + * Determine whether this session is in Graph mode * - * @return true, if this Test Session represents a Graph Session + * @return true if the this session is in Graph mode */ public boolean isGraph() { return !isEager(); } /** - * Get the epsilon value for evaluating float values + * Get the current EPSILON value for floating point number comparison. * - * @return the epsilon value for evaluating float values + * @return the current EPSILON value for floating point number comparison. */ public float getEpsilon() { return this.epsilon; } /** - * Set the epsilon value for evaluating float values + * Set the current EPSILON value for floating point number comparison. * - * @param epsilon the epsilon value for evaluating float values + * @param epsilon the new EPSILON value for floating point number comparison. */ public void setEpsilon(float epsilon) { this.epsilon = epsilon; } /** - * Get the TensorFlow session object associated with this Test Session + * Get the TensorFlow Session object * - * @return a TensorFlow session if this is a Graph session, otherwise null + * @return the TensorFlow Session object, returns null if this is not a Graph Test Session */ public abstract Session getGraphSession(); /** - * Get the TensorFlow eager session object associated with this Test Session + * Get the TensorFlow EagerSession object * - * @return a TensorFlow session if this is an eager session, otherwise null + * @return the TensorFlow Session object, returns null if this is not a Graph Test Session */ public abstract EagerSession getEagerSession(); @@ -602,15 +1089,21 @@ public void setEpsilon(float epsilon) { @Override public abstract void close(); - /** @return the debug setting */ + /** + * Get the debug setting + * + * @return the debug setting + */ public boolean isDebug() { return debug; } /** - * Set the debug flag + * Sets the debug setting. + * + * <p>If true, then evaluate methods will also print the Tensor values to System.out. * - * @param debug the setting for debugging + * @param debug the debug to set */ public void setDebug(boolean debug) { this.debug = debug; From fa936ccb8e0675452c95ddf60bb50721f0b400a9 Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Sun, 30 Aug 2020 18:43:15 -0400 Subject: [PATCH 02/14] Add ability to change learning rate between steps by adding a Placeholder into each Optimizer. Also, added to each Optimizer a corresponding Tensor that holds the value of the learning rate, and added a feed dictionary that maps the placeholder to the Tensor, so that it can be fed into the runner when running or evaluating. When setLearning rate is called the learning rate tensor and the feed dictionary are updated. --- .../framework/optimizers/AdaDelta.java | 54 +++++++++++++++++- .../framework/optimizers/AdaGrad.java | 53 +++++++++++++++++- .../framework/optimizers/AdaGradDA.java | 51 ++++++++++++++++- .../tensorflow/framework/optimizers/Adam.java | 52 +++++++++++++++-- .../framework/optimizers/GradientDescent.java | 54 +++++++++++++++++- .../framework/optimizers/Momentum.java | 53 +++++++++++++++++- .../framework/optimizers/Optimizer.java | 31 +++++++--- .../framework/optimizers/RMSProp.java | 56 ++++++++++++++++++- 8 files changed, 380 insertions(+), 24 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index b5dc2434d60..aa77a6146f9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -15,12 +15,19 @@ */ package org.tensorflow.framework.optimizers; +import java.util.Collections; import java.util.List; +import java.util.Map; + import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; /** @@ -33,7 +40,10 @@ public class AdaDelta extends Optimizer { public static final String ACCUMULATOR = "accum"; public static final String ACCUMULATOR_UPDATE = "accum_update"; - private final float learningRate; + private float learningRate; + private Tensor<TFloat32> learningRateTensor; + private final Placeholder<TFloat32> learningRatePlaceholder; + private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; private final float rho; @@ -46,6 +56,10 @@ public AdaDelta(Graph graph, float learningRate) { public AdaDelta(Graph graph, float learningRate, float rho, float epsilon) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.rho = rho; this.epsilon = epsilon; } @@ -57,6 +71,10 @@ public AdaDelta(Graph graph, String name, float learningRate) { public AdaDelta(Graph graph, String name, float learningRate, float rho, float epsilon) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.rho = rho; this.epsilon = epsilon; } @@ -82,7 +100,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable Variable<T> accumSlot = getSlot(variable, ACCUMULATOR).get(); Variable<T> accumUpdateSlot = getSlot(variable, ACCUMULATOR_UPDATE).get(); return tf.train.applyAdadelta(variable, accumSlot, accumUpdateSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), tf.dtypes.cast(tf.constant(rho), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), gradient); @@ -101,4 +119,36 @@ public String toString() { public String getOptimizerName() { return "Adadelta"; } + + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 4dfabb21357..7df1ddfa991 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -15,12 +15,19 @@ */ package org.tensorflow.framework.optimizers; +import java.util.Collections; import java.util.List; +import java.util.Map; + import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; /** @@ -33,7 +40,10 @@ public class AdaGrad extends Optimizer { public static final String ACCUMULATOR = "accumulator"; - private final float learningRate; + private float learningRate; + private Tensor<TFloat32> learningRateTensor; + private final Placeholder<TFloat32> learningRatePlaceholder; + private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; private final float initialAccumulatorValue; @@ -44,6 +54,10 @@ public AdaGrad(Graph graph, float learningRate) { public AdaGrad(Graph graph, float learningRate, float initialAccumulatorValue) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.initialAccumulatorValue = initialAccumulatorValue; } @@ -54,6 +68,10 @@ public AdaGrad(Graph graph, String name, float learningRate) { public AdaGrad(Graph graph, String name, float learningRate, float initialAccumulatorValue) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.initialAccumulatorValue = initialAccumulatorValue; } @@ -74,7 +92,7 @@ private <T extends TType> void createAdaGradSlot(Output<T> v) { protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { Variable<T> slot = getSlot(variable, ACCUMULATOR).get(); return tf.train - .applyAdagrad(variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + .applyAdagrad(variable, slot, tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), gradient); } @@ -90,4 +108,35 @@ public String toString() { public String getOptimizerName() { return "Adagrad"; } + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index 0544309dc7f..4d590906b2b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -15,15 +15,20 @@ */ package org.tensorflow.framework.optimizers; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -36,7 +41,10 @@ public class AdaGradDA extends Optimizer { public static final String ACCUMULATOR = "gradient_accumulator"; public static final String SQUARED_ACCUMULATOR = "gradient_squared_accumulator"; - private final float learningRate; + private float learningRate; + private Tensor<TFloat32> learningRateTensor; + private final Placeholder<TFloat32> learningRatePlaceholder; + private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; private final float initialAccumulatorValue; private final float l1Strength; private final float l2Strength; @@ -50,6 +58,10 @@ public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, float l2Strength) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.initialAccumulatorValue = initialAccumulatorValue; this.l1Strength = l1Strength; this.l2Strength = l2Strength; @@ -63,6 +75,10 @@ public AdaGradDA(Graph graph, String name, float learningRate, float initialAccu float l2Strength) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.initialAccumulatorValue = initialAccumulatorValue; this.l1Strength = l1Strength; this.l2Strength = l2Strength; @@ -97,7 +113,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable Variable<T> gradSlot = getSlot(variable, ACCUMULATOR).get(); Variable<T> gradSquaredSlot = getSlot(variable, SQUARED_ACCUMULATOR).get(); return tf.train.applyAdagradDa(variable, gradSlot, gradSquaredSlot, gradient, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), tf.dtypes.cast(tf.constant(l1Strength), gradient.dataType()), tf.dtypes.cast(tf.constant(l2Strength), gradient.dataType()), globalStep); @@ -133,4 +149,35 @@ public String toString() { public String getOptimizerName() { return "adagrad-da"; } + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java index 11ab4be6b64..ac07f4e2fc9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java @@ -15,17 +15,21 @@ */ package org.tensorflow.framework.optimizers; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; import org.tensorflow.op.Op; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.ndarray.Shape; import org.tensorflow.types.TFloat32; @@ -42,7 +46,10 @@ public class Adam extends Optimizer { public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; - private final float learningRate; + private float learningRate; + private Tensor<TFloat32> learningRateTensor; + private final Placeholder<TFloat32> learningRatePlaceholder; + private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; private final float betaOne; @@ -50,7 +57,6 @@ public class Adam extends Optimizer { private final float epsilon; - private Constant<TFloat32> learningRateConst; private Constant<TFloat32> epsilonConst; private Constant<TFloat32> betaOneConst; private Constant<TFloat32> betaTwoConst; @@ -64,6 +70,10 @@ public Adam(Graph graph, float learningRate) { public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; @@ -76,6 +86,10 @@ public Adam(Graph graph, String name, float learningRate) { public Adam(Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; @@ -121,7 +135,6 @@ protected void createSlots(List<Output<? extends TType>> variables) { protected Optional<Op> prepare(String scopeName) { betaOneConst = tf.constant(betaOne); betaTwoConst = tf.constant(betaTwo); - learningRateConst = tf.constant(learningRate); epsilonConst = tf.constant(epsilon); return Optional.empty(); } @@ -142,7 +155,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable return tf.train.applyAdam(variable, firstMomentSlot, secondMomentSlot, tf.dtypes.cast(betaOnePower, gradient.dataType()), tf.dtypes.cast(betaTwoPower, gradient.dataType()), - tf.dtypes.cast(learningRateConst, gradient.dataType()), + tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), tf.dtypes.cast(betaOneConst, gradient.dataType()), tf.dtypes.cast(betaTwoConst, gradient.dataType()), tf.dtypes.cast(epsilonConst, gradient.dataType()), @@ -179,4 +192,35 @@ public String toString() { public String getOptimizerName() { return "Adam"; } + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java index 7ed90c846f1..3ba437d241e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java @@ -16,31 +16,50 @@ package org.tensorflow.framework.optimizers; import org.tensorflow.Graph; +import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; +import java.util.Collections; +import java.util.Map; + /** * Basic SGD. */ public class GradientDescent extends Optimizer { - private final float learningRate; + private float learningRate; + private Tensor<TFloat32> learningRateTensor; + private final Placeholder<TFloat32> learningRatePlaceholder; + private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; public GradientDescent(Graph graph, float learningRate) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); } public GradientDescent(Graph graph, String name, float learningRate) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); } @Override protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { return tf.train.applyGradientDescent(variable, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), gradient); + tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), gradient); } @Override @@ -54,4 +73,35 @@ public String toString() { public String getOptimizerName() { return "GradientDescent"; } + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index b8582b4e278..a058649373a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -15,13 +15,20 @@ */ package org.tensorflow.framework.optimizers; +import java.util.Collections; import java.util.List; +import java.util.Map; + import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyMomentum; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; /** @@ -34,7 +41,10 @@ public class Momentum extends Optimizer { public static final String MOMENTUM = "momentum"; - private final float learningRate; + private float learningRate; + private Tensor<TFloat32> learningRateTensor; + private final Placeholder<TFloat32> learningRatePlaceholder; + private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; private final float momentum; @@ -43,6 +53,10 @@ public class Momentum extends Optimizer { public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.momentum = momentum; this.useNesterov = useNesterov; } @@ -50,6 +64,10 @@ public Momentum(Graph graph, float learningRate, float momentum, boolean useNest public Momentum(Graph graph, String name, float learningRate, float momentum, boolean useNesterov) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.momentum = momentum; this.useNesterov = useNesterov; } @@ -71,7 +89,7 @@ private <T extends TType> void createMomentumSlot(Output<T> v) { protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { Variable<T> slot = getSlot(variable, MOMENTUM).get(); return tf.train - .applyMomentum(variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + .applyMomentum(variable, slot, tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), gradient, tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), ApplyMomentum.useNesterov(useNesterov)); @@ -90,4 +108,35 @@ public String toString() { public String getOptimizerName() { return "Momentum"; } + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index ffff35a8ddd..def464a86ca 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -21,10 +21,8 @@ import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.Output; + +import org.tensorflow.*; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; @@ -36,8 +34,8 @@ /** * Base class for gradient optimizers. */ -public abstract class Optimizer { - +public abstract class Optimizer implements AutoCloseable { + public static final String LEARNING_RATE = "learning_rate"; public static final String VARIABLE_V2 = "VariableV2"; /** * Global state variables @@ -247,7 +245,26 @@ protected Op finish(List<Op> updateOperations, String name) { public abstract String getOptimizerName(); /** - * Optional attributes for {@link org.tensorflow.training.optimizers.Optimizer} + * Set the learning rate + * @param learningRate the learning rate + */ + public abstract void setLearningRate(float learningRate); + + /** + * Get the learning rate + * @return the learning rate + */ + public abstract float getLearningRate(); + + /** + * Get the Feed Dictionary for the run methods to set the Placeholder values(s) + * + * @return the current Feed Dictionary for the run methods + */ + public abstract Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict(); + + /** + * Optional attributes for {@link org.tensorflow.framework.optimizers.Optimizer} */ public static class Options { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index cc64a23de3d..3d28c016de7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -15,12 +15,19 @@ */ package org.tensorflow.framework.optimizers; +import java.util.Collections; import java.util.List; +import java.util.Map; + import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; /** @@ -35,7 +42,10 @@ public class RMSProp extends Optimizer { public static final String MG = "mg"; // mean gradient? public static final String MOMENTUM = "momentum"; - private final float learningRate; + private float learningRate; + private Tensor<TFloat32> learningRateTensor; + private final Placeholder<TFloat32> learningRatePlaceholder; + private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; private final float decay; private final float momentum; private final float epsilon; @@ -49,6 +59,11 @@ public RMSProp(Graph graph, float learningRate, float decay, float momentum, flo boolean centered) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + this.decay = decay; this.momentum = momentum; this.epsilon = epsilon; @@ -63,6 +78,10 @@ public RMSProp(Graph graph, String name, float learningRate, float decay, float boolean centered) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.decay = decay; this.momentum = momentum; this.epsilon = epsilon; @@ -97,14 +116,14 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable if (centered) { Variable<T> mgSlot = getSlot(variable, MG).get(); return tf.train.applyCenteredRmsProp(variable, mgSlot, rmsSlot, momentumSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), tf.dtypes.cast(tf.constant(decay), gradient.dataType()), tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), gradient); } return tf.train.applyRmsProp(variable, rmsSlot, momentumSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), tf.dtypes.cast(tf.constant(decay), gradient.dataType()), tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), @@ -126,4 +145,35 @@ public String toString() { public String getOptimizerName() { return "RMSProp"; } + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } From 0afdb9ccd17e55941d9f48f78b8bbd7e31ec926a Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Tue, 1 Sep 2020 12:55:34 -0400 Subject: [PATCH 03/14] Add support for hanling feed dicts when evaluating or printing Operands. --- .../keras/utils/EagerTestSession.java | 2 +- .../keras/utils/GraphTestSession.java | 2 +- .../tensorflow/keras/utils/TestSession.java | 68 ++++++++++++++++++- 3 files changed, 69 insertions(+), 3 deletions(-) diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java index 6d286c311ff..8c0f26b21e7 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java @@ -31,7 +31,7 @@ import static org.junit.jupiter.api.Assertions.*; -/** @author Jim Clarke */ +/** An Eager Mode Test Session */ public class EagerTestSession extends TestSession { private final EagerSession session; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java index ff18b338ce2..98ff9d40c04 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java @@ -32,7 +32,7 @@ import static org.junit.jupiter.api.Assertions.*; -/** @author Jim Clarke */ +/** A Graph Mode Test Session */ public class GraphTestSession extends TestSession { private final Graph graph; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java index 34348ccc1f4..cd4b891a039 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java @@ -32,7 +32,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; -/** @author Jim Clarke */ +/** Abstract class for Test Sessions */ public abstract class TestSession implements AutoCloseable { protected float epsilon = 1e-5F; @@ -851,6 +851,72 @@ public void evaluate(FloatNdArray input, Predicate<Number> predicate) { input.scalars().forEach(f -> assertTrue(predicate.test(f.getFloat()))); } + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param <T> the data type for the input + */ + public <T extends TType> void print(Operand<T> input) { + print(System.out, input, null); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <T> the data type for the feedDict entries + */ + public <T extends TType> void print( + Operand<T> input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.asOutput(), feedDict); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + */ + public void print(Op input) { + print(System.out, input, null); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void print(Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0), feedDict); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param <T> the data type for the input + */ + public <T extends TType> void print(Output<T> input) { + print(System.out, input, null); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param <T> the data type for the input + */ + public <T extends TType> void print( + Output<T> input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input, feedDict); + } + /** * Print the results to output stream * From 9f10da9e877aeea343b7513bcab1038a03fb9869 Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Tue, 1 Sep 2020 12:58:59 -0400 Subject: [PATCH 04/14] Add tests for changing learning rates --- .../keras/optimizers/AdaDeltaTest.java | 95 +++++++++++ .../keras/optimizers/AdaGradDATest.java | 67 ++++++++ .../keras/optimizers/AdaGradTest.java | 80 +++++++++ .../tensorflow/keras/optimizers/AdamTest.java | 151 +++++++++++++++++ .../keras/optimizers/AdamaxTest.java | 122 ++++++++++++- .../keras/optimizers/NadamTest.java | 160 +++++++++++++++++- .../keras/optimizers/RMSPropTest.java | 136 +++++++++++++++ .../tensorflow/keras/optimizers/SGDTest.java | 73 ++++++++ 8 files changed, 882 insertions(+), 2 deletions(-) diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java index e8a3bc14d9b..403803295e9 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java @@ -202,4 +202,99 @@ public void testBasic() { } } } + + @Test + public void testWithLearningRateDecay() { + int numSteps = 4; // # number of ADADELTA steps to perform + float[] grads = {0.2F, 0.1F, 0.01F}; + + for (float grad : grads) { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + float lr = 1.0F; + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] fgrads = {grad, grad}; + Shape shape = Shape.of(var0_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> cgrads = tf.constant(fgrads); + + float accum = 0.0F; + float accum_update = 0.0F; + float rho = 0.95F; + float epsilon = 1e-8F; + float epsilon1 = 1e-5F; + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput())); + + /* get the Optimizer */ + AdaDelta instance = new AdaDelta(tf, lr, rho, epsilon); + + Op adadelta_update = instance.applyGradients(gradsAndVars, "AdaDeltaTest"); + + /* Create and validae the shapes of the slota */ + Variable<TFloat32>[] slots = new Variable[2]; + Variable<TFloat32>[] slotUpdates = new Variable[2]; + + slots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape()); + + slotUpdates[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape()); + + slots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape()); + + slotUpdates[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + /** make sure the variables were initialized properly */ + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + float[] updates = new float[numSteps]; + float totUpdate = 0; + for (int step = 0; step < numSteps; step++) { + session.run(adadelta_update, instance.getFeedDict()); + accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); + updates[step] = + ((float) Math.sqrt(accum_update + epsilon) + * (float) (1 / Math.sqrt(accum + epsilon)) + * grad); + accum_update = (accum_update * rho + ((float) Math.pow(updates[step], 2) * (1.0F - rho))); + totUpdate += updates[step] * lr; + + for (int i = 0; i < 2; i++) { + session.evaluate(accum, slots[i]); + session.evaluate(accum_update, slotUpdates[i]); + } + + Float[] var0_initUpdate = {var0_init[0] - totUpdate, var0_init[1] - totUpdate}; + Float[] var1_initUpdate = {var1_init[0] - totUpdate, var1_init[1] - totUpdate}; + + session.evaluate(var0_initUpdate, var0); + session.evaluate(var1_initUpdate, var1); + + // Adjust learning rate + lr *= 0.9F; + instance.setLearningRate(lr); + } + } + } + } } diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java index 85f4220c4c7..98c8145515b 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java @@ -121,4 +121,71 @@ public void testBasic() { session.evaluate(expected1, var1); } } + + @Test + public void testWithLearningRateDecay() { + float[] var0_init = {0.0F, 0.0F}; + float[] var1_init = {0.0F, 0.0F}; + float[] grads0_init = {0.1F, 0.2F}; + float[] grads1_init = {0.01F, 0.02F}; + float epsilon = 1e-8F; + float epsilon1 = 1e-5F; + int numSteps = 4; + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> grads0 = tf.constant(grads0_init); + Constant<TFloat32> grads1 = tf.constant(grads1_init); + + /* initialize the local variables */ + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + float learningRate = 3.0F; + + AdaGrad instance = new AdaGrad(tf, learningRate); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdGradDATest"); + + /** initialize the accumulators */ + session.run(tf.init()); + + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + float[][] expected0 = { + {-0.904534F, -1.603567F}, + {-1.683957F, -2.8763597F}, + {-2.3579178F, -3.9125152F}, + {-2.942418F, -4.770327F} + }; + float[][] expected1 = { + {-0.094821F, -0.189358F}, + {-0.18011717F, -0.35944232F}, + {-0.2568455F, -0.51221514F}, + {-0.3258666F, -0.6494397F} + }; + for (int i = 0; i < numSteps; i++) { + session.run(update, instance.getFeedDict()); + System.out.println("step: " + i); + session.evaluate(expected0[i], var0); + session.evaluate(expected1[i], var1); + learningRate *= 0.9; + instance.setLearningRate(learningRate); + } + } + } } diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java index b6f1d7c88fc..0de84aac0c6 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java @@ -41,6 +41,7 @@ /** Test cases for AdaGrad Optimizer */ public class AdaGradTest { + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; int index; @@ -149,6 +150,85 @@ public void testBasic() { } } + @Test + public void testWithLearningRateDecay() { + int numSteps = 3; + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.1F}; + float[] grads1_init = {0.01F, 0.01F}; + float epsilon = 1e-8F; + float epsilon1 = 1e-5F; + float[] accum0 = {0.1f, 0.1f}; + float[] accum1 = {0.1f, 0.1f}; + + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + FloatNdArray accum0_np = NdArrays.vectorOf(accum0); + FloatNdArray accum1_np = NdArrays.vectorOf(accum1); + + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> grads0 = tf.constant(grads0_init); + Constant<TFloat32> grads1 = tf.constant(grads1_init); + + float learningRate = 3.0F; + + AdaGrad instance = new AdaGrad(tf, learningRate); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op ada_update = instance.applyGradients(gradsAndVars, "AdGradTest"); + + Variable<TFloat32>[] accumulatorSlots = new Variable[2]; + accumulatorSlots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(accumulatorSlots[0].asOutput().shape(), var0.asOutput().shape()); + + accumulatorSlots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(accumulatorSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + /** make sure the variables were initialized properly */ + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + for (int step = 0; step < numSteps; step++) { + session.run(ada_update, instance.getFeedDict()); + + accum0_np = caclulateAccum(accum0_np, grads0_np); + var0_np = calculate(var0_np, accum0_np, grads0_np, learningRate); + session.evaluate(var0_np, var0); + + accum1_np = caclulateAccum(accum1_np, grads1_np); + var1_np = calculate(var1_np, accum1_np, grads1_np, learningRate); + session.evaluate(var1_np, var1); + + learningRate *= 0.9; + instance.setLearningRate(learningRate); + } + } + } + private FloatNdArray caclulateAccum(FloatNdArray accum, FloatNdArray grads) { // accum + g_t * g_t FloatNdArray squareG = ND.square(grads); diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java index 6a8f0f5078c..67bc9701935 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java @@ -42,6 +42,7 @@ /** Test cases for Adam Optimizer */ public class AdamTest { + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; int index; @@ -224,6 +225,156 @@ public void testBasic() { } } + @Test + public void testWithLearningRateDecay() { + float m0 = 0.0F; + float v0 = 0.0F; + float m1 = 0.0F; + float v1 = 0.0F; + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.1F}; + float[] grads1_init = {0.01F, 0.01F}; + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + + float epsilon1 = 1e-3F; + + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + + session.setEpsilon(epsilon1); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> grads0 = tf.constant(grads0_init); + Constant<TFloat32> grads1 = tf.constant(grads1_init); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + float learningRate = 0.001F; + float beta1 = 0.9F; + float beta2 = 0.999F; + float epsilon = 1e-8F; + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Adam instance = new Adam(tf, learningRate); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + Variable<TFloat32>[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /** initialize the accumulators */ + session.run(tf.init()); + + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + FloatNdArray m0_np = NdArrays.ofFloats(shape1); + FloatNdArray v0_np = NdArrays.ofFloats(shape1); + FloatNdArray m1_np = NdArrays.ofFloats(shape1); + FloatNdArray v1_np = NdArrays.ofFloats(shape1); + + for (int step = 0; step < 3; step++) { + + // Test powers + final float[] powers = { + (float) Math.pow(beta1, step + 1), (float) Math.pow(beta2, step + 1) + }; + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result + .data() + .scalars() + .forEach( + f -> { + assertEquals(powers[0], f.getFloat(), epsilon1); + }); + } + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("beta2_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result + .data() + .scalars() + .forEach( + f -> { + assertEquals(powers[1], f.getFloat(), epsilon1); + }); + } + session.run(update, instance.getFeedDict()); + + float lr_t = + learningRate + * (float) Math.sqrt(1 - (float) Math.pow(beta2, (step + 1))) + / (1 - (float) Math.pow(beta1, (step + 1))); + + m0_np = calculateM(m0_np, grads0_np, beta1); + v0_np = calculateV(v0_np, grads0_np, beta2); + var0_np = calculateParam(var0_np, lr_t, m0_np, v0_np, 1e-7F); + + m1_np = calculateM(m1_np, grads1_np, beta1); + v1_np = calculateV(v1_np, grads1_np, beta2); + var1_np = calculateParam(var1_np, lr_t, m1_np, v1_np, 1e-7F); + + // evaluate var 0 and var1 + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); + + // first moment + session.evaluate(m0_np, firstMomentSlots[0]); + session.evaluate(m1_np, firstMomentSlots[1]); + + // second moment + session.evaluate(v0_np, secondMomentSlots[0]); + session.evaluate(v1_np, secondMomentSlots[1]); + + learningRate *= 0.9; + instance.setLearningRate(learningRate); + } + } + } + private FloatNdArray calculateM(FloatNdArray m, FloatNdArray g_t, float beta) { // m_t = beta1 * m + (1 - beta1) * g_t return ND.add(ND.mul(m, beta), ND.mul(g_t, (1 - beta))); diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java index 3f6b232c179..24ec7cb15cc 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java @@ -40,6 +40,7 @@ /** Test cases for Adamax Optimizer */ public class AdamaxTest { + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; private static final int VAR = 0; @@ -197,16 +198,135 @@ public void testBasic() { v1 = resultNP[V]; // evaluate var0 and var1 + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); + } + } + } + + @Test + public void testWithLearningRateDecay() { + + float epsilon = 1e-6f; + float epsilon1 = 1e-3F; + int numSteps = 3; + + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + float[] zeros = {0.0F, 0.0F}; + FloatNdArray m0 = NdArrays.vectorOf(zeros); + FloatNdArray v0 = NdArrays.vectorOf(zeros); + FloatNdArray m1 = NdArrays.vectorOf(zeros); + FloatNdArray v1 = NdArrays.vectorOf(zeros); + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.1F}; + float[] grads1_init = {0.01F, 0.01F}; + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> grads0 = tf.constant(grads0_init); + Constant<TFloat32> grads1 = tf.constant(grads1_init); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + float learningRate = 0.001F; + + Adamax instance = new Adamax(tf, learningRate); + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + Variable<TFloat32>[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /** initialize the accumulators */ + session.run(tf.init()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + session.setEpsilon(epsilon1); + for (int step = 0; step < numSteps; step++) { + // Test powers + final float beta1_power = (float) Math.pow(BETA_ONE_DEFAULT, step + 1); + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result + .data() + .scalars() + .forEach( + f -> { + assertEquals(beta1_power, f.getFloat(), epsilon1); + }); + } + session.run(update, instance.getFeedDict()); + + FloatNdArray[] resultNP = calculate(var0_np, grads0_np, step, m0, v0, learningRate); + var0_np = resultNP[VAR]; + m0 = resultNP[M]; + v0 = resultNP[V]; + + resultNP = calculate(var1_np, grads1_np, step, m1, v1, learningRate); + var1_np = resultNP[VAR]; + m1 = resultNP[M]; + v1 = resultNP[V]; + + // evaluate var0 and var1 session.evaluate(var0_np, var0); session.evaluate(var1_np, var1); + + learningRate *= 0.9F; + instance.setLearningRate(learningRate); } } } private FloatNdArray[] calculate( FloatNdArray var_np, FloatNdArray grads_np, int step, FloatNdArray m, FloatNdArray v) { - float alpha = 0.001F; + return calculate(var_np, grads_np, step, m, v, 0.001F); + } + + private FloatNdArray[] calculate( + FloatNdArray var_np, + FloatNdArray grads_np, + int step, + FloatNdArray m, + FloatNdArray v, + float alpha) { float beta1 = BETA_ONE_DEFAULT; float beta2 = BETA_TWO_DEFAULT; float espilon = 1e-8F; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java index 6314b4b8b4c..32d90ea91ed 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java @@ -237,6 +237,154 @@ public void testBasic() { } } + @Test + public void testWithLearningRateDecay() { + int numSteps = 3; + + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.1F}; + float[] grads1_init = {0.01F, 0.01F}; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; + FloatNdArray m0 = NdArrays.vectorOf(zeros); + FloatNdArray v0 = NdArrays.vectorOf(zeros); + FloatNdArray m1 = NdArrays.vectorOf(zeros); + FloatNdArray v1 = NdArrays.vectorOf(zeros); + FloatNdArray mcache = NdArrays.vectorOf(ones); + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + + float epsilon = 1e-6f; + float epsilon1 = 1e-3F; + + float learningRate = 0.001F; + + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> grads0 = tf.constant(grads0_init); + Constant<TFloat32> grads1 = tf.constant(grads1_init); + + Nadam instance = new Nadam(tf, learningRate); + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + Variable<TFloat32>[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + session.setEpsilon(epsilon1); + + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result + .data() + .scalars() + .forEach( + f -> { + assertEquals(1F, f.getFloat(), epsilon1); + }); + } + momentum = 1F; + + for (int step = 0; step < numSteps; step++) { + + session.run(update, instance.getFeedDict()); + + float mut = + Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); + momentum = momentum * mut; + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result + .data() + .scalars() + .forEach( + f -> { + assertEquals(momentum, f.getFloat(), epsilon1); + }); + } + mcache = ND.mul(mcache, momentum); + FloatNdArray[] resultsNP = + nadam_update_numpy(var0_np, grads0_np, step, m0, v0, mcache, learningRate); + var0_np = resultsNP[VAR]; + m0 = resultsNP[M]; + v0 = resultsNP[V]; + + resultsNP = nadam_update_numpy(var1_np, grads1_np, step, m1, v1, mcache, learningRate); + var1_np = resultsNP[VAR]; + m1 = resultsNP[M]; + v1 = resultsNP[V]; + + // evaluate m0 and m1 + session.evaluate(m0, firstMomentSlots[0]); + session.evaluate(m1, firstMomentSlots[1]); + + // evaluate v0 and v1 + session.evaluate(v0, secondMomentSlots[0]); + session.evaluate(v1, secondMomentSlots[1]); + + // evaluate var0 and var1 + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); + + learningRate *= 0.9; + instance.setLearningRate(learningRate); + } + } + } + private FloatNdArray update_m_cache(FloatNdArray mcache, int t) { float mu_t = 0.9F * (1.0F - 0.5F * (float) Math.pow(0.96, (0.004 * (t + 1)))); return ND.mul(mu_t, mcache); @@ -249,8 +397,18 @@ private FloatNdArray[] nadam_update_numpy( FloatNdArray m, FloatNdArray v, FloatNdArray m_cache) { + return nadam_update_numpy(var_np, grads_np, t, m, v, m_cache, 0.001F); + } + + private FloatNdArray[] nadam_update_numpy( + FloatNdArray var_np, + FloatNdArray grads_np, + int t, + FloatNdArray m, + FloatNdArray v, + FloatNdArray m_cache, + float alpha) { - float alpha = 0.001F; float beta1 = 0.9F; float beta2 = 0.999F; float epsilon = 1e-8F; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java index 2a43bdb3df2..7651872643b 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java @@ -41,6 +41,7 @@ /** Test cases for RMSProp Optimizer */ public class RMSPropTest { + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; final int VAR_T = 0; @@ -224,6 +225,141 @@ public void testDense() { } } + @Test + public void testWithLearningRateDecay() { + int numSteps = 3; + + for (int run = 0; run < _test_param_values.length; run++) { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + session.setEpsilon(1e-2f); + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.2F}; + float[] grads1_init = {0.01F, 0.2F}; + final float epsilon1 = 1e-2F; + + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> grads0 = tf.constant(grads0_init); + Constant<TFloat32> grads1 = tf.constant(grads1_init); + + // learning_rate, rho (decay), momentum, epsilon, centered + float learningRate = (float) (float) _test_param_values[run][0]; + float decay = (float) _test_param_values[run][1]; + float momentum = (float) _test_param_values[run][2]; + float epsilon = (float) _test_param_values[run][3]; + boolean centered = (boolean) _test_param_values[run][4]; + + RMSProp instance = new RMSProp(tf, learningRate, decay, momentum, epsilon, centered); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + /** make sure the variables were initialized properly */ + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + Variable<TFloat32> mg0 = centered ? instance.getSlot(var0.asOutput(), MG).get() : null; + Variable<TFloat32> mg1 = centered ? instance.getSlot(var1.asOutput(), MG).get() : null; + Variable<TFloat32> mom0 = + momentum > 0.F ? instance.getSlot(var0.asOutput(), MOMENTUM).get() : null; + Variable<TFloat32> mom1 = + momentum > 0.F ? instance.getSlot(var1.asOutput(), MOMENTUM).get() : null; + Variable<TFloat32> rms0 = instance.getSlot(var0.asOutput(), RMS).get(); + Variable<TFloat32> rms1 = instance.getSlot(var1.asOutput(), RMS).get(); + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; // temp to match RMSProp + FloatNdArray mg0_np = NdArrays.vectorOf(zeros); + FloatNdArray mg1_np = NdArrays.vectorOf(zeros); + FloatNdArray rms0_np = NdArrays.vectorOf(ones); + FloatNdArray rms1_np = NdArrays.vectorOf(ones); + FloatNdArray mom0_np = NdArrays.vectorOf(zeros); + FloatNdArray mom1_np = NdArrays.vectorOf(zeros); + + for (int i = 0; i < numSteps; i++) { + session.run(update, instance.getFeedDict()); + FloatNdArray[] result0 = + calc( + var0_np, + grads0_np, + mg0_np, + rms0_np, + mom0_np, + learningRate, + decay, + momentum, + epsilon, + centered); + var0_np = result0[VAR_T]; + mg0_np = result0[MG_T]; + rms0_np = result0[RMS_T]; + mom0_np = result0[MOM_T]; + + FloatNdArray[] result1 = + calc( + var1_np, + grads1_np, + mg1_np, + rms1_np, + mom1_np, + learningRate, + decay, + momentum, + epsilon, + centered); + + var1_np = result1[VAR_T]; + mg1_np = result1[MG_T]; + rms1_np = result1[RMS_T]; + mom1_np = result1[MOM_T]; + + if (centered) { + session.evaluate(mg0_np, mg0); + session.evaluate(mg0_np, mg0); + } + if (momentum > 0.F) { + session.evaluate(mom0_np, mom0); + session.evaluate(mom1_np, mom1); + } + + /* TODO the values returned from rms slot, do not match what I see in the python test */ + session.evaluate(rms0_np, rms0); + session.evaluate(rms1_np, rms1); + + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); + + learningRate *= 0.9F; + instance.setLearningRate(learningRate); + } + } + } + } + FloatNdArray[] calc( FloatNdArray var_np, FloatNdArray grad_np, diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java index 3d24b85239a..1cf20f1b0d2 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java @@ -218,4 +218,77 @@ public void testMomentum() { session.evaluate(expectedVar1_2, var1); } } + + @Test + public void testWithLearningRateDecay() { + int numSteps = 2; + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.1F}; + float[] grads1_init = {0.01F, 0.01F}; + + float learningRate = 3.0F; + + float epsilon = 1e-6F; + float epsilon1 = 1e-2F; + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> grads0 = tf.constant(grads0_init); + Constant<TFloat32> grads1 = tf.constant(grads1_init); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + SGD instance = new SGD(tf, learningRate); + Op update = instance.applyGradients(gradsAndVars, "SGDTest"); + + Variable<TFloat32> momentumSlot0 = instance.getSlot(var0.asOutput(), MOMENTUM).get(); + assertEquals(momentumSlot0.asOutput().shape(), var0.asOutput().shape()); + Variable<TFloat32> momentumSlot1 = instance.getSlot(var1.asOutput(), MOMENTUM).get(); + assertEquals(momentumSlot1.asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + /** make sure the variables were initialized properly */ + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + float[][] expectedVar0 = { + {0.7F, 1.7F}, + {0.66999996F, 1.6700001F}, + {0.66699994F, 1.667F}, + {0.66669995F, 1.6667F}, + {0.66666996F, 1.66667F} + }; + float[][] expectedVar1 = { + {2.97F, 3.97F}, + {2.967F, 3.967F}, + {2.9667F, 3.9667F}, + {2.96667F, 3.96667F}, + {2.966667F, 3.966667F} + }; + for (int step = 0; step < numSteps; step++) { + session.run(update, instance.getFeedDict()); + session.evaluate(expectedVar0[step], var0); + session.evaluate(expectedVar1[step], var1); + learningRate *= 0.1; + instance.setLearningRate(learningRate); + } + } + } } From d8fab044e35d973aa80e7cde3765a166f783c38b Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Mon, 14 Sep 2020 12:45:48 -0400 Subject: [PATCH 05/14] Moved Optimizers to Keras. Added support for chanign learning rate. --- .../src/bazel/op_generator/op_generator.cc | 10 +- .../annotations/org/tensorflow/op/NnOps.java | 139 ++++-- .../java/org/tensorflow/op/core/Abort.java | 1 - .../org/tensorflow/op/core/AssertThat.java | 1 - .../op/core/AssignAddVariableOp.java | 1 - .../op/core/AssignSubVariableOp.java | 1 - .../tensorflow/op/core/AssignVariableOp.java | 1 - .../org/tensorflow/op/core/BarrierClose.java | 1 - .../tensorflow/op/core/BarrierInsertMany.java | 1 - .../tensorflow/op/core/ConsumeMutexLock.java | 1 - .../tensorflow/op/core/ControlTrigger.java | 1 - .../op/core/DeleteSessionTensor.java | 1 - .../tensorflow/op/core/DestroyResourceOp.java | 1 - .../org/tensorflow/op/core/DeviceIndex.java | 3 + .../tensorflow/op/core/InitializeTable.java | 1 - .../op/core/InitializeTableFromTextFile.java | 1 - .../tensorflow/op/core/LookupTableImport.java | 1 - .../tensorflow/op/core/LookupTableInsert.java | 1 - .../tensorflow/op/core/LookupTableRemove.java | 1 - .../java/org/tensorflow/op/core/MapClear.java | 1 - .../java/org/tensorflow/op/core/MapStage.java | 1 - .../gen/java/org/tensorflow/op/core/NoOp.java | 1 - .../tensorflow/op/core/OrderedMapClear.java | 1 - .../tensorflow/op/core/OrderedMapStage.java | 1 - .../java/org/tensorflow/op/core/Print.java | 1 - .../op/core/ResourceScatterAdd.java | 1 - .../op/core/ResourceScatterDiv.java | 1 - .../op/core/ResourceScatterMax.java | 1 - .../op/core/ResourceScatterMin.java | 1 - .../op/core/ResourceScatterMul.java | 1 - .../op/core/ResourceScatterNdAdd.java | 1 - .../op/core/ResourceScatterNdMax.java | 2 + .../op/core/ResourceScatterNdMin.java | 2 + .../op/core/ResourceScatterNdSub.java | 1 - .../op/core/ResourceScatterNdUpdate.java | 1 - .../op/core/ResourceScatterSub.java | 1 - .../op/core/ResourceScatterUpdate.java | 1 - .../op/core/ResourceStridedSliceAssign.java | 1 - .../org/tensorflow/op/core/ScatterNdMax.java | 3 + .../org/tensorflow/op/core/ScatterNdMin.java | 3 + .../gen/java/org/tensorflow/op/core/Send.java | 1 - .../java/org/tensorflow/op/core/Stage.java | 1 - .../org/tensorflow/op/core/StageClear.java | 1 - .../tensorflow/op/core/TensorArrayClose.java | 1 - .../core/TensorForestCreateTreeVariable.java | 1 - .../op/core/TensorForestTreeDeserialize.java | 1 - .../op/core/TensorScatterNdMax.java | 3 + .../op/core/TensorScatterNdMin.java | 3 + .../op/core/XlaSpmdFullToShardShape.java | 3 + .../op/core/XlaSpmdShardToFullShape.java | 3 + .../tensorflow/op/data/DatasetToTfRecord.java | 1 - .../tensorflow/op/data/DeleteIterator.java | 1 - .../tensorflow/op/data/DeleteMemoryCache.java | 1 - .../op/data/DeleteMultiDeviceIterator.java | 1 - .../op/data/DeserializeIterator.java | 1 - .../op/data/InitializeTableFromDataset.java | 2 + .../org/tensorflow/op/data/MakeIterator.java | 1 - .../tensorflow/op/data/RegisterDataset.java | 3 + .../op/data/ShuffleAndRepeatDataset.java | 2 +- .../tensorflow/op/data/ShuffleDataset.java | 2 +- .../op/data/experimental/CompressElement.java | 3 + .../data/experimental/DataServiceDataset.java | 3 + .../data/experimental/DatasetToTFRecord.java | 1 - .../experimental/DummyIterationCounter.java | 3 + .../StatsAggregatorSetSummaryWriter.java | 1 - .../data/experimental/UncompressElement.java | 3 + .../estimator/BoostedTreesCreateEnsemble.java | 1 - ...stedTreesCreateQuantileStreamResource.java | 1 - .../BoostedTreesDeserializeEnsemble.java | 1 - ...eesQuantileStreamResourceAddSummaries.java | 1 - ...reesQuantileStreamResourceDeserialize.java | 1 - ...ostedTreesQuantileStreamResourceFlush.java | 1 - .../estimator/BoostedTreesUpdateEnsemble.java | 1 - .../BoostedTreesUpdateEnsembleV2.java | 1 - .../tensorflow/op/image/ExtractGlimpse.java | 2 +- .../java/org/tensorflow/op/io/QueueClose.java | 1 - .../org/tensorflow/op/io/QueueEnqueue.java | 1 - .../tensorflow/op/io/QueueEnqueueMany.java | 1 - .../org/tensorflow/op/io/ReaderReset.java | 1 - .../tensorflow/op/io/ReaderRestoreState.java | 1 - .../java/org/tensorflow/op/io/WriteFile.java | 1 - .../op/linalg/BandedTriangularSolve.java | 3 + .../java/org/tensorflow/op/math/BesselI0.java | 3 + .../java/org/tensorflow/op/math/BesselI1.java | 3 + .../org/tensorflow/op/math/DenseBincount.java | 3 + .../tensorflow/op/math/special/BesselJ0.java | 3 + .../tensorflow/op/math/special/BesselJ1.java | 3 + .../tensorflow/op/math/special/BesselK0.java | 3 + .../tensorflow/op/math/special/BesselK0e.java | 3 + .../tensorflow/op/math/special/BesselK1.java | 3 + .../tensorflow/op/math/special/BesselK1e.java | 3 + .../tensorflow/op/math/special/BesselY0.java | 3 + .../tensorflow/op/math/special/BesselY1.java | 3 + .../tensorflow/op/ragged/RaggedBincount.java | 3 + .../op/ragged/RaggedCountSparseOutput.java | 3 + .../org/tensorflow/op/ragged/RaggedCross.java | 3 + .../op/random/AnonymousSeedGenerator.java | 3 + .../op/random/DeleteRandomSeedGenerator.java | 1 - .../op/random/DeleteSeedGenerator.java | 2 + .../org/tensorflow/op/random/RngSkip.java | 1 - ...StatelessParameterizedTruncatedNormal.java | 3 + .../experimental/DummySeedGenerator.java | 3 + .../op/sparse/DenseCountSparseOutput.java | 3 + .../SparseAccumulatorApplyGradient.java | 1 - .../tensorflow/op/sparse/SparseBincount.java | 3 + .../op/sparse/SparseCountSparseOutput.java | 3 + .../org/tensorflow/op/sparse/SparseCross.java | 2 +- .../op/sparse/SparseCrossHashed.java | 3 + .../op/summary/CloseSummaryWriter.java | 1 - .../op/summary/CreateSummaryDbWriter.java | 1 - .../op/summary/CreateSummaryFileWriter.java | 1 - .../op/summary/FlushSummaryWriter.java | 1 - .../tensorflow/op/summary/ImportEvent.java | 1 - .../op/summary/WriteAudioSummary.java | 1 - .../op/summary/WriteGraphSummary.java | 1 - .../op/summary/WriteHistogramSummary.java | 1 - .../op/summary/WriteImageSummary.java | 1 - .../op/summary/WriteRawProtoSummary.java | 1 - .../op/summary/WriteScalarSummary.java | 1 - .../tensorflow/op/summary/WriteSummary.java | 1 - .../op/tpu/ConfigureTPUEmbedding.java | 1 - .../tpu/EnqueueTPUEmbeddingIntegerBatch.java | 1 - .../EnqueueTPUEmbeddingRaggedTensorBatch.java | 2 + .../tpu/EnqueueTPUEmbeddingSparseBatch.java | 1 - .../EnqueueTPUEmbeddingSparseTensorBatch.java | 1 - .../org/tensorflow/op/tpu/InfeedEnqueue.java | 1 - .../tpu/InfeedEnqueuePrelinearizedBuffer.java | 1 - .../tensorflow/op/tpu/InfeedEnqueueTuple.java | 1 - .../tpu/LoadTPUEmbeddingADAMParameters.java | 1 - ...EmbeddingADAMParametersGradAccumDebug.java | 1 - .../LoadTPUEmbeddingAdadeltaParameters.java | 1 - ...ddingAdadeltaParametersGradAccumDebug.java | 1 - .../LoadTPUEmbeddingAdagradParameters.java | 1 - ...eddingAdagradParametersGradAccumDebug.java | 1 - ...TPUEmbeddingCenteredRMSPropParameters.java | 1 - .../tpu/LoadTPUEmbeddingFTRLParameters.java | 1 - ...EmbeddingFTRLParametersGradAccumDebug.java | 1 - ...TPUEmbeddingMDLAdagradLightParameters.java | 1 - .../LoadTPUEmbeddingMomentumParameters.java | 1 - ...ddingMomentumParametersGradAccumDebug.java | 1 - ...TPUEmbeddingProximalAdagradParameters.java | 1 - ...oximalAdagradParametersGradAccumDebug.java | 1 - ...oadTPUEmbeddingProximalYogiParameters.java | 2 + ...gProximalYogiParametersGradAccumDebug.java | 2 + .../LoadTPUEmbeddingRMSPropParameters.java | 1 - ...eddingRMSPropParametersGradAccumDebug.java | 1 - ...ngStochasticGradientDescentParameters.java | 1 - ...adientDescentParametersGradAccumDebug.java | 2 + .../org/tensorflow/op/tpu/OutfeedEnqueue.java | 1 - .../op/tpu/OutfeedEnqueueTuple.java | 1 - ...eveTPUEmbeddingProximalYogiParameters.java | 3 + ...gProximalYogiParametersGradAccumDebug.java | 3 + ...adientDescentParametersGradAccumDebug.java | 3 + .../op/tpu/SendTPUEmbeddingGradients.java | 1 - .../op/tpu/ShutdownDistributedTPU.java | 1 - .../op/tpu/TPUReplicateMetadata.java | 1 - .../op/train/AccumulatorApplyGradient.java | 1 - .../op/train/AccumulatorSetGlobalStep.java | 1 - .../op/train/MergeV2Checkpoints.java | 1 - .../org/tensorflow/op/train/NegTrain.java | 1 - .../ResourceAccumulatorApplyGradient.java | 1 - .../ResourceAccumulatorSetGlobalStep.java | 1 - .../op/train/ResourceApplyAdaMax.java | 1 - .../op/train/ResourceApplyAdadelta.java | 1 - .../op/train/ResourceApplyAdagrad.java | 1 - .../op/train/ResourceApplyAdagradDa.java | 1 - .../op/train/ResourceApplyAdam.java | 1 - .../train/ResourceApplyAdamWithAmsgrad.java | 1 - .../op/train/ResourceApplyAddSign.java | 1 - .../train/ResourceApplyCenteredRmsProp.java | 1 - .../op/train/ResourceApplyFtrl.java | 8 +- .../train/ResourceApplyGradientDescent.java | 1 - .../op/train/ResourceApplyKerasMomentum.java | 1 - .../op/train/ResourceApplyMomentum.java | 1 - .../op/train/ResourceApplyPowerSign.java | 1 - .../train/ResourceApplyProximalAdagrad.java | 1 - .../ResourceApplyProximalGradientDescent.java | 1 - .../op/train/ResourceApplyRmsProp.java | 1 - .../op/train/ResourceSparseApplyAdadelta.java | 1 - .../op/train/ResourceSparseApplyAdagrad.java | 1 - .../train/ResourceSparseApplyAdagradDa.java | 1 - .../train/ResourceSparseApplyAdagradV2.java | 1 - .../ResourceSparseApplyCenteredRmsProp.java | 1 - .../op/train/ResourceSparseApplyFtrl.java | 7 +- .../ResourceSparseApplyKerasMomentum.java | 1 - .../op/train/ResourceSparseApplyMomentum.java | 1 - .../ResourceSparseApplyProximalAdagrad.java | 1 - ...rceSparseApplyProximalGradientDescent.java | 1 - .../op/train/ResourceSparseApplyRmsProp.java | 1 - .../java/org/tensorflow/op/train/Save.java | 1 - .../org/tensorflow/op/train/SaveSlices.java | 1 - .../org/tensorflow/op/train/SdcaShrinkL1.java | 1 - .../gen/java/org/tensorflow/op/xla/Send.java | 1 - .../main/java/org/tensorflow/op/core/NN.java | 379 --------------- .../op/nn/SigmoidCrossEntropyWithLogits.java | 108 +++++ .../op/nn/SoftmaxCrossEntropyWithLogits.java | 214 +++++++++ .../SparseSoftmaxCrossEntropyWithLogits.java | 161 +++++++ .../main/java/org/tensorflow/types/TBool.java | 9 +- .../java/org/tensorflow/types/TString.java | 11 +- .../framework/optimizers/Momentum.java | 163 ++++--- .../framework/optimizers/Nadam.java | 295 ++++++++++++ .../framework/optimizers/Optimizer.java | 189 +++++--- .../framework/optimizers/Optimizers.java | 41 ++ .../framework/optimizers/RMSProp.java | 223 +++++---- .../schedules/PiecewiseConstantDecay.java | 58 +++ .../optimizers/schedules/PolynomialDecay.java | 127 +++++ .../framework/optimizers/MomentumTest.java | 182 +++---- .../framework}/optimizers/NadamTest.java | 307 ++++++------ .../framework/optimizers/OptimizersTest.java | 134 ++++++ .../framework/optimizers/RMSPropTest.java | 450 ++++++++++++++++++ .../schedules/PiecewiseConstantDecayTest.java | 16 + .../schedules/PolynomialDecayTest.java | 24 + .../org/tensorflow/framework}/utils/ND.java | 38 +- .../framework}/utils/TestSession.java | 261 +++++----- .../tensorflow/keras/optimizers/Nadam.java | 429 ----------------- .../keras/optimizers/OptimizerInterface.java | 49 -- .../keras/optimizers/Optimizers.java | 125 ----- .../tensorflow/keras/optimizers/RMSProp.java | 188 -------- .../org/tensorflow/keras/optimizers/SGD.java | 188 -------- .../keras/optimizers/RMSPropTest.java | 444 ----------------- 220 files changed, 2601 insertions(+), 2651 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/NN.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizers.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java rename tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java => tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java (58%) rename {tensorflow-keras/src/test/java/org/tensorflow/keras => tensorflow-framework/src/test/java/org/tensorflow/framework}/optimizers/NadamTest.java (50%) create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java rename {tensorflow-keras/src/test/java/org/tensorflow/keras => tensorflow-framework/src/test/java/org/tensorflow/framework}/utils/ND.java (96%) rename {tensorflow-keras/src/test/java/org/tensorflow/keras => tensorflow-framework/src/test/java/org/tensorflow/framework}/utils/TestSession.java (82%) delete mode 100644 tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java delete mode 100644 tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java delete mode 100644 tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java delete mode 100644 tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java delete mode 100644 tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java delete mode 100644 tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc index 03db4be125b..843f3bdb247 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc @@ -514,11 +514,15 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, Javadoc name_javadoc = Javadoc::Create("The name of this op, as known by TensorFlow core engine"); string quoted_string = "\"" + op.graph_op_name() + "\""; writer.WriteFieldWithInitializer(nameVariable, PUBLIC|STATIC|FINAL, &name_javadoc, quoted_string ); - writer.EndLine(); - for (const ArgumentSpec& output : op.outputs()) { - writer.WriteField(output.var(), PRIVATE); + + if(!op.outputs().empty()) { + writer.EndLine(); + for (const ArgumentSpec& output : op.outputs()) { + writer.WriteField(output.var(), PRIVATE); + } } + RenderConstructor(op, op_class, &writer); writer.EndType(); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java index 8374a864ec2..33caf02d890 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java @@ -20,7 +20,6 @@ import java.util.List; import org.tensorflow.DataType; import org.tensorflow.Operand; -import org.tensorflow.op.core.NN; import org.tensorflow.op.nn.AvgPool; import org.tensorflow.op.nn.AvgPool3d; import org.tensorflow.op.nn.AvgPool3dGrad; @@ -84,10 +83,13 @@ import org.tensorflow.op.nn.Relu; import org.tensorflow.op.nn.Relu6; import org.tensorflow.op.nn.Selu; +import org.tensorflow.op.nn.SigmoidCrossEntropyWithLogits; import org.tensorflow.op.nn.Softmax; +import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; import org.tensorflow.op.nn.Softsign; import org.tensorflow.op.nn.SpaceToBatch; import org.tensorflow.op.nn.SpaceToDepth; +import org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits; import org.tensorflow.op.nn.TopK; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -1756,49 +1758,53 @@ public <T extends TNumber> Selu<T> selu(Operand<T> features) { } /** - * Computes sigmoid cross entropy given `logits`. + * Computes sigmoid cross entropy given <code>logits</code>. * * <p>Measures the probability error in discrete classification tasks in which each class is * independent and not mutually exclusive. For instance, one could perform multilabel * classification where a picture can contain both an elephant and a dog at the same time. * - * <p>For brevity, let `x = logits`, `z = labels`. The logistic loss is + * <p>For brevity, let <code>x = logits</code>, <code>z = labels</code>. The logistic loss in + * pseudo-code is * * <pre> - * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) - * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) - * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) - * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) - * = (1 - z) * x + log(1 + exp(-x)) - * = x - x * z + log(1 + exp(-x)) + * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) + * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) + * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) + * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) + * = (1 - z) * x + log(1 + exp(-x)) + * = x - x * z + log(1 + exp(-x)) * </pre> * - * <p>For x < 0, to avoid overflow in exp(-x), we reformulate the above + * <p>For <code>x < 0</code>, to avoid overflow in <code>exp(-x)</code>, we reformulate the above * * <pre> - * x - x * z + log(1 + exp(-x)) - * = log(exp(x)) - x * z + log(1 + exp(-x)) - * = - x * z + log(1 + exp(x)) + * x - x * z + log(1 + exp(-x)) + * = log(exp(x)) - x * z + log(1 + exp(-x)) + * = - x * z + log(1 + exp(x)) * </pre> * * <p>Hence, to ensure stability and avoid overflow, the implementation uses this equivalent * formulation * * <pre> - * max(x, 0) - x * z + log(1 + exp(-abs(x))) + * max(x, 0) - x * z + log(1 + exp(-abs(x))) * </pre> * - * <p>`logits` and `labels` must have the same type and shape. + * <p></ode>logits</code> and <code>labels</code> must have the same type and shape. + * + * <p> * * @param scope The TensorFlow scope * @param labels the labels * @param logits the logits of type float32 or float64 * @param <T> the type of labels and logits * @return the component-wise logistic losses. + * @throws IllegalArgumentException if logits' and labels' do not have the same shape */ public <T extends TNumber> Operand<T> sigmoidCrossEntropyWithLogits(Operand<T> labels, Operand<T> logits) { - return NN.sigmoidCrossEntropyWithLogits(scope, labels, logits); + return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits); } /** @@ -1817,48 +1823,54 @@ public <T extends TNumber> Softmax<T> softmax(Operand<T> logits) { } /** - * Computes softmax cross entropy between `logits` and `labels`. + * Computes softmax cross entropy between <code>logits</code> and <code>labels</code>. * * <p>Measures the probability error in discrete classification tasks in which the classes are * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is * labeled with one and only one label: an image can be a dog or a truck, but not both. * - * <p>**NOTE:** While the classes are mutually exclusive, their probabilities need not be. All - * that is required is that each row of `labels` is a valid probability distribution. If they are - * not, the computation of the gradient will be incorrect. + * <p><b>NOTE:</b> + * + * <p>While the classes are mutually exclusive, their probabilities need not be. All that is + * required is that each row of <code>labels</code> is a valid probability distribution. If they + * are not, the computation of the gradient will be incorrect. * - * <p>If using exclusive `labels` (wherein one and only one class is true at a time), see - * `sparse_softmax_cross_entropy_with_logits`. + * <p>If using exclusive <code>labels</code> (wherein one and only one class is true at a time), + * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} * * <p>Usage: * * <pre> - * >>> logits = [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]] - * >>> labels = [[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]] - * >>> tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits) - * <tf.Tensor: shape=(2,), dtype=float32, - * numpy=array([0.16984604, 0.82474494], dtype=float32)> + * Operand<TFloat32> logits = + * tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} ); + * Operand<TFloat32> labels = + * tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} ); + * Operand<TFloat32> output = + * tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1); + * // output Shape = [2] + * // dataType = FLOAT (1) + * // values { 0.169846, 0.824745 } * </pre> * - * <p>Backpropagation will happen into both `logits` and `labels`. To disallow backpropagation - * into `labels`, pass label tensors through `tf.stop_gradient` before feeding it to this - * function. + * <p>Backpropagation will happen into both <code>logits</code> and <code>labels</code>. To + * disallow backpropagation into <code>labels</code>, pass label tensors through <code> + * tf.stopGradient</code> before feeding it to this function. * * @param scope current scope * @param labels Each vector along the class dimension should hold a valid probability - * distribution e.g. for the case in which labels are of shape `[batch_size, num_classes]`, - * each row of `labels[i]` must be a valid probability distribution. + * distribution e.g. for the case in which labels are of shape <code>[batch_size, num_classes] + * </code>, each row of <code>labels[i]</code> must be a valid probability distribution. * @param logits Per-label activations, typically a linear output. These activation energies are * interpreted as unnormalized log probabilities. * @param axis The class dimension. -1 is the last dimension. - * @param <U> the data type of the logits * @param <T> the number type of the operands - * @return the softmax cross entropy loss. Its type is the same as `logits` and its shape is the - * same as `labels` except that it does not have the last dimension of `labels`. + * @return the softmax cross entropy loss. Its type is the same as <code>logits</code> and its + * shape is the same as <code>labels</code> except that it does not have the last dimension of + * <code>labels</code>. */ - public <U extends TType, T extends TNumber> Operand<T> softmaxCrossEntropyWithLogits( - Operand<T> labels, Operand<U> logits, int axis) { - return NN.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); + public <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntropyWithLogits( + Operand<U> labels, Operand<T> logits, int axis) { + return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); } /** @@ -2050,22 +2062,51 @@ public <T extends TType> SpaceToDepth<T> spaceToDepth(Operand<T> input, Long blo } /** - * Computes sparse softmax cross entropy between `logits` and `labels`. + * Computes sparse softmax cross entropy between <code>logits</code> and <code>labels</code>. + * + * <p>Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + * <p><b>NOTE:</b> + * + * <p>For this operation, the probability of a given label is considered exclusive. That is, soft + * classes are not allowed, and the <code>labels</code> vector must provide a single specific + * index for the true class for each row of <code>logits</code> (each minibatch entry). For soft + * softmax classification with a probability distribution for each entry, {@link + * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. + * + * <p><b>WARNING:</b> + * + * <p>This op expects unscaled logits, since it performs a <code>softmax</code> on <code>logits + * </code> internally for efficiency. Do not call this op with the output of <code>softmax</code>, + * as it will produce incorrect results. + * + * <p>A common use case is to have logits of shape <code>[batchSize, numClasses]</code> and have + * labels of shape <code>[batchSize]</code>, but higher dimensions are supported, in which case + * the <code>dim</code>-th dimension is assumed to be of size <code>numClasses</code>. <code> + * logits</code> must have the <cod>dataType</cod> of <code>TFloat16</code>, <code>TFloat32</code> + * , or <code>TFloat64</code>, and <code>labels</code> must have the dtype of <code>TInt32</code> + * or <code>TInt64</code>. * * @param scope current scope - * @param labels `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of `labels` and - * result) and dtype `int32` or `int64`. Each entry in `labels` must be an index in `[0, - * num_classes)`. Other values will raise an exception when this op is run on CPU, and return - * `NaN` for corresponding loss and gradient rows on GPU. - * @param logits Per-label activations (typically a linear output) of shape `[d_0, d_1, ..., - * d_{r-1}, num_classes]` and dtype `float16`, `float32`, or `float64`. These activation - * energies are interpreted as unnormalized log probabilities. - * @return A `Tensor` of the same shape as `labels` and of the same type as `logits` with the - * softmax cross entropy loss. + * @param labels <code>Tensor</code> of shape <code>[d_0, d_1, ..., d_{r-1}]</code> (where <code>r + * </code> is rank of <code>labels</code> and result) and the dataType is <code>TInt32</code> + * or <code>TInt64</code>. Each entry in <code>labels</code> must be an index in <code>[0, + * numClasses)</code>. Other values will raise an exception when this op is run on CPU, and + * return <code>NaN</code> for corresponding loss and gradient rows on GPU. + * @param logits Per-label activations (typically a linear output) of shape <code>[d_0, d_1, ..., + * d_{r-1}, numClasses]</code> and dataType of <code>TFloat16</code>, <code>TFloat32</code>, + * or <code>TFloat64</code>. These activation energies are interpreted as unnormalized log + * probabilities. + * @return A <code>Tensor</code> of the same shape as <code>labels</code> and of the same type as + * <code>logits</code> with the softmax cross entropy loss. + * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank + * of the labels is not equal to the rank of the logits minus one. */ public <T extends TNumber, U extends TNumber> Operand sparseSoftmaxCrossEntropyWithLogits( Operand<T> labels, Operand<U> logits) { - return NN.sparseSoftmaxCrossEntropyWithLogits(scope, labels, logits); + return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits(scope, labels, logits); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Abort.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Abort.java index 53e9401dfa2..a84f2405b19 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Abort.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Abort.java @@ -104,7 +104,6 @@ public static Options exitWithoutError(Boolean exitWithoutError) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "Abort"; - private Abort(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssertThat.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssertThat.java index 950830b7462..dce70c04e5a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssertThat.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssertThat.java @@ -90,7 +90,6 @@ public static Options summarize(Long summarize) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "Assert"; - private AssertThat(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignAddVariableOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignAddVariableOp.java index 5adaccf15e0..53edc808882 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignAddVariableOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignAddVariableOp.java @@ -55,7 +55,6 @@ public static <T extends TType> AssignAddVariableOp create(Scope scope, Operand< /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "AssignAddVariableOp"; - private AssignAddVariableOp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignSubVariableOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignSubVariableOp.java index 4bb683c97d2..372a71b2168 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignSubVariableOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignSubVariableOp.java @@ -55,7 +55,6 @@ public static <T extends TType> AssignSubVariableOp create(Scope scope, Operand< /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "AssignSubVariableOp"; - private AssignSubVariableOp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignVariableOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignVariableOp.java index 90cabd12a24..ac08d62f9a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignVariableOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignVariableOp.java @@ -55,7 +55,6 @@ public static <T extends TType> AssignVariableOp create(Scope scope, Operand<?> /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "AssignVariableOp"; - private AssignVariableOp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierClose.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierClose.java index a777d684ec1..514f4f50edf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierClose.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierClose.java @@ -95,7 +95,6 @@ public static Options cancelPendingEnqueues(Boolean cancelPendingEnqueues) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BarrierClose"; - private BarrierClose(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierInsertMany.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierInsertMany.java index 31488738838..b652c11a35c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierInsertMany.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierInsertMany.java @@ -63,7 +63,6 @@ public static <T extends TType> BarrierInsertMany create(Scope scope, Operand<TS /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BarrierInsertMany"; - private BarrierInsertMany(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ConsumeMutexLock.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ConsumeMutexLock.java index 39dd4ffc7d9..094f6d5e4b0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ConsumeMutexLock.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ConsumeMutexLock.java @@ -57,7 +57,6 @@ public static ConsumeMutexLock create(Scope scope, Operand<?> mutexLock) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ConsumeMutexLock"; - private ConsumeMutexLock(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ControlTrigger.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ControlTrigger.java index 721112b8204..e40715c9f2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ControlTrigger.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ControlTrigger.java @@ -48,7 +48,6 @@ public static ControlTrigger create(Scope scope) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ControlTrigger"; - private ControlTrigger(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeleteSessionTensor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeleteSessionTensor.java index 50c7615a0ff..5f92cc26ca2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeleteSessionTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeleteSessionTensor.java @@ -50,7 +50,6 @@ public static DeleteSessionTensor create(Scope scope, Operand<TString> handle) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DeleteSessionTensor"; - private DeleteSessionTensor(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DestroyResourceOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DestroyResourceOp.java index e1958682ee1..8a427166874 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DestroyResourceOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DestroyResourceOp.java @@ -88,7 +88,6 @@ public static Options ignoreLookupError(Boolean ignoreLookupError) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DestroyResourceOp"; - private DestroyResourceOp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeviceIndex.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeviceIndex.java index 26f984e840d..f033d3fcc9d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeviceIndex.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeviceIndex.java @@ -68,6 +68,9 @@ public Output<TInt32> asOutput() { return index; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DeviceIndex"; + private Output<TInt32> index; private DeviceIndex(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTable.java index 5de2ca6ff07..48662ed420d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTable.java @@ -54,7 +54,6 @@ public static <T extends TType, U extends TType> InitializeTable create(Scope sc /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "InitializeTableV2"; - private InitializeTable(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTableFromTextFile.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTableFromTextFile.java index 0a88cea3ef2..2050c4d8628 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTableFromTextFile.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTableFromTextFile.java @@ -121,7 +121,6 @@ public static Options delimiter(String delimiter) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "InitializeTableFromTextFileV2"; - private InitializeTableFromTextFile(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableImport.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableImport.java index a94393a50f1..9884a40e3cb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableImport.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableImport.java @@ -57,7 +57,6 @@ public static <T extends TType, U extends TType> LookupTableImport create(Scope /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LookupTableImportV2"; - private LookupTableImport(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableInsert.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableInsert.java index c31784ea942..0f09ae25d1b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableInsert.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableInsert.java @@ -57,7 +57,6 @@ public static <T extends TType, U extends TType> LookupTableInsert create(Scope /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LookupTableInsertV2"; - private LookupTableInsert(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableRemove.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableRemove.java index 584e7e1325c..41463ad7539 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableRemove.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableRemove.java @@ -54,7 +54,6 @@ public static <T extends TType> LookupTableRemove create(Scope scope, Operand<?> /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LookupTableRemoveV2"; - private LookupTableRemove(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java index ea7581ef2c7..bad1e90554f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java @@ -145,7 +145,6 @@ public static Options sharedName(String sharedName) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "MapClear"; - private MapClear(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java index 5d72ce8f22f..9291b32d53b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java @@ -160,7 +160,6 @@ public static Options sharedName(String sharedName) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "MapStage"; - private MapStage(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NoOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NoOp.java index 862aabcd795..922b5d55ce3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NoOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NoOp.java @@ -46,7 +46,6 @@ public static NoOp create(Scope scope) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "NoOp"; - private NoOp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java index 29f4133ce09..05a1b7ab984 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java @@ -145,7 +145,6 @@ public static Options sharedName(String sharedName) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "OrderedMapClear"; - private OrderedMapClear(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java index b51f94c148a..7e02973e3c6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java @@ -162,7 +162,6 @@ public static Options sharedName(String sharedName) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "OrderedMapStage"; - private OrderedMapStage(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Print.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Print.java index 3e96c00d369..52b933329a0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Print.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Print.java @@ -105,7 +105,6 @@ public static Options end(String end) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "PrintV2"; - private Print(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterAdd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterAdd.java index 5383062823b..0966dd5fcc4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterAdd.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterAdd.java @@ -75,7 +75,6 @@ public static <T extends TNumber, U extends TType> ResourceScatterAdd create(Sco /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterAdd"; - private ResourceScatterAdd(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterDiv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterDiv.java index ed950863242..9560bddf284 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterDiv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterDiv.java @@ -75,7 +75,6 @@ public static <T extends TNumber, U extends TType> ResourceScatterDiv create(Sco /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterDiv"; - private ResourceScatterDiv(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMax.java index 7553fab4812..ce952ee19ba 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMax.java @@ -75,7 +75,6 @@ public static <T extends TNumber, U extends TType> ResourceScatterMax create(Sco /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterMax"; - private ResourceScatterMax(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMin.java index 68518b4c640..51ec6b7637e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMin.java @@ -75,7 +75,6 @@ public static <T extends TNumber, U extends TType> ResourceScatterMin create(Sco /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterMin"; - private ResourceScatterMin(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMul.java index f52b338de57..2d5f71e006d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMul.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMul.java @@ -75,7 +75,6 @@ public static <T extends TNumber, U extends TType> ResourceScatterMul create(Sco /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterMul"; - private ResourceScatterMul(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdAdd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdAdd.java index 5abfcbea5ee..11e45c33098 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdAdd.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdAdd.java @@ -125,7 +125,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterNdAdd"; - private ResourceScatterNdAdd(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMax.java index e24e3d68fef..82c1f766308 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMax.java @@ -91,6 +91,8 @@ public static Options useLocking(Boolean useLocking) { return new Options().useLocking(useLocking); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "ResourceScatterNdMax"; private ResourceScatterNdMax(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMin.java index 3ffc78afa87..88e107c65c7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMin.java @@ -91,6 +91,8 @@ public static Options useLocking(Boolean useLocking) { return new Options().useLocking(useLocking); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "ResourceScatterNdMin"; private ResourceScatterNdMin(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdSub.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdSub.java index c4b6060d611..267099b7cfc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdSub.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdSub.java @@ -125,7 +125,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterNdSub"; - private ResourceScatterNdSub(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdUpdate.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdUpdate.java index b47fb4a1367..4a1e875bc97 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdUpdate.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdUpdate.java @@ -127,7 +127,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterNdUpdate"; - private ResourceScatterNdUpdate(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterSub.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterSub.java index 2559ff21a93..7b772fab997 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterSub.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterSub.java @@ -75,7 +75,6 @@ public static <T extends TNumber, U extends TType> ResourceScatterSub create(Sco /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterSub"; - private ResourceScatterSub(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterUpdate.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterUpdate.java index eff04c6c08a..067ddf5f205 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterUpdate.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterUpdate.java @@ -66,7 +66,6 @@ public static <T extends TNumber, U extends TType> ResourceScatterUpdate create( /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterUpdate"; - private ResourceScatterUpdate(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceStridedSliceAssign.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceStridedSliceAssign.java index 2002140573b..4deb4c55f64 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceStridedSliceAssign.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceStridedSliceAssign.java @@ -176,7 +176,6 @@ public static Options shrinkAxisMask(Long shrinkAxisMask) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceStridedSliceAssign"; - private ResourceStridedSliceAssign(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMax.java index 851cbb16cf4..da94c783cae 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMax.java @@ -107,6 +107,9 @@ public Output<T> asOutput() { return outputRef; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "ScatterNdMax"; + private Output<T> outputRef; private ScatterNdMax(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMin.java index a3e3d4c9790..5aea70bc929 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMin.java @@ -107,6 +107,9 @@ public Output<T> asOutput() { return outputRef; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "ScatterNdMin"; + private Output<T> outputRef; private ScatterNdMin(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Send.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Send.java index bf3db1cd88a..d679b85319a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Send.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Send.java @@ -97,7 +97,6 @@ public static Options clientTerminated(Boolean clientTerminated) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "Send"; - private Send(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Stage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Stage.java index 408b6eca252..526462b02f4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Stage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Stage.java @@ -151,7 +151,6 @@ public static Options sharedName(String sharedName) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "Stage"; - private Stage(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java index 755e7ab72d9..60e51559f74 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java @@ -145,7 +145,6 @@ public static Options sharedName(String sharedName) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "StageClear"; - private StageClear(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayClose.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayClose.java index a16e856ae72..62180e8e5ff 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayClose.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayClose.java @@ -52,7 +52,6 @@ public static TensorArrayClose create(Scope scope, Operand<?> handle) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "TensorArrayCloseV3"; - private TensorArrayClose(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestCreateTreeVariable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestCreateTreeVariable.java index e647f58b2f3..5ca6ffa1cd0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestCreateTreeVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestCreateTreeVariable.java @@ -51,7 +51,6 @@ public static TensorForestCreateTreeVariable create(Scope scope, Operand<?> tree /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "TensorForestCreateTreeVariable"; - private TensorForestCreateTreeVariable(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestTreeDeserialize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestTreeDeserialize.java index 5fb704b2361..a5e1638035e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestTreeDeserialize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestTreeDeserialize.java @@ -51,7 +51,6 @@ public static TensorForestTreeDeserialize create(Scope scope, Operand<?> treeHan /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "TensorForestTreeDeserialize"; - private TensorForestTreeDeserialize(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMax.java index d040cf0639a..a14b20195af 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMax.java @@ -65,6 +65,9 @@ public Output<T> asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "TensorScatterMax"; + private Output<T> output; private TensorScatterNdMax(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMin.java index 797878d9c76..b202b72eebd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMin.java @@ -65,6 +65,9 @@ public Output<T> asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "TensorScatterMin"; + private Output<T> output; private TensorScatterNdMin(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdFullToShardShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdFullToShardShape.java index 51f2c2b5dde..6615d2ef9f6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdFullToShardShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdFullToShardShape.java @@ -68,6 +68,9 @@ public Output<T> asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "XlaSpmdFullToShardShape"; + private Output<T> output; private XlaSpmdFullToShardShape(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdShardToFullShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdShardToFullShape.java index 5a120fb6fb8..75e31c7c317 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdShardToFullShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdShardToFullShape.java @@ -70,6 +70,9 @@ public Output<T> asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "XlaSpmdShardToFullShape"; + private Output<T> output; private XlaSpmdShardToFullShape(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToTfRecord.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToTfRecord.java index 41617c9b690..114e11074dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToTfRecord.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToTfRecord.java @@ -54,7 +54,6 @@ public static DatasetToTfRecord create(Scope scope, Operand<?> inputDataset, Ope /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DatasetToTFRecord"; - private DatasetToTfRecord(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteIterator.java index ec3629a8eb7..69f3af096bb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteIterator.java @@ -51,7 +51,6 @@ public static DeleteIterator create(Scope scope, Operand<?> handle, Operand<?> d /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DeleteIterator"; - private DeleteIterator(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteMemoryCache.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteMemoryCache.java index 3c0f37dc409..21c33030b66 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteMemoryCache.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteMemoryCache.java @@ -49,7 +49,6 @@ public static DeleteMemoryCache create(Scope scope, Operand<?> handle, Operand<? /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DeleteMemoryCache"; - private DeleteMemoryCache(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteMultiDeviceIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteMultiDeviceIterator.java index fb77934d585..966d6a7dbf1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteMultiDeviceIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteMultiDeviceIterator.java @@ -53,7 +53,6 @@ public static DeleteMultiDeviceIterator create(Scope scope, Operand<?> multiDevi /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DeleteMultiDeviceIterator"; - private DeleteMultiDeviceIterator(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeserializeIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeserializeIterator.java index 528fd09bc53..4f772fd5028 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeserializeIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeserializeIterator.java @@ -52,7 +52,6 @@ public static DeserializeIterator create(Scope scope, Operand<?> resourceHandle, /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DeserializeIterator"; - private DeserializeIterator(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InitializeTableFromDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InitializeTableFromDataset.java index 527c951377b..05a263a4ec1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InitializeTableFromDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InitializeTableFromDataset.java @@ -46,6 +46,8 @@ public static InitializeTableFromDataset create(Scope scope, Operand<?> tableHan return new InitializeTableFromDataset(opBuilder.build()); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "InitializeTableFromDataset"; private InitializeTableFromDataset(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MakeIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MakeIterator.java index 685574a92d5..4aace25184e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MakeIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MakeIterator.java @@ -54,7 +54,6 @@ public static MakeIterator create(Scope scope, Operand<?> dataset, Operand<?> it /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "MakeIterator"; - private MakeIterator(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RegisterDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RegisterDataset.java index 7e0695768c2..5705413165c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RegisterDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RegisterDataset.java @@ -65,6 +65,9 @@ public Output<TInt64> asOutput() { return datasetId; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RegisterDataset"; + private Output<TInt64> datasetId; private RegisterDataset(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java index 1f6d2697497..c5703e8e85c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java @@ -119,7 +119,7 @@ public Output<TType> asOutput() { } /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "ShuffleAndRepeatDataset"; + public static final String OP_NAME = "ShuffleAndRepeatDatasetV2"; private Output<?> handle; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java index ce3de6f787f..3dd522e319c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java @@ -117,7 +117,7 @@ public Output<TType> asOutput() { } /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "ShuffleDatasetV2"; + public static final String OP_NAME = "ShuffleDatasetV3"; private Output<?> handle; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/CompressElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/CompressElement.java index e56a8cde614..9e4bfb34a8b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/CompressElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/CompressElement.java @@ -60,6 +60,9 @@ public Output<TType> asOutput() { return (Output<TType>) compressed; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "CompressElement"; + private Output<?> compressed; private CompressElement(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java index 4623ec2ea5d..b3e853f9de4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java @@ -122,6 +122,9 @@ public Output<TType> asOutput() { return (Output<TType>) handle; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DataServiceDataset"; + private Output<?> handle; private DataServiceDataset(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DatasetToTFRecord.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DatasetToTFRecord.java index ee0f494fa73..6e0a0b8f2dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DatasetToTFRecord.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DatasetToTFRecord.java @@ -54,7 +54,6 @@ public static DatasetToTFRecord create(Scope scope, Operand<?> inputDataset, Ope /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ExperimentalDatasetToTFRecord"; - private DatasetToTFRecord(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DummyIterationCounter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DummyIterationCounter.java index b2febce6f81..72f83847285 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DummyIterationCounter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DummyIterationCounter.java @@ -56,6 +56,9 @@ public Output<TType> asOutput() { return (Output<TType>) handle; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DummyIterationCounter"; + private Output<?> handle; private DummyIterationCounter(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorSetSummaryWriter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorSetSummaryWriter.java index ad63ff056d8..1af246d8313 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorSetSummaryWriter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorSetSummaryWriter.java @@ -50,7 +50,6 @@ public static StatsAggregatorSetSummaryWriter create(Scope scope, Operand<?> sta /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "StatsAggregatorSetSummaryWriter"; - private StatsAggregatorSetSummaryWriter(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java index 5d94d8699ab..c5732154e94 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java @@ -76,6 +76,9 @@ public Iterator<Operand<TType>> iterator() { return (Iterator) components.iterator(); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "UncompressElement"; + private List<Output<?>> components; private UncompressElement(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateEnsemble.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateEnsemble.java index 04613900567..8841988b36d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateEnsemble.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateEnsemble.java @@ -54,7 +54,6 @@ public static BoostedTreesCreateEnsemble create(Scope scope, Operand<?> treeEnse /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesCreateEnsemble"; - private BoostedTreesCreateEnsemble(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateQuantileStreamResource.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateQuantileStreamResource.java index 59362970a58..802a61ecb2a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateQuantileStreamResource.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateQuantileStreamResource.java @@ -88,7 +88,6 @@ public static Options maxElements(Long maxElements) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesCreateQuantileStreamResource"; - private BoostedTreesCreateQuantileStreamResource(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesDeserializeEnsemble.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesDeserializeEnsemble.java index 6fd83d5785d..15371fb4df9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesDeserializeEnsemble.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesDeserializeEnsemble.java @@ -56,7 +56,6 @@ public static BoostedTreesDeserializeEnsemble create(Scope scope, Operand<?> tre /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesDeserializeEnsemble"; - private BoostedTreesDeserializeEnsemble(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceAddSummaries.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceAddSummaries.java index 76480c4be6d..418ff3b2ff6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceAddSummaries.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceAddSummaries.java @@ -56,7 +56,6 @@ public static BoostedTreesQuantileStreamResourceAddSummaries create(Scope scope, /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesQuantileStreamResourceAddSummaries"; - private BoostedTreesQuantileStreamResourceAddSummaries(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceDeserialize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceDeserialize.java index 82066e267d2..6efb58ed60c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceDeserialize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceDeserialize.java @@ -54,7 +54,6 @@ public static BoostedTreesQuantileStreamResourceDeserialize create(Scope scope, /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesQuantileStreamResourceDeserialize"; - private BoostedTreesQuantileStreamResourceDeserialize(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceFlush.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceFlush.java index 359b7b63ff6..cc10434a582 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceFlush.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceFlush.java @@ -97,7 +97,6 @@ public static Options generateQuantiles(Boolean generateQuantiles) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesQuantileStreamResourceFlush"; - private BoostedTreesQuantileStreamResourceFlush(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsemble.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsemble.java index c1b7bb44559..e6ddcf3d2da 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsemble.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsemble.java @@ -79,7 +79,6 @@ public static BoostedTreesUpdateEnsemble create(Scope scope, Operand<?> treeEnse /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesUpdateEnsemble"; - private BoostedTreesUpdateEnsemble(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsembleV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsembleV2.java index afd9d646c2f..ceaff116fd1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsembleV2.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsembleV2.java @@ -118,7 +118,6 @@ public static Options logitsDimension(Long logitsDimension) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesUpdateEnsembleV2"; - private BoostedTreesUpdateEnsembleV2(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractGlimpse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractGlimpse.java index 172b24f74ac..05bc1d924b3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractGlimpse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractGlimpse.java @@ -199,7 +199,7 @@ public Output<TFloat32> asOutput() { } /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "ExtractGlimpse"; + public static final String OP_NAME = "ExtractGlimpseV2"; private Output<TFloat32> glimpse; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueClose.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueClose.java index fb55f708140..ea7791f143e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueClose.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueClose.java @@ -91,7 +91,6 @@ public static Options cancelPendingEnqueues(Boolean cancelPendingEnqueues) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "QueueCloseV2"; - private QueueClose(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueue.java index 546981e8abf..a159b0cd17c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueue.java @@ -96,7 +96,6 @@ public static Options timeoutMs(Long timeoutMs) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "QueueEnqueueV2"; - private QueueEnqueue(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueueMany.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueueMany.java index 48df5d3b9d3..b1f9cbd6807 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueueMany.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueueMany.java @@ -101,7 +101,6 @@ public static Options timeoutMs(Long timeoutMs) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "QueueEnqueueManyV2"; - private QueueEnqueueMany(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderReset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderReset.java index 6e3de01134b..243d4a72080 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderReset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderReset.java @@ -49,7 +49,6 @@ public static ReaderReset create(Scope scope, Operand<?> readerHandle) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ReaderResetV2"; - private ReaderReset(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderRestoreState.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderRestoreState.java index b0abea1257c..431ba079ffc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderRestoreState.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderRestoreState.java @@ -56,7 +56,6 @@ public static ReaderRestoreState create(Scope scope, Operand<?> readerHandle, Op /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ReaderRestoreStateV2"; - private ReaderRestoreState(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/WriteFile.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/WriteFile.java index d9fba243fc3..d1c9dd9b9c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/WriteFile.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/WriteFile.java @@ -54,7 +54,6 @@ public static WriteFile create(Scope scope, Operand<TString> filename, Operand<T /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteFile"; - private WriteFile(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BandedTriangularSolve.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BandedTriangularSolve.java index 6f96958c39e..8c1e184e6c9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BandedTriangularSolve.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BandedTriangularSolve.java @@ -113,6 +113,9 @@ public Output<T> asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BandedTriangularSolve"; + private Output<T> output; private BandedTriangularSolve(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java index 53e1ac83c32..45dcd2b8e4c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java @@ -59,6 +59,9 @@ public Output<T> asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselI0"; + private Output<T> y; private BesselI0(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java index 638f6b06972..148758aa5a4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java @@ -59,6 +59,9 @@ public Output<T> asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselI1"; + private Output<T> y; private BesselI1(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java index 165be081102..e38d559f4ae 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java @@ -112,6 +112,9 @@ public Output<U> asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DenseBincount"; + private Output<U> output; private DenseBincount(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java index bc73a0c9c02..8d2184a49cb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java @@ -59,6 +59,9 @@ public Output<T> asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselJ0"; + private Output<T> y; private BesselJ0(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java index 4fd21c42288..d8f9621a36c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java @@ -59,6 +59,9 @@ public Output<T> asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselJ1"; + private Output<T> y; private BesselJ1(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java index 8f3c540b185..eaae243f83f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java @@ -59,6 +59,9 @@ public Output<T> asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselK0"; + private Output<T> y; private BesselK0(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java index 1a8f9761c08..c57ae64e233 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java @@ -59,6 +59,9 @@ public Output<T> asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselK0e"; + private Output<T> y; private BesselK0e(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java index bcaaf6f6f9c..1858d25fe3d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java @@ -59,6 +59,9 @@ public Output<T> asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselK1"; + private Output<T> y; private BesselK1(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java index c6590805d54..e4a5cc23efd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java @@ -59,6 +59,9 @@ public Output<T> asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselK1e"; + private Output<T> y; private BesselK1e(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java index 86843a30939..9228d1b6145 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java @@ -59,6 +59,9 @@ public Output<T> asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselY0"; + private Output<T> y; private BesselY0(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java index 2cdc4ad7df0..0461416b808 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java @@ -59,6 +59,9 @@ public Output<T> asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselY1"; + private Output<T> y; private BesselY1(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java index fc1636d8d64..1e0224aa9ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java @@ -115,6 +115,9 @@ public Output<U> asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RaggedBincount"; + private Output<U> output; private RaggedBincount(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java index 07b364f6ebb..4829e49488b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java @@ -140,6 +140,9 @@ public Output<TInt64> outputDenseShape() { return outputDenseShape; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RaggedCountSparseOutput"; + private Output<TInt64> outputIndices; private Output<U> outputValues; private Output<TInt64> outputDenseShape; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java index fa6e811969b..9ea32878257 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java @@ -94,6 +94,9 @@ public Output<U> outputRowSplits() { return outputRowSplits; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RaggedCross"; + private Output<T> outputValues; private Output<U> outputRowSplits; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/AnonymousSeedGenerator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/AnonymousSeedGenerator.java index f55c3222977..c724bb6d110 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/AnonymousSeedGenerator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/AnonymousSeedGenerator.java @@ -63,6 +63,9 @@ public Output<?> deleter() { return deleter; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "AnonymousSeedGenerator"; + private Output<?> handle; private Output<?> deleter; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteRandomSeedGenerator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteRandomSeedGenerator.java index 9bc34d98ebb..23b154f9d75 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteRandomSeedGenerator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteRandomSeedGenerator.java @@ -49,7 +49,6 @@ public static DeleteRandomSeedGenerator create(Scope scope, Operand<?> handle, O /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DeleteRandomSeedGenerator"; - private DeleteRandomSeedGenerator(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteSeedGenerator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteSeedGenerator.java index 2872ea12aff..16982946d1f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteSeedGenerator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteSeedGenerator.java @@ -46,6 +46,8 @@ public static DeleteSeedGenerator create(Scope scope, Operand<?> handle, Operand return new DeleteSeedGenerator(opBuilder.build()); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DeleteSeedGenerator"; private DeleteSeedGenerator(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RngSkip.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RngSkip.java index e3411c3b989..f41cff35b04 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RngSkip.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RngSkip.java @@ -58,7 +58,6 @@ public static RngSkip create(Scope scope, Operand<?> resource, Operand<TInt64> a /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "RngSkip"; - private RngSkip(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java index 053db14c986..179160463c7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java @@ -72,6 +72,9 @@ public Output<V> asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "StatelessParameterizedTruncatedNormal"; + private Output<V> output; private StatelessParameterizedTruncatedNormal(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java index 92e58ba293f..dd537fa2d68 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java @@ -56,6 +56,9 @@ public Output<TType> asOutput() { return (Output<TType>) handle; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DummySeedGenerator"; + private Output<?> handle; private DummySeedGenerator(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java index 62c489cab7b..ed390a7ba47 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java @@ -132,6 +132,9 @@ public Output<TInt64> outputDenseShape() { return outputDenseShape; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DenseCountSparseOutput"; + private Output<TInt64> outputIndices; private Output<U> outputValues; private Output<TInt64> outputDenseShape; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorApplyGradient.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorApplyGradient.java index 3f36812ea57..328fe0c49ea 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorApplyGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorApplyGradient.java @@ -69,7 +69,6 @@ public static <T extends TType> SparseAccumulatorApplyGradient create(Scope scop /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "SparseAccumulatorApplyGradient"; - private SparseAccumulatorApplyGradient(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java index 7902e8544dd..344e27f1346 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java @@ -117,6 +117,9 @@ public Output<U> asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "SparseBincount"; + private Output<U> output; private SparseBincount(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java index 36230bc774e..5e5566db5ec 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java @@ -136,6 +136,9 @@ public Output<TInt64> outputDenseShape() { return outputDenseShape; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "SparseCountSparseOutput"; + private Output<TInt64> outputIndices; private Output<U> outputValues; private Output<TInt64> outputDenseShape; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCross.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCross.java index 06113f0315b..1cd471349c2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCross.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCross.java @@ -118,7 +118,7 @@ public Output<TInt64> outputShape() { } /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "SparseCross"; + public static final String OP_NAME = "SparseCrossV2"; private Output<TInt64> outputIndices; private Output<TString> outputValues; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCrossHashed.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCrossHashed.java index 9e7cc9b1e6c..2fc6976079e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCrossHashed.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCrossHashed.java @@ -122,6 +122,9 @@ public Output<TInt64> outputShape() { return outputShape; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "SparseCrossHashed"; + private Output<TInt64> outputIndices; private Output<TInt64> outputValues; private Output<TInt64> outputShape; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CloseSummaryWriter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CloseSummaryWriter.java index ff9735f0b07..f5d95d50976 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CloseSummaryWriter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CloseSummaryWriter.java @@ -47,7 +47,6 @@ public static CloseSummaryWriter create(Scope scope, Operand<?> writer) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "CloseSummaryWriter"; - private CloseSummaryWriter(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryDbWriter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryDbWriter.java index 61e7405f74d..8e40aa798d6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryDbWriter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryDbWriter.java @@ -56,7 +56,6 @@ public static CreateSummaryDbWriter create(Scope scope, Operand<?> writer, Opera /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "CreateSummaryDbWriter"; - private CreateSummaryDbWriter(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryFileWriter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryFileWriter.java index d113ebcf3f6..e429fab20e2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryFileWriter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryFileWriter.java @@ -57,7 +57,6 @@ public static CreateSummaryFileWriter create(Scope scope, Operand<?> writer, Ope /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "CreateSummaryFileWriter"; - private CreateSummaryFileWriter(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/FlushSummaryWriter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/FlushSummaryWriter.java index 6b1e610c632..e1586542972 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/FlushSummaryWriter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/FlushSummaryWriter.java @@ -47,7 +47,6 @@ public static FlushSummaryWriter create(Scope scope, Operand<?> writer) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "FlushSummaryWriter"; - private FlushSummaryWriter(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImportEvent.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImportEvent.java index 9b6dc173abe..7bd97de571e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImportEvent.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImportEvent.java @@ -50,7 +50,6 @@ public static ImportEvent create(Scope scope, Operand<?> writer, Operand<TString /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ImportEvent"; - private ImportEvent(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteAudioSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteAudioSummary.java index 306b1d88e6e..8cdf23f4982 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteAudioSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteAudioSummary.java @@ -92,7 +92,6 @@ public static Options maxOutputs(Long maxOutputs) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteAudioSummary"; - private WriteAudioSummary(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteGraphSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteGraphSummary.java index fa10196ca72..dc0cbe0f222 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteGraphSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteGraphSummary.java @@ -53,7 +53,6 @@ public static WriteGraphSummary create(Scope scope, Operand<?> writer, Operand<T /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteGraphSummary"; - private WriteGraphSummary(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteHistogramSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteHistogramSummary.java index a6f50ccdf8b..2069cefafa4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteHistogramSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteHistogramSummary.java @@ -57,7 +57,6 @@ public static <T extends TNumber> WriteHistogramSummary create(Scope scope, Oper /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteHistogramSummary"; - private WriteHistogramSummary(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java index 286d584d695..757ddf59a1c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java @@ -94,7 +94,6 @@ public static Options maxImages(Long maxImages) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteImageSummary"; - private WriteImageSummary(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteRawProtoSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteRawProtoSummary.java index 524b56bed7a..75499c1ff69 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteRawProtoSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteRawProtoSummary.java @@ -53,7 +53,6 @@ public static WriteRawProtoSummary create(Scope scope, Operand<?> writer, Operan /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteRawProtoSummary"; - private WriteRawProtoSummary(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java index 2317db7bdeb..f173651001a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java @@ -57,7 +57,6 @@ public static <T extends TNumber> WriteScalarSummary create(Scope scope, Operand /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteScalarSummary"; - private WriteScalarSummary(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteSummary.java index 6d257f948e8..5404e593f27 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteSummary.java @@ -58,7 +58,6 @@ public static <T extends TType> WriteSummary create(Scope scope, Operand<?> writ /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteSummary"; - private WriteSummary(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ConfigureTPUEmbedding.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ConfigureTPUEmbedding.java index 1905a3082b3..76bccd51f83 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ConfigureTPUEmbedding.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ConfigureTPUEmbedding.java @@ -48,7 +48,6 @@ public static ConfigureTPUEmbedding create(Scope scope, String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ConfigureTPUEmbedding"; - private ConfigureTPUEmbedding(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingIntegerBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingIntegerBatch.java index 4198c38e11c..0a1a80c7a0a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingIntegerBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingIntegerBatch.java @@ -93,7 +93,6 @@ public static Options deviceOrdinal(Long deviceOrdinal) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "EnqueueTPUEmbeddingIntegerBatch"; - private EnqueueTPUEmbeddingIntegerBatch(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java index c605dafbc87..bf4da86d05d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java @@ -177,6 +177,8 @@ public static Options maxSequenceLengths(List<Long> maxSequenceLengths) { return new Options().maxSequenceLengths(maxSequenceLengths); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "EnqueueTPUEmbeddingRaggedTensorBatch"; private EnqueueTPUEmbeddingRaggedTensorBatch(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java index 23288018938..2cb7dfb674b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java @@ -146,7 +146,6 @@ public static Options combiners(List<String> combiners) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "EnqueueTPUEmbeddingSparseBatch"; - private EnqueueTPUEmbeddingSparseBatch(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java index 59018e1b3e5..3d93c6a0f71 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java @@ -178,7 +178,6 @@ public static Options maxSequenceLengths(List<Long> maxSequenceLengths) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "EnqueueTPUEmbeddingSparseTensorBatch"; - private EnqueueTPUEmbeddingSparseTensorBatch(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueue.java index 9c79df444e3..391d51a9ab0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueue.java @@ -135,7 +135,6 @@ public static Options deviceOrdinal(Long deviceOrdinal) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "InfeedEnqueue"; - private InfeedEnqueue(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueuePrelinearizedBuffer.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueuePrelinearizedBuffer.java index b1d32f70ec9..9344352791c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueuePrelinearizedBuffer.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueuePrelinearizedBuffer.java @@ -84,7 +84,6 @@ public static Options deviceOrdinal(Long deviceOrdinal) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "InfeedEnqueuePrelinearizedBuffer"; - private InfeedEnqueuePrelinearizedBuffer(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueueTuple.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueueTuple.java index d2a95d84244..b439df84f71 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueueTuple.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueueTuple.java @@ -124,7 +124,6 @@ public static Options deviceOrdinal(Long deviceOrdinal) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "InfeedEnqueueTuple"; - private InfeedEnqueueTuple(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParameters.java index 9e60fae350e..744688cee23 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParameters.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingADAMParameters"; - private LoadTPUEmbeddingADAMParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParametersGradAccumDebug.java index 58cfa5cf465..63df2e6aa79 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParametersGradAccumDebug.java @@ -137,7 +137,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingADAMParametersGradAccumDebug"; - private LoadTPUEmbeddingADAMParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParameters.java index e4f4228f0f1..43535a2aff8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParameters.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingAdadeltaParameters"; - private LoadTPUEmbeddingAdadeltaParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParametersGradAccumDebug.java index 76af15dc0b6..ce1b759ee60 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParametersGradAccumDebug.java @@ -137,7 +137,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingAdadeltaParametersGradAccumDebug"; - private LoadTPUEmbeddingAdadeltaParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParameters.java index dc4f5c62341..f9e16c5b5d6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParameters.java @@ -133,7 +133,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingAdagradParameters"; - private LoadTPUEmbeddingAdagradParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParametersGradAccumDebug.java index 6551f875f2d..7f8df653745 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParametersGradAccumDebug.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingAdagradParametersGradAccumDebug"; - private LoadTPUEmbeddingAdagradParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingCenteredRMSPropParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingCenteredRMSPropParameters.java index d4a0103654c..f0b704cfaa1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingCenteredRMSPropParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingCenteredRMSPropParameters.java @@ -137,7 +137,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingCenteredRMSPropParameters"; - private LoadTPUEmbeddingCenteredRMSPropParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParameters.java index a65301f6348..c96edf58894 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParameters.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingFTRLParameters"; - private LoadTPUEmbeddingFTRLParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParametersGradAccumDebug.java index 5a1c165428d..f0a85bd945a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParametersGradAccumDebug.java @@ -137,7 +137,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingFTRLParametersGradAccumDebug"; - private LoadTPUEmbeddingFTRLParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMDLAdagradLightParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMDLAdagradLightParameters.java index 407cf842f19..f418a70cc8c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMDLAdagradLightParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMDLAdagradLightParameters.java @@ -137,7 +137,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingMDLAdagradLightParameters"; - private LoadTPUEmbeddingMDLAdagradLightParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParameters.java index 35b8479749b..718bdc24f5c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParameters.java @@ -133,7 +133,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingMomentumParameters"; - private LoadTPUEmbeddingMomentumParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParametersGradAccumDebug.java index babc2de15fd..424c3c846c0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParametersGradAccumDebug.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingMomentumParametersGradAccumDebug"; - private LoadTPUEmbeddingMomentumParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParameters.java index 0ebad625abe..7b7265e9b82 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParameters.java @@ -133,7 +133,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingProximalAdagradParameters"; - private LoadTPUEmbeddingProximalAdagradParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug.java index 80b05d47203..c18d2cf22f5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug"; - private LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParameters.java index 651c8e189c4..2a96916c4f5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParameters.java @@ -125,6 +125,8 @@ public static Options config(String config) { return new Options().config(config); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "LoadTPUEmbeddingProximalYogiParameters"; private LoadTPUEmbeddingProximalYogiParameters(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParametersGradAccumDebug.java index 274accba1c9..e863dc554d6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParametersGradAccumDebug.java @@ -127,6 +127,8 @@ public static Options config(String config) { return new Options().config(config); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "LoadTPUEmbeddingProximalYogiParametersGradAccumDebug"; private LoadTPUEmbeddingProximalYogiParametersGradAccumDebug(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParameters.java index d0e39d22edb..1f747282d55 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParameters.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingRMSPropParameters"; - private LoadTPUEmbeddingRMSPropParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParametersGradAccumDebug.java index 98f1043a768..a7c8ed4812c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParametersGradAccumDebug.java @@ -137,7 +137,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingRMSPropParametersGradAccumDebug"; - private LoadTPUEmbeddingRMSPropParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParameters.java index ca881823239..769d3436eda 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParameters.java @@ -131,7 +131,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingStochasticGradientDescentParameters"; - private LoadTPUEmbeddingStochasticGradientDescentParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java index 76a13a489ad..e408844e484 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java @@ -130,6 +130,8 @@ public static Options config(String config) { return new Options().config(config); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"; private LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueue.java index 46ee54430d9..5b5f059a81e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueue.java @@ -49,7 +49,6 @@ public static <T extends TType> OutfeedEnqueue create(Scope scope, Operand<T> in /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "OutfeedEnqueue"; - private OutfeedEnqueue(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueueTuple.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueueTuple.java index 25d110f1114..8bfd04b9a2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueueTuple.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueueTuple.java @@ -50,7 +50,6 @@ public static OutfeedEnqueueTuple create(Scope scope, Iterable<Operand<?>> input /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "OutfeedEnqueueTuple"; - private OutfeedEnqueueTuple(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParameters.java index eaee4fdabc4..a46cae90359 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParameters.java @@ -137,6 +137,9 @@ public Output<TFloat32> m() { return m; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RetrieveTPUEmbeddingProximalYogiParameters"; + private Output<TFloat32> parameters; private Output<TFloat32> v; private Output<TFloat32> m; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug.java index ec57d8cb424..55535a573f6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug.java @@ -143,6 +143,9 @@ public Output<TFloat32> gradientAccumulators() { return gradientAccumulators; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug"; + private Output<TFloat32> parameters; private Output<TFloat32> v; private Output<TFloat32> m; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java index f649b4d01fa..9f35ffcd8e0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java @@ -139,6 +139,9 @@ public Output<TFloat32> gradientAccumulators() { return gradientAccumulators; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"; + private Output<TFloat32> parameters; private Output<TFloat32> gradientAccumulators; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/SendTPUEmbeddingGradients.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/SendTPUEmbeddingGradients.java index 25d25c1d5bf..482080bde5d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/SendTPUEmbeddingGradients.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/SendTPUEmbeddingGradients.java @@ -64,7 +64,6 @@ public static SendTPUEmbeddingGradients create(Scope scope, Iterable<Operand<TFl /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "SendTPUEmbeddingGradients"; - private SendTPUEmbeddingGradients(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ShutdownDistributedTPU.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ShutdownDistributedTPU.java index 213ed8e6050..9ce924d153a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ShutdownDistributedTPU.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ShutdownDistributedTPU.java @@ -47,7 +47,6 @@ public static ShutdownDistributedTPU create(Scope scope) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ShutdownDistributedTPU"; - private ShutdownDistributedTPU(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/TPUReplicateMetadata.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/TPUReplicateMetadata.java index 4cc6056c8bc..3fb9b782f9c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/TPUReplicateMetadata.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/TPUReplicateMetadata.java @@ -252,7 +252,6 @@ public static Options allowSoftPlacement(Boolean allowSoftPlacement) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "TPUReplicateMetadata"; - private TPUReplicateMetadata(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorApplyGradient.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorApplyGradient.java index 82583bd691e..504acba4c96 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorApplyGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorApplyGradient.java @@ -58,7 +58,6 @@ public static <T extends TType> AccumulatorApplyGradient create(Scope scope, Ope /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "AccumulatorApplyGradient"; - private AccumulatorApplyGradient(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorSetGlobalStep.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorSetGlobalStep.java index b57bb702669..9039d3a654d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorSetGlobalStep.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorSetGlobalStep.java @@ -56,7 +56,6 @@ public static AccumulatorSetGlobalStep create(Scope scope, Operand<TString> hand /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "AccumulatorSetGlobalStep"; - private AccumulatorSetGlobalStep(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/MergeV2Checkpoints.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/MergeV2Checkpoints.java index 4fe4a27171a..986553a2d8e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/MergeV2Checkpoints.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/MergeV2Checkpoints.java @@ -96,7 +96,6 @@ public static Options deleteOldDirs(Boolean deleteOldDirs) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "MergeV2Checkpoints"; - private MergeV2Checkpoints(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/NegTrain.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/NegTrain.java index f43928961a9..b3e5316ad7e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/NegTrain.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/NegTrain.java @@ -68,7 +68,6 @@ public static NegTrain create(Scope scope, Operand<TFloat32> wIn, Operand<TFloat /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "NegTrain"; - private NegTrain(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorApplyGradient.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorApplyGradient.java index ca29897ca3d..560c3400c3b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorApplyGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorApplyGradient.java @@ -56,7 +56,6 @@ public static <T extends TType> ResourceAccumulatorApplyGradient create(Scope sc /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceAccumulatorApplyGradient"; - private ResourceAccumulatorApplyGradient(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorSetGlobalStep.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorSetGlobalStep.java index e04784aef48..37570909340 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorSetGlobalStep.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorSetGlobalStep.java @@ -54,7 +54,6 @@ public static ResourceAccumulatorSetGlobalStep create(Scope scope, Operand<?> ha /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceAccumulatorSetGlobalStep"; - private ResourceAccumulatorSetGlobalStep(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdaMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdaMax.java index 5efab216739..169da75fecd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdaMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdaMax.java @@ -107,7 +107,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAdaMax"; - private ResourceApplyAdaMax(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdadelta.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdadelta.java index 0121155a2ef..4323a39de45 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdadelta.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdadelta.java @@ -103,7 +103,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAdadelta"; - private ResourceApplyAdadelta(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagrad.java index 8868193464f..b60ad1fc0e0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagrad.java @@ -117,7 +117,6 @@ public static Options updateSlots(Boolean updateSlots) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAdagradV2"; - private ResourceApplyAdagrad(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagradDa.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagradDa.java index 7f5b26056ac..7c6f06634ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagradDa.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagradDa.java @@ -101,7 +101,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAdagradDA"; - private ResourceApplyAdagradDa(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdam.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdam.java index 20a07b4865d..4a1aea5d355 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdam.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdam.java @@ -130,7 +130,6 @@ public static Options useNesterov(Boolean useNesterov) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAdam"; - private ResourceApplyAdam(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdamWithAmsgrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdamWithAmsgrad.java index ec13f10038d..a436bc7fdd2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdamWithAmsgrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdamWithAmsgrad.java @@ -114,7 +114,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAdamWithAmsgrad"; - private ResourceApplyAdamWithAmsgrad(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAddSign.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAddSign.java index e64354ec3bb..85c9c587979 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAddSign.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAddSign.java @@ -104,7 +104,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAddSign"; - private ResourceApplyAddSign(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyCenteredRmsProp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyCenteredRmsProp.java index c4c8d7c9ec7..6fc3a8a02ff 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyCenteredRmsProp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyCenteredRmsProp.java @@ -123,7 +123,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyCenteredRMSProp"; - private ResourceApplyCenteredRmsProp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyFtrl.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyFtrl.java index c9de01ad14d..e69b6b99959 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyFtrl.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyFtrl.java @@ -120,18 +120,16 @@ public static <T extends TType> ResourceApplyFtrl create(Scope scope, Operand<?> public static Options useLocking(Boolean useLocking) { return new Options().useLocking(useLocking); } - - /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "ResourceApplyFtrlV2"; - + /** * @param multiplyLinearByLr */ public static Options multiplyLinearByLr(Boolean multiplyLinearByLr) { return new Options().multiplyLinearByLr(multiplyLinearByLr); } - + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "ResourceApplyFtrlV2"; private ResourceApplyFtrl(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyGradientDescent.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyGradientDescent.java index 0e495bdd651..f33c6b9ca87 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyGradientDescent.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyGradientDescent.java @@ -90,7 +90,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyGradientDescent"; - private ResourceApplyGradientDescent(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyKerasMomentum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyKerasMomentum.java index 3a986d617eb..3922439dcad 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyKerasMomentum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyKerasMomentum.java @@ -124,7 +124,6 @@ public static Options useNesterov(Boolean useNesterov) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyKerasMomentum"; - private ResourceApplyKerasMomentum(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyMomentum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyMomentum.java index c441193d864..c554c8a939d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyMomentum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyMomentum.java @@ -124,7 +124,6 @@ public static Options useNesterov(Boolean useNesterov) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyMomentum"; - private ResourceApplyMomentum(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyPowerSign.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyPowerSign.java index c1ba8b0ebd7..662c2253264 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyPowerSign.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyPowerSign.java @@ -104,7 +104,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyPowerSign"; - private ResourceApplyPowerSign(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalAdagrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalAdagrad.java index b51ce4698e1..8036d891e33 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalAdagrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalAdagrad.java @@ -100,7 +100,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyProximalAdagrad"; - private ResourceApplyProximalAdagrad(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalGradientDescent.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalGradientDescent.java index 7f9c4f4e52c..3b217c88c67 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalGradientDescent.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalGradientDescent.java @@ -97,7 +97,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyProximalGradientDescent"; - private ResourceApplyProximalGradientDescent(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyRmsProp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyRmsProp.java index 4c400f1aaa1..ae42295c1f7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyRmsProp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyRmsProp.java @@ -113,7 +113,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyRMSProp"; - private ResourceApplyRmsProp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdadelta.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdadelta.java index 9a50137d196..baea98fc1f7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdadelta.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdadelta.java @@ -101,7 +101,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyAdadelta"; - private ResourceSparseApplyAdadelta(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagrad.java index 69a3e775622..f7816e78d0c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagrad.java @@ -120,7 +120,6 @@ public static Options updateSlots(Boolean updateSlots) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyAdagrad"; - private ResourceSparseApplyAdagrad(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradDa.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradDa.java index 4f3189f074c..417eca86a80 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradDa.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradDa.java @@ -104,7 +104,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyAdagradDA"; - private ResourceSparseApplyAdagradDa(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradV2.java index 30e6c19da15..f60d192c368 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradV2.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradV2.java @@ -121,7 +121,6 @@ public static Options updateSlots(Boolean updateSlots) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyAdagradV2"; - private ResourceSparseApplyAdagradV2(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyCenteredRmsProp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyCenteredRmsProp.java index ab9b9d3c38d..d6806c36abf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyCenteredRmsProp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyCenteredRmsProp.java @@ -124,7 +124,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyCenteredRMSProp"; - private ResourceSparseApplyCenteredRmsProp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyFtrl.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyFtrl.java index 84caa503dec..a13382272c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyFtrl.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyFtrl.java @@ -125,18 +125,15 @@ public static Options useLocking(Boolean useLocking) { return new Options().useLocking(useLocking); } - - /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "ResourceSparseApplyFtrlV2"; - /** * @param multiplyLinearByLr */ public static Options multiplyLinearByLr(Boolean multiplyLinearByLr) { return new Options().multiplyLinearByLr(multiplyLinearByLr); } - + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "ResourceSparseApplyFtrlV2"; private ResourceSparseApplyFtrl(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyKerasMomentum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyKerasMomentum.java index 0284564f78c..b385403f989 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyKerasMomentum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyKerasMomentum.java @@ -129,7 +129,6 @@ public static Options useNesterov(Boolean useNesterov) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyKerasMomentum"; - private ResourceSparseApplyKerasMomentum(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyMomentum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyMomentum.java index 5199932b5bc..bc303bfbbf0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyMomentum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyMomentum.java @@ -129,7 +129,6 @@ public static Options useNesterov(Boolean useNesterov) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyMomentum"; - private ResourceSparseApplyMomentum(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalAdagrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalAdagrad.java index e235a19f5d1..678601d6aea 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalAdagrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalAdagrad.java @@ -105,7 +105,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyProximalAdagrad"; - private ResourceSparseApplyProximalAdagrad(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalGradientDescent.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalGradientDescent.java index 08a9edc01c4..11ad213524c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalGradientDescent.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalGradientDescent.java @@ -101,7 +101,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyProximalGradientDescent"; - private ResourceSparseApplyProximalGradientDescent(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyRmsProp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyRmsProp.java index 982e1f30eb7..8c519504f89 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyRmsProp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyRmsProp.java @@ -116,7 +116,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyRMSProp"; - private ResourceSparseApplyRmsProp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Save.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Save.java index 781714d8121..c5de40fc91b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Save.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Save.java @@ -63,7 +63,6 @@ public static Save create(Scope scope, Operand<TString> prefix, Operand<TString> /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "SaveV2"; - private Save(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SaveSlices.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SaveSlices.java index e8e67190e63..73325d1d1bc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SaveSlices.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SaveSlices.java @@ -89,7 +89,6 @@ public static SaveSlices create(Scope scope, Operand<TString> filename, Operand< /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "SaveSlices"; - private SaveSlices(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SdcaShrinkL1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SdcaShrinkL1.java index 24c6d53ef9d..748a2eacaec 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SdcaShrinkL1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SdcaShrinkL1.java @@ -56,7 +56,6 @@ public static SdcaShrinkL1 create(Scope scope, Iterable<Operand<TFloat32>> weigh /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "SdcaShrinkL1"; - private SdcaShrinkL1(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Send.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Send.java index d1172f8e96f..b18a86458ca 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Send.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Send.java @@ -55,7 +55,6 @@ public static <T extends TType> Send create(Scope scope, Operand<T> tensor, Stri /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "XlaSend"; - private Send(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/NN.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/NN.java deleted file mode 100644 index b4fa7bd01de..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/NN.java +++ /dev/null @@ -1,379 +0,0 @@ -package org.tensorflow.op.core; - -import org.tensorflow.DataType; -import org.tensorflow.Operand; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Op; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.op.math.*; -import org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits; -import org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits; -import org.tensorflow.types.*; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.op.dtypes.Cast; -import org.tensorflow.types.family.TType; -import org.tensorflow.op.linalg.Transpose; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -@Operator(group = "nn") -public abstract class NN { - - /** - * Computes softmax cross entropy between `logits` and `labels`. - * - * <p>Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - * <p>**NOTE:** While the classes are mutually exclusive, their probabilities need not be. All - * that is required is that each row of `labels` is a valid probability distribution. If they are - * not, the computation of the gradient will be incorrect. - * - * <p>If using exclusive `labels` (wherein one and only one class is true at a time), see - * `sparse_softmax_cross_entropy_with_logits`. - * - * <p>Usage: - * - * <pre> - * >>> logits = [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]] - * >>> labels = [[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]] - * >>> tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits) - * <tf.Tensor: shape=(2,), dtype=float32, - * numpy=array([0.16984604, 0.82474494], dtype=float32)> - * </pre> - * - * <p>Backpropagation will happen into both `logits` and `labels`. To disallow backpropagation - * into `labels`, pass label tensors through `tf.stop_gradient` before feeding it to this - * function. - * - * @param scope current scope - * @param labels Each vector along the class dimension should hold a valid probability - * distribution e.g. for the case in which labels are of shape `[batch_size, num_classes]`, - * each row of `labels[i]` must be a valid probability distribution. - * @param logits Per-label activations, typically a linear output. These activation energies are - * interpreted as unnormalized log probabilities. - * @param axis The class dimension. -1 is the last dimension. - * @param <U> the data type of the logits - * @param <T> the number type of the operands - * @return the softmax cross entropy loss. Its type is the same as `logits` and its shape is the - * same as `labels` except that it does not have the last dimension of `labels`. - */ - @Endpoint(name = "softmaxCrossEntropyWithLogits") - public static <U extends TType, T extends TNumber> Operand<T> softmaxCrossEntropyWithLogits( - Scope scope, Operand<T> labels, Operand<U> logits, int axis) { - axis = axis % logits.asOutput().shape().numDimensions(); - if (axis < 0) { - axis += logits.asOutput().shape().numDimensions(); - } - - Operand precise_logits = - logits; // cannot use generics cause logits of bool gets cast to TFloat32 - - boolean convertToFloat32 = - logits.asOutput().dataType() == TFloat16.DTYPE - || logits.asOutput().dataType() == TBfloat16.DTYPE; - if (convertToFloat32) { - precise_logits = Cast.create(scope, logits, TFloat32.DTYPE); - } - /* cannot use generics on DataType because precis_logits may have been cast. */ - DataType dtype = precise_logits.asOutput().dataType(); - labels = Cast.create(scope, labels, dtype); - Operand<TInt64> inputRank = - Cast.create(scope, Rank.create(scope, precise_logits), TInt64.DTYPE); - Shape shape = logits.asOutput().shape(); - - // Move the dim to the end if dim is not the last dimension. - if (axis != -1 && axis != precise_logits.asOutput().shape().numDimensions() - 1) { - precise_logits = moveDimToEnd(scope, precise_logits, axis, inputRank); - labels = moveDimToEnd(scope, labels, axis, inputRank); - } - - Shape inputShape = precise_logits.asOutput().shape(); - precise_logits = flattenOuterDims(scope, precise_logits); - labels = flattenOuterDims(scope, labels); - SoftmaxCrossEntropyWithLogits<T> smax = - SoftmaxCrossEntropyWithLogits.create(scope, precise_logits, labels); - /* cannot use generic on cost, because cost may be recast later. */ - Operand cost = smax.loss(); - Operand<TInt64> outputShape = - Slice.create( - scope, - Constant.vectorOf(scope, inputShape.asArray()), - Constant.vectorOf(scope, new long[] {0}), - Constant.vectorOf(scope, new long[] {inputShape.numDimensions() - 1})); - cost = Reshape.create(scope, cost, outputShape); - if (scope.env().isGraph() && !shape.hasUnknownDimension()) { - long[] array = shape.asArray(); - long[] newArray = new long[array.length - 1]; - if (axis < 0) { - axis = shape.numDimensions() + axis; - } - for (int i = 0; i < axis; i++) { - newArray[i] = shape.size(i); - } - for (int i = axis + 1; i < shape.numDimensions(); i++) { - newArray[i - 1] = shape.size(i); - } - Shape newShape = Shape.of(newArray); - cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newShape.asArray())); - } - - if (convertToFloat32) { - cost = Cast.create(scope, cost, logits.asOutput().dataType()); - } - return cost; - } - - /** - * Computes sparse softmax cross entropy between `logits` and `labels`. - * - * @param scope current scope - * @param labels `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of `labels` and - * result) and dtype `int32` or `int64`. Each entry in `labels` must be an index in `[0, - * num_classes)`. Other values will raise an exception when this op is run on CPU, and return - * `NaN` for corresponding loss and gradient rows on GPU. - * @param logits Per-label activations (typically a linear output) of shape `[d_0, d_1, ..., - * d_{r-1}, num_classes]` and dtype `float16`, `float32`, or `float64`. These activation - * energies are interpreted as unnormalized log probabilities. - * @return A `Tensor` of the same shape as `labels` and of the same type as `logits` with the - * softmax cross entropy loss. - */ - @Endpoint(name = "sparseSoftmaxCrossEntropyWithLogits") - public static <T extends TNumber, U extends TNumber> Operand sparseSoftmaxCrossEntropyWithLogits( - Scope scope, Operand<T> labels, Operand<U> logits) { - // assert shapeIsCompatible(labels.asOutput().shape(), logits.asOutput().shape()): - // String.format("Shapes %s and %s are incompatible", - // labels.asOutput().shape(), logits.asOutput().shape()); - scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits"); - /** cannot use generics on precise_logits as it may be recast later */ - Operand precise_logits = logits; - boolean convertToFloat32 = - logits.asOutput().dataType() == TFloat16.DTYPE - || logits.asOutput().dataType() == TBfloat16.DTYPE; - if (convertToFloat32) { - precise_logits = Cast.create(scope, logits, TFloat32.DTYPE); - } - Shape labelsStaticShape = labels.asOutput().shape(); - org.tensorflow.op.core.Shape<TInt32> labelsShape = - org.tensorflow.op.core.Shape.create(scope, labels); - Shape logitsShape = logits.asOutput().shape(); - Shape logitsShortened = logitsShape.take(logitsShape.numDimensions() - 1); - - boolean staticShapesFullyDefined = - !labelsStaticShape.hasUnknownDimension() && !logitsShortened.hasUnknownDimension(); - if (logitsShape.numDimensions() == 0) { - throw new IllegalArgumentException( - String.format("Logits cannot be scalars - received shape %s.", logitsShape)); - } - if (!logitsShape.hasUnknownDimension() - && !labelsStaticShape.hasUnknownDimension() - && labelsStaticShape.numDimensions() != logitsShape.numDimensions() - 1) { - throw new IllegalArgumentException( - String.format( - "Rank mismatch: Rank of labels (received %s) should equal rank of logits minus 1 (received %s).", - labelsStaticShape.toString(), logitsShape.toString())); - } - - if (staticShapesFullyDefined && !labelsStaticShape.equals(logitsShortened)) { - throw new IllegalArgumentException( - String.format( - "Shape mismatch: The shape of labels (received %s) " - + "should equal the shape of logits except for the last " - + "dimension (received %s).", - labelsStaticShape.toString(), logitsShape.toString())); - } - // Check if no reshapes are required. - if (logitsShape.numDimensions() == 2) { - SparseSoftmaxCrossEntropyWithLogits smax = - SparseSoftmaxCrossEntropyWithLogits.create(scope, precise_logits, labels); - Operand loss = smax.loss(); - if (logits.asOutput().dataType() == TFloat16.DTYPE) { - loss = Cast.create(scope, loss, TFloat16.DTYPE); - } - return loss; - } - - List<Op> shapeChecks = new ArrayList<>(); - - if (!staticShapesFullyDefined) { - shapeChecks.add( - AssertThat.create( - scope, - Equal.create( - scope, - org.tensorflow.op.core.Shape.create(scope, labels), - Shapes.take( - scope, - org.tensorflow.op.core.Shape.create(scope, logits), - Constant.scalarOf(scope, -1))), - Collections.singletonList( - Constant.scalarOf( - scope, - "Shape mismatch: The shape of labels " - + "should equal the shape of logits except for the last " - + "dimension ")))); - } - - // Reshape logits to 2 dim, labels to 1 dim. - long numClassses = logitsShape.size(logitsShape.numDimensions() - 1); - - precise_logits = - Reshape.create( - scope, precise_logits, Constant.vectorOf(scope, new long[] {-1, numClassses})); - labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1)); - scope.withControlDependencies(shapeChecks); - SparseSoftmaxCrossEntropyWithLogits smax = - SparseSoftmaxCrossEntropyWithLogits.create(scope, precise_logits, labels); - Operand cost = smax.loss(); - cost = Reshape.create(scope, cost, labelsShape); - if (logits.asOutput().dataType() == TFloat16.DTYPE) { - cost = Cast.create(scope, cost, TFloat16.DTYPE); - } - return cost; - } - - /** - * Computes sigmoid cross entropy given `logits`. - * - * <p>Measures the probability error in discrete classification tasks in which each class is - * independent and not mutually exclusive. For instance, one could perform multilabel - * classification where a picture can contain both an elephant and a dog at the same time. - * - * <p>For brevity, let `x = logits`, `z = labels`. The logistic loss is - * - * <pre> - * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) - * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) - * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) - * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) - * = (1 - z) * x + log(1 + exp(-x)) - * = x - x * z + log(1 + exp(-x)) - * </pre> - * - * <p>For x < 0, to avoid overflow in exp(-x), we reformulate the above - * - * <pre> - * x - x * z + log(1 + exp(-x)) - * = log(exp(x)) - x * z + log(1 + exp(-x)) - * = - x * z + log(1 + exp(x)) - * </pre> - * - * <p>Hence, to ensure stability and avoid overflow, the implementation uses this equivalent - * formulation - * - * <pre> - * max(x, 0) - x * z + log(1 + exp(-abs(x))) - * </pre> - * - * <p>`logits` and `labels` must have the same type and shape. - * - * @param scope The TensorFlow scope - * @param labels the labels - * @param logits the logits of type float32 or float64 - * @param <T> the type of labels and logits - * @return the component-wise logistic losses. - */ - @Endpoint(name = "sigmoidCrossEntropyWithLogits") - public static <T extends TNumber> Operand<T> sigmoidCrossEntropyWithLogits( - Scope scope, Operand<T> labels, Operand<T> logits) { - if (labels.asOutput().shape().numDimensions() != logits.asOutput().shape().numDimensions()) - throw new IllegalArgumentException( - String.format( - "logits and labels must have the same shape (%s vs %s)", - labels.asOutput().shape().toString(), logits.asOutput().shape())); - Operand<T> zeros = - Cast.create(scope, ZerosLike.create(scope, logits), logits.asOutput().dataType()); - Operand<TBool> cond = GreaterEqual.create(scope, logits, zeros); - - Operand<T> relu_logits = Select.create(scope, cond, logits, zeros); - Operand<T> neg_abs_logits = Select.create(scope, cond, Neg.create(scope, logits), logits); - return Add.create( - scope, - Sub.create(scope, relu_logits, Mul.create(scope, logits, labels)), - Log1p.create(scope, Exp.create(scope, neg_abs_logits))); - } - - /** - * Flattens logits' outer dimensions and keep its last dimension. - * - * @param scope the TensorFlow scope - * @param logits the logits - * @param <T> the type of logits - * @return the flattened logits - */ - private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Operand<T> logits) { - Operand<TInt64> one = Constant.scalarOf(scope, 1L); - - org.tensorflow.ndarray.Shape shape = logits.asOutput().shape(); - int ndims = shape.numDimensions(); - if (!shape.hasUnknownDimension()) { - long product = 1L; - boolean productValid = true; - for (int i = ndims - 2; i >= 0; i--) { - long d = shape.size(i); - if (d == org.tensorflow.ndarray.Shape.UNKNOWN_SIZE) { - productValid = false; - break; - } - product *= d; - } - if (productValid) { - org.tensorflow.ndarray.Shape outputShape = Shape.of(product, shape.size(ndims - 1)); - return Reshape.create(scope, logits, Constant.vectorOf(scope, outputShape.asArray())); - } - } - - Operand<TInt64> rank = Cast.create(scope, Rank.create(scope, logits), TInt64.DTYPE); - Operand<TInt64> rankMinusOne = Sub.create(scope, rank, one); - - Operand<TInt64> last_dim_size = - Slice.create( - scope, - org.tensorflow.op.core.Shape.create(scope, logits, TInt64.DTYPE), - rankMinusOne, - one); - Operand<TInt64> concat = - Concat.create( - scope, - Arrays.asList(Constant.vectorOf(scope, new long[] {-1}), last_dim_size), - Constant.scalarOf(scope, 0)); - return Reshape.create(scope, logits, concat); - } - - /** - * Move the dim to the end if dim is not the last dimension. - * - * @param scope The TensorFlow Scope - * @param input the input to reshape - * @param dim_index the index to move - * @param rank the number of Dimensions in the tensor - * @param <T> the data type of the tensor. - * @param <U> the data type of the rank - * @return the reshaped input - */ - private static <T extends TNumber, U extends TNumber> Operand<T> moveDimToEnd( - Scope scope, Operand<T> input, int dim_index, Operand<U> rank) { - DataType rankDType = rank.asOutput().dataType(); - Operand<U> one = Cast.create(scope, Constant.scalarOf(scope, 1), rankDType); - List<Operand<T>> concatList = - Arrays.asList( - Range.create( - scope, - Cast.create(scope, Constant.scalarOf(scope, dim_index), rankDType), - one, - one), - Range.create( - scope, - Cast.create(scope, Constant.scalarOf(scope, (dim_index + 1)), rankDType), - rank, - one)); - return Transpose.create( - scope, input, Concat.create(scope, concatList, Constant.scalarOf(scope, 0))); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java new file mode 100644 index 00000000000..4f3e9569103 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java @@ -0,0 +1,108 @@ +package org.tensorflow.op.nn; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.Select; +import org.tensorflow.op.core.ZerosLike; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TNumber; + +@Operator(group = "nn") +public class SigmoidCrossEntropyWithLogits { + + /** + * Computes sigmoid cross entropy given <code>logits</code>. + * + * <p>Measures the probability error in discrete classification tasks in which each class is + * independent and not mutually exclusive. For instance, one could perform multilabel + * classification where a picture can contain both an elephant and a dog at the same time. + * + * <p>For brevity, let <code>x = logits</code>, <code>z = labels</code>. The logistic loss in + * pseudo-code is + * + * <pre> + * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) + * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) + * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) + * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) + * = (1 - z) * x + log(1 + exp(-x)) + * = x - x * z + log(1 + exp(-x)) + * </pre> + * + * <p>For <code>x < 0</code>, to avoid overflow in <code>exp(-x)</code>, we reformulate the above + * + * <pre> + * x - x * z + log(1 + exp(-x)) + * = log(exp(x)) - x * z + log(1 + exp(-x)) + * = - x * z + log(1 + exp(x)) + * </pre> + * + * <p>Hence, to ensure stability and avoid overflow, the implementation uses this equivalent + * formulation + * + * <pre> + * max(x, 0) - x * z + log(1 + exp(-abs(x))) + * </pre> + * + * <p></ode>logits</code> and <code>labels</code> must have the same type and shape. + * + * <p> + * + * @param scope The TensorFlow scope + * @param labels the labels + * @param logits the logits of type float32 or float64 + * @param <T> the type of labels and logits + * @return the component-wise logistic losses. + * @throws IllegalArgumentException if logits' and labels' do not have the same shape + */ + @Endpoint(name = "sigmoidCrossEntropyWithLogits") + public static <T extends TNumber> Operand<T> sigmoidCrossEntropyWithLogits( + Scope scope, Operand<T> labels, Operand<T> logits) { + if (!isCompatible(labels.asOutput().shape(), logits.asOutput().shape())) { + throw new IllegalArgumentException( + String.format( + "logits and labels must have the same shape (%s vs %s)", + labels.asOutput().shape().toString(), logits.asOutput().shape())); + } + scope = scope.withSubScope("SigmoidCrossEntropyWithLogits"); + + Operand<T> zeros = + Cast.create(scope, ZerosLike.create(scope, logits), logits.asOutput().dataType()); + Operand<TBool> cond = GreaterEqual.create(scope, logits, zeros); + + Operand<T> reluLogits = Select.create(scope, cond, logits, zeros); + Operand<T> negAbsLogits = Select.create(scope, cond, Neg.create(scope, logits), logits); + return Add.create( + scope, + Sub.create(scope, reluLogits, Mul.create(scope, logits, labels)), + Log1p.create(scope, Exp.create(scope, negAbsLogits))); + } + /** + * Determine if 2 shapes are compatible + * + * <p>2 shapes are compatible if they have the same number of dimensions, and if the corresponding + * dimensions are equal, or at least one of the corresponding dimensions is unknown. + * + * @param shape the first shape + * @param other the second shape + * @return true, if the shapes are compatible. + */ + private static boolean isCompatible(Shape shape, Shape other) { + if (shape.numDimensions() != other.numDimensions()) return false; + for (int i = 0; i < shape.numDimensions(); i++) { + long aShapeDim = shape.size(i); + long bShapeDim = other.size(i); + if (aShapeDim == bShapeDim + || (aShapeDim == Shape.UNKNOWN_SIZE || bShapeDim == Shape.UNKNOWN_SIZE)) { + continue; + } + return false; + } + return true; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java new file mode 100644 index 00000000000..0c8bac697ed --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -0,0 +1,214 @@ +package org.tensorflow.op.nn; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.*; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.linalg.Transpose; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.List; + +@Operator(group = "nn") +public class SoftmaxCrossEntropyWithLogits { + + /** + * Computes softmax cross entropy between <code>logits</code> and <code>labels</code>. + * + * <p>Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + * <p><b>NOTE:</b> + * + * <p>While the classes are mutually exclusive, their probabilities need not be. All that is + * required is that each row of <code>labels</code> is a valid probability distribution. If they + * are not, the computation of the gradient will be incorrect. + * + * <p>If using exclusive <code>labels</code> (wherein one and only one class is true at a time), + * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} + * + * <p>Usage: + * + * <pre> + * Operand<TFloat32> logits = + * tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} ); + * Operand<TFloat32> labels = + * tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} ); + * Operand<TFloat32> output = + * tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1); + * // output Shape = [2] + * // dataType = FLOAT (1) + * // values { 0.169846, 0.824745 } + * </pre> + * + * <p>Backpropagation will happen into both <code>logits</code> and <code>labels</code>. To + * disallow backpropagation into <code>labels</code>, pass label tensors through <code> + * tf.stopGradient</code> before feeding it to this function. + * + * @param scope current scope + * @param labels Each vector along the class dimension should hold a valid probability + * distribution e.g. for the case in which labels are of shape <code>[batch_size, num_classes] + * </code>, each row of <code>labels[i]</code> must be a valid probability distribution. + * @param logits Per-label activations, typically a linear output. These activation energies are + * interpreted as unnormalized log probabilities. + * @param axis The class dimension. -1 is the last dimension. + * @param <T> the number type of the operands + * @return the softmax cross entropy loss. Its type is the same as <code>logits</code> and its + * shape is the same as <code>labels</code> except that it does not have the last dimension of + * <code>labels</code>. + */ + @Endpoint(name = "softmaxCrossEntropyWithLogits") + public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntropyWithLogits( + Scope scope, Operand<U> labels, Operand<T> logits, int axis) { + scope = scope.withSubScope("SoftmaxCrossEntropyWithLogits"); + axis = axis % logits.asOutput().shape().numDimensions(); + if (axis < 0) { + axis += logits.asOutput().shape().numDimensions(); + } + + + boolean convertToFloat32 = + logits.asOutput().dataType() == TFloat16.DTYPE + || logits.asOutput().dataType() == TBfloat16.DTYPE; + if (convertToFloat32) { + Operand<TFloat32> result = softmaxCrossEntropyWithLogits(scope, + Cast.create(scope, labels, TFloat32.DTYPE), + Cast.create(scope, logits, TFloat32.DTYPE), + axis); + return Cast.create(scope, result, logits.asOutput().dataType()); + } else if(!logits.asOutput().dataType().equals(labels.asOutput().dataType())) { + return softmaxCrossEntropyWithLogits(scope, + Cast.create(scope, labels, logits.asOutput().dataType()), + logits, + axis); + } + + Operand<TInt64> inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.DTYPE); + Shape shape = logits.asOutput().shape(); + + // Move the dim to the end if dim is not the last dimension. + if (axis != -1 && axis != logits.asOutput().shape().numDimensions() - 1) { + logits = moveDimToEnd(scope, logits, axis, inputRank); + labels = moveDimToEnd(scope, labels, axis, inputRank); + } + + Shape inputShape = logits.asOutput().shape(); + logits = flattenOuterDims(scope, logits); + labels = flattenOuterDims(scope, labels); + + org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits<T> smax = + org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits.create( + scope, logits, (Operand<T>)labels); + /* cannot use generic on cost, because cost may be recast later. */ + Operand<T> cost = smax.loss(); + Operand<TInt64> outputShape = + Slice.create( + scope, + Constant.tensorOf(scope, inputShape), + Constant.arrayOf(scope, 0L), + Constant.arrayOf(scope, inputShape.numDimensions() - 1L)); + cost = Reshape.create(scope, cost, outputShape); + if (scope.env().isGraph() && !shape.hasUnknownDimension()) { + long[] array = shape.asArray(); + long[] newArray = new long[array.length - 1]; + if (axis < 0) { + axis = shape.numDimensions() + axis; + } + for (int i = 0; i < axis; i++) { + newArray[i] = shape.size(i); + } + for (int i = axis + 1; i < shape.numDimensions(); i++) { + newArray[i - 1] = shape.size(i); + } + cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newArray)); + } + + return cost; + } + + /** + * Flattens logits' outer dimensions and keep its last dimension. + * + * @param scope the TensorFlow scope + * @param logits the logits + * @param <T> the type of logits + * @return the flattened logits + */ + private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Operand<T> logits) { + Operand<TInt64> one = Constant.scalarOf(scope, 1L); + + Shape shape = logits.asOutput().shape(); + int ndims = shape.numDimensions(); + if (!shape.hasUnknownDimension()) { + long product = 1L; + boolean productValid = true; + for (int i = ndims - 2; i >= 0; i--) { + long d = shape.size(i); + if (d == org.tensorflow.ndarray.Shape.UNKNOWN_SIZE) { + productValid = false; + break; + } + product *= d; + } + if (productValid) { + return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.size(-1))); + } + } + + Operand<TInt64> rank = Cast.create(scope, Rank.create(scope, logits), TInt64.DTYPE); + Operand<TInt64> rankMinusOne = Sub.create(scope, rank, one); + + Operand<TInt64> lastDimSize = + Slice.create( + scope, + org.tensorflow.op.core.Shape.create(scope, logits, TInt64.DTYPE), + rankMinusOne, + one); + Operand<TInt64> concat = + Concat.create( + scope, + Arrays.asList(Constant.arrayOf(scope, -1L), lastDimSize), + Constant.scalarOf(scope, 0)); + return Reshape.create(scope, logits, concat); + } + + /** + * Move the dim to the end if dimIndex is not the last dimension. + * + * @param scope The TensorFlow Scope + * @param input the input to reshape + * @param dimIndex the index to move + * @param rank the number of Dimensions in the tensor + * @param <T> the data type of the tensor. + * @param <U> the data type of the rank + * @return the reshaped input + */ + private static <T extends TNumber, U extends TNumber> Operand<T> moveDimToEnd( + Scope scope, Operand<T> input, int dimIndex, Operand<U> rank) { + DataType<? extends TNumber> rankDType = rank.asOutput().dataType(); + Operand one = Cast.create(scope, Constant.scalarOf(scope, 1), rankDType); + List<Operand<U>> concatList = + Arrays.asList( + Range.create( + scope, Cast.create(scope, Constant.scalarOf(scope, dimIndex), rankDType), one, one), + Range.create( + scope, + Cast.create(scope, Constant.scalarOf(scope, dimIndex + 1), rankDType), + one, + one)); + return Transpose.create( + scope, input, Concat.create(scope, concatList, Constant.scalarOf(scope, 0))); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java new file mode 100644 index 00000000000..ebd6f74e7d8 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -0,0 +1,161 @@ +package org.tensorflow.op.nn; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.AssertThat; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.core.Shapes; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.Equal; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +@Operator(group = "nn") +public class SparseSoftmaxCrossEntropyWithLogits { + + /** + * Computes sparse softmax cross entropy between <code>logits</code> and <code>labels</code>. + * + * <p>Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + * <p><b>NOTE:</b> + * + * <p>For this operation, the probability of a given label is considered exclusive. That is, soft + * classes are not allowed, and the <code>labels</code> vector must provide a single specific + * index for the true class for each row of <code>logits</code> (each minibatch entry). For soft + * softmax classification with a probability distribution for each entry, {@link + * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. + * + * <p><b>WARNING:</b> + * + * <p>This op expects unscaled logits, since it performs a <code>softmax</code> on <code>logits + * </code> internally for efficiency. Do not call this op with the output of <code>softmax</code>, + * as it will produce incorrect results. + * + * <p>A common use case is to have logits of shape <code>[batchSize, numClasses]</code> and have + * labels of shape <code>[batchSize]</code>, but higher dimensions are supported, in which case + * the <code>dim</code>-th dimension is assumed to be of size <code>numClasses</code>. <code> + * logits</code> must have the <cod>dataType</cod> of <code>TFloat16</code>, <code>TFloat32</code> + * , or <code>TFloat64</code>, and <code>labels</code> must have the dtype of <code>TInt32</code> + * or <code>TInt64</code>. + * + * @param scope current scope + * @param labels <code>Tensor</code> of shape <code>[d_0, d_1, ..., d_{r-1}]</code> (where <code>r + * </code> is rank of <code>labels</code> and result) and the dataType is <code>TInt32</code> + * or <code>TInt64</code>. Each entry in <code>labels</code> must be an index in <code>[0, + * numClasses)</code>. Other values will raise an exception when this op is run on CPU, and + * return <code>NaN</code> for corresponding loss and gradient rows on GPU. + * @param logits Per-label activations (typically a linear output) of shape <code>[d_0, d_1, ..., + * d_{r-1}, numClasses]</code> and dataType of <code>TFloat16</code>, <code>TFloat32</code>, + * or <code>TFloat64</code>. These activation energies are interpreted as unnormalized log + * probabilities. + * @return A <code>Tensor</code> of the same shape as <code>labels</code> and of the same type as + * <code>logits</code> with the softmax cross entropy loss. + * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank + * of the labels is not equal to the rank of the logits minus one. + */ + @Endpoint(name = "sparseSoftmaxCrossEntropyWithLogits") + public static <T extends TNumber, U extends TNumber> Operand sparseSoftmaxCrossEntropyWithLogits( + Scope scope, Operand<T> labels, Operand<U> logits) { + scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits"); + /** cannot use generics on preciseLogits as it may be recast later */ + Operand preciseLogits = logits; + boolean convertToFloat32 = + logits.asOutput().dataType() == TFloat16.DTYPE + || logits.asOutput().dataType() == TBfloat16.DTYPE; + if (convertToFloat32) { + preciseLogits = Cast.create(scope, logits, TFloat32.DTYPE); + } + Shape labelsStaticShape = labels.asOutput().shape(); + org.tensorflow.op.core.Shape<TInt32> labelsShape = + org.tensorflow.op.core.Shape.create(scope, labels); + Shape logitsShape = logits.asOutput().shape(); + Shape logitsShortened = logitsShape.take(logitsShape.numDimensions() - 1); + + boolean staticShapesFullyDefined = + !labelsStaticShape.hasUnknownDimension() && !logitsShortened.hasUnknownDimension(); + if (logitsShape.numDimensions() == 0) { + throw new IllegalArgumentException( + String.format("Logits cannot be scalars - received shape %s.", logitsShape)); + } + if (!logitsShape.hasUnknownDimension() + && !labelsStaticShape.hasUnknownDimension() + && labelsStaticShape.numDimensions() != logitsShape.numDimensions() - 1) { + throw new IllegalArgumentException( + String.format( + "Rank mismatch: Rank of labels (received %s) should equal rank of logits minus 1 (received %s).", + labelsStaticShape.toString(), logitsShape.toString())); + } + + if (staticShapesFullyDefined && !labelsStaticShape.equals(logitsShortened)) { + throw new IllegalArgumentException( + String.format( + "Shape mismatch: The shape of labels (received %s) " + + "should equal the shape of logits except for the last " + + "dimension (received %s).", + labelsStaticShape.toString(), logitsShape.toString())); + } + // Check if no reshapes are required. + if (logitsShape.numDimensions() == 2) { + org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create( + scope, preciseLogits, labels); + Operand loss = smax.loss(); + if (logits.asOutput().dataType() == TFloat16.DTYPE) { + loss = Cast.create(scope, loss, TFloat16.DTYPE); + } + return loss; + } + + List<Op> shapeChecks = new ArrayList<>(); + + if (!staticShapesFullyDefined) { + shapeChecks.add( + AssertThat.create( + scope, + Equal.create( + scope, + org.tensorflow.op.core.Shape.create(scope, labels), + Shapes.take( + scope, + org.tensorflow.op.core.Shape.create(scope, logits), + Constant.scalarOf(scope, -1))), + Collections.singletonList( + Constant.scalarOf( + scope, + "Shape mismatch: The shape of labels " + + "should equal the shape of logits except for the last " + + "dimension ")))); + } + + // Reshape logits to 2 dims, labels to 1 dim. + long numClassses = logitsShape.size(-1); + + preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses)); + labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1)); + scope.withControlDependencies(shapeChecks); + org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create( + scope, preciseLogits, labels); + Operand cost = smax.loss(); + cost = Reshape.create(scope, cost, labelsShape); + if (logits.asOutput().dataType() == TFloat16.DTYPE) { + cost = Cast.create(scope, cost, TFloat16.DTYPE); + } + return cost; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java index bac5fb96f87..3cc72101893 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java @@ -17,21 +17,22 @@ package org.tensorflow.types; -import java.util.function.Consumer; import org.tensorflow.DataType; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.buffer.layout.DataLayouts; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.BooleanDataBuffer; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayouts; import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; import org.tensorflow.types.family.TType; +import java.util.function.Consumer; + /** * Boolean tensor type. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java index 0f097a16ddb..6e2e7a7ba56 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java @@ -17,23 +17,24 @@ package org.tensorflow.types; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; -import java.util.function.Function; import org.tensorflow.DataType; import org.tensorflow.Tensor; import org.tensorflow.internal.buffer.StringTensorBuffer; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffer; import org.tensorflow.ndarray.buffer.layout.DataLayout; import org.tensorflow.ndarray.buffer.layout.DataLayouts; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.impl.dense.DenseNdArray; import org.tensorflow.types.family.TType; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.function.Function; + /** * String type. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index a058649373a..a099eae53e8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -15,63 +15,98 @@ */ package org.tensorflow.framework.optimizers; -import java.util.Collections; -import java.util.List; -import java.util.Map; - import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.Tensor; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; -import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyMomentum; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; +import java.util.List; + /** - * SGD plus momentum, either nesterov or traditional. - * <p> - * See the <a href="http://jmlr.org/proceedings/papers/v28/sutskever13.pdf">paper</a> for details of - * nesterov momentum. + * Stochastic gradient descent plus momentum, either nesterov or traditional. + * + * <p>See the <a href="http://jmlr.org/proceedings/papers/v28/sutskever13.pdf">paper</a> for details + * of nesterov momentum. */ public class Momentum extends Optimizer { - public static final String MOMENTUM = "momentum"; + public static final float LEARNING_RATE_DEFAULT = 0.01F; + public static final float MOMENTUM_DEFAULT = 0.0F; + public static final boolean NESTEROV_DEFAULT = false; - private float learningRate; - private Tensor<TFloat32> learningRateTensor; - private final Placeholder<TFloat32> learningRatePlaceholder; - private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; + public static final String MOMENTUM = "momentum"; private final float momentum; private final boolean useNesterov; + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + */ + public Momentum(Graph graph) { + this(graph, LEARNING_RATE_DEFAULT, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); + } + + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + * @param learningRate the learning rate + */ + public Momentum(Graph graph, float learningRate) { + this(graph, learningRate, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); + } + + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + * @param learningRate the learning rate + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. Default is 0. + */ + public Momentum(Graph graph, float learningRate, float momentum) { + this(graph, learningRate, momentum, NESTEROV_DEFAULT); + } + + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + * @param learningRate the learning rate + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. Default is 0. + * @param useNesterov Whether to apply Nesterov momentum. Defaults to false. + */ public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) { - super(graph); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.learningRatePlaceholder = - tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + super(graph, learningRate); this.momentum = momentum; this.useNesterov = useNesterov; } - public Momentum(Graph graph, String name, float learningRate, float momentum, boolean useNesterov) { - super(graph, name); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.learningRatePlaceholder = - tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer + * @param learningRate the learning rate + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. Default is 0. + * @param useNesterov Whether to apply Nesterov momentum. Defaults to false. + */ + public Momentum( + Graph graph, String name, float learningRate, float momentum, boolean useNesterov) { + super(graph, name, learningRate); this.momentum = momentum; this.useNesterov = useNesterov; } + /** {@inheritDoc} */ @Override protected void createSlots(List<Output<? extends TType>> variables) { for (Output<? extends TType> v : variables) { @@ -79,64 +114,46 @@ protected void createSlots(List<Output<? extends TType>> variables) { } } + /** + * Creates a slot for the momentum variable + * + * @param v the variable + * @param <T> the data type of the variable + */ private <T extends TType> void createMomentumSlot(Output<T> v) { - Operand<T> initializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + Operand<T> initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); createSlot(v.asOutput(), MOMENTUM, initializer); } + /** {@inheritDoc} */ @Override protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { Variable<T> slot = getSlot(variable, MOMENTUM).get(); - return tf.train - .applyMomentum(variable, slot, tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), - gradient, - tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), - ApplyMomentum.useNesterov(useNesterov)); + return tf.train.applyMomentum( + variable, + slot, + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), + gradient, + tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), + ApplyMomentum.useNesterov(useNesterov)); } + /** {@inheritDoc} */ @Override public String toString() { - return "Momentum{" + - "learningRate=" + learningRate + - ", momentum=" + momentum + - ", useNesterov=" + useNesterov + - '}'; + return "Momentum{" + + "learningRate=" + + learningRate + + ", momentum=" + + momentum + + ", useNesterov=" + + useNesterov + + '}'; } + /** {@inheritDoc} */ @Override public String getOptimizerName() { return "Momentum"; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public final void setLearningRate(float learningRate) { - this.learningRate = learningRate; - if (this.learningRateTensor != null) { - this.learningRateTensor.close(); - } - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); - } - - /** {@inheritDoc} */ - public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { - return this.feedDict; - } - - /** {@inheritDoc} */ - @Override - public void close() throws Exception { - if (this.learningRateTensor != null) { - this.learningRateTensor.close(); - this.learningRateTensor = null; - } - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java new file mode 100644 index 00000000000..d0228eb8b3a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -0,0 +1,295 @@ +package org.tensorflow.framework.optimizers; + +import org.tensorflow.DataType; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +import java.util.List; +import java.util.Optional; + +/** + * Nadam Optimizer that implements the NAdam algorithm. + * + * <p>Much like Adam is essentially RMSprop with momentum, Nadam is Adam with Nesterov momentum. + * + * @see <a href="http://cs229.stanford.edu/proj2015/054_report.pdf">Dozat, 2015</a> + */ +public class Nadam extends Optimizer { + + private static final float DECAY_BASE = 0.96f; + private static final float DECAY = 0.004f; + public static final float LEARNING_RATE_DEFAULT = 0.001f; + public static final float EPSILON_DEFAULT = 1e-8f; + public static final float BETA_ONE_DEFAULT = 0.9f; + public static final float BETA_TWO_DEFAULT = 0.999f; + public static final String FIRST_MOMENT = "m"; + public static final String SECOND_MOMENT = "v"; + public static final String MOMENTUM = "momentum"; + + /** The exponential decay rate for the 1st moment estimates. */ + private final float betaOne; + + /** The exponential decay rate for the exponentially weighted infinity norm. */ + private final float betaTwo; + + /** A small constant for numerical stability. */ + private final float epsilon; + + private Constant<TFloat32> epsilonConst; + private Constant<TFloat32> betaOneConst; + private Constant<TFloat32> betaTwoConst; + + private Variable<TFloat32> betaOnePower; + private Variable<TFloat32> betaTwoPower; + private Variable<TFloat32> momentum; + + private long iterations = 0; + + // private Operand<TFloat32> mT; + private Operand<TFloat32> mT1; + + private Operand<TFloat32> oneMinusBeta1; + private Operand<TFloat32> oneMinusBeta2; + private Operand<TFloat32> oneMinusMT; + private Operand<TFloat32> oneMinusMScheduleNew; + private Operand<TFloat32> oneMinusMScheduleNext; + private Operand<TFloat32> vTPrimeDenominator; + + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + */ + public Nadam(Graph graph) { + this(graph, LEARNING_RATE_DEFAULT, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + * @param learningRate the learning rate, defaults to 0.001 + */ + public Nadam(Graph graph, float learningRate) { + this(graph, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + * @param learningRate the learning rate, defaults to 0.001 + * @param betaOne The exponential decay rate for the 1st moment estimates. Default is 0.9. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. Default + * is 0.999. + * @param epsilon A small constant for numerical stability. Default is 1e-8. + */ + public Nadam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { + super(graph, learningRate); + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer, defaults to "Nadam" + * @param learningRate the learning rate, defaults to 0.001 + */ + public Nadam(Graph graph, String name, float learningRate) { + this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer, defaults to "Nadam" + * @param learningRate the learning rate, defaults to 0.001 + * @param betaOne The exponential decay rate for the 1st moment estimates. Default is 0.9. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. Default + * is 0.999. + * @param epsilon A small constant for numerical stability. Default is 1e-8. + */ + public Nadam( + Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { + super(graph, name, learningRate); + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + + /** {@inheritDoc} */ + @Override + protected void createSlots(List<Output<? extends TType>> variables) { + for (Output<? extends TType> v : variables) { + createNadamSlot(v.asOutput()); + } + betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); + Assign<TFloat32> betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); + ((Graph) tf.scope().env()).addInitializer(betaOnePowerInit); + + betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.DTYPE); + Assign<TFloat32> betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo)); + ((Graph) tf.scope().env()).addInitializer(betaTwoPowerInit); + + momentum = tf.withName("momentum").variable(Shape.scalar(), TFloat32.DTYPE); + Assign<TFloat32> momentumInit = tf.assign(momentum, tf.constant(1.0F)); + ((Graph) tf.scope().env()).addInitializer(momentumInit); + } + + /** + * Creates slots for first and second moments and momentum + * + * @param v the variable + * @param <T> the data type or the Variable + */ + private <T extends TType> void createNadamSlot(Output<T> v) { + Operand<T> firstMomentInitializer = + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); + Operand<T> secondMomentInitializer = + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); + + Operand<T> momentumInitializer = + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.dataType())); + createSlot(v.asOutput(), MOMENTUM, momentumInitializer); + } + + /** {@inheritDoc} */ + @Override + protected Optional<Op> prepare(String scopeName) { + Constant<TFloat32> one = tf.constant(1.0F); + Constant<TFloat32> point5 = tf.constant(0.5F); + + betaOneConst = tf.constant(betaOne); + betaTwoConst = tf.constant(betaTwo); + Constant<TInt64> localStepConst = tf.constant(this.iterations + 1); + Constant<TInt64> nextStepConst = tf.constant(this.iterations + 2); + Constant<TFloat32> decayConst = tf.constant(DECAY); + Constant<TFloat32> decayBaseConst = tf.constant(DECAY_BASE); + epsilonConst = tf.constant(this.epsilon); + + Operand<TFloat32> mT = + tf.math.mul( + betaOneConst, + tf.math.sub( + one, + tf.math.mul( + point5, + tf.math.pow( + decayBaseConst, + tf.math.mul(decayConst, tf.dtypes.cast(localStepConst, TFloat32.DTYPE)))))); + + mT1 = + tf.math.mul( + betaOneConst, + tf.math.sub( + one, + tf.math.mul( + point5, + tf.math.pow( + decayBaseConst, + tf.math.mul(decayConst, tf.dtypes.cast(nextStepConst, TFloat32.DTYPE)))))); + + Operand<TFloat32> mScheduleNew = tf.math.mul(momentum, mT); + + mScheduleNew = tf.assign(momentum, mScheduleNew, Assign.useLocking(true)); + Operand<TFloat32> mScheduleNext = tf.math.mul(mScheduleNew, mT1); + + oneMinusBeta1 = tf.math.sub(one, betaOneConst); + oneMinusBeta2 = tf.math.sub(one, betaTwoConst); + oneMinusMT = tf.math.sub(one, mT); + oneMinusMScheduleNew = tf.math.sub(one, mScheduleNew); + oneMinusMScheduleNext = tf.math.sub(one, mScheduleNext); + vTPrimeDenominator = + tf.math.sub(one, tf.math.pow(betaTwoConst, tf.dtypes.cast(localStepConst, TFloat32.DTYPE))); + return Optional.empty(); + } + + /** {@inheritDoc} */ + @Override + protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { + DataType<T> dType = gradient.dataType(); + Variable<T> m = getSlot(variable, FIRST_MOMENT).get(); // first Moment + Variable<T> v = getSlot(variable, SECOND_MOMENT).get(); // Second Moment + + // gPrime = grad / coefficients['oneMinusMScheduleNew'] + Operand<T> gPrime = tf.math.div(gradient, tf.dtypes.cast(oneMinusMScheduleNew, dType)); + // mT = (coefficients['beta_1_t'] * m + coefficients['one_minus_beta_1_t'] * grad) + Operand<T> mT = + tf.math.add( + tf.math.mul(tf.dtypes.cast(betaOneConst, dType), m), + tf.math.mul(tf.dtypes.cast(oneMinusBeta1, dType), gradient)); + // mT = state_ops.assign(m, mT, use_locking=self._use_locking) + // update m + mT = tf.assign(m, mT, Assign.useLocking(true)); + + // mTPrime = mT / coefficients['oneMinusMScheduleNext'] + Operand<T> mTPrime = tf.math.div(mT, tf.dtypes.cast(oneMinusMScheduleNext, dType)); + + // vT = (coefficients['beta_2_t'] * v + coefficients['one_minus_beta_2_t'] * + // math_ops.square(grad)) + Operand<T> vT = + tf.math.add( + tf.math.mul(tf.dtypes.cast(betaTwoConst, dType), v), + tf.math.mul(tf.dtypes.cast(oneMinusBeta2, dType), tf.math.square(gradient))); + // vT = state_ops.assign(v, vT, use_locking=self._use_locking) + // update v + vT = tf.assign(v, vT, Assign.useLocking(true)); + + // vTPrime = vT / coefficients['vTPrimeDenominator'] + Operand<T> vTPrime = tf.math.div(vT, tf.dtypes.cast(vTPrimeDenominator, dType)); + + // m_t_bar = (coefficients['oneMinusMT'] * gPrime + coefficients['mT1'] * mTPrime) + Operand<T> m_t_bar = + tf.math.add( + tf.math.mul(tf.dtypes.cast(oneMinusMT, dType), gPrime), + tf.math.mul(tf.dtypes.cast(mT1, dType), mTPrime)); + // varT = var - coefficients['lr_t'] * m_t_bar / (math_ops.sqrt(vTPrime) + + // coefficients['epsilon']) + Operand<T> varT = + tf.math.sub( + variable, + tf.math.div( + tf.math.mul(tf.dtypes.cast(getLearningRateOperand(), dType), m_t_bar), + tf.math.add(tf.math.sqrt(vTPrime), tf.dtypes.cast(epsilonConst, dType)))); + + return tf.assign(variable, varT, Assign.useLocking(true)); + } + + /** + * Gathers up the update operations into a single op that can be used as a run target. + * + * <p>Adds the betaOne, betaTwo and mu updates to the end of the updates list. + * + * @param updateOperations The update operations. + * @param name The name of the run target. + * @return A NoOp with a control dependency on each update operation. + */ + @Override + protected Op finish(List<Op> updateOperations, String name) { + iterations++; // increment the step; + updateOperations.add(tf.assign(betaOnePower, tf.math.mul(betaOnePower, betaOneConst))); + updateOperations.add(tf.assign(betaTwoPower, tf.math.mul(betaTwoPower, betaTwoConst))); + return super.finish(updateOperations, name); + } + + /** {@inheritDoc} */ + @Override + public String getOptimizerName() { + return "Nadam"; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index def464a86ca..8e0471dc0ba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -15,50 +15,47 @@ */ package org.tensorflow.framework.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - import org.tensorflow.*; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.NoOp; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -/** - * Base class for gradient optimizers. - */ -public abstract class Optimizer implements AutoCloseable { +import java.util.*; +import java.util.stream.Collectors; + +/** Base class for gradient optimizers. */ +public abstract class Optimizer implements AutoCloseable { public static final String LEARNING_RATE = "learning_rate"; public static final String VARIABLE_V2 = "VariableV2"; - /** - * Global state variables - */ - //TODO make this be used. + public static final float LEARNING_RATE_DEFAULT = 0.001f; + + /** Global state variables */ + // TODO make this be used. protected final List<Variable<?>> globals; - /** - * The Graph this optimizer is operating on. - */ + /** The Graph this optimizer is operating on. */ protected final Graph graph; - /** - * The ops builder for the graph. - */ + /** The ops builder for the graph. */ protected final Ops tf; - /** - * Top level map key is the variable name, lower level map key is the slot name. - */ + /** Top level map key is the variable name, lower level map key is the slot name. */ private final Map<String, Map<String, Variable<?>>> slots; + protected float learningRate; + protected Placeholder<TFloat32> learningRatePlaceholder = null; + private Tensor<TFloat32> learningRateTensor; + private Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap = null; + /** * Builds an optimizer for the supplied graph. - * <p> - * Uses the name from {@link Optimizer#getOptimizerName()} to name the operations. + * + * <p>Uses the name from {@link Optimizer#getOptimizerName()} to name the operations. + * * @param graph The graph to optimize. */ protected Optimizer(Graph graph) { @@ -66,10 +63,28 @@ protected Optimizer(Graph graph) { this.tf = Ops.create(graph).withName(getOptimizerName()); this.slots = new HashMap<>(); this.globals = new ArrayList<>(); + setLearningRate(LEARNING_RATE_DEFAULT); } /** * Builds an optimizer for the supplied graph. + * + * <p>Uses the name from {@link Optimizer#getOptimizerName()} to name the operations. + * + * @param graph The graph to optimize. + * @param learningRate the learning rate. + */ + protected Optimizer(Graph graph, float learningRate) { + this.graph = graph; + this.tf = Ops.create(graph).withName(getOptimizerName()); + this.slots = new HashMap<>(); + this.globals = new ArrayList<>(); + setLearningRate(learningRate); + } + + /** + * Builds an optimizer for the supplied graph. + * * @param graph The graph to optimize. * @param name The base name for the operations. */ @@ -78,6 +93,22 @@ protected Optimizer(Graph graph, String name) { this.tf = Ops.create(graph).withName(name); this.slots = new HashMap<>(); this.globals = new ArrayList<>(); + setLearningRate(LEARNING_RATE_DEFAULT); + } + + /** + * Builds an optimizer for the supplied graph. + * + * @param graph The graph to optimize. + * @param name The base name for the operations. + * @param learningRate the learning rate. + */ + protected Optimizer(Graph graph, String name, float learningRate) { + this.graph = graph; + this.tf = Ops.create(graph).withName(name); + this.slots = new HashMap<>(); + this.globals = new ArrayList<>(); + setLearningRate(learningRate); } public static String createName(Output<? extends TType> variable, String slotName) { @@ -96,11 +127,14 @@ public Op minimize(Operand<?> loss, String name) { public <T extends TType> List<GradAndVar<?>> computeGradients(Operand<?> loss) { List<Operation> variables = new ArrayList<>(); - graph.operations().forEachRemaining((Operation op) -> { - if (op.type().equals(VARIABLE_V2)) { - variables.add(op); - } - }); + graph + .operations() + .forEachRemaining( + (Operation op) -> { + if (op.type().equals(VARIABLE_V2)) { + variables.add(op); + } + }); Output<?>[] variableOutputArray = new Output[variables.size()]; for (int i = 0; i < variables.size(); i++) { @@ -123,8 +157,8 @@ public <T extends TType> List<GradAndVar<?>> computeGradients(Operand<?> loss) { } public Op applyGradients(List<GradAndVar<? extends TType>> gradsAndVars, String name) { - List<Output<? extends TType>> variables = gradsAndVars.stream().map(GradAndVar::getVariable) - .collect(Collectors.toList()); + List<Output<? extends TType>> variables = + gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList()); createSlots(variables); @@ -142,7 +176,7 @@ public Op applyGradients(List<GradAndVar<? extends TType>> gradsAndVars, String /** * Gets the slot associated with the specified variable and slot name. * - * @param var The variable to lookup. + * @param var The variable to lookup. * @param slotName The slot name. * @return The slot or {@link Optional#empty}. */ @@ -153,7 +187,7 @@ public <T extends TType> Optional<Variable<T>> getSlot(Output<T> var, String slo /** * Gets the slot associated with the specified variable and slot name. * - * @param varName The variable to lookup. + * @param varName The variable to lookup. * @param slotName The slot name. * @return The slot or {@link Optional#empty}. */ @@ -163,7 +197,7 @@ private <T extends TType> Optional<Variable<T>> getSlot(String varName, String s Variable<? extends TType> slot = variables.get(varName); if (slot != null) { @SuppressWarnings("unchecked") // This method should only be called when the type is known. - Optional<Variable<T>> opt = Optional.of((Variable<T>) slot); + Optional<Variable<T>> opt = Optional.of((Variable<T>) slot); return opt; } return Optional.empty(); @@ -175,20 +209,20 @@ private <T extends TType> Optional<Variable<T>> getSlot(String varName, String s * Creates a slot in the graph for the specified variable with the specified name. Adds the slot's * initializer to the graph's initializers, and the slot to the Optimizer's slot map. * - * @param variable The variable to create the slot for. - * @param slotName The name of the slot. + * @param variable The variable to create the slot for. + * @param slotName The name of the slot. * @param initializer The initializer for the slot. - * @param <T> The type of the variable. + * @param <T> The type of the variable. */ - protected <T extends TType> void createSlot(Output<T> variable, String slotName, - Operand<T> initializer) { - Variable<T> slot = tf.withName(createName(variable, slotName)) - .variable(variable.shape(), variable.dataType()); + protected <T extends TType> void createSlot( + Output<T> variable, String slotName, Operand<T> initializer) { + Variable<T> slot = + tf.withName(createName(variable, slotName)).variable(variable.shape(), variable.dataType()); Assign<T> slotInit = tf.assign(slot, initializer); graph.addInitializer(slotInit); String varName = variable.op().name(); - Map<String, Variable<? extends TType>> variables = slots - .computeIfAbsent(slotName, (k) -> new HashMap<>()); + Map<String, Variable<? extends TType>> variables = + slots.computeIfAbsent(slotName, (k) -> new HashMap<>()); variables.put(varName, slot); } @@ -206,8 +240,7 @@ protected Optional<Op> prepare(String scopeName) { * * @param variables The variables to create slots for. */ - protected void createSlots(List<Output<? extends TType>> variables) { - } + protected void createSlots(List<Output<? extends TType>> variables) {} private <T extends TType> Op applyDense(GradAndVar<T> gradVarPair) { return applyDense(gradVarPair.getGradient(), gradVarPair.getVariable()); @@ -218,7 +251,7 @@ private <T extends TType> Op applyDense(GradAndVar<T> gradVarPair) { * * @param gradient The gradient to use. * @param variable The variable to update. - * @param <T> The type of the variable. + * @param <T> The type of the variable. * @return An operand which applies the desired optimizer update to the variable. */ protected abstract <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable); @@ -227,7 +260,7 @@ private <T extends TType> Op applyDense(GradAndVar<T> gradVarPair) { * Gathers up the update operations into a single op that can be used as a run target. * * @param updateOperations The update operations. - * @param name The name of the run target. + * @param name The name of the run target. * @return A NoOp with a control dependency on each update operation. */ protected Op finish(List<Op> updateOperations, String name) { @@ -238,44 +271,78 @@ protected Op finish(List<Op> updateOperations, String name) { } /** - * Name of the optimizer. + * Gets the Name of the optimizer. * * @return The optimizer name. */ public abstract String getOptimizerName(); /** - * Set the learning rate + * Sets the learning rate + * * @param learningRate the learning rate */ - public abstract void setLearningRate(float learningRate); + public final void setLearningRate(float learningRate) { + if (this.learningRatePlaceholder == null) { + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + } + + if (this.learningRate != learningRate) { + if (this.learningRateTensor != null) this.learningRateTensor.close(); + this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedMap = Collections.singletonMap(this.learningRatePlaceholder, learningRateTensor); + } + } /** - * Get the learning rate + * Gets the learning rate + * * @return the learning rate */ - public abstract float getLearningRate(); + public float getLearningRate() { + return this.learningRate; + } /** - * Get the Feed Dictionary for the run methods to set the Placeholder values(s) + * Gets the learning rate Operand, used by subclasses in their graph operations * - * @return the current Feed Dictionary for the run methods + * @return the learning rate Operand */ - public abstract Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict(); + protected Operand<TFloat32> getLearningRateOperand() { + return this.learningRatePlaceholder; + } /** - * Optional attributes for {@link org.tensorflow.framework.optimizers.Optimizer} + * Gets the Feed Map for the run methods to set the Placeholder value(s). Each entry in the Feed + * Map contains a PlaceHolder and a Tensor with the value + * + * @return the current Feed Map for the run methods, this may be null if an LearningRate as an + * Operand has been set. */ + public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedMap() { + return this.feedMap; + } + + public void close() { + // close the learningRate Tensor if it exists. + if (this.feedMap != null) { + this.feedMap.get(this.learningRatePlaceholder).close(); + } + } + + /** Optional attributes for {@link org.tensorflow.framework.optimizers.Optimizer} */ public static class Options { protected String sharedName; - private Options() { - } + private Options() {} /** * @param sharedName If non-empty, this variable is named in the given bucket with this - * shared_name. Otherwise, the node name is used instead. + * shared_name. Otherwise, the node name is used instead. */ public Optimizer.Options sharedName(String sharedName) { this.sharedName = sharedName; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizers.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizers.java new file mode 100644 index 00000000000..8d7f9620984 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizers.java @@ -0,0 +1,41 @@ +package org.tensorflow.framework.optimizers; + +import org.tensorflow.Graph; + +import java.util.function.Function; + +/** Enumerator used to create a new Optimizer with default parameters. */ +public enum Optimizers { + ADADELTA(AdaDelta::new), + ADAGRAD(AdaGrad::new), + ADAGRAD_DA(AdaGradDA::new), + ADAM(Adam::new), + ADAMAX(Adamax::new), + FTRL(Ftrl::new), + NADAM(Nadam::new), + RMSPROP(RMSProp::new), + MOMENTUM(Momentum::new), + GRADIENT_DESCENT(GradientDescent::new); + + private final Function<Graph, Optimizer> creator; + + /** + * Creates an Optimizers enum + * + * @param creator the lambda function that accepts a Graph argument used to create the default + * Optimizer + */ + Optimizers(Function<Graph, Optimizer> creator) { + this.creator = creator; + } + + /** + * Creates an Optimizer with default settings. + * + * @param graph the TensorFlow Graph + * @return the Optimizer + */ + public Optimizer createOptimizer(Graph graph) { + return creator.apply(graph); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index 3d28c016de7..face906d682 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -15,79 +15,152 @@ */ package org.tensorflow.framework.optimizers; -import java.util.Collections; -import java.util.List; -import java.util.Map; - import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.Tensor; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; -import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; +import java.util.List; + /** * Optimizer that implements the RMSProp algorithm. - * <p> - * See the <a href="http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf">lecture - * notes</a> that is inexplicably the canonical reference. + * + * <p>The gist of RMSprop is to: <nl> + * <li>Maintain a moving (discounted) average of the square of gradients + * <li>Divide the gradient by the root of this average </nl> + * + * <p> + * + * <p>This implementation of RMSprop uses plain momentum, not Nesterov momentum. + * + * <p> + * + * <p>The centered version additionally maintains a moving average of the gradients, and uses + * that average to estimate the variance. + * + * <p> + * + * @see <a href="http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf">Hinton G, + * et al. 2012, lecture notes</a> that is inexplicably the canonical reference. */ public class RMSProp extends Optimizer { + public static final float LEARNING_RATE_DEFAULT = 0.001f; + public static final float DECAY_DEFAULT = 0.9f; + public static final float MOMENTUM_DEFAULT = 0.0f; + public static final float EPSILON_DEFAULT = 1e-10f; + public static final boolean CENTERED_DEFAULT = false; public static final String RMS = "rms"; public static final String MG = "mg"; // mean gradient? public static final String MOMENTUM = "momentum"; - private float learningRate; - private Tensor<TFloat32> learningRateTensor; - private final Placeholder<TFloat32> learningRatePlaceholder; - private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; private final float decay; private final float momentum; private final float epsilon; private final boolean centered; + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + */ + public RMSProp(Graph graph) { + this( + graph, + LEARNING_RATE_DEFAULT, + DECAY_DEFAULT, + MOMENTUM_DEFAULT, + EPSILON_DEFAULT, + CENTERED_DEFAULT); + } + + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + * @param learningRate the learning rate + */ public RMSProp(Graph graph, float learningRate) { - this(graph, learningRate, 0.9f, 0.0f, 1e-10f, false); + this(graph, learningRate, DECAY_DEFAULT, MOMENTUM_DEFAULT, EPSILON_DEFAULT, CENTERED_DEFAULT); } - public RMSProp(Graph graph, float learningRate, float decay, float momentum, float epsilon, + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + * @param learningRate the learning rate + * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. + * @param momentum the acceleration factor, default is 0. + * @param epsilon A small constant for numerical stability + * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the + * gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code> + * true</code>> may help with training, but is slightly more expensive in terms of computation + * and memory. Defaults to <code>false</code>. + */ + public RMSProp( + Graph graph, + float learningRate, + float decay, + float momentum, + float epsilon, boolean centered) { - super(graph); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.learningRatePlaceholder = - tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); - + super(graph, learningRate); this.decay = decay; this.momentum = momentum; this.epsilon = epsilon; this.centered = centered; } + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer. Defaults to "RMSProp". + * @param learningRate the learning rate + */ public RMSProp(Graph graph, String name, float learningRate) { - this(graph, name, learningRate, 0.9f, 0.0f, 1e-10f, false); + this( + graph, + name, + learningRate, + DECAY_DEFAULT, + MOMENTUM_DEFAULT, + EPSILON_DEFAULT, + CENTERED_DEFAULT); } - public RMSProp(Graph graph, String name, float learningRate, float decay, float momentum, float epsilon, + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer. Defaults to "RMSProp". + * @param learningRate the learning rate + * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. + * @param momentum The acceleration factor, default is 0. + * @param epsilon A small constant for numerical stability + * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the + * gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code> + * true</code>> may help with training, but is slightly more expensive in terms of computation + * and memory. Defaults to <code>false</code>. + */ + public RMSProp( + Graph graph, + String name, + float learningRate, + float decay, + float momentum, + float epsilon, boolean centered) { - super(graph, name); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.learningRatePlaceholder = - tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + super(graph, name, learningRate); this.decay = decay; this.momentum = momentum; this.epsilon = epsilon; this.centered = centered; } + /** {@inheritDoc} */ @Override protected void createSlots(List<Output<? extends TType>> variables) { for (Output<? extends TType> v : variables) { @@ -95,85 +168,75 @@ protected void createSlots(List<Output<? extends TType>> variables) { } } + /** + * Creates the RMSProp Slots for Root Mean Squared (RMS), MOMENTUM, and Mean Gradient (MG) + * + * @param v the variable to install in the slot + * @param <T> the datatype of the variable. + */ private <T extends TType> void createRMSPropSlot(Output<T> v) { - Operand<T> rmsInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.dataType())); + Operand<T> rmsInitializer = + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.dataType())); createSlot(v.asOutput(), RMS, rmsInitializer); - Operand<T> momentumInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + Operand<T> momentumInitializer = + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { - Operand<T> mgInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + Operand<T> mgInitializer = + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); createSlot(v.asOutput(), MG, mgInitializer); } } + /** {@inheritDoc} */ @Override protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { Variable<T> rmsSlot = getSlot(variable, RMS).get(); Variable<T> momentumSlot = getSlot(variable, MOMENTUM).get(); if (centered) { Variable<T> mgSlot = getSlot(variable, MG).get(); - return tf.train.applyCenteredRmsProp(variable, mgSlot, rmsSlot, momentumSlot, - tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), + return tf.train.applyCenteredRmsProp( + variable, + mgSlot, + rmsSlot, + momentumSlot, + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), tf.dtypes.cast(tf.constant(decay), gradient.dataType()), tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), gradient); } - return tf.train.applyRmsProp(variable, rmsSlot, momentumSlot, - tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), + return tf.train.applyRmsProp( + variable, + rmsSlot, + momentumSlot, + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), tf.dtypes.cast(tf.constant(decay), gradient.dataType()), tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), gradient); } + /** {@inheritDoc} */ @Override public String toString() { - return "RMSProp{" + - "learningRate=" + learningRate + - ", decay=" + decay + - ", momentum=" + momentum + - ", epsilon=" + epsilon + - ", centered=" + centered + - '}'; + return "RMSProp{" + + "learningRate=" + + learningRate + + ", decay=" + + decay + + ", momentum=" + + momentum + + ", epsilon=" + + epsilon + + ", centered=" + + centered + + '}'; } + /** {@inheritDoc} */ @Override public String getOptimizerName() { return "RMSProp"; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public final void setLearningRate(float learningRate) { - this.learningRate = learningRate; - if (this.learningRateTensor != null) { - this.learningRateTensor.close(); - } - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); - } - - /** {@inheritDoc} */ - public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { - return this.feedDict; - } - - /** {@inheritDoc} */ - @Override - public void close() throws Exception { - if (this.learningRateTensor != null) { - this.learningRateTensor.close(); - this.learningRateTensor = null; - } - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java new file mode 100644 index 00000000000..43f85fa0ff1 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.optimizers.schedules;; + +/** + * A LearningRateSchedule that uses a piecewise constant decay schedule. + * <p> + * <p>The function computes the piecewise constant + when passed the current optimizer step. This can be useful for changing the + learning rate value across different invocations of optimizer functions. + * <p> + * <p>Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 + for the next 10000 steps, and 0.1 for any additional steps. + */ +public class PiecewiseConstantDecay implements LearningRateSchedule { + private float[] boundaries; + private float[] values; + + private int lastIndex = 0; + + /** + * Create an PiecewiseConstantDecay + * + * @param boundaries An array of with strictly increasing entries + * @param values An array that specifies the + values for the intervals defined by <code>boundaries</code>. It should have one + more element than <code>boundaries</code>. + * @throws java.lang.IllegalArgumentException if the the length of values does not have 1 more element than boundaries. + */ + public PiecewiseConstantDecay(float[] boundaries, float[] values) { + if(boundaries.length != values.length - 1) { + throw new IllegalArgumentException("The length of boundaries should be 1 less than the length of values"); + } + this.boundaries = boundaries; + this.values = values; + } + + + @Override + public float call(int step) { + if(lastIndex < boundaries.length && step > boundaries[lastIndex]) + lastIndex++; + return values[lastIndex]; + } + +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java new file mode 100644 index 00000000000..0988577c38f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.optimizers.schedules; + +/** + * A LearningRateSchedule that uses a polynomial decay schedule. + * + * <p> + * + * <p>It is commonly observed that a monotonically decreasing learning rate, whose degree of change + * is carefully chosen, results in a better performing model. This schedule applies a polynomial + * decay function to an optimizer step, given a provided `initial_learning_rate`, to reach an + * `end_learning_rate` in the given `decay_steps`. + * + * <p> + * + * <p>The schedule is a 1-arg callable that produces a decayed learning rate when passed the current + * optimizer step. This can be useful for changing the learning rate value across different + * invocations of optimizer functions. It is computed as: + * + * <pre> + * step = min(step, decay_steps) + * ((initialLearningRate - endLearningRate) * + * (1 - step / decaySteps) ^ (power) + * ) + endLearningRate + * </pre> + * + * <p> + * + * <p>If `cycle` is True then a multiple of `decay_steps` is used, the first one that is bigger than + * `step`. + */ +public class PolynomialDecay implements LearningRateSchedule { + private static final float END_LEARNING_RATE_DEFAULT = 0.0001f; + public static final float POWER_DEFAULT = 1.0f; + public static final boolean CYCLE_DEFAULT = false; + + protected final float initialLearningRate; + protected final float decaySteps; + protected final float endLearningRate; + protected final float power; + protected final boolean cycle; + + /** + * Create a PolynomialDecay + * + * @param initialLearningRate The initial learning rate. + * @param decaySteps How often to apply decay. + */ + public PolynomialDecay(float initialLearningRate, int decaySteps) { + this(initialLearningRate, decaySteps, END_LEARNING_RATE_DEFAULT, POWER_DEFAULT, CYCLE_DEFAULT); + } + + /** + * Create a PolynomialDecay + * + * @param initialLearningRate The initial learning rate. + * @param decaySteps How often to apply decay. + * @param cycle Whether or not it should cycle beyond decay_steps. Default is false. + */ + public PolynomialDecay(float initialLearningRate, int decaySteps, boolean cycle) { + this(initialLearningRate, decaySteps, END_LEARNING_RATE_DEFAULT, POWER_DEFAULT, cycle); + } + + /** + * Create a PolynomialDecay + * + * @param initialLearningRate The initial learning rate. + * @param decaySteps How often to apply decay. + * @param endLearningRate The end learning rate. Default is 0.0001. + */ + public PolynomialDecay(float initialLearningRate, int decaySteps, float endLearningRate) { + this(initialLearningRate, decaySteps, endLearningRate, POWER_DEFAULT, CYCLE_DEFAULT); + } + + /** + * Create a PolynomialDecay + * + * @param initialLearningRate The initial learning rate. + * @param decaySteps How often to apply decay. + * @param endLearningRate The end learning rate. Default is 0.0001. + * @param power The power of the polynomial. Defaults to linear, 1.0. + * @param cycle Whether or not it should cycle beyond decay_steps. Default is false. + */ + public PolynomialDecay( + float initialLearningRate, + int decaySteps, + float endLearningRate, + float power, + boolean cycle) { + this.initialLearningRate = initialLearningRate; + this.decaySteps = decaySteps; + this.endLearningRate = endLearningRate; + this.power = power; + this.cycle = cycle; + } + + @Override + public float call(int step) { + + float lDecaySteps = decaySteps; + float lStep = step; + if (cycle) { + float multipler = step == 0 ? 1.0f : (float) Math.ceil(step / decaySteps); + lDecaySteps = decaySteps * multipler; + } else { + lStep = Math.min(lStep, lDecaySteps); + } + + float p = lStep / lDecaySteps; + + float f = (this.initialLearningRate - this.endLearningRate) * (float) Math.pow(1.0f - p, power); + return f + endLearningRate; + } +} diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java similarity index 58% rename from tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java index 1cf20f1b0d2..ce5ad379629 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java @@ -12,11 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. =======================================================================*/ -package org.tensorflow.keras.optimizers; +package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import org.tensorflow.framework.optimizers.Optimizer; -import org.tensorflow.keras.utils.TestSession; +import org.tensorflow.Graph; +import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -24,25 +24,20 @@ import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TType; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.tensorflow.framework.optimizers.Momentum.MOMENTUM; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; -import static org.tensorflow.keras.optimizers.SGD.*; /** Test cases for SGD Optimizer */ -public class SGDTest { +public class MomentumTest { - private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; - int index; - - public SGDTest() {} + public MomentumTest() {} @BeforeAll public static void setUpClass() {} @@ -56,29 +51,13 @@ public void setUp() {} @AfterEach public void tearDown() {} - /** Test of create method, of class SGD. */ - @Test - public void testCreate() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - Map<String, Object> config = new HashMap<>(); - config.put(NAME_KEY, "Ftrl"); - config.put(LEARNING_RATE_KEY, 2.0F); - config.put(MOMENTUM_KEY, MOMENTUM_DEFAULT); - config.put(NESTEROV_KEY, NESTEROV_DEFAULT); - SGD expResult = new SGD(tf, 2.0F); - SGD result = SGD.create(tf, config); - assertEquals(expResult.getConfig(), result.getConfig()); - } - } - /** Test of getOptimizerName method, of class SGD. */ @Test public void testGetOptimizerName() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - SGD instance = new SGD(tf); - String expResult = "SGD"; + try (TestSession session = TestSession.createTestSession(tfMode)) { + Graph graph = session.getGraph(); + Momentum instance = new Momentum(graph); + String expResult = "Momentum"; String result = instance.getOptimizerName(); assertEquals(expResult, result); } @@ -86,49 +65,47 @@ public void testGetOptimizerName() { @Test public void testBasic() { - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.1F}; - float[] grads1_init = {0.01F, 0.01F}; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; float learningRate = 3.0F; - float epsilon = 1e-6F; - float epsilon1 = 1e-2F; - - try (TestSession session = TestSession.createTestSession(tf_mode)) { + try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + Graph graph = session.getGraph(); - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); - Constant<TFloat32> grads0 = tf.constant(grads0_init); - Constant<TFloat32> grads1 = tf.constant(grads1_init); + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - SGD instance = new SGD(tf, learningRate); + Momentum instance = new Momentum(graph, learningRate); Op update = instance.applyGradients(gradsAndVars, "SGDTest"); /* initialize the local variables */ session.run(var0Initializer); session.run(var1Initializer); - /** initialize the accumulators */ + /* initialize the accumulators */ session.run(tf.init()); - /** make sure the variables were initialized properly */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); - session.run(update, instance.getFeedDict()); // 1 step + session.run(update, instance.getFeedMap()); // 1 step float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F}; float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F}; @@ -139,37 +116,34 @@ public void testBasic() { @Test public void testMomentum() { - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.1F}; - float[] grads1_init = {0.01F, 0.01F}; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; float learningRate = 2.0F; float momentum = 0.9F; - float epsilon = 1e-6F; - float epsilon1 = 1e-2F; - - try (TestSession session = TestSession.createTestSession(tf_mode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Momentum instance = new Momentum(session.getGraph(), learningRate, momentum)) { Ops tf = session.getTF(); - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); - Constant<TFloat32> grads0 = tf.constant(grads0_init); - Constant<TFloat32> grads1 = tf.constant(grads1_init); + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - SGD instance = new SGD(tf, learningRate, momentum); Op update = instance.applyGradients(gradsAndVars, "SGDTest"); Variable<TFloat32> momentumSlot0 = instance.getSlot(var0.asOutput(), MOMENTUM).get(); @@ -181,14 +155,14 @@ public void testMomentum() { session.run(var0Initializer); session.run(var1Initializer); - /** initialize the accumulators */ + /* initialize the accumulators */ session.run(tf.init()); - /** make sure the variables were initialized properly */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); - session.run(update, instance.getFeedDict()); // 1 step + session.run(update, instance.getFeedMap()); // 1 step float[] expectedMomentum0 = {0.1F, 0.1F}; float[] expectedMomentum1 = {0.01F, 0.01F}; @@ -200,57 +174,55 @@ public void testMomentum() { session.evaluate(expectedVar0, var0); session.evaluate(expectedVar1, var1); - session.run(update, instance.getFeedDict()); // step 2 + session.run(update, instance.getFeedMap()); // step 2 - float[] expectedMomentum0_2 = {(0.9f * 0.1f + 0.1f), (0.9f * 0.1f + 0.1f)}; - float[] expectedMomentum1_2 = {(0.9f * 0.01f + 0.01f), (0.9f * 0.01f + 0.01f)}; - session.evaluate(expectedMomentum0_2, momentumSlot0); - session.evaluate(expectedMomentum1_2, momentumSlot1); + float[] expectedMomentum02 = {(0.9f * 0.1f + 0.1f), (0.9f * 0.1f + 0.1f)}; + float[] expectedMomentum12 = {(0.9f * 0.01f + 0.01f), (0.9f * 0.01f + 0.01f)}; + session.evaluate(expectedMomentum02, momentumSlot0); + session.evaluate(expectedMomentum12, momentumSlot1); - float[] expectedVar0_2 = { + float[] expectedVar02 = { 1.0F - (0.1F * 2.0F) - ((0.9F * 0.1F + 0.1F) * 2.0F), 2.0F - (0.1F * 2.0F) - ((0.9F * 0.1F + 0.1F) * 2.0F) }; - float[] expectedVar1_2 = { + float[] expectedVar12 = { 2.98F - ((0.9F * 0.01F + 0.01F) * 2.0F), 3.98F - ((0.9F * 0.01F + 0.01F) * 2.0F) }; - session.evaluate(expectedVar0_2, var0); - session.evaluate(expectedVar1_2, var1); + session.evaluate(expectedVar02, var0); + session.evaluate(expectedVar12, var1); } } @Test public void testWithLearningRateDecay() { int numSteps = 2; - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.1F}; - float[] grads1_init = {0.01F, 0.01F}; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; float learningRate = 3.0F; - float epsilon = 1e-6F; - float epsilon1 = 1e-2F; - try (TestSession session = TestSession.createTestSession(tf_mode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Momentum instance = new Momentum(session.getGraph(), learningRate)) { Ops tf = session.getTF(); - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); - Constant<TFloat32> grads0 = tf.constant(grads0_init); - Constant<TFloat32> grads1 = tf.constant(grads1_init); + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - SGD instance = new SGD(tf, learningRate); - Op update = instance.applyGradients(gradsAndVars, "SGDTest"); + Op update = instance.applyGradients(gradsAndVars, "MomentumTest"); Variable<TFloat32> momentumSlot0 = instance.getSlot(var0.asOutput(), MOMENTUM).get(); assertEquals(momentumSlot0.asOutput().shape(), var0.asOutput().shape()); @@ -261,12 +233,12 @@ public void testWithLearningRateDecay() { session.run(var0Initializer); session.run(var1Initializer); - /** initialize the accumulators */ + // initialize the accumulators session.run(tf.init()); - /** make sure the variables were initialized properly */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); + // make sure the variables were initialized properly + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); float[][] expectedVar0 = { {0.7F, 1.7F}, @@ -283,7 +255,9 @@ public void testWithLearningRateDecay() { {2.966667F, 3.966667F} }; for (int step = 0; step < numSteps; step++) { - session.run(update, instance.getFeedDict()); + assertEquals(learningRate, instance.getLearningRate(), 1e-6); + session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); session.evaluate(expectedVar0[step], var0); session.evaluate(expectedVar1[step], var1); learningRate *= 0.1; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java similarity index 50% rename from tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index 32d90ea91ed..fcdd1e3ef7c 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -12,13 +12,12 @@ See the License for the specific language governing permissions and limitations under the License. =======================================================================*/ -package org.tensorflow.keras.optimizers; +package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; import org.tensorflow.Tensor; -import org.tensorflow.framework.optimizers.Optimizer; -import org.tensorflow.keras.utils.ND; -import org.tensorflow.keras.utils.TestSession; +import org.tensorflow.framework.utils.ND; +import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -28,26 +27,21 @@ import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TType; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.keras.optimizers.Adamax.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.Nadam.*; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; /** Test cases for Nadam Optimizer */ public class NadamTest { - private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; private static final int VAR = 0; private static final int M = 1; private static final int V = 2; - int index = 0; float momentum = 1; public NadamTest() {} @@ -64,29 +58,12 @@ public void setUp() {} @AfterEach public void tearDown() {} - /** Test of create method, of class Nadam. */ - @Test - public void testCreate() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - Map<String, Object> config = new HashMap<>(); - config.put(NAME_KEY, "AdaDelta"); - config.put(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - config.put(BETA_ONE_KEY, BETA_ONE_DEFAULT); - config.put(BETA_TWO_KEY, BETA_TWO_DEFAULT); - config.put(EPSILON_KEY, EPSILON_DEFAULT); - AdaDelta expResult = new AdaDelta(tf); - AdaDelta result = AdaDelta.create(tf, config); - assertEquals(expResult.getConfig(), result.getConfig()); - } - } - /** Test of getOptimizerName method, of class Nadam. */ @Test public void testGetOptimizerName() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - Nadam instance = new Nadam(tf); + try (TestSession session = TestSession.createTestSession(tfMode); + Nadam instance = new Nadam(session.getGraph())) { + String expResult = "Nadam"; String result = instance.getOptimizerName(); assertEquals(expResult, result); @@ -99,10 +76,10 @@ public void testBasic() { int numSteps = 3; - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.1F}; - float[] grads1_init = {0.01F, 0.01F}; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; float[] zeros = {0.0F, 0.0F}; float[] ones = {1.0F, 1.0F}; @@ -111,63 +88,64 @@ public void testBasic() { FloatNdArray m1 = NdArrays.vectorOf(zeros); FloatNdArray v1 = NdArrays.vectorOf(zeros); FloatNdArray mcache = NdArrays.vectorOf(ones); - FloatNdArray var0_np = NdArrays.vectorOf(var0_init); - FloatNdArray var1_np = NdArrays.vectorOf(var1_init); - FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); - FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); - float epsilon = 1e-6f; float epsilon1 = 1e-3F; - try (TestSession session = TestSession.createTestSession(tf_mode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Nadam instance = new Nadam(session.getGraph())) { Ops tf = session.getTF(); - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); - Constant<TFloat32> grads0 = tf.constant(grads0_init); - Constant<TFloat32> grads1 = tf.constant(grads1_init); + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); - Nadam instance = new Nadam(tf); /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); Op update = instance.applyGradients(gradsAndVars, "AdamTest"); /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") Variable<TFloat32>[] secondMomentSlots = new Variable[2]; - firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get(); assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); - secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.SECOND_MOMENT).get(); assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); - firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.FIRST_MOMENT).get(); assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); - secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.SECOND_MOMENT).get(); assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); /* initialize the local variables */ session.run(var0Initializer); session.run(var1Initializer); - /** initialize the accumulators */ + /* initialize the accumulators */ session.run(tf.init()); session.setEpsilon(epsilon1); - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); try (Tensor<TFloat32> result = session @@ -177,19 +155,13 @@ public void testBasic() { .run() .get(0) .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> { - assertEquals(1F, f.getFloat(), epsilon1); - }); + result.data().scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); } momentum = 1F; for (int step = 0; step < numSteps; step++) { - session.run(update, instance.getFeedDict()); + session.run(update, instance.getFeedMap()); float mut = Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); @@ -203,22 +175,16 @@ public void testBasic() { .run() .get(0) .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> { - assertEquals(momentum, f.getFloat(), epsilon1); - }); + result.data().scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); } mcache = ND.mul(mcache, momentum); - FloatNdArray[] resultsNP = nadam_update_numpy(var0_np, grads0_np, step, m0, v0, mcache); - var0_np = resultsNP[VAR]; + FloatNdArray[] resultsNP = nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache); + var0Np = resultsNP[VAR]; m0 = resultsNP[M]; v0 = resultsNP[V]; - resultsNP = nadam_update_numpy(var1_np, grads1_np, step, m1, v1, mcache); - var1_np = resultsNP[VAR]; + resultsNP = nadamUpdateNdArray(var1Np, grads1Np, step, m1, v1, mcache); + var1Np = resultsNP[VAR]; m1 = resultsNP[M]; v1 = resultsNP[V]; @@ -231,8 +197,8 @@ public void testBasic() { session.evaluate(v1, secondMomentSlots[1]); // evaluate var0 and var1 - session.evaluate(var0_np, var0); - session.evaluate(var1_np, var1); + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); } } } @@ -241,10 +207,10 @@ public void testBasic() { public void testWithLearningRateDecay() { int numSteps = 3; - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.1F}; - float[] grads1_init = {0.01F, 0.01F}; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; float[] zeros = {0.0F, 0.0F}; float[] ones = {1.0F, 1.0F}; @@ -253,117 +219,115 @@ public void testWithLearningRateDecay() { FloatNdArray m1 = NdArrays.vectorOf(zeros); FloatNdArray v1 = NdArrays.vectorOf(zeros); FloatNdArray mcache = NdArrays.vectorOf(ones); - FloatNdArray var0_np = NdArrays.vectorOf(var0_init); - FloatNdArray var1_np = NdArrays.vectorOf(var1_init); - FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); - FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); - float epsilon = 1e-6f; float epsilon1 = 1e-3F; float learningRate = 0.001F; - try (TestSession session = TestSession.createTestSession(tf_mode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Nadam instance = new Nadam(session.getGraph(), learningRate)) { Ops tf = session.getTF(); - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); - Constant<TFloat32> grads0 = tf.constant(grads0_init); - Constant<TFloat32> grads1 = tf.constant(grads1_init); - Nadam instance = new Nadam(tf, learningRate); /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); Op update = instance.applyGradients(gradsAndVars, "AdamTest"); /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") Variable<TFloat32>[] secondMomentSlots = new Variable[2]; - firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get(); assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); - secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.SECOND_MOMENT).get(); assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); - firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.FIRST_MOMENT).get(); assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); - secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.SECOND_MOMENT).get(); assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); /* initialize the local variables */ session.run(var0Initializer); session.run(var1Initializer); - /** initialize the accumulators */ + // initialize the accumulators session.run(tf.init()); session.setEpsilon(epsilon1); - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); try (Tensor<TFloat32> result = - session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0) - .expect(TFloat32.DTYPE)) { + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { result - .data() - .scalars() - .forEach( - f -> { - assertEquals(1F, f.getFloat(), epsilon1); - }); + .data() + .scalars() + .forEach( + f -> assertEquals(1F, f.getFloat(), epsilon1)); } momentum = 1F; for (int step = 0; step < numSteps; step++) { - - session.run(update, instance.getFeedDict()); + assertEquals(learningRate, instance.getLearningRate(), 1e-6f); + session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); float mut = - Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); + Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); momentum = momentum * mut; try (Tensor<TFloat32> result = - session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0) - .expect(TFloat32.DTYPE)) { + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { result - .data() - .scalars() - .forEach( - f -> { - assertEquals(momentum, f.getFloat(), epsilon1); - }); + .data() + .scalars() + .forEach( + f -> assertEquals(momentum, f.getFloat(), epsilon1)); } mcache = ND.mul(mcache, momentum); - FloatNdArray[] resultsNP = - nadam_update_numpy(var0_np, grads0_np, step, m0, v0, mcache, learningRate); - var0_np = resultsNP[VAR]; + FloatNdArray[] resultsNP = nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache, learningRate); + var0Np = resultsNP[VAR]; m0 = resultsNP[M]; v0 = resultsNP[V]; - resultsNP = nadam_update_numpy(var1_np, grads1_np, step, m1, v1, mcache, learningRate); - var1_np = resultsNP[VAR]; + resultsNP = nadamUpdateNdArray(var1Np, grads1Np, step, m1, v1, mcache, learningRate); + var1Np = resultsNP[VAR]; m1 = resultsNP[M]; v1 = resultsNP[V]; @@ -376,8 +340,8 @@ public void testWithLearningRateDecay() { session.evaluate(v1, secondMomentSlots[1]); // evaluate var0 and var1 - session.evaluate(var0_np, var0); - session.evaluate(var1_np, var1); + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); learningRate *= 0.9; instance.setLearningRate(learningRate); @@ -385,50 +349,45 @@ public void testWithLearningRateDecay() { } } - private FloatNdArray update_m_cache(FloatNdArray mcache, int t) { - float mu_t = 0.9F * (1.0F - 0.5F * (float) Math.pow(0.96, (0.004 * (t + 1)))); - return ND.mul(mu_t, mcache); - } - private FloatNdArray[] nadam_update_numpy( - FloatNdArray var_np, - FloatNdArray grads_np, - int t, - FloatNdArray m, - FloatNdArray v, - FloatNdArray m_cache) { - return nadam_update_numpy(var_np, grads_np, t, m, v, m_cache, 0.001F); + private FloatNdArray[] nadamUpdateNdArray( + FloatNdArray varNp, + FloatNdArray gradsNp, + int t, + FloatNdArray m, + FloatNdArray v, + FloatNdArray mCache) { + return nadamUpdateNdArray(varNp, gradsNp, t, m, v, mCache, 0.001F); } - - private FloatNdArray[] nadam_update_numpy( - FloatNdArray var_np, - FloatNdArray grads_np, - int t, - FloatNdArray m, - FloatNdArray v, - FloatNdArray m_cache, - float alpha) { + private FloatNdArray[] nadamUpdateNdArray( + FloatNdArray varNp, + FloatNdArray gradsNp, + int t, + FloatNdArray m, + FloatNdArray v, + FloatNdArray mCache, + float alpha) { float beta1 = 0.9F; float beta2 = 0.999F; float epsilon = 1e-8F; - float mu_t = beta1 * (1F - 0.5F * (float) Math.pow(0.96, 0.004 * (t + 1))); - float mu_t_1 = beta1 * (1F - 0.5F * (float) Math.pow(0.96, (0.004 * (t + 2)))); - FloatNdArray m_cache_t_1 = ND.mul(m_cache, mu_t_1); - FloatNdArray g_prime_t = ND.div(grads_np, ND.sub(1.0F, m_cache)); - FloatNdArray m_t = ND.add(ND.mul(beta1, m), ND.mul((1 - beta1), grads_np)); - FloatNdArray v_t = ND.add(ND.mul(beta2, v), ND.mul((1 - beta2), ND.square(grads_np))); - - FloatNdArray m_prime_t = ND.div(m_t, ND.sub(1.F, m_cache_t_1)); - FloatNdArray v_prime_t = ND.div(v_t, 1.F - (float) Math.pow(beta2, t + 1)); - FloatNdArray m_bar_t = ND.add(ND.mul((1 - mu_t), g_prime_t), ND.mul(mu_t_1, m_prime_t)); - FloatNdArray param_t = - ND.sub(var_np, ND.div(ND.mul(alpha, m_bar_t), ND.add(ND.sqrt(v_prime_t), epsilon))); + float muT = beta1 * (1F - 0.5F * (float) Math.pow(0.96, 0.004 * (t + 1))); + float muT1 = beta1 * (1F - 0.5F * (float) Math.pow(0.96, (0.004 * (t + 2)))); + FloatNdArray mCacheT1 = ND.mul(mCache, muT1); + FloatNdArray gPrimeT = ND.div(gradsNp, ND.sub(1.0F, mCache)); + FloatNdArray mT = ND.add(ND.mul(beta1, m), ND.mul((1 - beta1), gradsNp)); + FloatNdArray vT = ND.add(ND.mul(beta2, v), ND.mul((1 - beta2), ND.square(gradsNp))); + + FloatNdArray mPrimeT = ND.div(mT, ND.sub(1.F, mCacheT1)); + FloatNdArray vPrimeT = ND.div(vT, 1.F - (float) Math.pow(beta2, t + 1)); + FloatNdArray mBarT = ND.add(ND.mul((1 - muT), gPrimeT), ND.mul(muT1, mPrimeT)); + FloatNdArray paramT = + ND.sub(varNp, ND.div(ND.mul(alpha, mBarT), ND.add(ND.sqrt(vPrimeT), epsilon))); FloatNdArray[] results = new FloatNdArray[3]; - results[VAR] = param_t; - results[M] = m_t; - results[V] = v_t; + results[VAR] = paramT; + results[M] = mT; + results[V] = vT; return results; } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java new file mode 100644 index 00000000000..a0bf027abab --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java @@ -0,0 +1,134 @@ +package org.tensorflow.framework.optimizers; + +import org.junit.jupiter.api.*; +import org.tensorflow.framework.utils.TestSession; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class OptimizersTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + public OptimizersTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + /** Test ADADELTA enum */ + @Test + public void testADADELTA() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADADELTA.createOptimizer(session.getGraph())) { + String expResult = "Adadelta"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test ADAGRAD enum */ + @Test + public void testADAGRAD() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADAGRAD.createOptimizer(session.getGraph())) { + String expResult = "Adagrad"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test ADAGRAD_DA enum */ + @Test + public void testADAGRAD_DA() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADAGRAD_DA.createOptimizer(session.getGraph())) { + String expResult = "adagrad-da"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test ADAM enum */ + @Test + public void testADAM() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADAM.createOptimizer(session.getGraph())) { + String expResult = "Adam"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test ADAMAX enum */ + @Test + public void testADAMAX() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADAMAX.createOptimizer(session.getGraph())) { + String expResult = "Adamax"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test FTRL enum */ + @Test + public void testFTRL() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.FTRL.createOptimizer(session.getGraph())) { + String expResult = "Ftrl"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test NADAM enum */ + @Test + public void testNADAM() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.NADAM.createOptimizer(session.getGraph())) { + String expResult = "Nadam"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test RMSPROP enum */ + @Test + public void testRMSPROP() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.RMSPROP.createOptimizer(session.getGraph())) { + String expResult = "RMSProp"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test MOMENTUM enum */ + @Test + public void testMOMENTUM() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.MOMENTUM.createOptimizer(session.getGraph())) { + String expResult = "Momentum"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test GRADIENT_DESCENT enum */ + @Test + public void testGRADIENT_DESCENT() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.GRADIENT_DESCENT.createOptimizer(session.getGraph())) { + String expResult = "GradientDescent"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java new file mode 100644 index 00000000000..6d489951c77 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java @@ -0,0 +1,450 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.optimizers; + +import org.junit.jupiter.api.*; +import org.tensorflow.framework.utils.ND; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; +import static org.tensorflow.framework.optimizers.RMSProp.*; + +/** Test cases for RMSProp Optimizer */ +public class RMSPropTest { + final int VAR_T = 0; + final int MG_T = 1; + final int RMS_T = 2; + final int MOM_T = 3; + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + Object[][] testParamValues = { + // learningRate, rho (decay), momentum, epsilon, centered + {0.05F, 0.9F, 0.0F, 1e-3F, true}, + {0.05F, 0.9F, 0.0F, 1e-3F, false}, + {0.1F, 0.9F, 0.0F, 1e-3F, true}, + {0.01F, 0.9F, 0.0F, 1e-5F, true}, + {0.01F, 0.9F, 0.9F, 1e-5F, true} + }; + + public RMSPropTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + @Test + public void testDense() { + + int numSteps = 3; + + for (Object[] testParamValue : testParamValues) { + // learningRate, rho (decay), momentum, epsilon, centered + float learningRate = (float) (float) testParamValue[0]; + float decay = (float) testParamValue[1]; + float momentum = (float) testParamValue[2]; + float epsilon = (float) testParamValue[3]; + boolean centered = (boolean) testParamValue[4]; + try (TestSession session = TestSession.createTestSession(tfMode); + RMSProp instance = + new RMSProp(session.getGraph(), learningRate, decay, momentum, epsilon, centered)) { + Ops tf = session.getTF(); + + session.setEpsilon(1e-2f); + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.2F}; + + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + Variable<TFloat32> mg0 = + centered && instance.getSlot(var0.asOutput(), MG).isPresent() + ? instance.getSlot(var0.asOutput(), MG).get() + : null; + Variable<TFloat32> mg1 = + centered && instance.getSlot(var1.asOutput(), MG).isPresent() + ? instance.getSlot(var1.asOutput(), MG).get() + : null; + Variable<TFloat32> mom0 = + momentum > 0.F && instance.getSlot(var0.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var0.asOutput(), MOMENTUM).get() + : null; + Variable<TFloat32> mom1 = + momentum > 0.F && instance.getSlot(var1.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var1.asOutput(), MOMENTUM).get() + : null; + Variable<TFloat32> rms0 = + instance.getSlot(var0.asOutput(), RMS).isPresent() + ? instance.getSlot(var0.asOutput(), RMS).get() + : null; + Variable<TFloat32> rms1 = + instance.getSlot(var1.asOutput(), RMS).isPresent() + ? instance.getSlot(var1.asOutput(), RMS).get() + : null; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; // temp to match RMSProp + FloatNdArray mg0Np = NdArrays.vectorOf(zeros); + FloatNdArray mg1Np = NdArrays.vectorOf(zeros); + FloatNdArray rms0Np = NdArrays.vectorOf(ones); + FloatNdArray rms1Np = NdArrays.vectorOf(ones); + FloatNdArray mom0Np = NdArrays.vectorOf(zeros); + FloatNdArray mom1Np = NdArrays.vectorOf(zeros); + + for (int i = 0; i < numSteps; i++) { + session.run(update, instance.getFeedMap()); + FloatNdArray[] result0 = + calc( + var0Np, + grads0Np, + mg0Np, + rms0Np, + mom0Np, + learningRate, + decay, + momentum, + epsilon, + centered); + var0Np = result0[VAR_T]; + mg0Np = result0[MG_T]; + rms0Np = result0[RMS_T]; + mom0Np = result0[MOM_T]; + + FloatNdArray[] result1 = + calc( + var1Np, + grads1Np, + mg1Np, + rms1Np, + mom1Np, + learningRate, + decay, + momentum, + epsilon, + centered); + + var1Np = result1[VAR_T]; + mg1Np = result1[MG_T]; + rms1Np = result1[RMS_T]; + mom1Np = result1[MOM_T]; + + if (centered) { + if (mg0 != null) session.evaluate(mg0Np, mg0); + if (mg1 != null) session.evaluate(mg1Np, mg1); + } + + if (mom0 != null) session.evaluate(mom0Np, mom0); + if (mom1 != null) session.evaluate(mom1Np, mom1); + + /* TODO the values returned from rms slot, do not match what I see in the python test */ + if (rms0 != null) session.evaluate(rms0Np, rms0); + else fail("rms0 is null"); + if (rms1 != null) session.evaluate(rms1Np, rms1); + else fail("rms1 is null"); + + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + } + } + } + } + + @Test + public void testWithLearningRateDecay() { + int numSteps = 3; + + for (Object[] testParamValue : testParamValues) { + float learningRate = (float) testParamValue[0]; + float decay = (float) testParamValue[1]; + float momentum = (float) testParamValue[2]; + float epsilon = (float) testParamValue[3]; + boolean centered = (boolean) testParamValue[4]; + + try (TestSession session = TestSession.createTestSession(tfMode); + RMSProp instance = + new RMSProp(session.getGraph(), learningRate, decay, momentum, epsilon, centered)) { + Ops tf = session.getTF(); + session.setEpsilon(1e-2f); + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.2F}; + float[] grads1_init = {0.01F, 0.2F}; + + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> grads0 = tf.constant(grads0_init); + Constant<TFloat32> grads1 = tf.constant(grads1_init); + + /* build the GradsAnvVars */ + List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + // initialize the accumulators + session.run(tf.init()); + + // make sure the variables were initialized properly + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + Variable<TFloat32> mg0 = + centered && instance.getSlot(var0.asOutput(), MG).isPresent() + ? instance.getSlot(var0.asOutput(), MG).get() + : null; + Variable<TFloat32> mg1 = + centered && instance.getSlot(var1.asOutput(), MG).isPresent() + ? instance.getSlot(var1.asOutput(), MG).get() + : null; + Variable<TFloat32> mom0 = + momentum > 0.F && instance.getSlot(var0.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var0.asOutput(), MOMENTUM).get() + : null; + Variable<TFloat32> mom1 = + momentum > 0.F && instance.getSlot(var1.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var1.asOutput(), MOMENTUM).get() + : null; + Variable<TFloat32> rms0 = + instance.getSlot(var0.asOutput(), RMS).isPresent() + ? instance.getSlot(var0.asOutput(), RMS).get() + : null; + Variable<TFloat32> rms1 = + instance.getSlot(var1.asOutput(), RMS).isPresent() + ? instance.getSlot(var1.asOutput(), RMS).get() + : null; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; // temp to match RMSProp + FloatNdArray mg0_np = NdArrays.vectorOf(zeros); + FloatNdArray mg1_np = NdArrays.vectorOf(zeros); + FloatNdArray rms0_np = NdArrays.vectorOf(ones); + FloatNdArray rms1_np = NdArrays.vectorOf(ones); + FloatNdArray mom0_np = NdArrays.vectorOf(zeros); + FloatNdArray mom1_np = NdArrays.vectorOf(zeros); + + for (int i = 0; i < numSteps; i++) { + assertEquals(learningRate, instance.getLearningRate(), epsilon); + session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); + FloatNdArray[] result0 = + calc( + var0_np, + grads0_np, + mg0_np, + rms0_np, + mom0_np, + learningRate, + decay, + momentum, + epsilon, + centered); + var0_np = result0[VAR_T]; + mg0_np = result0[MG_T]; + rms0_np = result0[RMS_T]; + mom0_np = result0[MOM_T]; + + FloatNdArray[] result1 = + calc( + var1_np, + grads1_np, + mg1_np, + rms1_np, + mom1_np, + learningRate, + decay, + momentum, + epsilon, + centered); + + var1_np = result1[VAR_T]; + mg1_np = result1[MG_T]; + rms1_np = result1[RMS_T]; + mom1_np = result1[MOM_T]; + + if (centered) { + if (mg0 != null) session.evaluate(mg0_np, mg0); + else fail("mg0 is null"); + if (mg1 != null) session.evaluate(mg1_np, mg1); + else fail("mg1 is null"); + } + if (momentum > 0.F) { + if (mom0 != null) session.evaluate(mom0_np, mom0); + if (mom1 != null) session.evaluate(mom1_np, mom1); + } + + /* TODO the values returned from rms slot, do not match what I see in the python test */ + if (rms0 != null) session.evaluate(rms0_np, rms0); + else fail("rms0 is null"); + if (rms1 != null) session.evaluate(rms1_np, rms1); + else fail("rms1 is null"); + + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); + + learningRate *= 0.9F; + instance.setLearningRate(learningRate); + } + } + } + } + + FloatNdArray[] calc( + FloatNdArray varNp, + FloatNdArray gradNp, + FloatNdArray mgNp, + FloatNdArray rmsNp, + FloatNdArray mom, + float lr, + float decay, + float momentum, + float epsilon, + boolean centered) { + + FloatNdArray[] result = new FloatNdArray[4]; // varT, mgT, rmsT, momT + result[RMS_T] = calcRMS(rmsNp, gradNp, decay); // RMS + + FloatNdArray denomT; + if (centered) { + result[MG_T] = calcMG(mgNp, gradNp, decay); + // rmsT - mgT * mgT + denomT = ND.sub(result[RMS_T], ND.square(result[MG_T])); + } else { + result[MG_T] = mgNp; + denomT = rmsNp; + } + if (momentum > 0.F) { + // momentum * mom + lr * g / (np.sqrt(denomT + epsilon)) + result[MOM_T] = calcMom(momentum, mom, lr, gradNp, denomT, epsilon); + // varT = var - momT + result[VAR_T] = ND.sub(varNp, result[MOM_T]); + } else { + result[MOM_T] = mom; + result[VAR_T] = calcVar(varNp, gradNp, lr, denomT, epsilon); + } + + return result; + } + + private FloatNdArray calcRMS(FloatNdArray rmsNp, FloatNdArray gradNp, float decay) { + // rms * rho + (1 - rho) * g * g + FloatNdArray rmsRho = ND.mul(rmsNp, decay); + FloatNdArray squareG = ND.square(gradNp); + float oneRHO = 1.0F - decay; + FloatNdArray decayG2 = ND.mul(oneRHO, squareG); + return ND.add(rmsRho, decayG2); + } + + private FloatNdArray calcMG(FloatNdArray mgNp, FloatNdArray gradNp, float decay) { + // mgT = mg * rho + (1 - rho) * g + FloatNdArray mgRho = ND.mul(mgNp, decay); + float oneRHO = 1.0F - decay; + FloatNdArray decayG = ND.mul(oneRHO, gradNp); + return ND.add(mgRho, decayG); + } + + private FloatNdArray calcMom( + float momentum, + FloatNdArray mom, + float lr, + FloatNdArray gradNp, + FloatNdArray denomT, + float epsilon) { + // momentum * mom + lr * g / (np.sqrt(denomT + epsilon)) + FloatNdArray moMo = ND.mul(momentum, mom); + FloatNdArray dividend = ND.mul(lr, gradNp); + FloatNdArray divisor = ND.sqrt(ND.add(denomT, epsilon)); + FloatNdArray quotient = ND.div(dividend, divisor); + return ND.add(moMo, quotient); + } + + private FloatNdArray calcVar( + FloatNdArray varNp, FloatNdArray gradNp, float lr, FloatNdArray denomT, float epsilon) { + // var - lr * g / (np.sqrt(denomT) + epsilon) + FloatNdArray dividend = ND.mul(lr, gradNp); + FloatNdArray divisor = ND.add(ND.sqrt(denomT), epsilon); + FloatNdArray quotient = ND.div(dividend, divisor); + return ND.sub(varNp, quotient); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java new file mode 100644 index 00000000000..dac8caa19a3 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java @@ -0,0 +1,16 @@ +package org.tensorflow.framework.optimizers.schedules; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class PiecewiseConstantDecayTest { + + public PiecewiseConstantDecayTest() {} + + @Test + public void testDecay() { + + } + +} \ No newline at end of file diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java new file mode 100644 index 00000000000..a28e56ad7cb --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java @@ -0,0 +1,24 @@ +package org.tensorflow.framework.optimizers.schedules; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class PolynomialDecayTest { + + public PolynomialDecayTest() {} + + @Test + public void testBeginWithCycle() { + float initialLearningRate = 0.1f; + int decaySteps = 10; + float decayRate = 0.96f; + float epsilon = 1e-6f; + PolynomialDecay instance = new PolynomialDecay(initialLearningRate, decaySteps, true); + float expected = initialLearningRate; + float actual = instance.call(0); + assertEquals(expected, actual, epsilon); + + } + +} \ No newline at end of file diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java similarity index 96% rename from tensorflow-keras/src/test/java/org/tensorflow/keras/utils/ND.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index 2855af5af25..0503a41dfc2 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -12,18 +12,20 @@ See the License for the specific language governing permissions and limitations under the License. =======================================================================*/ -package org.tensorflow.keras.utils; +package org.tensorflow.framework.utils; -import java.util.Arrays; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + // TODO used in the Callbacks, this should be a part of NDArray? + /** NDArray math Utilities */ public class ND { @@ -126,7 +128,7 @@ public static FloatNdArray square(FloatNdArray a) { * @return the resulting array from the add operation */ public static FloatNdArray add(FloatNdArray a, FloatNdArray b) { - if(a.shape().size() != b.shape().size()) + if (a.shape().size() != b.shape().size()) throw new IllegalArgumentException("a and b muse have the same number of dimensions"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); @@ -176,7 +178,7 @@ public static FloatNdArray add(float scalar, FloatNdArray a) { * @return the resulting array from the subtraction operation */ public static FloatNdArray sub(FloatNdArray a, FloatNdArray b) { - if(a.shape().size() != b.shape().size()) + if (a.shape().size() != b.shape().size()) throw new IllegalArgumentException("a and b muse have the same number of dimensions"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); @@ -232,9 +234,10 @@ public static FloatNdArray sub(float scalar, FloatNdArray a) { * @return the resulting array from the muliply operation */ public static FloatNdArray mul(FloatNdArray a, FloatNdArray b) { - if(!a.shape().equals(b.shape())) - throw new IllegalArgumentException(String.format( - "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape())); + if (!a.shape().equals(b.shape())) + throw new IllegalArgumentException( + String.format( + "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape())); boolean sameSize = a.shape().size() == b.shape().size(); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); @@ -289,7 +292,7 @@ public static FloatNdArray mul(float scalar, FloatNdArray a) { * @return the resulting array from the Divide operation */ public static FloatNdArray div(FloatNdArray a, FloatNdArray b) { - if(a.shape().size() != b.shape().size()) + if (a.shape().size() != b.shape().size()) throw new IllegalArgumentException("a and b muse have the same number of dimensions"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); @@ -309,8 +312,7 @@ public static FloatNdArray div(FloatNdArray a, FloatNdArray b) { * @return the resulting array from the Divide operation */ public static FloatNdArray div(FloatNdArray a, float scalar) { - if(scalar == 0) - throw new IllegalArgumentException("Cannot divide by zero"); + if (scalar == 0) throw new IllegalArgumentException("Cannot divide by zero"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); a.elements(nDims - 1) @@ -348,7 +350,7 @@ public static FloatNdArray div(float scalar, FloatNdArray a) { * @return the array result of the power operation */ public static FloatNdArray pow(FloatNdArray a, FloatNdArray b) { - if(a.shape().size() != b.shape().size()) + if (a.shape().size() != b.shape().size()) throw new IllegalArgumentException("a and b muse have the same number of dimensions"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); @@ -444,10 +446,10 @@ public static float min(FloatNdArray a) { * @param a the first array * @param a the second array * @return the resulting array with the maximum values between each element of the arrays. - * @throws java.lang.AssertionError if the two arrays are not the same size. + * @throws AssertionError if the two arrays are not the same size. */ public static FloatNdArray max(FloatNdArray a, FloatNdArray b) { - if(a.shape().size() != b.shape().size()) + if (a.shape().size() != b.shape().size()) throw new IllegalArgumentException("a and b muse have the same number of dimensions"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); @@ -496,10 +498,10 @@ public static FloatNdArray max(float scalar, FloatNdArray a) { * @param a the first array * @param a the second array * @return the resulting array with the minimum values between each element of the arrays. - * @throws java.lang.AssertionError if the two arrays are not the same size. + * @throws AssertionError if the two arrays are not the same size. */ public static FloatNdArray min(FloatNdArray a, FloatNdArray b) { - if(a.shape().size() != b.shape().size()) + if (a.shape().size() != b.shape().size()) throw new IllegalArgumentException("a and b muse have the same number of dimensions"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java similarity index 82% rename from tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index cd4b891a039..47c39e820fc 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. =======================================================================*/ -package org.tensorflow.keras.utils; +package org.tensorflow.framework.utils; import org.tensorflow.*; import org.tensorflow.ndarray.FloatNdArray; @@ -41,7 +41,7 @@ public abstract class TestSession implements AutoCloseable { /** Enumerate between Eager and Graph Mode */ public enum Mode { EAGER, - GRAPH; + GRAPH } public static TestSession createEagerSession() { @@ -56,10 +56,21 @@ public static TestSession createTestSession(Mode mode) { return mode == Mode.EAGER ? createEagerSession() : createGraphSession(); } + /** + * Initializer any graph initializers, if in Graph mode, for Eager mode, this method does nothing. + */ public void initialize() { // empty } + /** + * Returns the Graph if in Graph mode, or null if in EagerMode + * @return the Graph if in Graph mode, or null if in EagerMode + */ + public Graph getGraph() { + return null; + } + /** * Perform session.run() * @@ -67,7 +78,10 @@ public void initialize() { * * @param op The Operation to run */ - public abstract void run(Op op); + public void run(Op op) { + run(op, null); + } + /** * Perform session.run() @@ -75,10 +89,10 @@ public void initialize() { * <p>If in eager mode, this does nothing. * * @param op The Operation to run - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ - public abstract void run(Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + public abstract void run(Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** * Evaluate the expected results versus the actual results @@ -96,15 +110,15 @@ public <U extends TNumber> void evaluate(Number expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( Number expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - evaluate(new Number[] {expected}, input, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(new Number[] {expected}, input, feedMap); } /** @@ -122,13 +136,12 @@ public void evaluate(Number expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. - * @param <T> the data type for the feedDict entries */ - public <T extends TType> void evaluate( - Number expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - evaluate(new Number[] {expected}, input, feedDict); + public void evaluate( + Number expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(new Number[] {expected}, input, feedMap); } /** @@ -148,16 +161,16 @@ public <U extends TNumber> void evaluate(Number[] expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type for the input */ public <U extends TNumber> void evaluate( Number[] expected, Op input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Output<U> output = input.op().output(0); - evaluate(expected, output, feedDict); + evaluate(expected, output, feedMap); } /** @@ -177,16 +190,16 @@ public <U extends TNumber> void evaluate(Number[] expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( Number[] expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - Output output = input.asOutput(); - evaluate(expected, output, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + Output<U> output = input.asOutput(); + evaluate(expected, output, feedMap); } /** @@ -205,15 +218,15 @@ public <U extends TNumber> void evaluate(byte expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( byte expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - evaluate((double) expected, input, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -232,15 +245,15 @@ public <U extends TNumber> void evaluate(int expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( int expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - evaluate((double) expected, input, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -259,15 +272,15 @@ public <U extends TNumber> void evaluate(long expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( long expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - evaluate((double) expected, input, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -286,15 +299,15 @@ public <U extends TNumber> void evaluate(float expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( float expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - evaluate((double) expected, input, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -313,14 +326,14 @@ public <U extends TNumber> void evaluate(double expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public abstract <U extends TNumber> void evaluate( double expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** * Evaluate the expected results versus the actual results @@ -338,19 +351,19 @@ public <U extends TNumber> void evaluate(byte[] expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( byte[] expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Byte[] iArray = new Byte[expected.length]; for (int i = 0; i < expected.length; i++) { iArray[i] = expected[i]; } - evaluate(iArray, input, feedDict); + evaluate(iArray, input, feedMap); } /** @@ -369,19 +382,19 @@ public <U extends TNumber> void evaluate(int[] expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( int[] expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Integer[] iArray = new Integer[expected.length]; for (int i = 0; i < expected.length; i++) { iArray[i] = expected[i]; } - evaluate(iArray, input, feedDict); + evaluate(iArray, input, feedMap); } /** @@ -400,19 +413,19 @@ public <U extends TNumber> void evaluate(long[] expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( long[] expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Long[] iArray = new Long[expected.length]; for (int i = 0; i < expected.length; i++) { iArray[i] = expected[i]; } - evaluate(iArray, input, feedDict); + evaluate(iArray, input, feedMap); } /** @@ -431,19 +444,19 @@ public <U extends TNumber> void evaluate(float[] expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( float[] expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Float[] iArray = new Float[expected.length]; for (int i = 0; i < expected.length; i++) { iArray[i] = expected[i]; } - evaluate(iArray, input, feedDict); + evaluate(iArray, input, feedMap); } /** @@ -462,19 +475,19 @@ public <U extends TNumber> void evaluate(double[] expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( double[] expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Double[] iArray = new Double[expected.length]; for (int i = 0; i < expected.length; i++) { iArray[i] = expected[i]; } - evaluate(iArray, input, feedDict); + evaluate(iArray, input, feedMap); } /** @@ -493,14 +506,14 @@ public <U extends TNumber> void evaluate(Number[] expected, Output<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public abstract <U extends TNumber> void evaluate( Number[] expected, Output<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** * Evaluate the expected results versus the actual results @@ -517,14 +530,14 @@ public void evaluate(String expected, Operand<TString> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( String expected, Operand<TString> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - evaluate(new String[] {expected}, input, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(new String[] {expected}, input, feedMap); } /** @@ -542,12 +555,12 @@ public void evaluate(String expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( - String expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - evaluate(new String[] {expected}, input, feedDict); + String expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(new String[] {expected}, input, feedMap); } /** @@ -565,15 +578,15 @@ public void evaluate(String[] expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( String[] expected, Op input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - Output output = input.op().output(0); - evaluate(expected, output, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + Output<TString> output = input.op().output(0); + evaluate(expected, output, feedMap); } /** @@ -583,7 +596,7 @@ public void evaluate( * @param input the actual value */ public void evaluate(String[] expected, Operand<TString> input) { - Output output = input.asOutput(); + Output<TString> output = input.asOutput(); evaluate(expected, output, null); } @@ -592,13 +605,13 @@ public void evaluate(String[] expected, Operand<TString> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public abstract void evaluate( String[] expected, Output<TString> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** * Evaluate the expected results versus the actual results @@ -615,14 +628,14 @@ public void evaluate(Boolean expected, Operand<TBool> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( Boolean expected, Operand<TBool> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - evaluate(new Boolean[] {expected}, input, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(new Boolean[] {expected}, input, feedMap); } /** @@ -640,12 +653,12 @@ public void evaluate(Boolean expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( - Boolean expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - evaluate(new Boolean[] {expected}, input, feedDict); + Boolean expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(new Boolean[] {expected}, input, feedMap); } /** @@ -655,7 +668,7 @@ public void evaluate( * @param input the actual value */ public void evaluate(Boolean[] expected, Op input) { - Output output = input.op().output(0); + Output<TBool> output = input.op().output(0); evaluate(expected, output, null); } @@ -664,15 +677,15 @@ public void evaluate(Boolean[] expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( Boolean[] expected, Op input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - Output output = input.op().output(0); - evaluate(expected, output, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + Output<TBool> output = input.op().output(0); + evaluate(expected, output, feedMap); } /** @@ -682,7 +695,7 @@ public void evaluate( * @param input the actual value */ public void evaluate(Boolean[] expected, Operand<TBool> input) { - Output output = input.asOutput(); + Output<TBool> output = input.asOutput(); evaluate(expected, output, null); } @@ -691,15 +704,15 @@ public void evaluate(Boolean[] expected, Operand<TBool> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( Boolean[] expected, Operand<TBool> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - Output output = input.asOutput(); - evaluate(expected, output, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + Output<TBool> output = input.asOutput(); + evaluate(expected, output, feedMap); } /** @@ -717,13 +730,13 @@ public void evaluate(Boolean[] expected, Output<TBool> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public abstract void evaluate( Boolean[] expected, Output<TBool> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** * Evaluate the expected results versus the actual results @@ -741,7 +754,7 @@ public <T extends TType> void evaluate(Operand<T> expected, Output<T> input) { * * @param expected the expected value * @param input the actual value - * @param <T> the data type for the feedDict entries + * @param <T> the data type for the feedMap entries */ public <T extends TType> void evaluate(Operand<T> expected, Operand<T> input) { evaluate(expected.asOutput(), input.asOutput(), null); @@ -752,14 +765,14 @@ public <T extends TType> void evaluate(Operand<T> expected, Operand<T> input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. - * @param <T> the data type for the feedDict entries + * @param <T> the data type for the feedMap entries */ public abstract <T extends TType> void evaluate( Output<T> expected, Output<T> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** * Evaluate the expected results versus the actual results @@ -777,15 +790,15 @@ public <U extends TNumber> void evaluate(FloatNdArray expected, Operand<U> input * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( FloatNdArray expected, Operand<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - evaluate(expected, input.asOutput(), feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(expected, input.asOutput(), feedMap); } /** @@ -804,14 +817,14 @@ public <U extends TNumber> void evaluate(FloatNdArray expected, Output<U> input) * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public abstract <U extends TNumber> void evaluate( FloatNdArray expected, Output<U> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** * Evaluate the actual results using a predicate @@ -831,14 +844,14 @@ public <U extends TNumber> void evaluate(Operand<U> input, Predicate<Number> pre * @param input the actual value * @param predicate a predicate that accepts a Number as an argument, if the result of the * predicate is false, then the test will fail - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <U> the data type of the input */ public abstract <U extends TNumber> void evaluate( Output<U> input, Predicate<Number> predicate, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** * Evaluate the actual results using a predicate @@ -865,13 +878,13 @@ public <T extends TType> void print(Operand<T> input) { * Print the results to the "standard" output stream. * * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. - * @param <T> the data type for the feedDict entries + * @param <T> the data type for the feedMap entries */ public <T extends TType> void print( - Operand<T> input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - print(new PrintWriter(new OutputStreamWriter(System.out)), input.asOutput(), feedDict); + Operand<T> input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.asOutput(), feedMap); } /** @@ -887,11 +900,11 @@ public void print(Op input) { * Print the results to the "standard" output stream. * * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ - public void print(Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0), feedDict); + public void print(Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0), feedMap); } /** @@ -908,13 +921,13 @@ public <T extends TType> void print(Output<T> input) { * Print the results to the "standard" output stream. * * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <T> the data type for the input */ public <T extends TType> void print( - Output<T> input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - print(new PrintWriter(new OutputStreamWriter(System.out)), input, feedDict); + Output<T> input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input, feedMap); } /** @@ -933,15 +946,15 @@ public <T extends TType> void print(OutputStream out, Operand<T> input) { * * @param out the output stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. - * @param <T> the data type for the feedDict entries + * @param <T> the data type for the feedMap entries */ public <T extends TType> void print( OutputStream out, Operand<T> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput(), feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput(), feedMap); } /** @@ -959,12 +972,12 @@ public void print(OutputStream out, Op input) { * * @param out the output stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void print( - OutputStream out, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0), feedDict); + OutputStream out, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0), feedMap); } /** @@ -983,15 +996,15 @@ public <T extends TType> void print(OutputStream out, Output<T> input) { * * @param out the output stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <T> the data type for the input */ public <T extends TType> void print( OutputStream out, Output<T> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - print(new PrintWriter(new OutputStreamWriter(out)), input, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(new OutputStreamWriter(out)), input, feedMap); } /** @@ -1010,15 +1023,15 @@ public <T extends TType> void print(Writer writer, Operand<T> input) { * * @param writer the character stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <T> the data type for the input */ public <T extends TType> void print( Writer writer, Operand<T> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - print(new PrintWriter(writer), input.asOutput(), feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(writer), input.asOutput(), feedMap); } /** @@ -1036,12 +1049,12 @@ public void print(Writer writer, Op input) { * * @param writer the character stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void print( - Writer writer, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - print(new PrintWriter(writer), input.op().output(0), feedDict); + Writer writer, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(writer), input.op().output(0), feedMap); } /** @@ -1060,15 +1073,15 @@ public <T extends TType> void print(Writer writer, Output<T> input) { * * @param writer the character stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <T> the data type for the input */ public <T extends TType> void print( Writer writer, Output<T> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { - print(new PrintWriter(writer), input, feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(writer), input, feedMap); } /** @@ -1087,14 +1100,14 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) { * * @param writer the character stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param <T> the data type for the input */ public abstract <T extends TType> void print( PrintWriter writer, Output<T> input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict); + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** * Get the TensorFlow Ops for this test session diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java deleted file mode 100644 index f9f796d7738..00000000000 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java +++ /dev/null @@ -1,429 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the ); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.keras.optimizers; - -import org.tensorflow.*; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; -import org.tensorflow.op.core.Constant; -import org.tensorflow.op.core.Placeholder; -import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TType; - -import java.util.*; - -import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; - -/** Nadam Optimizer that implements the NAdam algorithm. */ -public class Nadam extends org.tensorflow.framework.optimizers.Optimizer - implements OptimizerInterface, AutoCloseable { - - public static final String FIRST_MOMENT = "m"; - public static final String SECOND_MOMENT = "v"; - public static final String MOMENTUM = "momentum"; - - public static final String LEARNING_RATE_KEY = "learning_rate"; - public static final String EPSILON_KEY = "epsilon"; - public static final String BETA_ONE_KEY = "beta_1"; - public static final String BETA_TWO_KEY = "beta_2"; - - public static final float LEARNING_RATE_DEFAULT = 0.001F; - public static final float EPSILON_DEFAULT = 1e-07F; - public static final float BETA_ONE_DEFAULT = 0.9F; - public static final float BETA_TWO_DEFAULT = 0.999F; - - private final Map<String, Object> config = new HashMap<>(); - - private float learningRate; - private Tensor<TFloat32> learningRateTensor; - private final Placeholder<TFloat32> learningRatePlaceholder; - private Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict; - private final float betaOne; - private final float betaTwo; - private final float epsilon; - private final float decayBase = 0.96F; - private final float decay = 0.004F; - - private long iterations = 0; - - private Constant<TFloat32> betaOneConst; - private Constant<TFloat32> betaTwoConst; - private Constant<TInt64> localStepConst; - private Constant<TInt64> nextStepConst; - - private Constant<TFloat32> decayBaseConst; - private Constant<TFloat32> decayConst; - private Constant<TFloat32> epsilonConst; - - private Variable<TFloat32> betaOnePower; - private Variable<TFloat32> betaTwoPower; - private Variable<TFloat32> momentum; - - private Operand<TFloat32> m_t; - private Operand<TFloat32> m_t_1; - private Operand<TFloat32> m_schedule_new; - private Operand<TFloat32> m_schedule_next; - private Operand<TFloat32> one_minus_beta_1; - private Operand<TFloat32> one_minus_beta_2; - private Operand<TFloat32> one_minus_m_t; - private Operand<TFloat32> one_minus_m_schedule_new; - private Operand<TFloat32> one_minus_m_schedule_next; - private Operand<TFloat32> v_t_prime_denominator; - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - */ - public Nadam(Ops tf) { - this(tf, LEARNING_RATE_DEFAULT, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); - } - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - * @param name name for the operations created when applying gradients. Defaults to "Nadam". - */ - public Nadam(Ops tf, String name) { - this(tf, name, LEARNING_RATE_DEFAULT, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); - } - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. - */ - public Nadam(Ops tf, float learningRate) { - this(tf, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); - } - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - * @param name name for the operations created when applying gradients. Defaults to "Adamax". - * @param learningRate The learning rate. - */ - public Nadam(Ops tf, String name, float learningRate) { - this(tf, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); - } - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. - * @param betaOne The exponential decay rate for the 1st moment estimates. - * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. - * @param epsilon A small constant for numerical stability. - */ - public Nadam(Ops tf, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(assertGraph(tf)); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.learningRatePlaceholder = - tf.withSubScope(LEARNING_RATE) - .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; - initConfig(learningRate, betaOne, betaTwo, epsilon); - } - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - * @param name name for the operations created when applying gradients. - * @param learningRate The learning rate. - * @param betaOne The exponential decay rate for the 1st moment estimates. - * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. - * @param epsilon A small constant for numerical stability. - */ - public Nadam( - Ops tf, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(assertGraph(tf), name); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.learningRatePlaceholder = - tf.withSubScope(LEARNING_RATE) - .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; - - initConfig(learningRate, betaOne, betaTwo, epsilon); - } - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - * @param config a config object to initialize - */ - public static Nadam create(Ops tf, Map<String, Object> config) { - String name = (String) config.get(NAME_KEY); - float learningRate = (float) config.getOrDefault(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - float epsilon = (float) config.getOrDefault(EPSILON_KEY, EPSILON_DEFAULT); - float betaOne = (float) config.getOrDefault(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - float betaTwo = (float) config.getOrDefault(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - if (name == null) { - return new Nadam(tf, learningRate, betaOne, betaTwo, epsilon); - } else { - return new Nadam(tf, name, learningRate, betaOne, betaTwo, epsilon); - } - } - - /** {@inheritDoc} */ - @Override - public Map<String, Object> getConfig() { - return config; - } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public final void setLearningRate(float learningRate) { - this.learningRate = learningRate; - if (this.learningRateTensor != null) { - this.learningRateTensor.close(); - } - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); - } - - /** - * Get the Feed Dictionary for the run methods to set the Placeholder values(s) - * - * @return the current Feed Dictionary for the run methods - */ - public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedDict() { - return this.feedDict; - } - - /** {@inheritDoc} */ - @Override - public void close() throws Exception { - if (this.learningRateTensor != null) { - this.learningRateTensor.close(); - this.learningRateTensor = null; - } - } - - /** {@inheritDoc} */ - @Override - protected void createSlots(List<Output<? extends TType>> variables) { - for (Output<? extends TType> v : variables) { - createNadamSlot(v.asOutput()); - } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); - Assign<TFloat32> betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); - ((Graph) tf.scope().env()).addInitializer(betaOnePowerInit); - - betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.DTYPE); - Assign<TFloat32> betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo)); - ((Graph) tf.scope().env()).addInitializer(betaTwoPowerInit); - - momentum = tf.withName("momentum").variable(Shape.scalar(), TFloat32.DTYPE); - Assign<TFloat32> momentumInit = tf.assign(momentum, tf.constant(1.0F)); - ((Graph) tf.scope().env()).addInitializer(momentumInit); - } - - /** - * Create slots for first and second momements and momentum - * - * @param v the variable - * @param <T> the data type or the Variable - */ - private <T extends TType> void createNadamSlot(Output<T> v) { - Operand<T> firstMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); - createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); - Operand<T> secondMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); - createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); - - Operand<T> momentumInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.dataType())); - createSlot(v.asOutput(), MOMENTUM, momentumInitializer); - } - - /** {@inheritDoc} */ - @Override - protected Optional<Op> prepare(String scopeName) { - Constant<TFloat32> one = tf.constant(1.0F); - Constant<TFloat32> point5 = tf.constant(0.5F); - - betaOneConst = tf.constant(betaOne); - betaTwoConst = tf.constant(betaTwo); - localStepConst = tf.constant(this.iterations + 1); - nextStepConst = tf.constant(this.iterations + 2); - decayConst = tf.constant(decay); - decayBaseConst = tf.constant(this.decayBase); - epsilonConst = tf.constant(this.epsilon); - - // m_t = beta_1_t * (1. - 0.5 * ( math_ops.pow(decay_base, self._initial_decay * local_step))) - m_t = - tf.math.mul( - betaOneConst, - tf.math.sub( - one, - tf.math.mul( - point5, - tf.math.pow( - decayBaseConst, - tf.math.mul(decayConst, tf.dtypes.cast(localStepConst, TFloat32.DTYPE)))))); - // m_t_1 = beta_1_t * (1. - 0.5 * ( math_ops.pow(decay_base, self._initial_decay * next_step))) - m_t_1 = - tf.math.mul( - betaOneConst, - tf.math.sub( - one, - tf.math.mul( - point5, - tf.math.pow( - decayBaseConst, - tf.math.mul(decayConst, tf.dtypes.cast(nextStepConst, TFloat32.DTYPE)))))); - - // m_schedule_new = math_ops.cast(self._m_cache_read, var_dtype) * m_t - m_schedule_new = tf.math.mul(momentum, m_t); - // if var_dtype is self._m_cache.dtype: - // m_schedule_new = array_ops.identity(state_ops.assign( - // self._m_cache, m_schedule_new, use_locking=self._use_locking)) - m_schedule_new = tf.identity(tf.assign(momentum, m_schedule_new, Assign.useLocking(true))); - // m_schedule_next = m_schedule_new * m_t_1 - m_schedule_next = tf.math.mul(m_schedule_new, m_t_1); - - // 1 - beta_1_t - one_minus_beta_1 = tf.math.sub(one, betaOneConst); - // 1 - beta_2_t, - one_minus_beta_2 = tf.math.sub(one, betaTwoConst); - // 1. - m_t, - one_minus_m_t = tf.math.sub(one, m_t); - // 1. - m_schedule_new - one_minus_m_schedule_new = tf.math.sub(one, m_schedule_new); - // 1. - m_schedule_next - one_minus_m_schedule_next = tf.math.sub(one, m_schedule_next); - // 1. - math_ops.pow(beta_2_t, local_step) - v_t_prime_denominator = - tf.math.sub(one, tf.math.pow(betaTwoConst, tf.dtypes.cast(localStepConst, TFloat32.DTYPE))); - return Optional.empty(); - } - - /** {@inheritDoc} */ - @Override - protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { - DataType dType = gradient.dataType(); - Variable<T> m = getSlot(variable, FIRST_MOMENT).get(); // first Moment - Variable<T> v = getSlot(variable, SECOND_MOMENT).get(); // Second Moment - - // g_prime = grad / coefficients['one_minus_m_schedule_new'] - Operand<T> g_prime = tf.math.div(gradient, tf.dtypes.cast(one_minus_m_schedule_new, dType)); - // m_t = (coefficients['beta_1_t'] * m + coefficients['one_minus_beta_1_t'] * grad) - Operand<T> m_t = - tf.math.add( - tf.math.mul(tf.dtypes.cast(betaOneConst, dType), m), - tf.math.mul(tf.dtypes.cast(one_minus_beta_1, dType), gradient)); - // m_t = state_ops.assign(m, m_t, use_locking=self._use_locking) - // update m - m_t = tf.assign(m, m_t, Assign.useLocking(true)); - - // m_t_prime = m_t / coefficients['one_minus_m_schedule_next'] - Operand<T> m_t_prime = tf.math.div(m_t, tf.dtypes.cast(one_minus_m_schedule_next, dType)); - - // v_t = (coefficients['beta_2_t'] * v + coefficients['one_minus_beta_2_t'] * - // math_ops.square(grad)) - Operand<T> v_t = - tf.math.add( - tf.math.mul(tf.dtypes.cast(betaTwoConst, dType), v), - tf.math.mul(tf.dtypes.cast(one_minus_beta_2, dType), tf.math.square(gradient))); - // v_t = state_ops.assign(v, v_t, use_locking=self._use_locking) - // update v - v_t = tf.assign(v, v_t, Assign.useLocking(true)); - - // v_t_prime = v_t / coefficients['v_t_prime_denominator'] - Operand<T> v_t_prime = tf.math.div(v_t, tf.dtypes.cast(v_t_prime_denominator, dType)); - - // m_t_bar = (coefficients['one_minus_m_t'] * g_prime + coefficients['m_t_1'] * m_t_prime) - Operand<T> m_t_bar = - tf.math.add( - tf.math.mul(tf.dtypes.cast(one_minus_m_t, dType), g_prime), - tf.math.mul(tf.dtypes.cast(m_t_1, dType), m_t_prime)); - // var_t = var - coefficients['lr_t'] * m_t_bar / (math_ops.sqrt(v_t_prime) + - // coefficients['epsilon']) - Operand<T> var_t = - tf.math.sub( - variable, - tf.math.div( - tf.math.mul(tf.dtypes.cast(this.learningRatePlaceholder, dType), m_t_bar), - tf.math.add(tf.math.sqrt(v_t_prime), tf.dtypes.cast(epsilonConst, dType)))); - // assign(var, var_t, use_locking=self._use_locking) - return tf.assign(variable, var_t, Assign.useLocking(true)); - } - - /** - * Gathers up the update operations into a single op that can be used as a run target. - * - * <p>Adds the betaOne, betaTwo and mu updates to the end of the updates list. - * - * @param updateOperations The update operations. - * @param name The name of the run target. - * @return A NoOp with a control dependency on each update operation. - */ - @Override - protected Op finish(List<Op> updateOperations, String name) { - iterations++; // increment the step; - updateOperations.add(tf.assign(betaOnePower, tf.math.mul(betaOnePower, betaOneConst))); - updateOperations.add(tf.assign(betaTwoPower, tf.math.mul(betaTwoPower, betaTwoConst))); - return super.finish(updateOperations, name); - } - - /** {@inheritDoc} */ - @Override - public String getOptimizerName() { - return "Nadam"; - } - - /** - * Sets the config object based on the current state of the Optmizer. - * - * @param learningRate The learning rate. Defaults to 0.001. - * @param betaOne The exponential decay rate for the 1st moment estimates. Defaults to 0.9. - * @param betaTwo The exponential decay rate for the 2nd moment estimates. Defaults to 0.999. - * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the - * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm - * 1 of the paper. Defaults to 1e-7. - */ - private void initConfig(float learningRate, float betaOne, float betaTwo, float epsilon) { - config.put(NAME_KEY, this.getOptimizerName()); - config.put(LEARNING_RATE_KEY, learningRate); - config.put(EPSILON_KEY, epsilon); - config.put(BETA_ONE_KEY, betaOne); - config.put(BETA_TWO_KEY, betaTwo); - } -} diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java deleted file mode 100644 index 183c71dd976..00000000000 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.keras.optimizers; - -import org.tensorflow.Graph; -import org.tensorflow.op.Ops; - -import java.util.Map; - -/** The main Interface for Keras Optimizers */ -public interface OptimizerInterface { - - /** The value for the name key in the Config object */ - String NAME_KEY = "name"; - - /** - * Get a TensorFlow Graph from the Ops. - * - * @param tf the TensorFlow Ops - * @return the graph - * @throws java.lang.IllegalArgumentException if the TensorFlow Ops does not represent Graph mode - */ - static Graph assertGraph(Ops tf) { - if (!tf.scope().env().isGraph()) { - throw new IllegalArgumentException( - "Invalid environment, Optimizers can only be used in Graph Mode"); - } - return (Graph) tf.scope().env(); - } - - /** - * Return the config object used to initialize the Optimizer - * - * @return the config object used to initialize the Optimizer - */ - Map<String, Object> getConfig(); -} diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java deleted file mode 100644 index aecd8dcf537..00000000000 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.keras.optimizers; - -import org.tensorflow.framework.optimizers.Optimizer; -import org.tensorflow.op.Ops; - -import java.lang.reflect.Constructor; -import java.lang.reflect.InvocationTargetException; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; -import java.util.function.Supplier; -import java.util.logging.Level; -import java.util.logging.Logger; - -/** - * Functions to get an Optimizer based on String name, an Optimizer class, or lambda function. - * - * <p>Example: - * - * <pre> - * Adam instance = Optimizers.get(tf, "adam"); - * Ftrl instance = Optimizers.get(tf, ltf -> new Ftrl(ltf, 0.1f); - * </pre> - */ -public class Optimizers { - - static Map<String, Function<Ops, Optimizer>> map = - new HashMap<String, Function<Ops, Optimizer>>() { - { - put("adadelta", AdaDelta::new); - put("adagrad", AdaGrad::new); - put("adagrad-da", AdaGradDA::new); - put("adam", Adam::new); - put("adamax", Adamax::new); - put("ftrl", Ftrl::new); - put("nadam", Nadam::new); - put("rmsprop", RMSProp::new); - put("sgd", SGD::new); - } - }; - - /** - * Get an Optimizer - * - * @param optimizerFunction either a String that identifies the Optimizer, an Optimizer class, or - * an Optimizer object. - * @return the Optimizer object or null if not found. - */ - public static Optimizer get(Ops tf, Object optimizerFunction) { - return get(tf, optimizerFunction, null); - } - - /** - * Get an Optimizer - * - * @param func a lamda function that returns the Optimizer - * @return the Intializer object - */ - public static Optimizer get(Ops tf, Function<Ops, Optimizer> func) { - return func.apply(tf); - } - - /** - * Get an Optimizer - * - * @param optimizerFunction either a String that identifies the Optimizer, an Optimizer class, or - * * an Optimizer object. - * @param custom_functions a map of Optimizer lambdas that will be queried if the Optimizer is not - * found in the standard keys - * @return the Optimizer object - */ - public static Optimizer get( - Ops tf, Object optimizerFunction, Map<String, Function<Ops, Optimizer>> custom_functions) { - if (optimizerFunction != null) { - if (optimizerFunction instanceof String) { - String s = - optimizerFunction - .toString(); // do this for Java 8 rather than Pattern Matching for instanceof - Function<Ops, Optimizer> function = map.get(s); - if (function == null && custom_functions != null) { - function = custom_functions.get(s); - } - return function != null ? function.apply(tf) : null; - } else if (optimizerFunction instanceof Class) { - // do this for Java 8 rather than Pattern Matching for instanceof - Class<OptimizerInterface> c = (Class<OptimizerInterface>) optimizerFunction; - try { - Constructor<OptimizerInterface> ctor = c.getConstructor(Ops.class); - return (Optimizer) ctor.newInstance(tf); - } catch (NoSuchMethodException - | InstantiationException - | IllegalAccessException - | IllegalArgumentException - | InvocationTargetException ex) { - Logger.getLogger(Optimizers.class.getName()).log(Level.SEVERE, null, ex); - } - } else if (optimizerFunction instanceof Optimizer) { - return (Optimizer) optimizerFunction; - } else if (optimizerFunction instanceof Function) { - return ((Function<Ops, Optimizer>) optimizerFunction).apply(tf); - } else if (optimizerFunction instanceof Supplier) { - return ((Supplier<Optimizer>) optimizerFunction).get(); - } - } else { - return null; - } - - throw new IllegalArgumentException( - "optimizerFunction must be a symbolic name, Optimizer, Function<Ops, Optimizer>, Supplier<Optimizer> or a Class object"); - } -} diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java deleted file mode 100644 index 03fc4c01f71..00000000000 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the ); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.keras.optimizers; - -import org.tensorflow.op.Ops; - -import java.util.HashMap; -import java.util.Map; - -import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; - -/** RMSProp Optimizer that implements the RMSProp algorithm. */ -public class RMSProp extends org.tensorflow.framework.optimizers.RMSProp - implements OptimizerInterface { - - public static final String LEARNING_RATE_KEY = "learning_rate"; - public static final String DECAY_KEY = "decay"; - public static final String MOMENTUM_KEY = "momentum"; - public static final String EPSILON_KEY = "epsilon"; - public static final String CENTERED_KEY = "centered"; - - public static final float LEARNING_RATE_DEFAULT = 0.001F; - public static final float DECAY_DEFAULT = 0.9F; - public static final float MOMENTUM_DEFAULT = 0.0F; - public static final float EPSILON_DEFAULT = 1e-07F; - public static final boolean CENTERED_DEFAULT = false; - - private Map<String, Object> config = new HashMap<>(); - - /** - * Create an RMSProp Optimizer with the following defaults, name="RMSProp", learning_rate=0.001, - * decay=0.9, momentum=0.0, epsilon=1e-07, centered=false - * - * @param tf the TensorFlow Ops - */ - public RMSProp(Ops tf) { - this( - tf, - LEARNING_RATE_DEFAULT, - DECAY_DEFAULT, - MOMENTUM_DEFAULT, - EPSILON_DEFAULT, - CENTERED_DEFAULT); - } - - /** - * Create an RMSProp Optimizer with the following defaults, name="RMSProp", decay=0.9, - * momentum=0.0, epsilon=1e-07, centered=false - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. - */ - public RMSProp(Ops tf, float learningRate) { - this(tf, learningRate, DECAY_DEFAULT, MOMENTUM_DEFAULT, EPSILON_DEFAULT, CENTERED_DEFAULT); - } - - /** - * Create an RMSProp Optimizer with the following defaults, decay=0.9, momentum=0.0, - * epsilon=1e-07, centered=false - * - * @param tf the TensorFlow Ops - * @param name prefix for the operations created when applying gradients. Defaults to "RMSProp" - * @param learningRate The learning rate. - */ - public RMSProp(Ops tf, String name, float learningRate) { - this( - tf, name, learningRate, DECAY_DEFAULT, MOMENTUM_DEFAULT, EPSILON_DEFAULT, CENTERED_DEFAULT); - } - - /** - * Create an RMSProp Optimizer - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. Defaults to 0.001. - * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. - * @param momentum hyperparameter that accelerates descent in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - * @param epsilon A small constant for numerical stability. - * @param centered If True, gradients are normalized by the estimated variance of the gradient; if - * False, by the uncentered second moment. - */ - public RMSProp( - Ops tf, float learningRate, float decay, float momentum, float epsilon, boolean centered) { - super(assertGraph(tf), learningRate, decay, momentum, epsilon, centered); - initConfig(learningRate, decay, momentum, epsilon, centered); - } - - /** - * Create an RMSProp Optimizer - * - * @param tf the TensorFlow Ops - * @param name prefix for the operations created when applying gradients. Defaults to "RMSProp" - * @param learningRate The learning rate. Defaults to 0.001. - * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. - * @param momentum hyperparameter that accelerates descent in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - * @param epsilon A small constant for numerical stability. - * @param centered If True, gradients are normalized by the estimated variance of the gradient; if - * False, by the uncentered second moment. - */ - public RMSProp( - Ops tf, - String name, - float learningRate, - float decay, - float momentum, - float epsilon, - boolean centered) { - super(assertGraph(tf), name, learningRate, decay, momentum, epsilon, centered); - initConfig(learningRate, decay, momentum, epsilon, centered); - } - - /** - * Create a RMSProp Optimizer using a configuration - * - * @param tf the TensorFlow Ops - * @param config a config object to initialize the Optimizer, the config object has keys for - * "name", "learning_rate", "decay", "momentum", "epsilon" and "centered". If a key is missing - * the default value is used. - * @return the RMSProp optimizer - */ - public static RMSProp fromConfig(Ops tf, Map<String, Object> config) { - return create(tf, config); - } - - /** - * Create a RMSProp Optimizer using a configuration - * - * @param tf the TensorFlow Ops - * @param config a config object to initialize the Optimizer, the config object has keys for - * "name", "learning_rate", "decay", "momentum", "epsilon" and "centered". If a key is missing - * the default value is used. - * @return the RMSProp optimizer - */ - public static RMSProp create(Ops tf, Map<String, Object> config) { - - String name = (String) config.get(NAME_KEY); - float learningRate = (float) config.getOrDefault(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - float decay = (float) config.getOrDefault(DECAY_KEY, DECAY_DEFAULT); - float momentum = (float) config.getOrDefault(MOMENTUM_KEY, MOMENTUM_DEFAULT); - float epsilon = (float) config.getOrDefault(EPSILON_KEY, EPSILON_DEFAULT); - boolean centered = (boolean) config.getOrDefault(CENTERED_KEY, CENTERED_DEFAULT); - if (name == null) { - return new RMSProp(tf, learningRate, decay, momentum, epsilon, centered); - } else { - return new RMSProp(tf, name, learningRate, decay, momentum, epsilon, centered); - } - } - - /** - * Initialize the configuration based on which constructor is called. - * - * @param learningRate The learning rate. Defaults to 0.001. - * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. - * @param momentum hyperparameter that accelerates descent in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - * @param epsilon A small constant for numerical stability. - * @param centered If True, gradients are normalized by the estimated variance of the gradient; if - * False, by the uncentered second moment. - */ - private void initConfig( - float learningRate, float decay, float momentum, float epsilon, boolean centered) { - config.put(NAME_KEY, this.getOptimizerName()); - config.put(LEARNING_RATE_KEY, learningRate); - config.put(DECAY_KEY, decay); - config.put(MOMENTUM_KEY, momentum); - config.put(EPSILON_KEY, epsilon); - config.put(CENTERED_KEY, centered); - } - - /** {@inheritDoc} */ - @Override - public Map<String, Object> getConfig() { - return config; - } -} diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java deleted file mode 100644 index 5e7155c2ab5..00000000000 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the ); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.keras.optimizers; - -import org.tensorflow.op.Ops; - -import java.util.HashMap; -import java.util.Map; - -import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; - -/** Stochastic Gradient Descent and momentum optimizer. */ -public class SGD extends org.tensorflow.framework.optimizers.Momentum - implements OptimizerInterface { - - public static final String LEARNING_RATE_KEY = "learning_rate"; - public static final String MOMENTUM_KEY = "momentum"; - public static final String NESTEROV_KEY = "nesterov"; - - public static final float LEARNING_RATE_DEFAULT = 0.01F; - public static final float MOMENTUM_DEFAULT = 0.0F; - public static final boolean NESTEROV_DEFAULT = false; - - private Map<String, Object> config = new HashMap<>(); - - /** - * Create a Stochastic Gradient Descent optimizer using defaults: name="SGD", learning_rate=0.01, - * momentum=0.0, and nesterov=false - * - * @param tf the TensorFlow Ops - */ - public SGD(Ops tf) { - this(tf, LEARNING_RATE_DEFAULT, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); - } - - /** - * Create a Stochastic gradient descent optimizer using defaults: name="SGD", momentum=0.0, and - * nesterov=false - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. Defaults to 0.01. - */ - public SGD(Ops tf, float learningRate) { - this(tf, learningRate, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); - } - - /** - * Create a Stochastic gradient descent optimizer using defaults: name="SGD", and nesterov=false - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. Defaults to 0.01. - * @param momentum hyperparameter that accelerates SGD in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - */ - public SGD(Ops tf, float learningRate, float momentum) { - this(tf, learningRate, momentum, NESTEROV_DEFAULT); - } - - /** - * Create a Stochastic gradient descent optimizer using defaults: momentum=0.0, and nesterov=false - * - * @param tf the TensorFlow Ops - * @param name prefix for the operations created when applying gradients - * @param learningRate The learning rate. Defaults to 0.01. - */ - public SGD(Ops tf, String name, float learningRate) { - this(tf, name, learningRate, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); - } - - /** - * create a Stochastic gradient descent optimizer using defaults: momentum=0.0, and nesterov=false - * - * @param tf the TensorFlow Ops - * @param name prefix for the operations created when applying gradients - * @param learningRate The learning rate. Defaults to 0.01. - * @param momentum hyperparameter that accelerates SGD in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - */ - public SGD(Ops tf, String name, float learningRate, float momentum) { - this(tf, name, learningRate, momentum, NESTEROV_DEFAULT); - } - - /** - * Create a Stochastic gradient descent optimizer - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. Defaults to 0.01. - * @param momentum hyperparameter that accelerates SGD in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - * @param useNesterov Whether to apply Nesterov momentum. Defaults to `false`. - */ - public SGD(Ops tf, float learningRate, float momentum, boolean useNesterov) { - super(assertGraph(tf), learningRate, momentum, useNesterov); - if (momentum < 0 || momentum > 1) - throw new IllegalArgumentException("\"momentum\" must be between [0, 1]."); - initConfig(learningRate, momentum, useNesterov); - } - - /** - * Create a Stochastic gradient descent optimizer - * - * @param tf the TensorFlow Ops - * @param name prefix for the operations created when applying gradients - * @param learningRate The learning rate. Defaults to 0.01. - * @param momentum hyperparameter that accelerates SGD in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - * @param useNesterov Whether to apply Nesterov momentum. Defaults to `false`. - */ - public SGD(Ops tf, String name, float learningRate, float momentum, boolean useNesterov) { - super(assertGraph(tf), name, learningRate, momentum, useNesterov); - if (momentum < 0 || momentum > 1) - throw new IllegalArgumentException("\"momentum\" must be between [0, 1]."); - initConfig(learningRate, momentum, useNesterov); - } - - /** - * Create a Stochastic gradient descent optimizer - * - * @param tf the TensorFlow Ops - * @param config a config object to initialize, the config object has keys for "name", - * "learning_rate", "momentum", and "nesterov". If a key is missing the default value is used. - * @return the Stochastic gradient descent optimizer - */ - public static SGD fromConfig(Ops tf, Map<String, Object> config) { - return create(tf, config); - } - - /** - * Create a Stochastic gradient descent optimizer - * - * @param tf the TensorFlow Ops - * @param config a config object to initialize, the config object has keys for "name", - * "learning_rate", "momentum", and "nesterov". If a key is missing the default value is used. - * @return the Stochastic gradient descent optimizer - */ - public static SGD create(Ops tf, Map<String, Object> config) { - - String name = (String) config.get(NAME_KEY); - float learningRate = (float) config.getOrDefault(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - float momentum = (float) config.getOrDefault(MOMENTUM_KEY, MOMENTUM_DEFAULT); - boolean nesterov = (boolean) config.getOrDefault(NESTEROV_KEY, NESTEROV_DEFAULT); - if (name == null) { - return new SGD(tf, learningRate, momentum, nesterov); - } else { - return new SGD(tf, name, learningRate, momentum, nesterov); - } - } - - /** - * Initialize the configuration ased on which constructor is called. - * - * @param learningRate learningRate The learning rate. Defaults to 0.01. - * @param momentum hyperparameter that accelerates SGD in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - * @param useNesterov Whether to apply Nesterov momentum. Defaults to `false`. - */ - private void initConfig(float learningRate, float momentum, boolean useNesterov) { - config.put(NAME_KEY, this.getOptimizerName()); - config.put(LEARNING_RATE_KEY, learningRate); - config.put(MOMENTUM_KEY, momentum); - config.put(NESTEROV_KEY, useNesterov); - } - - /** { @inheritDoc } */ - @Override - public Map<String, Object> getConfig() { - return config; - } - - // overide the momentum name to return "SGD" - /** {@inheritDoc} */ - @Override - public String getOptimizerName() { - return "SGD"; - } -} diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java deleted file mode 100644 index 7651872643b..00000000000 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java +++ /dev/null @@ -1,444 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.keras.optimizers; - -import org.junit.jupiter.api.*; -import org.tensorflow.framework.optimizers.Optimizer; -import org.tensorflow.keras.utils.ND; -import org.tensorflow.keras.utils.TestSession; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; -import org.tensorflow.op.core.Constant; -import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.RMSProp.*; -import static org.tensorflow.keras.optimizers.Ftrl.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; -import static org.tensorflow.keras.optimizers.RMSProp.*; - -/** Test cases for RMSProp Optimizer */ -public class RMSPropTest { - - private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; - - final int VAR_T = 0; - final int MG_T = 1; - final int RMS_T = 2; - final int MOM_T = 3; - - int index; - - public RMSPropTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - /** Test of create method, of class RMSProp. */ - @Test - public void testCreate() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - Map<String, Object> config = new HashMap<>(); - config.put(NAME_KEY, "Ftrl"); - config.put(LEARNING_RATE_KEY, 2.0F); - config.put(DECAY_KEY, DECAY_DEFAULT); - config.put(MOMENTUM_KEY, MOMENTUM_DEFAULT); - config.put(EPSILON_KEY, EPSILON_DEFAULT); - config.put(CENTERED_KEY, CENTERED_DEFAULT); - Ftrl expResult = new Ftrl(tf, 2.0F); - Ftrl result = Ftrl.create(tf, config); - assertEquals(expResult.getConfig(), result.getConfig()); - } - } - - Object[][] _test_param_values = { - // learning_rate, rho (decay), momentum, epsilon, centered - {0.05F, 0.9F, 0.0F, 1e-3F, true}, - {0.05F, 0.9F, 0.0F, 1e-3F, false}, - {0.1F, 0.9F, 0.0F, 1e-3F, true}, - {0.01F, 0.9F, 0.0F, 1e-5F, true}, - {0.01F, 0.9F, 0.9F, 1e-5F, true} - }; - - @Test - public void testDense() { - - int numSteps = 3; - - for (int run = 0; run < _test_param_values.length; run++) { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - session.setEpsilon(1e-2f); - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.2F}; - float[] grads1_init = {0.01F, 0.2F}; - final float epsilon1 = 1e-2F; - - FloatNdArray var0_np = NdArrays.vectorOf(var0_init); - FloatNdArray var1_np = NdArrays.vectorOf(var1_init); - FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); - FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); - - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); - Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - - Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); - - Constant<TFloat32> grads0 = tf.constant(grads0_init); - Constant<TFloat32> grads1 = tf.constant(grads1_init); - - // learning_rate, rho (decay), momentum, epsilon, centered - float learningRate = (float) (float) _test_param_values[run][0]; - float decay = (float) _test_param_values[run][1]; - float momentum = (float) _test_param_values[run][2]; - float epsilon = (float) _test_param_values[run][3]; - boolean centered = (boolean) _test_param_values[run][4]; - - RMSProp instance = new RMSProp(tf, learningRate, decay, momentum, epsilon, centered); - - /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); - gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); - gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - - Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); - - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - - /** initialize the accumulators */ - session.run(tf.init()); - - /** make sure the variables were initialized properly */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); - - Variable<TFloat32> mg0 = centered ? instance.getSlot(var0.asOutput(), MG).get() : null; - Variable<TFloat32> mg1 = centered ? instance.getSlot(var1.asOutput(), MG).get() : null; - Variable<TFloat32> mom0 = - momentum > 0.F ? instance.getSlot(var0.asOutput(), MOMENTUM).get() : null; - Variable<TFloat32> mom1 = - momentum > 0.F ? instance.getSlot(var1.asOutput(), MOMENTUM).get() : null; - Variable<TFloat32> rms0 = instance.getSlot(var0.asOutput(), RMS).get(); - Variable<TFloat32> rms1 = instance.getSlot(var1.asOutput(), RMS).get(); - - float[] zeros = {0.0F, 0.0F}; - float[] ones = {1.0F, 1.0F}; // temp to match RMSProp - FloatNdArray mg0_np = NdArrays.vectorOf(zeros); - FloatNdArray mg1_np = NdArrays.vectorOf(zeros); - FloatNdArray rms0_np = NdArrays.vectorOf(ones); - FloatNdArray rms1_np = NdArrays.vectorOf(ones); - FloatNdArray mom0_np = NdArrays.vectorOf(zeros); - FloatNdArray mom1_np = NdArrays.vectorOf(zeros); - - for (int i = 0; i < numSteps; i++) { - session.run(update, instance.getFeedDict()); - FloatNdArray[] result0 = - calc( - var0_np, - grads0_np, - mg0_np, - rms0_np, - mom0_np, - learningRate, - decay, - momentum, - epsilon, - centered); - var0_np = result0[VAR_T]; - mg0_np = result0[MG_T]; - rms0_np = result0[RMS_T]; - mom0_np = result0[MOM_T]; - - FloatNdArray[] result1 = - calc( - var1_np, - grads1_np, - mg1_np, - rms1_np, - mom1_np, - learningRate, - decay, - momentum, - epsilon, - centered); - - var1_np = result1[VAR_T]; - mg1_np = result1[MG_T]; - rms1_np = result1[RMS_T]; - mom1_np = result1[MOM_T]; - - if (centered) { - session.evaluate(mg0_np, mg0); - session.evaluate(mg0_np, mg0); - } - if (momentum > 0.F) { - session.evaluate(mom0_np, mom0); - session.evaluate(mom1_np, mom1); - } - - /* TODO the values returned from rms slot, do not match what I see in the python test */ - session.evaluate(rms0_np, rms0); - session.evaluate(rms1_np, rms1); - - session.evaluate(var0_np, var0); - session.evaluate(var1_np, var1); - } - } - } - } - - @Test - public void testWithLearningRateDecay() { - int numSteps = 3; - - for (int run = 0; run < _test_param_values.length; run++) { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - session.setEpsilon(1e-2f); - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.2F}; - float[] grads1_init = {0.01F, 0.2F}; - final float epsilon1 = 1e-2F; - - FloatNdArray var0_np = NdArrays.vectorOf(var0_init); - FloatNdArray var1_np = NdArrays.vectorOf(var1_init); - FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); - FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); - - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); - Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - - Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); - - Constant<TFloat32> grads0 = tf.constant(grads0_init); - Constant<TFloat32> grads1 = tf.constant(grads1_init); - - // learning_rate, rho (decay), momentum, epsilon, centered - float learningRate = (float) (float) _test_param_values[run][0]; - float decay = (float) _test_param_values[run][1]; - float momentum = (float) _test_param_values[run][2]; - float epsilon = (float) _test_param_values[run][3]; - boolean centered = (boolean) _test_param_values[run][4]; - - RMSProp instance = new RMSProp(tf, learningRate, decay, momentum, epsilon, centered); - - /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); - gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); - gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - - Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); - - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - - /** initialize the accumulators */ - session.run(tf.init()); - - /** make sure the variables were initialized properly */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); - - Variable<TFloat32> mg0 = centered ? instance.getSlot(var0.asOutput(), MG).get() : null; - Variable<TFloat32> mg1 = centered ? instance.getSlot(var1.asOutput(), MG).get() : null; - Variable<TFloat32> mom0 = - momentum > 0.F ? instance.getSlot(var0.asOutput(), MOMENTUM).get() : null; - Variable<TFloat32> mom1 = - momentum > 0.F ? instance.getSlot(var1.asOutput(), MOMENTUM).get() : null; - Variable<TFloat32> rms0 = instance.getSlot(var0.asOutput(), RMS).get(); - Variable<TFloat32> rms1 = instance.getSlot(var1.asOutput(), RMS).get(); - - float[] zeros = {0.0F, 0.0F}; - float[] ones = {1.0F, 1.0F}; // temp to match RMSProp - FloatNdArray mg0_np = NdArrays.vectorOf(zeros); - FloatNdArray mg1_np = NdArrays.vectorOf(zeros); - FloatNdArray rms0_np = NdArrays.vectorOf(ones); - FloatNdArray rms1_np = NdArrays.vectorOf(ones); - FloatNdArray mom0_np = NdArrays.vectorOf(zeros); - FloatNdArray mom1_np = NdArrays.vectorOf(zeros); - - for (int i = 0; i < numSteps; i++) { - session.run(update, instance.getFeedDict()); - FloatNdArray[] result0 = - calc( - var0_np, - grads0_np, - mg0_np, - rms0_np, - mom0_np, - learningRate, - decay, - momentum, - epsilon, - centered); - var0_np = result0[VAR_T]; - mg0_np = result0[MG_T]; - rms0_np = result0[RMS_T]; - mom0_np = result0[MOM_T]; - - FloatNdArray[] result1 = - calc( - var1_np, - grads1_np, - mg1_np, - rms1_np, - mom1_np, - learningRate, - decay, - momentum, - epsilon, - centered); - - var1_np = result1[VAR_T]; - mg1_np = result1[MG_T]; - rms1_np = result1[RMS_T]; - mom1_np = result1[MOM_T]; - - if (centered) { - session.evaluate(mg0_np, mg0); - session.evaluate(mg0_np, mg0); - } - if (momentum > 0.F) { - session.evaluate(mom0_np, mom0); - session.evaluate(mom1_np, mom1); - } - - /* TODO the values returned from rms slot, do not match what I see in the python test */ - session.evaluate(rms0_np, rms0); - session.evaluate(rms1_np, rms1); - - session.evaluate(var0_np, var0); - session.evaluate(var1_np, var1); - - learningRate *= 0.9F; - instance.setLearningRate(learningRate); - } - } - } - } - - FloatNdArray[] calc( - FloatNdArray var_np, - FloatNdArray grad_np, - FloatNdArray mg_np, - FloatNdArray rms_np, - FloatNdArray mom, - float lr, - float decay, - float momentum, - float epsilon, - boolean centered) { - - FloatNdArray[] result = new FloatNdArray[4]; // var_t, mg_t, rms_t, mom_t - result[RMS_T] = calcRMS(rms_np, grad_np, decay); // RMS - - FloatNdArray denom_t; - if (centered) { - result[MG_T] = calcMG(mg_np, grad_np, decay); - // rms_t - mg_t * mg_t - denom_t = ND.sub(result[RMS_T], ND.square(result[MG_T])); - } else { - result[MG_T] = mg_np; - denom_t = rms_np; - } - if (momentum > 0.F) { - // momentum * mom + lr * g / (np.sqrt(denom_t + epsilon)) - result[MOM_T] = calcMom(momentum, mom, lr, grad_np, denom_t, epsilon); - // var_t = var - mom_t - result[VAR_T] = ND.sub(var_np, result[MOM_T]); - } else { - result[MOM_T] = mom; - result[VAR_T] = calcVar(var_np, grad_np, lr, denom_t, epsilon); - } - - return result; - } - - private FloatNdArray calcRMS(FloatNdArray rms_np, FloatNdArray grad_np, float decay) { - // rms * rho + (1 - rho) * g * g - FloatNdArray rms_rho = ND.mul(rms_np, decay); - FloatNdArray squareG = ND.square(grad_np); - float oneRHO = 1.0F - decay; - FloatNdArray decayG2 = ND.mul(oneRHO, squareG); - FloatNdArray result = ND.add(rms_rho, decayG2); - return result; - } - - private FloatNdArray calcMG(FloatNdArray mg_np, FloatNdArray grad_np, float decay) { - // mg_t = mg * rho + (1 - rho) * g - FloatNdArray mg_rho = ND.mul(mg_np, decay); - float oneRHO = 1.0F - decay; - FloatNdArray decayG = ND.mul(oneRHO, grad_np); - FloatNdArray result = ND.add(mg_rho, decayG); - return result; - } - - private FloatNdArray calcMom( - float momentum, - FloatNdArray mom, - float lr, - FloatNdArray grad_np, - FloatNdArray denom_t, - float epsilon) { - // momentum * mom + lr * g / (np.sqrt(denom_t + epsilon)) - FloatNdArray moMo = ND.mul(momentum, mom); - FloatNdArray dividend = ND.mul(lr, grad_np); - FloatNdArray divisor = ND.sqrt(ND.add(denom_t, epsilon)); - FloatNdArray quotient = ND.div(dividend, divisor); - FloatNdArray result = ND.add(moMo, quotient); - return result; - } - - private FloatNdArray calcVar( - FloatNdArray var_np, FloatNdArray grad_np, float lr, FloatNdArray denom_t, float epsilon) { - // var - lr * g / (np.sqrt(denom_t) + epsilon) - FloatNdArray dividend = ND.mul(lr, grad_np); - FloatNdArray divisor = ND.add(ND.sqrt(denom_t), epsilon); - FloatNdArray quotient = ND.div(dividend, divisor); - FloatNdArray result = ND.sub(var_np, quotient); - return result; - } -} From dddc2975922a8c3322b04a4cc991b06ab5d254a8 Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Mon, 14 Sep 2020 19:52:43 -0400 Subject: [PATCH 06/14] Reformatted code --- .../java/org/tensorflow/framework/optimizers/AdaDelta.java | 4 ---- .../java/org/tensorflow/framework/optimizers/AdaGrad.java | 5 +---- .../main/java/org/tensorflow/framework/optimizers/Ftrl.java | 4 +--- 3 files changed, 2 insertions(+), 11 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index 9453dce7343..be5b39a534d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -68,14 +68,10 @@ public class AdaDelta extends Optimizer { public static final float RHO_DEFAULT = 0.95f; public static final float EPSILON_DEFAULT = 1e-7f; - - private final float rho; private final float epsilon; - - public AdaDelta(Graph graph) { this(graph, LEARNING_RATE_DEFAULT, RHO_DEFAULT, EPSILON_DEFAULT); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 3aae6f71693..384c04e60bb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -139,10 +139,7 @@ private <T extends TType> void createAdaGradSlot(Output<T> v) { protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { Variable<T> slot = getSlot(variable, ACCUMULATOR).get(); return tf.train.applyAdagrad( - variable, - slot, - tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), - gradient); + variable, slot, tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), gradient); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index 8e7b638dc21..b455ae1f0be 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -176,7 +176,6 @@ public Ftrl( this.l2RegularizationStrength = l2Strength; this.l2ShrinkageRegularizationStrength = l2ShrinkageRegularizationStrength; validateParams(); - } /** Validates all the settings of the Frtl Optmizer */ @@ -248,8 +247,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable tf.dtypes.cast(this.getLearningRateOperand(), gradient.dataType()), tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.dataType()), tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.dataType()), - tf.dtypes.cast( - tf.constant(l2ShrinkageRegularizationStrength), gradient.dataType()), + tf.dtypes.cast(tf.constant(l2ShrinkageRegularizationStrength), gradient.dataType()), tf.dtypes.cast(tf.constant(learningRatePower), gradient.dataType()), ApplyFtrl.useLocking(true)); } From cb8104ce7f26456bb3acb3c40b7f4d67fd5d8b90 Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Mon, 14 Sep 2020 19:54:37 -0400 Subject: [PATCH 07/14] Reformatted code --- .../framework/optimizers/AdaDeltaTest.java | 3 +- .../framework/optimizers/AdaGradDATest.java | 14 ++-- .../framework/optimizers/AdamTest.java | 16 +--- .../framework/optimizers/AdamaxTest.java | 3 +- .../optimizers/GradientDescentTest.java | 5 +- .../framework/optimizers/MomentumTest.java | 3 +- .../framework/optimizers/NadamTest.java | 79 +++++++++---------- .../framework/optimizers/RMSPropTest.java | 3 +- 8 files changed, 58 insertions(+), 68 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java index 37c7bc5ded0..3547ea9a30e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java @@ -235,7 +235,8 @@ public void testWithLearningRateDecay() { float totUpdate = 0; for (int step = 0; step < numSteps; step++) { assertEquals(learningRate, instance.getLearningRate(), epsilon); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(adadeltaUpdate, instance.getFeedMap()); accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); updates[step] = diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java index 8e44f7db0ed..1f8044c1168 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import static org.junit.jupiter.api.Assertions.assertEquals; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -29,6 +28,8 @@ import java.util.ArrayList; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; + /** Test cases for AdaGradDA Optimizer */ public class AdaGradDATest { @@ -140,14 +141,15 @@ public void testWithLearningRateDecay() { session.evaluate(var1Init, var1); float[][][] expected = { - {{ -2.121320f, -2.683281f},{ -0.298511f, -0.588348f}}, - {{ -3.680166f, -4.483282f}, { -0.565851f, -1.107964f}}, - {{ -4.895166f, -5.831203f}, { -0.805286f, -1.567190f}}, - {{ -5.873222f, -6.892054f}, { -1.019739f, -1.973306f}} + {{-2.121320f, -2.683281f}, {-0.298511f, -0.588348f}}, + {{-3.680166f, -4.483282f}, {-0.565851f, -1.107964f}}, + {{-4.895166f, -5.831203f}, {-0.805286f, -1.567190f}}, + {{-5.873222f, -6.892054f}, {-1.019739f, -1.973306f}} }; for (int i = 0; i < numSteps; i++) { assertEquals(learningRate, instance.getLearningRate(), epsilon); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); session.evaluate(expected[i][0], var0); session.evaluate(expected[i][1], var1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java index c6681fa1557..a8be65c3650 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java @@ -274,12 +274,7 @@ public void testWithLearningRateDecay() { .run() .get(0) .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> assertEquals(powers[0], f.getFloat(), epsilon1) - ); + result.data().scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); } try (Tensor<TFloat32> result = session @@ -289,14 +284,11 @@ public void testWithLearningRateDecay() { .run() .get(0) .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> assertEquals(powers[1], f.getFloat(), epsilon1)); + result.data().scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); } assertEquals(learningRate, instance.getLearningRate(), 1e-6f); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); float lr_t = diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java index 1d9648b9b51..57d3cbdb70c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java @@ -269,7 +269,8 @@ public void testWithLearningRateDecay() { }); } assertEquals(learningRate, instance.getLearningRate(), epsilon); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); FloatNdArray[] resultNP = calculate(var0Np, grads0Np, step, m0, v0, learningRate); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index ec59a046421..8e793e35d5f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -16,7 +16,6 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.Momentum.MOMENTUM; /** Test cases for GradientDescent Optimizer */ public class GradientDescentTest { @@ -129,7 +128,6 @@ public void testWithLearningRateDecay() { Op update = instance.applyGradients(gradsAndVars, "GradientDescentTest"); - /* initialize the local variables */ session.run(var0Initializer); session.run(var1Initializer); @@ -157,7 +155,8 @@ public void testWithLearningRateDecay() { }; for (int step = 0; step < numSteps; step++) { assertEquals(learningRate, instance.getLearningRate(), 1e-6f); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); session.evaluate(expectedVar0[step], var0); session.evaluate(expectedVar1[step], var1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java index ce5ad379629..b54e3b52a26 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java @@ -256,7 +256,8 @@ public void testWithLearningRateDecay() { }; for (int step = 0; step < numSteps; step++) { assertEquals(learningRate, instance.getLearningRate(), 1e-6); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); session.evaluate(expectedVar0[step], var0); session.evaluate(expectedVar1[step], var1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index fcdd1e3ef7c..c7c17689a33 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -243,7 +243,6 @@ public void testWithLearningRateDecay() { Constant<TFloat32> grads0 = tf.constant(grads0Init); Constant<TFloat32> grads1 = tf.constant(grads1Init); - /* build the GradsAnvVars */ List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); @@ -282,46 +281,40 @@ public void testWithLearningRateDecay() { session.evaluate(var1Init, var1); try (Tensor<TFloat32> result = - session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> assertEquals(1F, f.getFloat(), epsilon1)); + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); } momentum = 1F; for (int step = 0; step < numSteps; step++) { assertEquals(learningRate, instance.getLearningRate(), 1e-6f); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); float mut = - Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); + Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); momentum = momentum * mut; try (Tensor<TFloat32> result = - session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> assertEquals(momentum, f.getFloat(), epsilon1)); + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); } mcache = ND.mul(mcache, momentum); - FloatNdArray[] resultsNP = nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache, learningRate); + FloatNdArray[] resultsNP = + nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache, learningRate); var0Np = resultsNP[VAR]; m0 = resultsNP[M]; v0 = resultsNP[V]; @@ -349,24 +342,24 @@ public void testWithLearningRateDecay() { } } - private FloatNdArray[] nadamUpdateNdArray( - FloatNdArray varNp, - FloatNdArray gradsNp, - int t, - FloatNdArray m, - FloatNdArray v, - FloatNdArray mCache) { + FloatNdArray varNp, + FloatNdArray gradsNp, + int t, + FloatNdArray m, + FloatNdArray v, + FloatNdArray mCache) { return nadamUpdateNdArray(varNp, gradsNp, t, m, v, mCache, 0.001F); } + private FloatNdArray[] nadamUpdateNdArray( - FloatNdArray varNp, - FloatNdArray gradsNp, - int t, - FloatNdArray m, - FloatNdArray v, - FloatNdArray mCache, - float alpha) { + FloatNdArray varNp, + FloatNdArray gradsNp, + int t, + FloatNdArray m, + FloatNdArray v, + FloatNdArray mCache, + float alpha) { float beta1 = 0.9F; float beta2 = 0.999F; @@ -382,7 +375,7 @@ private FloatNdArray[] nadamUpdateNdArray( FloatNdArray vPrimeT = ND.div(vT, 1.F - (float) Math.pow(beta2, t + 1)); FloatNdArray mBarT = ND.add(ND.mul((1 - muT), gPrimeT), ND.mul(muT1, mPrimeT)); FloatNdArray paramT = - ND.sub(varNp, ND.div(ND.mul(alpha, mBarT), ND.add(ND.sqrt(vPrimeT), epsilon))); + ND.sub(varNp, ND.div(ND.mul(alpha, mBarT), ND.add(ND.sqrt(vPrimeT), epsilon))); FloatNdArray[] results = new FloatNdArray[3]; results[VAR] = paramT; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java index 6d489951c77..2a012ff0f99 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java @@ -306,7 +306,8 @@ public void testWithLearningRateDecay() { for (int i = 0; i < numSteps; i++) { assertEquals(learningRate, instance.getLearningRate(), epsilon); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); FloatNdArray[] result0 = calc( From 15189b4c90e18eb859c6e708615bc741594eccad Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Mon, 14 Sep 2020 19:55:07 -0400 Subject: [PATCH 08/14] Reformatted code --- .../framework/utils/TestSession.java | 170 +++++++++--------- 1 file changed, 83 insertions(+), 87 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index 47c39e820fc..713225a4962 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -65,6 +65,7 @@ public void initialize() { /** * Returns the Graph if in Graph mode, or null if in EagerMode + * * @return the Graph if in Graph mode, or null if in EagerMode */ public Graph getGraph() { @@ -78,19 +79,18 @@ public Graph getGraph() { * * @param op The Operation to run */ - public void run(Op op) { - run(op, null); + public void run(Op op) { + run(op, null); } - /** * Perform session.run() * * <p>If in eager mode, this does nothing. * * @param op The Operation to run - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public abstract void run(Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); @@ -110,8 +110,8 @@ public <U extends TNumber> void evaluate(Number expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( @@ -136,8 +136,8 @@ public void evaluate(Number expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( Number expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { @@ -161,14 +161,12 @@ public <U extends TNumber> void evaluate(Number[] expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type for the input */ public <U extends TNumber> void evaluate( - Number[] expected, - Op input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + Number[] expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Output<U> output = input.op().output(0); evaluate(expected, output, feedMap); } @@ -190,8 +188,8 @@ public <U extends TNumber> void evaluate(Number[] expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( @@ -218,8 +216,8 @@ public <U extends TNumber> void evaluate(byte expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( @@ -245,8 +243,8 @@ public <U extends TNumber> void evaluate(int expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( @@ -272,8 +270,8 @@ public <U extends TNumber> void evaluate(long expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( @@ -299,8 +297,8 @@ public <U extends TNumber> void evaluate(float expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( @@ -326,8 +324,8 @@ public <U extends TNumber> void evaluate(double expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public abstract <U extends TNumber> void evaluate( @@ -351,8 +349,8 @@ public <U extends TNumber> void evaluate(byte[] expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( @@ -382,8 +380,8 @@ public <U extends TNumber> void evaluate(int[] expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( @@ -413,8 +411,8 @@ public <U extends TNumber> void evaluate(long[] expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( @@ -444,8 +442,8 @@ public <U extends TNumber> void evaluate(float[] expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( @@ -475,8 +473,8 @@ public <U extends TNumber> void evaluate(double[] expected, Operand<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( @@ -506,8 +504,8 @@ public <U extends TNumber> void evaluate(Number[] expected, Output<U> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public abstract <U extends TNumber> void evaluate( @@ -530,8 +528,8 @@ public void evaluate(String expected, Operand<TString> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( String expected, @@ -555,8 +553,8 @@ public void evaluate(String expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( String expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { @@ -578,13 +576,11 @@ public void evaluate(String[] expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( - String[] expected, - Op input, - Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + String[] expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Output<TString> output = input.op().output(0); evaluate(expected, output, feedMap); } @@ -605,8 +601,8 @@ public void evaluate(String[] expected, Operand<TString> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public abstract void evaluate( String[] expected, @@ -628,8 +624,8 @@ public void evaluate(Boolean expected, Operand<TBool> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( Boolean expected, @@ -653,8 +649,8 @@ public void evaluate(Boolean expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( Boolean expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { @@ -677,8 +673,8 @@ public void evaluate(Boolean[] expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( Boolean[] expected, @@ -704,8 +700,8 @@ public void evaluate(Boolean[] expected, Operand<TBool> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( Boolean[] expected, @@ -730,8 +726,8 @@ public void evaluate(Boolean[] expected, Output<TBool> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public abstract void evaluate( Boolean[] expected, @@ -765,8 +761,8 @@ public <T extends TType> void evaluate(Operand<T> expected, Operand<T> input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <T> the data type for the feedMap entries */ public abstract <T extends TType> void evaluate( @@ -790,8 +786,8 @@ public <U extends TNumber> void evaluate(FloatNdArray expected, Operand<U> input * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public <U extends TNumber> void evaluate( @@ -817,8 +813,8 @@ public <U extends TNumber> void evaluate(FloatNdArray expected, Output<U> input) * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public abstract <U extends TNumber> void evaluate( @@ -844,8 +840,8 @@ public <U extends TNumber> void evaluate(Operand<U> input, Predicate<Number> pre * @param input the actual value * @param predicate a predicate that accepts a Number as an argument, if the result of the * predicate is false, then the test will fail - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <U> the data type of the input */ public abstract <U extends TNumber> void evaluate( @@ -878,8 +874,8 @@ public <T extends TType> void print(Operand<T> input) { * Print the results to the "standard" output stream. * * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <T> the data type for the feedMap entries */ public <T extends TType> void print( @@ -900,8 +896,8 @@ public void print(Op input) { * Print the results to the "standard" output stream. * * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void print(Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0), feedMap); @@ -921,8 +917,8 @@ public <T extends TType> void print(Output<T> input) { * Print the results to the "standard" output stream. * * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <T> the data type for the input */ public <T extends TType> void print( @@ -946,8 +942,8 @@ public <T extends TType> void print(OutputStream out, Operand<T> input) { * * @param out the output stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <T> the data type for the feedMap entries */ public <T extends TType> void print( @@ -972,8 +968,8 @@ public void print(OutputStream out, Op input) { * * @param out the output stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void print( OutputStream out, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { @@ -996,8 +992,8 @@ public <T extends TType> void print(OutputStream out, Output<T> input) { * * @param out the output stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <T> the data type for the input */ public <T extends TType> void print( @@ -1023,8 +1019,8 @@ public <T extends TType> void print(Writer writer, Operand<T> input) { * * @param writer the character stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <T> the data type for the input */ public <T extends TType> void print( @@ -1049,8 +1045,8 @@ public void print(Writer writer, Op input) { * * @param writer the character stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void print( Writer writer, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { @@ -1073,8 +1069,8 @@ public <T extends TType> void print(Writer writer, Output<T> input) { * * @param writer the character stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <T> the data type for the input */ public <T extends TType> void print( @@ -1100,8 +1096,8 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) { * * @param writer the character stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param <T> the data type for the input */ public abstract <T extends TType> void print( From eb2c48e0695b03fc8a71b880f5029c0ec8abb44c Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Tue, 15 Sep 2020 08:13:23 -0400 Subject: [PATCH 09/14] Remove premature commit --- .../schedules/PiecewiseConstantDecay.java | 58 -------- .../optimizers/schedules/PolynomialDecay.java | 127 ------------------ .../schedules/PiecewiseConstantDecayTest.java | 16 --- .../schedules/PolynomialDecayTest.java | 24 ---- 4 files changed, 225 deletions(-) delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java delete mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java delete mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java deleted file mode 100644 index 43f85fa0ff1..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.optimizers.schedules;; - -/** - * A LearningRateSchedule that uses a piecewise constant decay schedule. - * <p> - * <p>The function computes the piecewise constant - when passed the current optimizer step. This can be useful for changing the - learning rate value across different invocations of optimizer functions. - * <p> - * <p>Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 - for the next 10000 steps, and 0.1 for any additional steps. - */ -public class PiecewiseConstantDecay implements LearningRateSchedule { - private float[] boundaries; - private float[] values; - - private int lastIndex = 0; - - /** - * Create an PiecewiseConstantDecay - * - * @param boundaries An array of with strictly increasing entries - * @param values An array that specifies the - values for the intervals defined by <code>boundaries</code>. It should have one - more element than <code>boundaries</code>. - * @throws java.lang.IllegalArgumentException if the the length of values does not have 1 more element than boundaries. - */ - public PiecewiseConstantDecay(float[] boundaries, float[] values) { - if(boundaries.length != values.length - 1) { - throw new IllegalArgumentException("The length of boundaries should be 1 less than the length of values"); - } - this.boundaries = boundaries; - this.values = values; - } - - - @Override - public float call(int step) { - if(lastIndex < boundaries.length && step > boundaries[lastIndex]) - lastIndex++; - return values[lastIndex]; - } - -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java deleted file mode 100644 index 0988577c38f..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.optimizers.schedules; - -/** - * A LearningRateSchedule that uses a polynomial decay schedule. - * - * <p> - * - * <p>It is commonly observed that a monotonically decreasing learning rate, whose degree of change - * is carefully chosen, results in a better performing model. This schedule applies a polynomial - * decay function to an optimizer step, given a provided `initial_learning_rate`, to reach an - * `end_learning_rate` in the given `decay_steps`. - * - * <p> - * - * <p>The schedule is a 1-arg callable that produces a decayed learning rate when passed the current - * optimizer step. This can be useful for changing the learning rate value across different - * invocations of optimizer functions. It is computed as: - * - * <pre> - * step = min(step, decay_steps) - * ((initialLearningRate - endLearningRate) * - * (1 - step / decaySteps) ^ (power) - * ) + endLearningRate - * </pre> - * - * <p> - * - * <p>If `cycle` is True then a multiple of `decay_steps` is used, the first one that is bigger than - * `step`. - */ -public class PolynomialDecay implements LearningRateSchedule { - private static final float END_LEARNING_RATE_DEFAULT = 0.0001f; - public static final float POWER_DEFAULT = 1.0f; - public static final boolean CYCLE_DEFAULT = false; - - protected final float initialLearningRate; - protected final float decaySteps; - protected final float endLearningRate; - protected final float power; - protected final boolean cycle; - - /** - * Create a PolynomialDecay - * - * @param initialLearningRate The initial learning rate. - * @param decaySteps How often to apply decay. - */ - public PolynomialDecay(float initialLearningRate, int decaySteps) { - this(initialLearningRate, decaySteps, END_LEARNING_RATE_DEFAULT, POWER_DEFAULT, CYCLE_DEFAULT); - } - - /** - * Create a PolynomialDecay - * - * @param initialLearningRate The initial learning rate. - * @param decaySteps How often to apply decay. - * @param cycle Whether or not it should cycle beyond decay_steps. Default is false. - */ - public PolynomialDecay(float initialLearningRate, int decaySteps, boolean cycle) { - this(initialLearningRate, decaySteps, END_LEARNING_RATE_DEFAULT, POWER_DEFAULT, cycle); - } - - /** - * Create a PolynomialDecay - * - * @param initialLearningRate The initial learning rate. - * @param decaySteps How often to apply decay. - * @param endLearningRate The end learning rate. Default is 0.0001. - */ - public PolynomialDecay(float initialLearningRate, int decaySteps, float endLearningRate) { - this(initialLearningRate, decaySteps, endLearningRate, POWER_DEFAULT, CYCLE_DEFAULT); - } - - /** - * Create a PolynomialDecay - * - * @param initialLearningRate The initial learning rate. - * @param decaySteps How often to apply decay. - * @param endLearningRate The end learning rate. Default is 0.0001. - * @param power The power of the polynomial. Defaults to linear, 1.0. - * @param cycle Whether or not it should cycle beyond decay_steps. Default is false. - */ - public PolynomialDecay( - float initialLearningRate, - int decaySteps, - float endLearningRate, - float power, - boolean cycle) { - this.initialLearningRate = initialLearningRate; - this.decaySteps = decaySteps; - this.endLearningRate = endLearningRate; - this.power = power; - this.cycle = cycle; - } - - @Override - public float call(int step) { - - float lDecaySteps = decaySteps; - float lStep = step; - if (cycle) { - float multipler = step == 0 ? 1.0f : (float) Math.ceil(step / decaySteps); - lDecaySteps = decaySteps * multipler; - } else { - lStep = Math.min(lStep, lDecaySteps); - } - - float p = lStep / lDecaySteps; - - float f = (this.initialLearningRate - this.endLearningRate) * (float) Math.pow(1.0f - p, power); - return f + endLearningRate; - } -} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java deleted file mode 100644 index dac8caa19a3..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java +++ /dev/null @@ -1,16 +0,0 @@ -package org.tensorflow.framework.optimizers.schedules; - -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.*; - -class PiecewiseConstantDecayTest { - - public PiecewiseConstantDecayTest() {} - - @Test - public void testDecay() { - - } - -} \ No newline at end of file diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java deleted file mode 100644 index a28e56ad7cb..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java +++ /dev/null @@ -1,24 +0,0 @@ -package org.tensorflow.framework.optimizers.schedules; - -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.*; - -class PolynomialDecayTest { - - public PolynomialDecayTest() {} - - @Test - public void testBeginWithCycle() { - float initialLearningRate = 0.1f; - int decaySteps = 10; - float decayRate = 0.96f; - float epsilon = 1e-6f; - PolynomialDecay instance = new PolynomialDecay(initialLearningRate, decaySteps, true); - float expected = initialLearningRate; - float actual = instance.call(0); - assertEquals(expected, actual, epsilon); - - } - -} \ No newline at end of file From d5edd353059839bcb207b5a88d1c53eb31cc8e3c Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Sat, 19 Sep 2020 18:45:40 -0400 Subject: [PATCH 10/14] Added JavaDoc back in, changed setLearningRate() to setLearningRate(newLearningRate), eliminated spurious "this." --- .../framework/optimizers/Optimizer.java | 71 +++++++++++++++---- 1 file changed, 57 insertions(+), 14 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index 8e0471dc0ba..7c5258348c2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -111,20 +111,47 @@ protected Optimizer(Graph graph, String name, float learningRate) { setLearningRate(learningRate); } + /** + * Creates a name by combining a variable name and a slot name + * + * @param variable the variable + * @param slotName the name of the slot + * @return the combined name + */ public static String createName(Output<? extends TType> variable, String slotName) { return variable.op().name() + "-" + slotName; } + /** + * Minimizes the loss by updating the variables + * + * @param loss the loss operation that returns the value to minimize + * @return returns op that minimizes the loss by updating the listed variables + */ public Op minimize(Operand<?> loss) { return minimize(loss, getOptimizerName() + "-minimize"); } + /** + * Minimizes the loss by updating the variables + * + * @param loss the loss operation that returns the value to minimize + * @param name the name for the minimize operation + * @return op that minimizes the loss by updating the listed variables + */ public Op minimize(Operand<?> loss, String name) { List<GradAndVar<?>> gradsAndVars = computeGradients(loss); return applyGradients(gradsAndVars, name); } + /** + * Computes the gradients based on a loss operand. + * + * @param loss the loss operation + * @param <T> the data type of the loss, gradients and variables. + * @return the computed gradients + */ public <T extends TType> List<GradAndVar<?>> computeGradients(Operand<?> loss) { List<Operation> variables = new ArrayList<>(); graph @@ -156,6 +183,13 @@ public <T extends TType> List<GradAndVar<?>> computeGradients(Operand<?> loss) { return gradVarPairs; } + /** + * Applies gradients to variables + * + * @param gradsAndVars the list of (gradient, variable) pairs. + * @param name the name of the apply gradients operation + * @return an Op that applies the gradients to the variables. + */ public Op applyGradients(List<GradAndVar<? extends TType>> gradsAndVars, String name) { List<Output<? extends TType>> variables = gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList()); @@ -242,6 +276,13 @@ protected Optional<Op> prepare(String scopeName) { */ protected void createSlots(List<Output<? extends TType>> variables) {} + /** + * Generates the gradient update operations for the specific variable and gradient. + * + * @param gradVarPair the list of (gradient, variable) pairs. + * @param <T> the datatype of the gradients and variables. + * @return An operand which applies the desired optimizer update to the variable. + */ private <T extends TType> Op applyDense(GradAndVar<T> gradVarPair) { return applyDense(gradVarPair.getGradient(), gradVarPair.getVariable()); } @@ -280,20 +321,20 @@ protected Op finish(List<Op> updateOperations, String name) { /** * Sets the learning rate * - * @param learningRate the learning rate + * @param newLearningRate the new earning rate */ - public final void setLearningRate(float learningRate) { - if (this.learningRatePlaceholder == null) { - this.learningRatePlaceholder = + public final void setLearningRate(float newLearningRate) { + if (learningRatePlaceholder == null) { + learningRatePlaceholder = tf.withSubScope(LEARNING_RATE) .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); } - if (this.learningRate != learningRate) { - if (this.learningRateTensor != null) this.learningRateTensor.close(); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.feedMap = Collections.singletonMap(this.learningRatePlaceholder, learningRateTensor); + if (learningRate != newLearningRate) { + if (learningRateTensor != null) learningRateTensor.close(); + learningRate = newLearningRate; + learningRateTensor = TFloat32.scalarOf(learningRate); + feedMap = Collections.singletonMap(learningRatePlaceholder, learningRateTensor); } } @@ -303,7 +344,7 @@ public final void setLearningRate(float learningRate) { * @return the learning rate */ public float getLearningRate() { - return this.learningRate; + return learningRate; } /** @@ -312,7 +353,7 @@ public float getLearningRate() { * @return the learning rate Operand */ protected Operand<TFloat32> getLearningRateOperand() { - return this.learningRatePlaceholder; + return learningRatePlaceholder; } /** @@ -323,13 +364,15 @@ protected Operand<TFloat32> getLearningRateOperand() { * Operand has been set. */ public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedMap() { - return this.feedMap; + return feedMap; } + /** {@inheritDoc} */ public void close() { // close the learningRate Tensor if it exists. - if (this.feedMap != null) { - this.feedMap.get(this.learningRatePlaceholder).close(); + if (learningRateTensor != null) { + learningRateTensor.close(); + learningRateTensor = null; } } From 2f57c1df6c3eaeec39a64b87da8cf5a66525904a Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Mon, 21 Sep 2020 15:00:45 -0400 Subject: [PATCH 11/14] Change Optimizer to only have one constructor, "protected Optimizer(Graph graph, String name, float learningRate)"", change all the subclass ctors to use this one. --- .../framework/optimizers/AdaDelta.java | 4 +- .../framework/optimizers/AdaGrad.java | 9 +--- .../framework/optimizers/AdaGradDA.java | 18 +------ .../tensorflow/framework/optimizers/Adam.java | 5 +- .../framework/optimizers/Adamax.java | 5 +- .../tensorflow/framework/optimizers/Ftrl.java | 16 +++--- .../framework/optimizers/GradientDescent.java | 2 +- .../framework/optimizers/Momentum.java | 4 +- .../framework/optimizers/Nadam.java | 5 +- .../framework/optimizers/Optimizer.java | 49 +------------------ .../framework/optimizers/RMSProp.java | 6 +-- 11 files changed, 20 insertions(+), 103 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index be5b39a534d..9f2e868ea1a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -95,9 +95,7 @@ public AdaDelta(Graph graph, float learningRate) { * @param epsilon A constant epsilon used to better conditioning the grad update */ public AdaDelta(Graph graph, float learningRate, float rho, float epsilon) { - super(graph, learningRate); - this.rho = rho; - this.epsilon = epsilon; + this(graph, null, learningRate, rho, epsilon); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 384c04e60bb..9a9498630a8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -74,14 +74,7 @@ public AdaGrad(Graph graph, float learningRate) { * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative */ public AdaGrad(Graph graph, float learningRate, float initialAccumulatorValue) { - super(graph, learningRate); - if (initialAccumulatorValue < 0F) { - throw new IllegalArgumentException( - String.format( - "initialAccumulatorValue must be non-negative: %f", initialAccumulatorValue)); - } - this.learningRate = learningRate; - this.initialAccumulatorValue = initialAccumulatorValue; + this(graph, null, learningRate, initialAccumulatorValue); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index 48823bf5fd8..af5c3737251 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -100,23 +100,7 @@ public AdaGradDA( float initialAccumulatorValue, float l1Strength, float l2Strength) { - super(graph, learningRate); - if (initialAccumulatorValue <= 0F) { - throw new IllegalArgumentException( - String.format( - "initialAccumulatorValue must be greater than zero: %f", initialAccumulatorValue)); - } - if (l1Strength < 0F) { - throw new IllegalArgumentException( - String.format("l1Strength must not be negative: %f", l1Strength)); - } - if (l2Strength < 0F) { - throw new IllegalArgumentException( - String.format("l2Strength must not be negative: %f", l2Strength)); - } - this.initialAccumulatorValue = initialAccumulatorValue; - this.l1Strength = l1Strength; - this.l2Strength = l2Strength; + this(graph, null, learningRate, initialAccumulatorValue, l1Strength, l2Strength); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java index b5915b4bd57..3ca9fbdab57 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java @@ -99,10 +99,7 @@ public Adam(Graph graph, float learningRate) { * 1 of the paper. Defaults to 1e-8. */ public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph, learningRate); - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; + this(graph, null, learningRate, betaOne, betaTwo, epsilon); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index e568f881773..c381013e97c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -92,10 +92,7 @@ public Adamax(Graph graph, String name, float learningRate) { * @param epsilon A small constant for numerical stability. */ public Adamax(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph, learningRate); - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; + this(graph, null, learningRate, betaOne, betaTwo, epsilon); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index b455ae1f0be..edbe91c62e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -132,13 +132,15 @@ public Ftrl( float l1Strength, float l2Strength, float l2ShrinkageRegularizationStrength) { - super(graph, learningRate); - this.learningRatePower = learningRatePower; - this.initialAccumulatorValue = initialAccumulatorValue; - this.l1RegularizationStrength = l1Strength; - this.l2RegularizationStrength = l2Strength; - this.l2ShrinkageRegularizationStrength = l2ShrinkageRegularizationStrength; - validateParams(); + this( + graph, + null, + learningRate, + learningRatePower, + initialAccumulatorValue, + l1Strength, + l2Strength, + l2ShrinkageRegularizationStrength); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java index 36f36057c26..f57503d3347 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java @@ -44,7 +44,7 @@ public GradientDescent(Graph graph) { * @param learningRate the learning rate, defaults to 0.01 */ public GradientDescent(Graph graph, float learningRate) { - super(graph, learningRate); + super(graph, null, learningRate); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index a099eae53e8..19e3f275f1f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -84,9 +84,7 @@ public Momentum(Graph graph, float learningRate, float momentum) { * @param useNesterov Whether to apply Nesterov momentum. Defaults to false. */ public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) { - super(graph, learningRate); - this.momentum = momentum; - this.useNesterov = useNesterov; + this(graph, null, learningRate, momentum, useNesterov); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index d0228eb8b3a..ece7c024969 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -94,10 +94,7 @@ public Nadam(Graph graph, float learningRate) { * @param epsilon A small constant for numerical stability. Default is 1e-8. */ public Nadam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph, learningRate); - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; + this(graph, null, learningRate, betaOne, betaTwo, epsilon); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index 7c5258348c2..868d04672f1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -51,51 +51,6 @@ public abstract class Optimizer implements AutoCloseable { private Tensor<TFloat32> learningRateTensor; private Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap = null; - /** - * Builds an optimizer for the supplied graph. - * - * <p>Uses the name from {@link Optimizer#getOptimizerName()} to name the operations. - * - * @param graph The graph to optimize. - */ - protected Optimizer(Graph graph) { - this.graph = graph; - this.tf = Ops.create(graph).withName(getOptimizerName()); - this.slots = new HashMap<>(); - this.globals = new ArrayList<>(); - setLearningRate(LEARNING_RATE_DEFAULT); - } - - /** - * Builds an optimizer for the supplied graph. - * - * <p>Uses the name from {@link Optimizer#getOptimizerName()} to name the operations. - * - * @param graph The graph to optimize. - * @param learningRate the learning rate. - */ - protected Optimizer(Graph graph, float learningRate) { - this.graph = graph; - this.tf = Ops.create(graph).withName(getOptimizerName()); - this.slots = new HashMap<>(); - this.globals = new ArrayList<>(); - setLearningRate(learningRate); - } - - /** - * Builds an optimizer for the supplied graph. - * - * @param graph The graph to optimize. - * @param name The base name for the operations. - */ - protected Optimizer(Graph graph, String name) { - this.graph = graph; - this.tf = Ops.create(graph).withName(name); - this.slots = new HashMap<>(); - this.globals = new ArrayList<>(); - setLearningRate(LEARNING_RATE_DEFAULT); - } - /** * Builds an optimizer for the supplied graph. * @@ -105,7 +60,7 @@ protected Optimizer(Graph graph, String name) { */ protected Optimizer(Graph graph, String name, float learningRate) { this.graph = graph; - this.tf = Ops.create(graph).withName(name); + this.tf = Ops.create(graph).withName(name == null ? getOptimizerName() : name); this.slots = new HashMap<>(); this.globals = new ArrayList<>(); setLearningRate(learningRate); @@ -367,7 +322,7 @@ public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedMap() { return feedMap; } - /** {@inheritDoc} */ + /** {@inheritDoc} */ public void close() { // close the learningRate Tensor if it exists. if (learningRateTensor != null) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index face906d682..41b65a0ac01 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -106,11 +106,7 @@ public RMSProp( float momentum, float epsilon, boolean centered) { - super(graph, learningRate); - this.decay = decay; - this.momentum = momentum; - this.epsilon = epsilon; - this.centered = centered; + this(graph, null, learningRate, decay, momentum, epsilon, centered); } /** From e9e2b24608f83de551608aa935caffff328a71d9 Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Wed, 23 Sep 2020 10:53:52 -0400 Subject: [PATCH 12/14] Fixed close() routine to free up closed tensor in feedMap by setting feedMap to null. --- .../java/org/tensorflow/framework/optimizers/Optimizer.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index 868d04672f1..5194cb32e73 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -276,7 +276,7 @@ protected Op finish(List<Op> updateOperations, String name) { /** * Sets the learning rate * - * @param newLearningRate the new earning rate + * @param newLearningRate the new learning rate */ public final void setLearningRate(float newLearningRate) { if (learningRatePlaceholder == null) { @@ -315,7 +315,7 @@ protected Operand<TFloat32> getLearningRateOperand() { * Gets the Feed Map for the run methods to set the Placeholder value(s). Each entry in the Feed * Map contains a PlaceHolder and a Tensor with the value * - * @return the current Feed Map for the run methods, this may be null if an LearningRate as an + * @return the current Feed Map for the run methods, this may be null if the LearningRate is an * Operand has been set. */ public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedMap() { @@ -329,6 +329,7 @@ public void close() { learningRateTensor.close(); learningRateTensor = null; } + if (feedMap != null) feedMap = null; } /** Optional attributes for {@link org.tensorflow.framework.optimizers.Optimizer} */ From 62ff85c170983312b98e0d00561ae743db2b2e4c Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Fri, 25 Sep 2020 11:02:05 -0400 Subject: [PATCH 13/14] Fix javadoc for references to Default values, Add Operand<TFloat32> learningRateOperand as an option for learning rate. --- .../framework/optimizers/AdaDelta.java | 78 ++++++++- .../framework/optimizers/AdaGrad.java | 86 +++++++++- .../framework/optimizers/AdaGradDA.java | 121 ++++++++++++- .../tensorflow/framework/optimizers/Adam.java | 105 ++++++++++-- .../framework/optimizers/Adamax.java | 100 ++++++++++- .../tensorflow/framework/optimizers/Ftrl.java | 159 +++++++++++++++++- .../framework/optimizers/GradientDescent.java | 39 ++++- .../framework/optimizers/Momentum.java | 99 +++++++++-- .../framework/optimizers/Nadam.java | 111 +++++++++--- .../framework/optimizers/Optimizer.java | 40 ++++- .../framework/optimizers/RMSProp.java | 129 ++++++++++++-- 11 files changed, 964 insertions(+), 103 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index 9f2e868ea1a..30abd0fcbe3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -20,6 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -62,6 +63,7 @@ */ public class AdaDelta extends Optimizer { + public static final String DEFAULT_NAME = "Adadelta"; public static final String ACCUMULATOR = "accum"; public static final String ACCUMULATOR_UPDATE = "accum_update"; public static final float LEARNING_RATE_DEFAULT = 0.001f; @@ -72,12 +74,20 @@ public class AdaDelta extends Optimizer { private final float epsilon; + /** + * Creates an AdaDelta Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learningRate, {@link #RHO_DEFAULT} for the rho, and {@link + * #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph. + */ public AdaDelta(Graph graph) { this(graph, LEARNING_RATE_DEFAULT, RHO_DEFAULT, EPSILON_DEFAULT); } /** - * Creates an AdaDelta Optimizer + * Creates an AdaDelta Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #RHO_DEFAULT} for the rho, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -87,7 +97,19 @@ public AdaDelta(Graph graph, float learningRate) { } /** - * Creates an AdaDelta Optimizer + * Creates an AdaDelta Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #RHO_DEFAULT} for the rho, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaDelta(Graph graph, Operand<TFloat32> learningRateOperand) { + this(graph, learningRateOperand, RHO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates an AdaDelta Optimizer {@link #DEFAULT_NAME} for the Optimizer name * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -102,18 +124,45 @@ public AdaDelta(Graph graph, float learningRate, float rho, float epsilon) { * Creates an AdaDelta Optimizer * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'Adadelta') + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param rho The decay factor + * @param epsilon A constant epsilon used to better conditioning the grad update + */ + public AdaDelta(Graph graph, Operand<TFloat32> learningRateOperand, float rho, float epsilon) { + this(graph, null, learningRateOperand, rho, epsilon); + } + + /** + * Creates an AdaDelta Optimizer using {@link #RHO_DEFAULT} for the rho, and {@link * + * #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. * @param learningRate the learning rate */ public AdaDelta(Graph graph, String name, float learningRate) { - this(graph, name, learningRate, 0.95f, 1e-8f); + this(graph, name, learningRate, RHO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates an AdaDelta Optimizer using {@link #RHO_DEFAULT} for the rho, and {@link * + * #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaDelta(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this(graph, name, learningRateOperand, RHO_DEFAULT, EPSILON_DEFAULT); } /** * Creates an AdaDelta Optimizer * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'Adadelta') + * @param name the name for this Optimizer. * @param learningRate the learning rate * @param rho The decay factor * @param epsilon A constant epsilon used to better conditioning the grad update @@ -124,6 +173,23 @@ public AdaDelta(Graph graph, String name, float learningRate, float rho, float e this.epsilon = epsilon; } + /** + * Creates an AdaDelta Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param rho The decay factor + * @param epsilon A constant epsilon used to better conditioning the grad update + */ + public AdaDelta( + Graph graph, String name, Operand<TFloat32> learningRateOperand, float rho, float epsilon) { + super(graph, name, learningRateOperand); + this.rho = rho; + this.epsilon = epsilon; + } + /** {@inheritDoc} */ @Override protected void createSlots(List<Output<? extends TType>> variables) { @@ -178,6 +244,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Adadelta"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 9a9498630a8..c0cc47409d7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -20,6 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -40,6 +41,8 @@ */ public class AdaGrad extends Optimizer { + public static final String DEFAULT_NAME = "Adagrad"; + public static final String ACCUMULATOR = "accumulator"; public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float INITIAL_ACCUMULATOR_DEFAULT = 0.01f; @@ -47,7 +50,9 @@ public class AdaGrad extends Optimizer { private final float initialAccumulatorValue; /** - * Creates an AdaGrad Optimizer + * Creates an AdaGrad Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, and {@link * #INITIAL_ACCUMULATOR_DEFAULT} for + * the initialAccumulatorValue. * * @param graph the TensorFlow Graph */ @@ -56,7 +61,8 @@ public AdaGrad(Graph graph) { } /** - * Creates an AdaGrad Optimizer + * Creates an AdaGrad Optimizer using using {@link #DEFAULT_NAME} for the Optimizer name, {@link * + * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -66,7 +72,19 @@ public AdaGrad(Graph graph, float learningRate) { } /** - * Creates an AdaGrad Optimizer + * Creates an AdaGrad Optimizer using using {@link #DEFAULT_NAME} for the Optimizer name, {@link * + * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaGrad(Graph graph, Operand<TFloat32> learningRateOperand) { + this(graph, learningRateOperand, INITIAL_ACCUMULATOR_DEFAULT); + } + + /** + * Creates an AdaGrad Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -78,21 +96,49 @@ public AdaGrad(Graph graph, float learningRate, float initialAccumulatorValue) { } /** - * Creates an AdaGrad Optimizer + * Creates an AdaGrad Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param initialAccumulatorValue Starting value for the accumulators, must be non-negative. + * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative + */ + public AdaGrad( + Graph graph, Operand<TFloat32> learningRateOperand, float initialAccumulatorValue) { + this(graph, null, learningRateOperand, initialAccumulatorValue); + } + + /** + * Creates an AdaGrad Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue. * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'Adagrad') + * @param name the name for this Optimizer . * @param learningRate the learning rate */ public AdaGrad(Graph graph, String name, float learningRate) { - this(graph, name, learningRate, 0.01f); + this(graph, name, learningRate, INITIAL_ACCUMULATOR_DEFAULT); + } + + /** + * Creates an AdaGrad Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaGrad(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this(graph, name, learningRateOperand, INITIAL_ACCUMULATOR_DEFAULT); } /** * Creates an AdaGrad Optimizer * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'Adagrad') + * @param name the name for this Optimizer * @param learningRate the learning rate * @param initialAccumulatorValue Starting value for the accumulators, must be non-negative. * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative @@ -107,6 +153,30 @@ public AdaGrad(Graph graph, String name, float learningRate, float initialAccumu this.initialAccumulatorValue = initialAccumulatorValue; } + /** + * Creates an AdaGrad Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param initialAccumulatorValue Starting value for the accumulators, must be non-negative. + * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative + */ + public AdaGrad( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float initialAccumulatorValue) { + super(graph, name, learningRateOperand); + if (initialAccumulatorValue < 0F) { + throw new IllegalArgumentException( + String.format( + "initialAccumulatorValue must be non-negative: %f", initialAccumulatorValue)); + } + this.initialAccumulatorValue = initialAccumulatorValue; + } + /** {@inheritDoc} */ @Override protected void createSlots(List<Output<? extends TType>> variables) { @@ -149,6 +219,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Adagrad"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index af5c3737251..0e070f2f4fa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -22,6 +22,7 @@ import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -46,6 +47,7 @@ */ public class AdaGradDA extends Optimizer { + public static final String DEFAULT_NAME = "adagrad-da"; public static final String ACCUMULATOR = "gradient_accumulator"; public static final String SQUARED_ACCUMULATOR = "gradient_squared_accumulator"; public static final float LEARNING_RATE_DEFAULT = 0.001F; @@ -59,7 +61,10 @@ public class AdaGradDA extends Optimizer { private Variable<TInt64> globalStep; /** - * Creates an AdaGradDA Optimizer + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for the l1Strength, and {@link + * #L2_STRENGTH_DEFAULT} for the l2Strength. * * @param graph the TensorFlow Graph */ @@ -73,7 +78,9 @@ public AdaGradDA(Graph graph) { } /** - * Creates an AdaGradDA Optimizer + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for + * the l1Strength, and {@link #L2_STRENGTH_DEFAULT} for the l2Strength. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -84,7 +91,25 @@ public AdaGradDA(Graph graph, float learningRate) { } /** - * Creates an AdaGradDA Optimizer + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for + * the l1Strength, and {@link #L2_STRENGTH_DEFAULT} for the l2Strength. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaGradDA(Graph graph, Operand<TFloat32> learningRateOperand) { + this( + graph, + learningRateOperand, + INITIAL_ACCUMULATOR_DEFAULT, + L1_STRENGTH_DEFAULT, + L2_STRENGTH_DEFAULT); + } + + /** + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -104,10 +129,33 @@ public AdaGradDA( } /** - * Creates an AdaGradDA Optimizer + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'adagrad-da') + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param initialAccumulatorValue Starting value for the accumulators, must be greater than zero. + * @param l1Strength l1 regularization strength, must be greater than or equal to zero. + * @param l2Strength l2 regularization strength, must be greater than or equal to zero. + * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is not greater than zero, + * or l1Strength or l2Strength is less than zero + */ + public AdaGradDA( + Graph graph, + Operand<TFloat32> learningRateOperand, + float initialAccumulatorValue, + float l1Strength, + float l2Strength) { + this(graph, null, learningRateOperand, initialAccumulatorValue, l1Strength, l2Strength); + } + + /** + * Creates an AdaGradDA Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for the l1Strength, and {@link + * #L2_STRENGTH_DEFAULT} for the l2Strength. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. * @param learningRate the learning rate */ public AdaGradDA(Graph graph, String name, float learningRate) { @@ -120,11 +168,31 @@ public AdaGradDA(Graph graph, String name, float learningRate) { L2_STRENGTH_DEFAULT); } + /** + * Creates an AdaGradDA Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for the l1Strength, and {@link + * #L2_STRENGTH_DEFAULT} for the l2Strength. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaGradDA(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this( + graph, + name, + learningRateOperand, + INITIAL_ACCUMULATOR_DEFAULT, + L1_STRENGTH_DEFAULT, + L2_STRENGTH_DEFAULT); + } + /** * Creates an AdaGradDA Optimizer * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'adagrad-da') + * @param name the name for this Optimizer. * @param learningRate the learning rate * @param initialAccumulatorValue Starting value for the accumulators, must be positive * @param l1Strength l1 regularization strength, must be greater than or equal to zero. @@ -158,6 +226,45 @@ public AdaGradDA( this.l2Strength = l2Strength; } + /** + * Creates an AdaGradDA Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param initialAccumulatorValue Starting value for the accumulators, must be positive + * @param l1Strength l1 regularization strength, must be greater than or equal to zero. + * @param l2Strength l2 regularization strength, must be greater than or equal to zero. + * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is not greater than zero, + * or * l1Strength or l2Strength is less than zero + */ + public AdaGradDA( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float initialAccumulatorValue, + float l1Strength, + float l2Strength) { + super(graph, name, learningRateOperand); + if (initialAccumulatorValue <= 0F) { + throw new IllegalArgumentException( + String.format( + "initialAccumulatorValue must be greater than zero: %f", initialAccumulatorValue)); + } + if (l1Strength < 0F) { + throw new IllegalArgumentException( + String.format("l1Strength must not be negative: %f", l1Strength)); + } + if (l2Strength < 0F) { + throw new IllegalArgumentException( + String.format("l2Strength must not be negative: %f", l2Strength)); + } + this.initialAccumulatorValue = initialAccumulatorValue; + this.l1Strength = l1Strength; + this.l2Strength = l2Strength; + } + /** {@inheritDoc} */ @Override protected Optional<Op> prepare(String name) { @@ -240,6 +347,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "adagrad-da"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java index 3ca9fbdab57..9e4f41f1039 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java @@ -48,6 +48,8 @@ @Operator public class Adam extends Optimizer { + public static final String DEFAULT_NAME = "Adam"; + public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; @@ -69,7 +71,9 @@ public class Adam extends Optimizer { private Variable<TFloat32> betaTwoPower; /** - * Creates an Adam optimizer + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, + * {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph */ @@ -78,7 +82,9 @@ public Adam(Graph graph) { } /** - * Creates an Adam optimizer + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and + * {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph * @param learningRate the learning rate @@ -86,44 +92,91 @@ public Adam(Graph graph) { public Adam(Graph graph, float learningRate) { this(graph, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } + /** + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and + * {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Adam(Graph graph, Operand<TFloat32> learningRateOperand) { + this(graph, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } /** - * Creates an Adam optimizer + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph * @param learningRate the learning rate - * @param betaOne The exponential decay rate for the 1st moment estimates. Defaults to 0.9. - * @param betaTwo The exponential decay rate for the 2nd moment estimates. Defaults to 0.999. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the 2nd moment estimates. * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm - * 1 of the paper. Defaults to 1e-8. + * 1 of the paper.. */ public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { this(graph, null, learningRate, betaOne, betaTwo, epsilon); } /** - * Creates an Adam optimizer + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph - * @param name the Optimizer name, defaults to "Adam" + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the 2nd moment estimates. + * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the + * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm + * 1 of the paper. + */ + public Adam( + Graph graph, + Operand<TFloat32> learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + this(graph, null, learningRateOperand, betaOne, betaTwo, epsilon); + } + + /** + * Creates an Adam optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link + * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param name the Optimizer name. * @param learningRate the learning rate */ public Adam(Graph graph, String name, float learningRate) { this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } + /** + * Creates an Adam optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link + * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param name the Optimizer name. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Adam(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this(graph, name, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + /** * Creates an Adam optimizer * * @param graph the TensorFlow graph - * @param name the Optimizer name, defaults to "Adam" + * @param name the Optimizer name. * @param learningRate the learning rate - * @param betaOne The exponential decay rate for the 1st moment estimates. Defaults to 0.9. - * @param betaTwo The exponential decay rate for the 2nd moment estimates. Defaults to 0.999. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the 2nd moment estimates. * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm - * 1 of the paper. Defaults to 1e-8. + * 1 of the paper. */ public Adam( Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { @@ -133,6 +186,32 @@ public Adam( this.epsilon = epsilon; } + /** + * Creates an Adam optimizer + * + * @param graph the TensorFlow graph + * @param name the Optimizer name. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the 2nd moment estimates. + * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the + * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm + * 1 of the paper. + */ + public Adam( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + super(graph, name, learningRateOperand); + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + /** * Creates the Operation that minimizes the loss * @@ -265,6 +344,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Adam"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index c381013e97c..e33775db961 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -25,6 +25,7 @@ */ public class Adamax extends Optimizer { + public static final String DEFAULT_NAME = "Adamax"; public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; @@ -43,7 +44,10 @@ public class Adamax extends Optimizer { private Variable<TFloat32> betaOnePower; /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, using {@link #DEFAULT_NAME} for the + * Optimizer name, {@link #LEARNING_RATE_DEFAULT} for the learning rate, {@link #BETA_ONE_DEFAULT} + * for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link + * #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph */ @@ -52,17 +56,21 @@ public Adamax(Graph graph) { } /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, {@link #LEARNING_RATE_DEFAULT} for + * the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} + * for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph - * @param name name for the operations Created when applying gradients. Defaults to "Adamax". + * @param name name for the operations Created when applying gradients. */ public Adamax(Graph graph, String name) { this(graph, name, LEARNING_RATE_DEFAULT, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, using {@link #DEFAULT_NAME} for the + * Optimizer name, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for + * the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph * @param learningRate The learning rate. @@ -72,18 +80,48 @@ public Adamax(Graph graph, float learningRate) { } /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, using {@link #DEFAULT_NAME} for the + * Optimizer name, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for + * the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Adamax(Graph graph, Operand<TFloat32> learningRateOperand) { + this(graph, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates an Optimizer that implements the Adamax algorithm, using {@link #BETA_ONE_DEFAULT} for + * the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link + * #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph - * @param name name for the operations Created when applying gradients. Defaults to "Adamax". + * @param name name for the operations Created when applying gradients. * @param learningRate The learning rate. */ public Adamax(Graph graph, String name, float learningRate) { this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } + /** + * Creates an Optimizer that implements the Adamax algorithm, using {@link #BETA_ONE_DEFAULT} for + * the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link + * #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param name name for the operations Created when applying gradients. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Adamax(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this(graph, name, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, {@link #LEARNING_RATE_DEFAULT} for + * the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} + * for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph * @param learningRate The learning rate. @@ -94,12 +132,32 @@ public Adamax(Graph graph, String name, float learningRate) { public Adamax(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { this(graph, null, learningRate, betaOne, betaTwo, epsilon); } + /** + * Creates an Optimizer that implements the Adamax algorithm, {@link #LEARNING_RATE_DEFAULT} for + * the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} + * for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. + */ + public Adamax( + Graph graph, + Operand<TFloat32> learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + this(graph, null, learningRateOperand, betaOne, betaTwo, epsilon); + } /** * Creates an Optimizer that implements the Adamax algorithm. * * @param graph the TensorFlow graph - * @param name name for the operations Created when applying gradients. Defaults to "Adamax". + * @param name name for the operations Created when applying gradients. * @param learningRate The learning rate. * @param betaOne The exponential decay rate for the 1st moment estimates. * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. @@ -113,6 +171,30 @@ public Adamax( this.epsilon = epsilon; } + /** + * Creates an Optimizer that implements the Adamax algorithm. + * + * @param graph the TensorFlow graph + * @param name name for the operations Created when applying gradients. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. + */ + public Adamax( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + super(graph, name, learningRateOperand); + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + /** {@inheritDoc} */ @Override protected Optional<Op> prepare(String scopeName) { @@ -177,6 +259,6 @@ protected Op finish(List<Op> updateOperations, String name) { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Adamax"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index edbe91c62e9..35eeb7dc225 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -6,6 +6,7 @@ import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyFtrl; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -20,6 +21,8 @@ */ public class Ftrl extends Optimizer { + public static final String DEFAULT_NAME = "Ftrl"; + public static final String ACCUMULATOR = "gradient_accumulator"; public static final String LINEAR_ACCUMULATOR = "linear_accumulator"; @@ -37,7 +40,12 @@ public class Ftrl extends Optimizer { private final float l2ShrinkageRegularizationStrength; /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #LEARNING_RATE_POWER_DEFAULT} for the + * learningRatePower. {@link #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, + * {@link #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength + * and {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. * * @param graph the TensorFlow Graph */ @@ -53,7 +61,12 @@ public Ftrl(Graph graph) { } /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #LEARNING_RATE_DEFAULT} for the learning rate, {@link + * #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. {@link + * #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. * * @param graph the TensorFlow Graph * @param name the name of this Optimizer @@ -71,7 +84,12 @@ public Ftrl(Graph graph, String name) { } /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. {@link + * #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -88,7 +106,34 @@ public Ftrl(Graph graph, float learningRate) { } /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. {@link + * #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Ftrl(Graph graph, Operand<TFloat32> learningRateOperand) { + this( + graph, + learningRateOperand, + LEARNING_RATE_POWER_DEFAULT, + INITIAL_ACCUMULATOR_VALUE_DEFAULT, + L1STRENGTH_DEFAULT, + L2STRENGTH_DEFAULT, + L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT); + } + + /** + * Creates an Ftrl Optimizer using {@link #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. + * {@link #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. * * @param graph the TensorFlow Graph * @param name the name of this Optimizer @@ -107,7 +152,31 @@ public Ftrl(Graph graph, String name, float learningRate) { } /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. + * {@link #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Ftrl(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this( + graph, + name, + learningRateOperand, + LEARNING_RATE_POWER_DEFAULT, + INITIAL_ACCUMULATOR_VALUE_DEFAULT, + L1STRENGTH_DEFAULT, + L2STRENGTH_DEFAULT, + L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT); + } + + /** + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -143,6 +212,44 @@ public Ftrl( l2ShrinkageRegularizationStrength); } + /** + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param learningRatePower Controls how the learning rate decreases during training. Use zero for + * a fixed learning rate. + * @param initialAccumulatorValue The starting value for accumulators. Only zero or positive + * values are allowed. + * @param l1Strength the L1 Regularization strength, must be greater than or equal to zero. + * @param l2Strength the L2 Regularization strength, must be greater than or equal to zero. + * @param l2ShrinkageRegularizationStrength This differs from L2 above in that the L2 above is a + * stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. must be greater + * than or equal to zero. + * @throws java.lang.IllegalArgumentException if the initialAccumulatorValue, + * l1RegularizationStrength, l2RegularizationStrength, or l2ShrinkageRegularizationStrength + * are less than 0.0, or learningRatePower is greater than 0.0. + */ + public Ftrl( + Graph graph, + Operand<TFloat32> learningRateOperand, + float learningRatePower, + float initialAccumulatorValue, + float l1Strength, + float l2Strength, + float l2ShrinkageRegularizationStrength) { + this( + graph, + null, + learningRateOperand, + learningRatePower, + initialAccumulatorValue, + l1Strength, + l2Strength, + l2ShrinkageRegularizationStrength); + } + /** * Creates a Ftrl Optimizer * @@ -180,7 +287,45 @@ public Ftrl( validateParams(); } - /** Validates all the settings of the Frtl Optmizer */ + /** + * Creates a Ftrl Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param learningRatePower Controls how the learning rate decreases during training. Use zero for + * a fixed learning rate. + * @param initialAccumulatorValue The starting value for accumulators. Only zero or positive + * values are allowed. + * @param l1Strength the L1 Regularization strength, must be greater than or equal to zero. + * @param l2Strength the L2 Regularization strength, must be greater than or equal to zero. + * @param l2ShrinkageRegularizationStrength This differs from L2 above in that the L2 above is a + * stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. must be greater + * than or equal to zero. + * @throws java.lang.IllegalArgumentException if the initialAccumulatorValue, + * l1RegularizationStrength, l2RegularizationStrength, or l2ShrinkageRegularizationStrength + * are less than 0.0, or learningRatePower is greater than 0.0. + */ + public Ftrl( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float learningRatePower, + float initialAccumulatorValue, + float l1Strength, + float l2Strength, + float l2ShrinkageRegularizationStrength) { + super(graph, name, learningRateOperand); + this.learningRatePower = learningRatePower; + this.initialAccumulatorValue = initialAccumulatorValue; + this.l1RegularizationStrength = l1Strength; + this.l2RegularizationStrength = l2Strength; + this.l2ShrinkageRegularizationStrength = l2ShrinkageRegularizationStrength; + validateParams(); + } + + /** Validates all the settings of the Ftrl Optimizer */ private void validateParams() { if (this.initialAccumulatorValue < 0.0F) { throw new IllegalArgumentException( @@ -257,6 +402,6 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Ftrl"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java index f57503d3347..efec399f40d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java @@ -16,8 +16,10 @@ package org.tensorflow.framework.optimizers; import org.tensorflow.Graph; +import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.Op; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; /** @@ -26,10 +28,12 @@ */ public class GradientDescent extends Optimizer { + public static final String DEFAULT_NAME = "GradientDescent"; public static final float LEARNING_RATE_DEFAULT = 0.01f; /** - * Creates a GradientDescent Optimizer + * Creates a GradientDescent Optimizer using {@link #DEFAULT_NAME} for the Optimizer name and + * {@link #LEARNING_RATE_DEFAULT} for the learning rate. * * @param graph the TensorFlow graph */ @@ -38,26 +42,49 @@ public GradientDescent(Graph graph) { } /** - * Creates a GradientDescent Optimizer + * Creates a GradientDescent Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph - * @param learningRate the learning rate, defaults to 0.01 + * @param learningRate the learning rate. */ public GradientDescent(Graph graph, float learningRate) { super(graph, null, learningRate); } + /** + * Creates a GradientDescent Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public GradientDescent(Graph graph, Operand<TFloat32> learningRateOperand) { + super(graph, null, learningRateOperand); + } + /** * Creates a GradientDescent Optimizer * * @param graph the TensorFlow graph - * @param name the name for this Optimizer, default is "GradientDescent" - * @param learningRate the learning rate, defaults to 0.01 + * @param name the name for this Optimizer. + * @param learningRate the learning rate. */ public GradientDescent(Graph graph, String name, float learningRate) { super(graph, name, learningRate); } + /** + * Creates a GradientDescent Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public GradientDescent(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + super(graph, name, learningRateOperand); + } + /** {@inheritDoc} */ @Override protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { @@ -74,6 +101,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "GradientDescent"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index 19e3f275f1f..436b587c353 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -21,6 +21,7 @@ import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyMomentum; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -33,6 +34,7 @@ */ public class Momentum extends Optimizer { + public static final String DEFAULT_NAME = "Momentum"; public static final float LEARNING_RATE_DEFAULT = 0.01F; public static final float MOMENTUM_DEFAULT = 0.0F; public static final boolean NESTEROV_DEFAULT = false; @@ -44,7 +46,9 @@ public class Momentum extends Optimizer { private final boolean useNesterov; /** - * Creates a Momentum Optimizer + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #MOMENTUM_DEFAULT} for the momentum, and + * {@link #NESTEROV_DEFAULT} for the Nesterov flag. * * @param graph the TensorFlow graph */ @@ -53,7 +57,8 @@ public Momentum(Graph graph) { } /** - * Creates a Momentum Optimizer + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #MOMENTUM_DEFAULT} for the momentum, and {@link #NESTEROV_DEFAULT} for the Nesterov flag. * * @param graph the TensorFlow graph * @param learningRate the learning rate @@ -63,30 +68,76 @@ public Momentum(Graph graph, float learningRate) { } /** - * Creates a Momentum Optimizer + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #MOMENTUM_DEFAULT} for the momentum, and {@link #NESTEROV_DEFAULT} for the Nesterov flag. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Momentum(Graph graph, Operand<TFloat32> learningRateOperand) { + this(graph, learningRateOperand, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); + } + + /** + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name and {@link + * #NESTEROV_DEFAULT} for the Nesterov flag. * * @param graph the TensorFlow graph * @param learningRate the learning rate * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and - * dampens oscillations, Must be greater than or equal to zero. Default is 0. + * dampens oscillations, Must be greater than or equal to zero. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. */ public Momentum(Graph graph, float learningRate, float momentum) { this(graph, learningRate, momentum, NESTEROV_DEFAULT); } /** - * Creates a Momentum Optimizer + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name and {@link + * #NESTEROV_DEFAULT} for the Nesterov flag. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. + */ + public Momentum(Graph graph, Operand<TFloat32> learningRateOperand, float momentum) { + this(graph, learningRateOperand, momentum, NESTEROV_DEFAULT); + } + + /** + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph * @param learningRate the learning rate * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and - * dampens oscillations, Must be greater than or equal to zero. Default is 0. - * @param useNesterov Whether to apply Nesterov momentum. Defaults to false. + * dampens oscillations, Must be greater than or equal to zero. + * @param useNesterov Whether to apply Nesterov momentum. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. */ public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) { this(graph, null, learningRate, momentum, useNesterov); } + /** + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. + * @param useNesterov Whether to apply Nesterov momentum. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. + */ + public Momentum( + Graph graph, Operand<TFloat32> learningRateOperand, float momentum, boolean useNesterov) { + this(graph, null, learningRateOperand, momentum, useNesterov); + } + /** * Creates a Momentum Optimizer * @@ -94,12 +145,40 @@ public Momentum(Graph graph, float learningRate, float momentum, boolean useNest * @param name the name for this Optimizer * @param learningRate the learning rate * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and - * dampens oscillations, Must be greater than or equal to zero. Default is 0. - * @param useNesterov Whether to apply Nesterov momentum. Defaults to false. + * dampens oscillations, Must be greater than or equal to zero. + * @param useNesterov Whether to apply Nesterov momentum. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. */ public Momentum( Graph graph, String name, float learningRate, float momentum, boolean useNesterov) { super(graph, name, learningRate); + if (momentum < 0) + throw new IllegalArgumentException("momentum must be greater than or equal to zero."); + this.momentum = momentum; + this.useNesterov = useNesterov; + } + + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. + * @param useNesterov Whether to apply Nesterov momentum. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. + */ + public Momentum( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float momentum, + boolean useNesterov) { + super(graph, name, learningRateOperand); + if (momentum < 0) + throw new IllegalArgumentException("momentum must be greater than or equal to zero."); this.momentum = momentum; this.useNesterov = useNesterov; } @@ -152,6 +231,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Momentum"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index ece7c024969..202f5013e7d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -25,6 +25,7 @@ */ public class Nadam extends Optimizer { + public static final String DEFAULT_NAME = "Nadam"; private static final float DECAY_BASE = 0.96f; private static final float DECAY = 0.004f; public static final float LEARNING_RATE_DEFAULT = 0.001f; @@ -65,7 +66,9 @@ public class Nadam extends Optimizer { private Operand<TFloat32> vTPrimeDenominator; /** - * Creates a Nadam Optimizer + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, + * {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph */ @@ -74,50 +77,96 @@ public Nadam(Graph graph) { } /** - * Creates a Nadam Optimizer + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and + * {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph - * @param learningRate the learning rate, defaults to 0.001 + * @param learningRate the learning rate. */ public Nadam(Graph graph, float learningRate) { this(graph, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } /** - * Creates a Nadam Optimizer + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and + * {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Nadam(Graph graph, Operand<TFloat32> learningRateOperand) { + this(graph, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph - * @param learningRate the learning rate, defaults to 0.001 - * @param betaOne The exponential decay rate for the 1st moment estimates. Default is 0.9. - * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. Default - * is 0.999. - * @param epsilon A small constant for numerical stability. Default is 1e-8. + * @param learningRate the learning rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. */ public Nadam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { this(graph, null, learningRate, betaOne, betaTwo, epsilon); } /** - * Creates a Nadam Optimizer + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. + */ + public Nadam( + Graph graph, + Operand<TFloat32> learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + this(graph, null, learningRateOperand, betaOne, betaTwo, epsilon); + } + + /** + * Creates a Nadam Optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link + * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph - * @param name the name for this Optimizer, defaults to "Nadam" - * @param learningRate the learning rate, defaults to 0.001 + * @param name the name for this Optimizer. + * @param learningRate the learning rate. */ public Nadam(Graph graph, String name, float learningRate) { this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } + /** + * Creates a Nadam Optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link + * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Nadam(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this(graph, name, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + /** * Creates a Nadam Optimizer * * @param graph the TensorFlow graph - * @param name the name for this Optimizer, defaults to "Nadam" - * @param learningRate the learning rate, defaults to 0.001 - * @param betaOne The exponential decay rate for the 1st moment estimates. Default is 0.9. - * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. Default - * is 0.999. - * @param epsilon A small constant for numerical stability. Default is 1e-8. + * @param name the name for this Optimizer. + * @param learningRate the learning rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. */ public Nadam( Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { @@ -127,6 +176,30 @@ public Nadam( this.epsilon = epsilon; } + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. + */ + public Nadam( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + super(graph, name, learningRateOperand); + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + /** {@inheritDoc} */ @Override protected void createSlots(List<Output<? extends TType>> variables) { @@ -287,6 +360,6 @@ protected Op finish(List<Op> updateOperations, String name) { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Nadam"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index 5194cb32e73..586fef28c1e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -50,6 +50,7 @@ public abstract class Optimizer implements AutoCloseable { protected Placeholder<TFloat32> learningRatePlaceholder = null; private Tensor<TFloat32> learningRateTensor; private Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap = null; + private Operand<TFloat32> learningRateOperand; /** * Builds an optimizer for the supplied graph. @@ -66,6 +67,21 @@ protected Optimizer(Graph graph, String name, float learningRate) { setLearningRate(learningRate); } + /** + * Builds an optimizer for the supplied graph. + * + * @param graph The graph to optimize. + * @param name The base name for the operations. + * @param learningRateOperand the learning rate. + */ + protected Optimizer(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this.graph = graph; + this.tf = Ops.create(graph).withName(name == null ? getOptimizerName() : name); + this.slots = new HashMap<>(); + this.globals = new ArrayList<>(); + setLearningRateOperand(learningRateOperand); + } + /** * Creates a name by combining a variable name and a slot name * @@ -293,6 +309,17 @@ public final void setLearningRate(float newLearningRate) { } } + /** + * Sets the learning rate Operand. The learning rate operand is an operand that is used to + * calculate the learning rate. + * + * @param newLearningRateOperand the new learning rate operand. + */ + public final void setLearningRateOperand(Operand<TFloat32> newLearningRateOperand) { + close(); // Cleanup the placeholder and tensor if they exist. + learningRateOperand = newLearningRateOperand; + } + /** * Gets the learning rate * @@ -303,20 +330,23 @@ public float getLearningRate() { } /** - * Gets the learning rate Operand, used by subclasses in their graph operations + * Gets the learning rate Operand, used by subclasses in their graph operations. If a float + * learning rate has been set using {@link #setLearningRate}, then this will be the learning rate + * Placeholder, otherwise the learning rate operand is returned as passed to {@link + * #setLearningRateOperand}. * * @return the learning rate Operand */ protected Operand<TFloat32> getLearningRateOperand() { - return learningRatePlaceholder; + return learningRatePlaceholder == null ? learningRateOperand : learningRatePlaceholder; } /** * Gets the Feed Map for the run methods to set the Placeholder value(s). Each entry in the Feed - * Map contains a PlaceHolder and a Tensor with the value + * Map contains a PlaceHolder and a Tensor with the value. * - * @return the current Feed Map for the run methods, this may be null if the LearningRate is an - * Operand has been set. + * @return the current Feed Map for the run methods, this will be null if the LearningRateOperand + * has been set. */ public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedMap() { return feedMap; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index 41b65a0ac01..7a03bec849d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -20,6 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -47,6 +48,7 @@ */ public class RMSProp extends Optimizer { + public static final String DEFAULT_NAME = "RMSProp"; public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float DECAY_DEFAULT = 0.9f; public static final float MOMENTUM_DEFAULT = 0.0f; @@ -62,7 +64,10 @@ public class RMSProp extends Optimizer { private final boolean centered; /** - * Creates an RMSPRrop Optimizer + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #DECAY_DEFAULT} for the decay, {@link + * #MOMENTUM_DEFAULT} for the momentum, {@link #EPSILON_DEFAULT} for the epsilon value and {@link + * #CENTERED_DEFAULT} for the centered flag. * * @param graph the TensorFlow Graph */ @@ -77,7 +82,9 @@ public RMSProp(Graph graph) { } /** - * Creates an RMSPRrop Optimizer + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #DECAY_DEFAULT} for the decay, {@link #MOMENTUM_DEFAULT} for the momentum, {@link + * #EPSILON_DEFAULT} for the epsilon value and {@link #CENTERED_DEFAULT} for the centered flag. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -87,17 +94,36 @@ public RMSProp(Graph graph, float learningRate) { } /** - * Creates an RMSPRrop Optimizer + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #DECAY_DEFAULT} for the decay, {@link #MOMENTUM_DEFAULT} for the momentum, {@link + * #EPSILON_DEFAULT} for the epsilon value and {@link #CENTERED_DEFAULT} for the centered flag. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public RMSProp(Graph graph, Operand<TFloat32> learningRateOperand) { + this( + graph, + learningRateOperand, + DECAY_DEFAULT, + MOMENTUM_DEFAULT, + EPSILON_DEFAULT, + CENTERED_DEFAULT); + } + + /** + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow Graph * @param learningRate the learning rate - * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. - * @param momentum the acceleration factor, default is 0. + * @param decay Discounting factor for the history/coming gradient. + * @param momentum the acceleration factor. * @param epsilon A small constant for numerical stability * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the * gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code> * true</code>> may help with training, but is slightly more expensive in terms of computation - * and memory. Defaults to <code>false</code>. + * and memory. */ public RMSProp( Graph graph, @@ -110,10 +136,36 @@ public RMSProp( } /** - * Creates an RMSPRrop Optimizer + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param decay Discounting factor for the history/coming gradient. + * @param momentum the acceleration factor. + * @param epsilon A small constant for numerical stability + * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the + * gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code> + * true</code>> may help with training, but is slightly more expensive in terms of computation + * and memory. + */ + public RMSProp( + Graph graph, + Operand<TFloat32> learningRateOperand, + float decay, + float momentum, + float epsilon, + boolean centered) { + this(graph, null, learningRateOperand, decay, momentum, epsilon, centered); + } + + /** + * Creates an RMSPRrop Optimizer using {@link #DECAY_DEFAULT} for the decay, {@link + * #MOMENTUM_DEFAULT} for the momentum, {@link #EPSILON_DEFAULT} for the epsilon value and {@link + * #CENTERED_DEFAULT} for the centered flag. * * @param graph the TensorFlow Graph - * @param name the name of this Optimizer. Defaults to "RMSProp". + * @param name the name of this Optimizer. * @param learningRate the learning rate */ public RMSProp(Graph graph, String name, float learningRate) { @@ -127,19 +179,40 @@ public RMSProp(Graph graph, String name, float learningRate) { CENTERED_DEFAULT); } + /** + * Creates an RMSPRrop Optimizer using {@link #DECAY_DEFAULT} for the decay, {@link + * #MOMENTUM_DEFAULT} for the momentum, {@link #EPSILON_DEFAULT} for the epsilon value and {@link + * #CENTERED_DEFAULT} for the centered flag. + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public RMSProp(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this( + graph, + name, + learningRateOperand, + DECAY_DEFAULT, + MOMENTUM_DEFAULT, + EPSILON_DEFAULT, + CENTERED_DEFAULT); + } + /** * Creates an RMSPRrop Optimizer * * @param graph the TensorFlow Graph - * @param name the name of this Optimizer. Defaults to "RMSProp". + * @param name the name of this Optimizer. * @param learningRate the learning rate - * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. - * @param momentum The acceleration factor, default is 0. + * @param decay Discounting factor for the history/coming gradient. + * @param momentum The acceleration factor,. * @param epsilon A small constant for numerical stability * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the * gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code> * true</code>> may help with training, but is slightly more expensive in terms of computation - * and memory. Defaults to <code>false</code>. + * and memory. */ public RMSProp( Graph graph, @@ -156,6 +229,36 @@ public RMSProp( this.centered = centered; } + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param decay Discounting factor for the history/coming gradient. + * @param momentum The acceleration factor. + * @param epsilon A small constant for numerical stability + * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the + * gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code> + * true</code>> may help with training, but is slightly more expensive in terms of computation + * and memory. + */ + public RMSProp( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float decay, + float momentum, + float epsilon, + boolean centered) { + super(graph, name, learningRateOperand); + this.decay = decay; + this.momentum = momentum; + this.epsilon = epsilon; + this.centered = centered; + } + /** {@inheritDoc} */ @Override protected void createSlots(List<Output<? extends TType>> variables) { @@ -233,6 +336,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "RMSProp"; + return DEFAULT_NAME; } } From ca1395e45fc8db61c29d5fcabbde3c1109f19239 Mon Sep 17 00:00:00 2001 From: Jim Clarke <JimClarke5@me.com> Date: Fri, 25 Sep 2020 14:45:55 -0400 Subject: [PATCH 14/14] Added Operand<TFloat32> learningRateOperand test case for learning rate. --- .../framework/optimizers/AdaDeltaTest.java | 107 +++++++++++- .../framework/optimizers/AdaGradDATest.java | 51 +++++- .../framework/optimizers/AdaGradTest.java | 64 +++++++ .../framework/optimizers/AdamTest.java | 138 ++++++++++++++- .../framework/optimizers/AdamaxTest.java | 131 +++++++++++++-- .../framework/optimizers/FtrlTest.java | 63 +++++++ .../optimizers/GradientDescentTest.java | 55 +++++- .../framework/optimizers/MomentumTest.java | 56 +++++- .../framework/optimizers/NadamTest.java | 135 +++++++++++++++ .../framework/optimizers/RMSPropTest.java | 159 +++++++++++++++++- 10 files changed, 936 insertions(+), 23 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java index 3547ea9a30e..7653c99bc98 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java @@ -167,6 +167,109 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + int numUpdates = 4; // # number of ADADELTA steps to perform + float[] grads = {0.2F, 0.1F, 0.01F}; + float[] lrs = {1.0F, 0.5F, 0.1F}; + + float rho = 0.95F; + float epsilon = 1e-8F; + + for (float grad : grads) { + for (float lr : lrs) { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // this just uses a trivial operand + try (AdaDelta instance = + new AdaDelta( + session.getGraph(), + tf.math.mul(tf.constant(lr), tf.constant(1.f)), + rho, + epsilon)) { + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] fgrads = {grad, grad}; + Shape shape = Shape.of(var0Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> cgrads = tf.constant(fgrads); + float accum = 0.0F; + float accumUpdate = 0.0F; + + /* build the GradsAnvVars */ + List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput())); + + /*apply gradients */ + Op adadeltaUpdate = instance.applyGradients(gradsAndVars, "AdaDeltaTest"); + + /* Create and validate the shapes of the slota */ + @SuppressWarnings("unchecked") + Variable<TFloat32>[] slots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable<TFloat32>[] slotUpdates = new Variable[2]; + + slots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape()); + + slotUpdates[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape()); + + slots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape()); + + slotUpdates[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + float[] updates = new float[numUpdates]; + float totUpdate = 0; + for (int step = 0; step < numUpdates; step++) { + + session.run(adadeltaUpdate, instance.getFeedMap()); + accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); + updates[step] = + ((float) Math.sqrt(accumUpdate + epsilon) + * (float) (1 / Math.sqrt(accum + epsilon)) + * grad); + accumUpdate = + (accumUpdate * rho + ((float) Math.pow(updates[step], 2) * (1.0F - rho))); + totUpdate += updates[step] * lr; + + for (int i = 0; i < 2; i++) { + session.evaluate(accum, slots[i]); + session.evaluate(accumUpdate, slotUpdates[i]); + } + + Float[] var0InitUpdate = {var0Init[0] - totUpdate, var0Init[1] - totUpdate}; + Float[] var1InitUpdate = {var1Init[0] - totUpdate, var1Init[1] - totUpdate}; + + session.evaluate(var0InitUpdate, var0); + session.evaluate(var1InitUpdate, var1); + } + } + } + } + } + } + @Test public void testWithLearningRateDecay() { int numSteps = 4; // # number of ADADELTA steps to perform @@ -224,10 +327,10 @@ public void testWithLearningRateDecay() { session.run(var0Initializer); session.run(var1Initializer); - /** initialize the accumulators */ + /* initialize the accumulators */ session.run(tf.init()); - /** make sure the variables were initialized properly */ + /* make sure the variables were initialized properly */ session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java index 1f8044c1168..9e67b4660df 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java @@ -97,6 +97,56 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + float[] var0Init = {0.0F, 0.0F}; + float[] var1Init = {0.0F, 0.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.02F}; + float learningRate = 1.5F; + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (AdaGradDA instance = + new AdaGradDA( + session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(2.f)))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* initialize the local variables */ + + session.run(var0Initializer); + session.run(var1Initializer); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradDATest"); + + /* initialize the accumulators */ + session.run(tf.init()); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + session.run(adaUpdate, instance.getFeedMap()); + float[] expected0 = {-0.904534F, -1.603567F}; + session.evaluate(expected0, var0); + float[] expected1 = {-0.094821f, -0.189358f}; + session.evaluate(expected1, var1); + } + } + } + @Test public void testWithLearningRateDecay() { float[] var0Init = {0.0F, 0.0F}; @@ -104,7 +154,6 @@ public void testWithLearningRateDecay() { float[] grads0Init = {0.1F, 0.2F}; float[] grads1Init = {0.01F, 0.02F}; float epsilon = 1e-8F; - float epsilon1 = 1e-5F; int numSteps = 4; float learningRate = 3.0F; try (TestSession session = TestSession.createTestSession(tfMode); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java index 4a71fe59ba0..dc05b9b8f81 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java @@ -111,6 +111,70 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + int numSteps = 3; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + + float learningRate = 1.0F; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (AdaGrad instance = + new AdaGrad( + session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(3.f)), 0.1f)) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradTest"); + + @SuppressWarnings("unchecked") + Variable<TFloat32>[] accumulatorSlots = new Variable[2]; + accumulatorSlots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(accumulatorSlots[0].asOutput().shape(), var0.asOutput().shape()); + + accumulatorSlots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(accumulatorSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + for (int step = 0; step < numSteps; step++) { + session.run(adaUpdate, instance.getFeedMap()); + } + float[] expected0 = {-1.6026098728179932f, -0.6026098728179932f}; + session.evaluate(expected0, var0); + float[] expected1 = {2.715679168701172f, 3.715679168701172f}; + session.evaluate(expected1, var1); + } + } + } + @Test public void testWithLearningRateDecay() { int numSteps = 3; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java index a8be65c3650..c5bb153d804 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java @@ -65,8 +65,8 @@ public void testBasic() { FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); - float epsilon1 = 1e-3F; - float learningRate = 0.001F; + float epsilon1 = 1e-3f; + float learningRate = 0.001f; try (TestSession session = TestSession.createTestSession(tfMode); Adam instance = new Adam(session.getGraph(), learningRate)) { @@ -185,6 +185,140 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + float epsilon1 = 1e-3f; + float learningRate = 0.001f; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Adam instance = + new Adam( + session.getGraph(), tf.constant(learningRate))) { + + float beta1 = 0.9F; + float beta2 = 0.999F; + + session.setEpsilon(epsilon1); + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validate the shapes of the slots */ + @SuppressWarnings("unchecked") + Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable<TFloat32>[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the accumulators */ + session.run(tf.init(), instance.getFeedMap()); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + FloatNdArray m0Np = NdArrays.ofFloats(shape0); + FloatNdArray v0Np = NdArrays.ofFloats(shape0); + FloatNdArray m1Np = NdArrays.ofFloats(shape1); + FloatNdArray v1Np = NdArrays.ofFloats(shape1); + + for (int step = 0; step < 3; step++) { + + // Test powers + final float[] powers = { + (float) Math.pow(beta1, step + 1), (float) Math.pow(beta2, step + 1) + }; + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); + } + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("beta2_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); + } + session.run(update, instance.getFeedMap()); + + float lrT = + learningRate + * (float) Math.sqrt(1 - (float) Math.pow(beta2, (step + 1))) + / (1 - (float) Math.pow(beta1, (step + 1))); + + m0Np = calculateM(m0Np, grads0Np, beta1); + v0Np = calculateV(v0Np, grads0Np, beta2); + var0Np = calculateParam(var0Np, lrT, m0Np, v0Np, 1e-7F); + + m1Np = calculateM(m1Np, grads1Np, beta1); + v1Np = calculateV(v1Np, grads1Np, beta2); + var1Np = calculateParam(var1Np, lrT, m1Np, v1Np, 1e-7F); + + // evaluate var 0 and var1 + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + + // first moment + session.evaluate(m0Np, firstMomentSlots[0]); + session.evaluate(m1Np, firstMomentSlots[1]); + + // second moment + session.evaluate(v0Np, secondMomentSlots[0]); + session.evaluate(v1Np, secondMomentSlots[1]); + } + } + } + } + @Test public void testWithLearningRateDecay() { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java index 57d3cbdb70c..e900018ccad 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import org.tensorflow.Graph; import org.tensorflow.Tensor; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; @@ -61,10 +60,10 @@ public void tearDown() {} /** Test of getOptimizerName method, of class Adamax. */ @Test public void testGetOptimizerName() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Adamax instance = new Adamax(graph); - String expResult = "Adamax"; + try (TestSession session = TestSession.createTestSession(tfMode); + Adamax instance = new Adamax(session.getGraph())) { + + String expResult = DEFAULT_NAME; String result = instance.getOptimizerName(); assertEquals(expResult, result); } @@ -178,6 +177,118 @@ public void testBasic() { } } + /** Test of applyDense method, of class Adamax. */ + @Test + public void testBasicWithLROperand() { + + int numSteps = 3; + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + + float[] zeros = {0.0F, 0.0F}; + FloatNdArray m0 = NdArrays.vectorOf(zeros); + FloatNdArray v0 = NdArrays.vectorOf(zeros); + FloatNdArray m1 = NdArrays.vectorOf(zeros); + FloatNdArray v1 = NdArrays.vectorOf(zeros); + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + float epsilon1 = 1e-3f; + float learningRate = 1f; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Adamax instance = + new Adamax( + session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(1e-3f)))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* build the GradsAnvVars */ + List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") + Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable<TFloat32>[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + session.setEpsilon(epsilon1); + for (int step = 0; step < numSteps; step++) { + // Test powers + final float beta1Power = (float) Math.pow(BETA_ONE_DEFAULT, step + 1); + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1)); + } + session.run(update, instance.getFeedMap()); + + FloatNdArray[] resultNP = calculate(var0Np, grads0Np, step, m0, v0); + var0Np = resultNP[VAR]; + m0 = resultNP[M]; + v0 = resultNP[V]; + + resultNP = calculate(var1Np, grads1Np, step, m1, v1); + var1Np = resultNP[VAR]; + m1 = resultNP[M]; + v1 = resultNP[V]; + + // evaluate var0 and var1 + + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + } + } + } + } + @Test public void testWithLearningRateDecay() { @@ -242,7 +353,7 @@ public void testWithLearningRateDecay() { secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); - /** initialize the accumulators */ + /* initialize the accumulators */ session.run(tf.init()); /* initialize the local variables */ @@ -260,13 +371,7 @@ public void testWithLearningRateDecay() { .run() .get(0) .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> { - assertEquals(betaPower, f.getFloat(), epsilon1); - }); + result.data().scalars().forEach(f -> assertEquals(betaPower, f.getFloat(), epsilon1)); } assertEquals(learningRate, instance.getLearningRate(), epsilon); session.evaluate( diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java index 92b610e5951..c047cd4bf40 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java @@ -123,6 +123,69 @@ public void testFtrlWithL1L2L2Shrinkage() { } } + @Test + public void testFtrlWithL1L2L2ShrinkageWithLROperand() { + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {4.0F, 3.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.02F}; + float learningRate = 1.0F; + + int numSteps = 10; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Ftrl instance = + new Ftrl( + session.getGraph(), + tf.math.mul(tf.constant(learningRate), tf.constant(3f)), + -0.5F, // learningRatePower + 0.1F, // initialAccumulatorValue + 0.001F, // l1RegularizationStrength + 2.0F, // l2RegularizationStrength + 0.1F // l2ShrinkageRegularizationStrength + )) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op ftrlUpdate = instance.applyGradients(gradsAndVars, "FtrlTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + for (int i = 0; i < numSteps; i++) { + session.run(ftrlUpdate, instance.getFeedMap()); + } + + float[] expectedVar0 = {-0.22578995F, -0.44345796F}; + session.evaluate(expectedVar0, var0); + float[] expectedVar1 = {-0.14378493F, -0.13229476F}; + session.evaluate(expectedVar1, var1); + } + } + } + @Test public void testFtrlWithL1() { float[] var0Init = {1.0F, 2.0F}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index 8e793e35d5f..ce687186994 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -56,7 +56,7 @@ public void testBasic() { float learningRate = 3.0F; try (TestSession session = TestSession.createTestSession(tfMode); - GradientDescent instance = new GradientDescent(session.getGraph(), learningRate); ) { + GradientDescent instance = new GradientDescent(session.getGraph(), learningRate)) { Ops tf = session.getTF(); Shape shape0 = Shape.of(var0Init.length); @@ -97,6 +97,59 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + float learningRate = 1.5f; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (GradientDescent instance = + new GradientDescent( + session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(2f)))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "SGDTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + session.run(update, instance.getFeedMap()); // 1 step + + float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F}; + float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F}; + session.evaluate(expectedVar0, var0); + session.evaluate(expectedVar1, var1); + } + } + } + @Test public void testWithLearningRateDecay() { int numSteps = 2; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java index b54e3b52a26..014c72e55e6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java @@ -71,9 +71,9 @@ public void testBasic() { float[] grads1Init = {0.01F, 0.01F}; float learningRate = 3.0F; - try (TestSession session = TestSession.createTestSession(tfMode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Momentum instance = new Momentum(session.getGraph(), learningRate)) { Ops tf = session.getTF(); - Graph graph = session.getGraph(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); @@ -91,7 +91,6 @@ public void testBasic() { gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - Momentum instance = new Momentum(graph, learningRate); Op update = instance.applyGradients(gradsAndVars, "SGDTest"); /* initialize the local variables */ @@ -114,6 +113,57 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + float learningRate = 3.0F; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Momentum instance = new Momentum(session.getGraph(), tf.constant(learningRate))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "SGDTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + session.run(update, instance.getFeedMap()); // 1 step + + float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F}; + float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F}; + session.evaluate(expectedVar0, var0); + session.evaluate(expectedVar1, var1); + } + } + } + @Test public void testMomentum() { float[] var0Init = {1.0F, 2.0F}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index c7c17689a33..0832543c104 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -203,6 +203,141 @@ public void testBasic() { } } + /** Test of applyDense method, of class Nadam. */ + @Test + public void testBasicWithLROperand() { + + int numSteps = 3; + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; + FloatNdArray m0 = NdArrays.vectorOf(zeros); + FloatNdArray v0 = NdArrays.vectorOf(zeros); + FloatNdArray m1 = NdArrays.vectorOf(zeros); + FloatNdArray v1 = NdArrays.vectorOf(zeros); + FloatNdArray mcache = NdArrays.vectorOf(ones); + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + float epsilon1 = 1e-3F; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Nadam instance = + new Nadam(session.getGraph(), tf.math.mul(tf.constant(1f), tf.constant(1e-3f)))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") + Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable<TFloat32>[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + session.setEpsilon(epsilon1); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); + } + momentum = 1F; + + for (int step = 0; step < numSteps; step++) { + + session.run(update, instance.getFeedMap()); + + float mut = + Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); + momentum = momentum * mut; + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); + } + mcache = ND.mul(mcache, momentum); + FloatNdArray[] resultsNP = nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache); + var0Np = resultsNP[VAR]; + m0 = resultsNP[M]; + v0 = resultsNP[V]; + + resultsNP = nadamUpdateNdArray(var1Np, grads1Np, step, m1, v1, mcache); + var1Np = resultsNP[VAR]; + m1 = resultsNP[M]; + v1 = resultsNP[V]; + + // evaluate m0 and m1 + session.evaluate(m0, firstMomentSlots[0]); + session.evaluate(m1, firstMomentSlots[1]); + + // evaluate v0 and v1 + session.evaluate(v0, secondMomentSlots[0]); + session.evaluate(v1, secondMomentSlots[1]); + + // evaluate var0 and var1 + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + } + } + } + } + @Test public void testWithLearningRateDecay() { int numSteps = 3; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java index 2a012ff0f99..711358e9222 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java @@ -73,7 +73,7 @@ public void testDense() { for (Object[] testParamValue : testParamValues) { // learningRate, rho (decay), momentum, epsilon, centered - float learningRate = (float) (float) testParamValue[0]; + float learningRate = (float) testParamValue[0]; float decay = (float) testParamValue[1]; float momentum = (float) testParamValue[2]; float epsilon = (float) testParamValue[3]; @@ -215,6 +215,163 @@ public void testDense() { } } + @Test + public void testDenseWithLROperand() { + + int numSteps = 3; + + for (Object[] testParamValue : testParamValues) { + // learningRate, rho (decay), momentum, epsilon, centered + float learningRate = (float) testParamValue[0]; + float decay = (float) testParamValue[1]; + float momentum = (float) testParamValue[2]; + float epsilon = (float) testParamValue[3]; + boolean centered = (boolean) testParamValue[4]; + try (TestSession session = TestSession.createTestSession(tfMode)) { + + Ops tf = session.getTF(); + try (RMSProp instance = + new RMSProp( + session.getGraph(), + tf.math.add(tf.constant(learningRate), tf.constant(0f)), + decay, + momentum, + epsilon, + centered)) { + + session.setEpsilon(1e-2f); + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.2F}; + + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + Variable<TFloat32> mg0 = + centered && instance.getSlot(var0.asOutput(), MG).isPresent() + ? instance.getSlot(var0.asOutput(), MG).get() + : null; + Variable<TFloat32> mg1 = + centered && instance.getSlot(var1.asOutput(), MG).isPresent() + ? instance.getSlot(var1.asOutput(), MG).get() + : null; + Variable<TFloat32> mom0 = + momentum > 0.F && instance.getSlot(var0.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var0.asOutput(), MOMENTUM).get() + : null; + Variable<TFloat32> mom1 = + momentum > 0.F && instance.getSlot(var1.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var1.asOutput(), MOMENTUM).get() + : null; + Variable<TFloat32> rms0 = + instance.getSlot(var0.asOutput(), RMS).isPresent() + ? instance.getSlot(var0.asOutput(), RMS).get() + : null; + Variable<TFloat32> rms1 = + instance.getSlot(var1.asOutput(), RMS).isPresent() + ? instance.getSlot(var1.asOutput(), RMS).get() + : null; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; // temp to match RMSProp + FloatNdArray mg0Np = NdArrays.vectorOf(zeros); + FloatNdArray mg1Np = NdArrays.vectorOf(zeros); + FloatNdArray rms0Np = NdArrays.vectorOf(ones); + FloatNdArray rms1Np = NdArrays.vectorOf(ones); + FloatNdArray mom0Np = NdArrays.vectorOf(zeros); + FloatNdArray mom1Np = NdArrays.vectorOf(zeros); + + for (int i = 0; i < numSteps; i++) { + session.run(update, instance.getFeedMap()); + FloatNdArray[] result0 = + calc( + var0Np, + grads0Np, + mg0Np, + rms0Np, + mom0Np, + learningRate, + decay, + momentum, + epsilon, + centered); + var0Np = result0[VAR_T]; + mg0Np = result0[MG_T]; + rms0Np = result0[RMS_T]; + mom0Np = result0[MOM_T]; + + FloatNdArray[] result1 = + calc( + var1Np, + grads1Np, + mg1Np, + rms1Np, + mom1Np, + learningRate, + decay, + momentum, + epsilon, + centered); + + var1Np = result1[VAR_T]; + mg1Np = result1[MG_T]; + rms1Np = result1[RMS_T]; + mom1Np = result1[MOM_T]; + + if (centered) { + if (mg0 != null) session.evaluate(mg0Np, mg0); + if (mg1 != null) session.evaluate(mg1Np, mg1); + } + + if (mom0 != null) session.evaluate(mom0Np, mom0); + if (mom1 != null) session.evaluate(mom1Np, mom1); + + /* TODO the values returned from rms slot, do not match what I see in the python test */ + if (rms0 != null) session.evaluate(rms0Np, rms0); + else fail("rms0 is null"); + if (rms1 != null) session.evaluate(rms1Np, rms1); + else fail("rms1 is null"); + + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + } + } + } + } + } + @Test public void testWithLearningRateDecay() { int numSteps = 3;