Skip to content

Initial checkin of Keras Optimzers and helper classes. #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Sep 15, 2020

Conversation

JimClarke5
Copy link
Contributor

Added Keras Optimizers and test cases.

The general structure is to pass Ops tf in all the constructors. There is an assertGraph() method in org.tensorflow.keras.optimizers.OptimizerInterface class that ensures that the tf.scope().env() represents a Graph, and is not Eager.

The following Optimizers provide a facade to the org.tensorflow.framework.optimizers Optimizers.
AdaDelta, AdaGrad, AdaGradDA, Adam, RMSProp, and SGD (Momentum).

The remaining Optimizers were created from scratch, following the pattern of the org.tensorflow.framework.optimizers Optimizers.
Adamax, Ftrl, and Nadam

Internal support methods are found in org.tensorflow.keras.backend. The K class is a rewrite of most of the Python TensorFlow file, keras/backend.py. Classes in org.tensor.keras.backend.tf probably should be incorporated into other modules of TensorFlow java. This is currently under consideration. Internal utility classes are in org.tensor.keras.utils

Test cases are in src/main/test/org/tensorflow/keras/optimizers.
Test case support classes are in src/main/test/org/tensorflow/keras/utils.
Here you will find a TestSession abstract base class, and implementations for Graph mode (GraphTestSession) and
Eager mode (EagerTestSession). These TestSession classes help streamline the test case code when evaluating Tensor values.
A factory method on TestSession, allows the desired mode to be created based on
the Enum TestSession.Mode. Of course, Optimizers are currently constrained to Graph mode.

Fixed project dependencies in pom.xml to include the other tensorflow-java required modules, as well as
org.json:json and org.apache.commons:commons-csv that are used in callbacks to be installed later.

Copy link
Collaborator

@Craigacp Craigacp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've gone through everything bar the ProgressBar class as I'm not sure how java.io.Console works. It looks good, I've got a few comments and we should probably merge a bunch of helpers & new ops into the appropriate classes in core.

if (tf.scope().env().isGraph()) {
try (Session session = new Session((Graph) tf.scope().env())) {
if (dtype.equals(TInt32.DTYPE)) {
try (Tensor<TInt32> result =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this usable, NdArraySequence must be typed. If we decide to go ahead with this PR, it could make things easier but I'll need to double-check before confirming this. Otherwise, we will need one method per type.

tf.expandDims(sampleWeight, tf.constant(-1)),
sampleWeight);
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point, we will need to document all these methods but I guess it is fine doing it later when the interface is more stable

Copy link
Contributor Author

@JimClarke5 JimClarke5 Aug 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added JavaDoc to my copy, that I haven't pushed yet.

As far as NdArraySequence, could we have a common method on each numeric type that returns java.lang.Number?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would improve the experience at the user level (while passing the burden to the author :) ). Most of the NdArray library has been written with that spirit.

With type erasure though, you will need to provide a distinct name for each variant, i.e. something like:

public NdArraySequence<FloatNdArray> getTensorFloats(Ops tf, Operand<TFloat32> operand)

Copy link
Collaborator

@karllessard karllessard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JimClarke5 , this pull request is very long to review and unfortunately, I don't have the time to go though it in details these days.

So what I suggest is that you apply (or not) the changes I and @Craigacp proposed so far based on your judgement and we merge then as it, so Keras development will not be blocked. We won't release the Keras lib anytime soon so that gives us time to revisit it later and refactor it if needed. What do you think? and @Craigacp ?

=======================================================================*/
package org.tensorflow.keras.backend.tf;

import org.tensorflow.keras.backend.*;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this not belong to Keras in Python, then we should move either to tensorflow-framework or tensorflow-core-api. The way to decide if it is the one or the other is if we consider the operation as being something very close to a what could have been a core operation (e.g. sparseSoftmaxCrossEntropyWithLogits) or if it defines a higher-level concept. If that can help you to decide where this one (and the others you are adding) could fit....

* @return A `Tensor` of the same shape as `labels` and of the same type as `logits` with the
* softmax cross entropy loss.
*/
public static Operand sparse_softmax_cross_entropy_with_logits(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add sparseSoftmaxCrossEntropyWithLogits as an operator accessible via the Ops API, probably under tf.sparse.*

Operand precise_logits = logits;
boolean convertToFloat32 =
logits.asOutput().dataType() == TFloat16.DTYPE
|| logits.asOutput().dataType() == TBfloat16.DTYPE;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might get out of date if we add more float types in the future. We should add a type family called TFloating that is used to tagged all our floating point tensor types (that is what I did in https://github.com/karllessard/tensorflow-java/tree/tensor-as-ndarrays-3)

Copy link
Contributor Author

@JimClarke5 JimClarke5 Aug 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This specific logic is primarily checking to see if the float type is smaller than TFloat32 and if so, flag it for casting up later. I don't think TFloating would capture this specific logic. I have modified my version of DataType to add isFloating(), isInteger(), isNumeric(), isBoolean(), and isString(). I did this based on Craig's suggestion earlier.

also sparseCrossEntropyWithLogits is defined in nn in Python TF "nn.softmax_cross_entropy_with_logits". It has nothing to do with SparseTensor if that is what you thought.

*
* @param <T> the type of the SparseTensor
*/
public class SparseTensor<T extends TType> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really think SparseTensor should be a core concept of the library. I still don't know where and how it will fit exactly. Do you need this one to be checked-in immediatly @JimClarke5 ?

…hod. This allows the NAME to be used elsewhere instead of hardcoding the string.
…coding the string.

added methods isFloating(), isInteger(), isNUmeric(), isBoolean() and isString()
…EntropyWitLogits()

Added tf.nn.sparesSoftmaxCrossEntropyWithLogits() and
tf.nn.raw.sparesSoftmaxCrossEntropyWithLogits()

Added tf.nn.sigmoidCrossEntropyWithLogits()
@Craigacp
Copy link
Collaborator

I agree with Karl that once the changes we've discussed are in we should merge it and revisit later if necessary.

@JimClarke5
Copy link
Contributor Author

JimClarke5 commented Aug 20, 2020

I have gone through and simplified the PR, to only include classes directly related to Optimizers.

There will be changes to DataType and its implementations like TFLoat32 to change the hardcoded string names to a public static final String NAME attribute for each per the suggestions.

I also changed the C code to generate public static final String OP_NAME = "xxxxxx"; for each OP so that we don't have to use hard strings for op names elsewhere. This required me to update checkins for all the Ops.

I should be pushing these changes later today.

The original PR has raised some issues that we will eventually need to address, but I think it is better to move forward in baby steps rather than try to address all those issues now.

@KartikChugh
Copy link

It says over 1k files changed so I'm not sure how to review this. Any guidance?

@Craigacp
Copy link
Collaborator

There are a lot of generated files as this touched the Java ops generator to make each op have it's name as a static final field. So if you want to review it you'll need to use something other than github and diff the branch vs master in specific folders.

@JimClarke5
Copy link
Contributor Author

Yes, most of the checked in files are generated files for the TensorFlow Ops to add a static NAME field. If you want to focus on Keras, then look at the org.tensorflow.keras.optimizers package in the tensorflow-keras module under src and test.

Copy link
Collaborator

@karllessard karllessard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @JimClarke5 , thanks again for all this work,
I've reviewed everything under tensorflow-core-api and I'll review tensorflow-keras a bit later.

precise_logits = Cast.create(scope, logits, TFloat32.DTYPE);
}
/* cannot use generics on DataType because precis_logits may have been cast. */
DataType dtype = precise_logits.asOutput().dataType();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If generic parameter cannot be carried, then we should carried the wildcard so IDEs or lint checks will not complain about it: DataType<?> dtype = ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried all kinds of combinations, <?>, <? extends TType> or <? extends TNumber>, etc., but the main hangup is when I call other methods that have generics defined as in <T extends TType>, it won't compile. The only way I can get it to compile is to remove the generic.
Do you have any other suggestions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. You are right that <?> is a bad pick, we need at least to be bound to TType, so <? extends TType>. Now you are saying that even with this generic, you are not able to call a method that accepts only <T extends TType>? Do you have an example?

Copy link
Contributor Author

@JimClarke5 JimClarke5 Sep 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a code fragment that highlights the issue.

public static <T extends TNumber> Operand<T> foo(Ops tf, Operand<T> logits) {
        Operand<? extends TNumber> preciseLogits = tf.dtypes.cast(logits, TFloat32.DTYPE);
        return flattenOuterDims(tf.scope(), preciseLogits);
}
    
private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Operand<T> logits) {
    return null;
}

I get compile error on the line return flattenOuterDims:

GenericExample.java:[35,32] incompatible types: inferred type does not conform to equality constraint(s)
 inferred: T
 equality constraints(s): T,capture#1 of ? extends org.tensorflow.types.family.TNumber

BTW: I have also run into issues with consistency of <T> and <U>. Usually, <T extends TType> and <U extends TNumber>, but some times I run into methods where <T extends TNumber>. This causes conflicts similar to what I see with the above code. Seems that the compiler is looking for consistency in the generic definitions.

If I change the above to:
Operand<T> preciseLogits = tf.dtypes.cast(logits, TFloat32.DTYPE);

I get an error complaining mismatched generics on and on the line doing the cast.

GenericExample.java:[34,50] incompatible types: inferred type does not conform to equality constraint(s)
    inferred: T
    equality constraints(s): T,org.tensorflow.types.TFloat32

Copy link
Collaborator

@karllessard karllessard Sep 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for this example specifically, I can see that it won't work because the type of the returned value in foo is T, while the one returned by flattenOuterDims is <? extends Number> (inferred by preciseLogits). Both T are unrelated.

It seems that in your code, you are always expecting to return a TFloat32 tensor as the result of foo with the explicit cast. Therefore, you should enforce this in the signature of the method as well, i.e.

public static <T extends TNumber> Operand<TFloat32> foo(Ops tf, Operand<T> logits) {
        Operand<TFloat32> preciseLogits = tf.dtypes.cast(logits, TFloat32.DTYPE);
        return flattenOuterDims(tf.scope(), preciseLogits);
}
    
private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Operand<T> logits) {
    return null;
}

Do you have other examples to share that we can look at? I can imagine though that we could lose track of the TNumber boundary sometimes for some operands which is then causing problem when calling a method only accepting them. Having a concrete example showing can help us figure out what would be the best approach to take.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a quick example, I could make this change with these methods which are under my control, but I will still run into the underlying issue when I call other apis in tensorflow.core-api. The real code in org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits only does the cast to TFloat32 when it is 16 bit Float type (TBFLoat16 or TFloat16). No cast is required, If is is already TFLoat32, TFloat64, TInt32, or TInt64.

In org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits, if I change to Operand<? extends TNumber> preciseLogits, then I get an error when calling:

org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits smax =
        org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits.create(
            scope, preciseLogits, castLabels);

'create(org.tensorflow.op.Scope, org.tensorflow.Operand<T>, org.tensorflow.Operand<T>)' in 'org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits' cannot be applied to '(org.tensorflow.op.Scope, org.tensorflow.Operand<capture<? extends org.tensorflow.types.family.TNumber>>, org.tensorflow.Operand<capture<? extends org.tensorflow.types.family.TNumber>>)'

Copy link
Collaborator

@karllessard karllessard Sep 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, the problem is that the signature of raw.SoftmaxCrossEntropyWithLogits enforces that both preciseLogits and castLabels should be of the same type. When I look at your current implementation, this information gets lost so the compiler fails (and TF runtime will also fail if this condition is not met).

I see a few things that could help, if you want to give it a try:

  • U must be bound to TNumber as well, since you assert it later with an explicit cast
  • You can get rid of the casting complexity for float types with a recursively call the same method when required, so the method could recapture the effective type:
    if (logits.asOutput().dataType() == TFloat16.DTYPE || logits.asOutput().dataType() == TBfloat16.DTYPE) {
        Operand<TFloat32> cost = softmaxCrossEntropyWithLogits(scope, labels, Cast.create(scope, logits, TFloat32.DTYPE));
        return Cast.create(scope, cost, logits.asOutput().dataType());
    }
  • Make sure all parameterized variables keeps track of their type (DataType, Operand, ...), avoiding wildcards as much as possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on your suggestion, I did the following, using recursion, which does compile. I still need to validate that I don't get a runtime error.

public static <T extends TNumber> Operand<T> softmaxCrossEntropyWithLogits(
      Scope scope, Operand<T> labels, Operand<T> logits, int axis) {

...
if (convertToFloat32) {
      Operand<TFloat32> result =  softmaxCrossEntropyWithLogits(scope,
              Cast.create(scope, labels, TFloat32.DTYPE),
              Cast.create(scope, logits, TFloat32.DTYPE),
              axis);
      return Cast.create(scope, result, logits.asOutput().dataType());
} else if(!logits.asOutput().dataType().equals(labels.asOutput().dataType())) {
      return softmaxCrossEntropyWithLogits(scope,
              Cast.create(scope, labels, logits.asOutput().dataType()),
              logits,
              axis);
}
....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did have to cast labels when calling org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits.create

At this point, both args will be the same datatype.

org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits<T> smax =
        org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits.create(
            scope, logits, (Operand<T>)labels);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to change the signature to:

public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntropyWithLogits(
      Scope scope, Operand<U> labels, Operand<T> logits, int axis) {

This is because labels and logits may be different types initially.

Cast.create(scope, Rank.create(scope, precise_logits), TInt64.DTYPE);
Shape shape = logits.asOutput().shape();

// Move the dim to the end if dim is not the last dimension.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

".... if axis is not the last dimension"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it is dim_index, which I changed to dimIndex.
// Move the dim to the end if dimIndex is not the last dimension.

SoftmaxCrossEntropyWithLogits<T> smax =
SoftmaxCrossEntropyWithLogits.create(scope, precise_logits, labels);
/* cannot use generic on cost, because cost may be recast later. */
Operand cost = smax.loss();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idem: Operand<?>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same issue with generics as mentioned earlier in this method.

}
if (productValid) {
org.tensorflow.ndarray.Shape outputShape = Shape.of(product, shape.size(ndims - 1));
return Reshape.create(scope, logits, Constant.vectorOf(scope, outputShape.asArray()));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just Constant.tensorOf(scope, shape.size(ndims - 1)) will do the job here

Copy link
Contributor Author

@JimClarke5 JimClarke5 Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand your comment.

This method, preserves the last dimension of shape, but does a product on all the dimensions but the last dimension. The resulting dimension is a long[] of product, last size.
For example, shape [2,3,4] becomes [6,4].
This is a more concise way to do this though, getting rid of the intermediate shape.

return Reshape.create(scope, logits,
                        Constant.vectorOf(scope, new long[] { product, shape.size(-1)});

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just changed it to use Constant.arrayOf()

return Reshape.create(
            scope, logits, Constant.arrayOf(scope, product, shape.size(-1)));

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my comment was wrong as I've ignored the product value. What you did is right by using Constant.arrayOf is basically what I meant, i.e. there is no need to convert an array to a Shape if we are passing just back this array to Constant

Operand<TInt64> rank = Cast.create(scope, Rank.create(scope, logits), TInt64.DTYPE);
Operand<TInt64> rankMinusOne = Sub.create(scope, rank, one);

Operand<TInt64> last_dim_size =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

camelCase variable names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Operand<TInt64> concat =
Concat.create(
scope,
Arrays.asList(Constant.vectorOf(scope, new long[] {-1}), last_dim_size),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...Constant.arrayOf(scope, -1L)...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

* @return the reshaped input
*/
private static <T extends TNumber, U extends TNumber> Operand<T> moveDimToEnd(
Scope scope, Operand<T> input, int dim_index, Operand<U> rank) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dimIndex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

super(layout.applyTo(buffer), shape);
tensorBuffer = buffer;
}
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get a feeling that auto-formatter was a bit aggressive here, can you please just double-checked that the Google one was used when reformatting all the types classes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have verified that google-java-format settings are installed and enabled in InteliJ

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a format change on TString.java and TBool.java when I double checked all the formats on the types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it's fine I just wanted to make sure that the right formatter was applied

…x JavaDoc. Change from snake case to camel case.
case 1:
return Optional.of(initializers.get(0));
default:
return Optional.of( tf.withSubScope(name).withControlDependencies(initializers).noOp());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra space before tf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

String name = (String) config.get(NAME_KEY);
float learningRate = (float) config.getOrDefault(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT);
float rho = (float) config.getOrDefault(RHO_RATE_KEY, RHO_DEFAULT);
float epsilon = (float) config.getOrDefault(EPSILON_KEY, EPSILON_DEFAULT);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create a concrete type for the config instead of having a map of untyped values? e.g. could it be something like an AdaDelta.Config subclass following a building pattern similar to the ops options classes like this one? I probably don't understand the use of this config format though.

The same comment applies for all other optimizers

Copy link
Contributor Author

@JimClarke5 JimClarke5 Sep 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was used in Python for "serialization" of the Optimizers. The Config could be saved, then used to restore the state of the Optimizer. The key is always a String, but the value can be a String, a Float or Integer, (but maybe also a Double or Boolean), so Map<String, Object> should suffice. Since I did this, I have not yet seen where this dictionary is actually used, so I assume the python code expected some user software or tool to use it. I am not sure how this plays into saving/restoring the model, as I haven't looked at that logic yet.

We could pull it out for now, then revisit it when we start tackling saving/restoring the Keras model.

Or I can change it to the options pattern, as long as there is a way to read/write it, perhaps as a JSON string.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vote for removing the serialization/deserialization support for now and add it later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we return to this we could make the config objects implement Map<String,Object> but specialise the types per optimiser and provide specific getters & setters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have started to write the "options" pattern with internal classes on the Optimizers. Do we want to do this? Some of the optimizers' ctor's are trivial, but then others may have a more significant number of options with Keras specified defaults.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really think we should postpone this feature for now as there is no urgent need to support serialization/deserialization of the configurations, we can then brainstorm on the right approach to take.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed it for now and we can revisit later when we look at serializing the Keras Model

float rho = (float) config.getOrDefault(RHO_RATE_KEY, RHO_DEFAULT);
float epsilon = (float) config.getOrDefault(EPSILON_KEY, EPSILON_DEFAULT);
if (name == null) // doe this to get the default name
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrequired new line before bracket

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

this.learningRate = learningRate;
}


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra lines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

public static final String INITIAL_ACCUM_KEY = "accumulator";

public static final float LEARNING_RATE_DEFAULT = 0.001F;
public static final float INITIAL_ACCUM__DEFAULT = 0.1f;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double underscores.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

import org.tensorflow.types.family.TType;

/** Adamax Optimizer that implements the Adamax algorithm. */
public class Adamax extends org.tensorflow.framework.optimizers.Optimizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create a Adamax implementation in the framework first and then having its Keras wrapper extend from it like other optimizers? The same applies for all other Keras optimizers that are not present in the framework.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

(Operand) tf.dtypes.cast(betaOneConst, gradient.dataType()),
(Operand) tf.dtypes.cast(betaTwoConst, gradient.dataType()),
(Operand) tf.dtypes.cast(epsilonConst, gradient.dataType()),
(Operand) gradient);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again are you sure we need these explicit casting to Operand? If so, something is wrong with our generic signatures and we should address that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the (Operand) casts. It must have been when I was just figuring out the framework.

import org.tensorflow.op.Ops;

/** The main Interface for Keras Optimizers */
public interface OptimizerInterface {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, I would prefer we avoid to suffix this class name with Interface, since none of our other interfaces have this suffix.

Copy link
Contributor Author

@JimClarke5 JimClarke5 Sep 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to avoid a name clash with org.tensorflow.framework.optimizers.Optimizer. Some of OptimizerInterface could be refactored into org.tensorflow.framework.optimizers.Optimizer. At the time, I was trying to keep the two projects, keras and frameworks, totally separate.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could call this KerasOptimizer? The fact that python overloads all the names doesn't mean we need to. Or rename org.tensorflow.framework.optimizers.Optimizer to BaseOptimizer or FrameworkOptimizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I have done away with the KerasOptimizer interface for now and put assertGraph as a static call in a helper class. However, this helper class will probably be useful in some of the other Keras packages that use Variables, so maybe it should be moved to a general util package that is only visible to the java module, wherever we decide to put it.

One question that remains for the Optimizer interface, is where to put the method prototypes for get/setLearningRate, when that feature comes in the next PR. The Keras code will need to treat these methods as a general Optimizer, rather than as individual types of Optimizer. There are Optimizers that I found on the Web that don't use learning rate, but all of the currently defined Optimizers in the framework do use it.

* @throws java.lang.IllegalArgumentException if the TensorFlow Ops does not represent Graph mode
*/
static Graph assertGraph(Ops tf) {
if(!tf.scope().env().isGraph()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing space after if

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, on the formatting,

If I refactor OptimizerInterface into org.tensorflow.framework.optimizers.Optimizer, then where should assetGraph() method go? It is basically a utility that can be called from a super() ctor method, ensuring that Ops tf always represents a Graph, and then passes the Graph to the framework optimizers ctors.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with either or.

I myself have some trouble figuring out how the framework and keras layer should split the job but at a high level, what I'm thinking is that most of the logic should occur in the framework and the core libraries while the Keras library should add a Pythonic Keras-like API as a facade to them. But I'm sure there will be exceptions. So in this case, if you don't think Keras users will need to access directly the graph of an optimizer and no such getter exists in the Python Keras lib, then maybe we should move it to the framework.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main issue is consistency. The rest of the Keras apis are using Ops instead of Graph directly. Using Ops has an advantage, from my point of view, in that it insulates from a tight binding to Graph, when one wants Eager support.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your comment also brings up another consideration. If you take the premise that "Keras library should add a Pythonic Keras-like API as a facade to them", then why wouldn't activations, loss, metrics, initializers, etc. be primarily in frameworks rather than solely in Keras?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should have stronger typing information than exists in Python, as it's what would be expected from idiomatic Java and it helps IDEs & discoverability

I totally support that as well, we should avoid users to pass arbitrary string labels as parameters in the Java implementation. We could though have some sort of enums per components (optimizers, metrics, ...) which will restrict the possible choices without the need for the user to instantiate explicitly that component, to get closer to what the Python implementation offers. e.g. model.compile(Optimizers.ADAM, Losses.BINARY_CROSSENTROPY)

Allows developers to selectively be more explicit by overriding defaults or dipping into the framework API

This is where I'm hesitating. If we allow users to access lower level functionalities from the Keras API, why do we need another API then? The current Python API is made up of two layers because it is historically the product of a merge between two different projects: the original TF API and the Keras project. I personally think it brings more confusion to the users that benefits and we don't need to follow this schema if we think we can do better in Java since we start from scratch.

I'm slowly leaning now to the idea of having a single API that supports both "beginner" and "advanced" modes, whether we call it Keras or not. At the same time, I don't want to slow @JimClarke5 down any further with this PR. Maybe let's continue that discussion with a broader audience? We can raise an new issue in GitHub, start another discussion on the mailing list or we could do this in the scope of a PR created by @deansher as he suggested (these descriptions sound good), what do you think?

Copy link
Collaborator

@Craigacp Craigacp Sep 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a PR to add that to the readme sounds good, and after that's merged we can discuss it in an issue. We definitely shouldn't hold up this optimizer PR any longer.

My worry about forcing everything through something that's like Keras is that I train models that won't fit into a keras style model.fit call. One of the things I'm currently looking at is USE, which is trained on multiple tasks, with each task using a different combination of dataset and loss function (e.g. train a question answering loss on one dataset, a textual similarity loss on another, and an MLM loss on a third, where apart from the top layer all layers are shared). However this isn't something that is well supported by any framework I've come across (unfortunately Google didn't release any training details or code on USE), and requires a lot of complexity that most users won't need. I can just about do it in bare TF or pytorch (in python, not tried using TF Java yet as we'd need TF text to work), but anything higher level requires fighting the framework at each step because I need control of the training loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To move forward, may I suggest the following:

  1. Add the 3 new Optimizers, Adamax, Ftrl, and Nadam,to framework.
  2. Move the test cases to framework.
  3. Hold off on the remaining classes in org.tensorflow.keras.optimizers, which are just a facade on framework, until we decide what to do with Keras.

The next PR that I will work on will allow for a changing LearningRate and only impacts Optimizers in framework.
After that, I propose adding the other Keras elements, Initializers, Activations, etc., to framework.

Copy link
Contributor Author

@JimClarke5 JimClarke5 Sep 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One other issue, Keras defines the Optimizer SGD which is actually just Momentum in framework, should we just ignore this for the time being?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Add the 3 new Optimizers, Adamax, Ftrl, and Nadam,to framework.

  • Move the test cases to framework.

  • Hold off on the remaining classes in org.tensorflow.keras.optimizers, which are just a facade on framework, until we decide what to do with Keras.

Sounds good @JimClarke5 , as in both approaches we are currently looking at, none of the optimizer backends will end up in the Keras module. Let's continue this discussion on #109 as it is becoming soon a priority that we all agree on this.

SGD can be added later, yes.

/** {@inheritDoc} */
@Override
public Map<String, Object> getConfig() {
return config;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a concrete type for our config, then either our Keras optimizers classes should be parametized to return the right type of config, or either it should be returned as an Object itself. I can't see in this PR what is the use of getConfig() so I can't give a clear answer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my previous comment on Config.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed config from this based on deferring the question to when we address serializing the Keras model.

…Java files for some Ops. This also resulted in new generated source that are also committed.
…gmoidCrossEntropyWithLogits.java, and SparseSoftmaxCrossEntropyWithLogits.java under package org.tensorflow.op.nn in
Copy link
Contributor

@deansher deansher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nits in some tests.

config.put(BETA_ONE_KEY, BETA_ONE_DEFAULT);
config.put(BETA_TWO_KEY, BETA_TWO_DEFAULT);
config.put(EPSILON_KEY, EPSILON_DEFAULT);
AdaDelta expResult = new AdaDelta(tf);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AdaDelta -> Adam

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is being removed as the config has been removed until we figure out the best way to serialize a Keras model.

session.evaluate(var0_init, var0);
session.evaluate(var1_init, var1);

FloatNdArray m0_np = NdArrays.ofFloats(shape1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape1 -> shape0 on this line and the next.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


m0_np = calculateM(m0_np, grads0_np, beta1);
v0_np = calculateV(v0_np, grads0_np, beta2);
var0_np = calculateParam(var0_np, lr_t, m0_np, v0_np, 1e-7F);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1e-7F -> epsilon here and a few lines down for var1_np.

The fact that the test doesn't notice this difference suggests using a substantially larger epsilon, but consistency with the Python test may be more important at the moment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also @JimClarke5 , don't forget please to camel-case all these variables as well, or once we'll reactivate lint checks there will be a lot of failures.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was borrowed from the Python test, and the calculation of the expected values is not as precise as the TF calculation. Maybe there is a way to make the expected calculation more precise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK on camel-case

}

@Test
public void testBasic() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I wouldn't propose holding up this PR for this issue.)

This test raises interesting questions about both our API layers and our testing strategy. It purports to be a test of keras.optimizers.Adam, but in fact it tests the actual underlying Adam logic that is exposed by framework.optimizers.Adam -- which is not, however, yet tested in framework.optimizers.

If we stick with the current layering strategy, I'd propose that in a future PR we move this test into framework.optimizers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of now, I would vote that these tests be moved to framework, especially in that we are moving an addition 3 optimizers to framework that are defined in Keras. Also, the JavaDoc on the existing Optimizers need a lot of work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that I have limited time availability and am a decent technical writer, that's a likely area for me to contribute.

@JimClarke5
Copy link
Contributor Author

I just pushed the following changes,

  1. Added 3 Optimizers to framework.
  2. Moved the Optimizer unit tests from Keras module to framework.
  3. Added Optimizers.java which contains an ENUM to quickly identify a default Optimizer. The intent here is the user could pass in the Enum, and then some other framework can use that to instantiate a default Optimizer. For example, Optimizers.RMSPROP would indicate that the default RMSProp optimizer is desired. These enums replace String identifiers that were use in Python TF. I just picked Optimizers as the Enum class name, I am open to other suggestions. The benefit of this approach is that actual Optimizer instantiation can be deferred until a Model, for example, is set up and ready to use it.
  4. I deleted the Keras files under this PR, as it is easy enough to add them back in later, if we decide to do the facade pattern for Keras.

Copy link
Collaborator

@karllessard karllessard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @JimClarke5 , I think this one is ready! Anyway, like we've discussed, if there are little adjustments to do, we'll do them as we go.

@karllessard karllessard merged commit 2843138 into tensorflow:master Sep 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants