diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java
index 1bdeffe4dcb..30abd0fcbe3 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java
@@ -20,6 +20,7 @@
 import org.tensorflow.Output;
 import org.tensorflow.op.Op;
 import org.tensorflow.op.core.Variable;
+import org.tensorflow.types.TFloat32;
 import org.tensorflow.types.family.TType;
 
 import java.util.List;
@@ -62,24 +63,31 @@
  */
 public class AdaDelta extends Optimizer {
 
+  public static final String DEFAULT_NAME = "Adadelta";
   public static final String ACCUMULATOR = "accum";
   public static final String ACCUMULATOR_UPDATE = "accum_update";
   public static final float LEARNING_RATE_DEFAULT = 0.001f;
   public static final float RHO_DEFAULT = 0.95f;
   public static final float EPSILON_DEFAULT = 1e-7f;
 
-  private final float learningRate;
-
   private final float rho;
 
   private final float epsilon;
 
+  /**
+   * Creates an AdaDelta Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #LEARNING_RATE_DEFAULT} for the learningRate, {@link #RHO_DEFAULT} for the rho, and {@link
+   * #EPSILON_DEFAULT} for the epsilon.
+   *
+   * @param graph the TensorFlow graph.
+   */
   public AdaDelta(Graph graph) {
     this(graph, LEARNING_RATE_DEFAULT, RHO_DEFAULT, EPSILON_DEFAULT);
   }
 
   /**
-   * Creates an AdaDelta Optimizer
+   * Creates an AdaDelta Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #RHO_DEFAULT} for the rho, and {@link #EPSILON_DEFAULT} for the epsilon.
    *
    * @param graph the TensorFlow Graph
    * @param learningRate the learning rate
@@ -89,7 +97,19 @@ public AdaDelta(Graph graph, float learningRate) {
   }
 
   /**
-   * Creates an AdaDelta Optimizer
+   * Creates an AdaDelta Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #RHO_DEFAULT} for the rho, and {@link #EPSILON_DEFAULT} for the epsilon.
+   *
+   * @param graph the TensorFlow Graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public AdaDelta(Graph graph, Operand<TFloat32> learningRateOperand) {
+    this(graph, learningRateOperand, RHO_DEFAULT, EPSILON_DEFAULT);
+  }
+
+  /**
+   * Creates an AdaDelta Optimizer {@link #DEFAULT_NAME} for the Optimizer name
    *
    * @param graph the TensorFlow Graph
    * @param learningRate the learning rate
@@ -97,35 +117,75 @@ public AdaDelta(Graph graph, float learningRate) {
    * @param epsilon A constant epsilon used to better conditioning the grad update
    */
   public AdaDelta(Graph graph, float learningRate, float rho, float epsilon) {
-    super(graph);
-    this.learningRate = learningRate;
-    this.rho = rho;
-    this.epsilon = epsilon;
+    this(graph, null, learningRate, rho, epsilon);
   }
 
   /**
    * Creates an AdaDelta Optimizer
    *
    * @param graph the TensorFlow Graph
-   * @param name the name for this Optimizer (defaults to 'Adadelta')
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param rho The decay factor
+   * @param epsilon A constant epsilon used to better conditioning the grad update
+   */
+  public AdaDelta(Graph graph, Operand<TFloat32> learningRateOperand, float rho, float epsilon) {
+    this(graph, null, learningRateOperand, rho, epsilon);
+  }
+
+  /**
+   * Creates an AdaDelta Optimizer using {@link #RHO_DEFAULT} for the rho, and {@link *
+   * #EPSILON_DEFAULT} for the epsilon.
+   *
+   * @param graph the TensorFlow Graph
+   * @param name the name for this Optimizer.
    * @param learningRate the learning rate
    */
   public AdaDelta(Graph graph, String name, float learningRate) {
-    this(graph, name, learningRate, 0.95f, 1e-8f);
+    this(graph, name, learningRate, RHO_DEFAULT, EPSILON_DEFAULT);
+  }
+
+  /**
+   * Creates an AdaDelta Optimizer using {@link #RHO_DEFAULT} for the rho, and {@link *
+   * #EPSILON_DEFAULT} for the epsilon.
+   *
+   * @param graph the TensorFlow Graph
+   * @param name the name for this Optimizer.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public AdaDelta(Graph graph, String name, Operand<TFloat32> learningRateOperand) {
+    this(graph, name, learningRateOperand, RHO_DEFAULT, EPSILON_DEFAULT);
   }
 
   /**
    * Creates an AdaDelta Optimizer
    *
    * @param graph the TensorFlow Graph
-   * @param name the name for this Optimizer (defaults to 'Adadelta')
+   * @param name the name for this Optimizer.
    * @param learningRate the learning rate
    * @param rho The decay factor
    * @param epsilon A constant epsilon used to better conditioning the grad update
    */
   public AdaDelta(Graph graph, String name, float learningRate, float rho, float epsilon) {
-    super(graph, name);
-    this.learningRate = learningRate;
+    super(graph, name, learningRate);
+    this.rho = rho;
+    this.epsilon = epsilon;
+  }
+
+  /**
+   * Creates an AdaDelta Optimizer
+   *
+   * @param graph the TensorFlow Graph
+   * @param name the name for this Optimizer.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param rho The decay factor
+   * @param epsilon A constant epsilon used to better conditioning the grad update
+   */
+  public AdaDelta(
+      Graph graph, String name, Operand<TFloat32> learningRateOperand, float rho, float epsilon) {
+    super(graph, name, learningRateOperand);
     this.rho = rho;
     this.epsilon = epsilon;
   }
@@ -162,7 +222,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
         variable,
         accumSlot,
         accumUpdateSlot,
-        tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()),
+        tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()),
         tf.dtypes.cast(tf.constant(rho), gradient.dataType()),
         tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()),
         gradient);
@@ -184,6 +244,6 @@ public String toString() {
   /** {@inheritDoc} */
   @Override
   public String getOptimizerName() {
-    return "Adadelta";
+    return DEFAULT_NAME;
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java
index 6855be30759..c0cc47409d7 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java
@@ -20,6 +20,7 @@
 import org.tensorflow.Output;
 import org.tensorflow.op.Op;
 import org.tensorflow.op.core.Variable;
+import org.tensorflow.types.TFloat32;
 import org.tensorflow.types.family.TType;
 
 import java.util.List;
@@ -40,16 +41,18 @@
  */
 public class AdaGrad extends Optimizer {
 
+  public static final String DEFAULT_NAME = "Adagrad";
+
   public static final String ACCUMULATOR = "accumulator";
   public static final float LEARNING_RATE_DEFAULT = 0.001f;
   public static final float INITIAL_ACCUMULATOR_DEFAULT = 0.01f;
 
-  private final float learningRate;
-
   private final float initialAccumulatorValue;
 
   /**
-   * Creates an AdaGrad Optimizer
+   * Creates an AdaGrad Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #LEARNING_RATE_DEFAULT} for the learning rate, and {@link * #INITIAL_ACCUMULATOR_DEFAULT} for
+   * the initialAccumulatorValue.
    *
    * @param graph the TensorFlow Graph
    */
@@ -58,7 +61,8 @@ public AdaGrad(Graph graph) {
   }
 
   /**
-   * Creates an AdaGrad Optimizer
+   * Creates an AdaGrad Optimizer using using {@link #DEFAULT_NAME} for the Optimizer name, {@link *
+   * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue.
    *
    * @param graph the TensorFlow Graph
    * @param learningRate the learning rate
@@ -68,7 +72,19 @@ public AdaGrad(Graph graph, float learningRate) {
   }
 
   /**
-   * Creates an AdaGrad Optimizer
+   * Creates an AdaGrad Optimizer using using {@link #DEFAULT_NAME} for the Optimizer name, {@link *
+   * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue.
+   *
+   * @param graph the TensorFlow Graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public AdaGrad(Graph graph, Operand<TFloat32> learningRateOperand) {
+    this(graph, learningRateOperand, INITIAL_ACCUMULATOR_DEFAULT);
+  }
+
+  /**
+   * Creates an AdaGrad Optimizer using {@link #DEFAULT_NAME} for the Optimizer name,
    *
    * @param graph the TensorFlow Graph
    * @param learningRate the learning rate
@@ -76,44 +92,88 @@ public AdaGrad(Graph graph, float learningRate) {
    * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative
    */
   public AdaGrad(Graph graph, float learningRate, float initialAccumulatorValue) {
-    super(graph);
-    if (initialAccumulatorValue < 0F) {
-      throw new IllegalArgumentException(
-          String.format(
-              "initialAccumulatorValue must be non-negative: %f", initialAccumulatorValue));
-    }
-    this.learningRate = learningRate;
-    this.initialAccumulatorValue = initialAccumulatorValue;
+    this(graph, null, learningRate, initialAccumulatorValue);
   }
 
   /**
-   * Creates an AdaGrad Optimizer
+   * Creates an AdaGrad Optimizer using {@link #DEFAULT_NAME} for the Optimizer name,
    *
    * @param graph the TensorFlow Graph
-   * @param name the name for this Optimizer (defaults to 'Adagrad')
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param initialAccumulatorValue Starting value for the accumulators, must be non-negative.
+   * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative
+   */
+  public AdaGrad(
+      Graph graph, Operand<TFloat32> learningRateOperand, float initialAccumulatorValue) {
+    this(graph, null, learningRateOperand, initialAccumulatorValue);
+  }
+
+  /**
+   * Creates an AdaGrad Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the
+   * initialAccumulatorValue.
+   *
+   * @param graph the TensorFlow Graph
+   * @param name the name for this Optimizer .
    * @param learningRate the learning rate
    */
   public AdaGrad(Graph graph, String name, float learningRate) {
-    this(graph, name, learningRate, 0.01f);
+    this(graph, name, learningRate, INITIAL_ACCUMULATOR_DEFAULT);
+  }
+
+  /**
+   * Creates an AdaGrad Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the
+   * initialAccumulatorValue.
+   *
+   * @param graph the TensorFlow Graph
+   * @param name the name for this Optimizer.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public AdaGrad(Graph graph, String name, Operand<TFloat32> learningRateOperand) {
+    this(graph, name, learningRateOperand, INITIAL_ACCUMULATOR_DEFAULT);
   }
 
   /**
    * Creates an AdaGrad Optimizer
    *
    * @param graph the TensorFlow Graph
-   * @param name the name for this Optimizer (defaults to 'Adagrad')
+   * @param name the name for this Optimizer
    * @param learningRate the learning rate
    * @param initialAccumulatorValue Starting value for the accumulators, must be non-negative.
    * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative
    */
   public AdaGrad(Graph graph, String name, float learningRate, float initialAccumulatorValue) {
-    super(graph, name);
+    super(graph, name, learningRate);
+    if (initialAccumulatorValue < 0F) {
+      throw new IllegalArgumentException(
+          String.format(
+              "initialAccumulatorValue must be non-negative: %f", initialAccumulatorValue));
+    }
+    this.initialAccumulatorValue = initialAccumulatorValue;
+  }
+
+  /**
+   * Creates an AdaGrad Optimizer
+   *
+   * @param graph the TensorFlow Graph
+   * @param name the name for this Optimizer
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param initialAccumulatorValue Starting value for the accumulators, must be non-negative.
+   * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative
+   */
+  public AdaGrad(
+      Graph graph,
+      String name,
+      Operand<TFloat32> learningRateOperand,
+      float initialAccumulatorValue) {
+    super(graph, name, learningRateOperand);
     if (initialAccumulatorValue < 0F) {
       throw new IllegalArgumentException(
           String.format(
               "initialAccumulatorValue must be non-negative: %f", initialAccumulatorValue));
     }
-    this.learningRate = learningRate;
     this.initialAccumulatorValue = initialAccumulatorValue;
   }
 
@@ -142,7 +202,7 @@ private <T extends TType> void createAdaGradSlot(Output<T> v) {
   protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
     Variable<T> slot = getSlot(variable, ACCUMULATOR).get();
     return tf.train.applyAdagrad(
-        variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), gradient);
+        variable, slot, tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), gradient);
   }
 
   /** {@inheritDoc} */
@@ -159,6 +219,6 @@ public String toString() {
   /** {@inheritDoc} */
   @Override
   public String getOptimizerName() {
-    return "Adagrad";
+    return DEFAULT_NAME;
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java
index e74a1d85359..0e070f2f4fa 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java
@@ -22,6 +22,7 @@
 import org.tensorflow.op.Op;
 import org.tensorflow.op.core.Assign;
 import org.tensorflow.op.core.Variable;
+import org.tensorflow.types.TFloat32;
 import org.tensorflow.types.TInt64;
 import org.tensorflow.types.family.TType;
 
@@ -46,6 +47,7 @@
  */
 public class AdaGradDA extends Optimizer {
 
+  public static final String DEFAULT_NAME = "adagrad-da";
   public static final String ACCUMULATOR = "gradient_accumulator";
   public static final String SQUARED_ACCUMULATOR = "gradient_squared_accumulator";
   public static final float LEARNING_RATE_DEFAULT = 0.001F;
@@ -53,14 +55,16 @@ public class AdaGradDA extends Optimizer {
   public static final float L1_STRENGTH_DEFAULT = 0.0F;
   public static final float L2_STRENGTH_DEFAULT = 0.0F;
 
-  private final float learningRate;
   private final float initialAccumulatorValue;
   private final float l1Strength;
   private final float l2Strength;
   private Variable<TInt64> globalStep;
 
   /**
-   * Creates an AdaGradDA Optimizer
+   * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #INITIAL_ACCUMULATOR_DEFAULT} for the
+   * initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for the l1Strength, and {@link
+   * #L2_STRENGTH_DEFAULT} for the l2Strength.
    *
    * @param graph the TensorFlow Graph
    */
@@ -74,7 +78,9 @@ public AdaGradDA(Graph graph) {
   }
 
   /**
-   * Creates an AdaGradDA Optimizer
+   * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for
+   * the l1Strength, and {@link #L2_STRENGTH_DEFAULT} for the l2Strength.
    *
    * @param graph the TensorFlow Graph
    * @param learningRate the learning rate
@@ -85,7 +91,25 @@ public AdaGradDA(Graph graph, float learningRate) {
   }
 
   /**
-   * Creates an AdaGradDA Optimizer
+   * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for
+   * the l1Strength, and {@link #L2_STRENGTH_DEFAULT} for the l2Strength.
+   *
+   * @param graph the TensorFlow Graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public AdaGradDA(Graph graph, Operand<TFloat32> learningRateOperand) {
+    this(
+        graph,
+        learningRateOperand,
+        INITIAL_ACCUMULATOR_DEFAULT,
+        L1_STRENGTH_DEFAULT,
+        L2_STRENGTH_DEFAULT);
+  }
+
+  /**
+   * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
    *
    * @param graph the TensorFlow Graph
    * @param learningRate the learning rate
@@ -101,31 +125,37 @@ public AdaGradDA(
       float initialAccumulatorValue,
       float l1Strength,
       float l2Strength) {
-    super(graph);
-    if (initialAccumulatorValue <= 0F) {
-      throw new IllegalArgumentException(
-          String.format(
-              "initialAccumulatorValue must be greater than zero: %f", initialAccumulatorValue));
-    }
-    if (l1Strength < 0F) {
-      throw new IllegalArgumentException(
-          String.format("l1Strength must not be negative: %f", l1Strength));
-    }
-    if (l2Strength < 0F) {
-      throw new IllegalArgumentException(
-          String.format("l2Strength must not be negative: %f", l2Strength));
-    }
-    this.learningRate = learningRate;
-    this.initialAccumulatorValue = initialAccumulatorValue;
-    this.l1Strength = l1Strength;
-    this.l2Strength = l2Strength;
+    this(graph, null, learningRate, initialAccumulatorValue, l1Strength, l2Strength);
   }
 
   /**
-   * Creates an AdaGradDA Optimizer
+   * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
+   *
+   * @param graph the TensorFlow Graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param initialAccumulatorValue Starting value for the accumulators, must be greater than zero.
+   * @param l1Strength l1 regularization strength, must be greater than or equal to zero.
+   * @param l2Strength l2 regularization strength, must be greater than or equal to zero.
+   * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is not greater than zero,
+   *     or l1Strength or l2Strength is less than zero
+   */
+  public AdaGradDA(
+      Graph graph,
+      Operand<TFloat32> learningRateOperand,
+      float initialAccumulatorValue,
+      float l1Strength,
+      float l2Strength) {
+    this(graph, null, learningRateOperand, initialAccumulatorValue, l1Strength, l2Strength);
+  }
+
+  /**
+   * Creates an AdaGradDA Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the
+   * initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for the l1Strength, and {@link
+   * #L2_STRENGTH_DEFAULT} for the l2Strength.
    *
    * @param graph the TensorFlow Graph
-   * @param name the name for this Optimizer (defaults to 'adagrad-da')
+   * @param name the name for this Optimizer.
    * @param learningRate the learning rate
    */
   public AdaGradDA(Graph graph, String name, float learningRate) {
@@ -138,11 +168,31 @@ public AdaGradDA(Graph graph, String name, float learningRate) {
         L2_STRENGTH_DEFAULT);
   }
 
+  /**
+   * Creates an AdaGradDA Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the
+   * initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for the l1Strength, and {@link
+   * #L2_STRENGTH_DEFAULT} for the l2Strength.
+   *
+   * @param graph the TensorFlow Graph
+   * @param name the name for this Optimizer.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public AdaGradDA(Graph graph, String name, Operand<TFloat32> learningRateOperand) {
+    this(
+        graph,
+        name,
+        learningRateOperand,
+        INITIAL_ACCUMULATOR_DEFAULT,
+        L1_STRENGTH_DEFAULT,
+        L2_STRENGTH_DEFAULT);
+  }
+
   /**
    * Creates an AdaGradDA Optimizer
    *
    * @param graph the TensorFlow Graph
-   * @param name the name for this Optimizer (defaults to 'adagrad-da')
+   * @param name the name for this Optimizer.
    * @param learningRate the learning rate
    * @param initialAccumulatorValue Starting value for the accumulators, must be positive
    * @param l1Strength l1 regularization strength, must be greater than or equal to zero.
@@ -157,7 +207,46 @@ public AdaGradDA(
       float initialAccumulatorValue,
       float l1Strength,
       float l2Strength) {
-    super(graph, name);
+    super(graph, name, learningRate);
+    if (initialAccumulatorValue <= 0F) {
+      throw new IllegalArgumentException(
+          String.format(
+              "initialAccumulatorValue must be greater than zero: %f", initialAccumulatorValue));
+    }
+    if (l1Strength < 0F) {
+      throw new IllegalArgumentException(
+          String.format("l1Strength must not be negative: %f", l1Strength));
+    }
+    if (l2Strength < 0F) {
+      throw new IllegalArgumentException(
+          String.format("l2Strength must not be negative: %f", l2Strength));
+    }
+    this.initialAccumulatorValue = initialAccumulatorValue;
+    this.l1Strength = l1Strength;
+    this.l2Strength = l2Strength;
+  }
+
+  /**
+   * Creates an AdaGradDA Optimizer
+   *
+   * @param graph the TensorFlow Graph
+   * @param name the name for this Optimizer.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param initialAccumulatorValue Starting value for the accumulators, must be positive
+   * @param l1Strength l1 regularization strength, must be greater than or equal to zero.
+   * @param l2Strength l2 regularization strength, must be greater than or equal to zero.
+   * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is not greater than zero,
+   *     or * l1Strength or l2Strength is less than zero
+   */
+  public AdaGradDA(
+      Graph graph,
+      String name,
+      Operand<TFloat32> learningRateOperand,
+      float initialAccumulatorValue,
+      float l1Strength,
+      float l2Strength) {
+    super(graph, name, learningRateOperand);
     if (initialAccumulatorValue <= 0F) {
       throw new IllegalArgumentException(
           String.format(
@@ -171,7 +260,6 @@ public AdaGradDA(
       throw new IllegalArgumentException(
           String.format("l2Strength must not be negative: %f", l2Strength));
     }
-    this.learningRate = learningRate;
     this.initialAccumulatorValue = initialAccumulatorValue;
     this.l1Strength = l1Strength;
     this.l2Strength = l2Strength;
@@ -218,7 +306,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
         gradSlot,
         gradSquaredSlot,
         gradient,
-        tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()),
+        tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()),
         tf.dtypes.cast(tf.constant(l1Strength), gradient.dataType()),
         tf.dtypes.cast(tf.constant(l2Strength), gradient.dataType()),
         globalStep);
@@ -259,6 +347,6 @@ public String toString() {
   /** {@inheritDoc} */
   @Override
   public String getOptimizerName() {
-    return "adagrad-da";
+    return DEFAULT_NAME;
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java
index 8f620678781..9e4f41f1039 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java
@@ -48,6 +48,8 @@
 @Operator
 public class Adam extends Optimizer {
 
+  public static final String DEFAULT_NAME = "Adam";
+
   public static final String FIRST_MOMENT = "m";
   public static final String SECOND_MOMENT = "v";
 
@@ -56,15 +58,12 @@ public class Adam extends Optimizer {
   public static final float BETA_ONE_DEFAULT = 0.9f;
   public static final float BETA_TWO_DEFAULT = 0.999f;
 
-  private final float learningRate;
-
   private final float betaOne;
 
   private final float betaTwo;
 
   private final float epsilon;
 
-  private Constant<TFloat32> learningRateConst;
   private Constant<TFloat32> epsilonConst;
   private Constant<TFloat32> betaOneConst;
   private Constant<TFloat32> betaTwoConst;
@@ -72,7 +71,9 @@ public class Adam extends Optimizer {
   private Variable<TFloat32> betaTwoPower;
 
   /**
-   * Creates an Adam optimizer
+   * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value,
+   * {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon.
    *
    * @param graph the TensorFlow graph
    */
@@ -81,7 +82,9 @@ public Adam(Graph graph) {
   }
 
   /**
-   * Creates an Adam optimizer
+   * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and
+   * {@link #EPSILON_DEFAULT} for the epsilon.
    *
    * @param graph the TensorFlow graph
    * @param learningRate the learning rate
@@ -89,53 +92,121 @@ public Adam(Graph graph) {
   public Adam(Graph graph, float learningRate) {
     this(graph, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT);
   }
+  /**
+   * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and
+   * {@link #EPSILON_DEFAULT} for the epsilon.
+   *
+   * @param graph the TensorFlow graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public Adam(Graph graph, Operand<TFloat32> learningRateOperand) {
+    this(graph, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT);
+  }
 
   /**
-   * Creates an Adam optimizer
+   * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
    *
    * @param graph the TensorFlow graph
    * @param learningRate the learning rate
-   * @param betaOne The exponential decay rate for the 1st moment estimates. Defaults to 0.9.
-   * @param betaTwo The exponential decay rate for the 2nd moment estimates. Defaults to 0.999.
+   * @param betaOne The exponential decay rate for the 1st moment estimates.
+   * @param betaTwo The exponential decay rate for the 2nd moment estimates.
    * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the
    *     Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm
-   *     1 of the paper. Defaults to 1e-8.
+   *     1 of the paper..
    */
   public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) {
-    super(graph);
-    this.learningRate = learningRate;
-    this.betaOne = betaOne;
-    this.betaTwo = betaTwo;
-    this.epsilon = epsilon;
+    this(graph, null, learningRate, betaOne, betaTwo, epsilon);
   }
 
   /**
-   * Creates an Adam optimizer
+   * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
    *
    * @param graph the TensorFlow graph
-   * @param name the Optimizer name, defaults to "Adam"
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param betaOne The exponential decay rate for the 1st moment estimates.
+   * @param betaTwo The exponential decay rate for the 2nd moment estimates.
+   * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the
+   *     Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm
+   *     1 of the paper.
+   */
+  public Adam(
+      Graph graph,
+      Operand<TFloat32> learningRateOperand,
+      float betaOne,
+      float betaTwo,
+      float epsilon) {
+    this(graph, null, learningRateOperand, betaOne, betaTwo, epsilon);
+  }
+
+  /**
+   * Creates an Adam optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link
+   * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon.
+   *
+   * @param graph the TensorFlow graph
+   * @param name the Optimizer name.
    * @param learningRate the learning rate
    */
   public Adam(Graph graph, String name, float learningRate) {
     this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT);
   }
 
+  /**
+   * Creates an Adam optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link
+   * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon.
+   *
+   * @param graph the TensorFlow graph
+   * @param name the Optimizer name.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public Adam(Graph graph, String name, Operand<TFloat32> learningRateOperand) {
+    this(graph, name, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT);
+  }
+
   /**
    * Creates an Adam optimizer
    *
    * @param graph the TensorFlow graph
-   * @param name the Optimizer name, defaults to "Adam"
+   * @param name the Optimizer name.
    * @param learningRate the learning rate
-   * @param betaOne The exponential decay rate for the 1st moment estimates. Defaults to 0.9.
-   * @param betaTwo The exponential decay rate for the 2nd moment estimates. Defaults to 0.999.
+   * @param betaOne The exponential decay rate for the 1st moment estimates.
+   * @param betaTwo The exponential decay rate for the 2nd moment estimates.
    * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the
    *     Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm
-   *     1 of the paper. Defaults to 1e-8.
+   *     1 of the paper.
    */
   public Adam(
       Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) {
-    super(graph, name);
-    this.learningRate = learningRate;
+    super(graph, name, learningRate);
+    this.betaOne = betaOne;
+    this.betaTwo = betaTwo;
+    this.epsilon = epsilon;
+  }
+
+  /**
+   * Creates an Adam optimizer
+   *
+   * @param graph the TensorFlow graph
+   * @param name the Optimizer name.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param betaOne The exponential decay rate for the 1st moment estimates.
+   * @param betaTwo The exponential decay rate for the 2nd moment estimates.
+   * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the
+   *     Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm
+   *     1 of the paper.
+   */
+  public Adam(
+      Graph graph,
+      String name,
+      Operand<TFloat32> learningRateOperand,
+      float betaOne,
+      float betaTwo,
+      float epsilon) {
+    super(graph, name, learningRateOperand);
     this.betaOne = betaOne;
     this.betaTwo = betaTwo;
     this.epsilon = epsilon;
@@ -202,7 +273,6 @@ protected void createSlots(List<Output<? extends TType>> variables) {
   protected Optional<Op> prepare(String scopeName) {
     betaOneConst = tf.constant(betaOne);
     betaTwoConst = tf.constant(betaTwo);
-    learningRateConst = tf.constant(learningRate);
     epsilonConst = tf.constant(epsilon);
     return Optional.empty();
   }
@@ -233,7 +303,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
         secondMomentSlot,
         tf.dtypes.cast(betaOnePower, gradient.dataType()),
         tf.dtypes.cast(betaTwoPower, gradient.dataType()),
-        tf.dtypes.cast(learningRateConst, gradient.dataType()),
+        tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()),
         tf.dtypes.cast(betaOneConst, gradient.dataType()),
         tf.dtypes.cast(betaTwoConst, gradient.dataType()),
         tf.dtypes.cast(epsilonConst, gradient.dataType()),
@@ -274,6 +344,6 @@ public String toString() {
   /** {@inheritDoc} */
   @Override
   public String getOptimizerName() {
-    return "Adam";
+    return DEFAULT_NAME;
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java
index 335d83cedfa..e33775db961 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java
@@ -25,6 +25,7 @@
  */
 public class Adamax extends Optimizer {
 
+  public static final String DEFAULT_NAME = "Adamax";
   public static final String FIRST_MOMENT = "m";
   public static final String SECOND_MOMENT = "v";
 
@@ -33,19 +34,20 @@ public class Adamax extends Optimizer {
   public static final float BETA_ONE_DEFAULT = 0.9f;
   public static final float BETA_TWO_DEFAULT = 0.999f;
 
-  private float learningRate;
   private final float betaOne;
   private final float betaTwo;
   private final float epsilon;
 
-  private Constant<TFloat32> learningRateConst;
   private Constant<TFloat32> epsilonConst;
   private Constant<TFloat32> betaOneConst;
   private Constant<TFloat32> betaTwoConst;
   private Variable<TFloat32> betaOnePower;
 
   /**
-   * Creates an Optimizer that implements the Adamax algorithm.
+   * Creates an Optimizer that implements the Adamax algorithm, using {@link #DEFAULT_NAME} for the
+   * Optimizer name, {@link #LEARNING_RATE_DEFAULT} for the learning rate, {@link #BETA_ONE_DEFAULT}
+   * for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link
+   * #EPSILON_DEFAULT} for the epsilon.
    *
    * @param graph the TensorFlow graph
    */
@@ -54,17 +56,21 @@ public Adamax(Graph graph) {
   }
 
   /**
-   * Creates an Optimizer that implements the Adamax algorithm.
+   * Creates an Optimizer that implements the Adamax algorithm, {@link #LEARNING_RATE_DEFAULT} for
+   * the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT}
+   * for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon.
    *
    * @param graph the TensorFlow graph
-   * @param name name for the operations Created when applying gradients. Defaults to "Adamax".
+   * @param name name for the operations Created when applying gradients.
    */
   public Adamax(Graph graph, String name) {
     this(graph, name, LEARNING_RATE_DEFAULT, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT);
   }
 
   /**
-   * Creates an Optimizer that implements the Adamax algorithm.
+   * Creates an Optimizer that implements the Adamax algorithm, using {@link #DEFAULT_NAME} for the
+   * Optimizer name, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for
+   * the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon.
    *
    * @param graph the TensorFlow graph
    * @param learningRate The learning rate.
@@ -74,18 +80,48 @@ public Adamax(Graph graph, float learningRate) {
   }
 
   /**
-   * Creates an Optimizer that implements the Adamax algorithm.
+   * Creates an Optimizer that implements the Adamax algorithm, using {@link #DEFAULT_NAME} for the
+   * Optimizer name, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for
+   * the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon.
+   *
+   * @param graph the TensorFlow graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public Adamax(Graph graph, Operand<TFloat32> learningRateOperand) {
+    this(graph, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT);
+  }
+
+  /**
+   * Creates an Optimizer that implements the Adamax algorithm, using {@link #BETA_ONE_DEFAULT} for
+   * the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link
+   * #EPSILON_DEFAULT} for the epsilon.
    *
    * @param graph the TensorFlow graph
-   * @param name name for the operations Created when applying gradients. Defaults to "Adamax".
+   * @param name name for the operations Created when applying gradients.
    * @param learningRate The learning rate.
    */
   public Adamax(Graph graph, String name, float learningRate) {
     this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT);
   }
+  /**
+   * Creates an Optimizer that implements the Adamax algorithm, using {@link #BETA_ONE_DEFAULT} for
+   * the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link
+   * #EPSILON_DEFAULT} for the epsilon.
+   *
+   * @param graph the TensorFlow graph
+   * @param name name for the operations Created when applying gradients.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public Adamax(Graph graph, String name, Operand<TFloat32> learningRateOperand) {
+    this(graph, name, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT);
+  }
 
   /**
-   * Creates an Optimizer that implements the Adamax algorithm.
+   * Creates an Optimizer that implements the Adamax algorithm, {@link #LEARNING_RATE_DEFAULT} for
+   * the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT}
+   * for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon.
    *
    * @param graph the TensorFlow graph
    * @param learningRate The learning rate.
@@ -94,8 +130,42 @@ public Adamax(Graph graph, String name, float learningRate) {
    * @param epsilon A small constant for numerical stability.
    */
   public Adamax(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) {
-    super(graph);
-    this.learningRate = learningRate;
+    this(graph, null, learningRate, betaOne, betaTwo, epsilon);
+  }
+  /**
+   * Creates an Optimizer that implements the Adamax algorithm, {@link #LEARNING_RATE_DEFAULT} for
+   * the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT}
+   * for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon.
+   *
+   * @param graph the TensorFlow graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param betaOne The exponential decay rate for the 1st moment estimates.
+   * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm.
+   * @param epsilon A small constant for numerical stability.
+   */
+  public Adamax(
+      Graph graph,
+      Operand<TFloat32> learningRateOperand,
+      float betaOne,
+      float betaTwo,
+      float epsilon) {
+    this(graph, null, learningRateOperand, betaOne, betaTwo, epsilon);
+  }
+
+  /**
+   * Creates an Optimizer that implements the Adamax algorithm.
+   *
+   * @param graph the TensorFlow graph
+   * @param name name for the operations Created when applying gradients.
+   * @param learningRate The learning rate.
+   * @param betaOne The exponential decay rate for the 1st moment estimates.
+   * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm.
+   * @param epsilon A small constant for numerical stability.
+   */
+  public Adamax(
+      Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) {
+    super(graph, name, learningRate);
     this.betaOne = betaOne;
     this.betaTwo = betaTwo;
     this.epsilon = epsilon;
@@ -105,16 +175,21 @@ public Adamax(Graph graph, float learningRate, float betaOne, float betaTwo, flo
    * Creates an Optimizer that implements the Adamax algorithm.
    *
    * @param graph the TensorFlow graph
-   * @param name name for the operations Created when applying gradients. Defaults to "Adamax".
-   * @param learningRate The learning rate.
+   * @param name name for the operations Created when applying gradients.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
    * @param betaOne The exponential decay rate for the 1st moment estimates.
    * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm.
    * @param epsilon A small constant for numerical stability.
    */
   public Adamax(
-      Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) {
-    super(graph, name);
-    this.learningRate = learningRate;
+      Graph graph,
+      String name,
+      Operand<TFloat32> learningRateOperand,
+      float betaOne,
+      float betaTwo,
+      float epsilon) {
+    super(graph, name, learningRateOperand);
     this.betaOne = betaOne;
     this.betaTwo = betaTwo;
     this.epsilon = epsilon;
@@ -125,7 +200,6 @@ public Adamax(
   protected Optional<Op> prepare(String scopeName) {
     betaOneConst = tf.constant(betaOne);
     betaTwoConst = tf.constant(betaTwo);
-    learningRateConst = tf.constant(learningRate);
     epsilonConst = tf.constant(epsilon);
 
     return Optional.empty();
@@ -168,7 +242,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
         firstMomentSlot,
         secondMomentSlot,
         tf.dtypes.cast(betaOnePower, gradient.dataType()),
-        tf.dtypes.cast(learningRateConst, gradient.dataType()),
+        tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()),
         tf.dtypes.cast(betaOneConst, gradient.dataType()),
         tf.dtypes.cast(betaTwoConst, gradient.dataType()),
         tf.dtypes.cast(epsilonConst, gradient.dataType()),
@@ -185,6 +259,6 @@ protected Op finish(List<Op> updateOperations, String name) {
   /** {@inheritDoc} */
   @Override
   public String getOptimizerName() {
-    return "Adamax";
+    return DEFAULT_NAME;
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java
index 13f68e4bbef..35eeb7dc225 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java
@@ -6,6 +6,7 @@
 import org.tensorflow.op.Op;
 import org.tensorflow.op.core.Variable;
 import org.tensorflow.op.train.ApplyFtrl;
+import org.tensorflow.types.TFloat32;
 import org.tensorflow.types.family.TType;
 
 import java.util.List;
@@ -20,6 +21,8 @@
  */
 public class Ftrl extends Optimizer {
 
+  public static final String DEFAULT_NAME = "Ftrl";
+
   public static final String ACCUMULATOR = "gradient_accumulator";
   public static final String LINEAR_ACCUMULATOR = "linear_accumulator";
 
@@ -30,7 +33,6 @@ public class Ftrl extends Optimizer {
   public static final float L2STRENGTH_DEFAULT = 0.0f;
   public static final float L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT = 0.0f;
 
-  private float learningRate;
   private final float learningRatePower;
   private final float initialAccumulatorValue;
   private final float l1RegularizationStrength;
@@ -38,7 +40,12 @@ public class Ftrl extends Optimizer {
   private final float l2ShrinkageRegularizationStrength;
 
   /**
-   * Creates a Ftrl Optimizer
+   * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #LEARNING_RATE_POWER_DEFAULT} for the
+   * learningRatePower. {@link #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue,
+   * {@link #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength
+   * and {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the
+   * l2ShrinkageRegularizationStrength.
    *
    * @param graph the TensorFlow Graph
    */
@@ -54,7 +61,12 @@ public Ftrl(Graph graph) {
   }
 
   /**
-   * Creates a Ftrl Optimizer
+   * Creates an Ftrl Optimizer using {@link #LEARNING_RATE_DEFAULT} for the learning rate, {@link
+   * #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. {@link
+   * #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link
+   * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and
+   * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the
+   * l2ShrinkageRegularizationStrength.
    *
    * @param graph the TensorFlow Graph
    * @param name the name of this Optimizer
@@ -72,7 +84,12 @@ public Ftrl(Graph graph, String name) {
   }
 
   /**
-   * Creates a Ftrl Optimizer
+   * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. {@link
+   * #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link
+   * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and
+   * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the
+   * l2ShrinkageRegularizationStrength.
    *
    * @param graph the TensorFlow Graph
    * @param learningRate the learning rate
@@ -89,7 +106,34 @@ public Ftrl(Graph graph, float learningRate) {
   }
 
   /**
-   * Creates a Ftrl Optimizer
+   * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. {@link
+   * #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link
+   * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and
+   * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the
+   * l2ShrinkageRegularizationStrength.
+   *
+   * @param graph the TensorFlow Graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public Ftrl(Graph graph, Operand<TFloat32> learningRateOperand) {
+    this(
+        graph,
+        learningRateOperand,
+        LEARNING_RATE_POWER_DEFAULT,
+        INITIAL_ACCUMULATOR_VALUE_DEFAULT,
+        L1STRENGTH_DEFAULT,
+        L2STRENGTH_DEFAULT,
+        L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT);
+  }
+
+  /**
+   * Creates an Ftrl Optimizer using {@link #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower.
+   * {@link #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link
+   * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and
+   * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the
+   * l2ShrinkageRegularizationStrength.
    *
    * @param graph the TensorFlow Graph
    * @param name the name of this Optimizer
@@ -107,10 +151,110 @@ public Ftrl(Graph graph, String name, float learningRate) {
         L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT);
   }
 
+  /**
+   * Creates an Ftrl Optimizer using {@link #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower.
+   * {@link #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link
+   * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and
+   * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the
+   * l2ShrinkageRegularizationStrength.
+   *
+   * @param graph the TensorFlow Graph
+   * @param name the name of this Optimizer
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public Ftrl(Graph graph, String name, Operand<TFloat32> learningRateOperand) {
+    this(
+        graph,
+        name,
+        learningRateOperand,
+        LEARNING_RATE_POWER_DEFAULT,
+        INITIAL_ACCUMULATOR_VALUE_DEFAULT,
+        L1STRENGTH_DEFAULT,
+        L2STRENGTH_DEFAULT,
+        L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT);
+  }
+
+  /**
+   * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
+   *
+   * @param graph the TensorFlow Graph
+   * @param learningRate the learning rate
+   * @param learningRatePower Controls how the learning rate decreases during training. Use zero for
+   *     a fixed learning rate.
+   * @param initialAccumulatorValue The starting value for accumulators. Only zero or positive
+   *     values are allowed.
+   * @param l1Strength the L1 Regularization strength, must be greater than or equal to zero.
+   * @param l2Strength the L2 Regularization strength, must be greater than or equal to zero.
+   * @param l2ShrinkageRegularizationStrength This differs from L2 above in that the L2 above is a
+   *     stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. must be greater
+   *     than or equal to zero.
+   * @throws java.lang.IllegalArgumentException if the initialAccumulatorValue,
+   *     l1RegularizationStrength, l2RegularizationStrength, or l2ShrinkageRegularizationStrength
+   *     are less than 0.0, or learningRatePower is greater than 0.0.
+   */
+  public Ftrl(
+      Graph graph,
+      float learningRate,
+      float learningRatePower,
+      float initialAccumulatorValue,
+      float l1Strength,
+      float l2Strength,
+      float l2ShrinkageRegularizationStrength) {
+    this(
+        graph,
+        null,
+        learningRate,
+        learningRatePower,
+        initialAccumulatorValue,
+        l1Strength,
+        l2Strength,
+        l2ShrinkageRegularizationStrength);
+  }
+
+  /**
+   * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
+   *
+   * @param graph the TensorFlow Graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param learningRatePower Controls how the learning rate decreases during training. Use zero for
+   *     a fixed learning rate.
+   * @param initialAccumulatorValue The starting value for accumulators. Only zero or positive
+   *     values are allowed.
+   * @param l1Strength the L1 Regularization strength, must be greater than or equal to zero.
+   * @param l2Strength the L2 Regularization strength, must be greater than or equal to zero.
+   * @param l2ShrinkageRegularizationStrength This differs from L2 above in that the L2 above is a
+   *     stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. must be greater
+   *     than or equal to zero.
+   * @throws java.lang.IllegalArgumentException if the initialAccumulatorValue,
+   *     l1RegularizationStrength, l2RegularizationStrength, or l2ShrinkageRegularizationStrength
+   *     are less than 0.0, or learningRatePower is greater than 0.0.
+   */
+  public Ftrl(
+      Graph graph,
+      Operand<TFloat32> learningRateOperand,
+      float learningRatePower,
+      float initialAccumulatorValue,
+      float l1Strength,
+      float l2Strength,
+      float l2ShrinkageRegularizationStrength) {
+    this(
+        graph,
+        null,
+        learningRateOperand,
+        learningRatePower,
+        initialAccumulatorValue,
+        l1Strength,
+        l2Strength,
+        l2ShrinkageRegularizationStrength);
+  }
+
   /**
    * Creates a Ftrl Optimizer
    *
    * @param graph the TensorFlow Graph
+   * @param name the name of this Optimizer
    * @param learningRate the learning rate
    * @param learningRatePower Controls how the learning rate decreases during training. Use zero for
    *     a fixed learning rate.
@@ -127,14 +271,14 @@ public Ftrl(Graph graph, String name, float learningRate) {
    */
   public Ftrl(
       Graph graph,
+      String name,
       float learningRate,
       float learningRatePower,
       float initialAccumulatorValue,
       float l1Strength,
       float l2Strength,
       float l2ShrinkageRegularizationStrength) {
-    super(graph);
-    this.learningRate = learningRate;
+    super(graph, name, learningRate);
     this.learningRatePower = learningRatePower;
     this.initialAccumulatorValue = initialAccumulatorValue;
     this.l1RegularizationStrength = l1Strength;
@@ -148,7 +292,8 @@ public Ftrl(
    *
    * @param graph the TensorFlow Graph
    * @param name the name of this Optimizer
-   * @param learningRate the learning rate
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
    * @param learningRatePower Controls how the learning rate decreases during training. Use zero for
    *     a fixed learning rate.
    * @param initialAccumulatorValue The starting value for accumulators. Only zero or positive
@@ -165,14 +310,13 @@ public Ftrl(
   public Ftrl(
       Graph graph,
       String name,
-      float learningRate,
+      Operand<TFloat32> learningRateOperand,
       float learningRatePower,
       float initialAccumulatorValue,
       float l1Strength,
       float l2Strength,
       float l2ShrinkageRegularizationStrength) {
-    super(graph, name);
-    this.learningRate = learningRate;
+    super(graph, name, learningRateOperand);
     this.learningRatePower = learningRatePower;
     this.initialAccumulatorValue = initialAccumulatorValue;
     this.l1RegularizationStrength = l1Strength;
@@ -181,7 +325,7 @@ public Ftrl(
     validateParams();
   }
 
-  /** Validates all the settings of the Frtl Optmizer */
+  /** Validates all the settings of the Ftrl Optimizer */
   private void validateParams() {
     if (this.initialAccumulatorValue < 0.0F) {
       throw new IllegalArgumentException(
@@ -242,24 +386,22 @@ private <T extends TType> void createFtrlSlot(Output<T> v) {
   protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
     Variable<T> accumSlot = getSlot(variable, ACCUMULATOR).get();
     Variable<T> linearSlot = getSlot(variable, LINEAR_ACCUMULATOR).get();
-    ApplyFtrl.Options options = ApplyFtrl.useLocking(true);
     return this.tf.train.applyFtrl(
         variable,
-        accumSlot, // accum
-        linearSlot, // linear
-        gradient, // gradient
-        tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), // lr
-        tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.dataType()), // l1
-        tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.dataType()), // l2
-        tf.dtypes.cast(
-            tf.constant(l2ShrinkageRegularizationStrength), gradient.dataType()), // l2Shrinkage
-        tf.dtypes.cast(tf.constant(learningRatePower), gradient.dataType()), // lrPower
-        options);
+        accumSlot,
+        linearSlot,
+        gradient,
+        tf.dtypes.cast(this.getLearningRateOperand(), gradient.dataType()),
+        tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.dataType()),
+        tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.dataType()),
+        tf.dtypes.cast(tf.constant(l2ShrinkageRegularizationStrength), gradient.dataType()),
+        tf.dtypes.cast(tf.constant(learningRatePower), gradient.dataType()),
+        ApplyFtrl.useLocking(true));
   }
 
   /** {@inheritDoc} */
   @Override
   public String getOptimizerName() {
-    return "Ftrl";
+    return DEFAULT_NAME;
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java
index e307855e636..efec399f40d 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java
@@ -16,8 +16,10 @@
 package org.tensorflow.framework.optimizers;
 
 import org.tensorflow.Graph;
+import org.tensorflow.Operand;
 import org.tensorflow.Output;
 import org.tensorflow.op.Op;
+import org.tensorflow.types.TFloat32;
 import org.tensorflow.types.family.TType;
 
 /**
@@ -26,12 +28,12 @@
  */
 public class GradientDescent extends Optimizer {
 
+  public static final String DEFAULT_NAME = "GradientDescent";
   public static final float LEARNING_RATE_DEFAULT = 0.01f;
 
-  private final float learningRate;
-
   /**
-   * Creates a GradientDescent Optimizer
+   * Creates a GradientDescent Optimizer using {@link #DEFAULT_NAME} for the Optimizer name and
+   * {@link #LEARNING_RATE_DEFAULT} for the learning rate.
    *
    * @param graph the TensorFlow graph
    */
@@ -40,33 +42,54 @@ public GradientDescent(Graph graph) {
   }
 
   /**
-   * Creates a GradientDescent Optimizer
+   * Creates a GradientDescent Optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
    *
    * @param graph the TensorFlow graph
-   * @param learningRate the learning rate, defaults to 0.01
+   * @param learningRate the learning rate.
    */
   public GradientDescent(Graph graph, float learningRate) {
-    super(graph);
-    this.learningRate = learningRate;
+    super(graph, null, learningRate);
+  }
+
+  /**
+   * Creates a GradientDescent Optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
+   *
+   * @param graph the TensorFlow graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public GradientDescent(Graph graph, Operand<TFloat32> learningRateOperand) {
+    super(graph, null, learningRateOperand);
   }
 
   /**
    * Creates a GradientDescent Optimizer
    *
    * @param graph the TensorFlow graph
-   * @param name the name for this Optimizer, default is "GradientDescent"
-   * @param learningRate the learning rate, defaults to 0.01
+   * @param name the name for this Optimizer.
+   * @param learningRate the learning rate.
    */
   public GradientDescent(Graph graph, String name, float learningRate) {
-    super(graph, name);
-    this.learningRate = learningRate;
+    super(graph, name, learningRate);
+  }
+
+  /**
+   * Creates a GradientDescent Optimizer
+   *
+   * @param graph the TensorFlow graph
+   * @param name the name for this Optimizer.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public GradientDescent(Graph graph, String name, Operand<TFloat32> learningRateOperand) {
+    super(graph, name, learningRateOperand);
   }
 
   /** {@inheritDoc} */
   @Override
   protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
     return tf.train.applyGradientDescent(
-        variable, tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), gradient);
+        variable, tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), gradient);
   }
 
   /** {@inheritDoc} */
@@ -78,6 +101,6 @@ public String toString() {
   /** {@inheritDoc} */
   @Override
   public String getOptimizerName() {
-    return "GradientDescent";
+    return DEFAULT_NAME;
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java
index 111727d26fa..436b587c353 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java
@@ -21,6 +21,7 @@
 import org.tensorflow.op.Op;
 import org.tensorflow.op.core.Variable;
 import org.tensorflow.op.train.ApplyMomentum;
+import org.tensorflow.types.TFloat32;
 import org.tensorflow.types.family.TType;
 
 import java.util.List;
@@ -33,20 +34,21 @@
  */
 public class Momentum extends Optimizer {
 
+  public static final String DEFAULT_NAME = "Momentum";
   public static final float LEARNING_RATE_DEFAULT = 0.01F;
   public static final float MOMENTUM_DEFAULT = 0.0F;
   public static final boolean NESTEROV_DEFAULT = false;
 
   public static final String MOMENTUM = "momentum";
 
-  private final float learningRate;
-
   private final float momentum;
 
   private final boolean useNesterov;
 
   /**
-   * Creates a Momentum Optimizer
+   * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #MOMENTUM_DEFAULT} for the momentum, and
+   * {@link #NESTEROV_DEFAULT} for the Nesterov flag.
    *
    * @param graph the TensorFlow graph
    */
@@ -55,7 +57,8 @@ public Momentum(Graph graph) {
   }
 
   /**
-   * Creates a Momentum Optimizer
+   * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #MOMENTUM_DEFAULT} for the momentum, and {@link #NESTEROV_DEFAULT} for the Nesterov flag.
    *
    * @param graph the TensorFlow graph
    * @param learningRate the learning rate
@@ -65,31 +68,74 @@ public Momentum(Graph graph, float learningRate) {
   }
 
   /**
-   * Creates a Momentum Optimizer
+   * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #MOMENTUM_DEFAULT} for the momentum, and {@link #NESTEROV_DEFAULT} for the Nesterov flag.
+   *
+   * @param graph the TensorFlow graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public Momentum(Graph graph, Operand<TFloat32> learningRateOperand) {
+    this(graph, learningRateOperand, MOMENTUM_DEFAULT, NESTEROV_DEFAULT);
+  }
+
+  /**
+   * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name and {@link
+   * #NESTEROV_DEFAULT} for the Nesterov flag.
    *
    * @param graph the TensorFlow graph
    * @param learningRate the learning rate
    * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and
-   *     dampens oscillations, Must be greater than or equal to zero. Default is 0.
+   *     dampens oscillations, Must be greater than or equal to zero.
+   * @throws java.lang.IllegalArgumentException if momentum is less than zero.
    */
   public Momentum(Graph graph, float learningRate, float momentum) {
     this(graph, learningRate, momentum, NESTEROV_DEFAULT);
   }
 
   /**
-   * Creates a Momentum Optimizer
+   * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name and {@link
+   * #NESTEROV_DEFAULT} for the Nesterov flag.
+   *
+   * @param graph the TensorFlow graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and
+   *     dampens oscillations, Must be greater than or equal to zero.
+   * @throws java.lang.IllegalArgumentException if momentum is less than zero.
+   */
+  public Momentum(Graph graph, Operand<TFloat32> learningRateOperand, float momentum) {
+    this(graph, learningRateOperand, momentum, NESTEROV_DEFAULT);
+  }
+
+  /**
+   * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
    *
    * @param graph the TensorFlow graph
    * @param learningRate the learning rate
    * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and
-   *     dampens oscillations, Must be greater than or equal to zero. Default is 0.
-   * @param useNesterov Whether to apply Nesterov momentum. Defaults to false.
+   *     dampens oscillations, Must be greater than or equal to zero.
+   * @param useNesterov Whether to apply Nesterov momentum.
+   * @throws java.lang.IllegalArgumentException if momentum is less than zero.
    */
   public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) {
-    super(graph);
-    this.learningRate = learningRate;
-    this.momentum = momentum;
-    this.useNesterov = useNesterov;
+    this(graph, null, learningRate, momentum, useNesterov);
+  }
+
+  /**
+   * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
+   *
+   * @param graph the TensorFlow graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and
+   *     dampens oscillations, Must be greater than or equal to zero.
+   * @param useNesterov Whether to apply Nesterov momentum.
+   * @throws java.lang.IllegalArgumentException if momentum is less than zero.
+   */
+  public Momentum(
+      Graph graph, Operand<TFloat32> learningRateOperand, float momentum, boolean useNesterov) {
+    this(graph, null, learningRateOperand, momentum, useNesterov);
   }
 
   /**
@@ -99,13 +145,40 @@ public Momentum(Graph graph, float learningRate, float momentum, boolean useNest
    * @param name the name for this Optimizer
    * @param learningRate the learning rate
    * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and
-   *     dampens oscillations, Must be greater than or equal to zero. Default is 0.
-   * @param useNesterov Whether to apply Nesterov momentum. Defaults to false.
+   *     dampens oscillations, Must be greater than or equal to zero.
+   * @param useNesterov Whether to apply Nesterov momentum.
+   * @throws java.lang.IllegalArgumentException if momentum is less than zero.
    */
   public Momentum(
       Graph graph, String name, float learningRate, float momentum, boolean useNesterov) {
-    super(graph, name);
-    this.learningRate = learningRate;
+    super(graph, name, learningRate);
+    if (momentum < 0)
+      throw new IllegalArgumentException("momentum must be greater than or equal to zero.");
+    this.momentum = momentum;
+    this.useNesterov = useNesterov;
+  }
+
+  /**
+   * Creates a Momentum Optimizer
+   *
+   * @param graph the TensorFlow graph
+   * @param name the name for this Optimizer
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and
+   *     dampens oscillations, Must be greater than or equal to zero.
+   * @param useNesterov Whether to apply Nesterov momentum.
+   * @throws java.lang.IllegalArgumentException if momentum is less than zero.
+   */
+  public Momentum(
+      Graph graph,
+      String name,
+      Operand<TFloat32> learningRateOperand,
+      float momentum,
+      boolean useNesterov) {
+    super(graph, name, learningRateOperand);
+    if (momentum < 0)
+      throw new IllegalArgumentException("momentum must be greater than or equal to zero.");
     this.momentum = momentum;
     this.useNesterov = useNesterov;
   }
@@ -136,7 +209,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
     return tf.train.applyMomentum(
         variable,
         slot,
-        tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()),
+        tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()),
         gradient,
         tf.dtypes.cast(tf.constant(momentum), gradient.dataType()),
         ApplyMomentum.useNesterov(useNesterov));
@@ -158,6 +231,6 @@ public String toString() {
   /** {@inheritDoc} */
   @Override
   public String getOptimizerName() {
-    return "Momentum";
+    return DEFAULT_NAME;
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java
index 48e5135c952..202f5013e7d 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java
@@ -25,6 +25,7 @@
  */
 public class Nadam extends Optimizer {
 
+  public static final String DEFAULT_NAME = "Nadam";
   private static final float DECAY_BASE = 0.96f;
   private static final float DECAY = 0.004f;
   public static final float LEARNING_RATE_DEFAULT = 0.001f;
@@ -35,9 +36,6 @@ public class Nadam extends Optimizer {
   public static final String SECOND_MOMENT = "v";
   public static final String MOMENTUM = "momentum";
 
-  /** The learning rate. */
-  private final float learningRate;
-
   /** The exponential decay rate for the 1st moment estimates. */
   private final float betaOne;
 
@@ -47,7 +45,6 @@ public class Nadam extends Optimizer {
   /** A small constant for numerical stability. */
   private final float epsilon;
 
-  private Constant<TFloat32> learningRateConst;
   private Constant<TFloat32> epsilonConst;
   private Constant<TFloat32> betaOneConst;
   private Constant<TFloat32> betaTwoConst;
@@ -69,7 +66,9 @@ public class Nadam extends Optimizer {
   private Operand<TFloat32> vTPrimeDenominator;
 
   /**
-   * Creates a Nadam Optimizer
+   * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value,
+   * {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon.
    *
    * @param graph the TensorFlow graph
    */
@@ -78,59 +77,124 @@ public Nadam(Graph graph) {
   }
 
   /**
-   * Creates a Nadam Optimizer
+   * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and
+   * {@link #EPSILON_DEFAULT} for the epsilon.
    *
    * @param graph the TensorFlow graph
-   * @param learningRate the learning rate, defaults to 0.001
+   * @param learningRate the learning rate.
    */
   public Nadam(Graph graph, float learningRate) {
     this(graph, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT);
   }
 
   /**
-   * Creates a Nadam Optimizer
+   * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and
+   * {@link #EPSILON_DEFAULT} for the epsilon.
+   *
+   * @param graph the TensorFlow graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public Nadam(Graph graph, Operand<TFloat32> learningRateOperand) {
+    this(graph, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT);
+  }
+
+  /**
+   * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
    *
    * @param graph the TensorFlow graph
-   * @param learningRate the learning rate, defaults to 0.001
-   * @param betaOne The exponential decay rate for the 1st moment estimates. Default is 0.9.
-   * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. Default
-   *     is 0.999.
-   * @param epsilon A small constant for numerical stability. Default is 1e-8.
+   * @param learningRate the learning rate.
+   * @param betaOne The exponential decay rate for the 1st moment estimates.
+   * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm.
+   * @param epsilon A small constant for numerical stability.
    */
   public Nadam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) {
-    super(graph);
-    this.learningRate = learningRate;
-    this.betaOne = betaOne;
-    this.betaTwo = betaTwo;
-    this.epsilon = epsilon;
+    this(graph, null, learningRate, betaOne, betaTwo, epsilon);
   }
 
   /**
-   * Creates a Nadam Optimizer
+   * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
+   *
+   * @param graph the TensorFlow graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param betaOne The exponential decay rate for the 1st moment estimates.
+   * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm.
+   * @param epsilon A small constant for numerical stability.
+   */
+  public Nadam(
+      Graph graph,
+      Operand<TFloat32> learningRateOperand,
+      float betaOne,
+      float betaTwo,
+      float epsilon) {
+    this(graph, null, learningRateOperand, betaOne, betaTwo, epsilon);
+  }
+
+  /**
+   * Creates a Nadam Optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link
+   * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon.
    *
    * @param graph the TensorFlow graph
-   * @param name the name for this Optimizer, defaults to "Nadam"
-   * @param learningRate the learning rate, defaults to 0.001
+   * @param name the name for this Optimizer.
+   * @param learningRate the learning rate.
    */
   public Nadam(Graph graph, String name, float learningRate) {
     this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT);
   }
 
+  /**
+   * Creates a Nadam Optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link
+   * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon.
+   *
+   * @param graph the TensorFlow graph
+   * @param name the name for this Optimizer.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public Nadam(Graph graph, String name, Operand<TFloat32> learningRateOperand) {
+    this(graph, name, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT);
+  }
+
   /**
    * Creates a Nadam Optimizer
    *
    * @param graph the TensorFlow graph
-   * @param name the name for this Optimizer, defaults to "Nadam"
-   * @param learningRate the learning rate, defaults to 0.001
-   * @param betaOne The exponential decay rate for the 1st moment estimates. Default is 0.9.
-   * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. Default
-   *     is 0.999.
-   * @param epsilon A small constant for numerical stability. Default is 1e-8.
+   * @param name the name for this Optimizer.
+   * @param learningRate the learning rate.
+   * @param betaOne The exponential decay rate for the 1st moment estimates.
+   * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm.
+   * @param epsilon A small constant for numerical stability.
    */
   public Nadam(
       Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) {
-    super(graph, name);
-    this.learningRate = learningRate;
+    super(graph, name, learningRate);
+    this.betaOne = betaOne;
+    this.betaTwo = betaTwo;
+    this.epsilon = epsilon;
+  }
+
+  /**
+   * Creates a Nadam Optimizer
+   *
+   * @param graph the TensorFlow graph
+   * @param name the name for this Optimizer.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param betaOne The exponential decay rate for the 1st moment estimates.
+   * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm.
+   * @param epsilon A small constant for numerical stability.
+   */
+  public Nadam(
+      Graph graph,
+      String name,
+      Operand<TFloat32> learningRateOperand,
+      float betaOne,
+      float betaTwo,
+      float epsilon) {
+    super(graph, name, learningRateOperand);
     this.betaOne = betaOne;
     this.betaTwo = betaTwo;
     this.epsilon = epsilon;
@@ -180,7 +244,6 @@ protected Optional<Op> prepare(String scopeName) {
     Constant<TFloat32> one = tf.constant(1.0F);
     Constant<TFloat32> point5 = tf.constant(0.5F);
 
-    learningRateConst = tf.constant(learningRate);
     betaOneConst = tf.constant(betaOne);
     betaTwoConst = tf.constant(betaTwo);
     Constant<TInt64> localStepConst = tf.constant(this.iterations + 1);
@@ -271,7 +334,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
         tf.math.sub(
             variable,
             tf.math.div(
-                tf.math.mul(tf.dtypes.cast(learningRateConst, dType), m_t_bar),
+                tf.math.mul(tf.dtypes.cast(getLearningRateOperand(), dType), m_t_bar),
                 tf.math.add(tf.math.sqrt(vTPrime), tf.dtypes.cast(epsilonConst, dType))));
 
     return tf.assign(variable, varT, Assign.useLocking(true));
@@ -297,6 +360,6 @@ protected Op finish(List<Op> updateOperations, String name) {
   /** {@inheritDoc} */
   @Override
   public String getOptimizerName() {
-    return "Nadam";
+    return DEFAULT_NAME;
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java
index 933a54c7670..586fef28c1e 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java
@@ -15,25 +15,27 @@
  */
 package org.tensorflow.framework.optimizers;
 
-import org.tensorflow.Graph;
-import org.tensorflow.Operand;
-import org.tensorflow.Operation;
-import org.tensorflow.Output;
+import org.tensorflow.*;
+import org.tensorflow.ndarray.Shape;
 import org.tensorflow.op.Op;
 import org.tensorflow.op.Ops;
 import org.tensorflow.op.Scope;
 import org.tensorflow.op.core.Assign;
 import org.tensorflow.op.core.NoOp;
+import org.tensorflow.op.core.Placeholder;
 import org.tensorflow.op.core.Variable;
+import org.tensorflow.types.TFloat32;
 import org.tensorflow.types.family.TType;
 
 import java.util.*;
 import java.util.stream.Collectors;
 
 /** Base class for gradient optimizers. */
-public abstract class Optimizer {
-
+public abstract class Optimizer implements AutoCloseable {
+  public static final String LEARNING_RATE = "learning_rate";
   public static final String VARIABLE_V2 = "VariableV2";
+  public static final float LEARNING_RATE_DEFAULT = 0.001f;
+
   /** Global state variables */
   // TODO make this be used.
   protected final List<Variable<?>> globals;
@@ -44,18 +46,25 @@ public abstract class Optimizer {
   /** Top level map key is the variable name, lower level map key is the slot name. */
   private final Map<String, Map<String, Variable<?>>> slots;
 
+  protected float learningRate;
+  protected Placeholder<TFloat32> learningRatePlaceholder = null;
+  private Tensor<TFloat32> learningRateTensor;
+  private Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap = null;
+  private Operand<TFloat32> learningRateOperand;
+
   /**
    * Builds an optimizer for the supplied graph.
    *
-   * <p>Uses the name from {@link Optimizer#getOptimizerName()} to name the operations.
-   *
    * @param graph The graph to optimize.
+   * @param name The base name for the operations.
+   * @param learningRate the learning rate.
    */
-  protected Optimizer(Graph graph) {
+  protected Optimizer(Graph graph, String name, float learningRate) {
     this.graph = graph;
-    this.tf = Ops.create(graph).withName(getOptimizerName());
+    this.tf = Ops.create(graph).withName(name == null ? getOptimizerName() : name);
     this.slots = new HashMap<>();
     this.globals = new ArrayList<>();
+    setLearningRate(learningRate);
   }
 
   /**
@@ -63,20 +72,14 @@ protected Optimizer(Graph graph) {
    *
    * @param graph The graph to optimize.
    * @param name The base name for the operations.
+   * @param learningRateOperand the learning rate.
    */
-  protected Optimizer(Graph graph, String name) {
+  protected Optimizer(Graph graph, String name, Operand<TFloat32> learningRateOperand) {
     this.graph = graph;
-    this.tf = Ops.create(graph).withName(name);
+    this.tf = Ops.create(graph).withName(name == null ? getOptimizerName() : name);
     this.slots = new HashMap<>();
     this.globals = new ArrayList<>();
-  }
-
-  /**
-   * Gets the Optimizer's Ops instance
-   * @return the Optimizer's Ops instance
-   */
-  public final Ops getTF() {
-    return tf;
+    setLearningRateOperand(learningRateOperand);
   }
 
   /**
@@ -229,7 +232,7 @@ protected <T extends TType> void createSlot(
   }
 
   /**
-   * Returns a No-op prepare.
+   * No-op prepare method.
    *
    * @param scopeName The scope name to use for any variable creations.
    */
@@ -238,7 +241,7 @@ protected Optional<Op> prepare(String scopeName) {
   }
 
   /**
-   * Performs a No-op slot creation method.
+   * No-op slot creation method.
    *
    * @param variables The variables to create slots for.
    */
@@ -280,12 +283,85 @@ protected Op finish(List<Op> updateOperations, String name) {
   }
 
   /**
-   * Get the Name of the optimizer.
+   * Gets the Name of the optimizer.
    *
    * @return The optimizer name.
    */
   public abstract String getOptimizerName();
 
+  /**
+   * Sets the learning rate
+   *
+   * @param newLearningRate the new learning rate
+   */
+  public final void setLearningRate(float newLearningRate) {
+    if (learningRatePlaceholder == null) {
+      learningRatePlaceholder =
+          tf.withSubScope(LEARNING_RATE)
+              .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar()));
+    }
+
+    if (learningRate != newLearningRate) {
+      if (learningRateTensor != null) learningRateTensor.close();
+      learningRate = newLearningRate;
+      learningRateTensor = TFloat32.scalarOf(learningRate);
+      feedMap = Collections.singletonMap(learningRatePlaceholder, learningRateTensor);
+    }
+  }
+
+  /**
+   * Sets the learning rate Operand. The learning rate operand is an operand that is used to
+   * calculate the learning rate.
+   *
+   * @param newLearningRateOperand the new learning rate operand.
+   */
+  public final void setLearningRateOperand(Operand<TFloat32> newLearningRateOperand) {
+    close(); // Cleanup the placeholder and tensor if they exist.
+    learningRateOperand = newLearningRateOperand;
+  }
+
+  /**
+   * Gets the learning rate
+   *
+   * @return the learning rate
+   */
+  public float getLearningRate() {
+    return learningRate;
+  }
+
+  /**
+   * Gets the learning rate Operand, used by subclasses in their graph operations. If a float
+   * learning rate has been set using {@link #setLearningRate}, then this will be the learning rate
+   * Placeholder, otherwise the learning rate operand is returned as passed to {@link
+   * #setLearningRateOperand}.
+   *
+   * @return the learning rate Operand
+   */
+  protected Operand<TFloat32> getLearningRateOperand() {
+    return learningRatePlaceholder == null ? learningRateOperand : learningRatePlaceholder;
+  }
+
+  /**
+   * Gets the Feed Map for the run methods to set the Placeholder value(s). Each entry in the Feed
+   * Map contains a PlaceHolder and a Tensor with the value.
+   *
+   * @return the current Feed Map for the run methods, this will be null if the LearningRateOperand
+   *     has been set.
+   */
+  public Map<Operand<? extends TType>, Tensor<? extends TType>> getFeedMap() {
+    return feedMap;
+  }
+
+  /** {@inheritDoc} */
+  public void close() {
+    // close the learningRate Tensor if it exists.
+    if (learningRateTensor != null) {
+      learningRateTensor.close();
+      learningRateTensor = null;
+    }
+    if (feedMap != null) feedMap = null;
+  }
+
   /** Optional attributes for {@link org.tensorflow.framework.optimizers.Optimizer} */
   public static class Options {
 
@@ -294,8 +370,6 @@ public static class Options {
     private Options() {}
 
     /**
-     * Sets the shared name
-     *
      * @param sharedName If non-empty, this variable is named in the given bucket with this
      *     shared_name. Otherwise, the node name is used instead.
      */
@@ -305,41 +379,20 @@ public Optimizer.Options sharedName(String sharedName) {
     }
   }
 
-  /**
-   * A class that holds a paired gradient and variable.
-   *
-   * @param <T> the data type for the gradient and variable
-   */
   public static class GradAndVar<T extends TType> {
 
     private final Output<T> gradient;
     private final Output<T> variable;
 
-    /**
-     * Creates a Gradient and Variable pair
-     *
-     * @param gradient the gradient
-     * @param variable the variable
-     */
     public GradAndVar(Output<T> gradient, Output<T> variable) {
       this.gradient = gradient;
       this.variable = variable;
     }
 
-    /**
-     * Gets the gradient
-     *
-     * @return the gradient
-     */
     public Output<T> getGradient() {
       return gradient;
     }
 
-    /**
-     * Gets the variable
-     *
-     * @return the variable
-     */
     public Output<T> getVariable() {
       return variable;
     }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java
index 8b71558e549..7a03bec849d 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java
@@ -20,6 +20,7 @@
 import org.tensorflow.Output;
 import org.tensorflow.op.Op;
 import org.tensorflow.op.core.Variable;
+import org.tensorflow.types.TFloat32;
 import org.tensorflow.types.family.TType;
 
 import java.util.List;
@@ -47,6 +48,7 @@
  */
 public class RMSProp extends Optimizer {
 
+  public static final String DEFAULT_NAME = "RMSProp";
   public static final float LEARNING_RATE_DEFAULT = 0.001f;
   public static final float DECAY_DEFAULT = 0.9f;
   public static final float MOMENTUM_DEFAULT = 0.0f;
@@ -56,14 +58,16 @@ public class RMSProp extends Optimizer {
   public static final String MG = "mg"; // mean gradient?
   public static final String MOMENTUM = "momentum";
 
-  private final float learningRate;
   private final float decay;
   private final float momentum;
   private final float epsilon;
   private final boolean centered;
 
   /**
-   * Creates an RMSPRrop Optimizer
+   * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #DECAY_DEFAULT} for the decay, {@link
+   * #MOMENTUM_DEFAULT} for the momentum, {@link #EPSILON_DEFAULT} for the epsilon value and {@link
+   * #CENTERED_DEFAULT} for the centered flag.
    *
    * @param graph the TensorFlow Graph
    */
@@ -78,7 +82,9 @@ public RMSProp(Graph graph) {
   }
 
   /**
-   * Creates an RMSPRrop Optimizer
+   * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #DECAY_DEFAULT} for the decay, {@link #MOMENTUM_DEFAULT} for the momentum, {@link
+   * #EPSILON_DEFAULT} for the epsilon value and {@link #CENTERED_DEFAULT} for the centered flag.
    *
    * @param graph the TensorFlow Graph
    * @param learningRate the learning rate
@@ -88,17 +94,36 @@ public RMSProp(Graph graph, float learningRate) {
   }
 
   /**
-   * Creates an RMSPRrop Optimizer
+   * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link
+   * #DECAY_DEFAULT} for the decay, {@link #MOMENTUM_DEFAULT} for the momentum, {@link
+   * #EPSILON_DEFAULT} for the epsilon value and {@link #CENTERED_DEFAULT} for the centered flag.
+   *
+   * @param graph the TensorFlow Graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public RMSProp(Graph graph, Operand<TFloat32> learningRateOperand) {
+    this(
+        graph,
+        learningRateOperand,
+        DECAY_DEFAULT,
+        MOMENTUM_DEFAULT,
+        EPSILON_DEFAULT,
+        CENTERED_DEFAULT);
+  }
+
+  /**
+   * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
    *
    * @param graph the TensorFlow Graph
    * @param learningRate the learning rate
-   * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9.
-   * @param momentum the acceleration factor, default is 0.
+   * @param decay Discounting factor for the history/coming gradient.
+   * @param momentum the acceleration factor.
    * @param epsilon A small constant for numerical stability
    * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the
    *     gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code>
    *     true</code>> may help with training, but is slightly more expensive in terms of computation
-   *     and memory. Defaults to <code>false</code>.
+   *     and memory.
    */
   public RMSProp(
       Graph graph,
@@ -107,19 +132,40 @@ public RMSProp(
       float momentum,
       float epsilon,
       boolean centered) {
-    super(graph);
-    this.learningRate = learningRate;
-    this.decay = decay;
-    this.momentum = momentum;
-    this.epsilon = epsilon;
-    this.centered = centered;
+    this(graph, null, learningRate, decay, momentum, epsilon, centered);
   }
 
   /**
-   * Creates an RMSPRrop Optimizer
+   * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name.
+   *
+   * @param graph the TensorFlow Graph
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param decay Discounting factor for the history/coming gradient.
+   * @param momentum the acceleration factor.
+   * @param epsilon A small constant for numerical stability
+   * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the
+   *     gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code>
+   *     true</code>> may help with training, but is slightly more expensive in terms of computation
+   *     and memory.
+   */
+  public RMSProp(
+      Graph graph,
+      Operand<TFloat32> learningRateOperand,
+      float decay,
+      float momentum,
+      float epsilon,
+      boolean centered) {
+    this(graph, null, learningRateOperand, decay, momentum, epsilon, centered);
+  }
+
+  /**
+   * Creates an RMSPRrop Optimizer using {@link #DECAY_DEFAULT} for the decay, {@link
+   * #MOMENTUM_DEFAULT} for the momentum, {@link #EPSILON_DEFAULT} for the epsilon value and {@link
+   * #CENTERED_DEFAULT} for the centered flag.
    *
    * @param graph the TensorFlow Graph
-   * @param name the name of this Optimizer. Defaults to "RMSProp".
+   * @param name the name of this Optimizer.
    * @param learningRate the learning rate
    */
   public RMSProp(Graph graph, String name, float learningRate) {
@@ -133,19 +179,40 @@ public RMSProp(Graph graph, String name, float learningRate) {
         CENTERED_DEFAULT);
   }
 
+  /**
+   * Creates an RMSPRrop Optimizer using {@link #DECAY_DEFAULT} for the decay, {@link
+   * #MOMENTUM_DEFAULT} for the momentum, {@link #EPSILON_DEFAULT} for the epsilon value and {@link
+   * #CENTERED_DEFAULT} for the centered flag.
+   *
+   * @param graph the TensorFlow Graph
+   * @param name the name of this Optimizer.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   */
+  public RMSProp(Graph graph, String name, Operand<TFloat32> learningRateOperand) {
+    this(
+        graph,
+        name,
+        learningRateOperand,
+        DECAY_DEFAULT,
+        MOMENTUM_DEFAULT,
+        EPSILON_DEFAULT,
+        CENTERED_DEFAULT);
+  }
+
   /**
    * Creates an RMSPRrop Optimizer
    *
    * @param graph the TensorFlow Graph
-   * @param name the name of this Optimizer. Defaults to "RMSProp".
+   * @param name the name of this Optimizer.
    * @param learningRate the learning rate
-   * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9.
-   * @param momentum The acceleration factor, default is 0.
+   * @param decay Discounting factor for the history/coming gradient.
+   * @param momentum The acceleration factor,.
    * @param epsilon A small constant for numerical stability
    * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the
    *     gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code>
    *     true</code>> may help with training, but is slightly more expensive in terms of computation
-   *     and memory. Defaults to <code>false</code>.
+   *     and memory.
    */
   public RMSProp(
       Graph graph,
@@ -155,8 +222,37 @@ public RMSProp(
       float momentum,
       float epsilon,
       boolean centered) {
-    super(graph, name);
-    this.learningRate = learningRate;
+    super(graph, name, learningRate);
+    this.decay = decay;
+    this.momentum = momentum;
+    this.epsilon = epsilon;
+    this.centered = centered;
+  }
+
+  /**
+   * Creates an RMSPRrop Optimizer
+   *
+   * @param graph the TensorFlow Graph
+   * @param name the name of this Optimizer.
+   * @param learningRateOperand the learning rate Operand, this is used to calculate the learning
+   *     rate.
+   * @param decay Discounting factor for the history/coming gradient.
+   * @param momentum The acceleration factor.
+   * @param epsilon A small constant for numerical stability
+   * @param centered If <code>true</code>, gradients are normalized by the estimated variance of the
+   *     gradient; if <code>false</code>>, by the uncentered second moment. Setting this to <code>
+   *     true</code>> may help with training, but is slightly more expensive in terms of computation
+   *     and memory.
+   */
+  public RMSProp(
+      Graph graph,
+      String name,
+      Operand<TFloat32> learningRateOperand,
+      float decay,
+      float momentum,
+      float epsilon,
+      boolean centered) {
+    super(graph, name, learningRateOperand);
     this.decay = decay;
     this.momentum = momentum;
     this.epsilon = epsilon;
@@ -171,10 +267,8 @@ protected void createSlots(List<Output<? extends TType>> variables) {
     }
   }
 
-
   /**
-   * Creates the RMSProp Slots for Root Mean Squared (RMS),
-   * MOMENTUM, and Mean Gradient (MG)
+   * Creates the RMSProp Slots for Root Mean Squared (RMS), MOMENTUM, and Mean Gradient (MG)
    *
    * @param v the variable to install in the slot
    * @param <T> the datatype of the variable.
@@ -205,7 +299,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
           mgSlot,
           rmsSlot,
           momentumSlot,
-          tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()),
+          tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()),
           tf.dtypes.cast(tf.constant(decay), gradient.dataType()),
           tf.dtypes.cast(tf.constant(momentum), gradient.dataType()),
           tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()),
@@ -215,7 +309,7 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
         variable,
         rmsSlot,
         momentumSlot,
-        tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()),
+        tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()),
         tf.dtypes.cast(tf.constant(decay), gradient.dataType()),
         tf.dtypes.cast(tf.constant(momentum), gradient.dataType()),
         tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()),
@@ -242,6 +336,6 @@ public String toString() {
   /** {@inheritDoc} */
   @Override
   public String getOptimizerName() {
-    return "RMSProp";
+    return DEFAULT_NAME;
   }
 }
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java
index 5c4ce542c65..7653c99bc98 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java
@@ -15,7 +15,6 @@
 package org.tensorflow.framework.optimizers;
 
 import org.junit.jupiter.api.*;
-import org.tensorflow.Graph;
 import org.tensorflow.framework.optimizers.Optimizer.GradAndVar;
 import org.tensorflow.framework.utils.TestSession;
 import org.tensorflow.ndarray.Shape;
@@ -55,11 +54,10 @@ public void tearDown() {}
 
   @Test
   public void testConstructAdadeltaWithLR() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      AdaDelta opt = new AdaDelta(graph, 1.0F, 0.9F, 1.F);
-      AdaDelta opt2 = new AdaDelta(graph, 0.1F, 0.9F, 1.F);
-      AdaDelta opt3 = new AdaDelta(graph, 0.1F, 0.9F, 1e-8F);
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        AdaDelta opt = new AdaDelta(session.getGraph(), "opt1", 1.0F, 0.9F, 1.F);
+        AdaDelta opt2 = new AdaDelta(session.getGraph(), "opt2", 0.1F, 0.9F, 1.F);
+        AdaDelta opt3 = new AdaDelta(session.getGraph(), "opt3", 0.1F, 0.9F, 1e-8F)) {
       String format = "AdaDelta{learningRate=%s, rho=%s, epsilon=%s}";
       String optExpected = String.format(format, 1.0F, 0.9F, 1.F);
       String opt2Expected = String.format(format, 0.1F, 0.9F, 1.F);
@@ -80,11 +78,15 @@ public void testBasic() {
     int numUpdates = 4; // # number of ADADELTA steps to perform
     float[] grads = {0.2F, 0.1F, 0.01F};
     float[] lrs = {1.0F, 0.5F, 0.1F};
+
+    float rho = 0.95F;
+    float epsilon = 1e-8F;
+
     for (float grad : grads) {
       for (float lr : lrs) {
-        try (TestSession session = TestSession.createTestSession(tfMode)) {
+        try (TestSession session = TestSession.createTestSession(tfMode);
+            AdaDelta instance = new AdaDelta(session.getGraph(), lr, rho, epsilon)) {
           Ops tf = session.getTF();
-          Graph graph = session.getGraph();
           float[] var0Init = {1.0F, 2.0F};
           float[] var1Init = {3.0F, 4.0F};
           float[] fgrads = {grad, grad};
@@ -96,37 +98,33 @@ public void testBasic() {
           Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
 
           Constant<TFloat32> cgrads = tf.constant(fgrads);
-
           float accum = 0.0F;
           float accumUpdate = 0.0F;
-          float rho = 0.95F;
-          float epsilon = 1e-8F;
 
           /* build the GradsAnvVars */
           List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
           gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput()));
           gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput()));
 
-          /* get the Optimizer */
-          AdaDelta adaDelta = new AdaDelta(graph, lr, rho, epsilon);
-
           /*apply gradients */
-          Op adadeltaUpdate = adaDelta.applyGradients(gradsAndVars, "AdaDeltaTest");
+          Op adadeltaUpdate = instance.applyGradients(gradsAndVars, "AdaDeltaTest");
 
           /* Create and validate the shapes of the slota */
+          @SuppressWarnings("unchecked")
           Variable<TFloat32>[] slots = new Variable[2];
+          @SuppressWarnings("unchecked")
           Variable<TFloat32>[] slotUpdates = new Variable[2];
 
-          slots[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR).get();
+          slots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get();
           assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape());
 
-          slotUpdates[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get();
+          slotUpdates[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get();
           assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape());
 
-          slots[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR).get();
+          slots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get();
           assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape());
 
-          slotUpdates[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get();
+          slotUpdates[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get();
           assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape());
 
           /* initialize the local variables */
@@ -143,7 +141,8 @@ public void testBasic() {
           float[] updates = new float[numUpdates];
           float totUpdate = 0;
           for (int step = 0; step < numUpdates; step++) {
-            session.run(adadeltaUpdate);
+
+            session.run(adadeltaUpdate, instance.getFeedMap());
             accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho);
             updates[step] =
                 ((float) Math.sqrt(accumUpdate + epsilon)
@@ -167,4 +166,205 @@ public void testBasic() {
       }
     }
   }
+
+  @Test
+  public void testBasicWithLROperand() {
+    int numUpdates = 4; // # number of ADADELTA steps to perform
+    float[] grads = {0.2F, 0.1F, 0.01F};
+    float[] lrs = {1.0F, 0.5F, 0.1F};
+
+    float rho = 0.95F;
+    float epsilon = 1e-8F;
+
+    for (float grad : grads) {
+      for (float lr : lrs) {
+        try (TestSession session = TestSession.createTestSession(tfMode)) {
+          Ops tf = session.getTF();
+          // this just uses a trivial operand
+          try (AdaDelta instance =
+              new AdaDelta(
+                  session.getGraph(),
+                  tf.math.mul(tf.constant(lr), tf.constant(1.f)),
+                  rho,
+                  epsilon)) {
+
+            float[] var0Init = {1.0F, 2.0F};
+            float[] var1Init = {3.0F, 4.0F};
+            float[] fgrads = {grad, grad};
+            Shape shape = Shape.of(var0Init.length);
+            Variable<TFloat32> var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE);
+            Variable<TFloat32> var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE);
+
+            Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+            Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+            Constant<TFloat32> cgrads = tf.constant(fgrads);
+            float accum = 0.0F;
+            float accumUpdate = 0.0F;
+
+            /* build the GradsAnvVars */
+            List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+            gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput()));
+            gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput()));
+
+            /*apply gradients */
+            Op adadeltaUpdate = instance.applyGradients(gradsAndVars, "AdaDeltaTest");
+
+            /* Create and validate the shapes of the slota */
+            @SuppressWarnings("unchecked")
+            Variable<TFloat32>[] slots = new Variable[2];
+            @SuppressWarnings("unchecked")
+            Variable<TFloat32>[] slotUpdates = new Variable[2];
+
+            slots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get();
+            assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape());
+
+            slotUpdates[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get();
+            assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape());
+
+            slots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get();
+            assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape());
+
+            slotUpdates[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get();
+            assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape());
+
+            /* initialize the local variables */
+            session.run(var0Initializer);
+            session.run(var1Initializer);
+
+            /* initialize the accumulators */
+            session.run(tf.init());
+
+            /* make sure the variables were initialized properly */
+            session.evaluate(var0Init, var0);
+            session.evaluate(var1Init, var1);
+
+            float[] updates = new float[numUpdates];
+            float totUpdate = 0;
+            for (int step = 0; step < numUpdates; step++) {
+
+              session.run(adadeltaUpdate, instance.getFeedMap());
+              accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho);
+              updates[step] =
+                  ((float) Math.sqrt(accumUpdate + epsilon)
+                      * (float) (1 / Math.sqrt(accum + epsilon))
+                      * grad);
+              accumUpdate =
+                  (accumUpdate * rho + ((float) Math.pow(updates[step], 2) * (1.0F - rho)));
+              totUpdate += updates[step] * lr;
+
+              for (int i = 0; i < 2; i++) {
+                session.evaluate(accum, slots[i]);
+                session.evaluate(accumUpdate, slotUpdates[i]);
+              }
+
+              Float[] var0InitUpdate = {var0Init[0] - totUpdate, var0Init[1] - totUpdate};
+              Float[] var1InitUpdate = {var1Init[0] - totUpdate, var1Init[1] - totUpdate};
+
+              session.evaluate(var0InitUpdate, var0);
+              session.evaluate(var1InitUpdate, var1);
+            }
+          }
+        }
+      }
+    }
+  }
+
+  @Test
+  public void testWithLearningRateDecay() {
+    int numSteps = 4; // # number of ADADELTA steps to perform
+    float[] grads = {0.2F, 0.1F, 0.01F};
+
+    for (float grad : grads) {
+      float learningRate = 1.0F;
+      float rho = 0.95F;
+      float epsilon = 1e-8F;
+      try (TestSession session = TestSession.createTestSession(tfMode);
+          AdaDelta instance = new AdaDelta(session.getGraph(), learningRate, rho, epsilon)) {
+        Ops tf = session.getTF();
+
+        float[] var0Init = {1.0F, 2.0F};
+        float[] var1Init = {3.0F, 4.0F};
+        float[] fgrads = {grad, grad};
+        Shape shape = Shape.of(var0Init.length);
+        Variable<TFloat32> var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE);
+        Variable<TFloat32> var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE);
+
+        Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+        Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+        Constant<TFloat32> cgrads = tf.constant(fgrads);
+
+        float accum = 0.0F;
+        float accumUpdate = 0.0F;
+
+        /* build the GradsAnvVars */
+        List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+        gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput()));
+        gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput()));
+
+        Op adadeltaUpdate = instance.applyGradients(gradsAndVars, "AdaDeltaTest");
+
+        /* Create and validae the shapes of the slota */
+        @SuppressWarnings("unchecked")
+        Variable<TFloat32>[] slots = new Variable[2];
+        @SuppressWarnings("unchecked")
+        Variable<TFloat32>[] slotUpdates = new Variable[2];
+
+        slots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get();
+        assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape());
+
+        slotUpdates[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get();
+        assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape());
+
+        slots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get();
+        assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape());
+
+        slotUpdates[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get();
+        assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape());
+
+        /* initialize the local variables */
+        session.run(var0Initializer);
+        session.run(var1Initializer);
+
+        /* initialize the accumulators */
+        session.run(tf.init());
+
+        /* make sure the variables were initialized properly */
+        session.evaluate(var0Init, var0);
+        session.evaluate(var1Init, var1);
+
+        float[] updates = new float[numSteps];
+        float totUpdate = 0;
+        for (int step = 0; step < numSteps; step++) {
+          assertEquals(learningRate, instance.getLearningRate(), epsilon);
+          session.evaluate(
+              learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap());
+          session.run(adadeltaUpdate, instance.getFeedMap());
+          accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho);
+          updates[step] =
+              ((float) Math.sqrt(accumUpdate + epsilon)
+                  * (float) (1 / Math.sqrt(accum + epsilon))
+                  * grad);
+          accumUpdate = (accumUpdate * rho + ((float) Math.pow(updates[step], 2) * (1.0F - rho)));
+          totUpdate += updates[step] * learningRate;
+
+          for (int i = 0; i < 2; i++) {
+            session.evaluate(accum, slots[i]);
+            session.evaluate(accumUpdate, slotUpdates[i]);
+          }
+
+          Float[] var0InitUpdate = {var0Init[0] - totUpdate, var0Init[1] - totUpdate};
+          Float[] var1InitUpdate = {var1Init[0] - totUpdate, var1Init[1] - totUpdate};
+
+          session.evaluate(var0InitUpdate, var0);
+          session.evaluate(var1InitUpdate, var1);
+
+          // Adjust learning rate
+          learningRate *= 0.9F;
+          instance.setLearningRate(learningRate);
+        }
+      }
+    }
+  }
 }
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java
index ef9053ff1eb..9e67b4660df 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java
@@ -15,7 +15,6 @@
 package org.tensorflow.framework.optimizers;
 
 import org.junit.jupiter.api.*;
-import org.tensorflow.Graph;
 import org.tensorflow.framework.utils.TestSession;
 import org.tensorflow.ndarray.Shape;
 import org.tensorflow.op.Op;
@@ -29,6 +28,8 @@
 import java.util.ArrayList;
 import java.util.List;
 
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
 /** Test cases for AdaGradDA Optimizer */
 public class AdaGradDATest {
 
@@ -54,13 +55,11 @@ public void testBasic() {
     float[] var1Init = {0.0F, 0.0F};
     float[] grads0Init = {0.1F, 0.2F};
     float[] grads1Init = {0.01F, 0.02F};
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
+    float learningRate = 3.0F;
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        AdaGradDA instance = new AdaGradDA(session.getGraph(), learningRate)) {
 
-      float learningRate = 3.0F;
-
-      AdaGradDA instance = new AdaGradDA(graph, learningRate);
-      Ops tf = instance.getTF();
+      Ops tf = session.getTF();
 
       Shape shape0 = Shape.of(var0Init.length);
       Shape shape1 = Shape.of(var1Init.length);
@@ -78,7 +77,6 @@ public void testBasic() {
       session.run(var0Initializer);
       session.run(var1Initializer);
 
-
       /* build the GradsAnvVars */
       List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
@@ -91,11 +89,122 @@ public void testBasic() {
 
       session.evaluate(var0Init, var0);
       session.evaluate(var1Init, var1);
-      session.run(adaUpdate);
+      session.run(adaUpdate, instance.getFeedMap());
       float[] expected0 = {-0.904534F, -1.603567F};
       session.evaluate(expected0, var0);
       float[] expected1 = {-0.094821f, -0.189358f};
       session.evaluate(expected1, var1);
     }
   }
+
+  @Test
+  public void testBasicWithLROperand() {
+    float[] var0Init = {0.0F, 0.0F};
+    float[] var1Init = {0.0F, 0.0F};
+    float[] grads0Init = {0.1F, 0.2F};
+    float[] grads1Init = {0.01F, 0.02F};
+    float learningRate = 1.5F;
+    try (TestSession session = TestSession.createTestSession(tfMode)) {
+      Ops tf = session.getTF();
+      try (AdaGradDA instance =
+          new AdaGradDA(
+              session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(2.f)))) {
+
+        Shape shape0 = Shape.of(var0Init.length);
+        Shape shape1 = Shape.of(var1Init.length);
+        Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+        Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+        Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+        Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+        Constant<TFloat32> grads0 = tf.constant(grads0Init);
+        Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+        /* initialize the local variables */
+
+        session.run(var0Initializer);
+        session.run(var1Initializer);
+
+        /* build the GradsAnvVars */
+        List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+        Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradDATest");
+
+        /* initialize the accumulators */
+        session.run(tf.init());
+
+        session.evaluate(var0Init, var0);
+        session.evaluate(var1Init, var1);
+        session.run(adaUpdate, instance.getFeedMap());
+        float[] expected0 = {-0.904534F, -1.603567F};
+        session.evaluate(expected0, var0);
+        float[] expected1 = {-0.094821f, -0.189358f};
+        session.evaluate(expected1, var1);
+      }
+    }
+  }
+
+  @Test
+  public void testWithLearningRateDecay() {
+    float[] var0Init = {0.0F, 0.0F};
+    float[] var1Init = {0.0F, 0.0F};
+    float[] grads0Init = {0.1F, 0.2F};
+    float[] grads1Init = {0.01F, 0.02F};
+    float epsilon = 1e-8F;
+    int numSteps = 4;
+    float learningRate = 3.0F;
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        AdaGrad instance = new AdaGrad(session.getGraph(), learningRate)) {
+      Ops tf = session.getTF();
+
+      Shape shape0 = Shape.of(var0Init.length);
+      Shape shape1 = Shape.of(var1Init.length);
+      Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+      Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+      Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+      Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+      Constant<TFloat32> grads0 = tf.constant(grads0Init);
+      Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+      /* initialize the local variables */
+      /* initialize the local variables */
+      session.run(var0Initializer);
+      session.run(var1Initializer);
+
+      /* build the GradsAnvVars */
+      List gradsAndVars = new ArrayList<>();
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+      Op update = instance.applyGradients(gradsAndVars, "AdGradDATest");
+
+      /** initialize the accumulators */
+      session.run(tf.init());
+
+      session.evaluate(var0Init, var0);
+      session.evaluate(var1Init, var1);
+
+      float[][][] expected = {
+        {{-2.121320f, -2.683281f}, {-0.298511f, -0.588348f}},
+        {{-3.680166f, -4.483282f}, {-0.565851f, -1.107964f}},
+        {{-4.895166f, -5.831203f}, {-0.805286f, -1.567190f}},
+        {{-5.873222f, -6.892054f}, {-1.019739f, -1.973306f}}
+      };
+      for (int i = 0; i < numSteps; i++) {
+        assertEquals(learningRate, instance.getLearningRate(), epsilon);
+        session.evaluate(
+            learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap());
+        session.run(update, instance.getFeedMap());
+        session.evaluate(expected[i][0], var0);
+        session.evaluate(expected[i][1], var1);
+        learningRate *= 0.9;
+        instance.setLearningRate(learningRate);
+      }
+    }
+  }
 }
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java
index c5ae178b84c..dc05b9b8f81 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java
@@ -15,11 +15,7 @@
 package org.tensorflow.framework.optimizers;
 
 import org.junit.jupiter.api.*;
-import org.tensorflow.Graph;
-import org.tensorflow.framework.utils.ND;
 import org.tensorflow.framework.utils.TestSession;
-import org.tensorflow.ndarray.FloatNdArray;
-import org.tensorflow.ndarray.NdArrays;
 import org.tensorflow.ndarray.Shape;
 import org.tensorflow.op.Op;
 import org.tensorflow.op.Ops;
@@ -60,22 +56,13 @@ public void testBasic() {
     float[] var1Init = {3.0F, 4.0F};
     float[] grads0Init = {0.1F, 0.1F};
     float[] grads1Init = {0.01F, 0.01F};
-    float[] accum0 = {0.1f, 0.1f};
-    float[] accum1 = {0.1f, 0.1f};
 
-    FloatNdArray var0Np = NdArrays.vectorOf(var0Init);
-    FloatNdArray var1Np = NdArrays.vectorOf(var1Init);
-    FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init);
-    FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init);
-    FloatNdArray accum0Np = NdArrays.vectorOf(accum0);
-    FloatNdArray accum1Np = NdArrays.vectorOf(accum1);
+    float learningRate = 3.0F;
 
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        AdaGrad instance = new AdaGrad(session.getGraph(), learningRate, 0.1f)) {
 
-      float learningRate = 3.0F;
-      AdaGrad instance = new AdaGrad(graph, learningRate, 0.1f);
-      Ops tf = instance.getTF();
+      Ops tf = session.getTF();
 
       Shape shape0 = Shape.of(var0Init.length);
       Shape shape1 = Shape.of(var1Init.length);
@@ -88,8 +75,6 @@ public void testBasic() {
       Constant<TFloat32> grads0 = tf.constant(grads0Init);
       Constant<TFloat32> grads1 = tf.constant(grads1Init);
 
-
-
       /* build the GradsAnvVars */
       List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
@@ -97,6 +82,7 @@ public void testBasic() {
 
       Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradTest");
 
+      @SuppressWarnings("unchecked")
       Variable<TFloat32>[] accumulatorSlots = new Variable[2];
       accumulatorSlots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get();
       assertEquals(accumulatorSlots[0].asOutput().shape(), var0.asOutput().shape());
@@ -116,7 +102,7 @@ public void testBasic() {
       session.evaluate(var1Init, var1);
 
       for (int step = 0; step < numSteps; step++) {
-        session.run(adaUpdate);
+        session.run(adaUpdate, instance.getFeedMap());
       }
       float[] expected0 = {-1.6026098728179932f, -0.6026098728179932f};
       session.evaluate(expected0, var0);
@@ -125,18 +111,138 @@ public void testBasic() {
     }
   }
 
-  private FloatNdArray caclulateAccum(FloatNdArray accum, FloatNdArray grads) {
-    // accum + gT * gT
-    FloatNdArray squareG = ND.square(grads);
-    return ND.add(accum, squareG);
+  @Test
+  public void testBasicWithLROperand() {
+    int numSteps = 3;
+    float[] var0Init = {1.0F, 2.0F};
+    float[] var1Init = {3.0F, 4.0F};
+    float[] grads0Init = {0.1F, 0.1F};
+    float[] grads1Init = {0.01F, 0.01F};
+
+    float learningRate = 1.0F;
+
+    try (TestSession session = TestSession.createTestSession(tfMode)) {
+      Ops tf = session.getTF();
+      try (AdaGrad instance =
+          new AdaGrad(
+              session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(3.f)), 0.1f)) {
+
+        Shape shape0 = Shape.of(var0Init.length);
+        Shape shape1 = Shape.of(var1Init.length);
+        Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+        Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+        Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+        Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+        Constant<TFloat32> grads0 = tf.constant(grads0Init);
+        Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+        /* build the GradsAnvVars */
+        List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+        Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradTest");
+
+        @SuppressWarnings("unchecked")
+        Variable<TFloat32>[] accumulatorSlots = new Variable[2];
+        accumulatorSlots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get();
+        assertEquals(accumulatorSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+        accumulatorSlots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get();
+        assertEquals(accumulatorSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+        /* initialize the local variables */
+        session.run(var0Initializer);
+        session.run(var1Initializer);
+
+        /* initialize the accumulators */
+        session.run(tf.init());
+
+        /* make sure the variables were initialized properly */
+        session.evaluate(var0Init, var0);
+        session.evaluate(var1Init, var1);
+
+        for (int step = 0; step < numSteps; step++) {
+          session.run(adaUpdate, instance.getFeedMap());
+        }
+        float[] expected0 = {-1.6026098728179932f, -0.6026098728179932f};
+        session.evaluate(expected0, var0);
+        float[] expected1 = {2.715679168701172f, 3.715679168701172f};
+        session.evaluate(expected1, var1);
+      }
+    }
   }
 
-  private FloatNdArray calculate(
-      FloatNdArray param, FloatNdArray accum, FloatNdArray grads, float learningRate) {
-    // param - lr * gT / (np.sqrt(accumT) + epsilon)
-    FloatNdArray divisor = ND.add(ND.sqrt(accum), 1e-07f);
-    FloatNdArray dividend = ND.mul(learningRate, grads);
-    FloatNdArray quotient = ND.div(dividend, divisor);
-    return ND.sub(param, quotient);
+  @Test
+  public void testWithLearningRateDecay() {
+    int numSteps = 3;
+    float[] var0Init = {1.0F, 2.0F};
+    float[] var1Init = {3.0F, 4.0F};
+    float[] grads0Init = {0.1F, 0.1F};
+    float[] grads1Init = {0.01F, 0.01F};
+    float epsilon = 1e-8F;
+    float learningRate = 3.0F;
+
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        AdaGrad instance = new AdaGrad(session.getGraph(), learningRate)) {
+      Ops tf = session.getTF();
+
+      Shape shape0 = Shape.of(var0Init.length);
+      Shape shape1 = Shape.of(var1Init.length);
+      Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+      Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+      Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+      Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+      Constant<TFloat32> grads0 = tf.constant(grads0Init);
+      Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+      /* build the GradsAnvVars */
+      List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+      Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradTest");
+
+      @SuppressWarnings("unchecked")
+      Variable<TFloat32>[] accumulatorSlots = new Variable[2];
+      accumulatorSlots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get();
+      assertEquals(accumulatorSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+      accumulatorSlots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get();
+      assertEquals(accumulatorSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+      /* initialize the local variables */
+      session.run(var0Initializer);
+      session.run(var1Initializer);
+
+      // initialize the accumulators
+      session.run(tf.init());
+
+      // make sure the variables were initialized properly
+      session.evaluate(var0Init, var0);
+      session.evaluate(var1Init, var1);
+
+      float[][][] expected = {
+        {{-1.121320f, -0.121320f}, {2.701489f, 3.701489f}},
+        {{-2.680166f, -1.680166f}, {2.434149f, 3.434149f}},
+        {{-3.895166f, -2.895166f}, {2.194714f, 3.194714f}}
+      };
+      for (int step = 0; step < numSteps; step++) {
+        assertEquals(learningRate, instance.getLearningRate(), epsilon);
+        session.evaluate(
+            learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap());
+        session.run(adaUpdate, instance.getFeedMap());
+
+        session.evaluate(expected[step][0], var0);
+        session.evaluate(expected[step][1], var1);
+
+        learningRate *= 0.9;
+        instance.setLearningRate(learningRate);
+      }
+    }
   }
 }
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java
index 461fa75397f..c5bb153d804 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java
@@ -15,7 +15,6 @@
 package org.tensorflow.framework.optimizers;
 
 import org.junit.jupiter.api.*;
-import org.tensorflow.Graph;
 import org.tensorflow.Tensor;
 import org.tensorflow.framework.utils.ND;
 import org.tensorflow.framework.utils.TestSession;
@@ -66,18 +65,18 @@ public void testBasic() {
     FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init);
     FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init);
 
-    float epsilon1 = 1e-3F;
+    float epsilon1 = 1e-3f;
+    float learningRate = 0.001f;
+
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Adam instance = new Adam(session.getGraph(), learningRate)) {
 
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      float learningRate = 0.001F;
       float beta1 = 0.9F;
       float beta2 = 0.999F;
-      Graph graph = session.getGraph();
 
       session.setEpsilon(epsilon1);
 
-      Adam instance = new Adam(graph, learningRate);
-      Ops tf = instance.getTF();
+      Ops tf = session.getTF();
       Shape shape0 = Shape.of(var0Init.length);
       Shape shape1 = Shape.of(var1Init.length);
       Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
@@ -93,15 +92,11 @@ public void testBasic() {
       session.run(var0Initializer);
       session.run(var1Initializer);
 
-
-
       /* build the GradsAnvVars */
       List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
 
-
-
       Op update = instance.applyGradients(gradsAndVars, "AdamTest");
 
       /* Create and validate the shapes of the slots */
@@ -123,7 +118,7 @@ public void testBasic() {
       assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
 
       /* initialize the accumulators */
-      session.run(tf.init());
+      session.run(tf.init(), instance.getFeedMap());
 
       session.evaluate(var0Init, var0);
       session.evaluate(var1Init, var1);
@@ -160,7 +155,7 @@ public void testBasic() {
                 .expect(TFloat32.DTYPE)) {
           result.data().scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1));
         }
-        session.run(update);
+        session.run(update, instance.getFeedMap());
 
         float lrT =
             learningRate
@@ -190,13 +185,282 @@ public void testBasic() {
     }
   }
 
+  @Test
+  public void testBasicWithLROperand() {
+    float[] var0Init = {1.0F, 2.0F};
+    float[] var1Init = {3.0F, 4.0F};
+    float[] grads0Init = {0.1F, 0.1F};
+    float[] grads1Init = {0.01F, 0.01F};
+    FloatNdArray var0Np = NdArrays.vectorOf(var0Init);
+    FloatNdArray var1Np = NdArrays.vectorOf(var1Init);
+    FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init);
+    FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init);
+
+    float epsilon1 = 1e-3f;
+    float learningRate = 0.001f;
+
+    try (TestSession session = TestSession.createTestSession(tfMode)) {
+      Ops tf = session.getTF();
+      try (Adam instance =
+          new Adam(
+              session.getGraph(), tf.constant(learningRate))) {
+
+        float beta1 = 0.9F;
+        float beta2 = 0.999F;
+
+        session.setEpsilon(epsilon1);
+
+        Shape shape0 = Shape.of(var0Init.length);
+        Shape shape1 = Shape.of(var1Init.length);
+        Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+        Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+        Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+        Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+        Constant<TFloat32> grads0 = tf.constant(grads0Init);
+        Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+        /* initialize the local variables */
+        session.run(var0Initializer);
+        session.run(var1Initializer);
+
+        /* build the GradsAnvVars */
+        List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+        Op update = instance.applyGradients(gradsAndVars, "AdamTest");
+
+        /* Create and validate the shapes of the slots */
+        @SuppressWarnings("unchecked")
+        Variable<TFloat32>[] firstMomentSlots = new Variable[2];
+        @SuppressWarnings("unchecked")
+        Variable<TFloat32>[] secondMomentSlots = new Variable[2];
+
+        firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get();
+        assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+        secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get();
+        assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+        firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get();
+        assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+        secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get();
+        assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+        /* initialize the accumulators */
+        session.run(tf.init(), instance.getFeedMap());
+
+        session.evaluate(var0Init, var0);
+        session.evaluate(var1Init, var1);
+
+        FloatNdArray m0Np = NdArrays.ofFloats(shape0);
+        FloatNdArray v0Np = NdArrays.ofFloats(shape0);
+        FloatNdArray m1Np = NdArrays.ofFloats(shape1);
+        FloatNdArray v1Np = NdArrays.ofFloats(shape1);
+
+        for (int step = 0; step < 3; step++) {
+
+          // Test powers
+          final float[] powers = {
+            (float) Math.pow(beta1, step + 1), (float) Math.pow(beta2, step + 1)
+          };
+
+          try (Tensor<TFloat32> result =
+              session
+                  .getGraphSession()
+                  .runner()
+                  .fetch("beta1_power")
+                  .run()
+                  .get(0)
+                  .expect(TFloat32.DTYPE)) {
+            result.data().scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1));
+          }
+          try (Tensor<TFloat32> result =
+              session
+                  .getGraphSession()
+                  .runner()
+                  .fetch("beta2_power")
+                  .run()
+                  .get(0)
+                  .expect(TFloat32.DTYPE)) {
+            result.data().scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1));
+          }
+          session.run(update, instance.getFeedMap());
+
+          float lrT =
+              learningRate
+                  * (float) Math.sqrt(1 - (float) Math.pow(beta2, (step + 1)))
+                  / (1 - (float) Math.pow(beta1, (step + 1)));
+
+          m0Np = calculateM(m0Np, grads0Np, beta1);
+          v0Np = calculateV(v0Np, grads0Np, beta2);
+          var0Np = calculateParam(var0Np, lrT, m0Np, v0Np, 1e-7F);
+
+          m1Np = calculateM(m1Np, grads1Np, beta1);
+          v1Np = calculateV(v1Np, grads1Np, beta2);
+          var1Np = calculateParam(var1Np, lrT, m1Np, v1Np, 1e-7F);
+
+          // evaluate var 0 and var1
+          session.evaluate(var0Np, var0);
+          session.evaluate(var1Np, var1);
+
+          // first moment
+          session.evaluate(m0Np, firstMomentSlots[0]);
+          session.evaluate(m1Np, firstMomentSlots[1]);
+
+          // second moment
+          session.evaluate(v0Np, secondMomentSlots[0]);
+          session.evaluate(v1Np, secondMomentSlots[1]);
+        }
+      }
+    }
+  }
+
+  @Test
+  public void testWithLearningRateDecay() {
+
+    float[] var0Init = {1.0F, 2.0F};
+    float[] var1Init = {3.0F, 4.0F};
+    float[] grads0Init = {0.1F, 0.1F};
+    float[] grads1Init = {0.01F, 0.01F};
+    FloatNdArray var0Np = NdArrays.vectorOf(var0Init);
+    FloatNdArray var1Np = NdArrays.vectorOf(var1Init);
+    FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init);
+    FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init);
+
+    float epsilon1 = 1e-3F;
+    float learningRate = 0.001F;
+    float beta1 = 0.9F;
+    float beta2 = 0.999F;
+
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Adam instance = new Adam(session.getGraph(), learningRate)) {
+      Ops tf = session.getTF();
+
+      session.setEpsilon(epsilon1);
+
+      Shape shape0 = Shape.of(var0Init.length);
+      Shape shape1 = Shape.of(var1Init.length);
+      Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+      Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+      Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+      Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+      Constant<TFloat32> grads0 = tf.constant(grads0Init);
+      Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+      /* initialize the local variables */
+      session.run(var0Initializer);
+      session.run(var1Initializer);
+
+      /* build the GradsAnvVars */
+      List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+      Op update = instance.applyGradients(gradsAndVars, "AdamTest");
+
+      /* Create and validate the shapes of the slots */
+      @SuppressWarnings("unchecked")
+      Variable<TFloat32>[] firstMomentSlots = new Variable[2];
+      @SuppressWarnings("unchecked")
+      Variable<TFloat32>[] secondMomentSlots = new Variable[2];
+
+      firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get();
+      assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+      secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get();
+      assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+      firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get();
+      assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+      secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get();
+      assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+      // initialize the accumulators
+      session.run(tf.init());
+
+      session.evaluate(var0Init, var0);
+      session.evaluate(var1Init, var1);
+
+      FloatNdArray m0Np = NdArrays.ofFloats(shape1);
+      FloatNdArray v0Np = NdArrays.ofFloats(shape1);
+      FloatNdArray m1Np = NdArrays.ofFloats(shape1);
+      FloatNdArray v1Np = NdArrays.ofFloats(shape1);
+
+      for (int step = 0; step < 3; step++) {
+
+        // Test powers
+        final float[] powers = {
+          (float) Math.pow(beta1, step + 1), (float) Math.pow(beta2, step + 1)
+        };
+
+        try (Tensor<TFloat32> result =
+            session
+                .getGraphSession()
+                .runner()
+                .fetch("beta1_power")
+                .run()
+                .get(0)
+                .expect(TFloat32.DTYPE)) {
+          result.data().scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1));
+        }
+        try (Tensor<TFloat32> result =
+            session
+                .getGraphSession()
+                .runner()
+                .fetch("beta2_power")
+                .run()
+                .get(0)
+                .expect(TFloat32.DTYPE)) {
+          result.data().scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1));
+        }
+        assertEquals(learningRate, instance.getLearningRate(), 1e-6f);
+        session.evaluate(
+            learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap());
+        session.run(update, instance.getFeedMap());
+
+        float lr_t =
+            learningRate
+                * (float) Math.sqrt(1 - (float) Math.pow(beta2, (step + 1)))
+                / (1 - (float) Math.pow(beta1, (step + 1)));
+
+        m0Np = calculateM(m0Np, grads0Np, beta1);
+        v0Np = calculateV(v0Np, grads0Np, beta2);
+        var0Np = calculateParam(var0Np, lr_t, m0Np, v0Np, 1e-7F);
+
+        m1Np = calculateM(m1Np, grads1Np, beta1);
+        v1Np = calculateV(v1Np, grads1Np, beta2);
+        var1Np = calculateParam(var1Np, lr_t, m1Np, v1Np, 1e-7F);
+
+        // evaluate var 0 and var1
+        session.evaluate(var0Np, var0);
+        session.evaluate(var1Np, var1);
+
+        // first moment
+        session.evaluate(m0Np, firstMomentSlots[0]);
+        session.evaluate(m1Np, firstMomentSlots[1]);
+
+        // second moment
+        session.evaluate(v0Np, secondMomentSlots[0]);
+        session.evaluate(v1Np, secondMomentSlots[1]);
+
+        learningRate *= 0.9;
+        instance.setLearningRate(learningRate);
+      }
+    }
+  }
+
   private FloatNdArray calculateM(FloatNdArray m, FloatNdArray gT, float beta) {
-    // mT = beta1 * m + (1 - beta1) * gT
     return ND.add(ND.mul(m, beta), ND.mul(gT, (1 - beta)));
   }
 
   private FloatNdArray calculateV(FloatNdArray v, FloatNdArray gT, float beta) {
-    // beta2 * v + (1 - beta2) * gT * gT
     FloatNdArray mul1 = ND.mul(v, beta);
     FloatNdArray squareG = ND.square(gT);
     FloatNdArray mul2 = ND.mul((1 - beta), squareG);
@@ -204,11 +468,10 @@ private FloatNdArray calculateV(FloatNdArray v, FloatNdArray gT, float beta) {
   }
 
   private FloatNdArray calculateParam(
-      FloatNdArray param, float lrT, FloatNdArray m, FloatNdArray v, float epsilon) {
-    //  param - lrT * mT / (np.sqrt(vT) + epsilon)
+      FloatNdArray param, float lr_t, FloatNdArray m, FloatNdArray v, float epsilon) {
     FloatNdArray sqrt = ND.sqrt(v);
     FloatNdArray divisor = ND.add(sqrt, epsilon);
-    FloatNdArray dividend = ND.mul(lrT, m);
+    FloatNdArray dividend = ND.mul(lr_t, m);
     FloatNdArray quotient = ND.div(dividend, divisor);
     return ND.sub(param, quotient);
   }
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java
index de17395f76a..e900018ccad 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java
@@ -15,7 +15,6 @@
 package org.tensorflow.framework.optimizers;
 
 import org.junit.jupiter.api.*;
-import org.tensorflow.Graph;
 import org.tensorflow.Tensor;
 import org.tensorflow.framework.utils.ND;
 import org.tensorflow.framework.utils.TestSession;
@@ -61,10 +60,10 @@ public void tearDown() {}
   /** Test of getOptimizerName method, of class Adamax. */
   @Test
   public void testGetOptimizerName() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      Adamax instance = new Adamax(graph);
-      String expResult = "Adamax";
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Adamax instance = new Adamax(session.getGraph())) {
+
+      String expResult = DEFAULT_NAME;
       String result = instance.getOptimizerName();
       assertEquals(expResult, result);
     }
@@ -93,11 +92,9 @@ public void testBasic() {
 
     float epsilon1 = 1e-3F;
 
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-
-      Adamax instance = new Adamax(graph);
-      Ops tf = instance.getTF();
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Adamax instance = new Adamax(session.getGraph())) {
+      Ops tf = session.getTF();
 
       Shape shape0 = Shape.of(var0Init.length);
       Shape shape1 = Shape.of(var1Init.length);
@@ -114,7 +111,6 @@ public void testBasic() {
       session.run(var0Initializer);
       session.run(var1Initializer);
 
-
       /* build the GradsAnvVars */
       List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
       gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput()));
@@ -123,7 +119,9 @@ public void testBasic() {
       Op update = instance.applyGradients(gradsAndVars, "AdamTest");
 
       /* Create and validae the shapes of the slota */
+      @SuppressWarnings("unchecked")
       Variable<TFloat32>[] firstMomentSlots = new Variable[2];
+      @SuppressWarnings("unchecked")
       Variable<TFloat32>[] secondMomentSlots = new Variable[2];
 
       firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get();
@@ -159,7 +157,7 @@ public void testBasic() {
                 .expect(TFloat32.DTYPE)) {
           result.data().scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1));
         }
-        session.run(update);
+        session.run(update, instance.getFeedMap());
 
         FloatNdArray[] resultNP = calculate(var0Np, grads0Np, step, m0, v0);
         var0Np = resultNP[VAR];
@@ -179,20 +177,252 @@ public void testBasic() {
     }
   }
 
+  /** Test of applyDense method, of class Adamax. */
+  @Test
+  public void testBasicWithLROperand() {
+
+    int numSteps = 3;
+
+    float[] var0Init = {1.0F, 2.0F};
+    float[] var1Init = {3.0F, 4.0F};
+    float[] grads0Init = {0.1F, 0.1F};
+    float[] grads1Init = {0.01F, 0.01F};
+
+    float[] zeros = {0.0F, 0.0F};
+    FloatNdArray m0 = NdArrays.vectorOf(zeros);
+    FloatNdArray v0 = NdArrays.vectorOf(zeros);
+    FloatNdArray m1 = NdArrays.vectorOf(zeros);
+    FloatNdArray v1 = NdArrays.vectorOf(zeros);
+    FloatNdArray var0Np = NdArrays.vectorOf(var0Init);
+    FloatNdArray var1Np = NdArrays.vectorOf(var1Init);
+    FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init);
+    FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init);
+
+    float epsilon1 = 1e-3f;
+    float learningRate = 1f;
+
+    try (TestSession session = TestSession.createTestSession(tfMode)) {
+      Ops tf = session.getTF();
+      try (Adamax instance =
+          new Adamax(
+              session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(1e-3f)))) {
+
+        Shape shape0 = Shape.of(var0Init.length);
+        Shape shape1 = Shape.of(var1Init.length);
+        Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+        Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+        Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+        Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+        Constant<TFloat32> grads0 = tf.constant(grads0Init);
+        Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+        /* initialize the local variables */
+        session.run(var0Initializer);
+        session.run(var1Initializer);
+
+        /* build the GradsAnvVars */
+        List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+        gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+        gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+        Op update = instance.applyGradients(gradsAndVars, "AdamTest");
+
+        /* Create and validae the shapes of the slota */
+        @SuppressWarnings("unchecked")
+        Variable<TFloat32>[] firstMomentSlots = new Variable[2];
+        @SuppressWarnings("unchecked")
+        Variable<TFloat32>[] secondMomentSlots = new Variable[2];
+
+        firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get();
+        assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+        secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get();
+        assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+        firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get();
+        assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+        secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get();
+        assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+        /* initialize the accumulators */
+        session.run(tf.init());
+
+        /* initialize the local variables */
+        session.run(var0Initializer);
+        session.run(var1Initializer);
+        session.setEpsilon(epsilon1);
+        for (int step = 0; step < numSteps; step++) {
+          // Test powers
+          final float beta1Power = (float) Math.pow(BETA_ONE_DEFAULT, step + 1);
+
+          try (Tensor<TFloat32> result =
+              session
+                  .getGraphSession()
+                  .runner()
+                  .fetch("beta1_power")
+                  .run()
+                  .get(0)
+                  .expect(TFloat32.DTYPE)) {
+            result.data().scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1));
+          }
+          session.run(update, instance.getFeedMap());
+
+          FloatNdArray[] resultNP = calculate(var0Np, grads0Np, step, m0, v0);
+          var0Np = resultNP[VAR];
+          m0 = resultNP[M];
+          v0 = resultNP[V];
+
+          resultNP = calculate(var1Np, grads1Np, step, m1, v1);
+          var1Np = resultNP[VAR];
+          m1 = resultNP[M];
+          v1 = resultNP[V];
+
+          // evaluate  var0 and var1
+
+          session.evaluate(var0Np, var0);
+          session.evaluate(var1Np, var1);
+        }
+      }
+    }
+  }
+
+  @Test
+  public void testWithLearningRateDecay() {
+
+    float epsilon = 1e-6f;
+    float epsilon1 = 1e-3F;
+    int numSteps = 3;
+    float learningRate = 0.001F;
+
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Adamax instance = new Adamax(session.getGraph(), learningRate)) {
+      Ops tf = session.getTF();
+      float[] zeros = {0.0F, 0.0F};
+      FloatNdArray m0 = NdArrays.vectorOf(zeros);
+      FloatNdArray v0 = NdArrays.vectorOf(zeros);
+      FloatNdArray m1 = NdArrays.vectorOf(zeros);
+      FloatNdArray v1 = NdArrays.vectorOf(zeros);
+      float[] var0Init = {1.0F, 2.0F};
+      float[] var1Init = {3.0F, 4.0F};
+      float[] grads0Init = {0.1F, 0.1F};
+      float[] grads1Init = {0.01F, 0.01F};
+      FloatNdArray var0Np = NdArrays.vectorOf(var0Init);
+      FloatNdArray var1Np = NdArrays.vectorOf(var1Init);
+      FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init);
+      FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init);
+      Shape shape0 = Shape.of(var0Init.length);
+      Shape shape1 = Shape.of(var1Init.length);
+      Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+      Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+      Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+      Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+      Constant<TFloat32> grads0 = tf.constant(grads0Init);
+      Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+      /* initialize the local variables */
+      session.run(var0Initializer);
+      session.run(var1Initializer);
+
+      /* build the GradsAnvVars */
+      List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+      Op update = instance.applyGradients(gradsAndVars, "AdamTest");
+
+      /* Create and validae the shapes of the slota */
+      @SuppressWarnings("unchecked")
+      Variable<TFloat32>[] firstMomentSlots = new Variable[2];
+      @SuppressWarnings("unchecked")
+      Variable<TFloat32>[] secondMomentSlots = new Variable[2];
+
+      firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get();
+      assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+      secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get();
+      assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+      firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get();
+      assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+      secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get();
+      assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+      /* initialize the accumulators */
+      session.run(tf.init());
+
+      /* initialize the local variables */
+      session.run(var0Initializer);
+      session.run(var1Initializer);
+      session.setEpsilon(epsilon1);
+      for (int step = 0; step < numSteps; step++) {
+        final float betaPower = (float) Math.pow(BETA_ONE_DEFAULT, step + 1);
+
+        try (Tensor<TFloat32> result =
+            session
+                .getGraphSession()
+                .runner()
+                .fetch("beta1_power")
+                .run()
+                .get(0)
+                .expect(TFloat32.DTYPE)) {
+          result.data().scalars().forEach(f -> assertEquals(betaPower, f.getFloat(), epsilon1));
+        }
+        assertEquals(learningRate, instance.getLearningRate(), epsilon);
+        session.evaluate(
+            learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap());
+        session.run(update, instance.getFeedMap());
+
+        FloatNdArray[] resultNP = calculate(var0Np, grads0Np, step, m0, v0, learningRate);
+        var0Np = resultNP[VAR];
+        m0 = resultNP[M];
+        v0 = resultNP[V];
+
+        resultNP = calculate(var1Np, grads1Np, step, m1, v1, learningRate);
+        var1Np = resultNP[VAR];
+        m1 = resultNP[M];
+        v1 = resultNP[V];
+
+        // evaluate  var0 and var1
+        session.evaluate(var0Np, var0);
+        session.evaluate(var1Np, var1);
+
+        learningRate *= 0.9F;
+        instance.setLearningRate(learningRate);
+      }
+    }
+  }
+
   private FloatNdArray[] calculate(
       FloatNdArray varNp, FloatNdArray gradsNp, int step, FloatNdArray m, FloatNdArray v) {
-    float alpha = 0.001F;
+    return calculate(varNp, gradsNp, step, m, v, 0.001F);
+  }
+
+  private FloatNdArray[] calculate(
+      FloatNdArray varNp,
+      FloatNdArray gradsNp,
+      int step,
+      FloatNdArray m,
+      FloatNdArray v,
+      float alpha) {
+    float beta1 = BETA_ONE_DEFAULT;
+    float beta2 = BETA_TWO_DEFAULT;
     float espilon = 1e-8F;
 
-    float oneMinusBeta1 = 1.F - BETA_ONE_DEFAULT;
-    float oneMinusBeta1Pow = 1.F - (float) Math.pow(BETA_ONE_DEFAULT, step + 1);
+    float oneMinusBeta1 = 1.F - beta1;
+    float oneMinusBeta1Pow = 1.F - (float) Math.pow(beta1, step + 1);
     float alpha1 = alpha / oneMinusBeta1Pow;
 
-    // beta1 * m + (1 - beta1) * gT;
-    m = ND.add(ND.mul(BETA_ONE_DEFAULT, m), ND.mul(oneMinusBeta1, gradsNp));
-    // np.maximum(BETA_TWO_DEFAULT * v, np.abs(gT))
-    v = ND.max(ND.mul(BETA_TWO_DEFAULT, v), ND.abs(gradsNp));
-    // paramT = param - (alpha / (1 - beta1**(t + 1))) * (mT / (vT + epsilon))
+    // beta1 * m + (1 - beta1) * g_t;
+    m = ND.add(ND.mul(beta1, m), ND.mul(oneMinusBeta1, gradsNp));
+    // np.maximum(beta2 * v, np.abs(g_t))
+    v = ND.max(ND.mul(beta2, v), ND.abs(gradsNp));
+    // param_t = param - (alpha / (1 - beta1**(t + 1))) * (m_t / (v_t + epsilon))
     varNp = ND.sub(varNp, ND.mul(alpha1, ND.div(m, ND.add(v, espilon))));
 
     FloatNdArray[] result = new FloatNdArray[3];
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java
index 597f8e52bcd..c047cd4bf40 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java
@@ -67,12 +67,22 @@ public void testFtrlWithL1L2L2Shrinkage() {
     float[] var1Init = {4.0F, 3.0F};
     float[] grads0Init = {0.1F, 0.2F};
     float[] grads1Init = {0.01F, 0.02F};
+    float learningRate = 3.0F;
 
     int numSteps = 10;
 
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Ftrl instance =
+            new Ftrl(
+                session.getGraph(),
+                learningRate,
+                -0.5F, // learningRatePower
+                0.1F, // initialAccumulatorValue
+                0.001F, // l1RegularizationStrength
+                2.0F, // l2RegularizationStrength
+                0.1F // l2ShrinkageRegularizationStrength
+                )) {
       Ops tf = session.getTF();
-      Graph graph = session.getGraph();
 
       Shape shape0 = Shape.of(var0Init.length);
       Shape shape1 = Shape.of(var1Init.length);
@@ -85,19 +95,6 @@ public void testFtrlWithL1L2L2Shrinkage() {
       Constant<TFloat32> grads0 = tf.constant(grads0Init);
       Constant<TFloat32> grads1 = tf.constant(grads1Init);
 
-      float learningRate = 3.0F;
-
-      Ftrl instance =
-          new Ftrl(
-              graph,
-              learningRate,
-              -0.5F, // learningRatePower
-              0.1F, // initialAccumulatorValue
-              0.001F, // l1RegularizationStrength
-              2.0F, // l2RegularizationStrength
-              0.1F // l2ShrinkageRegularizationStrength
-              );
-
       /* build the GradsAnvVars */
       List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
@@ -116,7 +113,7 @@ public void testFtrlWithL1L2L2Shrinkage() {
       session.evaluate(var1Init, var1);
 
       for (int i = 0; i < numSteps; i++) {
-        session.run(ftrlUpdate);
+        session.run(ftrlUpdate, instance.getFeedMap());
       }
 
       float[] expectedVar0 = {-0.22578995F, -0.44345796F};
@@ -126,6 +123,69 @@ public void testFtrlWithL1L2L2Shrinkage() {
     }
   }
 
+  @Test
+  public void testFtrlWithL1L2L2ShrinkageWithLROperand() {
+    float[] var0Init = {1.0F, 2.0F};
+    float[] var1Init = {4.0F, 3.0F};
+    float[] grads0Init = {0.1F, 0.2F};
+    float[] grads1Init = {0.01F, 0.02F};
+    float learningRate = 1.0F;
+
+    int numSteps = 10;
+
+    try (TestSession session = TestSession.createTestSession(tfMode)) {
+      Ops tf = session.getTF();
+      try (Ftrl instance =
+          new Ftrl(
+              session.getGraph(),
+              tf.math.mul(tf.constant(learningRate), tf.constant(3f)),
+              -0.5F, // learningRatePower
+              0.1F, // initialAccumulatorValue
+              0.001F, // l1RegularizationStrength
+              2.0F, // l2RegularizationStrength
+              0.1F // l2ShrinkageRegularizationStrength
+              )) {
+
+        Shape shape0 = Shape.of(var0Init.length);
+        Shape shape1 = Shape.of(var1Init.length);
+        Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+        Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+        Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+        Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+        Constant<TFloat32> grads0 = tf.constant(grads0Init);
+        Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+        /* build the GradsAnvVars */
+        List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+        Op ftrlUpdate = instance.applyGradients(gradsAndVars, "FtrlTest");
+
+        /* initialize the local variables */
+        session.run(var0Initializer);
+        session.run(var1Initializer);
+
+        /* initialize the accumulators */
+        session.run(tf.init());
+
+        session.evaluate(var0Init, var0);
+        session.evaluate(var1Init, var1);
+
+        for (int i = 0; i < numSteps; i++) {
+          session.run(ftrlUpdate, instance.getFeedMap());
+        }
+
+        float[] expectedVar0 = {-0.22578995F, -0.44345796F};
+        session.evaluate(expectedVar0, var0);
+        float[] expectedVar1 = {-0.14378493F, -0.13229476F};
+        session.evaluate(expectedVar1, var1);
+      }
+    }
+  }
+
   @Test
   public void testFtrlWithL1() {
     float[] var0Init = {1.0F, 2.0F};
@@ -181,7 +241,7 @@ public void testFtrlWithL1() {
       session.evaluate(var1Init, var1);
 
       for (int i = 0; i < numSteps; i++) {
-        session.run(ftrlUpdate);
+        session.run(ftrlUpdate, instance.getFeedMap());
       }
 
       float[] expectedVar0 = {-7.66718769F, -10.91273689F};
@@ -247,7 +307,7 @@ public void testFtrlWithL1L2() {
       session.evaluate(var1Init, var1);
 
       for (int i = 0; i < numSteps; i++) {
-        session.run(ftrlUpdate);
+        session.run(ftrlUpdate, instance.getFeedMap());
       }
 
       float[] expectedVar0 = {-0.24059935F, -0.46829352F};
@@ -258,6 +318,86 @@ public void testFtrlWithL1L2() {
     }
   }
 
+  @Test
+  public void testChangingLearningRate() {
+    float learningRate = 3.0F;
+    float epsilon = 1e-8f;
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Ftrl instance =
+            new Ftrl(
+                session.getGraph(),
+                learningRate,
+                Ftrl.LEARNING_RATE_POWER_DEFAULT,
+                0.1F,
+                0.001F,
+                2.0F,
+                Ftrl.L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT)) {
+      Ops tf = session.getTF();
+      int numSteps = 10;
+
+      float[] var0Init = {1.0F, 2.0F};
+      float[] var1Init = {4.0F, 3.0F};
+      float[] grads0Init = {0.1F, 0.2F};
+      float[] grads1Init = {0.01F, 0.02F};
+      Shape shape0 = Shape.of(var0Init.length);
+      Shape shape1 = Shape.of(var1Init.length);
+      Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+      Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+      Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+      Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+      Constant<TFloat32> grads0 = tf.constant(grads0Init);
+      Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+      /* build the GradsAnvVars */
+      List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+      Op ftrlUpdate = instance.applyGradients(gradsAndVars, "FtrlTest");
+
+      /* initialize the local variables */
+      session.run(var0Initializer);
+      session.run(var1Initializer);
+
+      // initialize the accumulators
+      session.run(tf.init());
+      float expected[][][] = {
+        // Step: 0
+        {{-0.022833f, -0.038881f}, {-0.002141f, -0.004474f}},
+        // Step: 1
+        {{-0.203440f, -0.336012f}, {-0.021145f, -0.042048f}},
+        // Step: 2
+        {{-0.667457f, -0.962730f}, {-0.074745f, -0.147685f}},
+        // Step: 3
+        {{-0.854973f, -1.163154f}, {-0.099466f, -0.196172f}},
+        // Step: 4
+        {{-0.878825f, -1.186153f}, {-0.102859f, -0.202808f}},
+        // Step: 5
+        {{-0.881205f, -1.188360f}, {-0.103211f, -0.203495f}},
+        // Step: 6
+        {{-0.881436f, -1.188569f}, {-0.103246f, -0.203564f}},
+        // Step: 7
+        {{-0.881459f, -1.188589f}, {-0.103250f, -0.203571f}},
+        // Step: 8
+        {{-0.881461f, -1.188591f}, {-0.103250f, -0.203572f}},
+        // Step: 9
+        {{-0.881462f, -1.188591f}, {-0.103250f, -0.203572f}},
+      };
+      for (int i = 0; i < numSteps; i++) {
+        assertEquals(learningRate, instance.getLearningRate(), epsilon);
+        session.evaluate(
+            learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap());
+        session.run(ftrlUpdate, instance.getFeedMap());
+        session.evaluate(expected[i][0], var0);
+        session.evaluate(expected[i][1], var1);
+        learningRate *= 0.1f;
+        instance.setLearningRate(learningRate);
+      }
+    }
+  }
+
   @Test
   public void doTestFtrlwithoutRegularization() {
     float[] var0Init = {0.0F, 0.0F};
@@ -266,10 +406,11 @@ public void doTestFtrlwithoutRegularization() {
     float[] grads1Init = {0.01F, 0.02F};
 
     int numSteps = 3;
+    float learningRate = 3.0f;
 
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Ftrl instance = new Ftrl(session.getGraph(), learningRate)) {
       Ops tf = session.getTF();
-      Graph graph = session.getGraph();
 
       Shape shape0 = Shape.of(var0Init.length);
       Shape shape1 = Shape.of(var1Init.length);
@@ -282,10 +423,6 @@ public void doTestFtrlwithoutRegularization() {
       Constant<TFloat32> grads0 = tf.constant(grads0Init);
       Constant<TFloat32> grads1 = tf.constant(grads1Init);
 
-      float learningRate = 3.0F;
-
-      Ftrl instance = new Ftrl(graph, learningRate);
-
       /* build the GradsAnvVars */
       List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
@@ -303,7 +440,7 @@ public void doTestFtrlwithoutRegularization() {
       session.evaluate(var1Init, var1);
 
       for (int i = 0; i < numSteps; i++) {
-        session.run(ftrlUpdate);
+        session.run(ftrlUpdate, instance.getFeedMap());
       }
 
       float[] expectedVar0 = {-2.60260963F, -4.29698515F};
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java
index 4362c54d815..ce687186994 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java
@@ -55,9 +55,9 @@ public void testBasic() {
     float[] grads1Init = {0.01F, 0.01F};
     float learningRate = 3.0F;
 
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        GradientDescent instance = new GradientDescent(session.getGraph(), learningRate)) {
       Ops tf = session.getTF();
-      Graph graph = session.getGraph();
 
       Shape shape0 = Shape.of(var0Init.length);
       Shape shape1 = Shape.of(var1Init.length);
@@ -75,7 +75,6 @@ public void testBasic() {
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
 
-      GradientDescent instance = new GradientDescent(graph, learningRate);
       Op update = instance.applyGradients(gradsAndVars, "SGDTest");
 
       /* initialize the local variables */
@@ -89,7 +88,7 @@ public void testBasic() {
       session.evaluate(var0Init, var0);
       session.evaluate(var1Init, var1);
 
-      session.run(update); // 1 step
+      session.run(update, instance.getFeedMap()); // 1 step
 
       float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F};
       float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F};
@@ -97,4 +96,126 @@ public void testBasic() {
       session.evaluate(expectedVar1, var1);
     }
   }
+
+  @Test
+  public void testBasicWithLROperand() {
+    float[] var0Init = {1.0F, 2.0F};
+    float[] var1Init = {3.0F, 4.0F};
+    float[] grads0Init = {0.1F, 0.1F};
+    float[] grads1Init = {0.01F, 0.01F};
+    float learningRate = 1.5f;
+
+    try (TestSession session = TestSession.createTestSession(tfMode)) {
+      Ops tf = session.getTF();
+      try (GradientDescent instance =
+          new GradientDescent(
+              session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(2f)))) {
+
+        Shape shape0 = Shape.of(var0Init.length);
+        Shape shape1 = Shape.of(var1Init.length);
+        Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+        Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+        Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+        Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+        Constant<TFloat32> grads0 = tf.constant(grads0Init);
+        Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+        /* build the GradsAnvVars */
+        List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+        Op update = instance.applyGradients(gradsAndVars, "SGDTest");
+
+        /* initialize the local variables */
+        session.run(var0Initializer);
+        session.run(var1Initializer);
+
+        /* initialize the accumulators */
+        session.run(tf.init());
+
+        /* make sure the variables were initialized properly */
+        session.evaluate(var0Init, var0);
+        session.evaluate(var1Init, var1);
+
+        session.run(update, instance.getFeedMap()); // 1 step
+
+        float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F};
+        float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F};
+        session.evaluate(expectedVar0, var0);
+        session.evaluate(expectedVar1, var1);
+      }
+    }
+  }
+
+  @Test
+  public void testWithLearningRateDecay() {
+    int numSteps = 2;
+    float[] var0Init = {1.0F, 2.0F};
+    float[] var1Init = {3.0F, 4.0F};
+    float[] grads0Init = {0.1F, 0.1F};
+    float[] grads1Init = {0.01F, 0.01F};
+
+    float learningRate = 3.0F;
+
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        GradientDescent instance = new GradientDescent(session.getGraph(), learningRate)) {
+      Ops tf = session.getTF();
+      Shape shape0 = Shape.of(var0Init.length);
+      Shape shape1 = Shape.of(var1Init.length);
+      Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+      Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+      Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+      Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+      Constant<TFloat32> grads0 = tf.constant(grads0Init);
+      Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+      /* build the GradsAnvVars */
+      List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+      Op update = instance.applyGradients(gradsAndVars, "GradientDescentTest");
+
+      /* initialize the local variables */
+      session.run(var0Initializer);
+      session.run(var1Initializer);
+
+      // initialize the accumulators
+      session.run(tf.init());
+
+      // make sure the variables were initialized properly
+      session.evaluate(var0Init, var0);
+      session.evaluate(var1Init, var1);
+
+      float[][] expectedVar0 = {
+        {0.7F, 1.7F},
+        {0.66999996F, 1.6700001F},
+        {0.66699994F, 1.667F},
+        {0.66669995F, 1.6667F},
+        {0.66666996F, 1.66667F}
+      };
+      float[][] expectedVar1 = {
+        {2.97F, 3.97F},
+        {2.967F, 3.967F},
+        {2.9667F, 3.9667F},
+        {2.96667F, 3.96667F},
+        {2.966667F, 3.966667F}
+      };
+      for (int step = 0; step < numSteps; step++) {
+        assertEquals(learningRate, instance.getLearningRate(), 1e-6f);
+        session.evaluate(
+            learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap());
+        session.run(update, instance.getFeedMap());
+        session.evaluate(expectedVar0[step], var0);
+        session.evaluate(expectedVar1[step], var1);
+        learningRate *= 0.1;
+        instance.setLearningRate(learningRate);
+      }
+    }
+  }
 }
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java
index bcfff97773d..014c72e55e6 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java
@@ -71,9 +71,9 @@ public void testBasic() {
     float[] grads1Init = {0.01F, 0.01F};
     float learningRate = 3.0F;
 
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Momentum instance = new Momentum(session.getGraph(), learningRate)) {
       Ops tf = session.getTF();
-      Graph graph = session.getGraph();
 
       Shape shape0 = Shape.of(var0Init.length);
       Shape shape1 = Shape.of(var1Init.length);
@@ -91,7 +91,6 @@ public void testBasic() {
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
 
-      Momentum instance = new Momentum(graph, learningRate);
       Op update = instance.applyGradients(gradsAndVars, "SGDTest");
 
       /* initialize the local variables */
@@ -105,7 +104,7 @@ public void testBasic() {
       session.evaluate(var0Init, var0);
       session.evaluate(var1Init, var1);
 
-      session.run(update); // 1 step
+      session.run(update, instance.getFeedMap()); // 1 step
 
       float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F};
       float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F};
@@ -114,6 +113,57 @@ public void testBasic() {
     }
   }
 
+  @Test
+  public void testBasicWithLROperand() {
+    float[] var0Init = {1.0F, 2.0F};
+    float[] var1Init = {3.0F, 4.0F};
+    float[] grads0Init = {0.1F, 0.1F};
+    float[] grads1Init = {0.01F, 0.01F};
+    float learningRate = 3.0F;
+
+    try (TestSession session = TestSession.createTestSession(tfMode)) {
+      Ops tf = session.getTF();
+      try (Momentum instance = new Momentum(session.getGraph(), tf.constant(learningRate))) {
+
+        Shape shape0 = Shape.of(var0Init.length);
+        Shape shape1 = Shape.of(var1Init.length);
+        Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+        Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+        Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+        Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+        Constant<TFloat32> grads0 = tf.constant(grads0Init);
+        Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+        /* build the GradsAnvVars */
+        List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+        Op update = instance.applyGradients(gradsAndVars, "SGDTest");
+
+        /* initialize the local variables */
+        session.run(var0Initializer);
+        session.run(var1Initializer);
+
+        /* initialize the accumulators */
+        session.run(tf.init());
+
+        /* make sure the variables were initialized properly */
+        session.evaluate(var0Init, var0);
+        session.evaluate(var1Init, var1);
+
+        session.run(update, instance.getFeedMap()); // 1 step
+
+        float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F};
+        float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F};
+        session.evaluate(expectedVar0, var0);
+        session.evaluate(expectedVar1, var1);
+      }
+    }
+  }
+
   @Test
   public void testMomentum() {
     float[] var0Init = {1.0F, 2.0F};
@@ -124,9 +174,9 @@ public void testMomentum() {
     float learningRate = 2.0F;
     float momentum = 0.9F;
 
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Momentum instance = new Momentum(session.getGraph(), learningRate, momentum)) {
       Ops tf = session.getTF();
-      Graph graph = session.getGraph();
 
       Shape shape0 = Shape.of(var0Init.length);
       Shape shape1 = Shape.of(var1Init.length);
@@ -144,7 +194,6 @@ public void testMomentum() {
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
 
-      Momentum instance = new Momentum(graph, learningRate, momentum);
       Op update = instance.applyGradients(gradsAndVars, "SGDTest");
 
       Variable<TFloat32> momentumSlot0 = instance.getSlot(var0.asOutput(), MOMENTUM).get();
@@ -163,7 +212,7 @@ public void testMomentum() {
       session.evaluate(var0Init, var0);
       session.evaluate(var1Init, var1);
 
-      session.run(update); // 1 step
+      session.run(update, instance.getFeedMap()); // 1 step
 
       float[] expectedMomentum0 = {0.1F, 0.1F};
       float[] expectedMomentum1 = {0.01F, 0.01F};
@@ -175,7 +224,7 @@ public void testMomentum() {
       session.evaluate(expectedVar0, var0);
       session.evaluate(expectedVar1, var1);
 
-      session.run(update); // step 2
+      session.run(update, instance.getFeedMap()); // step 2
 
       float[] expectedMomentum02 = {(0.9f * 0.1f + 0.1f), (0.9f * 0.1f + 0.1f)};
       float[] expectedMomentum12 = {(0.9f * 0.01f + 0.01f), (0.9f * 0.01f + 0.01f)};
@@ -193,4 +242,78 @@ public void testMomentum() {
       session.evaluate(expectedVar12, var1);
     }
   }
+
+  @Test
+  public void testWithLearningRateDecay() {
+    int numSteps = 2;
+    float[] var0Init = {1.0F, 2.0F};
+    float[] var1Init = {3.0F, 4.0F};
+    float[] grads0Init = {0.1F, 0.1F};
+    float[] grads1Init = {0.01F, 0.01F};
+
+    float learningRate = 3.0F;
+
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Momentum instance = new Momentum(session.getGraph(), learningRate)) {
+      Ops tf = session.getTF();
+      Shape shape0 = Shape.of(var0Init.length);
+      Shape shape1 = Shape.of(var1Init.length);
+      Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+      Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+      Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+      Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+      Constant<TFloat32> grads0 = tf.constant(grads0Init);
+      Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+      /* build the GradsAnvVars */
+      List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+      Op update = instance.applyGradients(gradsAndVars, "MomentumTest");
+
+      Variable<TFloat32> momentumSlot0 = instance.getSlot(var0.asOutput(), MOMENTUM).get();
+      assertEquals(momentumSlot0.asOutput().shape(), var0.asOutput().shape());
+      Variable<TFloat32> momentumSlot1 = instance.getSlot(var1.asOutput(), MOMENTUM).get();
+      assertEquals(momentumSlot1.asOutput().shape(), var1.asOutput().shape());
+
+      /* initialize the local variables */
+      session.run(var0Initializer);
+      session.run(var1Initializer);
+
+      // initialize the accumulators
+      session.run(tf.init());
+
+      // make sure the variables were initialized properly
+      session.evaluate(var0Init, var0);
+      session.evaluate(var1Init, var1);
+
+      float[][] expectedVar0 = {
+        {0.7F, 1.7F},
+        {0.66999996F, 1.6700001F},
+        {0.66699994F, 1.667F},
+        {0.66669995F, 1.6667F},
+        {0.66666996F, 1.66667F}
+      };
+      float[][] expectedVar1 = {
+        {2.97F, 3.97F},
+        {2.967F, 3.967F},
+        {2.9667F, 3.9667F},
+        {2.96667F, 3.96667F},
+        {2.966667F, 3.966667F}
+      };
+      for (int step = 0; step < numSteps; step++) {
+        assertEquals(learningRate, instance.getLearningRate(), 1e-6);
+        session.evaluate(
+            learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap());
+        session.run(update, instance.getFeedMap());
+        session.evaluate(expectedVar0[step], var0);
+        session.evaluate(expectedVar1[step], var1);
+        learningRate *= 0.1;
+        instance.setLearningRate(learningRate);
+      }
+    }
+  }
 }
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java
index a583d74246b..0832543c104 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java
@@ -15,7 +15,6 @@
 package org.tensorflow.framework.optimizers;
 
 import org.junit.jupiter.api.*;
-import org.tensorflow.Graph;
 import org.tensorflow.Tensor;
 import org.tensorflow.framework.utils.ND;
 import org.tensorflow.framework.utils.TestSession;
@@ -62,9 +61,9 @@ public void tearDown() {}
   /** Test of getOptimizerName method, of class Nadam. */
   @Test
   public void testGetOptimizerName() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      Nadam instance = new Nadam(graph);
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Nadam instance = new Nadam(session.getGraph())) {
+
       String expResult = "Nadam";
       String result = instance.getOptimizerName();
       assertEquals(expResult, result);
@@ -96,9 +95,9 @@ public void testBasic() {
 
     float epsilon1 = 1e-3F;
 
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Nadam instance = new Nadam(session.getGraph())) {
       Ops tf = session.getTF();
-      Graph graph = session.getGraph();
 
       Shape shape0 = Shape.of(var0Init.length);
       Shape shape1 = Shape.of(var1Init.length);
@@ -111,7 +110,6 @@ public void testBasic() {
       Constant<TFloat32> grads0 = tf.constant(grads0Init);
       Constant<TFloat32> grads1 = tf.constant(grads1Init);
 
-      Nadam instance = new Nadam(graph);
       /* build the GradsAnvVars */
       List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
       gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
@@ -120,7 +118,9 @@ public void testBasic() {
       Op update = instance.applyGradients(gradsAndVars, "AdamTest");
 
       /* Create and validae the shapes of the slota */
+      @SuppressWarnings("unchecked")
       Variable<TFloat32>[] firstMomentSlots = new Variable[2];
+      @SuppressWarnings("unchecked")
       Variable<TFloat32>[] secondMomentSlots = new Variable[2];
 
       firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get();
@@ -161,7 +161,7 @@ public void testBasic() {
 
       for (int step = 0; step < numSteps; step++) {
 
-        session.run(update);
+        session.run(update, instance.getFeedMap());
 
         float mut =
             Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1))));
@@ -203,6 +203,280 @@ public void testBasic() {
     }
   }
 
+  /** Test of applyDense method, of class Nadam. */
+  @Test
+  public void testBasicWithLROperand() {
+
+    int numSteps = 3;
+
+    float[] var0Init = {1.0F, 2.0F};
+    float[] var1Init = {3.0F, 4.0F};
+    float[] grads0Init = {0.1F, 0.1F};
+    float[] grads1Init = {0.01F, 0.01F};
+
+    float[] zeros = {0.0F, 0.0F};
+    float[] ones = {1.0F, 1.0F};
+    FloatNdArray m0 = NdArrays.vectorOf(zeros);
+    FloatNdArray v0 = NdArrays.vectorOf(zeros);
+    FloatNdArray m1 = NdArrays.vectorOf(zeros);
+    FloatNdArray v1 = NdArrays.vectorOf(zeros);
+    FloatNdArray mcache = NdArrays.vectorOf(ones);
+    FloatNdArray var0Np = NdArrays.vectorOf(var0Init);
+    FloatNdArray var1Np = NdArrays.vectorOf(var1Init);
+    FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init);
+    FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init);
+
+    float epsilon1 = 1e-3F;
+
+    try (TestSession session = TestSession.createTestSession(tfMode)) {
+      Ops tf = session.getTF();
+      try (Nadam instance =
+          new Nadam(session.getGraph(), tf.math.mul(tf.constant(1f), tf.constant(1e-3f)))) {
+
+        Shape shape0 = Shape.of(var0Init.length);
+        Shape shape1 = Shape.of(var1Init.length);
+        Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+        Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+        Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+        Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+        Constant<TFloat32> grads0 = tf.constant(grads0Init);
+        Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+        /* build the GradsAnvVars */
+        List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+        gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+        Op update = instance.applyGradients(gradsAndVars, "AdamTest");
+
+        /* Create and validae the shapes of the slota */
+        @SuppressWarnings("unchecked")
+        Variable<TFloat32>[] firstMomentSlots = new Variable[2];
+        @SuppressWarnings("unchecked")
+        Variable<TFloat32>[] secondMomentSlots = new Variable[2];
+
+        firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get();
+        assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+        secondMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.SECOND_MOMENT).get();
+        assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+        firstMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.FIRST_MOMENT).get();
+        assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+        secondMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.SECOND_MOMENT).get();
+        assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+        /* initialize the local variables */
+        session.run(var0Initializer);
+        session.run(var1Initializer);
+
+        /* initialize the accumulators */
+        session.run(tf.init());
+
+        session.setEpsilon(epsilon1);
+
+        session.evaluate(var0Init, var0);
+        session.evaluate(var1Init, var1);
+
+        try (Tensor<TFloat32> result =
+            session
+                .getGraphSession()
+                .runner()
+                .fetch("momentum")
+                .run()
+                .get(0)
+                .expect(TFloat32.DTYPE)) {
+          result.data().scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1));
+        }
+        momentum = 1F;
+
+        for (int step = 0; step < numSteps; step++) {
+
+          session.run(update, instance.getFeedMap());
+
+          float mut =
+              Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1))));
+          momentum = momentum * mut;
+
+          try (Tensor<TFloat32> result =
+              session
+                  .getGraphSession()
+                  .runner()
+                  .fetch("momentum")
+                  .run()
+                  .get(0)
+                  .expect(TFloat32.DTYPE)) {
+            result.data().scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1));
+          }
+          mcache = ND.mul(mcache, momentum);
+          FloatNdArray[] resultsNP = nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache);
+          var0Np = resultsNP[VAR];
+          m0 = resultsNP[M];
+          v0 = resultsNP[V];
+
+          resultsNP = nadamUpdateNdArray(var1Np, grads1Np, step, m1, v1, mcache);
+          var1Np = resultsNP[VAR];
+          m1 = resultsNP[M];
+          v1 = resultsNP[V];
+
+          // evaluate m0 and m1
+          session.evaluate(m0, firstMomentSlots[0]);
+          session.evaluate(m1, firstMomentSlots[1]);
+
+          // evaluate v0 and v1
+          session.evaluate(v0, secondMomentSlots[0]);
+          session.evaluate(v1, secondMomentSlots[1]);
+
+          // evaluate var0 and var1
+          session.evaluate(var0Np, var0);
+          session.evaluate(var1Np, var1);
+        }
+      }
+    }
+  }
+
+  @Test
+  public void testWithLearningRateDecay() {
+    int numSteps = 3;
+
+    float[] var0Init = {1.0F, 2.0F};
+    float[] var1Init = {3.0F, 4.0F};
+    float[] grads0Init = {0.1F, 0.1F};
+    float[] grads1Init = {0.01F, 0.01F};
+
+    float[] zeros = {0.0F, 0.0F};
+    float[] ones = {1.0F, 1.0F};
+    FloatNdArray m0 = NdArrays.vectorOf(zeros);
+    FloatNdArray v0 = NdArrays.vectorOf(zeros);
+    FloatNdArray m1 = NdArrays.vectorOf(zeros);
+    FloatNdArray v1 = NdArrays.vectorOf(zeros);
+    FloatNdArray mcache = NdArrays.vectorOf(ones);
+    FloatNdArray var0Np = NdArrays.vectorOf(var0Init);
+    FloatNdArray var1Np = NdArrays.vectorOf(var1Init);
+    FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init);
+    FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init);
+
+    float epsilon1 = 1e-3F;
+
+    float learningRate = 0.001F;
+
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Nadam instance = new Nadam(session.getGraph(), learningRate)) {
+      Ops tf = session.getTF();
+
+      Shape shape0 = Shape.of(var0Init.length);
+      Shape shape1 = Shape.of(var1Init.length);
+      Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+      Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+      Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+      Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+      Constant<TFloat32> grads0 = tf.constant(grads0Init);
+      Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+      /* build the GradsAnvVars */
+      List<Optimizer.GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+      gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+      Op update = instance.applyGradients(gradsAndVars, "AdamTest");
+
+      /* Create and validae the shapes of the slota */
+      @SuppressWarnings("unchecked")
+      Variable<TFloat32>[] firstMomentSlots = new Variable[2];
+      @SuppressWarnings("unchecked")
+      Variable<TFloat32>[] secondMomentSlots = new Variable[2];
+
+      firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get();
+      assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+      secondMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.SECOND_MOMENT).get();
+      assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape());
+
+      firstMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.FIRST_MOMENT).get();
+      assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+      secondMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.SECOND_MOMENT).get();
+      assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape());
+
+      /* initialize the local variables */
+      session.run(var0Initializer);
+      session.run(var1Initializer);
+
+      // initialize the accumulators
+      session.run(tf.init());
+
+      session.setEpsilon(epsilon1);
+
+      session.evaluate(var0Init, var0);
+      session.evaluate(var1Init, var1);
+
+      try (Tensor<TFloat32> result =
+          session
+              .getGraphSession()
+              .runner()
+              .fetch("momentum")
+              .run()
+              .get(0)
+              .expect(TFloat32.DTYPE)) {
+        result.data().scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1));
+      }
+      momentum = 1F;
+
+      for (int step = 0; step < numSteps; step++) {
+        assertEquals(learningRate, instance.getLearningRate(), 1e-6f);
+        session.evaluate(
+            learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap());
+        session.run(update, instance.getFeedMap());
+
+        float mut =
+            Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1))));
+        momentum = momentum * mut;
+
+        try (Tensor<TFloat32> result =
+            session
+                .getGraphSession()
+                .runner()
+                .fetch("momentum")
+                .run()
+                .get(0)
+                .expect(TFloat32.DTYPE)) {
+          result.data().scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1));
+        }
+        mcache = ND.mul(mcache, momentum);
+        FloatNdArray[] resultsNP =
+            nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache, learningRate);
+        var0Np = resultsNP[VAR];
+        m0 = resultsNP[M];
+        v0 = resultsNP[V];
+
+        resultsNP = nadamUpdateNdArray(var1Np, grads1Np, step, m1, v1, mcache, learningRate);
+        var1Np = resultsNP[VAR];
+        m1 = resultsNP[M];
+        v1 = resultsNP[V];
+
+        // evaluate m0 and m1
+        session.evaluate(m0, firstMomentSlots[0]);
+        session.evaluate(m1, firstMomentSlots[1]);
+
+        // evaluate v0 and v1
+        session.evaluate(v0, secondMomentSlots[0]);
+        session.evaluate(v1, secondMomentSlots[1]);
+
+        // evaluate var0 and var1
+        session.evaluate(var0Np, var0);
+        session.evaluate(var1Np, var1);
+
+        learningRate *= 0.9;
+        instance.setLearningRate(learningRate);
+      }
+    }
+  }
+
   private FloatNdArray[] nadamUpdateNdArray(
       FloatNdArray varNp,
       FloatNdArray gradsNp,
@@ -210,8 +484,18 @@ private FloatNdArray[] nadamUpdateNdArray(
       FloatNdArray m,
       FloatNdArray v,
       FloatNdArray mCache) {
+    return nadamUpdateNdArray(varNp, gradsNp, t, m, v, mCache, 0.001F);
+  }
+
+  private FloatNdArray[] nadamUpdateNdArray(
+      FloatNdArray varNp,
+      FloatNdArray gradsNp,
+      int t,
+      FloatNdArray m,
+      FloatNdArray v,
+      FloatNdArray mCache,
+      float alpha) {
 
-    float alpha = 0.001F;
     float beta1 = 0.9F;
     float beta2 = 0.999F;
     float epsilon = 1e-8F;
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java
index 78b56c8289e..a0bf027abab 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java
@@ -1,7 +1,6 @@
 package org.tensorflow.framework.optimizers;
 
 import org.junit.jupiter.api.*;
-import org.tensorflow.Graph;
 import org.tensorflow.framework.utils.TestSession;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -26,9 +25,8 @@ public void tearDown() {}
   /** Test ADADELTA enum */
   @Test
   public void testADADELTA() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      Optimizer instance = Optimizers.ADADELTA.createOptimizer(graph);
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Optimizer instance = Optimizers.ADADELTA.createOptimizer(session.getGraph())) {
       String expResult = "Adadelta";
       String result = instance.getOptimizerName();
       assertEquals(expResult, result);
@@ -38,9 +36,8 @@ public void testADADELTA() {
   /** Test ADAGRAD enum */
   @Test
   public void testADAGRAD() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      Optimizer instance = Optimizers.ADAGRAD.createOptimizer(graph);
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Optimizer instance = Optimizers.ADAGRAD.createOptimizer(session.getGraph())) {
       String expResult = "Adagrad";
       String result = instance.getOptimizerName();
       assertEquals(expResult, result);
@@ -50,9 +47,8 @@ public void testADAGRAD() {
   /** Test ADAGRAD_DA enum */
   @Test
   public void testADAGRAD_DA() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      Optimizer instance = Optimizers.ADAGRAD_DA.createOptimizer(graph);
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Optimizer instance = Optimizers.ADAGRAD_DA.createOptimizer(session.getGraph())) {
       String expResult = "adagrad-da";
       String result = instance.getOptimizerName();
       assertEquals(expResult, result);
@@ -62,9 +58,8 @@ public void testADAGRAD_DA() {
   /** Test ADAM enum */
   @Test
   public void testADAM() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      Optimizer instance = Optimizers.ADAM.createOptimizer(graph);
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Optimizer instance = Optimizers.ADAM.createOptimizer(session.getGraph())) {
       String expResult = "Adam";
       String result = instance.getOptimizerName();
       assertEquals(expResult, result);
@@ -74,9 +69,8 @@ public void testADAM() {
   /** Test ADAMAX enum */
   @Test
   public void testADAMAX() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      Optimizer instance = Optimizers.ADAMAX.createOptimizer(graph);
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Optimizer instance = Optimizers.ADAMAX.createOptimizer(session.getGraph())) {
       String expResult = "Adamax";
       String result = instance.getOptimizerName();
       assertEquals(expResult, result);
@@ -86,9 +80,8 @@ public void testADAMAX() {
   /** Test FTRL enum */
   @Test
   public void testFTRL() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      Optimizer instance = Optimizers.FTRL.createOptimizer(graph);
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Optimizer instance = Optimizers.FTRL.createOptimizer(session.getGraph())) {
       String expResult = "Ftrl";
       String result = instance.getOptimizerName();
       assertEquals(expResult, result);
@@ -98,9 +91,8 @@ public void testFTRL() {
   /** Test NADAM enum */
   @Test
   public void testNADAM() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      Optimizer instance = Optimizers.NADAM.createOptimizer(graph);
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Optimizer instance = Optimizers.NADAM.createOptimizer(session.getGraph())) {
       String expResult = "Nadam";
       String result = instance.getOptimizerName();
       assertEquals(expResult, result);
@@ -110,9 +102,8 @@ public void testNADAM() {
   /** Test RMSPROP enum */
   @Test
   public void testRMSPROP() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      Optimizer instance = Optimizers.RMSPROP.createOptimizer(graph);
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Optimizer instance = Optimizers.RMSPROP.createOptimizer(session.getGraph())) {
       String expResult = "RMSProp";
       String result = instance.getOptimizerName();
       assertEquals(expResult, result);
@@ -122,9 +113,8 @@ public void testRMSPROP() {
   /** Test MOMENTUM enum */
   @Test
   public void testMOMENTUM() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      Optimizer instance = Optimizers.MOMENTUM.createOptimizer(graph);
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Optimizer instance = Optimizers.MOMENTUM.createOptimizer(session.getGraph())) {
       String expResult = "Momentum";
       String result = instance.getOptimizerName();
       assertEquals(expResult, result);
@@ -134,9 +124,8 @@ public void testMOMENTUM() {
   /** Test GRADIENT_DESCENT enum */
   @Test
   public void testGRADIENT_DESCENT() {
-    try (TestSession session = TestSession.createTestSession(tfMode)) {
-      Graph graph = session.getGraph();
-      Optimizer instance = Optimizers.GRADIENT_DESCENT.createOptimizer(graph);
+    try (TestSession session = TestSession.createTestSession(tfMode);
+        Optimizer instance = Optimizers.GRADIENT_DESCENT.createOptimizer(session.getGraph())) {
       String expResult = "GradientDescent";
       String result = instance.getOptimizerName();
       assertEquals(expResult, result);
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java
index 202fb21ef68..711358e9222 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java
@@ -15,7 +15,6 @@
 package org.tensorflow.framework.optimizers;
 
 import org.junit.jupiter.api.*;
-import org.tensorflow.Graph;
 import org.tensorflow.framework.utils.ND;
 import org.tensorflow.framework.utils.TestSession;
 import org.tensorflow.ndarray.FloatNdArray;
@@ -32,6 +31,8 @@
 import java.util.ArrayList;
 import java.util.List;
 
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
 import static org.tensorflow.framework.optimizers.RMSProp.*;
 
 /** Test cases for RMSProp Optimizer */
@@ -42,7 +43,7 @@ public class RMSPropTest {
   final int MOM_T = 3;
   private final TestSession.Mode tfMode = TestSession.Mode.GRAPH;
 
-  Object[][] TestParamValues = {
+  Object[][] testParamValues = {
     // learningRate, rho (decay), momentum, epsilon, centered
     {0.05F, 0.9F, 0.0F, 1e-3F, true},
     {0.05F, 0.9F, 0.0F, 1e-3F, false},
@@ -70,10 +71,18 @@ public void testDense() {
 
     int numSteps = 3;
 
-    for (Object[] testParamValue : TestParamValues) {
-      try (TestSession session = TestSession.createTestSession(tfMode)) {
+    for (Object[] testParamValue : testParamValues) {
+      // learningRate, rho (decay), momentum, epsilon, centered
+      float learningRate = (float) testParamValue[0];
+      float decay = (float) testParamValue[1];
+      float momentum = (float) testParamValue[2];
+      float epsilon = (float) testParamValue[3];
+      boolean centered = (boolean) testParamValue[4];
+      try (TestSession session = TestSession.createTestSession(tfMode);
+          RMSProp instance =
+              new RMSProp(session.getGraph(), learningRate, decay, momentum, epsilon, centered)) {
         Ops tf = session.getTF();
-        Graph graph = session.getGraph();
+
         session.setEpsilon(1e-2f);
         float[] var0Init = {1.0F, 2.0F};
         float[] var1Init = {3.0F, 4.0F};
@@ -96,15 +105,6 @@ public void testDense() {
         Constant<TFloat32> grads0 = tf.constant(grads0Init);
         Constant<TFloat32> grads1 = tf.constant(grads1Init);
 
-        // learningRate, rho (decay), momentum, epsilon, centered
-        float learningRate = (float) (float) testParamValue[0];
-        float decay = (float) testParamValue[1];
-        float momentum = (float) testParamValue[2];
-        float epsilon = (float) testParamValue[3];
-        boolean centered = (boolean) testParamValue[4];
-
-        RMSProp instance = new RMSProp(graph, learningRate, decay, momentum, epsilon, centered);
-
         /* build the GradsAnvVars */
         List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
         gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput()));
@@ -123,14 +123,30 @@ public void testDense() {
         session.evaluate(var0Init, var0);
         session.evaluate(var1Init, var1);
 
-        Variable<TFloat32> mg0 = centered ? instance.getSlot(var0.asOutput(), MG).get() : null;
-        Variable<TFloat32> mg1 = centered ? instance.getSlot(var1.asOutput(), MG).get() : null;
+        Variable<TFloat32> mg0 =
+            centered && instance.getSlot(var0.asOutput(), MG).isPresent()
+                ? instance.getSlot(var0.asOutput(), MG).get()
+                : null;
+        Variable<TFloat32> mg1 =
+            centered && instance.getSlot(var1.asOutput(), MG).isPresent()
+                ? instance.getSlot(var1.asOutput(), MG).get()
+                : null;
         Variable<TFloat32> mom0 =
-            momentum > 0.F ? instance.getSlot(var0.asOutput(), MOMENTUM).get() : null;
+            momentum > 0.F && instance.getSlot(var0.asOutput(), MOMENTUM).isPresent()
+                ? instance.getSlot(var0.asOutput(), MOMENTUM).get()
+                : null;
         Variable<TFloat32> mom1 =
-            momentum > 0.F ? instance.getSlot(var1.asOutput(), MOMENTUM).get() : null;
-        Variable<TFloat32> rms0 = instance.getSlot(var0.asOutput(), RMS).get();
-        Variable<TFloat32> rms1 = instance.getSlot(var1.asOutput(), RMS).get();
+            momentum > 0.F && instance.getSlot(var1.asOutput(), MOMENTUM).isPresent()
+                ? instance.getSlot(var1.asOutput(), MOMENTUM).get()
+                : null;
+        Variable<TFloat32> rms0 =
+            instance.getSlot(var0.asOutput(), RMS).isPresent()
+                ? instance.getSlot(var0.asOutput(), RMS).get()
+                : null;
+        Variable<TFloat32> rms1 =
+            instance.getSlot(var1.asOutput(), RMS).isPresent()
+                ? instance.getSlot(var1.asOutput(), RMS).get()
+                : null;
 
         float[] zeros = {0.0F, 0.0F};
         float[] ones = {1.0F, 1.0F}; // temp to match RMSProp
@@ -142,7 +158,7 @@ public void testDense() {
         FloatNdArray mom1Np = NdArrays.vectorOf(zeros);
 
         for (int i = 0; i < numSteps; i++) {
-          session.run(update);
+          session.run(update, instance.getFeedMap());
           FloatNdArray[] result0 =
               calc(
                   var0Np,
@@ -179,16 +195,18 @@ public void testDense() {
           mom1Np = result1[MOM_T];
 
           if (centered) {
-            session.evaluate(mg0Np, mg0);
-            session.evaluate(mg1Np, mg1);
+            if (mg0 != null) session.evaluate(mg0Np, mg0);
+            if (mg1 != null) session.evaluate(mg1Np, mg1);
           }
 
           if (mom0 != null) session.evaluate(mom0Np, mom0);
           if (mom1 != null) session.evaluate(mom1Np, mom1);
 
           /*     TODO the values returned from rms slot, do not match what I see in the python test */
-          session.evaluate(rms0Np, rms0);
-          session.evaluate(rms1Np, rms1);
+          if (rms0 != null) session.evaluate(rms0Np, rms0);
+          else fail("rms0 is null");
+          if (rms1 != null) session.evaluate(rms1Np, rms1);
+          else fail("rms1 is null");
 
           session.evaluate(var0Np, var0);
           session.evaluate(var1Np, var1);
@@ -197,6 +215,319 @@ public void testDense() {
     }
   }
 
+  @Test
+  public void testDenseWithLROperand() {
+
+    int numSteps = 3;
+
+    for (Object[] testParamValue : testParamValues) {
+      // learningRate, rho (decay), momentum, epsilon, centered
+      float learningRate = (float) testParamValue[0];
+      float decay = (float) testParamValue[1];
+      float momentum = (float) testParamValue[2];
+      float epsilon = (float) testParamValue[3];
+      boolean centered = (boolean) testParamValue[4];
+      try (TestSession session = TestSession.createTestSession(tfMode)) {
+
+        Ops tf = session.getTF();
+        try (RMSProp instance =
+            new RMSProp(
+                session.getGraph(),
+                tf.math.add(tf.constant(learningRate), tf.constant(0f)),
+                decay,
+                momentum,
+                epsilon,
+                centered)) {
+
+          session.setEpsilon(1e-2f);
+          float[] var0Init = {1.0F, 2.0F};
+          float[] var1Init = {3.0F, 4.0F};
+          float[] grads0Init = {0.1F, 0.2F};
+          float[] grads1Init = {0.01F, 0.2F};
+
+          FloatNdArray var0Np = NdArrays.vectorOf(var0Init);
+          FloatNdArray var1Np = NdArrays.vectorOf(var1Init);
+          FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init);
+          FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init);
+
+          Shape shape0 = Shape.of(var0Init.length);
+          Shape shape1 = Shape.of(var1Init.length);
+          Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+          Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+          Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0Init));
+          Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1Init));
+
+          Constant<TFloat32> grads0 = tf.constant(grads0Init);
+          Constant<TFloat32> grads1 = tf.constant(grads1Init);
+
+          /* build the GradsAnvVars */
+          List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+          gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+          gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+          Op update = instance.applyGradients(gradsAndVars, "RMSPropTest");
+
+          /* initialize the local variables */
+          session.run(var0Initializer);
+          session.run(var1Initializer);
+
+          /* initialize the accumulators */
+          session.run(tf.init());
+
+          /* make sure the variables were initialized properly */
+          session.evaluate(var0Init, var0);
+          session.evaluate(var1Init, var1);
+
+          Variable<TFloat32> mg0 =
+              centered && instance.getSlot(var0.asOutput(), MG).isPresent()
+                  ? instance.getSlot(var0.asOutput(), MG).get()
+                  : null;
+          Variable<TFloat32> mg1 =
+              centered && instance.getSlot(var1.asOutput(), MG).isPresent()
+                  ? instance.getSlot(var1.asOutput(), MG).get()
+                  : null;
+          Variable<TFloat32> mom0 =
+              momentum > 0.F && instance.getSlot(var0.asOutput(), MOMENTUM).isPresent()
+                  ? instance.getSlot(var0.asOutput(), MOMENTUM).get()
+                  : null;
+          Variable<TFloat32> mom1 =
+              momentum > 0.F && instance.getSlot(var1.asOutput(), MOMENTUM).isPresent()
+                  ? instance.getSlot(var1.asOutput(), MOMENTUM).get()
+                  : null;
+          Variable<TFloat32> rms0 =
+              instance.getSlot(var0.asOutput(), RMS).isPresent()
+                  ? instance.getSlot(var0.asOutput(), RMS).get()
+                  : null;
+          Variable<TFloat32> rms1 =
+              instance.getSlot(var1.asOutput(), RMS).isPresent()
+                  ? instance.getSlot(var1.asOutput(), RMS).get()
+                  : null;
+
+          float[] zeros = {0.0F, 0.0F};
+          float[] ones = {1.0F, 1.0F}; // temp to match RMSProp
+          FloatNdArray mg0Np = NdArrays.vectorOf(zeros);
+          FloatNdArray mg1Np = NdArrays.vectorOf(zeros);
+          FloatNdArray rms0Np = NdArrays.vectorOf(ones);
+          FloatNdArray rms1Np = NdArrays.vectorOf(ones);
+          FloatNdArray mom0Np = NdArrays.vectorOf(zeros);
+          FloatNdArray mom1Np = NdArrays.vectorOf(zeros);
+
+          for (int i = 0; i < numSteps; i++) {
+            session.run(update, instance.getFeedMap());
+            FloatNdArray[] result0 =
+                calc(
+                    var0Np,
+                    grads0Np,
+                    mg0Np,
+                    rms0Np,
+                    mom0Np,
+                    learningRate,
+                    decay,
+                    momentum,
+                    epsilon,
+                    centered);
+            var0Np = result0[VAR_T];
+            mg0Np = result0[MG_T];
+            rms0Np = result0[RMS_T];
+            mom0Np = result0[MOM_T];
+
+            FloatNdArray[] result1 =
+                calc(
+                    var1Np,
+                    grads1Np,
+                    mg1Np,
+                    rms1Np,
+                    mom1Np,
+                    learningRate,
+                    decay,
+                    momentum,
+                    epsilon,
+                    centered);
+
+            var1Np = result1[VAR_T];
+            mg1Np = result1[MG_T];
+            rms1Np = result1[RMS_T];
+            mom1Np = result1[MOM_T];
+
+            if (centered) {
+              if (mg0 != null) session.evaluate(mg0Np, mg0);
+              if (mg1 != null) session.evaluate(mg1Np, mg1);
+            }
+
+            if (mom0 != null) session.evaluate(mom0Np, mom0);
+            if (mom1 != null) session.evaluate(mom1Np, mom1);
+
+            /*     TODO the values returned from rms slot, do not match what I see in the python test */
+            if (rms0 != null) session.evaluate(rms0Np, rms0);
+            else fail("rms0 is null");
+            if (rms1 != null) session.evaluate(rms1Np, rms1);
+            else fail("rms1 is null");
+
+            session.evaluate(var0Np, var0);
+            session.evaluate(var1Np, var1);
+          }
+        }
+      }
+    }
+  }
+
+  @Test
+  public void testWithLearningRateDecay() {
+    int numSteps = 3;
+
+    for (Object[] testParamValue : testParamValues) {
+      float learningRate = (float) testParamValue[0];
+      float decay = (float) testParamValue[1];
+      float momentum = (float) testParamValue[2];
+      float epsilon = (float) testParamValue[3];
+      boolean centered = (boolean) testParamValue[4];
+
+      try (TestSession session = TestSession.createTestSession(tfMode);
+          RMSProp instance =
+              new RMSProp(session.getGraph(), learningRate, decay, momentum, epsilon, centered)) {
+        Ops tf = session.getTF();
+        session.setEpsilon(1e-2f);
+        float[] var0_init = {1.0F, 2.0F};
+        float[] var1_init = {3.0F, 4.0F};
+        float[] grads0_init = {0.1F, 0.2F};
+        float[] grads1_init = {0.01F, 0.2F};
+
+        FloatNdArray var0_np = NdArrays.vectorOf(var0_init);
+        FloatNdArray var1_np = NdArrays.vectorOf(var1_init);
+        FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init);
+        FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init);
+
+        Shape shape0 = Shape.of(var0_init.length);
+        Shape shape1 = Shape.of(var1_init.length);
+        Variable<TFloat32> var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE);
+        Variable<TFloat32> var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE);
+
+        Assign<TFloat32> var0Initializer = tf.assign(var0, tf.constant(var0_init));
+        Assign<TFloat32> var1Initializer = tf.assign(var1, tf.constant(var1_init));
+
+        Constant<TFloat32> grads0 = tf.constant(grads0_init);
+        Constant<TFloat32> grads1 = tf.constant(grads1_init);
+
+        /* build the GradsAnvVars */
+        List<GradAndVar<? extends TType>> gradsAndVars = new ArrayList<>();
+        gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput()));
+        gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput()));
+
+        Op update = instance.applyGradients(gradsAndVars, "RMSPropTest");
+
+        /* initialize the local variables */
+        session.run(var0Initializer);
+        session.run(var1Initializer);
+
+        // initialize the accumulators
+        session.run(tf.init());
+
+        // make sure the variables were initialized properly
+        session.evaluate(var0_init, var0);
+        session.evaluate(var1_init, var1);
+
+        Variable<TFloat32> mg0 =
+            centered && instance.getSlot(var0.asOutput(), MG).isPresent()
+                ? instance.getSlot(var0.asOutput(), MG).get()
+                : null;
+        Variable<TFloat32> mg1 =
+            centered && instance.getSlot(var1.asOutput(), MG).isPresent()
+                ? instance.getSlot(var1.asOutput(), MG).get()
+                : null;
+        Variable<TFloat32> mom0 =
+            momentum > 0.F && instance.getSlot(var0.asOutput(), MOMENTUM).isPresent()
+                ? instance.getSlot(var0.asOutput(), MOMENTUM).get()
+                : null;
+        Variable<TFloat32> mom1 =
+            momentum > 0.F && instance.getSlot(var1.asOutput(), MOMENTUM).isPresent()
+                ? instance.getSlot(var1.asOutput(), MOMENTUM).get()
+                : null;
+        Variable<TFloat32> rms0 =
+            instance.getSlot(var0.asOutput(), RMS).isPresent()
+                ? instance.getSlot(var0.asOutput(), RMS).get()
+                : null;
+        Variable<TFloat32> rms1 =
+            instance.getSlot(var1.asOutput(), RMS).isPresent()
+                ? instance.getSlot(var1.asOutput(), RMS).get()
+                : null;
+
+        float[] zeros = {0.0F, 0.0F};
+        float[] ones = {1.0F, 1.0F}; // temp to match RMSProp
+        FloatNdArray mg0_np = NdArrays.vectorOf(zeros);
+        FloatNdArray mg1_np = NdArrays.vectorOf(zeros);
+        FloatNdArray rms0_np = NdArrays.vectorOf(ones);
+        FloatNdArray rms1_np = NdArrays.vectorOf(ones);
+        FloatNdArray mom0_np = NdArrays.vectorOf(zeros);
+        FloatNdArray mom1_np = NdArrays.vectorOf(zeros);
+
+        for (int i = 0; i < numSteps; i++) {
+          assertEquals(learningRate, instance.getLearningRate(), epsilon);
+          session.evaluate(
+              learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap());
+          session.run(update, instance.getFeedMap());
+          FloatNdArray[] result0 =
+              calc(
+                  var0_np,
+                  grads0_np,
+                  mg0_np,
+                  rms0_np,
+                  mom0_np,
+                  learningRate,
+                  decay,
+                  momentum,
+                  epsilon,
+                  centered);
+          var0_np = result0[VAR_T];
+          mg0_np = result0[MG_T];
+          rms0_np = result0[RMS_T];
+          mom0_np = result0[MOM_T];
+
+          FloatNdArray[] result1 =
+              calc(
+                  var1_np,
+                  grads1_np,
+                  mg1_np,
+                  rms1_np,
+                  mom1_np,
+                  learningRate,
+                  decay,
+                  momentum,
+                  epsilon,
+                  centered);
+
+          var1_np = result1[VAR_T];
+          mg1_np = result1[MG_T];
+          rms1_np = result1[RMS_T];
+          mom1_np = result1[MOM_T];
+
+          if (centered) {
+            if (mg0 != null) session.evaluate(mg0_np, mg0);
+            else fail("mg0 is null");
+            if (mg1 != null) session.evaluate(mg1_np, mg1);
+            else fail("mg1 is null");
+          }
+          if (momentum > 0.F) {
+            if (mom0 != null) session.evaluate(mom0_np, mom0);
+            if (mom1 != null) session.evaluate(mom1_np, mom1);
+          }
+
+          /*     TODO the values returned from rms slot, do not match what I see in the python test */
+          if (rms0 != null) session.evaluate(rms0_np, rms0);
+          else fail("rms0 is null");
+          if (rms1 != null) session.evaluate(rms1_np, rms1);
+          else fail("rms1 is null");
+
+          session.evaluate(var0_np, var0);
+          session.evaluate(var1_np, var1);
+
+          learningRate *= 0.9F;
+          instance.setLearningRate(learningRate);
+        }
+      }
+    }
+  }
+
   FloatNdArray[] calc(
       FloatNdArray varNp,
       FloatNdArray gradNp,
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java
index 9fb9885505c..33fb11fc31f 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java
@@ -16,27 +16,28 @@
 
 import org.tensorflow.*;
 import org.tensorflow.ndarray.FloatNdArray;
-import org.tensorflow.ndarray.IntNdArray;
 import org.tensorflow.ndarray.Shape;
+import org.tensorflow.op.Op;
 import org.tensorflow.op.Ops;
 import org.tensorflow.types.*;
 import org.tensorflow.types.family.TNumber;
 import org.tensorflow.types.family.TType;
 
 import java.io.PrintWriter;
+import java.util.Map;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.Predicate;
 
 import static org.junit.jupiter.api.Assertions.*;
 
-/** Eager Mode Test Session */
+/** An Eager Mode Test Session */
 public class EagerTestSession extends TestSession {
 
   private final EagerSession session;
   private final Ops tf;
 
-  /** Create an Eager mode test session. */
+  /** Create a EagerTestSession */
   public EagerTestSession() {
     this.session = EagerSession.create();
     this.tf = Ops.create(session).withName("test");
@@ -48,15 +49,6 @@ public Ops getTF() {
     return tf;
   }
 
-  /**
-   * Get the TensorFlow EagerSession instance
-   *
-   * @return the TensorFlow EagerSession instance
-   */
-  public EagerSession getSession() {
-    return session;
-  }
-
   /** {@inheritDoc} */
   @Override
   public void close() {
@@ -83,9 +75,19 @@ public EagerSession getEagerSession() {
 
   /** {@inheritDoc} */
   @Override
-  public <T extends TNumber> void evaluate(double expected, Operand<T> input) {
-    DataType<T> dtype = input.asOutput().dataType();
+  public final void run(Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    // Do nothing in EagerSession, run() only applies to Graph
+  }
+
+  /** {@inheritDoc} */
+  @Override
+  public <U extends TNumber> void evaluate(
+      double expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    DataType<U> dtype = input.asOutput().dataType();
     if (dtype == TFloat32.DTYPE) {
+      @SuppressWarnings("unchecked")
       Operand<TFloat32> o = (Operand<TFloat32>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -96,6 +98,8 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) {
       index.set(0);
       o.data().scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon));
     } else if (dtype == TFloat64.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Operand<TFloat64> o = (Operand<TFloat64>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -106,6 +110,8 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) {
       index.set(0);
       o.data().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon));
     } else if (dtype == TInt32.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Operand<TInt32> o = (Operand<TInt32>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -116,6 +122,8 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) {
       index.set(0);
       o.data().scalars().forEach(f -> assertEquals((int) expected, f.getInt()));
     } else if (dtype == TInt64.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Operand<TInt64> o = (Operand<TInt64>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -130,14 +138,19 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) {
 
   /** {@inheritDoc} */
   @Override
-  public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
+  public <U extends TNumber> void evaluate(
+      Number[] expected,
+      Output<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
     int size = input.shape().size() == 0 ? 1 : (int) input.shape().size();
     assertEquals(
         expected.length,
         size,
         () -> String.format("expected length (%d) != to input length (%d)", expected.length, size));
-    DataType<T> dtype = input.dataType();
+    DataType<U> dtype = input.dataType();
     if (dtype == TFloat32.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TFloat32> o = (Output<TFloat32>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -153,6 +166,8 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
                   assertEquals(
                       expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon));
     } else if (dtype == TFloat64.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TFloat64> o = (Output<TFloat64>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -168,6 +183,8 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
                   assertEquals(
                       expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon));
     } else if (dtype == TInt32.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TInt32> o = (Output<TInt32>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -180,6 +197,8 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
           .scalars()
           .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt()));
     } else if (dtype == TInt64.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TInt64> o = (Output<TInt64>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -196,9 +215,14 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
 
   /** {@inheritDoc} */
   @Override
-  public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
-    DataType<T> dtype = input.dataType();
+  public <U extends TNumber> void evaluate(
+      FloatNdArray expected,
+      Output<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    DataType<U> dtype = input.dataType();
     if (dtype == TFloat32.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TFloat32> o = (Output<TFloat32>) input;
       AtomicLong index = new AtomicLong();
       if (debug) {
@@ -212,6 +236,8 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
           .forEach(
               f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon));
     } else if (dtype == TFloat64.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TFloat64> o = (Output<TFloat64>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -226,6 +252,8 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
               f ->
                   assertEquals(expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon));
     } else if (dtype == TInt32.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TInt32> o = (Output<TInt32>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -234,10 +262,12 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
             .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt()));
       }
       index.set(0);
-      for (IntNdArray f : o.data().scalars()) {
-        assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt());
-      }
+      o.data()
+          .scalars()
+          .forEach(f -> assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt()));
     } else if (dtype == TInt64.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TInt64> o = (Output<TInt64>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -255,11 +285,16 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
 
   /** {@inheritDoc} */
   @Override
-  public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predicate) {
+  public <U extends TNumber> void evaluate(
+      Output<U> input,
+      Predicate<Number> predicate,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
     AtomicInteger index = new AtomicInteger();
-    DataType<T> dtype = input.asOutput().dataType();
+    DataType<U> dtype = input.asOutput().dataType();
     boolean isScalar = input.shape().equals(Shape.scalar());
     if (dtype == TFloat32.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TFloat32> o = (Output<TFloat32>) input;
       if (debug) {
         if (isScalar) {
@@ -284,6 +319,8 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic
             .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getFloat())));
       }
     } else if (dtype == TFloat64.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TFloat64> o = (Output<TFloat64>) input;
       if (debug) {
         if (isScalar) {
@@ -308,6 +345,8 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic
             .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getDouble())));
       }
     } else if (dtype == TInt32.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TInt32> o = (Output<TInt32>) input;
       if (debug) {
         if (isScalar) {
@@ -332,6 +371,8 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic
             .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getInt())));
       }
     } else if (dtype == TInt64.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TInt64> o = (Output<TInt64>) input;
       if (debug) {
         if (isScalar) {
@@ -362,7 +403,10 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic
 
   /** {@inheritDoc} */
   @Override
-  public void evaluate(String[] expected, Output<TString> input) {
+  public void evaluate(
+      String[] expected,
+      Output<TString> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
     int size = input.shape().size() == 0 ? 1 : (int) input.shape().size();
     assertEquals(
         expected.length,
@@ -384,7 +428,10 @@ public void evaluate(String[] expected, Output<TString> input) {
 
   /** {@inheritDoc} */
   @Override
-  public void evaluate(Boolean[] expected, Output<TBool> input) {
+  public void evaluate(
+      Boolean[] expected,
+      Output<TBool> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
     int size = input.shape().size() == 0 ? 1 : (int) input.shape().size();
     assertEquals(
         expected.length,
@@ -406,7 +453,10 @@ public void evaluate(Boolean[] expected, Output<TBool> input) {
 
   /** {@inheritDoc} */
   @Override
-  public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
+  public <T extends TType> void evaluate(
+      Output<T> expected,
+      Output<T> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
     assert input.shape().equals(expected.shape())
         : String.format(
             "expected shape (%s) != to input shape (%s)",
@@ -414,7 +464,10 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
     DataType<T> dtype = input.asOutput().dataType();
     boolean isScalar = input.shape().equals(Shape.scalar());
     if (dtype == TFloat32.DTYPE) {
+
+      @SuppressWarnings("unchecked")
       Output<TFloat32> x = (Output<TFloat32>) expected;
+      @SuppressWarnings("unchecked")
       Output<TFloat32> o = (Output<TFloat32>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -440,7 +493,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
                 (idx, f) -> assertEquals(x.data().getFloat(idx), f.getFloat(), epsilon));
       }
     } else if (dtype == TFloat64.DTYPE) {
+      @SuppressWarnings("unchecked")
       Output<TFloat64> x = (Output<TFloat64>) expected;
+      @SuppressWarnings("unchecked")
       Output<TFloat64> o = (Output<TFloat64>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -466,7 +521,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
                 (idx, f) -> assertEquals(x.data().getDouble(idx), f.getDouble(), epsilon));
       }
     } else if (dtype == TInt32.DTYPE) {
+      @SuppressWarnings("unchecked")
       Output<TInt32> x = (Output<TInt32>) expected;
+      @SuppressWarnings("unchecked")
       Output<TInt32> o = (Output<TInt32>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -491,7 +548,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
             .forEachIndexed((idx, f) -> assertEquals(x.data().getInt(idx), f.getInt()));
       }
     } else if (dtype == TInt64.DTYPE) {
+      @SuppressWarnings("unchecked")
       Output<TInt64> x = (Output<TInt64>) expected;
+      @SuppressWarnings("unchecked")
       Output<TInt64> o = (Output<TInt64>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -516,7 +575,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
             .forEachIndexed((idx, f) -> assertEquals(x.data().getLong(idx), f.getLong()));
       }
     } else if (dtype == TString.DTYPE) {
+      @SuppressWarnings("unchecked")
       Output<TString> x = (Output<TString>) expected;
+      @SuppressWarnings("unchecked")
       Output<TString> o = (Output<TString>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -541,7 +602,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
             .forEachIndexed((idx, f) -> assertEquals(x.data().getObject(idx), f.getObject()));
       }
     } else if (dtype == TBool.DTYPE) {
+      @SuppressWarnings("unchecked")
       Output<TBool> x = (Output<TBool>) expected;
+      @SuppressWarnings("unchecked")
       Output<TBool> o = (Output<TBool>) input;
       AtomicInteger index = new AtomicInteger();
       if (debug) {
@@ -570,44 +633,53 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
 
   /** {@inheritDoc} */
   @Override
-  public <T extends TType> void print(PrintWriter writer, Output<T> input) {
+  public <T extends TType> void print(
+      PrintWriter writer,
+      Output<T> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
     DataType<T> dtype = input.asOutput().dataType();
     if (dtype == TFloat32.DTYPE) {
+      @SuppressWarnings("unchecked")
       Output<TFloat32> o = (Output<TFloat32>) input;
       AtomicInteger index = new AtomicInteger();
       o.data()
           .scalars()
-          .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat()));
+          .forEach(f -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getFloat()));
     } else if (dtype == TFloat64.DTYPE) {
+      @SuppressWarnings("unchecked")
       Output<TFloat64> o = (Output<TFloat64>) input;
       AtomicInteger index = new AtomicInteger();
       o.data()
           .scalars()
-          .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble()));
+          .forEach(f -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getDouble()));
     } else if (dtype == TInt32.DTYPE) {
+      @SuppressWarnings("unchecked")
       Output<TInt32> o = (Output<TInt32>) input;
       AtomicInteger index = new AtomicInteger();
       o.data()
           .scalars()
-          .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt()));
+          .forEach(f -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getInt()));
     } else if (dtype == TInt64.DTYPE) {
+      @SuppressWarnings("unchecked")
       Output<TInt64> o = (Output<TInt64>) input;
       AtomicInteger index = new AtomicInteger();
       o.data()
           .scalars()
-          .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong()));
+          .forEach(f -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getLong()));
     } else if (dtype == TString.DTYPE) {
+      @SuppressWarnings("unchecked")
       Output<TString> o = (Output<TString>) input;
       AtomicInteger index = new AtomicInteger();
       o.data()
           .scalars()
-          .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject()));
+          .forEach(f -> writer.printf("%d). %s\n", index.getAndIncrement(), f.getObject()));
     } else if (dtype == TBool.DTYPE) {
+      @SuppressWarnings("unchecked")
       Output<TBool> o = (Output<TBool>) input;
       AtomicInteger index = new AtomicInteger();
       o.data()
           .scalars()
-          .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean()));
+          .forEach(f -> writer.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean()));
     } else {
       writer.println("Unexpected DataType: " + dtype);
     }
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java
index 0416267ae59..cc9140c5134 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java
@@ -15,6 +15,7 @@
 package org.tensorflow.framework.utils;
 
 import org.tensorflow.*;
+import org.tensorflow.Session.Runner;
 import org.tensorflow.ndarray.FloatNdArray;
 import org.tensorflow.ndarray.Shape;
 import org.tensorflow.op.Op;
@@ -24,20 +25,21 @@
 import org.tensorflow.types.family.TType;
 
 import java.io.PrintWriter;
+import java.util.Map;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.Predicate;
 
 import static org.junit.jupiter.api.Assertions.*;
 
-/** Graph Mode Test Session */
+/** A Graph Mode Test Session */
 public class GraphTestSession extends TestSession {
 
   private final Graph graph;
   private final Session session;
   private final Ops tf;
 
-  /** Create a Graph mode test session. */
+  /** Create a Graph Test Session */
   public GraphTestSession() {
     graph = new Graph();
     session = new Session(graph);
@@ -50,20 +52,12 @@ public Ops getTF() {
     return tf;
   }
 
-  /** Get the Graph object that is represented by this Test Session */
+  /** {@inheritDoc} */
+  @Override
   public Graph getGraph() {
     return graph;
   }
 
-  /**
-   * Get the TensorFlow Session instance
-   *
-   * @return the TensorFlow Session instance
-   */
-  public Session getSession() {
-    return session;
-  }
-
   /** {@inheritDoc} */
   @Override
   public void close() {
@@ -92,24 +86,45 @@ public EagerSession getEagerSession() {
   /** {@inheritDoc} */
   @Override
   public void initialize() {
-    graph.initializers().forEach(initializer -> session.runner().addTarget(initializer).run());
+    session.run(tf.init());
   }
 
   /** {@inheritDoc} */
   @Override
-  public void run(Op op) {
-    session.run(op);
+  public final void run(Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) {
+    createRunner(op, feedDict).run();
+  }
+
+  /**
+   * Create a runner for the Operation
+   *
+   * @param op the operation
+   * @param feedMap the dictionary of values to use for the runner's feed operations. Required when
+   *     placeholders are used.
+   * @return the runner
+   */
+  public final Runner createRunner(
+      Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    Runner runner = session.runner();
+    runner.addTarget(op.op());
+    if (feedMap != null) {
+      feedMap.forEach(runner::feed);
+    }
+    return runner;
   }
 
   /** {@inheritDoc} */
   @Override
-  public <T extends TNumber> void evaluate(double expected, Operand<T> input) {
-    DataType<T> dtype = input.asOutput().dataType();
+  public <U extends TNumber> void evaluate(
+      double expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) {
+    DataType<U> dtype = input.asOutput().dataType();
     if (dtype == TFloat32.DTYPE) {
       AtomicInteger index = new AtomicInteger();
       if (debug) {
         try (Tensor<TFloat32> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
           result
               .data()
               .scalars()
@@ -118,14 +133,14 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) {
       }
       index.set(0);
       try (Tensor<TFloat32> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
         result.data().scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon));
       }
     } else if (dtype == TFloat64.DTYPE) {
       AtomicInteger index = new AtomicInteger();
       if (debug) {
         try (Tensor<TFloat64> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
           result
               .data()
               .scalars()
@@ -134,14 +149,14 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) {
       }
       index.set(0);
       try (Tensor<TFloat64> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
         result.data().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon));
       }
     } else if (dtype == TInt32.DTYPE) {
       AtomicInteger index = new AtomicInteger();
       if (debug) {
         try (Tensor<TInt32> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) {
           result
               .data()
               .scalars()
@@ -150,14 +165,14 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) {
       }
       index.set(0);
       try (Tensor<TInt32> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) {
         result.data().scalars().forEach(f -> assertEquals((int) expected, f.getInt()));
       }
     } else if (dtype == TInt64.DTYPE) {
       AtomicInteger index = new AtomicInteger();
       if (debug) {
         try (Tensor<TInt64> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) {
           result
               .data()
               .scalars()
@@ -166,7 +181,7 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) {
       }
       index.set(0);
       try (Tensor<TInt64> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) {
         result.data().scalars().forEach(f -> assertEquals((long) expected, f.getLong()));
       }
     } else {
@@ -176,18 +191,24 @@ public <T extends TNumber> void evaluate(double expected, Operand<T> input) {
 
   /** {@inheritDoc} */
   @Override
-  public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
-    int size = input.shape().size() == 0 ? 1 : (int) input.shape().size();
-    assertEquals(
-        expected.length,
-        size,
-        () -> String.format("expected length (%d) != to input length (%d)", expected.length, size));
-    DataType<T> dtype = input.asOutput().dataType();
+  public <U extends TNumber> void evaluate(
+      Number[] expected,
+      Output<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) {
+    long size = input.shape().size() == 0 ? 1 : input.shape().size();
+    if (size != Shape.UNKNOWN_SIZE) {
+      assertEquals(
+          expected.length,
+          size,
+          () ->
+              String.format("expected length (%d) != to input length (%d)", expected.length, size));
+    }
+    DataType<U> dtype = input.asOutput().dataType();
     if (dtype == TFloat32.DTYPE) {
       AtomicInteger index = new AtomicInteger();
       if (debug) {
         try (Tensor<TFloat32> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
           result
               .data()
               .scalars()
@@ -196,7 +217,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TFloat32> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
         result
             .data()
             .scalars()
@@ -209,7 +230,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
       AtomicInteger index = new AtomicInteger();
       if (debug) {
         try (Tensor<TFloat64> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
           result
               .data()
               .scalars()
@@ -218,7 +239,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TFloat64> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
         result
             .data()
             .scalars()
@@ -231,7 +252,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
       AtomicInteger index = new AtomicInteger();
       if (debug) {
         try (Tensor<TInt32> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) {
           result
               .data()
               .scalars()
@@ -240,7 +261,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TInt32> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) {
         result
             .data()
             .scalars()
@@ -250,7 +271,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
       AtomicInteger index = new AtomicInteger();
       if (debug) {
         try (Tensor<TInt64> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) {
           result
               .data()
               .scalars()
@@ -259,7 +280,7 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TInt64> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) {
         result
             .data()
             .scalars()
@@ -272,13 +293,16 @@ public <T extends TNumber> void evaluate(Number[] expected, Output<T> input) {
 
   /** {@inheritDoc} */
   @Override
-  public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
-    DataType<T> dtype = input.asOutput().dataType();
+  public <U extends TNumber> void evaluate(
+      FloatNdArray expected,
+      Output<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) {
+    DataType<U> dtype = input.asOutput().dataType();
     if (dtype == TFloat32.DTYPE) {
       AtomicLong index = new AtomicLong();
       if (debug) {
         try (Tensor<TFloat32> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
           result
               .data()
               .scalars()
@@ -287,7 +311,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TFloat32> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
         result
             .data()
             .scalars()
@@ -300,7 +324,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
       AtomicInteger index = new AtomicInteger();
       if (debug) {
         try (Tensor<TFloat64> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
           result
               .data()
               .scalars()
@@ -309,7 +333,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TFloat64> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
         result
             .data()
             .scalars()
@@ -322,7 +346,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
       AtomicInteger index = new AtomicInteger();
       if (debug) {
         try (Tensor<TInt32> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) {
           result
               .data()
               .scalars()
@@ -331,7 +355,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TInt32> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) {
         result
             .data()
             .scalars()
@@ -342,7 +366,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
       AtomicInteger index = new AtomicInteger();
       if (debug) {
         try (Tensor<TInt64> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) {
           result
               .data()
               .scalars()
@@ -351,7 +375,7 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TInt64> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) {
         result
             .data()
             .scalars()
@@ -365,7 +389,10 @@ public <T extends TType> void evaluate(FloatNdArray expected, Output<T> input) {
 
   /** {@inheritDoc} */
   @Override
-  public void evaluate(String[] expected, Output<TString> input) {
+  public void evaluate(
+      String[] expected,
+      Output<TString> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) {
     int size = input.shape().size() == 0 ? 1 : (int) input.shape().size();
     assertEquals(
         expected.length,
@@ -374,7 +401,7 @@ public void evaluate(String[] expected, Output<TString> input) {
     AtomicInteger index = new AtomicInteger();
     if (debug) {
       try (Tensor<TString> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) {
         result
             .data()
             .scalars()
@@ -383,7 +410,7 @@ public void evaluate(String[] expected, Output<TString> input) {
     }
     index.set(0);
     try (Tensor<TString> result =
-        this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) {
+        createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) {
       result
           .data()
           .scalars()
@@ -393,7 +420,10 @@ public void evaluate(String[] expected, Output<TString> input) {
 
   /** {@inheritDoc} */
   @Override
-  public void evaluate(Boolean[] expected, Output<TBool> input) {
+  public void evaluate(
+      Boolean[] expected,
+      Output<TBool> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) {
     int size = input.shape().size() == 0 ? 1 : (int) input.shape().size();
     assertEquals(
         expected.length,
@@ -402,7 +432,7 @@ public void evaluate(Boolean[] expected, Output<TBool> input) {
     AtomicInteger index = new AtomicInteger();
     if (debug) {
       try (Tensor<TBool> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) {
         result
             .data()
             .scalars()
@@ -411,7 +441,7 @@ public void evaluate(Boolean[] expected, Output<TBool> input) {
     }
     index.set(0);
     try (Tensor<TBool> result =
-        this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) {
+        createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) {
       result
           .data()
           .scalars()
@@ -421,27 +451,25 @@ public void evaluate(Boolean[] expected, Output<TBool> input) {
 
   /** {@inheritDoc} */
   @Override
-  public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
+  public <T extends TType> void evaluate(
+      Output<T> expected,
+      Output<T> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) {
     assert input.shape().equals(expected.shape())
         : String.format(
             "expected shape (%s) != to input shape (%s)",
             expected.shape().toString(), input.shape().toString());
     AtomicInteger index = new AtomicInteger();
     DataType<T> dtype = input.asOutput().dataType();
-    if (!dtype.equals(expected.dataType())) {
-      throw new IllegalArgumentException(
-          String.format(
-              "Both data type must be equal, inout = %s, expected = %s",
-              dtype, expected.dataType()));
-    }
     boolean isScalar = input.shape().equals(Shape.scalar());
     if (dtype == TFloat32.DTYPE) {
+      @SuppressWarnings("unchecked")
       final Output<TFloat32> finalExpected = (Output<TFloat32>) expected;
       if (debug) {
         try (Tensor<TFloat32> result =
-                this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE);
+                createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE);
             Tensor<TFloat32> expectedResult =
-                this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
+                createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
           if (isScalar) {
             System.out.printf(
                 "0). %f <==> %f\n", expectedResult.data().getFloat(), result.data().getFloat());
@@ -461,9 +489,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TFloat32> result =
-              this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE);
+              createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE);
           Tensor<TFloat32> expectedResult =
-              this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
+              createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
         if (isScalar) {
           assertEquals(expectedResult.data().getFloat(), result.data().getFloat(), epsilon);
         } else {
@@ -476,12 +504,13 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
         }
       }
     } else if (dtype == TFloat64.DTYPE) {
+      @SuppressWarnings("unchecked")
       final Output<TFloat64> finalExpected = (Output<TFloat64>) expected;
       if (debug) {
         try (Tensor<TFloat64> result =
-                this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE);
+                createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE);
             Tensor<TFloat64> expectedResult =
-                this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
+                createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
           if (isScalar) {
             System.out.printf(
                 "0). %f <==> %f\n", expectedResult.data().getDouble(), result.data().getDouble());
@@ -501,9 +530,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TFloat64> result =
-              this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE);
+              createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE);
           Tensor<TFloat64> expectedResult =
-              this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
+              createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
         if (isScalar) {
           assertEquals(expectedResult.data().getDouble(), result.data().getDouble(), epsilon);
         } else {
@@ -516,12 +545,13 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
         }
       }
     } else if (dtype == TInt32.DTYPE) {
+      @SuppressWarnings("unchecked")
       final Output<TInt32> finalExpected = (Output<TInt32>) expected;
       if (debug) {
         try (Tensor<TInt32> result =
-                this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE);
+                createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE);
             Tensor<TInt32> expectedResult =
-                this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) {
+                createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) {
           if (isScalar) {
             System.out.printf(
                 "0). %d <==> %d\n", expectedResult.data().getInt(), result.data().getInt());
@@ -539,9 +569,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TInt32> result =
-              this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE);
+              createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE);
           Tensor<TInt32> expectedResult =
-              this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) {
+              createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) {
         if (isScalar) {
           assertEquals(expectedResult.data().getInt(), result.data().getInt(), epsilon);
         } else {
@@ -553,12 +583,13 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
         }
       }
     } else if (dtype == TInt64.DTYPE) {
+      @SuppressWarnings("unchecked")
       final Output<TInt64> finalExpected = (Output<TInt64>) expected;
       if (debug) {
         try (Tensor<TInt64> result =
-                this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE);
+                createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE);
             Tensor<TInt64> expectedResult =
-                this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) {
+                createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) {
           if (isScalar) {
             System.out.printf(
                 "0). %d <==> %d\n", expectedResult.data().getLong(), result.data().getLong());
@@ -578,9 +609,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TInt64> result =
-              this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE);
+              createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE);
           Tensor<TInt64> expectedResult =
-              this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) {
+              createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) {
         if (isScalar) {
           assertEquals(expectedResult.data().getLong(), result.data().getLong(), epsilon);
         } else {
@@ -593,12 +624,13 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
         }
       }
     } else if (dtype == TBool.DTYPE) {
+      @SuppressWarnings("unchecked")
       final Output<TBool> finalExpected = (Output<TBool>) expected;
       if (debug) {
         try (Tensor<TBool> result =
-                this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE);
+                createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE);
             Tensor<TBool> expectedResult =
-                this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) {
+                createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) {
           if (isScalar) {
             System.out.printf(
                 "0). %b <==> %b\n", expectedResult.data().getBoolean(), result.data().getBoolean());
@@ -618,9 +650,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TBool> result =
-              this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE);
+              createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE);
           Tensor<TBool> expectedResult =
-              this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) {
+              createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) {
         if (isScalar) {
           assertEquals(expectedResult.data().getBoolean(), result.data().getBoolean());
         } else {
@@ -632,12 +664,13 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
         }
       }
     } else if (dtype == TString.DTYPE) {
+      @SuppressWarnings("unchecked")
       final Output<TString> finalExpected = (Output<TString>) expected;
       if (debug) {
         try (Tensor<TString> result =
-                this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE);
+                createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE);
             Tensor<TString> expectedResult =
-                this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) {
+                createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) {
           if (isScalar) {
             System.out.printf(
                 "0). %s <==> %s\n", expectedResult.data().getObject(), result.data().getObject());
@@ -657,9 +690,9 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
       }
       index.set(0);
       try (Tensor<TString> result =
-              this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE);
+              createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE);
           Tensor<TString> expectedResult =
-              this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) {
+              createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) {
         if (isScalar) {
           assertEquals(expectedResult.data().getObject(), result.data().getObject());
         } else {
@@ -676,15 +709,17 @@ public <T extends TType> void evaluate(Output<T> expected, Output<T> input) {
   }
 
   /** {@inheritDoc} */
-  @Override
-  public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predicate) {
+  public <U extends TNumber> void evaluate(
+      Output<U> input,
+      Predicate<Number> predicate,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) {
     AtomicInteger index = new AtomicInteger();
-    DataType<T> dtype = input.asOutput().dataType();
+    DataType<U> dtype = input.asOutput().dataType();
     boolean isScalar = input.shape().equals(Shape.scalar());
     if (dtype == TFloat32.DTYPE) {
       if (debug) {
         try (Tensor<TFloat32> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
           if (isScalar) {
             System.out.printf(
                 "0). %b <==> %f\n",
@@ -703,7 +738,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic
       }
       index.set(0);
       try (Tensor<TFloat32> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
         if (isScalar) {
           assertTrue(predicate.test(result.data().getFloat()));
         } else {
@@ -716,7 +751,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic
     } else if (dtype == TFloat64.DTYPE) {
       if (debug) {
         try (Tensor<TFloat64> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
           if (isScalar) {
             System.out.printf(
                 "0). %b <==> %f\n",
@@ -735,7 +770,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic
       }
       index.set(0);
       try (Tensor<TFloat64> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
         if (isScalar) {
           assertTrue(predicate.test(result.data().getDouble()));
         } else {
@@ -748,7 +783,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic
     } else if (dtype == TInt32.DTYPE) {
       if (debug) {
         try (Tensor<TInt32> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) {
           if (isScalar) {
             System.out.printf(
                 "0). %b <==> %d\n", predicate.test(result.data().getInt()), result.data().getInt());
@@ -766,7 +801,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic
       }
       index.set(0);
       try (Tensor<TInt32> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) {
         if (isScalar) {
           assertTrue(predicate.test(result.data().getInt()));
         } else {
@@ -779,7 +814,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic
     } else if (dtype == TInt64.DTYPE) {
       if (debug) {
         try (Tensor<TInt64> result =
-            this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) {
+            createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) {
           if (isScalar) {
             System.out.printf(
                 "0). %b <==> %d\n",
@@ -798,7 +833,7 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic
       }
       index.set(0);
       try (Tensor<TInt64> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) {
         if (isScalar) {
           assertTrue(predicate.test(result.data().getLong()));
         } else {
@@ -815,14 +850,17 @@ public <T extends TType> void evaluate(Output<T> input, Predicate<Number> predic
 
   /** {@inheritDoc} */
   @Override
-  public <T extends TType> void print(PrintWriter writer, Output<T> input) {
+  public <T extends TType> void print(
+      PrintWriter writer,
+      Output<T> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedDict) {
     boolean isScalar = input.asOutput().shape().size() == 1;
 
     DataType<T> dtype = input.dataType();
     if (dtype == TFloat32.DTYPE) {
       AtomicInteger index = new AtomicInteger();
       try (Tensor<TFloat32> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) {
         if (isScalar) {
           writer.printf("%d). %f\n", index.getAndIncrement(), result.data().getFloat());
         } else {
@@ -837,10 +875,11 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) {
       AtomicInteger index = new AtomicInteger();
 
       try (Tensor<TFloat64> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) {
         if (isScalar) {
-          writer.printf(
-              "%d). %f\n", index.getAndIncrement(), ((Output<TFloat64>) input).data().getDouble());
+          @SuppressWarnings("unchecked")
+          Output<TFloat64> o = (Output<TFloat64>) input;
+          writer.printf("%d). %f\n", index.getAndIncrement(), o.data().getDouble());
         } else {
           result
               .data()
@@ -853,10 +892,11 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) {
       AtomicInteger index = new AtomicInteger();
 
       try (Tensor<TInt32> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) {
         if (isScalar) {
-          writer.printf(
-              "%d). %d\n", index.getAndIncrement(), ((Output<TInt32>) input).data().getInt());
+          @SuppressWarnings("unchecked")
+          Output<TInt32> o = (Output<TInt32>) input;
+          writer.printf("%d). %d\n", index.getAndIncrement(), o.data().getInt());
         } else {
           result
               .data()
@@ -869,10 +909,11 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) {
       AtomicInteger index = new AtomicInteger();
 
       try (Tensor<TInt64> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) {
         if (isScalar) {
-          writer.printf(
-              "%d). %d\n", index.getAndIncrement(), ((Output<TInt64>) input).data().getLong());
+          @SuppressWarnings("unchecked")
+          Output<TInt64> o = (Output<TInt64>) input;
+          writer.printf("%d). %d\n", index.getAndIncrement(), o.data().getLong());
         } else {
           result
               .data()
@@ -885,10 +926,11 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) {
       AtomicInteger index = new AtomicInteger();
 
       try (Tensor<TBool> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) {
         if (isScalar) {
-          writer.printf(
-              "%d). %b\n", index.getAndIncrement(), ((Output<TBool>) input).data().getBoolean());
+          @SuppressWarnings("unchecked")
+          Output<TBool> o = (Output<TBool>) input;
+          writer.printf("%d). %b\n", index.getAndIncrement(), o.data().getBoolean());
         } else {
           result
               .data()
@@ -901,10 +943,11 @@ public <T extends TType> void print(PrintWriter writer, Output<T> input) {
       AtomicInteger index = new AtomicInteger();
 
       try (Tensor<TString> result =
-          this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) {
+          createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) {
         if (isScalar) {
-          writer.printf(
-              "%d). %s\n", index.getAndIncrement(), ((Output<TString>) input).data().getObject());
+          @SuppressWarnings("unchecked")
+          Output<TString> o = (Output<TString>) input;
+          writer.printf("%d). %s\n", index.getAndIncrement(), o.data().getObject());
         } else {
           result
               .data()
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java
index a0855eb6260..713225a4962 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java
@@ -27,583 +27,1136 @@
 import java.io.OutputStreamWriter;
 import java.io.PrintWriter;
 import java.io.Writer;
+import java.util.Map;
 import java.util.function.Predicate;
 
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
-/** Base class for Test Session */
+/** Abstract class for Test Sessions */
 public abstract class TestSession implements AutoCloseable {
 
   protected float epsilon = 1e-5F;
   protected boolean debug;
 
-  /** The Test Session mode, either Eager or Graph */
+  /** Enumerate between Eager and Graph Mode */
   public enum Mode {
     EAGER,
     GRAPH
   }
 
-  /**
-   * Creates an Eager Test Session
-   *
-   * @return the Eager Test Session
-   */
   public static TestSession createEagerSession() {
     return new EagerTestSession();
   }
 
-  /**
-   * Creates a Graph Test Session
-   *
-   * @return the Graph Test Session
-   */
   public static TestSession createGraphSession() {
     return new GraphTestSession();
   }
 
-  /**
-   * Creates a Test Session
-   *
-   * @param mode the type of Session, either eager or graph
-   * @return returns the test session
-   */
   public static TestSession createTestSession(Mode mode) {
     return mode == Mode.EAGER ? createEagerSession() : createGraphSession();
   }
 
-  /** Initializes the Test Session, default implementation is do nothing. */
+  /**
+   * Initializer any graph initializers, if in Graph mode, for Eager mode, this method does nothing.
+   */
   public void initialize() {
     // empty
   }
 
   /**
-   * Runs the Operation
+   * Returns the Graph if in Graph mode, or null if in EagerMode
+   *
+   * @return the Graph if in Graph mode, or null if in EagerMode
+   */
+  public Graph getGraph() {
+    return null;
+  }
+
+  /**
+   * Perform session.run()
+   *
+   * <p>If in eager mode, this does nothing.
    *
-   * @param op the Operation to run
+   * @param op The Operation to run
    */
   public void run(Op op) {
-    // empty
+    run(op, null);
   }
 
   /**
-   * Gets the Graph
+   * Perform session.run()
    *
-   * @return the graph if in Graph Mode, otherwise null.
+   * <p>If in eager mode, this does nothing.
+   *
+   * @param op The Operation to run
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
    */
-  public Graph getGraph() {
-    return null;
+  public abstract void run(Op op, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap);
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(Number expected, Operand<U> input) {
+    evaluate(new Number[] {expected}, input, null);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
    */
-  public <T extends TNumber> void evaluate(Number expected, Operand<T> input) {
-    evaluate(new Number[] {expected}, input);
+  public <U extends TNumber> void evaluate(
+      Number expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    evaluate(new Number[] {expected}, input, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
    */
   public void evaluate(Number expected, Op input) {
-    evaluate(new Number[] {expected}, input);
+    evaluate(new Number[] {expected}, input, null);
   }
 
   /**
-   * Evaluates the input against the expected values
+   * Evaluate the expected results versus the actual results
    *
-   * @param expected the expected values
-   * @param input the operand to evaluate
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
    */
-  public void evaluate(Number[] expected, Op input) {
-    Output<? extends TNumber> output = input.op().output(0);
-    evaluate(expected, output);
+  public void evaluate(
+      Number expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    evaluate(new Number[] {expected}, input, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param <U> the data type for the input
    */
-  public <T extends TNumber> void evaluate(Number[] expected, Operand<T> input) {
-    Output<T> output = input.asOutput();
-    evaluate(expected, output);
+  public <U extends TNumber> void evaluate(Number[] expected, Op input) {
+    Output<U> output = input.op().output(0);
+    evaluate(expected, output, null);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type for the input
    */
-  public <T extends TNumber> void evaluate(byte expected, Operand<T> input) {
-    evaluate((double) expected, input);
+  public <U extends TNumber> void evaluate(
+      Number[] expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    Output<U> output = input.op().output(0);
+    evaluate(expected, output, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param <U> the data type of the input
    */
-  public <T extends TNumber> void evaluate(int expected, Operand<T> input) {
-    evaluate((double) expected, input);
+  public <U extends TNumber> void evaluate(Number[] expected, Operand<U> input) {
+    Output<U> output = input.asOutput();
+    evaluate(expected, output, null);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
    */
-  public <T extends TNumber> void evaluate(long expected, Operand<T> input) {
-    evaluate((double) expected, input);
+  public <U extends TNumber> void evaluate(
+      Number[] expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    Output<U> output = input.asOutput();
+    evaluate(expected, output, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param <U> the data type of the input
    */
-  public <T extends TNumber> void evaluate(float expected, Operand<T> input) {
-    evaluate((double) expected, input);
+  public <U extends TNumber> void evaluate(byte expected, Operand<U> input) {
+    evaluate((double) expected, input, null);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
    */
-  public abstract <T extends TNumber> void evaluate(double expected, Operand<T> input);
+  public <U extends TNumber> void evaluate(
+      byte expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    evaluate((double) expected, input, feedMap);
+  }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(int expected, Operand<U> input) {
+    evaluate((double) expected, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(
+      int expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    evaluate((double) expected, input, feedMap);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(long expected, Operand<U> input) {
+    evaluate((double) expected, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(
+      long expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    evaluate((double) expected, input, feedMap);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(float expected, Operand<U> input) {
+    evaluate((double) expected, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
    */
-  public <T extends TNumber> void evaluate(byte[] expected, Operand<T> input) {
+  public <U extends TNumber> void evaluate(
+      float expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    evaluate((double) expected, input, feedMap);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(double expected, Operand<U> input) {
+    evaluate(expected, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
+   */
+  public abstract <U extends TNumber> void evaluate(
+      double expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap);
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(byte[] expected, Operand<U> input) {
+    evaluate(expected, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(
+      byte[] expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
     Byte[] iArray = new Byte[expected.length];
-    for (int i = 0; i < expected.length; i++) iArray[i] = expected[i];
-    evaluate(iArray, input);
+    for (int i = 0; i < expected.length; i++) {
+      iArray[i] = expected[i];
+    }
+    evaluate(iArray, input, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param <U> the data type of the input
    */
-  public <T extends TNumber> void evaluate(int[] expected, Operand<T> input) {
+  public <U extends TNumber> void evaluate(int[] expected, Operand<U> input) {
+    evaluate(expected, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(
+      int[] expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
     Integer[] iArray = new Integer[expected.length];
-    for (int i = 0; i < expected.length; i++) iArray[i] = expected[i];
-    evaluate(iArray, input);
+    for (int i = 0; i < expected.length; i++) {
+      iArray[i] = expected[i];
+    }
+    evaluate(iArray, input, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param <U> the data type of the input
    */
-  public <T extends TNumber> void evaluate(long[] expected, Operand<T> input) {
+  public <U extends TNumber> void evaluate(long[] expected, Operand<U> input) {
+    evaluate(expected, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(
+      long[] expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
     Long[] iArray = new Long[expected.length];
-    for (int i = 0; i < expected.length; i++) iArray[i] = expected[i];
-    evaluate(iArray, input);
+    for (int i = 0; i < expected.length; i++) {
+      iArray[i] = expected[i];
+    }
+    evaluate(iArray, input, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected values
+   * Evaluate the expected results versus the actual results
    *
-   * @param expected the expected values
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param expected the expected value
+   * @param input the actual value
+   * @param <U> the data type of the input
    */
-  public <T extends TNumber> void evaluate(float[] expected, Operand<T> input) {
+  public <U extends TNumber> void evaluate(float[] expected, Operand<U> input) {
+    evaluate(expected, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(
+      float[] expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
     Float[] iArray = new Float[expected.length];
-    for (int i = 0; i < expected.length; i++) iArray[i] = expected[i];
-    evaluate(iArray, input);
+    for (int i = 0; i < expected.length; i++) {
+      iArray[i] = expected[i];
+    }
+    evaluate(iArray, input, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param <U> the data type of the input
    */
-  public <T extends TNumber> void evaluate(double[] expected, Operand<T> input) {
+  public <U extends TNumber> void evaluate(double[] expected, Operand<U> input) {
+    evaluate(expected, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(
+      double[] expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
     Double[] iArray = new Double[expected.length];
-    for (int i = 0; i < expected.length; i++) iArray[i] = expected[i];
-    evaluate(iArray, input);
+    for (int i = 0; i < expected.length; i++) {
+      iArray[i] = expected[i];
+    }
+    evaluate(iArray, input, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param <U> the data type of the input
+   */
+  public <U extends TNumber> void evaluate(Number[] expected, Output<U> input) {
+    evaluate(expected, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
    */
-  public abstract <T extends TNumber> void evaluate(Number[] expected, Output<T> input);
+  public abstract <U extends TNumber> void evaluate(
+      Number[] expected,
+      Output<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap);
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
    */
   public void evaluate(String expected, Operand<TString> input) {
-    evaluate(new String[] {expected}, input);
+    evaluate(expected, input, null);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   */
+  public void evaluate(
+      String expected,
+      Operand<TString> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    evaluate(new String[] {expected}, input, feedMap);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
    */
   public void evaluate(String expected, Op input) {
-    evaluate(new String[] {expected}, input);
+    evaluate(new String[] {expected}, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   */
+  public void evaluate(
+      String expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    evaluate(new String[] {expected}, input, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
    */
   public void evaluate(String[] expected, Op input) {
+    evaluate(expected, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   */
+  public void evaluate(
+      String[] expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
     Output<TString> output = input.op().output(0);
-    evaluate(expected, output);
+    evaluate(expected, output, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
    */
   public void evaluate(String[] expected, Operand<TString> input) {
     Output<TString> output = input.asOutput();
-    evaluate(expected, output);
+    evaluate(expected, output, null);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
    */
-  public abstract void evaluate(String[] expected, Output<TString> input);
+  public abstract void evaluate(
+      String[] expected,
+      Output<TString> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap);
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
    */
   public void evaluate(Boolean expected, Operand<TBool> input) {
-    evaluate(new Boolean[] {expected}, input);
+    evaluate(new Boolean[] {expected}, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   */
+  public void evaluate(
+      Boolean expected,
+      Operand<TBool> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    evaluate(new Boolean[] {expected}, input, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
    */
   public void evaluate(Boolean expected, Op input) {
-    evaluate(new Boolean[] {expected}, input);
+    evaluate(new Boolean[] {expected}, input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   */
+  public void evaluate(
+      Boolean expected, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    evaluate(new Boolean[] {expected}, input, feedMap);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
    */
   public void evaluate(Boolean[] expected, Op input) {
     Output<TBool> output = input.op().output(0);
-    evaluate(expected, output);
+    evaluate(expected, output, null);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   */
+  public void evaluate(
+      Boolean[] expected,
+      Op input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    Output<TBool> output = input.op().output(0);
+    evaluate(expected, output, feedMap);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
    */
   public void evaluate(Boolean[] expected, Operand<TBool> input) {
     Output<TBool> output = input.asOutput();
-    evaluate(expected, output);
+    evaluate(expected, output, null);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
    */
-  public abstract void evaluate(Boolean[] expected, Output<TBool> input);
+  public void evaluate(
+      Boolean[] expected,
+      Operand<TBool> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    Output<TBool> output = input.asOutput();
+    evaluate(expected, output, feedMap);
+  }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the expected Operand
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
    */
-  public <T extends TType> void evaluate(Operand<T> expected, Op input) {
-    Output<T> output = input.op().output(0);
-    evaluate(expected, output);
+  public void evaluate(Boolean[] expected, Output<TBool> input) {
+    evaluate(expected, input, null);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   */
+  public abstract void evaluate(
+      Boolean[] expected,
+      Output<TBool> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap);
+
+  /**
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
+   * @param input the actual value
    * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   */
+  public <T extends TType> void evaluate(Operand<T> expected, Output<T> input) {
+    evaluate(expected.asOutput(), input, null);
+  }
+
+  /**
+   * Evaluate the expected results versus the actual results
+   *
+   * @param expected the expected value
+   * @param input the actual value
+   * @param <T> the data type for the feedMap entries
    */
   public <T extends TType> void evaluate(Operand<T> expected, Operand<T> input) {
-    evaluate(expected.asOutput(), input.asOutput());
+    evaluate(expected.asOutput(), input.asOutput(), null);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <T> the data type for the feedMap entries
    */
-  public abstract <T extends TType> void evaluate(Output<T> expected, Output<T> input);
+  public abstract <T extends TType> void evaluate(
+      Output<T> expected,
+      Output<T> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap);
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param <U> the data type of the input
    */
-  public <T extends TType> void evaluate(FloatNdArray expected, Operand<T> input) {
-    evaluate(expected, input.asOutput());
+  public <U extends TNumber> void evaluate(FloatNdArray expected, Operand<U> input) {
+    evaluate(expected, input.asOutput(), null);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
    * @param expected the expected value
-   * @param input the operand to evaluate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
    */
-  public abstract <T extends TType> void evaluate(FloatNdArray expected, Output<T> input);
+  public <U extends TNumber> void evaluate(
+      FloatNdArray expected,
+      Operand<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    evaluate(expected, input.asOutput(), feedMap);
+  }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
-   * @param input the operand to evaluate
-   * @param predicate the Predicate
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param expected the expected value
+   * @param input the actual value
+   * @param <U> the data type of the input
    */
-  public <T extends TType> void evaluate(Operand<T> input, Predicate<Number> predicate) {
-    evaluate(input.asOutput(), predicate);
+  public <U extends TNumber> void evaluate(FloatNdArray expected, Output<U> input) {
+    evaluate(expected, input, null);
   }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the expected results versus the actual results
    *
-   * @param input the operand to evaluate
-   * @param predicate The Predicate that evaluates the each value from input
-   * @param <T> the data type of the input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param expected the expected value
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
+   */
+  public abstract <U extends TNumber> void evaluate(
+      FloatNdArray expected,
+      Output<U> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap);
+
+  /**
+   * Evaluate the actual results using a predicate
+   *
+   * @param input the actual value
+   * @param predicate a predicate that accepts a Number as an argument, if the result of the
+   *     predicate is false, then the test will fail
+   * @param <U> the data type of the input
    */
-  public abstract <T extends TType> void evaluate(Output<T> input, Predicate<Number> predicate);
+  public <U extends TNumber> void evaluate(Operand<U> input, Predicate<Number> predicate) {
+    evaluate(input.asOutput(), predicate, null);
+  }
 
   /**
-   * Evaluates the input against the expected value
+   * Evaluate the actual results using a predicate
    *
-   * @param input the operand to evaluate
-   * @param predicate The Predicate that evaluates the each value from input
-   * @throws org.opentest4j.AssertionFailedError if the evaluation fails
+   * @param input the actual value
+   * @param predicate a predicate that accepts a Number as an argument, if the result of the
+   *     predicate is false, then the test will fail
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <U> the data type of the input
+   */
+  public abstract <U extends TNumber> void evaluate(
+      Output<U> input,
+      Predicate<Number> predicate,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap);
+
+  /**
+   * Evaluate the actual results using a predicate
+   *
+   * @param input the actual value
+   * @param predicate a predicate that accepts a Number as an argument, if the result of the
+   *     predicate is false, then the test will fail
    */
   public void evaluate(FloatNdArray input, Predicate<Number> predicate) {
     input.scalars().forEach(f -> assertTrue(predicate.test(f.getFloat())));
   }
 
   /**
-   * Print the input
+   * Print the results to the "standard" output stream.
+   *
+   * @param input the actual value
+   * @param <T> the data type for the input
+   */
+  public <T extends TType> void print(Operand<T> input) {
+    print(System.out, input, null);
+  }
+
+  /**
+   * Print the results to the "standard" output stream.
+   *
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <T> the data type for the feedMap entries
+   */
+  public <T extends TType> void print(
+      Operand<T> input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    print(new PrintWriter(new OutputStreamWriter(System.out)), input.asOutput(), feedMap);
+  }
+
+  /**
+   * Print the results to the "standard" output stream.
+   *
+   * @param input the actual value
+   */
+  public void print(Op input) {
+    print(System.out, input, null);
+  }
+
+  /**
+   * Print the results to the "standard" output stream.
+   *
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   */
+  public void print(Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0), feedMap);
+  }
+
+  /**
+   * Print the results to the "standard" output stream.
+   *
+   * @param input the actual value
+   * @param <T> the data type for the input
+   */
+  public <T extends TType> void print(Output<T> input) {
+    print(System.out, input, null);
+  }
+
+  /**
+   * Print the results to the "standard" output stream.
+   *
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <T> the data type for the input
+   */
+  public <T extends TType> void print(
+      Output<T> input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    print(new PrintWriter(new OutputStreamWriter(System.out)), input, feedMap);
+  }
+
+  /**
+   * Print the results to output stream
    *
    * @param out the output stream
-   * @param input the operand to print
-   * @param <T> the data type of the input
+   * @param input the actual value
+   * @param <T> the data type for the input
    */
   public <T extends TType> void print(OutputStream out, Operand<T> input) {
-    print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput());
+    print(out, input, null);
   }
 
   /**
-   * Print the input
+   * Print the results to output stream
    *
    * @param out the output stream
-   * @param input the op to print
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <T> the data type for the feedMap entries
+   */
+  public <T extends TType> void print(
+      OutputStream out,
+      Operand<T> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput(), feedMap);
+  }
+
+  /**
+   * Print the results to output stream
+   *
+   * @param out the output stream
+   * @param input the actual value
    */
   public void print(OutputStream out, Op input) {
-    print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0));
+    print(out, input, null);
   }
 
   /**
-   * Print the input
+   * Print the results to output stream
    *
    * @param out the output stream
-   * @param input the op to print
-   * @param <T> the data type of the input
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   */
+  public void print(
+      OutputStream out, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0), feedMap);
+  }
+
+  /**
+   * Print the results to output stream
+   *
+   * @param out the output stream
+   * @param input the actual value
+   * @param <T> the data type for the input
    */
   public <T extends TType> void print(OutputStream out, Output<T> input) {
-    print(new PrintWriter(new OutputStreamWriter(out)), input);
+    print(out, input, null);
   }
 
   /**
-   * Print the input
+   * Print the results to output stream
    *
-   * @param writer the output writer
-   * @param input the operand to print
-   * @param <T> the data type of the input
+   * @param out the output stream
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <T> the data type for the input
+   */
+  public <T extends TType> void print(
+      OutputStream out,
+      Output<T> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    print(new PrintWriter(new OutputStreamWriter(out)), input, feedMap);
+  }
+
+  /**
+   * Print the results to the character stream
+   *
+   * @param writer the character stream
+   * @param input the actual value
+   * @param <T> the data type for the input
    */
   public <T extends TType> void print(Writer writer, Operand<T> input) {
-    print(new PrintWriter(writer), input.asOutput());
+    print(writer, input, null);
   }
 
   /**
-   * Print the input
+   * Print the results to the character stream
    *
-   * @param writer the output writer
-   * @param input the op to print
+   * @param writer the character stream
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <T> the data type for the input
+   */
+  public <T extends TType> void print(
+      Writer writer,
+      Operand<T> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    print(new PrintWriter(writer), input.asOutput(), feedMap);
+  }
+
+  /**
+   * Print the results to the character stream
+   *
+   * @param writer the character stream
+   * @param input the actual value
    */
   public void print(Writer writer, Op input) {
-    print(new PrintWriter(writer), input.op().output(0));
+    print(writer, input, null);
   }
 
   /**
-   * Print the input
+   * Print the results to the character stream
    *
-   * @param writer the output writer
-   * @param input the op to print
-   * @param <T> the data type of the input
+   * @param writer the character stream
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   */
+  public void print(
+      Writer writer, Op input, Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    print(new PrintWriter(writer), input.op().output(0), feedMap);
+  }
+
+  /**
+   * Print the results to the character stream
+   *
+   * @param writer the character stream
+   * @param input the actual value
+   * @param <T> the data type for the input
    */
   public <T extends TType> void print(Writer writer, Output<T> input) {
-    print(new PrintWriter(writer), input);
+    print(writer, input, null);
+  }
+
+  /**
+   * Print the results to the character stream
+   *
+   * @param writer the character stream
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <T> the data type for the input
+   */
+  public <T extends TType> void print(
+      Writer writer,
+      Output<T> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap) {
+    print(new PrintWriter(writer), input, feedMap);
+  }
+
+  /**
+   * Print the results to the character stream
+   *
+   * @param writer the character stream
+   * @param input the actual value
+   * @param <T> the data type for the input
+   */
+  public <T extends TType> void print(PrintWriter writer, Output<T> input) {
+    print(writer, input, null);
   }
 
   /**
-   * Print the input
+   * Print the results to the character stream
    *
-   * @param writer the output writer
-   * @param input the op to print
+   * @param writer the character stream
+   * @param input the actual value
+   * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required
+   *     for placeholders.
+   * @param <T> the data type for the input
    */
-  public abstract <T extends TType> void print(PrintWriter writer, Output<T> input);
+  public abstract <T extends TType> void print(
+      PrintWriter writer,
+      Output<T> input,
+      Map<Operand<? extends TType>, Tensor<? extends TType>> feedMap);
 
   /**
-   * Get the TensorFlow Ops
+   * Get the TensorFlow Ops for this test session
    *
-   * @return the TensorFlow Ops
+   * @return the TensorFlow Ops for this test session
    */
   public abstract Ops getTF();
 
   /**
-   * Determine if this Test Session represents an Eager Session
+   * Determine whether this session is in Eager mode
    *
-   * @return true, if this Test Session represents an Eager Session
+   * @return true if the this session is in Eager mode
    */
   public abstract boolean isEager();
 
   /**
-   * Determine if this Test Session represents a Graph Session
+   * Determine whether this session is in Graph mode
    *
-   * @return true, if this Test Session represents a Graph Session
+   * @return true if the this session is in Graph mode
    */
   public boolean isGraph() {
     return !isEager();
   }
 
   /**
-   * Get the epsilon value for evaluating float values
+   * Get the current EPSILON value for floating point number comparison.
    *
-   * @return the epsilon value for evaluating float values
+   * @return the current EPSILON value for floating point number comparison.
    */
   public float getEpsilon() {
     return this.epsilon;
   }
 
   /**
-   * Set the epsilon value for evaluating float values
+   * Set the current EPSILON value for floating point number comparison.
    *
-   * @param epsilon the epsilon value for evaluating float values
+   * @param epsilon the new EPSILON value for floating point number comparison.
    */
   public void setEpsilon(float epsilon) {
     this.epsilon = epsilon;
   }
 
   /**
-   * Get the TensorFlow session object associated with this Test Session
+   * Get the TensorFlow Session object
    *
-   * @return a TensorFlow session if this is a Graph session, otherwise null
+   * @return the TensorFlow Session object, returns null if this is not a Graph Test Session
    */
   public abstract Session getGraphSession();
 
   /**
-   * Get the TensorFlow eager session object associated with this Test Session
+   * Get the TensorFlow EagerSession object
    *
-   * @return a TensorFlow session if this is an eager session, otherwise null
+   * @return the TensorFlow Session object, returns null if this is not a Graph Test Session
    */
   public abstract EagerSession getEagerSession();
 
@@ -611,15 +1164,21 @@ public void setEpsilon(float epsilon) {
   @Override
   public abstract void close();
 
-  /** @return the debug setting */
+  /**
+   * Get the debug setting
+   *
+   * @return the debug setting
+   */
   public boolean isDebug() {
     return debug;
   }
 
   /**
-   * Set the debug flag
+   * Sets the debug setting.
+   *
+   * <p>If true, then evaluate methods will also print the Tensor values to System.out.
    *
-   * @param debug the setting for debugging
+   * @param debug the debug to set
    */
   public void setDebug(boolean debug) {
     this.debug = debug;