diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index 1bdeffe4dcb..30abd0fcbe3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -20,6 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -62,24 +63,31 @@ */ public class AdaDelta extends Optimizer { + public static final String DEFAULT_NAME = "Adadelta"; public static final String ACCUMULATOR = "accum"; public static final String ACCUMULATOR_UPDATE = "accum_update"; public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float RHO_DEFAULT = 0.95f; public static final float EPSILON_DEFAULT = 1e-7f; - private final float learningRate; - private final float rho; private final float epsilon; + /** + * Creates an AdaDelta Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learningRate, {@link #RHO_DEFAULT} for the rho, and {@link + * #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph. + */ public AdaDelta(Graph graph) { this(graph, LEARNING_RATE_DEFAULT, RHO_DEFAULT, EPSILON_DEFAULT); } /** - * Creates an AdaDelta Optimizer + * Creates an AdaDelta Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #RHO_DEFAULT} for the rho, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -89,7 +97,19 @@ public AdaDelta(Graph graph, float learningRate) { } /** - * Creates an AdaDelta Optimizer + * Creates an AdaDelta Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #RHO_DEFAULT} for the rho, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaDelta(Graph graph, Operand<TFloat32> learningRateOperand) { + this(graph, learningRateOperand, RHO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates an AdaDelta Optimizer {@link #DEFAULT_NAME} for the Optimizer name * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -97,35 +117,75 @@ public AdaDelta(Graph graph, float learningRate) { * @param epsilon A constant epsilon used to better conditioning the grad update */ public AdaDelta(Graph graph, float learningRate, float rho, float epsilon) { - super(graph); - this.learningRate = learningRate; - this.rho = rho; - this.epsilon = epsilon; + this(graph, null, learningRate, rho, epsilon); } /** * Creates an AdaDelta Optimizer * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'Adadelta') + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param rho The decay factor + * @param epsilon A constant epsilon used to better conditioning the grad update + */ + public AdaDelta(Graph graph, Operand<TFloat32> learningRateOperand, float rho, float epsilon) { + this(graph, null, learningRateOperand, rho, epsilon); + } + + /** + * Creates an AdaDelta Optimizer using {@link #RHO_DEFAULT} for the rho, and {@link * + * #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. * @param learningRate the learning rate */ public AdaDelta(Graph graph, String name, float learningRate) { - this(graph, name, learningRate, 0.95f, 1e-8f); + this(graph, name, learningRate, RHO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates an AdaDelta Optimizer using {@link #RHO_DEFAULT} for the rho, and {@link * + * #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaDelta(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this(graph, name, learningRateOperand, RHO_DEFAULT, EPSILON_DEFAULT); } /** * Creates an AdaDelta Optimizer * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'Adadelta') + * @param name the name for this Optimizer. * @param learningRate the learning rate * @param rho The decay factor * @param epsilon A constant epsilon used to better conditioning the grad update */ public AdaDelta(Graph graph, String name, float learningRate, float rho, float epsilon) { - super(graph, name); - this.learningRate = learningRate; + super(graph, name, learningRate); + this.rho = rho; + this.epsilon = epsilon; + } + + /** + * Creates an AdaDelta Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param rho The decay factor + * @param epsilon A constant epsilon used to better conditioning the grad update + */ + public AdaDelta( + Graph graph, String name, Operand<TFloat32> learningRateOperand, float rho, float epsilon) { + super(graph, name, learningRateOperand); this.rho = rho; this.epsilon = epsilon; } @@ -162,7 +222,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable variable, accumSlot, accumUpdateSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), tf.dtypes.cast(tf.constant(rho), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), gradient); @@ -184,6 +244,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Adadelta"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 6855be30759..c0cc47409d7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -20,6 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -40,16 +41,18 @@ */ public class AdaGrad extends Optimizer { + public static final String DEFAULT_NAME = "Adagrad"; + public static final String ACCUMULATOR = "accumulator"; public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float INITIAL_ACCUMULATOR_DEFAULT = 0.01f; - private final float learningRate; - private final float initialAccumulatorValue; /** - * Creates an AdaGrad Optimizer + * Creates an AdaGrad Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, and {@link * #INITIAL_ACCUMULATOR_DEFAULT} for + * the initialAccumulatorValue. * * @param graph the TensorFlow Graph */ @@ -58,7 +61,8 @@ public AdaGrad(Graph graph) { } /** - * Creates an AdaGrad Optimizer + * Creates an AdaGrad Optimizer using using {@link #DEFAULT_NAME} for the Optimizer name, {@link * + * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -68,7 +72,19 @@ public AdaGrad(Graph graph, float learningRate) { } /** - * Creates an AdaGrad Optimizer + * Creates an AdaGrad Optimizer using using {@link #DEFAULT_NAME} for the Optimizer name, {@link * + * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaGrad(Graph graph, Operand<TFloat32> learningRateOperand) { + this(graph, learningRateOperand, INITIAL_ACCUMULATOR_DEFAULT); + } + + /** + * Creates an AdaGrad Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -76,44 +92,88 @@ public AdaGrad(Graph graph, float learningRate) { * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative */ public AdaGrad(Graph graph, float learningRate, float initialAccumulatorValue) { - super(graph); - if (initialAccumulatorValue < 0F) { - throw new IllegalArgumentException( - String.format( - "initialAccumulatorValue must be non-negative: %f", initialAccumulatorValue)); - } - this.learningRate = learningRate; - this.initialAccumulatorValue = initialAccumulatorValue; + this(graph, null, learningRate, initialAccumulatorValue); } /** - * Creates an AdaGrad Optimizer + * Creates an AdaGrad Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'Adagrad') + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param initialAccumulatorValue Starting value for the accumulators, must be non-negative. + * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative + */ + public AdaGrad( + Graph graph, Operand<TFloat32> learningRateOperand, float initialAccumulatorValue) { + this(graph, null, learningRateOperand, initialAccumulatorValue); + } + + /** + * Creates an AdaGrad Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer . * @param learningRate the learning rate */ public AdaGrad(Graph graph, String name, float learningRate) { - this(graph, name, learningRate, 0.01f); + this(graph, name, learningRate, INITIAL_ACCUMULATOR_DEFAULT); + } + + /** + * Creates an AdaGrad Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaGrad(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this(graph, name, learningRateOperand, INITIAL_ACCUMULATOR_DEFAULT); } /** * Creates an AdaGrad Optimizer * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'Adagrad') + * @param name the name for this Optimizer * @param learningRate the learning rate * @param initialAccumulatorValue Starting value for the accumulators, must be non-negative. * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative */ public AdaGrad(Graph graph, String name, float learningRate, float initialAccumulatorValue) { - super(graph, name); + super(graph, name, learningRate); + if (initialAccumulatorValue < 0F) { + throw new IllegalArgumentException( + String.format( + "initialAccumulatorValue must be non-negative: %f", initialAccumulatorValue)); + } + this.initialAccumulatorValue = initialAccumulatorValue; + } + + /** + * Creates an AdaGrad Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param initialAccumulatorValue Starting value for the accumulators, must be non-negative. + * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative + */ + public AdaGrad( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float initialAccumulatorValue) { + super(graph, name, learningRateOperand); if (initialAccumulatorValue < 0F) { throw new IllegalArgumentException( String.format( "initialAccumulatorValue must be non-negative: %f", initialAccumulatorValue)); } - this.learningRate = learningRate; this.initialAccumulatorValue = initialAccumulatorValue; } @@ -142,7 +202,7 @@ private <T extends TType> void createAdaGradSlot(Output<T> v) { protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { Variable<T> slot = getSlot(variable, ACCUMULATOR).get(); return tf.train.applyAdagrad( - variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), gradient); + variable, slot, tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), gradient); } /** {@inheritDoc} */ @@ -159,6 +219,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Adagrad"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index e74a1d85359..0e070f2f4fa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -22,6 +22,7 @@ import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -46,6 +47,7 @@ */ public class AdaGradDA extends Optimizer { + public static final String DEFAULT_NAME = "adagrad-da"; public static final String ACCUMULATOR = "gradient_accumulator"; public static final String SQUARED_ACCUMULATOR = "gradient_squared_accumulator"; public static final float LEARNING_RATE_DEFAULT = 0.001F; @@ -53,14 +55,16 @@ public class AdaGradDA extends Optimizer { public static final float L1_STRENGTH_DEFAULT = 0.0F; public static final float L2_STRENGTH_DEFAULT = 0.0F; - private final float learningRate; private final float initialAccumulatorValue; private final float l1Strength; private final float l2Strength; private Variable<TInt64> globalStep; /** - * Creates an AdaGradDA Optimizer + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for the l1Strength, and {@link + * #L2_STRENGTH_DEFAULT} for the l2Strength. * * @param graph the TensorFlow Graph */ @@ -74,7 +78,9 @@ public AdaGradDA(Graph graph) { } /** - * Creates an AdaGradDA Optimizer + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for + * the l1Strength, and {@link #L2_STRENGTH_DEFAULT} for the l2Strength. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -85,7 +91,25 @@ public AdaGradDA(Graph graph, float learningRate) { } /** - * Creates an AdaGradDA Optimizer + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for + * the l1Strength, and {@link #L2_STRENGTH_DEFAULT} for the l2Strength. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaGradDA(Graph graph, Operand<TFloat32> learningRateOperand) { + this( + graph, + learningRateOperand, + INITIAL_ACCUMULATOR_DEFAULT, + L1_STRENGTH_DEFAULT, + L2_STRENGTH_DEFAULT); + } + + /** + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -101,31 +125,37 @@ public AdaGradDA( float initialAccumulatorValue, float l1Strength, float l2Strength) { - super(graph); - if (initialAccumulatorValue <= 0F) { - throw new IllegalArgumentException( - String.format( - "initialAccumulatorValue must be greater than zero: %f", initialAccumulatorValue)); - } - if (l1Strength < 0F) { - throw new IllegalArgumentException( - String.format("l1Strength must not be negative: %f", l1Strength)); - } - if (l2Strength < 0F) { - throw new IllegalArgumentException( - String.format("l2Strength must not be negative: %f", l2Strength)); - } - this.learningRate = learningRate; - this.initialAccumulatorValue = initialAccumulatorValue; - this.l1Strength = l1Strength; - this.l2Strength = l2Strength; + this(graph, null, learningRate, initialAccumulatorValue, l1Strength, l2Strength); } /** - * Creates an AdaGradDA Optimizer + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param initialAccumulatorValue Starting value for the accumulators, must be greater than zero. + * @param l1Strength l1 regularization strength, must be greater than or equal to zero. + * @param l2Strength l2 regularization strength, must be greater than or equal to zero. + * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is not greater than zero, + * or l1Strength or l2Strength is less than zero + */ + public AdaGradDA( + Graph graph, + Operand<TFloat32> learningRateOperand, + float initialAccumulatorValue, + float l1Strength, + float l2Strength) { + this(graph, null, learningRateOperand, initialAccumulatorValue, l1Strength, l2Strength); + } + + /** + * Creates an AdaGradDA Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for the l1Strength, and {@link + * #L2_STRENGTH_DEFAULT} for the l2Strength. * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'adagrad-da') + * @param name the name for this Optimizer. * @param learningRate the learning rate */ public AdaGradDA(Graph graph, String name, float learningRate) { @@ -138,11 +168,31 @@ public AdaGradDA(Graph graph, String name, float learningRate) { L2_STRENGTH_DEFAULT); } + /** + * Creates an AdaGradDA Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for the l1Strength, and {@link + * #L2_STRENGTH_DEFAULT} for the l2Strength. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaGradDA(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this( + graph, + name, + learningRateOperand, + INITIAL_ACCUMULATOR_DEFAULT, + L1_STRENGTH_DEFAULT, + L2_STRENGTH_DEFAULT); + } + /** * Creates an AdaGradDA Optimizer * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'adagrad-da') + * @param name the name for this Optimizer. * @param learningRate the learning rate * @param initialAccumulatorValue Starting value for the accumulators, must be positive * @param l1Strength l1 regularization strength, must be greater than or equal to zero. @@ -157,7 +207,46 @@ public AdaGradDA( float initialAccumulatorValue, float l1Strength, float l2Strength) { - super(graph, name); + super(graph, name, learningRate); + if (initialAccumulatorValue <= 0F) { + throw new IllegalArgumentException( + String.format( + "initialAccumulatorValue must be greater than zero: %f", initialAccumulatorValue)); + } + if (l1Strength < 0F) { + throw new IllegalArgumentException( + String.format("l1Strength must not be negative: %f", l1Strength)); + } + if (l2Strength < 0F) { + throw new IllegalArgumentException( + String.format("l2Strength must not be negative: %f", l2Strength)); + } + this.initialAccumulatorValue = initialAccumulatorValue; + this.l1Strength = l1Strength; + this.l2Strength = l2Strength; + } + + /** + * Creates an AdaGradDA Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param initialAccumulatorValue Starting value for the accumulators, must be positive + * @param l1Strength l1 regularization strength, must be greater than or equal to zero. + * @param l2Strength l2 regularization strength, must be greater than or equal to zero. + * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is not greater than zero, + * or * l1Strength or l2Strength is less than zero + */ + public AdaGradDA( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float initialAccumulatorValue, + float l1Strength, + float l2Strength) { + super(graph, name, learningRateOperand); if (initialAccumulatorValue <= 0F) { throw new IllegalArgumentException( String.format( @@ -171,7 +260,6 @@ public AdaGradDA( throw new IllegalArgumentException( String.format("l2Strength must not be negative: %f", l2Strength)); } - this.learningRate = learningRate; this.initialAccumulatorValue = initialAccumulatorValue; this.l1Strength = l1Strength; this.l2Strength = l2Strength; @@ -218,7 +306,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable gradSlot, gradSquaredSlot, gradient, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), tf.dtypes.cast(tf.constant(l1Strength), gradient.dataType()), tf.dtypes.cast(tf.constant(l2Strength), gradient.dataType()), globalStep); @@ -259,6 +347,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "adagrad-da"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java index 8f620678781..9e4f41f1039 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java @@ -48,6 +48,8 @@ @Operator public class Adam extends Optimizer { + public static final String DEFAULT_NAME = "Adam"; + public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; @@ -56,15 +58,12 @@ public class Adam extends Optimizer { public static final float BETA_ONE_DEFAULT = 0.9f; public static final float BETA_TWO_DEFAULT = 0.999f; - private final float learningRate; - private final float betaOne; private final float betaTwo; private final float epsilon; - private Constant<TFloat32> learningRateConst; private Constant<TFloat32> epsilonConst; private Constant<TFloat32> betaOneConst; private Constant<TFloat32> betaTwoConst; @@ -72,7 +71,9 @@ public class Adam extends Optimizer { private Variable<TFloat32> betaTwoPower; /** - * Creates an Adam optimizer + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, + * {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph */ @@ -81,7 +82,9 @@ public Adam(Graph graph) { } /** - * Creates an Adam optimizer + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and + * {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph * @param learningRate the learning rate @@ -89,53 +92,121 @@ public Adam(Graph graph) { public Adam(Graph graph, float learningRate) { this(graph, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } + /** + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and + * {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Adam(Graph graph, Operand<TFloat32> learningRateOperand) { + this(graph, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } /** - * Creates an Adam optimizer + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph * @param learningRate the learning rate - * @param betaOne The exponential decay rate for the 1st moment estimates. Defaults to 0.9. - * @param betaTwo The exponential decay rate for the 2nd moment estimates. Defaults to 0.999. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the 2nd moment estimates. * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm - * 1 of the paper. Defaults to 1e-8. + * 1 of the paper.. */ public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph); - this.learningRate = learningRate; - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; + this(graph, null, learningRate, betaOne, betaTwo, epsilon); } /** - * Creates an Adam optimizer + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph - * @param name the Optimizer name, defaults to "Adam" + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the 2nd moment estimates. + * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the + * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm + * 1 of the paper. + */ + public Adam( + Graph graph, + Operand<TFloat32> learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + this(graph, null, learningRateOperand, betaOne, betaTwo, epsilon); + } + + /** + * Creates an Adam optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link + * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param name the Optimizer name. * @param learningRate the learning rate */ public Adam(Graph graph, String name, float learningRate) { this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } + /** + * Creates an Adam optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link + * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param name the Optimizer name. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Adam(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this(graph, name, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + /** * Creates an Adam optimizer * * @param graph the TensorFlow graph - * @param name the Optimizer name, defaults to "Adam" + * @param name the Optimizer name. * @param learningRate the learning rate - * @param betaOne The exponential decay rate for the 1st moment estimates. Defaults to 0.9. - * @param betaTwo The exponential decay rate for the 2nd moment estimates. Defaults to 0.999. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the 2nd moment estimates. * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm - * 1 of the paper. Defaults to 1e-8. + * 1 of the paper. */ public Adam( Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph, name); - this.learningRate = learningRate; + super(graph, name, learningRate); + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + + /** + * Creates an Adam optimizer + * + * @param graph the TensorFlow graph + * @param name the Optimizer name. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the 2nd moment estimates. + * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the + * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm + * 1 of the paper. + */ + public Adam( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + super(graph, name, learningRateOperand); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; @@ -202,7 +273,6 @@ protected void createSlots(List<Output<? extends TType>> variables) { protected Optional<Op> prepare(String scopeName) { betaOneConst = tf.constant(betaOne); betaTwoConst = tf.constant(betaTwo); - learningRateConst = tf.constant(learningRate); epsilonConst = tf.constant(epsilon); return Optional.empty(); } @@ -233,7 +303,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable secondMomentSlot, tf.dtypes.cast(betaOnePower, gradient.dataType()), tf.dtypes.cast(betaTwoPower, gradient.dataType()), - tf.dtypes.cast(learningRateConst, gradient.dataType()), + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), tf.dtypes.cast(betaOneConst, gradient.dataType()), tf.dtypes.cast(betaTwoConst, gradient.dataType()), tf.dtypes.cast(epsilonConst, gradient.dataType()), @@ -274,6 +344,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Adam"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index 335d83cedfa..e33775db961 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -25,6 +25,7 @@ */ public class Adamax extends Optimizer { + public static final String DEFAULT_NAME = "Adamax"; public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; @@ -33,19 +34,20 @@ public class Adamax extends Optimizer { public static final float BETA_ONE_DEFAULT = 0.9f; public static final float BETA_TWO_DEFAULT = 0.999f; - private float learningRate; private final float betaOne; private final float betaTwo; private final float epsilon; - private Constant<TFloat32> learningRateConst; private Constant<TFloat32> epsilonConst; private Constant<TFloat32> betaOneConst; private Constant<TFloat32> betaTwoConst; private Variable<TFloat32> betaOnePower; /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, using {@link #DEFAULT_NAME} for the + * Optimizer name, {@link #LEARNING_RATE_DEFAULT} for the learning rate, {@link #BETA_ONE_DEFAULT} + * for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link + * #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph */ @@ -54,17 +56,21 @@ public Adamax(Graph graph) { } /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, {@link #LEARNING_RATE_DEFAULT} for + * the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} + * for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph - * @param name name for the operations Created when applying gradients. Defaults to "Adamax". + * @param name name for the operations Created when applying gradients. */ public Adamax(Graph graph, String name) { this(graph, name, LEARNING_RATE_DEFAULT, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, using {@link #DEFAULT_NAME} for the + * Optimizer name, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for + * the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph * @param learningRate The learning rate. @@ -74,18 +80,48 @@ public Adamax(Graph graph, float learningRate) { } /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, using {@link #DEFAULT_NAME} for the + * Optimizer name, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for + * the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Adamax(Graph graph, Operand<TFloat32> learningRateOperand) { + this(graph, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates an Optimizer that implements the Adamax algorithm, using {@link #BETA_ONE_DEFAULT} for + * the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link + * #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph - * @param name name for the operations Created when applying gradients. Defaults to "Adamax". + * @param name name for the operations Created when applying gradients. * @param learningRate The learning rate. */ public Adamax(Graph graph, String name, float learningRate) { this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } + /** + * Creates an Optimizer that implements the Adamax algorithm, using {@link #BETA_ONE_DEFAULT} for + * the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link + * #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param name name for the operations Created when applying gradients. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Adamax(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this(graph, name, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, {@link #LEARNING_RATE_DEFAULT} for + * the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} + * for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph * @param learningRate The learning rate. @@ -94,8 +130,42 @@ public Adamax(Graph graph, String name, float learningRate) { * @param epsilon A small constant for numerical stability. */ public Adamax(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph); - this.learningRate = learningRate; + this(graph, null, learningRate, betaOne, betaTwo, epsilon); + } + /** + * Creates an Optimizer that implements the Adamax algorithm, {@link #LEARNING_RATE_DEFAULT} for + * the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} + * for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. + */ + public Adamax( + Graph graph, + Operand<TFloat32> learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + this(graph, null, learningRateOperand, betaOne, betaTwo, epsilon); + } + + /** + * Creates an Optimizer that implements the Adamax algorithm. + * + * @param graph the TensorFlow graph + * @param name name for the operations Created when applying gradients. + * @param learningRate The learning rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. + */ + public Adamax( + Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { + super(graph, name, learningRate); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; @@ -105,16 +175,21 @@ public Adamax(Graph graph, float learningRate, float betaOne, float betaTwo, flo * Creates an Optimizer that implements the Adamax algorithm. * * @param graph the TensorFlow graph - * @param name name for the operations Created when applying gradients. Defaults to "Adamax". - * @param learningRate The learning rate. + * @param name name for the operations Created when applying gradients. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. * @param betaOne The exponential decay rate for the 1st moment estimates. * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. * @param epsilon A small constant for numerical stability. */ public Adamax( - Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph, name); - this.learningRate = learningRate; + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + super(graph, name, learningRateOperand); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; @@ -125,7 +200,6 @@ public Adamax( protected Optional<Op> prepare(String scopeName) { betaOneConst = tf.constant(betaOne); betaTwoConst = tf.constant(betaTwo); - learningRateConst = tf.constant(learningRate); epsilonConst = tf.constant(epsilon); return Optional.empty(); @@ -168,7 +242,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable firstMomentSlot, secondMomentSlot, tf.dtypes.cast(betaOnePower, gradient.dataType()), - tf.dtypes.cast(learningRateConst, gradient.dataType()), + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), tf.dtypes.cast(betaOneConst, gradient.dataType()), tf.dtypes.cast(betaTwoConst, gradient.dataType()), tf.dtypes.cast(epsilonConst, gradient.dataType()), @@ -185,6 +259,6 @@ protected Op finish(List<Op> updateOperations, String name) { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Adamax"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index 13f68e4bbef..35eeb7dc225 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -6,6 +6,7 @@ import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyFtrl; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -20,6 +21,8 @@ */ public class Ftrl extends Optimizer { + public static final String DEFAULT_NAME = "Ftrl"; + public static final String ACCUMULATOR = "gradient_accumulator"; public static final String LINEAR_ACCUMULATOR = "linear_accumulator"; @@ -30,7 +33,6 @@ public class Ftrl extends Optimizer { public static final float L2STRENGTH_DEFAULT = 0.0f; public static final float L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT = 0.0f; - private float learningRate; private final float learningRatePower; private final float initialAccumulatorValue; private final float l1RegularizationStrength; @@ -38,7 +40,12 @@ public class Ftrl extends Optimizer { private final float l2ShrinkageRegularizationStrength; /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #LEARNING_RATE_POWER_DEFAULT} for the + * learningRatePower. {@link #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, + * {@link #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength + * and {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. * * @param graph the TensorFlow Graph */ @@ -54,7 +61,12 @@ public Ftrl(Graph graph) { } /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #LEARNING_RATE_DEFAULT} for the learning rate, {@link + * #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. {@link + * #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. * * @param graph the TensorFlow Graph * @param name the name of this Optimizer @@ -72,7 +84,12 @@ public Ftrl(Graph graph, String name) { } /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. {@link + * #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -89,7 +106,34 @@ public Ftrl(Graph graph, float learningRate) { } /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. {@link + * #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Ftrl(Graph graph, Operand<TFloat32> learningRateOperand) { + this( + graph, + learningRateOperand, + LEARNING_RATE_POWER_DEFAULT, + INITIAL_ACCUMULATOR_VALUE_DEFAULT, + L1STRENGTH_DEFAULT, + L2STRENGTH_DEFAULT, + L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT); + } + + /** + * Creates an Ftrl Optimizer using {@link #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. + * {@link #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. * * @param graph the TensorFlow Graph * @param name the name of this Optimizer @@ -107,10 +151,110 @@ public Ftrl(Graph graph, String name, float learningRate) { L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT); } + /** + * Creates an Ftrl Optimizer using {@link #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. + * {@link #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Ftrl(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this( + graph, + name, + learningRateOperand, + LEARNING_RATE_POWER_DEFAULT, + INITIAL_ACCUMULATOR_VALUE_DEFAULT, + L1STRENGTH_DEFAULT, + L2STRENGTH_DEFAULT, + L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT); + } + + /** + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow Graph + * @param learningRate the learning rate + * @param learningRatePower Controls how the learning rate decreases during training. Use zero for + * a fixed learning rate. + * @param initialAccumulatorValue The starting value for accumulators. Only zero or positive + * values are allowed. + * @param l1Strength the L1 Regularization strength, must be greater than or equal to zero. + * @param l2Strength the L2 Regularization strength, must be greater than or equal to zero. + * @param l2ShrinkageRegularizationStrength This differs from L2 above in that the L2 above is a + * stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. must be greater + * than or equal to zero. + * @throws java.lang.IllegalArgumentException if the initialAccumulatorValue, + * l1RegularizationStrength, l2RegularizationStrength, or l2ShrinkageRegularizationStrength + * are less than 0.0, or learningRatePower is greater than 0.0. + */ + public Ftrl( + Graph graph, + float learningRate, + float learningRatePower, + float initialAccumulatorValue, + float l1Strength, + float l2Strength, + float l2ShrinkageRegularizationStrength) { + this( + graph, + null, + learningRate, + learningRatePower, + initialAccumulatorValue, + l1Strength, + l2Strength, + l2ShrinkageRegularizationStrength); + } + + /** + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param learningRatePower Controls how the learning rate decreases during training. Use zero for + * a fixed learning rate. + * @param initialAccumulatorValue The starting value for accumulators. Only zero or positive + * values are allowed. + * @param l1Strength the L1 Regularization strength, must be greater than or equal to zero. + * @param l2Strength the L2 Regularization strength, must be greater than or equal to zero. + * @param l2ShrinkageRegularizationStrength This differs from L2 above in that the L2 above is a + * stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. must be greater + * than or equal to zero. + * @throws java.lang.IllegalArgumentException if the initialAccumulatorValue, + * l1RegularizationStrength, l2RegularizationStrength, or l2ShrinkageRegularizationStrength + * are less than 0.0, or learningRatePower is greater than 0.0. + */ + public Ftrl( + Graph graph, + Operand<TFloat32> learningRateOperand, + float learningRatePower, + float initialAccumulatorValue, + float l1Strength, + float l2Strength, + float l2ShrinkageRegularizationStrength) { + this( + graph, + null, + learningRateOperand, + learningRatePower, + initialAccumulatorValue, + l1Strength, + l2Strength, + l2ShrinkageRegularizationStrength); + } + /** * Creates a Ftrl Optimizer * * @param graph the TensorFlow Graph + * @param name the name of this Optimizer * @param learningRate the learning rate * @param learningRatePower Controls how the learning rate decreases during training. Use zero for * a fixed learning rate. @@ -127,14 +271,14 @@ public Ftrl(Graph graph, String name, float learningRate) { */ public Ftrl( Graph graph, + String name, float learningRate, float learningRatePower, float initialAccumulatorValue, float l1Strength, float l2Strength, float l2ShrinkageRegularizationStrength) { - super(graph); - this.learningRate = learningRate; + super(graph, name, learningRate); this.learningRatePower = learningRatePower; this.initialAccumulatorValue = initialAccumulatorValue; this.l1RegularizationStrength = l1Strength; @@ -148,7 +292,8 @@ public Ftrl( * * @param graph the TensorFlow Graph * @param name the name of this Optimizer - * @param learningRate the learning rate + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. * @param learningRatePower Controls how the learning rate decreases during training. Use zero for * a fixed learning rate. * @param initialAccumulatorValue The starting value for accumulators. Only zero or positive @@ -165,14 +310,13 @@ public Ftrl( public Ftrl( Graph graph, String name, - float learningRate, + Operand<TFloat32> learningRateOperand, float learningRatePower, float initialAccumulatorValue, float l1Strength, float l2Strength, float l2ShrinkageRegularizationStrength) { - super(graph, name); - this.learningRate = learningRate; + super(graph, name, learningRateOperand); this.learningRatePower = learningRatePower; this.initialAccumulatorValue = initialAccumulatorValue; this.l1RegularizationStrength = l1Strength; @@ -181,7 +325,7 @@ public Ftrl( validateParams(); } - /** Validates all the settings of the Frtl Optmizer */ + /** Validates all the settings of the Ftrl Optimizer */ private void validateParams() { if (this.initialAccumulatorValue < 0.0F) { throw new IllegalArgumentException( @@ -242,24 +386,22 @@ private <T extends TType> void createFtrlSlot(Output<T> v) { protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { Variable<T> accumSlot = getSlot(variable, ACCUMULATOR).get(); Variable<T> linearSlot = getSlot(variable, LINEAR_ACCUMULATOR).get(); - ApplyFtrl.Options options = ApplyFtrl.useLocking(true); return this.tf.train.applyFtrl( variable, - accumSlot, // accum - linearSlot, // linear - gradient, // gradient - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), // lr - tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.dataType()), // l1 - tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.dataType()), // l2 - tf.dtypes.cast( - tf.constant(l2ShrinkageRegularizationStrength), gradient.dataType()), // l2Shrinkage - tf.dtypes.cast(tf.constant(learningRatePower), gradient.dataType()), // lrPower - options); + accumSlot, + linearSlot, + gradient, + tf.dtypes.cast(this.getLearningRateOperand(), gradient.dataType()), + tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.dataType()), + tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.dataType()), + tf.dtypes.cast(tf.constant(l2ShrinkageRegularizationStrength), gradient.dataType()), + tf.dtypes.cast(tf.constant(learningRatePower), gradient.dataType()), + ApplyFtrl.useLocking(true)); } /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Ftrl"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java index e307855e636..efec399f40d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java @@ -16,8 +16,10 @@ package org.tensorflow.framework.optimizers; import org.tensorflow.Graph; +import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.Op; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; /** @@ -26,12 +28,12 @@ */ public class GradientDescent extends Optimizer { + public static final String DEFAULT_NAME = "GradientDescent"; public static final float LEARNING_RATE_DEFAULT = 0.01f; - private final float learningRate; - /** - * Creates a GradientDescent Optimizer + * Creates a GradientDescent Optimizer using {@link #DEFAULT_NAME} for the Optimizer name and + * {@link #LEARNING_RATE_DEFAULT} for the learning rate. * * @param graph the TensorFlow graph */ @@ -40,33 +42,54 @@ public GradientDescent(Graph graph) { } /** - * Creates a GradientDescent Optimizer + * Creates a GradientDescent Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph - * @param learningRate the learning rate, defaults to 0.01 + * @param learningRate the learning rate. */ public GradientDescent(Graph graph, float learningRate) { - super(graph); - this.learningRate = learningRate; + super(graph, null, learningRate); + } + + /** + * Creates a GradientDescent Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public GradientDescent(Graph graph, Operand<TFloat32> learningRateOperand) { + super(graph, null, learningRateOperand); } /** * Creates a GradientDescent Optimizer * * @param graph the TensorFlow graph - * @param name the name for this Optimizer, default is "GradientDescent" - * @param learningRate the learning rate, defaults to 0.01 + * @param name the name for this Optimizer. + * @param learningRate the learning rate. */ public GradientDescent(Graph graph, String name, float learningRate) { - super(graph, name); - this.learningRate = learningRate; + super(graph, name, learningRate); + } + + /** + * Creates a GradientDescent Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public GradientDescent(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + super(graph, name, learningRateOperand); } /** {@inheritDoc} */ @Override protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { return tf.train.applyGradientDescent( - variable, tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), gradient); + variable, tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), gradient); } /** {@inheritDoc} */ @@ -78,6 +101,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "GradientDescent"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index 111727d26fa..436b587c353 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -21,6 +21,7 @@ import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyMomentum; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -33,20 +34,21 @@ */ public class Momentum extends Optimizer { + public static final String DEFAULT_NAME = "Momentum"; public static final float LEARNING_RATE_DEFAULT = 0.01F; public static final float MOMENTUM_DEFAULT = 0.0F; public static final boolean NESTEROV_DEFAULT = false; public static final String MOMENTUM = "momentum"; - private final float learningRate; - private final float momentum; private final boolean useNesterov; /** - * Creates a Momentum Optimizer + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #MOMENTUM_DEFAULT} for the momentum, and + * {@link #NESTEROV_DEFAULT} for the Nesterov flag. * * @param graph the TensorFlow graph */ @@ -55,7 +57,8 @@ public Momentum(Graph graph) { } /** - * Creates a Momentum Optimizer + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #MOMENTUM_DEFAULT} for the momentum, and {@link #NESTEROV_DEFAULT} for the Nesterov flag. * * @param graph the TensorFlow graph * @param learningRate the learning rate @@ -65,31 +68,74 @@ public Momentum(Graph graph, float learningRate) { } /** - * Creates a Momentum Optimizer + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #MOMENTUM_DEFAULT} for the momentum, and {@link #NESTEROV_DEFAULT} for the Nesterov flag. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Momentum(Graph graph, Operand<TFloat32> learningRateOperand) { + this(graph, learningRateOperand, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); + } + + /** + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name and {@link + * #NESTEROV_DEFAULT} for the Nesterov flag. * * @param graph the TensorFlow graph * @param learningRate the learning rate * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and - * dampens oscillations, Must be greater than or equal to zero. Default is 0. + * dampens oscillations, Must be greater than or equal to zero. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. */ public Momentum(Graph graph, float learningRate, float momentum) { this(graph, learningRate, momentum, NESTEROV_DEFAULT); } /** - * Creates a Momentum Optimizer + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name and {@link + * #NESTEROV_DEFAULT} for the Nesterov flag. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. + */ + public Momentum(Graph graph, Operand<TFloat32> learningRateOperand, float momentum) { + this(graph, learningRateOperand, momentum, NESTEROV_DEFAULT); + } + + /** + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph * @param learningRate the learning rate * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and - * dampens oscillations, Must be greater than or equal to zero. Default is 0. - * @param useNesterov Whether to apply Nesterov momentum. Defaults to false. + * dampens oscillations, Must be greater than or equal to zero. + * @param useNesterov Whether to apply Nesterov momentum. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. */ public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) { - super(graph); - this.learningRate = learningRate; - this.momentum = momentum; - this.useNesterov = useNesterov; + this(graph, null, learningRate, momentum, useNesterov); + } + + /** + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. + * @param useNesterov Whether to apply Nesterov momentum. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. + */ + public Momentum( + Graph graph, Operand<TFloat32> learningRateOperand, float momentum, boolean useNesterov) { + this(graph, null, learningRateOperand, momentum, useNesterov); } /** @@ -99,13 +145,40 @@ public Momentum(Graph graph, float learningRate, float momentum, boolean useNest * @param name the name for this Optimizer * @param learningRate the learning rate * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and - * dampens oscillations, Must be greater than or equal to zero. Default is 0. - * @param useNesterov Whether to apply Nesterov momentum. Defaults to false. + * dampens oscillations, Must be greater than or equal to zero. + * @param useNesterov Whether to apply Nesterov momentum. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. */ public Momentum( Graph graph, String name, float learningRate, float momentum, boolean useNesterov) { - super(graph, name); - this.learningRate = learningRate; + super(graph, name, learningRate); + if (momentum < 0) + throw new IllegalArgumentException("momentum must be greater than or equal to zero."); + this.momentum = momentum; + this.useNesterov = useNesterov; + } + + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. + * @param useNesterov Whether to apply Nesterov momentum. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. + */ + public Momentum( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float momentum, + boolean useNesterov) { + super(graph, name, learningRateOperand); + if (momentum < 0) + throw new IllegalArgumentException("momentum must be greater than or equal to zero."); this.momentum = momentum; this.useNesterov = useNesterov; } @@ -136,7 +209,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable return tf.train.applyMomentum( variable, slot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), gradient, tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), ApplyMomentum.useNesterov(useNesterov)); @@ -158,6 +231,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Momentum"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index 48e5135c952..202f5013e7d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -25,6 +25,7 @@ */ public class Nadam extends Optimizer { + public static final String DEFAULT_NAME = "Nadam"; private static final float DECAY_BASE = 0.96f; private static final float DECAY = 0.004f; public static final float LEARNING_RATE_DEFAULT = 0.001f; @@ -35,9 +36,6 @@ public class Nadam extends Optimizer { public static final String SECOND_MOMENT = "v"; public static final String MOMENTUM = "momentum"; - /** The learning rate. */ - private final float learningRate; - /** The exponential decay rate for the 1st moment estimates. */ private final float betaOne; @@ -47,7 +45,6 @@ public class Nadam extends Optimizer { /** A small constant for numerical stability. */ private final float epsilon; - private Constant<TFloat32> learningRateConst; private Constant<TFloat32> epsilonConst; private Constant<TFloat32> betaOneConst; private Constant<TFloat32> betaTwoConst; @@ -69,7 +66,9 @@ public class Nadam extends Optimizer { private Operand<TFloat32> vTPrimeDenominator; /** - * Creates a Nadam Optimizer + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, + * {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph */ @@ -78,59 +77,124 @@ public Nadam(Graph graph) { } /** - * Creates a Nadam Optimizer + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and + * {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph - * @param learningRate the learning rate, defaults to 0.001 + * @param learningRate the learning rate. */ public Nadam(Graph graph, float learningRate) { this(graph, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } /** - * Creates a Nadam Optimizer + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and + * {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Nadam(Graph graph, Operand<TFloat32> learningRateOperand) { + this(graph, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph - * @param learningRate the learning rate, defaults to 0.001 - * @param betaOne The exponential decay rate for the 1st moment estimates. Default is 0.9. - * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. Default - * is 0.999. - * @param epsilon A small constant for numerical stability. Default is 1e-8. + * @param learningRate the learning rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. */ public Nadam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph); - this.learningRate = learningRate; - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; + this(graph, null, learningRate, betaOne, betaTwo, epsilon); } /** - * Creates a Nadam Optimizer + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. + */ + public Nadam( + Graph graph, + Operand<TFloat32> learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + this(graph, null, learningRateOperand, betaOne, betaTwo, epsilon); + } + + /** + * Creates a Nadam Optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link + * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph - * @param name the name for this Optimizer, defaults to "Nadam" - * @param learningRate the learning rate, defaults to 0.001 + * @param name the name for this Optimizer. + * @param learningRate the learning rate. */ public Nadam(Graph graph, String name, float learningRate) { this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } + /** + * Creates a Nadam Optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link + * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Nadam(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this(graph, name, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + /** * Creates a Nadam Optimizer * * @param graph the TensorFlow graph - * @param name the name for this Optimizer, defaults to "Nadam" - * @param learningRate the learning rate, defaults to 0.001 - * @param betaOne The exponential decay rate for the 1st moment estimates. Default is 0.9. - * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. Default - * is 0.999. - * @param epsilon A small constant for numerical stability. Default is 1e-8. + * @param name the name for this Optimizer. + * @param learningRate the learning rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. */ public Nadam( Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph, name); - this.learningRate = learningRate; + super(graph, name, learningRate); + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. + */ + public Nadam( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + super(graph, name, learningRateOperand); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; @@ -180,7 +244,6 @@ protected Optional<Op> prepare(String scopeName) { Constant<TFloat32> one = tf.constant(1.0F); Constant<TFloat32> point5 = tf.constant(0.5F); - learningRateConst = tf.constant(learningRate); betaOneConst = tf.constant(betaOne); betaTwoConst = tf.constant(betaTwo); Constant<TInt64> localStepConst = tf.constant(this.iterations + 1); @@ -271,7 +334,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable tf.math.sub( variable, tf.math.div( - tf.math.mul(tf.dtypes.cast(learningRateConst, dType), m_t_bar), + tf.math.mul(tf.dtypes.cast(getLearningRateOperand(), dType), m_t_bar), tf.math.add(tf.math.sqrt(vTPrime), tf.dtypes.cast(epsilonConst, dType)))); return tf.assign(variable, varT, Assign.useLocking(true)); @@ -297,6 +360,6 @@ protected Op finish(List<Op> updateOperations, String name) { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Nadam"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index 933a54c7670..586fef28c1e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -15,25 +15,27 @@ */ package org.tensorflow.framework.optimizers; -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.Output; +import org.tensorflow.*; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.NoOp; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.*; import java.util.stream.Collectors; /** Base class for gradient optimizers. */ -public abstract class Optimizer { - +public abstract class Optimizer implements AutoCloseable { + public static final String LEARNING_RATE = "learning_rate"; public static final String VARIABLE_V2 = "VariableV2"; + public static final float LEARNING_RATE_DEFAULT = 0.001f; + /** Global state variables */ // TODO make this be used. protected final List<Variable<?>> globals; @@ -44,18 +46,25 @@ public abstract class Optimizer { /** Top level map key is the variable name, lower level map key is the slot name. */ private final Map<String, Map<String, Variable<?>>> slots; + protected float learningRate; + protected Placeholder<TFloat32> learningRatePlaceholder = null; + private Tensor<TFloat32> learningRateTensor; + private Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap = null; + private Operand<TFloat32> learningRateOperand; + /** * Builds an optimizer for the supplied graph. * - * <p>Uses the name from {@link Optimizer#getOptimizerName()} to name the operations. - * * @param graph The graph to optimize. + * @param name The base name for the operations. + * @param learningRate the learning rate. */ - protected Optimizer(Graph graph) { + protected Optimizer(Graph graph, String name, float learningRate) { this.graph = graph; - this.tf = Ops.create(graph).withName(getOptimizerName()); + this.tf = Ops.create(graph).withName(name == null ? getOptimizerName() : name); this.slots = new HashMap<>(); this.globals = new ArrayList<>(); + setLearningRate(learningRate); } /** @@ -63,20 +72,14 @@ protected Optimizer(Graph graph) { * * @param graph The graph to optimize. * @param name The base name for the operations. + * @param learningRateOperand the learning rate. */ - protected Optimizer(Graph graph, String name) { + protected Optimizer(Graph graph, String name, Operand<TFloat32> learningRateOperand) { this.graph = graph; - this.tf = Ops.create(graph).withName(name); + this.tf = Ops.create(graph).withName(name == null ? getOptimizerName() : name); this.slots = new HashMap<>(); this.globals = new ArrayList<>(); - } - - /** - * Gets the Optimizer's Ops instance - * @return the Optimizer's Ops instance - */ - public final Ops getTF() { - return tf; + setLearningRateOperand(learningRateOperand); } /** @@ -229,7 +232,7 @@ protected <T extends TType> void createSlot( } /** - * Returns a No-op prepare. + * No-op prepare method. * * @param scopeName The scope name to use for any variable creations. */ @@ -238,7 +241,7 @@ protected Optional<Op> prepare(String scopeName) { } /** - * Performs a No-op slot creation method. + * No-op slot creation method. * * @param variables The variables to create slots for. */ @@ -280,12 +283,85 @@ protected Op finish(List<Op> updateOperations, String name) { } /** - * Get the Name of the optimizer. + * Gets the Name of the optimizer. * * @return The optimizer name. */ public abstract String getOptimizerName(); + /** + * Sets the learning rate + * + * @param newLearningRate the new learning rate + */ + public final void setLearningRate(float newLearningRate) { + if (learningRatePlaceholder == null) { + learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + } + + if (learningRate != newLearningRate) { + if (learningRateTensor != null) learningRateTensor.close(); + learningRate = newLearningRate; + learningRateTensor = TFloat32.scalarOf(learningRate); + feedMap = Collections.singletonMap(learningRatePlaceholder, learningRateTensor); + } + } + + /** + * Sets the learning rate Operand. The learning rate operand is an operand that is used to + * calculate the learning rate. + * + * @param newLearningRateOperand the new learning rate operand. + */ + public final void setLearningRateOperand(Operand<TFloat32> newLearningRateOperand) { + close(); // Cleanup the placeholder and tensor if they exist. + learningRateOperand = newLearningRateOperand; + } + + /** + * Gets the learning rate + * + * @return the learning rate + */ + public float getLearningRate() { + return learningRate; + } + + /** + * Gets the learning rate Operand, used by subclasses in their graph operations. If a float + * learning rate has been set using {@link #setLearningRate}, then this will be the learning rate + * Placeholder, otherwise the learning rate operand is returned as passed to {@link + * #setLearningRateOperand}. + * + * @return the learning rate Operand + */ + protected Operand<TFloat32> getLearningRateOperand() { + return learningRatePlaceholder == null ? learningRateOperand : learningRatePlaceholder; + } + + /** + * Gets the Feed Map for the run methods to set the Placeholder value(s). Each entry in the Feed + * Map contains a PlaceHolder and a Tensor with the value. + * + * @return the current Feed Map for the run methods, this will be null if the LearningRateOperand + * has been set. + */ + public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedMap() { + return feedMap; + } + + /** {@inheritDoc} */ + public void close() { + // close the learningRate Tensor if it exists. + if (learningRateTensor != null) { + learningRateTensor.close(); + learningRateTensor = null; + } + if (feedMap != null) feedMap = null; + } + /** Optional attributes for {@link org.tensorflow.framework.optimizers.Optimizer} */ public static class Options { @@ -294,8 +370,6 @@ public static class Options { private Options() {} /** - * Sets the shared name - * * @param sharedName If non-empty, this variable is named in the given bucket with this * shared_name. Otherwise, the node name is used instead. */ @@ -305,41 +379,20 @@ public Optimizer.Options sharedName(String sharedName) { } } - /** - * A class that holds a paired gradient and variable. - * - * @param <T> the data type for the gradient and variable - */ public static class GradAndVar<T extends TType> { private final Output<T> gradient; private final Output<T> variable; - /** - * Creates a Gradient and Variable pair - * - * @param gradient the gradient - * @param variable the variable - */ public GradAndVar(Output<T> gradient, Output<T> variable) { this.gradient = gradient; this.variable = variable; } - /** - * Gets the gradient - * - * @return the gradient - */ public Output<T> getGradient() { return gradient; } - /** - * Gets the variable - * - * @return the variable - */ public Output<T> getVariable() { return variable; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index 8b71558e549..7a03bec849d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -20,6 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -47,6 +48,7 @@ */ public class RMSProp extends Optimizer { + public static final String DEFAULT_NAME = "RMSProp"; public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float DECAY_DEFAULT = 0.9f; public static final float MOMENTUM_DEFAULT = 0.0f; @@ -56,14 +58,16 @@ public class RMSProp extends Optimizer { public static final String MG = "mg"; // mean gradient? public static final String MOMENTUM = "momentum"; - private final float learningRate; private final float decay; private final float momentum; private final float epsilon; private final boolean centered; /** - * Creates an RMSPRrop Optimizer + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #DECAY_DEFAULT} for the decay, {@link + * #MOMENTUM_DEFAULT} for the momentum, {@link #EPSILON_DEFAULT} for the epsilon value and {@link + * #CENTERED_DEFAULT} for the centered flag. * * @param graph the TensorFlow Graph */ @@ -78,7 +82,9 @@ public RMSProp(Graph graph) { } /** - * Creates an RMSPRrop Optimizer + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #DECAY_DEFAULT} for the decay, {@link #MOMENTUM_DEFAULT} for the momentum, {@link + * #EPSILON_DEFAULT} for the epsilon value and {@link #CENTERED_DEFAULT} for the centered flag. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -88,17 +94,36 @@ public RMSProp(Graph graph, float learningRate) { } /** - * Creates an RMSPRrop Optimizer + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #DECAY_DEFAULT} for the decay, {@link #MOMENTUM_DEFAULT} for the momentum, {@link + * #EPSILON_DEFAULT} for the epsilon value and {@link #CENTERED_DEFAULT} for the centered flag. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public RMSProp(Graph graph, Operand<TFloat32> learningRateOperand) { + this( + graph, + learningRateOperand, + DECAY_DEFAULT, + MOMENTUM_DEFAULT, + EPSILON_DEFAULT, + CENTERED_DEFAULT); + } + + /** + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow Graph * @param learningRate the learning rate - * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. - * @param momentum the acceleration factor, default is 0. + * @param decay Discounting factor for the history/coming gradient. + * @param momentum the acceleration factor. * @param epsilon A small constant for numerical stability * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the * gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code> * true</code>> may help with training, but is slightly more expensive in terms of computation - * and memory. Defaults to <code>false</code>. + * and memory. */ public RMSProp( Graph graph, @@ -107,19 +132,40 @@ public RMSProp( float momentum, float epsilon, boolean centered) { - super(graph); - this.learningRate = learningRate; - this.decay = decay; - this.momentum = momentum; - this.epsilon = epsilon; - this.centered = centered; + this(graph, null, learningRate, decay, momentum, epsilon, centered); } /** - * Creates an RMSPRrop Optimizer + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param decay Discounting factor for the history/coming gradient. + * @param momentum the acceleration factor. + * @param epsilon A small constant for numerical stability + * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the + * gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code> + * true</code>> may help with training, but is slightly more expensive in terms of computation + * and memory. + */ + public RMSProp( + Graph graph, + Operand<TFloat32> learningRateOperand, + float decay, + float momentum, + float epsilon, + boolean centered) { + this(graph, null, learningRateOperand, decay, momentum, epsilon, centered); + } + + /** + * Creates an RMSPRrop Optimizer using {@link #DECAY_DEFAULT} for the decay, {@link + * #MOMENTUM_DEFAULT} for the momentum, {@link #EPSILON_DEFAULT} for the epsilon value and {@link + * #CENTERED_DEFAULT} for the centered flag. * * @param graph the TensorFlow Graph - * @param name the name of this Optimizer. Defaults to "RMSProp". + * @param name the name of this Optimizer. * @param learningRate the learning rate */ public RMSProp(Graph graph, String name, float learningRate) { @@ -133,19 +179,40 @@ public RMSProp(Graph graph, String name, float learningRate) { CENTERED_DEFAULT); } + /** + * Creates an RMSPRrop Optimizer using {@link #DECAY_DEFAULT} for the decay, {@link + * #MOMENTUM_DEFAULT} for the momentum, {@link #EPSILON_DEFAULT} for the epsilon value and {@link + * #CENTERED_DEFAULT} for the centered flag. + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public RMSProp(Graph graph, String name, Operand<TFloat32> learningRateOperand) { + this( + graph, + name, + learningRateOperand, + DECAY_DEFAULT, + MOMENTUM_DEFAULT, + EPSILON_DEFAULT, + CENTERED_DEFAULT); + } + /** * Creates an RMSPRrop Optimizer * * @param graph the TensorFlow Graph - * @param name the name of this Optimizer. Defaults to "RMSProp". + * @param name the name of this Optimizer. * @param learningRate the learning rate - * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. - * @param momentum The acceleration factor, default is 0. + * @param decay Discounting factor for the history/coming gradient. + * @param momentum The acceleration factor,. * @param epsilon A small constant for numerical stability * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the * gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code> * true</code>> may help with training, but is slightly more expensive in terms of computation - * and memory. Defaults to <code>false</code>. + * and memory. */ public RMSProp( Graph graph, @@ -155,8 +222,37 @@ public RMSProp( float momentum, float epsilon, boolean centered) { - super(graph, name); - this.learningRate = learningRate; + super(graph, name, learningRate); + this.decay = decay; + this.momentum = momentum; + this.epsilon = epsilon; + this.centered = centered; + } + + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param decay Discounting factor for the history/coming gradient. + * @param momentum The acceleration factor. + * @param epsilon A small constant for numerical stability + * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the + * gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code> + * true</code>> may help with training, but is slightly more expensive in terms of computation + * and memory. + */ + public RMSProp( + Graph graph, + String name, + Operand<TFloat32> learningRateOperand, + float decay, + float momentum, + float epsilon, + boolean centered) { + super(graph, name, learningRateOperand); this.decay = decay; this.momentum = momentum; this.epsilon = epsilon; @@ -171,10 +267,8 @@ protected void createSlots(List<Output<? extends TType>> variables) { } } - /** - * Creates the RMSProp Slots for Root Mean Squared (RMS), - * MOMENTUM, and Mean Gradient (MG) + * Creates the RMSProp Slots for Root Mean Squared (RMS), MOMENTUM, and Mean Gradient (MG) * * @param v the variable to install in the slot * @param <T> the datatype of the variable. @@ -205,7 +299,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable mgSlot, rmsSlot, momentumSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), tf.dtypes.cast(tf.constant(decay), gradient.dataType()), tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), @@ -215,7 +309,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable variable, rmsSlot, momentumSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), tf.dtypes.cast(tf.constant(decay), gradient.dataType()), tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), @@ -242,6 +336,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "RMSProp"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java index 5c4ce542c65..7653c99bc98 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import org.tensorflow.Graph; import org.tensorflow.framework.optimizers.Optimizer.GradAndVar; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -55,11 +54,10 @@ public void tearDown() {} @Test public void testConstructAdadeltaWithLR() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - AdaDelta opt = new AdaDelta(graph, 1.0F, 0.9F, 1.F); - AdaDelta opt2 = new AdaDelta(graph, 0.1F, 0.9F, 1.F); - AdaDelta opt3 = new AdaDelta(graph, 0.1F, 0.9F, 1e-8F); + try (TestSession session = TestSession.createTestSession(tfMode); + AdaDelta opt = new AdaDelta(session.getGraph(), "opt1", 1.0F, 0.9F, 1.F); + AdaDelta opt2 = new AdaDelta(session.getGraph(), "opt2", 0.1F, 0.9F, 1.F); + AdaDelta opt3 = new AdaDelta(session.getGraph(), "opt3", 0.1F, 0.9F, 1e-8F)) { String format = "AdaDelta{learningRate=%s, rho=%s, epsilon=%s}"; String optExpected = String.format(format, 1.0F, 0.9F, 1.F); String opt2Expected = String.format(format, 0.1F, 0.9F, 1.F); @@ -80,11 +78,15 @@ public void testBasic() { int numUpdates = 4; // # number of ADADELTA steps to perform float[] grads = {0.2F, 0.1F, 0.01F}; float[] lrs = {1.0F, 0.5F, 0.1F}; + + float rho = 0.95F; + float epsilon = 1e-8F; + for (float grad : grads) { for (float lr : lrs) { - try (TestSession session = TestSession.createTestSession(tfMode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + AdaDelta instance = new AdaDelta(session.getGraph(), lr, rho, epsilon)) { Ops tf = session.getTF(); - Graph graph = session.getGraph(); float[] var0Init = {1.0F, 2.0F}; float[] var1Init = {3.0F, 4.0F}; float[] fgrads = {grad, grad}; @@ -96,37 +98,33 @@ public void testBasic() { Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); Constant<TFloat32> cgrads = tf.constant(fgrads); - float accum = 0.0F; float accumUpdate = 0.0F; - float rho = 0.95F; - float epsilon = 1e-8F; /* build the GradsAnvVars */ List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput())); gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput())); - /* get the Optimizer */ - AdaDelta adaDelta = new AdaDelta(graph, lr, rho, epsilon); - /*apply gradients */ - Op adadeltaUpdate = adaDelta.applyGradients(gradsAndVars, "AdaDeltaTest"); + Op adadeltaUpdate = instance.applyGradients(gradsAndVars, "AdaDeltaTest"); /* Create and validate the shapes of the slota */ + @SuppressWarnings("unchecked") Variable<TFloat32>[] slots = new Variable[2]; + @SuppressWarnings("unchecked") Variable<TFloat32>[] slotUpdates = new Variable[2]; - slots[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR).get(); + slots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape()); - slotUpdates[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); + slotUpdates[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape()); - slots[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR).get(); + slots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape()); - slotUpdates[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); + slotUpdates[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape()); /* initialize the local variables */ @@ -143,7 +141,8 @@ public void testBasic() { float[] updates = new float[numUpdates]; float totUpdate = 0; for (int step = 0; step < numUpdates; step++) { - session.run(adadeltaUpdate); + + session.run(adadeltaUpdate, instance.getFeedMap()); accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); updates[step] = ((float) Math.sqrt(accumUpdate + epsilon) @@ -167,4 +166,205 @@ public void testBasic() { } } } + + @Test + public void testBasicWithLROperand() { + int numUpdates = 4; // # number of ADADELTA steps to perform + float[] grads = {0.2F, 0.1F, 0.01F}; + float[] lrs = {1.0F, 0.5F, 0.1F}; + + float rho = 0.95F; + float epsilon = 1e-8F; + + for (float grad : grads) { + for (float lr : lrs) { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // this just uses a trivial operand + try (AdaDelta instance = + new AdaDelta( + session.getGraph(), + tf.math.mul(tf.constant(lr), tf.constant(1.f)), + rho, + epsilon)) { + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] fgrads = {grad, grad}; + Shape shape = Shape.of(var0Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> cgrads = tf.constant(fgrads); + float accum = 0.0F; + float accumUpdate = 0.0F; + + /* build the GradsAnvVars */ + List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput())); + + /*apply gradients */ + Op adadeltaUpdate = instance.applyGradients(gradsAndVars, "AdaDeltaTest"); + + /* Create and validate the shapes of the slota */ + @SuppressWarnings("unchecked") + Variable<TFloat32>[] slots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable<TFloat32>[] slotUpdates = new Variable[2]; + + slots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape()); + + slotUpdates[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape()); + + slots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape()); + + slotUpdates[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + float[] updates = new float[numUpdates]; + float totUpdate = 0; + for (int step = 0; step < numUpdates; step++) { + + session.run(adadeltaUpdate, instance.getFeedMap()); + accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); + updates[step] = + ((float) Math.sqrt(accumUpdate + epsilon) + * (float) (1 / Math.sqrt(accum + epsilon)) + * grad); + accumUpdate = + (accumUpdate * rho + ((float) Math.pow(updates[step], 2) * (1.0F - rho))); + totUpdate += updates[step] * lr; + + for (int i = 0; i < 2; i++) { + session.evaluate(accum, slots[i]); + session.evaluate(accumUpdate, slotUpdates[i]); + } + + Float[] var0InitUpdate = {var0Init[0] - totUpdate, var0Init[1] - totUpdate}; + Float[] var1InitUpdate = {var1Init[0] - totUpdate, var1Init[1] - totUpdate}; + + session.evaluate(var0InitUpdate, var0); + session.evaluate(var1InitUpdate, var1); + } + } + } + } + } + } + + @Test + public void testWithLearningRateDecay() { + int numSteps = 4; // # number of ADADELTA steps to perform + float[] grads = {0.2F, 0.1F, 0.01F}; + + for (float grad : grads) { + float learningRate = 1.0F; + float rho = 0.95F; + float epsilon = 1e-8F; + try (TestSession session = TestSession.createTestSession(tfMode); + AdaDelta instance = new AdaDelta(session.getGraph(), learningRate, rho, epsilon)) { + Ops tf = session.getTF(); + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] fgrads = {grad, grad}; + Shape shape = Shape.of(var0Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> cgrads = tf.constant(fgrads); + + float accum = 0.0F; + float accumUpdate = 0.0F; + + /* build the GradsAnvVars */ + List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput())); + + Op adadeltaUpdate = instance.applyGradients(gradsAndVars, "AdaDeltaTest"); + + /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") + Variable<TFloat32>[] slots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable<TFloat32>[] slotUpdates = new Variable[2]; + + slots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape()); + + slotUpdates[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape()); + + slots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape()); + + slotUpdates[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + float[] updates = new float[numSteps]; + float totUpdate = 0; + for (int step = 0; step < numSteps; step++) { + assertEquals(learningRate, instance.getLearningRate(), epsilon); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(adadeltaUpdate, instance.getFeedMap()); + accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); + updates[step] = + ((float) Math.sqrt(accumUpdate + epsilon) + * (float) (1 / Math.sqrt(accum + epsilon)) + * grad); + accumUpdate = (accumUpdate * rho + ((float) Math.pow(updates[step], 2) * (1.0F - rho))); + totUpdate += updates[step] * learningRate; + + for (int i = 0; i < 2; i++) { + session.evaluate(accum, slots[i]); + session.evaluate(accumUpdate, slotUpdates[i]); + } + + Float[] var0InitUpdate = {var0Init[0] - totUpdate, var0Init[1] - totUpdate}; + Float[] var1InitUpdate = {var1Init[0] - totUpdate, var1Init[1] - totUpdate}; + + session.evaluate(var0InitUpdate, var0); + session.evaluate(var1InitUpdate, var1); + + // Adjust learning rate + learningRate *= 0.9F; + instance.setLearningRate(learningRate); + } + } + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java index ef9053ff1eb..9e67b4660df 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import org.tensorflow.Graph; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -29,6 +28,8 @@ import java.util.ArrayList; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; + /** Test cases for AdaGradDA Optimizer */ public class AdaGradDATest { @@ -54,13 +55,11 @@ public void testBasic() { float[] var1Init = {0.0F, 0.0F}; float[] grads0Init = {0.1F, 0.2F}; float[] grads1Init = {0.01F, 0.02F}; - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); + float learningRate = 3.0F; + try (TestSession session = TestSession.createTestSession(tfMode); + AdaGradDA instance = new AdaGradDA(session.getGraph(), learningRate)) { - float learningRate = 3.0F; - - AdaGradDA instance = new AdaGradDA(graph, learningRate); - Ops tf = instance.getTF(); + Ops tf = session.getTF(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); @@ -78,7 +77,6 @@ public void testBasic() { session.run(var0Initializer); session.run(var1Initializer); - /* build the GradsAnvVars */ List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); @@ -91,11 +89,122 @@ public void testBasic() { session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); - session.run(adaUpdate); + session.run(adaUpdate, instance.getFeedMap()); float[] expected0 = {-0.904534F, -1.603567F}; session.evaluate(expected0, var0); float[] expected1 = {-0.094821f, -0.189358f}; session.evaluate(expected1, var1); } } + + @Test + public void testBasicWithLROperand() { + float[] var0Init = {0.0F, 0.0F}; + float[] var1Init = {0.0F, 0.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.02F}; + float learningRate = 1.5F; + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (AdaGradDA instance = + new AdaGradDA( + session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(2.f)))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* initialize the local variables */ + + session.run(var0Initializer); + session.run(var1Initializer); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradDATest"); + + /* initialize the accumulators */ + session.run(tf.init()); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + session.run(adaUpdate, instance.getFeedMap()); + float[] expected0 = {-0.904534F, -1.603567F}; + session.evaluate(expected0, var0); + float[] expected1 = {-0.094821f, -0.189358f}; + session.evaluate(expected1, var1); + } + } + } + + @Test + public void testWithLearningRateDecay() { + float[] var0Init = {0.0F, 0.0F}; + float[] var1Init = {0.0F, 0.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.02F}; + float epsilon = 1e-8F; + int numSteps = 4; + float learningRate = 3.0F; + try (TestSession session = TestSession.createTestSession(tfMode); + AdaGrad instance = new AdaGrad(session.getGraph(), learningRate)) { + Ops tf = session.getTF(); + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* initialize the local variables */ + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdGradDATest"); + + /** initialize the accumulators */ + session.run(tf.init()); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + float[][][] expected = { + {{-2.121320f, -2.683281f}, {-0.298511f, -0.588348f}}, + {{-3.680166f, -4.483282f}, {-0.565851f, -1.107964f}}, + {{-4.895166f, -5.831203f}, {-0.805286f, -1.567190f}}, + {{-5.873222f, -6.892054f}, {-1.019739f, -1.973306f}} + }; + for (int i = 0; i < numSteps; i++) { + assertEquals(learningRate, instance.getLearningRate(), epsilon); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); + session.evaluate(expected[i][0], var0); + session.evaluate(expected[i][1], var1); + learningRate *= 0.9; + instance.setLearningRate(learningRate); + } + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java index c5ae178b84c..dc05b9b8f81 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java @@ -15,11 +15,7 @@ package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import org.tensorflow.Graph; -import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -60,22 +56,13 @@ public void testBasic() { float[] var1Init = {3.0F, 4.0F}; float[] grads0Init = {0.1F, 0.1F}; float[] grads1Init = {0.01F, 0.01F}; - float[] accum0 = {0.1f, 0.1f}; - float[] accum1 = {0.1f, 0.1f}; - FloatNdArray var0Np = NdArrays.vectorOf(var0Init); - FloatNdArray var1Np = NdArrays.vectorOf(var1Init); - FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); - FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); - FloatNdArray accum0Np = NdArrays.vectorOf(accum0); - FloatNdArray accum1Np = NdArrays.vectorOf(accum1); + float learningRate = 3.0F; - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); + try (TestSession session = TestSession.createTestSession(tfMode); + AdaGrad instance = new AdaGrad(session.getGraph(), learningRate, 0.1f)) { - float learningRate = 3.0F; - AdaGrad instance = new AdaGrad(graph, learningRate, 0.1f); - Ops tf = instance.getTF(); + Ops tf = session.getTF(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); @@ -88,8 +75,6 @@ public void testBasic() { Constant<TFloat32> grads0 = tf.constant(grads0Init); Constant<TFloat32> grads1 = tf.constant(grads1Init); - - /* build the GradsAnvVars */ List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); @@ -97,6 +82,7 @@ public void testBasic() { Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradTest"); + @SuppressWarnings("unchecked") Variable<TFloat32>[] accumulatorSlots = new Variable[2]; accumulatorSlots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); assertEquals(accumulatorSlots[0].asOutput().shape(), var0.asOutput().shape()); @@ -116,7 +102,7 @@ public void testBasic() { session.evaluate(var1Init, var1); for (int step = 0; step < numSteps; step++) { - session.run(adaUpdate); + session.run(adaUpdate, instance.getFeedMap()); } float[] expected0 = {-1.6026098728179932f, -0.6026098728179932f}; session.evaluate(expected0, var0); @@ -125,18 +111,138 @@ public void testBasic() { } } - private FloatNdArray caclulateAccum(FloatNdArray accum, FloatNdArray grads) { - // accum + gT * gT - FloatNdArray squareG = ND.square(grads); - return ND.add(accum, squareG); + @Test + public void testBasicWithLROperand() { + int numSteps = 3; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + + float learningRate = 1.0F; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (AdaGrad instance = + new AdaGrad( + session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(3.f)), 0.1f)) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradTest"); + + @SuppressWarnings("unchecked") + Variable<TFloat32>[] accumulatorSlots = new Variable[2]; + accumulatorSlots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(accumulatorSlots[0].asOutput().shape(), var0.asOutput().shape()); + + accumulatorSlots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(accumulatorSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + for (int step = 0; step < numSteps; step++) { + session.run(adaUpdate, instance.getFeedMap()); + } + float[] expected0 = {-1.6026098728179932f, -0.6026098728179932f}; + session.evaluate(expected0, var0); + float[] expected1 = {2.715679168701172f, 3.715679168701172f}; + session.evaluate(expected1, var1); + } + } } - private FloatNdArray calculate( - FloatNdArray param, FloatNdArray accum, FloatNdArray grads, float learningRate) { - // param - lr * gT / (np.sqrt(accumT) + epsilon) - FloatNdArray divisor = ND.add(ND.sqrt(accum), 1e-07f); - FloatNdArray dividend = ND.mul(learningRate, grads); - FloatNdArray quotient = ND.div(dividend, divisor); - return ND.sub(param, quotient); + @Test + public void testWithLearningRateDecay() { + int numSteps = 3; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + float epsilon = 1e-8F; + float learningRate = 3.0F; + + try (TestSession session = TestSession.createTestSession(tfMode); + AdaGrad instance = new AdaGrad(session.getGraph(), learningRate)) { + Ops tf = session.getTF(); + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradTest"); + + @SuppressWarnings("unchecked") + Variable<TFloat32>[] accumulatorSlots = new Variable[2]; + accumulatorSlots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(accumulatorSlots[0].asOutput().shape(), var0.asOutput().shape()); + + accumulatorSlots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(accumulatorSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + // initialize the accumulators + session.run(tf.init()); + + // make sure the variables were initialized properly + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + float[][][] expected = { + {{-1.121320f, -0.121320f}, {2.701489f, 3.701489f}}, + {{-2.680166f, -1.680166f}, {2.434149f, 3.434149f}}, + {{-3.895166f, -2.895166f}, {2.194714f, 3.194714f}} + }; + for (int step = 0; step < numSteps; step++) { + assertEquals(learningRate, instance.getLearningRate(), epsilon); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(adaUpdate, instance.getFeedMap()); + + session.evaluate(expected[step][0], var0); + session.evaluate(expected[step][1], var1); + + learningRate *= 0.9; + instance.setLearningRate(learningRate); + } + } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java index 461fa75397f..c5bb153d804 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import org.tensorflow.Graph; import org.tensorflow.Tensor; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; @@ -66,18 +65,18 @@ public void testBasic() { FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); - float epsilon1 = 1e-3F; + float epsilon1 = 1e-3f; + float learningRate = 0.001f; + + try (TestSession session = TestSession.createTestSession(tfMode); + Adam instance = new Adam(session.getGraph(), learningRate)) { - try (TestSession session = TestSession.createTestSession(tfMode)) { - float learningRate = 0.001F; float beta1 = 0.9F; float beta2 = 0.999F; - Graph graph = session.getGraph(); session.setEpsilon(epsilon1); - Adam instance = new Adam(graph, learningRate); - Ops tf = instance.getTF(); + Ops tf = session.getTF(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); @@ -93,15 +92,11 @@ public void testBasic() { session.run(var0Initializer); session.run(var1Initializer); - - /* build the GradsAnvVars */ List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - - Op update = instance.applyGradients(gradsAndVars, "AdamTest"); /* Create and validate the shapes of the slots */ @@ -123,7 +118,7 @@ public void testBasic() { assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); /* initialize the accumulators */ - session.run(tf.init()); + session.run(tf.init(), instance.getFeedMap()); session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); @@ -160,7 +155,7 @@ public void testBasic() { .expect(TFloat32.DTYPE)) { result.data().scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); } - session.run(update); + session.run(update, instance.getFeedMap()); float lrT = learningRate @@ -190,13 +185,282 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + float epsilon1 = 1e-3f; + float learningRate = 0.001f; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Adam instance = + new Adam( + session.getGraph(), tf.constant(learningRate))) { + + float beta1 = 0.9F; + float beta2 = 0.999F; + + session.setEpsilon(epsilon1); + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validate the shapes of the slots */ + @SuppressWarnings("unchecked") + Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable<TFloat32>[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the accumulators */ + session.run(tf.init(), instance.getFeedMap()); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + FloatNdArray m0Np = NdArrays.ofFloats(shape0); + FloatNdArray v0Np = NdArrays.ofFloats(shape0); + FloatNdArray m1Np = NdArrays.ofFloats(shape1); + FloatNdArray v1Np = NdArrays.ofFloats(shape1); + + for (int step = 0; step < 3; step++) { + + // Test powers + final float[] powers = { + (float) Math.pow(beta1, step + 1), (float) Math.pow(beta2, step + 1) + }; + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); + } + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("beta2_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); + } + session.run(update, instance.getFeedMap()); + + float lrT = + learningRate + * (float) Math.sqrt(1 - (float) Math.pow(beta2, (step + 1))) + / (1 - (float) Math.pow(beta1, (step + 1))); + + m0Np = calculateM(m0Np, grads0Np, beta1); + v0Np = calculateV(v0Np, grads0Np, beta2); + var0Np = calculateParam(var0Np, lrT, m0Np, v0Np, 1e-7F); + + m1Np = calculateM(m1Np, grads1Np, beta1); + v1Np = calculateV(v1Np, grads1Np, beta2); + var1Np = calculateParam(var1Np, lrT, m1Np, v1Np, 1e-7F); + + // evaluate var 0 and var1 + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + + // first moment + session.evaluate(m0Np, firstMomentSlots[0]); + session.evaluate(m1Np, firstMomentSlots[1]); + + // second moment + session.evaluate(v0Np, secondMomentSlots[0]); + session.evaluate(v1Np, secondMomentSlots[1]); + } + } + } + } + + @Test + public void testWithLearningRateDecay() { + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + float epsilon1 = 1e-3F; + float learningRate = 0.001F; + float beta1 = 0.9F; + float beta2 = 0.999F; + + try (TestSession session = TestSession.createTestSession(tfMode); + Adam instance = new Adam(session.getGraph(), learningRate)) { + Ops tf = session.getTF(); + + session.setEpsilon(epsilon1); + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validate the shapes of the slots */ + @SuppressWarnings("unchecked") + Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable<TFloat32>[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + // initialize the accumulators + session.run(tf.init()); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + FloatNdArray m0Np = NdArrays.ofFloats(shape1); + FloatNdArray v0Np = NdArrays.ofFloats(shape1); + FloatNdArray m1Np = NdArrays.ofFloats(shape1); + FloatNdArray v1Np = NdArrays.ofFloats(shape1); + + for (int step = 0; step < 3; step++) { + + // Test powers + final float[] powers = { + (float) Math.pow(beta1, step + 1), (float) Math.pow(beta2, step + 1) + }; + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); + } + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("beta2_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); + } + assertEquals(learningRate, instance.getLearningRate(), 1e-6f); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); + + float lr_t = + learningRate + * (float) Math.sqrt(1 - (float) Math.pow(beta2, (step + 1))) + / (1 - (float) Math.pow(beta1, (step + 1))); + + m0Np = calculateM(m0Np, grads0Np, beta1); + v0Np = calculateV(v0Np, grads0Np, beta2); + var0Np = calculateParam(var0Np, lr_t, m0Np, v0Np, 1e-7F); + + m1Np = calculateM(m1Np, grads1Np, beta1); + v1Np = calculateV(v1Np, grads1Np, beta2); + var1Np = calculateParam(var1Np, lr_t, m1Np, v1Np, 1e-7F); + + // evaluate var 0 and var1 + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + + // first moment + session.evaluate(m0Np, firstMomentSlots[0]); + session.evaluate(m1Np, firstMomentSlots[1]); + + // second moment + session.evaluate(v0Np, secondMomentSlots[0]); + session.evaluate(v1Np, secondMomentSlots[1]); + + learningRate *= 0.9; + instance.setLearningRate(learningRate); + } + } + } + private FloatNdArray calculateM(FloatNdArray m, FloatNdArray gT, float beta) { - // mT = beta1 * m + (1 - beta1) * gT return ND.add(ND.mul(m, beta), ND.mul(gT, (1 - beta))); } private FloatNdArray calculateV(FloatNdArray v, FloatNdArray gT, float beta) { - // beta2 * v + (1 - beta2) * gT * gT FloatNdArray mul1 = ND.mul(v, beta); FloatNdArray squareG = ND.square(gT); FloatNdArray mul2 = ND.mul((1 - beta), squareG); @@ -204,11 +468,10 @@ private FloatNdArray calculateV(FloatNdArray v, FloatNdArray gT, float beta) { } private FloatNdArray calculateParam( - FloatNdArray param, float lrT, FloatNdArray m, FloatNdArray v, float epsilon) { - // param - lrT * mT / (np.sqrt(vT) + epsilon) + FloatNdArray param, float lr_t, FloatNdArray m, FloatNdArray v, float epsilon) { FloatNdArray sqrt = ND.sqrt(v); FloatNdArray divisor = ND.add(sqrt, epsilon); - FloatNdArray dividend = ND.mul(lrT, m); + FloatNdArray dividend = ND.mul(lr_t, m); FloatNdArray quotient = ND.div(dividend, divisor); return ND.sub(param, quotient); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java index de17395f76a..e900018ccad 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import org.tensorflow.Graph; import org.tensorflow.Tensor; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; @@ -61,10 +60,10 @@ public void tearDown() {} /** Test of getOptimizerName method, of class Adamax. */ @Test public void testGetOptimizerName() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Adamax instance = new Adamax(graph); - String expResult = "Adamax"; + try (TestSession session = TestSession.createTestSession(tfMode); + Adamax instance = new Adamax(session.getGraph())) { + + String expResult = DEFAULT_NAME; String result = instance.getOptimizerName(); assertEquals(expResult, result); } @@ -93,11 +92,9 @@ public void testBasic() { float epsilon1 = 1e-3F; - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - - Adamax instance = new Adamax(graph); - Ops tf = instance.getTF(); + try (TestSession session = TestSession.createTestSession(tfMode); + Adamax instance = new Adamax(session.getGraph())) { + Ops tf = session.getTF(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); @@ -114,7 +111,6 @@ public void testBasic() { session.run(var0Initializer); session.run(var1Initializer); - /* build the GradsAnvVars */ List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); @@ -123,7 +119,9 @@ public void testBasic() { Op update = instance.applyGradients(gradsAndVars, "AdamTest"); /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") Variable<TFloat32>[] secondMomentSlots = new Variable[2]; firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); @@ -159,7 +157,7 @@ public void testBasic() { .expect(TFloat32.DTYPE)) { result.data().scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1)); } - session.run(update); + session.run(update, instance.getFeedMap()); FloatNdArray[] resultNP = calculate(var0Np, grads0Np, step, m0, v0); var0Np = resultNP[VAR]; @@ -179,20 +177,252 @@ public void testBasic() { } } + /** Test of applyDense method, of class Adamax. */ + @Test + public void testBasicWithLROperand() { + + int numSteps = 3; + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + + float[] zeros = {0.0F, 0.0F}; + FloatNdArray m0 = NdArrays.vectorOf(zeros); + FloatNdArray v0 = NdArrays.vectorOf(zeros); + FloatNdArray m1 = NdArrays.vectorOf(zeros); + FloatNdArray v1 = NdArrays.vectorOf(zeros); + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + float epsilon1 = 1e-3f; + float learningRate = 1f; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Adamax instance = + new Adamax( + session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(1e-3f)))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* build the GradsAnvVars */ + List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") + Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable<TFloat32>[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + session.setEpsilon(epsilon1); + for (int step = 0; step < numSteps; step++) { + // Test powers + final float beta1Power = (float) Math.pow(BETA_ONE_DEFAULT, step + 1); + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1)); + } + session.run(update, instance.getFeedMap()); + + FloatNdArray[] resultNP = calculate(var0Np, grads0Np, step, m0, v0); + var0Np = resultNP[VAR]; + m0 = resultNP[M]; + v0 = resultNP[V]; + + resultNP = calculate(var1Np, grads1Np, step, m1, v1); + var1Np = resultNP[VAR]; + m1 = resultNP[M]; + v1 = resultNP[V]; + + // evaluate var0 and var1 + + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + } + } + } + } + + @Test + public void testWithLearningRateDecay() { + + float epsilon = 1e-6f; + float epsilon1 = 1e-3F; + int numSteps = 3; + float learningRate = 0.001F; + + try (TestSession session = TestSession.createTestSession(tfMode); + Adamax instance = new Adamax(session.getGraph(), learningRate)) { + Ops tf = session.getTF(); + float[] zeros = {0.0F, 0.0F}; + FloatNdArray m0 = NdArrays.vectorOf(zeros); + FloatNdArray v0 = NdArrays.vectorOf(zeros); + FloatNdArray m1 = NdArrays.vectorOf(zeros); + FloatNdArray v1 = NdArrays.vectorOf(zeros); + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") + Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable<TFloat32>[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + session.setEpsilon(epsilon1); + for (int step = 0; step < numSteps; step++) { + final float betaPower = (float) Math.pow(BETA_ONE_DEFAULT, step + 1); + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(betaPower, f.getFloat(), epsilon1)); + } + assertEquals(learningRate, instance.getLearningRate(), epsilon); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); + + FloatNdArray[] resultNP = calculate(var0Np, grads0Np, step, m0, v0, learningRate); + var0Np = resultNP[VAR]; + m0 = resultNP[M]; + v0 = resultNP[V]; + + resultNP = calculate(var1Np, grads1Np, step, m1, v1, learningRate); + var1Np = resultNP[VAR]; + m1 = resultNP[M]; + v1 = resultNP[V]; + + // evaluate var0 and var1 + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + + learningRate *= 0.9F; + instance.setLearningRate(learningRate); + } + } + } + private FloatNdArray[] calculate( FloatNdArray varNp, FloatNdArray gradsNp, int step, FloatNdArray m, FloatNdArray v) { - float alpha = 0.001F; + return calculate(varNp, gradsNp, step, m, v, 0.001F); + } + + private FloatNdArray[] calculate( + FloatNdArray varNp, + FloatNdArray gradsNp, + int step, + FloatNdArray m, + FloatNdArray v, + float alpha) { + float beta1 = BETA_ONE_DEFAULT; + float beta2 = BETA_TWO_DEFAULT; float espilon = 1e-8F; - float oneMinusBeta1 = 1.F - BETA_ONE_DEFAULT; - float oneMinusBeta1Pow = 1.F - (float) Math.pow(BETA_ONE_DEFAULT, step + 1); + float oneMinusBeta1 = 1.F - beta1; + float oneMinusBeta1Pow = 1.F - (float) Math.pow(beta1, step + 1); float alpha1 = alpha / oneMinusBeta1Pow; - // beta1 * m + (1 - beta1) * gT; - m = ND.add(ND.mul(BETA_ONE_DEFAULT, m), ND.mul(oneMinusBeta1, gradsNp)); - // np.maximum(BETA_TWO_DEFAULT * v, np.abs(gT)) - v = ND.max(ND.mul(BETA_TWO_DEFAULT, v), ND.abs(gradsNp)); - // paramT = param - (alpha / (1 - beta1**(t + 1))) * (mT / (vT + epsilon)) + // beta1 * m + (1 - beta1) * g_t; + m = ND.add(ND.mul(beta1, m), ND.mul(oneMinusBeta1, gradsNp)); + // np.maximum(beta2 * v, np.abs(g_t)) + v = ND.max(ND.mul(beta2, v), ND.abs(gradsNp)); + // param_t = param - (alpha / (1 - beta1**(t + 1))) * (m_t / (v_t + epsilon)) varNp = ND.sub(varNp, ND.mul(alpha1, ND.div(m, ND.add(v, espilon)))); FloatNdArray[] result = new FloatNdArray[3]; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java index 597f8e52bcd..c047cd4bf40 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java @@ -67,12 +67,22 @@ public void testFtrlWithL1L2L2Shrinkage() { float[] var1Init = {4.0F, 3.0F}; float[] grads0Init = {0.1F, 0.2F}; float[] grads1Init = {0.01F, 0.02F}; + float learningRate = 3.0F; int numSteps = 10; - try (TestSession session = TestSession.createTestSession(tfMode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Ftrl instance = + new Ftrl( + session.getGraph(), + learningRate, + -0.5F, // learningRatePower + 0.1F, // initialAccumulatorValue + 0.001F, // l1RegularizationStrength + 2.0F, // l2RegularizationStrength + 0.1F // l2ShrinkageRegularizationStrength + )) { Ops tf = session.getTF(); - Graph graph = session.getGraph(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); @@ -85,19 +95,6 @@ public void testFtrlWithL1L2L2Shrinkage() { Constant<TFloat32> grads0 = tf.constant(grads0Init); Constant<TFloat32> grads1 = tf.constant(grads1Init); - float learningRate = 3.0F; - - Ftrl instance = - new Ftrl( - graph, - learningRate, - -0.5F, // learningRatePower - 0.1F, // initialAccumulatorValue - 0.001F, // l1RegularizationStrength - 2.0F, // l2RegularizationStrength - 0.1F // l2ShrinkageRegularizationStrength - ); - /* build the GradsAnvVars */ List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); @@ -116,7 +113,7 @@ public void testFtrlWithL1L2L2Shrinkage() { session.evaluate(var1Init, var1); for (int i = 0; i < numSteps; i++) { - session.run(ftrlUpdate); + session.run(ftrlUpdate, instance.getFeedMap()); } float[] expectedVar0 = {-0.22578995F, -0.44345796F}; @@ -126,6 +123,69 @@ public void testFtrlWithL1L2L2Shrinkage() { } } + @Test + public void testFtrlWithL1L2L2ShrinkageWithLROperand() { + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {4.0F, 3.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.02F}; + float learningRate = 1.0F; + + int numSteps = 10; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Ftrl instance = + new Ftrl( + session.getGraph(), + tf.math.mul(tf.constant(learningRate), tf.constant(3f)), + -0.5F, // learningRatePower + 0.1F, // initialAccumulatorValue + 0.001F, // l1RegularizationStrength + 2.0F, // l2RegularizationStrength + 0.1F // l2ShrinkageRegularizationStrength + )) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op ftrlUpdate = instance.applyGradients(gradsAndVars, "FtrlTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + for (int i = 0; i < numSteps; i++) { + session.run(ftrlUpdate, instance.getFeedMap()); + } + + float[] expectedVar0 = {-0.22578995F, -0.44345796F}; + session.evaluate(expectedVar0, var0); + float[] expectedVar1 = {-0.14378493F, -0.13229476F}; + session.evaluate(expectedVar1, var1); + } + } + } + @Test public void testFtrlWithL1() { float[] var0Init = {1.0F, 2.0F}; @@ -181,7 +241,7 @@ public void testFtrlWithL1() { session.evaluate(var1Init, var1); for (int i = 0; i < numSteps; i++) { - session.run(ftrlUpdate); + session.run(ftrlUpdate, instance.getFeedMap()); } float[] expectedVar0 = {-7.66718769F, -10.91273689F}; @@ -247,7 +307,7 @@ public void testFtrlWithL1L2() { session.evaluate(var1Init, var1); for (int i = 0; i < numSteps; i++) { - session.run(ftrlUpdate); + session.run(ftrlUpdate, instance.getFeedMap()); } float[] expectedVar0 = {-0.24059935F, -0.46829352F}; @@ -258,6 +318,86 @@ public void testFtrlWithL1L2() { } } + @Test + public void testChangingLearningRate() { + float learningRate = 3.0F; + float epsilon = 1e-8f; + try (TestSession session = TestSession.createTestSession(tfMode); + Ftrl instance = + new Ftrl( + session.getGraph(), + learningRate, + Ftrl.LEARNING_RATE_POWER_DEFAULT, + 0.1F, + 0.001F, + 2.0F, + Ftrl.L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT)) { + Ops tf = session.getTF(); + int numSteps = 10; + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {4.0F, 3.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.02F}; + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op ftrlUpdate = instance.applyGradients(gradsAndVars, "FtrlTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + // initialize the accumulators + session.run(tf.init()); + float expected[][][] = { + // Step: 0 + {{-0.022833f, -0.038881f}, {-0.002141f, -0.004474f}}, + // Step: 1 + {{-0.203440f, -0.336012f}, {-0.021145f, -0.042048f}}, + // Step: 2 + {{-0.667457f, -0.962730f}, {-0.074745f, -0.147685f}}, + // Step: 3 + {{-0.854973f, -1.163154f}, {-0.099466f, -0.196172f}}, + // Step: 4 + {{-0.878825f, -1.186153f}, {-0.102859f, -0.202808f}}, + // Step: 5 + {{-0.881205f, -1.188360f}, {-0.103211f, -0.203495f}}, + // Step: 6 + {{-0.881436f, -1.188569f}, {-0.103246f, -0.203564f}}, + // Step: 7 + {{-0.881459f, -1.188589f}, {-0.103250f, -0.203571f}}, + // Step: 8 + {{-0.881461f, -1.188591f}, {-0.103250f, -0.203572f}}, + // Step: 9 + {{-0.881462f, -1.188591f}, {-0.103250f, -0.203572f}}, + }; + for (int i = 0; i < numSteps; i++) { + assertEquals(learningRate, instance.getLearningRate(), epsilon); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(ftrlUpdate, instance.getFeedMap()); + session.evaluate(expected[i][0], var0); + session.evaluate(expected[i][1], var1); + learningRate *= 0.1f; + instance.setLearningRate(learningRate); + } + } + } + @Test public void doTestFtrlwithoutRegularization() { float[] var0Init = {0.0F, 0.0F}; @@ -266,10 +406,11 @@ public void doTestFtrlwithoutRegularization() { float[] grads1Init = {0.01F, 0.02F}; int numSteps = 3; + float learningRate = 3.0f; - try (TestSession session = TestSession.createTestSession(tfMode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Ftrl instance = new Ftrl(session.getGraph(), learningRate)) { Ops tf = session.getTF(); - Graph graph = session.getGraph(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); @@ -282,10 +423,6 @@ public void doTestFtrlwithoutRegularization() { Constant<TFloat32> grads0 = tf.constant(grads0Init); Constant<TFloat32> grads1 = tf.constant(grads1Init); - float learningRate = 3.0F; - - Ftrl instance = new Ftrl(graph, learningRate); - /* build the GradsAnvVars */ List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); @@ -303,7 +440,7 @@ public void doTestFtrlwithoutRegularization() { session.evaluate(var1Init, var1); for (int i = 0; i < numSteps; i++) { - session.run(ftrlUpdate); + session.run(ftrlUpdate, instance.getFeedMap()); } float[] expectedVar0 = {-2.60260963F, -4.29698515F}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index 4362c54d815..ce687186994 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -55,9 +55,9 @@ public void testBasic() { float[] grads1Init = {0.01F, 0.01F}; float learningRate = 3.0F; - try (TestSession session = TestSession.createTestSession(tfMode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + GradientDescent instance = new GradientDescent(session.getGraph(), learningRate)) { Ops tf = session.getTF(); - Graph graph = session.getGraph(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); @@ -75,7 +75,6 @@ public void testBasic() { gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - GradientDescent instance = new GradientDescent(graph, learningRate); Op update = instance.applyGradients(gradsAndVars, "SGDTest"); /* initialize the local variables */ @@ -89,7 +88,7 @@ public void testBasic() { session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); - session.run(update); // 1 step + session.run(update, instance.getFeedMap()); // 1 step float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F}; float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F}; @@ -97,4 +96,126 @@ public void testBasic() { session.evaluate(expectedVar1, var1); } } + + @Test + public void testBasicWithLROperand() { + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + float learningRate = 1.5f; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (GradientDescent instance = + new GradientDescent( + session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(2f)))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "SGDTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + session.run(update, instance.getFeedMap()); // 1 step + + float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F}; + float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F}; + session.evaluate(expectedVar0, var0); + session.evaluate(expectedVar1, var1); + } + } + } + + @Test + public void testWithLearningRateDecay() { + int numSteps = 2; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + + float learningRate = 3.0F; + + try (TestSession session = TestSession.createTestSession(tfMode); + GradientDescent instance = new GradientDescent(session.getGraph(), learningRate)) { + Ops tf = session.getTF(); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "GradientDescentTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + // initialize the accumulators + session.run(tf.init()); + + // make sure the variables were initialized properly + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + float[][] expectedVar0 = { + {0.7F, 1.7F}, + {0.66999996F, 1.6700001F}, + {0.66699994F, 1.667F}, + {0.66669995F, 1.6667F}, + {0.66666996F, 1.66667F} + }; + float[][] expectedVar1 = { + {2.97F, 3.97F}, + {2.967F, 3.967F}, + {2.9667F, 3.9667F}, + {2.96667F, 3.96667F}, + {2.966667F, 3.966667F} + }; + for (int step = 0; step < numSteps; step++) { + assertEquals(learningRate, instance.getLearningRate(), 1e-6f); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); + session.evaluate(expectedVar0[step], var0); + session.evaluate(expectedVar1[step], var1); + learningRate *= 0.1; + instance.setLearningRate(learningRate); + } + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java index bcfff97773d..014c72e55e6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java @@ -71,9 +71,9 @@ public void testBasic() { float[] grads1Init = {0.01F, 0.01F}; float learningRate = 3.0F; - try (TestSession session = TestSession.createTestSession(tfMode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Momentum instance = new Momentum(session.getGraph(), learningRate)) { Ops tf = session.getTF(); - Graph graph = session.getGraph(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); @@ -91,7 +91,6 @@ public void testBasic() { gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - Momentum instance = new Momentum(graph, learningRate); Op update = instance.applyGradients(gradsAndVars, "SGDTest"); /* initialize the local variables */ @@ -105,7 +104,7 @@ public void testBasic() { session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); - session.run(update); // 1 step + session.run(update, instance.getFeedMap()); // 1 step float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F}; float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F}; @@ -114,6 +113,57 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + float learningRate = 3.0F; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Momentum instance = new Momentum(session.getGraph(), tf.constant(learningRate))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "SGDTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + session.run(update, instance.getFeedMap()); // 1 step + + float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F}; + float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F}; + session.evaluate(expectedVar0, var0); + session.evaluate(expectedVar1, var1); + } + } + } + @Test public void testMomentum() { float[] var0Init = {1.0F, 2.0F}; @@ -124,9 +174,9 @@ public void testMomentum() { float learningRate = 2.0F; float momentum = 0.9F; - try (TestSession session = TestSession.createTestSession(tfMode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Momentum instance = new Momentum(session.getGraph(), learningRate, momentum)) { Ops tf = session.getTF(); - Graph graph = session.getGraph(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); @@ -144,7 +194,6 @@ public void testMomentum() { gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - Momentum instance = new Momentum(graph, learningRate, momentum); Op update = instance.applyGradients(gradsAndVars, "SGDTest"); Variable<TFloat32> momentumSlot0 = instance.getSlot(var0.asOutput(), MOMENTUM).get(); @@ -163,7 +212,7 @@ public void testMomentum() { session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); - session.run(update); // 1 step + session.run(update, instance.getFeedMap()); // 1 step float[] expectedMomentum0 = {0.1F, 0.1F}; float[] expectedMomentum1 = {0.01F, 0.01F}; @@ -175,7 +224,7 @@ public void testMomentum() { session.evaluate(expectedVar0, var0); session.evaluate(expectedVar1, var1); - session.run(update); // step 2 + session.run(update, instance.getFeedMap()); // step 2 float[] expectedMomentum02 = {(0.9f * 0.1f + 0.1f), (0.9f * 0.1f + 0.1f)}; float[] expectedMomentum12 = {(0.9f * 0.01f + 0.01f), (0.9f * 0.01f + 0.01f)}; @@ -193,4 +242,78 @@ public void testMomentum() { session.evaluate(expectedVar12, var1); } } + + @Test + public void testWithLearningRateDecay() { + int numSteps = 2; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + + float learningRate = 3.0F; + + try (TestSession session = TestSession.createTestSession(tfMode); + Momentum instance = new Momentum(session.getGraph(), learningRate)) { + Ops tf = session.getTF(); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "MomentumTest"); + + Variable<TFloat32> momentumSlot0 = instance.getSlot(var0.asOutput(), MOMENTUM).get(); + assertEquals(momentumSlot0.asOutput().shape(), var0.asOutput().shape()); + Variable<TFloat32> momentumSlot1 = instance.getSlot(var1.asOutput(), MOMENTUM).get(); + assertEquals(momentumSlot1.asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + // initialize the accumulators + session.run(tf.init()); + + // make sure the variables were initialized properly + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + float[][] expectedVar0 = { + {0.7F, 1.7F}, + {0.66999996F, 1.6700001F}, + {0.66699994F, 1.667F}, + {0.66669995F, 1.6667F}, + {0.66666996F, 1.66667F} + }; + float[][] expectedVar1 = { + {2.97F, 3.97F}, + {2.967F, 3.967F}, + {2.9667F, 3.9667F}, + {2.96667F, 3.96667F}, + {2.966667F, 3.966667F} + }; + for (int step = 0; step < numSteps; step++) { + assertEquals(learningRate, instance.getLearningRate(), 1e-6); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); + session.evaluate(expectedVar0[step], var0); + session.evaluate(expectedVar1[step], var1); + learningRate *= 0.1; + instance.setLearningRate(learningRate); + } + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index a583d74246b..0832543c104 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import org.tensorflow.Graph; import org.tensorflow.Tensor; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; @@ -62,9 +61,9 @@ public void tearDown() {} /** Test of getOptimizerName method, of class Nadam. */ @Test public void testGetOptimizerName() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Nadam instance = new Nadam(graph); + try (TestSession session = TestSession.createTestSession(tfMode); + Nadam instance = new Nadam(session.getGraph())) { + String expResult = "Nadam"; String result = instance.getOptimizerName(); assertEquals(expResult, result); @@ -96,9 +95,9 @@ public void testBasic() { float epsilon1 = 1e-3F; - try (TestSession session = TestSession.createTestSession(tfMode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Nadam instance = new Nadam(session.getGraph())) { Ops tf = session.getTF(); - Graph graph = session.getGraph(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); @@ -111,7 +110,6 @@ public void testBasic() { Constant<TFloat32> grads0 = tf.constant(grads0Init); Constant<TFloat32> grads1 = tf.constant(grads1Init); - Nadam instance = new Nadam(graph); /* build the GradsAnvVars */ List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); @@ -120,7 +118,9 @@ public void testBasic() { Op update = instance.applyGradients(gradsAndVars, "AdamTest"); /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") Variable<TFloat32>[] secondMomentSlots = new Variable[2]; firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get(); @@ -161,7 +161,7 @@ public void testBasic() { for (int step = 0; step < numSteps; step++) { - session.run(update); + session.run(update, instance.getFeedMap()); float mut = Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); @@ -203,6 +203,280 @@ public void testBasic() { } } + /** Test of applyDense method, of class Nadam. */ + @Test + public void testBasicWithLROperand() { + + int numSteps = 3; + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; + FloatNdArray m0 = NdArrays.vectorOf(zeros); + FloatNdArray v0 = NdArrays.vectorOf(zeros); + FloatNdArray m1 = NdArrays.vectorOf(zeros); + FloatNdArray v1 = NdArrays.vectorOf(zeros); + FloatNdArray mcache = NdArrays.vectorOf(ones); + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + float epsilon1 = 1e-3F; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Nadam instance = + new Nadam(session.getGraph(), tf.math.mul(tf.constant(1f), tf.constant(1e-3f)))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") + Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable<TFloat32>[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + session.setEpsilon(epsilon1); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); + } + momentum = 1F; + + for (int step = 0; step < numSteps; step++) { + + session.run(update, instance.getFeedMap()); + + float mut = + Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); + momentum = momentum * mut; + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); + } + mcache = ND.mul(mcache, momentum); + FloatNdArray[] resultsNP = nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache); + var0Np = resultsNP[VAR]; + m0 = resultsNP[M]; + v0 = resultsNP[V]; + + resultsNP = nadamUpdateNdArray(var1Np, grads1Np, step, m1, v1, mcache); + var1Np = resultsNP[VAR]; + m1 = resultsNP[M]; + v1 = resultsNP[V]; + + // evaluate m0 and m1 + session.evaluate(m0, firstMomentSlots[0]); + session.evaluate(m1, firstMomentSlots[1]); + + // evaluate v0 and v1 + session.evaluate(v0, secondMomentSlots[0]); + session.evaluate(v1, secondMomentSlots[1]); + + // evaluate var0 and var1 + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + } + } + } + } + + @Test + public void testWithLearningRateDecay() { + int numSteps = 3; + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; + FloatNdArray m0 = NdArrays.vectorOf(zeros); + FloatNdArray v0 = NdArrays.vectorOf(zeros); + FloatNdArray m1 = NdArrays.vectorOf(zeros); + FloatNdArray v1 = NdArrays.vectorOf(zeros); + FloatNdArray mcache = NdArrays.vectorOf(ones); + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + float epsilon1 = 1e-3F; + + float learningRate = 0.001F; + + try (TestSession session = TestSession.createTestSession(tfMode); + Nadam instance = new Nadam(session.getGraph(), learningRate)) { + Ops tf = session.getTF(); + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") + Variable<TFloat32>[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable<TFloat32>[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + // initialize the accumulators + session.run(tf.init()); + + session.setEpsilon(epsilon1); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); + } + momentum = 1F; + + for (int step = 0; step < numSteps; step++) { + assertEquals(learningRate, instance.getLearningRate(), 1e-6f); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); + + float mut = + Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); + momentum = momentum * mut; + + try (Tensor<TFloat32> result = + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); + } + mcache = ND.mul(mcache, momentum); + FloatNdArray[] resultsNP = + nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache, learningRate); + var0Np = resultsNP[VAR]; + m0 = resultsNP[M]; + v0 = resultsNP[V]; + + resultsNP = nadamUpdateNdArray(var1Np, grads1Np, step, m1, v1, mcache, learningRate); + var1Np = resultsNP[VAR]; + m1 = resultsNP[M]; + v1 = resultsNP[V]; + + // evaluate m0 and m1 + session.evaluate(m0, firstMomentSlots[0]); + session.evaluate(m1, firstMomentSlots[1]); + + // evaluate v0 and v1 + session.evaluate(v0, secondMomentSlots[0]); + session.evaluate(v1, secondMomentSlots[1]); + + // evaluate var0 and var1 + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + + learningRate *= 0.9; + instance.setLearningRate(learningRate); + } + } + } + private FloatNdArray[] nadamUpdateNdArray( FloatNdArray varNp, FloatNdArray gradsNp, @@ -210,8 +484,18 @@ private FloatNdArray[] nadamUpdateNdArray( FloatNdArray m, FloatNdArray v, FloatNdArray mCache) { + return nadamUpdateNdArray(varNp, gradsNp, t, m, v, mCache, 0.001F); + } + + private FloatNdArray[] nadamUpdateNdArray( + FloatNdArray varNp, + FloatNdArray gradsNp, + int t, + FloatNdArray m, + FloatNdArray v, + FloatNdArray mCache, + float alpha) { - float alpha = 0.001F; float beta1 = 0.9F; float beta2 = 0.999F; float epsilon = 1e-8F; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java index 78b56c8289e..a0bf027abab 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java @@ -1,7 +1,6 @@ package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import org.tensorflow.Graph; import org.tensorflow.framework.utils.TestSession; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -26,9 +25,8 @@ public void tearDown() {} /** Test ADADELTA enum */ @Test public void testADADELTA() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Optimizer instance = Optimizers.ADADELTA.createOptimizer(graph); + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADADELTA.createOptimizer(session.getGraph())) { String expResult = "Adadelta"; String result = instance.getOptimizerName(); assertEquals(expResult, result); @@ -38,9 +36,8 @@ public void testADADELTA() { /** Test ADAGRAD enum */ @Test public void testADAGRAD() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Optimizer instance = Optimizers.ADAGRAD.createOptimizer(graph); + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADAGRAD.createOptimizer(session.getGraph())) { String expResult = "Adagrad"; String result = instance.getOptimizerName(); assertEquals(expResult, result); @@ -50,9 +47,8 @@ public void testADAGRAD() { /** Test ADAGRAD_DA enum */ @Test public void testADAGRAD_DA() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Optimizer instance = Optimizers.ADAGRAD_DA.createOptimizer(graph); + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADAGRAD_DA.createOptimizer(session.getGraph())) { String expResult = "adagrad-da"; String result = instance.getOptimizerName(); assertEquals(expResult, result); @@ -62,9 +58,8 @@ public void testADAGRAD_DA() { /** Test ADAM enum */ @Test public void testADAM() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Optimizer instance = Optimizers.ADAM.createOptimizer(graph); + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADAM.createOptimizer(session.getGraph())) { String expResult = "Adam"; String result = instance.getOptimizerName(); assertEquals(expResult, result); @@ -74,9 +69,8 @@ public void testADAM() { /** Test ADAMAX enum */ @Test public void testADAMAX() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Optimizer instance = Optimizers.ADAMAX.createOptimizer(graph); + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADAMAX.createOptimizer(session.getGraph())) { String expResult = "Adamax"; String result = instance.getOptimizerName(); assertEquals(expResult, result); @@ -86,9 +80,8 @@ public void testADAMAX() { /** Test FTRL enum */ @Test public void testFTRL() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Optimizer instance = Optimizers.FTRL.createOptimizer(graph); + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.FTRL.createOptimizer(session.getGraph())) { String expResult = "Ftrl"; String result = instance.getOptimizerName(); assertEquals(expResult, result); @@ -98,9 +91,8 @@ public void testFTRL() { /** Test NADAM enum */ @Test public void testNADAM() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Optimizer instance = Optimizers.NADAM.createOptimizer(graph); + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.NADAM.createOptimizer(session.getGraph())) { String expResult = "Nadam"; String result = instance.getOptimizerName(); assertEquals(expResult, result); @@ -110,9 +102,8 @@ public void testNADAM() { /** Test RMSPROP enum */ @Test public void testRMSPROP() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Optimizer instance = Optimizers.RMSPROP.createOptimizer(graph); + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.RMSPROP.createOptimizer(session.getGraph())) { String expResult = "RMSProp"; String result = instance.getOptimizerName(); assertEquals(expResult, result); @@ -122,9 +113,8 @@ public void testRMSPROP() { /** Test MOMENTUM enum */ @Test public void testMOMENTUM() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Optimizer instance = Optimizers.MOMENTUM.createOptimizer(graph); + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.MOMENTUM.createOptimizer(session.getGraph())) { String expResult = "Momentum"; String result = instance.getOptimizerName(); assertEquals(expResult, result); @@ -134,9 +124,8 @@ public void testMOMENTUM() { /** Test GRADIENT_DESCENT enum */ @Test public void testGRADIENT_DESCENT() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Optimizer instance = Optimizers.GRADIENT_DESCENT.createOptimizer(graph); + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.GRADIENT_DESCENT.createOptimizer(session.getGraph())) { String expResult = "GradientDescent"; String result = instance.getOptimizerName(); assertEquals(expResult, result); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java index 202fb21ef68..711358e9222 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import org.tensorflow.Graph; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -32,6 +31,8 @@ import java.util.ArrayList; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; import static org.tensorflow.framework.optimizers.RMSProp.*; /** Test cases for RMSProp Optimizer */ @@ -42,7 +43,7 @@ public class RMSPropTest { final int MOM_T = 3; private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; - Object[][] TestParamValues = { + Object[][] testParamValues = { // learningRate, rho (decay), momentum, epsilon, centered {0.05F, 0.9F, 0.0F, 1e-3F, true}, {0.05F, 0.9F, 0.0F, 1e-3F, false}, @@ -70,10 +71,18 @@ public void testDense() { int numSteps = 3; - for (Object[] testParamValue : TestParamValues) { - try (TestSession session = TestSession.createTestSession(tfMode)) { + for (Object[] testParamValue : testParamValues) { + // learningRate, rho (decay), momentum, epsilon, centered + float learningRate = (float) testParamValue[0]; + float decay = (float) testParamValue[1]; + float momentum = (float) testParamValue[2]; + float epsilon = (float) testParamValue[3]; + boolean centered = (boolean) testParamValue[4]; + try (TestSession session = TestSession.createTestSession(tfMode); + RMSProp instance = + new RMSProp(session.getGraph(), learningRate, decay, momentum, epsilon, centered)) { Ops tf = session.getTF(); - Graph graph = session.getGraph(); + session.setEpsilon(1e-2f); float[] var0Init = {1.0F, 2.0F}; float[] var1Init = {3.0F, 4.0F}; @@ -96,15 +105,6 @@ public void testDense() { Constant<TFloat32> grads0 = tf.constant(grads0Init); Constant<TFloat32> grads1 = tf.constant(grads1Init); - // learningRate, rho (decay), momentum, epsilon, centered - float learningRate = (float) (float) testParamValue[0]; - float decay = (float) testParamValue[1]; - float momentum = (float) testParamValue[2]; - float epsilon = (float) testParamValue[3]; - boolean centered = (boolean) testParamValue[4]; - - RMSProp instance = new RMSProp(graph, learningRate, decay, momentum, epsilon, centered); - /* build the GradsAnvVars */ List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); @@ -123,14 +123,30 @@ public void testDense() { session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); - Variable<TFloat32> mg0 = centered ? instance.getSlot(var0.asOutput(), MG).get() : null; - Variable<TFloat32> mg1 = centered ? instance.getSlot(var1.asOutput(), MG).get() : null; + Variable<TFloat32> mg0 = + centered && instance.getSlot(var0.asOutput(), MG).isPresent() + ? instance.getSlot(var0.asOutput(), MG).get() + : null; + Variable<TFloat32> mg1 = + centered && instance.getSlot(var1.asOutput(), MG).isPresent() + ? instance.getSlot(var1.asOutput(), MG).get() + : null; Variable<TFloat32> mom0 = - momentum > 0.F ? instance.getSlot(var0.asOutput(), MOMENTUM).get() : null; + momentum > 0.F && instance.getSlot(var0.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var0.asOutput(), MOMENTUM).get() + : null; Variable<TFloat32> mom1 = - momentum > 0.F ? instance.getSlot(var1.asOutput(), MOMENTUM).get() : null; - Variable<TFloat32> rms0 = instance.getSlot(var0.asOutput(), RMS).get(); - Variable<TFloat32> rms1 = instance.getSlot(var1.asOutput(), RMS).get(); + momentum > 0.F && instance.getSlot(var1.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var1.asOutput(), MOMENTUM).get() + : null; + Variable<TFloat32> rms0 = + instance.getSlot(var0.asOutput(), RMS).isPresent() + ? instance.getSlot(var0.asOutput(), RMS).get() + : null; + Variable<TFloat32> rms1 = + instance.getSlot(var1.asOutput(), RMS).isPresent() + ? instance.getSlot(var1.asOutput(), RMS).get() + : null; float[] zeros = {0.0F, 0.0F}; float[] ones = {1.0F, 1.0F}; // temp to match RMSProp @@ -142,7 +158,7 @@ public void testDense() { FloatNdArray mom1Np = NdArrays.vectorOf(zeros); for (int i = 0; i < numSteps; i++) { - session.run(update); + session.run(update, instance.getFeedMap()); FloatNdArray[] result0 = calc( var0Np, @@ -179,16 +195,18 @@ public void testDense() { mom1Np = result1[MOM_T]; if (centered) { - session.evaluate(mg0Np, mg0); - session.evaluate(mg1Np, mg1); + if (mg0 != null) session.evaluate(mg0Np, mg0); + if (mg1 != null) session.evaluate(mg1Np, mg1); } if (mom0 != null) session.evaluate(mom0Np, mom0); if (mom1 != null) session.evaluate(mom1Np, mom1); /* TODO the values returned from rms slot, do not match what I see in the python test */ - session.evaluate(rms0Np, rms0); - session.evaluate(rms1Np, rms1); + if (rms0 != null) session.evaluate(rms0Np, rms0); + else fail("rms0 is null"); + if (rms1 != null) session.evaluate(rms1Np, rms1); + else fail("rms1 is null"); session.evaluate(var0Np, var0); session.evaluate(var1Np, var1); @@ -197,6 +215,319 @@ public void testDense() { } } + @Test + public void testDenseWithLROperand() { + + int numSteps = 3; + + for (Object[] testParamValue : testParamValues) { + // learningRate, rho (decay), momentum, epsilon, centered + float learningRate = (float) testParamValue[0]; + float decay = (float) testParamValue[1]; + float momentum = (float) testParamValue[2]; + float epsilon = (float) testParamValue[3]; + boolean centered = (boolean) testParamValue[4]; + try (TestSession session = TestSession.createTestSession(tfMode)) { + + Ops tf = session.getTF(); + try (RMSProp instance = + new RMSProp( + session.getGraph(), + tf.math.add(tf.constant(learningRate), tf.constant(0f)), + decay, + momentum, + epsilon, + centered)) { + + session.setEpsilon(1e-2f); + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.2F}; + + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant<TFloat32> grads0 = tf.constant(grads0Init); + Constant<TFloat32> grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + Variable<TFloat32> mg0 = + centered && instance.getSlot(var0.asOutput(), MG).isPresent() + ? instance.getSlot(var0.asOutput(), MG).get() + : null; + Variable<TFloat32> mg1 = + centered && instance.getSlot(var1.asOutput(), MG).isPresent() + ? instance.getSlot(var1.asOutput(), MG).get() + : null; + Variable<TFloat32> mom0 = + momentum > 0.F && instance.getSlot(var0.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var0.asOutput(), MOMENTUM).get() + : null; + Variable<TFloat32> mom1 = + momentum > 0.F && instance.getSlot(var1.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var1.asOutput(), MOMENTUM).get() + : null; + Variable<TFloat32> rms0 = + instance.getSlot(var0.asOutput(), RMS).isPresent() + ? instance.getSlot(var0.asOutput(), RMS).get() + : null; + Variable<TFloat32> rms1 = + instance.getSlot(var1.asOutput(), RMS).isPresent() + ? instance.getSlot(var1.asOutput(), RMS).get() + : null; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; // temp to match RMSProp + FloatNdArray mg0Np = NdArrays.vectorOf(zeros); + FloatNdArray mg1Np = NdArrays.vectorOf(zeros); + FloatNdArray rms0Np = NdArrays.vectorOf(ones); + FloatNdArray rms1Np = NdArrays.vectorOf(ones); + FloatNdArray mom0Np = NdArrays.vectorOf(zeros); + FloatNdArray mom1Np = NdArrays.vectorOf(zeros); + + for (int i = 0; i < numSteps; i++) { + session.run(update, instance.getFeedMap()); + FloatNdArray[] result0 = + calc( + var0Np, + grads0Np, + mg0Np, + rms0Np, + mom0Np, + learningRate, + decay, + momentum, + epsilon, + centered); + var0Np = result0[VAR_T]; + mg0Np = result0[MG_T]; + rms0Np = result0[RMS_T]; + mom0Np = result0[MOM_T]; + + FloatNdArray[] result1 = + calc( + var1Np, + grads1Np, + mg1Np, + rms1Np, + mom1Np, + learningRate, + decay, + momentum, + epsilon, + centered); + + var1Np = result1[VAR_T]; + mg1Np = result1[MG_T]; + rms1Np = result1[RMS_T]; + mom1Np = result1[MOM_T]; + + if (centered) { + if (mg0 != null) session.evaluate(mg0Np, mg0); + if (mg1 != null) session.evaluate(mg1Np, mg1); + } + + if (mom0 != null) session.evaluate(mom0Np, mom0); + if (mom1 != null) session.evaluate(mom1Np, mom1); + + /* TODO the values returned from rms slot, do not match what I see in the python test */ + if (rms0 != null) session.evaluate(rms0Np, rms0); + else fail("rms0 is null"); + if (rms1 != null) session.evaluate(rms1Np, rms1); + else fail("rms1 is null"); + + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + } + } + } + } + } + + @Test + public void testWithLearningRateDecay() { + int numSteps = 3; + + for (Object[] testParamValue : testParamValues) { + float learningRate = (float) testParamValue[0]; + float decay = (float) testParamValue[1]; + float momentum = (float) testParamValue[2]; + float epsilon = (float) testParamValue[3]; + boolean centered = (boolean) testParamValue[4]; + + try (TestSession session = TestSession.createTestSession(tfMode); + RMSProp instance = + new RMSProp(session.getGraph(), learningRate, decay, momentum, epsilon, centered)) { + Ops tf = session.getTF(); + session.setEpsilon(1e-2f); + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.2F}; + float[] grads1_init = {0.01F, 0.2F}; + + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant<TFloat32> grads0 = tf.constant(grads0_init); + Constant<TFloat32> grads1 = tf.constant(grads1_init); + + /* build the GradsAnvVars */ + List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + // initialize the accumulators + session.run(tf.init()); + + // make sure the variables were initialized properly + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + Variable<TFloat32> mg0 = + centered && instance.getSlot(var0.asOutput(), MG).isPresent() + ? instance.getSlot(var0.asOutput(), MG).get() + : null; + Variable<TFloat32> mg1 = + centered && instance.getSlot(var1.asOutput(), MG).isPresent() + ? instance.getSlot(var1.asOutput(), MG).get() + : null; + Variable<TFloat32> mom0 = + momentum > 0.F && instance.getSlot(var0.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var0.asOutput(), MOMENTUM).get() + : null; + Variable<TFloat32> mom1 = + momentum > 0.F && instance.getSlot(var1.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var1.asOutput(), MOMENTUM).get() + : null; + Variable<TFloat32> rms0 = + instance.getSlot(var0.asOutput(), RMS).isPresent() + ? instance.getSlot(var0.asOutput(), RMS).get() + : null; + Variable<TFloat32> rms1 = + instance.getSlot(var1.asOutput(), RMS).isPresent() + ? instance.getSlot(var1.asOutput(), RMS).get() + : null; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; // temp to match RMSProp + FloatNdArray mg0_np = NdArrays.vectorOf(zeros); + FloatNdArray mg1_np = NdArrays.vectorOf(zeros); + FloatNdArray rms0_np = NdArrays.vectorOf(ones); + FloatNdArray rms1_np = NdArrays.vectorOf(ones); + FloatNdArray mom0_np = NdArrays.vectorOf(zeros); + FloatNdArray mom1_np = NdArrays.vectorOf(zeros); + + for (int i = 0; i < numSteps; i++) { + assertEquals(learningRate, instance.getLearningRate(), epsilon); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); + FloatNdArray[] result0 = + calc( + var0_np, + grads0_np, + mg0_np, + rms0_np, + mom0_np, + learningRate, + decay, + momentum, + epsilon, + centered); + var0_np = result0[VAR_T]; + mg0_np = result0[MG_T]; + rms0_np = result0[RMS_T]; + mom0_np = result0[MOM_T]; + + FloatNdArray[] result1 = + calc( + var1_np, + grads1_np, + mg1_np, + rms1_np, + mom1_np, + learningRate, + decay, + momentum, + epsilon, + centered); + + var1_np = result1[VAR_T]; + mg1_np = result1[MG_T]; + rms1_np = result1[RMS_T]; + mom1_np = result1[MOM_T]; + + if (centered) { + if (mg0 != null) session.evaluate(mg0_np, mg0); + else fail("mg0 is null"); + if (mg1 != null) session.evaluate(mg1_np, mg1); + else fail("mg1 is null"); + } + if (momentum > 0.F) { + if (mom0 != null) session.evaluate(mom0_np, mom0); + if (mom1 != null) session.evaluate(mom1_np, mom1); + } + + /* TODO the values returned from rms slot, do not match what I see in the python test */ + if (rms0 != null) session.evaluate(rms0_np, rms0); + else fail("rms0 is null"); + if (rms1 != null) session.evaluate(rms1_np, rms1); + else fail("rms1 is null"); + + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); + + learningRate *= 0.9F; + instance.setLearningRate(learningRate); + } + } + } + } + FloatNdArray[] calc( FloatNdArray varNp, FloatNdArray gradNp, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java index 9fb9885505c..33fb11fc31f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java @@ -16,27 +16,28 @@ import org.tensorflow.*; import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.IntNdArray; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.types.*; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; import java.io.PrintWriter; +import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Predicate; import static org.junit.jupiter.api.Assertions.*; -/** Eager Mode Test Session */ +/** An Eager Mode Test Session */ public class EagerTestSession extends TestSession { private final EagerSession session; private final Ops tf; - /** Create an Eager mode test session. */ + /** Create a EagerTestSession */ public EagerTestSession() { this.session = EagerSession.create(); this.tf = Ops.create(session).withName("test"); @@ -48,15 +49,6 @@ public Ops getTF() { return tf; } - /** - * Get the TensorFlow EagerSession instance - * - * @return the TensorFlow EagerSession instance - */ - public EagerSession getSession() { - return session; - } - /** {@inheritDoc} */ @Override public void close() { @@ -83,9 +75,19 @@ public EagerSession getEagerSession() { /** {@inheritDoc} */ @Override - public <T extends TNumber> void evaluate(double expected, Operand<T> input) { - DataType<T> dtype = input.asOutput().dataType(); + public final void run(Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + // Do nothing in EagerSession, run() only applies to Graph + } + + /** {@inheritDoc} */ + @Override + public <U extends TNumber> void evaluate( + double expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + DataType<U> dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { + @SuppressWarnings("unchecked") Operand<TFloat32> o = (Operand<TFloat32>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -96,6 +98,8 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { index.set(0); o.data().scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon)); } else if (dtype == TFloat64.DTYPE) { + + @SuppressWarnings("unchecked") Operand<TFloat64> o = (Operand<TFloat64>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -106,6 +110,8 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { index.set(0); o.data().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); } else if (dtype == TInt32.DTYPE) { + + @SuppressWarnings("unchecked") Operand<TInt32> o = (Operand<TInt32>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -116,6 +122,8 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { index.set(0); o.data().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); } else if (dtype == TInt64.DTYPE) { + + @SuppressWarnings("unchecked") Operand<TInt64> o = (Operand<TInt64>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -130,14 +138,19 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { /** {@inheritDoc} */ @Override - public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { + public <U extends TNumber> void evaluate( + Number[] expected, + Output<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, size, () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - DataType<T> dtype = input.dataType(); + DataType<U> dtype = input.dataType(); if (dtype == TFloat32.DTYPE) { + + @SuppressWarnings("unchecked") Output<TFloat32> o = (Output<TFloat32>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -153,6 +166,8 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { assertEquals( expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); } else if (dtype == TFloat64.DTYPE) { + + @SuppressWarnings("unchecked") Output<TFloat64> o = (Output<TFloat64>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -168,6 +183,8 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { assertEquals( expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); } else if (dtype == TInt32.DTYPE) { + + @SuppressWarnings("unchecked") Output<TInt32> o = (Output<TInt32>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -180,6 +197,8 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); } else if (dtype == TInt64.DTYPE) { + + @SuppressWarnings("unchecked") Output<TInt64> o = (Output<TInt64>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -196,9 +215,14 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { /** {@inheritDoc} */ @Override - public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { - DataType<T> dtype = input.dataType(); + public <U extends TNumber> void evaluate( + FloatNdArray expected, + Output<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + DataType<U> dtype = input.dataType(); if (dtype == TFloat32.DTYPE) { + + @SuppressWarnings("unchecked") Output<TFloat32> o = (Output<TFloat32>) input; AtomicLong index = new AtomicLong(); if (debug) { @@ -212,6 +236,8 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { .forEach( f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); } else if (dtype == TFloat64.DTYPE) { + + @SuppressWarnings("unchecked") Output<TFloat64> o = (Output<TFloat64>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -226,6 +252,8 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); } else if (dtype == TInt32.DTYPE) { + + @SuppressWarnings("unchecked") Output<TInt32> o = (Output<TInt32>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -234,10 +262,12 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } index.set(0); - for (IntNdArray f : o.data().scalars()) { - assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt()); - } + o.data() + .scalars() + .forEach(f -> assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt())); } else if (dtype == TInt64.DTYPE) { + + @SuppressWarnings("unchecked") Output<TInt64> o = (Output<TInt64>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -255,11 +285,16 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { /** {@inheritDoc} */ @Override - public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predicate) { + public <U extends TNumber> void evaluate( + Output<U> input, + Predicate<Number> predicate, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { AtomicInteger index = new AtomicInteger(); - DataType<T> dtype = input.asOutput().dataType(); + DataType<U> dtype = input.asOutput().dataType(); boolean isScalar = input.shape().equals(Shape.scalar()); if (dtype == TFloat32.DTYPE) { + + @SuppressWarnings("unchecked") Output<TFloat32> o = (Output<TFloat32>) input; if (debug) { if (isScalar) { @@ -284,6 +319,8 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getFloat()))); } } else if (dtype == TFloat64.DTYPE) { + + @SuppressWarnings("unchecked") Output<TFloat64> o = (Output<TFloat64>) input; if (debug) { if (isScalar) { @@ -308,6 +345,8 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getDouble()))); } } else if (dtype == TInt32.DTYPE) { + + @SuppressWarnings("unchecked") Output<TInt32> o = (Output<TInt32>) input; if (debug) { if (isScalar) { @@ -332,6 +371,8 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getInt()))); } } else if (dtype == TInt64.DTYPE) { + + @SuppressWarnings("unchecked") Output<TInt64> o = (Output<TInt64>) input; if (debug) { if (isScalar) { @@ -362,7 +403,10 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic /** {@inheritDoc} */ @Override - public void evaluate(String[] expected, Output<TString> input) { + public void evaluate( + String[] expected, + Output<TString> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -384,7 +428,10 @@ public void evaluate(String[] expected, Output<TString> input) { /** {@inheritDoc} */ @Override - public void evaluate(Boolean[] expected, Output<TBool> input) { + public void evaluate( + Boolean[] expected, + Output<TBool> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -406,7 +453,10 @@ public void evaluate(Boolean[] expected, Output<TBool> input) { /** {@inheritDoc} */ @Override - public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { + public <T extends TType> void evaluate( + Output<T> expected, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { assert input.shape().equals(expected.shape()) : String.format( "expected shape (%s) != to input shape (%s)", @@ -414,7 +464,10 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { DataType<T> dtype = input.asOutput().dataType(); boolean isScalar = input.shape().equals(Shape.scalar()); if (dtype == TFloat32.DTYPE) { + + @SuppressWarnings("unchecked") Output<TFloat32> x = (Output<TFloat32>) expected; + @SuppressWarnings("unchecked") Output<TFloat32> o = (Output<TFloat32>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -440,7 +493,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { (idx, f) -> assertEquals(x.data().getFloat(idx), f.getFloat(), epsilon)); } } else if (dtype == TFloat64.DTYPE) { + @SuppressWarnings("unchecked") Output<TFloat64> x = (Output<TFloat64>) expected; + @SuppressWarnings("unchecked") Output<TFloat64> o = (Output<TFloat64>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -466,7 +521,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { (idx, f) -> assertEquals(x.data().getDouble(idx), f.getDouble(), epsilon)); } } else if (dtype == TInt32.DTYPE) { + @SuppressWarnings("unchecked") Output<TInt32> x = (Output<TInt32>) expected; + @SuppressWarnings("unchecked") Output<TInt32> o = (Output<TInt32>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -491,7 +548,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { .forEachIndexed((idx, f) -> assertEquals(x.data().getInt(idx), f.getInt())); } } else if (dtype == TInt64.DTYPE) { + @SuppressWarnings("unchecked") Output<TInt64> x = (Output<TInt64>) expected; + @SuppressWarnings("unchecked") Output<TInt64> o = (Output<TInt64>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -516,7 +575,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { .forEachIndexed((idx, f) -> assertEquals(x.data().getLong(idx), f.getLong())); } } else if (dtype == TString.DTYPE) { + @SuppressWarnings("unchecked") Output<TString> x = (Output<TString>) expected; + @SuppressWarnings("unchecked") Output<TString> o = (Output<TString>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -541,7 +602,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { .forEachIndexed((idx, f) -> assertEquals(x.data().getObject(idx), f.getObject())); } } else if (dtype == TBool.DTYPE) { + @SuppressWarnings("unchecked") Output<TBool> x = (Output<TBool>) expected; + @SuppressWarnings("unchecked") Output<TBool> o = (Output<TBool>) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -570,44 +633,53 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { /** {@inheritDoc} */ @Override - public <T extends TType> void print(PrintWriter writer, Output<T> input) { + public <T extends TType> void print( + PrintWriter writer, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { DataType<T> dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { + @SuppressWarnings("unchecked") Output<TFloat32> o = (Output<TFloat32>) input; AtomicInteger index = new AtomicInteger(); o.data() .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + .forEach(f -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } else if (dtype == TFloat64.DTYPE) { + @SuppressWarnings("unchecked") Output<TFloat64> o = (Output<TFloat64>) input; AtomicInteger index = new AtomicInteger(); o.data() .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + .forEach(f -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } else if (dtype == TInt32.DTYPE) { + @SuppressWarnings("unchecked") Output<TInt32> o = (Output<TInt32>) input; AtomicInteger index = new AtomicInteger(); o.data() .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + .forEach(f -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } else if (dtype == TInt64.DTYPE) { + @SuppressWarnings("unchecked") Output<TInt64> o = (Output<TInt64>) input; AtomicInteger index = new AtomicInteger(); o.data() .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + .forEach(f -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } else if (dtype == TString.DTYPE) { + @SuppressWarnings("unchecked") Output<TString> o = (Output<TString>) input; AtomicInteger index = new AtomicInteger(); o.data() .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + .forEach(f -> writer.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); } else if (dtype == TBool.DTYPE) { + @SuppressWarnings("unchecked") Output<TBool> o = (Output<TBool>) input; AtomicInteger index = new AtomicInteger(); o.data() .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); + .forEach(f -> writer.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); } else { writer.println("Unexpected DataType: " + dtype); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 0416267ae59..cc9140c5134 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.utils; import org.tensorflow.*; +import org.tensorflow.Session.Runner; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -24,20 +25,21 @@ import org.tensorflow.types.family.TType; import java.io.PrintWriter; +import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Predicate; import static org.junit.jupiter.api.Assertions.*; -/** Graph Mode Test Session */ +/** A Graph Mode Test Session */ public class GraphTestSession extends TestSession { private final Graph graph; private final Session session; private final Ops tf; - /** Create a Graph mode test session. */ + /** Create a Graph Test Session */ public GraphTestSession() { graph = new Graph(); session = new Session(graph); @@ -50,20 +52,12 @@ public Ops getTF() { return tf; } - /** Get the Graph object that is represented by this Test Session */ + /** {@inheritDoc} */ + @Override public Graph getGraph() { return graph; } - /** - * Get the TensorFlow Session instance - * - * @return the TensorFlow Session instance - */ - public Session getSession() { - return session; - } - /** {@inheritDoc} */ @Override public void close() { @@ -92,24 +86,45 @@ public EagerSession getEagerSession() { /** {@inheritDoc} */ @Override public void initialize() { - graph.initializers().forEach(initializer -> session.runner().addTarget(initializer).run()); + session.run(tf.init()); } /** {@inheritDoc} */ @Override - public void run(Op op) { - session.run(op); + public final void run(Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + createRunner(op, feedDict).run(); + } + + /** + * Create a runner for the Operation + * + * @param op the operation + * @param feedMap the dictionary of values to use for the runner's feed operations. Required when + * placeholders are used. + * @return the runner + */ + public final Runner createRunner( + Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + Runner runner = session.runner(); + runner.addTarget(op.op()); + if (feedMap != null) { + feedMap.forEach(runner::feed); + } + return runner; } /** {@inheritDoc} */ @Override - public <T extends TNumber> void evaluate(double expected, Operand<T> input) { - DataType<T> dtype = input.asOutput().dataType(); + public <U extends TNumber> void evaluate( + double expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + DataType<U> dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -118,14 +133,14 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { } index.set(0); try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result.data().scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); } } else if (dtype == TFloat64.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -134,14 +149,14 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { } index.set(0); try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result.data().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); } } else if (dtype == TInt32.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -150,14 +165,14 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { } index.set(0); try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result.data().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); } } else if (dtype == TInt64.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -166,7 +181,7 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { } index.set(0); try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result.data().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); } } else { @@ -176,18 +191,24 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) { /** {@inheritDoc} */ @Override - public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - DataType<T> dtype = input.asOutput().dataType(); + public <U extends TNumber> void evaluate( + Number[] expected, + Output<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + long size = input.shape().size() == 0 ? 1 : input.shape().size(); + if (size != Shape.UNKNOWN_SIZE) { + assertEquals( + expected.length, + size, + () -> + String.format("expected length (%d) != to input length (%d)", expected.length, size)); + } + DataType<U> dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -196,7 +217,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { } index.set(0); try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -209,7 +230,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -218,7 +239,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { } index.set(0); try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -231,7 +252,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -240,7 +261,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { } index.set(0); try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -250,7 +271,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -259,7 +280,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { } index.set(0); try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -272,13 +293,16 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) { /** {@inheritDoc} */ @Override - public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { - DataType<T> dtype = input.asOutput().dataType(); + public <U extends TNumber> void evaluate( + FloatNdArray expected, + Output<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { + DataType<U> dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { AtomicLong index = new AtomicLong(); if (debug) { try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -287,7 +311,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { } index.set(0); try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -300,7 +324,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -309,7 +333,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { } index.set(0); try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -322,7 +346,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -331,7 +355,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { } index.set(0); try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -342,7 +366,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -351,7 +375,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { } index.set(0); try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -365,7 +389,10 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) { /** {@inheritDoc} */ @Override - public void evaluate(String[] expected, Output<TString> input) { + public void evaluate( + String[] expected, + Output<TString> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -374,7 +401,7 @@ public void evaluate(String[] expected, Output<TString> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TString> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { result .data() .scalars() @@ -383,7 +410,7 @@ public void evaluate(String[] expected, Output<TString> input) { } index.set(0); try (Tensor<TString> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { result .data() .scalars() @@ -393,7 +420,10 @@ public void evaluate(String[] expected, Output<TString> input) { /** {@inheritDoc} */ @Override - public void evaluate(Boolean[] expected, Output<TBool> input) { + public void evaluate( + Boolean[] expected, + Output<TBool> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -402,7 +432,7 @@ public void evaluate(Boolean[] expected, Output<TBool> input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor<TBool> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { result .data() .scalars() @@ -411,7 +441,7 @@ public void evaluate(Boolean[] expected, Output<TBool> input) { } index.set(0); try (Tensor<TBool> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { result .data() .scalars() @@ -421,27 +451,25 @@ public void evaluate(Boolean[] expected, Output<TBool> input) { /** {@inheritDoc} */ @Override - public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { + public <T extends TType> void evaluate( + Output<T> expected, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { assert input.shape().equals(expected.shape()) : String.format( "expected shape (%s) != to input shape (%s)", expected.shape().toString(), input.shape().toString()); AtomicInteger index = new AtomicInteger(); DataType<T> dtype = input.asOutput().dataType(); - if (!dtype.equals(expected.dataType())) { - throw new IllegalArgumentException( - String.format( - "Both data type must be equal, inout = %s, expected = %s", - dtype, expected.dataType())); - } boolean isScalar = input.shape().equals(Shape.scalar()); if (dtype == TFloat32.DTYPE) { + @SuppressWarnings("unchecked") final Output<TFloat32> finalExpected = (Output<TFloat32>) expected; if (debug) { try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE); Tensor<TFloat32> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { System.out.printf( "0). %f <==> %f\n", expectedResult.data().getFloat(), result.data().getFloat()); @@ -461,9 +489,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } index.set(0); try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE); Tensor<TFloat32> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getFloat(), result.data().getFloat(), epsilon); } else { @@ -476,12 +504,13 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } } } else if (dtype == TFloat64.DTYPE) { + @SuppressWarnings("unchecked") final Output<TFloat64> finalExpected = (Output<TFloat64>) expected; if (debug) { try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE); Tensor<TFloat64> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { System.out.printf( "0). %f <==> %f\n", expectedResult.data().getDouble(), result.data().getDouble()); @@ -501,9 +530,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } index.set(0); try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE); Tensor<TFloat64> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getDouble(), result.data().getDouble(), epsilon); } else { @@ -516,12 +545,13 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } } } else if (dtype == TInt32.DTYPE) { + @SuppressWarnings("unchecked") final Output<TInt32> finalExpected = (Output<TInt32>) expected; if (debug) { try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE); Tensor<TInt32> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { System.out.printf( "0). %d <==> %d\n", expectedResult.data().getInt(), result.data().getInt()); @@ -539,9 +569,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } index.set(0); try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE); Tensor<TInt32> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getInt(), result.data().getInt(), epsilon); } else { @@ -553,12 +583,13 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } } } else if (dtype == TInt64.DTYPE) { + @SuppressWarnings("unchecked") final Output<TInt64> finalExpected = (Output<TInt64>) expected; if (debug) { try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE); Tensor<TInt64> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { System.out.printf( "0). %d <==> %d\n", expectedResult.data().getLong(), result.data().getLong()); @@ -578,9 +609,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } index.set(0); try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE); Tensor<TInt64> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getLong(), result.data().getLong(), epsilon); } else { @@ -593,12 +624,13 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } } } else if (dtype == TBool.DTYPE) { + @SuppressWarnings("unchecked") final Output<TBool> finalExpected = (Output<TBool>) expected; if (debug) { try (Tensor<TBool> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE); Tensor<TBool> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %b\n", expectedResult.data().getBoolean(), result.data().getBoolean()); @@ -618,9 +650,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } index.set(0); try (Tensor<TBool> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE); Tensor<TBool> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getBoolean(), result.data().getBoolean()); } else { @@ -632,12 +664,13 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } } } else if (dtype == TString.DTYPE) { + @SuppressWarnings("unchecked") final Output<TString> finalExpected = (Output<TString>) expected; if (debug) { try (Tensor<TString> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE); Tensor<TString> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { if (isScalar) { System.out.printf( "0). %s <==> %s\n", expectedResult.data().getObject(), result.data().getObject()); @@ -657,9 +690,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } index.set(0); try (Tensor<TString> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE); Tensor<TString> expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getObject(), result.data().getObject()); } else { @@ -676,15 +709,17 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) { } /** {@inheritDoc} */ - @Override - public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predicate) { + public <U extends TNumber> void evaluate( + Output<U> input, + Predicate<Number> predicate, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { AtomicInteger index = new AtomicInteger(); - DataType<T> dtype = input.asOutput().dataType(); + DataType<U> dtype = input.asOutput().dataType(); boolean isScalar = input.shape().equals(Shape.scalar()); if (dtype == TFloat32.DTYPE) { if (debug) { try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %f\n", @@ -703,7 +738,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } index.set(0); try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { assertTrue(predicate.test(result.data().getFloat())); } else { @@ -716,7 +751,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } else if (dtype == TFloat64.DTYPE) { if (debug) { try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %f\n", @@ -735,7 +770,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } index.set(0); try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { assertTrue(predicate.test(result.data().getDouble())); } else { @@ -748,7 +783,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } else if (dtype == TInt32.DTYPE) { if (debug) { try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %d\n", predicate.test(result.data().getInt()), result.data().getInt()); @@ -766,7 +801,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } index.set(0); try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { assertTrue(predicate.test(result.data().getInt())); } else { @@ -779,7 +814,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } else if (dtype == TInt64.DTYPE) { if (debug) { try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %d\n", @@ -798,7 +833,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic } index.set(0); try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { assertTrue(predicate.test(result.data().getLong())); } else { @@ -815,14 +850,17 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic /** {@inheritDoc} */ @Override - public <T extends TType> void print(PrintWriter writer, Output<T> input) { + public <T extends TType> void print( + PrintWriter writer, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) { boolean isScalar = input.asOutput().shape().size() == 1; DataType<T> dtype = input.dataType(); if (dtype == TFloat32.DTYPE) { AtomicInteger index = new AtomicInteger(); try (Tensor<TFloat32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { writer.printf("%d). %f\n", index.getAndIncrement(), result.data().getFloat()); } else { @@ -837,10 +875,11 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) { AtomicInteger index = new AtomicInteger(); try (Tensor<TFloat64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { - writer.printf( - "%d). %f\n", index.getAndIncrement(), ((Output<TFloat64>) input).data().getDouble()); + @SuppressWarnings("unchecked") + Output<TFloat64> o = (Output<TFloat64>) input; + writer.printf("%d). %f\n", index.getAndIncrement(), o.data().getDouble()); } else { result .data() @@ -853,10 +892,11 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) { AtomicInteger index = new AtomicInteger(); try (Tensor<TInt32> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { - writer.printf( - "%d). %d\n", index.getAndIncrement(), ((Output<TInt32>) input).data().getInt()); + @SuppressWarnings("unchecked") + Output<TInt32> o = (Output<TInt32>) input; + writer.printf("%d). %d\n", index.getAndIncrement(), o.data().getInt()); } else { result .data() @@ -869,10 +909,11 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) { AtomicInteger index = new AtomicInteger(); try (Tensor<TInt64> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { - writer.printf( - "%d). %d\n", index.getAndIncrement(), ((Output<TInt64>) input).data().getLong()); + @SuppressWarnings("unchecked") + Output<TInt64> o = (Output<TInt64>) input; + writer.printf("%d). %d\n", index.getAndIncrement(), o.data().getLong()); } else { result .data() @@ -885,10 +926,11 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) { AtomicInteger index = new AtomicInteger(); try (Tensor<TBool> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { if (isScalar) { - writer.printf( - "%d). %b\n", index.getAndIncrement(), ((Output<TBool>) input).data().getBoolean()); + @SuppressWarnings("unchecked") + Output<TBool> o = (Output<TBool>) input; + writer.printf("%d). %b\n", index.getAndIncrement(), o.data().getBoolean()); } else { result .data() @@ -901,10 +943,11 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) { AtomicInteger index = new AtomicInteger(); try (Tensor<TString> result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { if (isScalar) { - writer.printf( - "%d). %s\n", index.getAndIncrement(), ((Output<TString>) input).data().getObject()); + @SuppressWarnings("unchecked") + Output<TString> o = (Output<TString>) input; + writer.printf("%d). %s\n", index.getAndIncrement(), o.data().getObject()); } else { result .data() diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index a0855eb6260..713225a4962 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -27,583 +27,1136 @@ import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.io.Writer; +import java.util.Map; import java.util.function.Predicate; import static org.junit.jupiter.api.Assertions.assertTrue; -/** Base class for Test Session */ +/** Abstract class for Test Sessions */ public abstract class TestSession implements AutoCloseable { protected float epsilon = 1e-5F; protected boolean debug; - /** The Test Session mode, either Eager or Graph */ + /** Enumerate between Eager and Graph Mode */ public enum Mode { EAGER, GRAPH } - /** - * Creates an Eager Test Session - * - * @return the Eager Test Session - */ public static TestSession createEagerSession() { return new EagerTestSession(); } - /** - * Creates a Graph Test Session - * - * @return the Graph Test Session - */ public static TestSession createGraphSession() { return new GraphTestSession(); } - /** - * Creates a Test Session - * - * @param mode the type of Session, either eager or graph - * @return returns the test session - */ public static TestSession createTestSession(Mode mode) { return mode == Mode.EAGER ? createEagerSession() : createGraphSession(); } - /** Initializes the Test Session, default implementation is do nothing. */ + /** + * Initializer any graph initializers, if in Graph mode, for Eager mode, this method does nothing. + */ public void initialize() { // empty } /** - * Runs the Operation + * Returns the Graph if in Graph mode, or null if in EagerMode + * + * @return the Graph if in Graph mode, or null if in EagerMode + */ + public Graph getGraph() { + return null; + } + + /** + * Perform session.run() + * + * <p>If in eager mode, this does nothing. * - * @param op the Operation to run + * @param op The Operation to run */ public void run(Op op) { - // empty + run(op, null); } /** - * Gets the Graph + * Perform session.run() * - * @return the graph if in Graph Mode, otherwise null. + * <p>If in eager mode, this does nothing. + * + * @param op The Operation to run + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ - public Graph getGraph() { - return null; + public abstract void run(Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(Number expected, Operand<U> input) { + evaluate(new Number[] {expected}, input, null); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(Number expected, Operand<T> input) { - evaluate(new Number[] {expected}, input); + public <U extends TNumber> void evaluate( + Number expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(new Number[] {expected}, input, feedMap); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value */ public void evaluate(Number expected, Op input) { - evaluate(new Number[] {expected}, input); + evaluate(new Number[] {expected}, input, null); } /** - * Evaluates the input against the expected values + * Evaluate the expected results versus the actual results * - * @param expected the expected values - * @param input the operand to evaluate - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ - public void evaluate(Number[] expected, Op input) { - Output<? extends TNumber> output = input.op().output(0); - evaluate(expected, output); + public void evaluate( + Number expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(new Number[] {expected}, input, feedMap); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param <U> the data type for the input */ - public <T extends TNumber> void evaluate(Number[] expected, Operand<T> input) { - Output<T> output = input.asOutput(); - evaluate(expected, output); + public <U extends TNumber> void evaluate(Number[] expected, Op input) { + Output<U> output = input.op().output(0); + evaluate(expected, output, null); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type for the input */ - public <T extends TNumber> void evaluate(byte expected, Operand<T> input) { - evaluate((double) expected, input); + public <U extends TNumber> void evaluate( + Number[] expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + Output<U> output = input.op().output(0); + evaluate(expected, output, feedMap); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(int expected, Operand<T> input) { - evaluate((double) expected, input); + public <U extends TNumber> void evaluate(Number[] expected, Operand<U> input) { + Output<U> output = input.asOutput(); + evaluate(expected, output, null); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(long expected, Operand<T> input) { - evaluate((double) expected, input); + public <U extends TNumber> void evaluate( + Number[] expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + Output<U> output = input.asOutput(); + evaluate(expected, output, feedMap); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(float expected, Operand<T> input) { - evaluate((double) expected, input); + public <U extends TNumber> void evaluate(byte expected, Operand<U> input) { + evaluate((double) expected, input, null); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input */ - public abstract <T extends TNumber> void evaluate(double expected, Operand<T> input); + public <U extends TNumber> void evaluate( + byte expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate((double) expected, input, feedMap); + } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(int expected, Operand<U> input) { + evaluate((double) expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + int expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate((double) expected, input, feedMap); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(long expected, Operand<U> input) { + evaluate((double) expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + long expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate((double) expected, input, feedMap); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(float expected, Operand<U> input) { + evaluate((double) expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(byte[] expected, Operand<T> input) { + public <U extends TNumber> void evaluate( + float expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate((double) expected, input, feedMap); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(double expected, Operand<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input + */ + public abstract <U extends TNumber> void evaluate( + double expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(byte[] expected, Operand<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + byte[] expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Byte[] iArray = new Byte[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedMap); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(int[] expected, Operand<T> input) { + public <U extends TNumber> void evaluate(int[] expected, Operand<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + int[] expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Integer[] iArray = new Integer[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedMap); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(long[] expected, Operand<T> input) { + public <U extends TNumber> void evaluate(long[] expected, Operand<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + long[] expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Long[] iArray = new Long[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedMap); } /** - * Evaluates the input against the expected values + * Evaluate the expected results versus the actual results * - * @param expected the expected values - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(float[] expected, Operand<T> input) { + public <U extends TNumber> void evaluate(float[] expected, Operand<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + float[] expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Float[] iArray = new Float[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedMap); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TNumber> void evaluate(double[] expected, Operand<T> input) { + public <U extends TNumber> void evaluate(double[] expected, Operand<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate( + double[] expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Double[] iArray = new Double[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedMap); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param <U> the data type of the input + */ + public <U extends TNumber> void evaluate(Number[] expected, Output<U> input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input */ - public abstract <T extends TNumber> void evaluate(Number[] expected, Output<T> input); + public abstract <U extends TNumber> void evaluate( + Number[] expected, + Output<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value */ public void evaluate(String expected, Operand<TString> input) { - evaluate(new String[] {expected}, input); + evaluate(expected, input, null); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + */ + public void evaluate( + String expected, + Operand<TString> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(new String[] {expected}, input, feedMap); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(String expected, Op input) { - evaluate(new String[] {expected}, input); + evaluate(new String[] {expected}, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + */ + public void evaluate( + String expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(new String[] {expected}, input, feedMap); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value */ public void evaluate(String[] expected, Op input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + */ + public void evaluate( + String[] expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { Output<TString> output = input.op().output(0); - evaluate(expected, output); + evaluate(expected, output, feedMap); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value */ public void evaluate(String[] expected, Operand<TString> input) { Output<TString> output = input.asOutput(); - evaluate(expected, output); + evaluate(expected, output, null); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ - public abstract void evaluate(String[] expected, Output<TString> input); + public abstract void evaluate( + String[] expected, + Output<TString> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value */ public void evaluate(Boolean expected, Operand<TBool> input) { - evaluate(new Boolean[] {expected}, input); + evaluate(new Boolean[] {expected}, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + */ + public void evaluate( + Boolean expected, + Operand<TBool> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(new Boolean[] {expected}, input, feedMap); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value */ public void evaluate(Boolean expected, Op input) { - evaluate(new Boolean[] {expected}, input); + evaluate(new Boolean[] {expected}, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + */ + public void evaluate( + Boolean expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(new Boolean[] {expected}, input, feedMap); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value */ public void evaluate(Boolean[] expected, Op input) { Output<TBool> output = input.op().output(0); - evaluate(expected, output); + evaluate(expected, output, null); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + */ + public void evaluate( + Boolean[] expected, + Op input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + Output<TBool> output = input.op().output(0); + evaluate(expected, output, feedMap); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(Boolean[] expected, Operand<TBool> input) { Output<TBool> output = input.asOutput(); - evaluate(expected, output); + evaluate(expected, output, null); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ - public abstract void evaluate(Boolean[] expected, Output<TBool> input); + public void evaluate( + Boolean[] expected, + Operand<TBool> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + Output<TBool> output = input.asOutput(); + evaluate(expected, output, feedMap); + } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the expected Operand - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value */ - public <T extends TType> void evaluate(Operand<T> expected, Op input) { - Output<T> output = input.op().output(0); - evaluate(expected, output); + public void evaluate(Boolean[] expected, Output<TBool> input) { + evaluate(expected, input, null); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + */ + public abstract void evaluate( + Boolean[] expected, + Output<TBool> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); + + /** + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate + * @param input the actual value * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public <T extends TType> void evaluate(Operand<T> expected, Output<T> input) { + evaluate(expected.asOutput(), input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param <T> the data type for the feedMap entries */ public <T extends TType> void evaluate(Operand<T> expected, Operand<T> input) { - evaluate(expected.asOutput(), input.asOutput()); + evaluate(expected.asOutput(), input.asOutput(), null); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <T> the data type for the feedMap entries */ - public abstract <T extends TType> void evaluate(Output<T> expected, Output<T> input); + public abstract <T extends TType> void evaluate( + Output<T> expected, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TType> void evaluate(FloatNdArray expected, Operand<T> input) { - evaluate(expected, input.asOutput()); + public <U extends TNumber> void evaluate(FloatNdArray expected, Operand<U> input) { + evaluate(expected, input.asOutput(), null); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input */ - public abstract <T extends TType> void evaluate(FloatNdArray expected, Output<T> input); + public <U extends TNumber> void evaluate( + FloatNdArray expected, + Operand<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + evaluate(expected, input.asOutput(), feedMap); + } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * - * @param input the operand to evaluate - * @param predicate the Predicate - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param expected the expected value + * @param input the actual value + * @param <U> the data type of the input */ - public <T extends TType> void evaluate(Operand<T> input, Predicate<Number> predicate) { - evaluate(input.asOutput(), predicate); + public <U extends TNumber> void evaluate(FloatNdArray expected, Output<U> input) { + evaluate(expected, input, null); } /** - * Evaluates the input against the expected value + * Evaluate the expected results versus the actual results * - * @param input the operand to evaluate - * @param predicate The Predicate that evaluates the each value from input - * @param <T> the data type of the input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param expected the expected value + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input + */ + public abstract <U extends TNumber> void evaluate( + FloatNdArray expected, + Output<U> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); + + /** + * Evaluate the actual results using a predicate + * + * @param input the actual value + * @param predicate a predicate that accepts a Number as an argument, if the result of the + * predicate is false, then the test will fail + * @param <U> the data type of the input */ - public abstract <T extends TType> void evaluate(Output<T> input, Predicate<Number> predicate); + public <U extends TNumber> void evaluate(Operand<U> input, Predicate<Number> predicate) { + evaluate(input.asOutput(), predicate, null); + } /** - * Evaluates the input against the expected value + * Evaluate the actual results using a predicate * - * @param input the operand to evaluate - * @param predicate The Predicate that evaluates the each value from input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails + * @param input the actual value + * @param predicate a predicate that accepts a Number as an argument, if the result of the + * predicate is false, then the test will fail + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <U> the data type of the input + */ + public abstract <U extends TNumber> void evaluate( + Output<U> input, + Predicate<Number> predicate, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); + + /** + * Evaluate the actual results using a predicate + * + * @param input the actual value + * @param predicate a predicate that accepts a Number as an argument, if the result of the + * predicate is false, then the test will fail */ public void evaluate(FloatNdArray input, Predicate<Number> predicate) { input.scalars().forEach(f -> assertTrue(predicate.test(f.getFloat()))); } /** - * Print the input + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param <T> the data type for the input + */ + public <T extends TType> void print(Operand<T> input) { + print(System.out, input, null); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <T> the data type for the feedMap entries + */ + public <T extends TType> void print( + Operand<T> input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.asOutput(), feedMap); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + */ + public void print(Op input) { + print(System.out, input, null); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + */ + public void print(Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0), feedMap); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param <T> the data type for the input + */ + public <T extends TType> void print(Output<T> input) { + print(System.out, input, null); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <T> the data type for the input + */ + public <T extends TType> void print( + Output<T> input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input, feedMap); + } + + /** + * Print the results to output stream * * @param out the output stream - * @param input the operand to print - * @param <T> the data type of the input + * @param input the actual value + * @param <T> the data type for the input */ public <T extends TType> void print(OutputStream out, Operand<T> input) { - print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput()); + print(out, input, null); } /** - * Print the input + * Print the results to output stream * * @param out the output stream - * @param input the op to print + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <T> the data type for the feedMap entries + */ + public <T extends TType> void print( + OutputStream out, + Operand<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput(), feedMap); + } + + /** + * Print the results to output stream + * + * @param out the output stream + * @param input the actual value */ public void print(OutputStream out, Op input) { - print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0)); + print(out, input, null); } /** - * Print the input + * Print the results to output stream * * @param out the output stream - * @param input the op to print - * @param <T> the data type of the input + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + */ + public void print( + OutputStream out, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0), feedMap); + } + + /** + * Print the results to output stream + * + * @param out the output stream + * @param input the actual value + * @param <T> the data type for the input */ public <T extends TType> void print(OutputStream out, Output<T> input) { - print(new PrintWriter(new OutputStreamWriter(out)), input); + print(out, input, null); } /** - * Print the input + * Print the results to output stream * - * @param writer the output writer - * @param input the operand to print - * @param <T> the data type of the input + * @param out the output stream + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <T> the data type for the input + */ + public <T extends TType> void print( + OutputStream out, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(new OutputStreamWriter(out)), input, feedMap); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param <T> the data type for the input */ public <T extends TType> void print(Writer writer, Operand<T> input) { - print(new PrintWriter(writer), input.asOutput()); + print(writer, input, null); } /** - * Print the input + * Print the results to the character stream * - * @param writer the output writer - * @param input the op to print + * @param writer the character stream + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <T> the data type for the input + */ + public <T extends TType> void print( + Writer writer, + Operand<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(writer), input.asOutput(), feedMap); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value */ public void print(Writer writer, Op input) { - print(new PrintWriter(writer), input.op().output(0)); + print(writer, input, null); } /** - * Print the input + * Print the results to the character stream * - * @param writer the output writer - * @param input the op to print - * @param <T> the data type of the input + * @param writer the character stream + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + */ + public void print( + Writer writer, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(writer), input.op().output(0), feedMap); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param <T> the data type for the input */ public <T extends TType> void print(Writer writer, Output<T> input) { - print(new PrintWriter(writer), input); + print(writer, input, null); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <T> the data type for the input + */ + public <T extends TType> void print( + Writer writer, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) { + print(new PrintWriter(writer), input, feedMap); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param <T> the data type for the input + */ + public <T extends TType> void print(PrintWriter writer, Output<T> input) { + print(writer, input, null); } /** - * Print the input + * Print the results to the character stream * - * @param writer the output writer - * @param input the op to print + * @param writer the character stream + * @param input the actual value + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. + * @param <T> the data type for the input */ - public abstract <T extends TType> void print(PrintWriter writer, Output<T> input); + public abstract <T extends TType> void print( + PrintWriter writer, + Output<T> input, + Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap); /** - * Get the TensorFlow Ops + * Get the TensorFlow Ops for this test session * - * @return the TensorFlow Ops + * @return the TensorFlow Ops for this test session */ public abstract Ops getTF(); /** - * Determine if this Test Session represents an Eager Session + * Determine whether this session is in Eager mode * - * @return true, if this Test Session represents an Eager Session + * @return true if the this session is in Eager mode */ public abstract boolean isEager(); /** - * Determine if this Test Session represents a Graph Session + * Determine whether this session is in Graph mode * - * @return true, if this Test Session represents a Graph Session + * @return true if the this session is in Graph mode */ public boolean isGraph() { return !isEager(); } /** - * Get the epsilon value for evaluating float values + * Get the current EPSILON value for floating point number comparison. * - * @return the epsilon value for evaluating float values + * @return the current EPSILON value for floating point number comparison. */ public float getEpsilon() { return this.epsilon; } /** - * Set the epsilon value for evaluating float values + * Set the current EPSILON value for floating point number comparison. * - * @param epsilon the epsilon value for evaluating float values + * @param epsilon the new EPSILON value for floating point number comparison. */ public void setEpsilon(float epsilon) { this.epsilon = epsilon; } /** - * Get the TensorFlow session object associated with this Test Session + * Get the TensorFlow Session object * - * @return a TensorFlow session if this is a Graph session, otherwise null + * @return the TensorFlow Session object, returns null if this is not a Graph Test Session */ public abstract Session getGraphSession(); /** - * Get the TensorFlow eager session object associated with this Test Session + * Get the TensorFlow EagerSession object * - * @return a TensorFlow session if this is an eager session, otherwise null + * @return the TensorFlow Session object, returns null if this is not a Graph Test Session */ public abstract EagerSession getEagerSession(); @@ -611,15 +1164,21 @@ public void setEpsilon(float epsilon) { @Override public abstract void close(); - /** @return the debug setting */ + /** + * Get the debug setting + * + * @return the debug setting + */ public boolean isDebug() { return debug; } /** - * Set the debug flag + * Sets the debug setting. + * + * <p>If true, then evaluate methods will also print the Tensor values to System.out. * - * @param debug the setting for debugging + * @param debug the debug to set */ public void setDebug(boolean debug) { this.debug = debug;