-
Notifications
You must be signed in to change notification settings - Fork 214
Add Activations #123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add Activations #123
Changes from all commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
ef0ce67
Initial checkin of Keras Optimzers and helper classes.
JimClarke5 9c113a7
Added static final NAME to replace hardcoded String in the create met…
JimClarke5 824d487
Changed of method to use the DataType NAME attribute rather than hard…
JimClarke5 07a83a5
Added method WriteFieldWithInitializer to output a "final static Stri…
JimClarke5 3d26831
Added tf.nn.softmaxCrossEntropyWitLogits() and tf.nn.raw.softmaxCross…
JimClarke5 11cda5f
Moved SoftmaxCrossEntropyWithLogits and SparseSoftmaxCrossEntropyWit…
JimClarke5 9c7dfaa
Generated classes now have public static final String OP_NAME = "XXXX…
JimClarke5 84f49db
Generated classes now have public static final String OP_NAME = "XXXX…
JimClarke5 208b84a
fix dependencies for other Tensorflow Java modules
JimClarke5 3913161
formatting fix
JimClarke5 b5a7c0f
Fix ctors with name to properly pass the name to the the super ctor.
JimClarke5 fcba0a5
change asserts to IllegalArgumentException
JimClarke5 960cfc3
change asserts to IllegalArgumentException
JimClarke5 d37298a
Moved back to tests
JimClarke5 c68812c
Moved SoftmaxCrossEntropyWithLogits.java and SparseSoftmaxCrossEntrop…
JimClarke5 6b8eb26
Deleted files that are not necessary yet
JimClarke5 6515c24
Added nn.raw group for softmaxCrossEntropyWithLogits() and sparseSoft…
JimClarke5 76d0fe5
Added nn.raw group for softmaxCrossEntropyWithLogits() and sparseSoft…
JimClarke5 d2201df
Merge branch 'master' into master
JimClarke5 ab379d1
Refactor NN into individual operations under org.tensorflow.op.nn. Fi…
JimClarke5 889d67e
Refactor NN into individual operations under org.tensorflow.op.nn. Fi…
JimClarke5 515b799
Reformatted code
JimClarke5 5a9fe37
Added sub scope
JimClarke5 8d21dd7
Miscellaneous fixes based on review comments.
JimClarke5 4c3cc78
Fixed op_generator.cc to remove a spurious new line in the generated …
JimClarke5 44f530f
Changed back to non-generic Operand until we resolve how to handle ge…
JimClarke5 b8d3ac2
Regenerated due to creation of SoftmaxCrossEntropyWithLogits.java, S…
JimClarke5 c32fc5b
change snake case to camel case. format code
JimClarke5 171cd2f
clean upd warning, format code
JimClarke5 e9c3134
Added Adamax, Ftrl, and Nadam Optimizers. Added Optimizers enum for e…
JimClarke5 5c30a72
Removed optimize classes from tensorflow-keras, moved optimizer test …
JimClarke5 ebefc2e
Fixed generics
JimClarke5 7915e63
Fixed from Unit test results
JimClarke5 ec4f679
added @SuppressWarnings("unchecked") on Variable array
JimClarke5 c86d09b
Merge pull request #1 from tensorflow/master
JimClarke5 1a670ec
Added Support for evaluating TFloat16
JimClarke5 0cc9b9c
Add Activations
JimClarke5 ca77a0b
Remove no-arg CTORs
JimClarke5 73091be
Fix Unit Tests to include positive and negative numbers on input.
JimClarke5 946d1d5
Modify JavaDoc indicating Linear activation is also known as Identity…
JimClarke5 7c5cc4a
Changed DEFAULT values from private to public
JimClarke5 e32fe44
Fixed last sum to be over 'e' instead of 'input'
JimClarke5 0130914
Added tests for various parameter constructs.
JimClarke5 c7d0477
added tests for 1D and 3D input
JimClarke5 de0e610
Change snake case to camel case
JimClarke5 63c1f00
JavaDoc fixes
JimClarke5 2302cc5
Add TFloating family
JimClarke5 4c44c62
Add JavaDoc
JimClarke5 ef29af9
Changed to TFloating where appropriate.
JimClarke5 7519436
Remove the test of int arguments for those classes changed to TFloati…
JimClarke5 27c1126
Remove the test of int arguments for those classes changed to TFloati…
JimClarke5 b83f94f
Make LeakyRelu visible so that it is included in tf.nn.
JimClarke5 c59e905
Remove TNumber import
JimClarke5 ebbcc4f
Add tf.nn.leakyRelu operation
JimClarke5 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
1 change: 1 addition & 0 deletions
1
tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LeakyRelu.pbtxt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
op { | ||
graph_op_name: "LeakyRelu" | ||
visibility: VISIBLE | ||
endpoint { | ||
name: "nn.LeakyRelu" | ||
} | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
19 changes: 19 additions & 0 deletions
19
tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TFloating.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package org.tensorflow.types.family; | ||
|
||
/** | ||
* Marker interface for floating point tensor types. | ||
* | ||
* <p>Operations that only accepts floating point values as some of their operands enforce that the tensor | ||
* types for these operands to be bound to this interface. For example: | ||
* | ||
* <pre>{@code | ||
* TFloat32 tensor1 = TFloat32.vectorOf(1, 2, 3); | ||
* TBool tensor2 = TBool.vectorOf(true, false, true); | ||
* | ||
* Ops tf = Ops.create(); | ||
* Exponential<TFloat32> exp = new Exponential<>(tf); | ||
* exp.call(tf.constant(tensor1)); // OK | ||
* exp.call(tf.constant(tensor2)); // Compilation failure | ||
* }</pre> | ||
*/ | ||
public interface TFloating extends TNumber {} |
68 changes: 68 additions & 0 deletions
68
tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.activations; | ||
|
||
import org.tensorflow.Operand; | ||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.types.family.TNumber; | ||
|
||
/** | ||
* Abstract base class for Activations | ||
* | ||
* <p><b>Note:</b> The {@link #tf} attribute must be set prior to invoking the call method. See | ||
* {@link #setTF(Ops)} and the constructor {@link #Activation(Ops)}. | ||
* | ||
* @param <T> the data type of the activation | ||
*/ | ||
public abstract class Activation<T extends TNumber> { | ||
|
||
/** The TensorFlow Ops */ | ||
protected Ops tf; | ||
|
||
/** | ||
* Creates the abstract class for an Activation | ||
* | ||
* @param tf the TensorFlow Ops | ||
*/ | ||
protected Activation(Ops tf) { | ||
karllessard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
this.tf = tf; | ||
} | ||
|
||
/** | ||
* Sets the TensorFlow Ops | ||
* | ||
* @param tf the TensorFlow Ops | ||
*/ | ||
protected void setTF(Ops tf) { | ||
this.tf = tf; | ||
} | ||
|
||
/** | ||
* Gets the TensorFlow Ops | ||
* | ||
* @return the TensorFlow Ops | ||
*/ | ||
protected Ops getTF() { | ||
return this.tf; | ||
} | ||
karllessard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
/** | ||
* Gets the calculation operation for the activation. | ||
* | ||
* @param input the input tensor | ||
* @return The operand for the activation | ||
*/ | ||
public abstract Operand<T> call(Operand<T> input); | ||
karllessard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} |
98 changes: 98 additions & 0 deletions
98
tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.activations; | ||
|
||
import org.tensorflow.DataType; | ||
import org.tensorflow.Operand; | ||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.types.TBool; | ||
import org.tensorflow.types.family.TFloating; | ||
|
||
/** | ||
* Exponential linear unit. | ||
* | ||
* <p>The exponential linear unit (ELU) with <code>alpha > 0</code> is: | ||
* | ||
* <p><code>x</code> if <code>x > 0</code> and <code>alpha * (exp(x) - | ||
* 1)</code> if <code>x < 0</code>. | ||
* | ||
* <p>The ELU hyperparameter <code>alpha</code> controls the value to which an ELU saturates for | ||
* negative net inputs. ELUs diminish the vanishing gradient effect. | ||
* | ||
* <p>ELUs have negative values which pushes the mean of the activations closer to zero. Mean | ||
* activations that are closer to zero enable faster learning as they bring the gradient closer to | ||
* the natural gradient. ELUs saturate to a negative value when the argument gets smaller. | ||
* Saturation means a small derivative which decreases the variation and the information that is | ||
* propagated to the next layer. | ||
* | ||
* <p>Example Usage: | ||
* | ||
* <pre> | ||
* Operand<TFloat32> input = ...; | ||
* ELU<TFloat32> elu = new ELU<>(tf, 2.0f); | ||
* Operand<TFloat32> result = elu.call(input); | ||
* </pre> | ||
* | ||
* @param <T> the data type of the activation | ||
* @see <a href="https://arxiv.org/abs/1511.07289">Clevert et al, 2016, Fast and Accurate Deep | ||
* Network Learning by Exponential Linear Units (ELUs)</a> | ||
*/ | ||
public class ELU<T extends TFloating> extends Activation<T> { | ||
|
||
private static final double ALPHA_DEFAULT = 1.0; | ||
|
||
/** A scalar, slope of negative section. */ | ||
private final double alpha; | ||
|
||
/** | ||
* Creates a new ELU with alpha={@link #ALPHA_DEFAULT}. | ||
* | ||
* @param tf the TensorFlow Ops | ||
*/ | ||
public ELU(Ops tf) { | ||
this(tf, ALPHA_DEFAULT); | ||
} | ||
|
||
/** | ||
* Creates a new ELU | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param alpha A scalar, slope of negative section. It controls the value to which an ELU | ||
* saturates for negative net inputs. | ||
*/ | ||
public ELU(Ops tf, double alpha) { | ||
super(tf); | ||
this.alpha = alpha; | ||
} | ||
|
||
/** | ||
* Gets the calculation operation for the activation. | ||
* | ||
* @param input the input tensor | ||
* @return The operand for the activation | ||
*/ | ||
@Override | ||
public Operand<T> call(Operand<T> input) { | ||
|
||
Operand<T> result = tf.nn.elu(input); | ||
if (alpha == 1.0) return result; | ||
else { | ||
DataType<T> dataType = input.asOutput().dataType(); | ||
Operand<T> y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), dataType)); | ||
Operand<TBool> cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), dataType)); | ||
return tf.select(cond, result, y); | ||
} | ||
} | ||
} |
57 changes: 57 additions & 0 deletions
57
tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.activations; | ||
|
||
import org.tensorflow.Operand; | ||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.types.family.TFloating; | ||
|
||
/** | ||
* Exponential activation function. | ||
* | ||
* <p>For example: | ||
* | ||
* <pre> | ||
* Operand<TFloat32> input = tf.constant( | ||
* new float[] {-3.0f,-1.0f, 0.0f,1.0f,3.0f}); | ||
* Exponential<TFloat32> exp = new Exponential<>(tf); | ||
* Operand<TFloat32> result = exp.call(input); | ||
* // result is [0.04978707f, 0.36787945f, 1.f, 2.7182817f, 20.085537f] | ||
* </pre> | ||
* | ||
* @param <T> the data type of the activation | ||
*/ | ||
public class Exponential<T extends TFloating> extends Activation<T> { | ||
|
||
/** | ||
* Creates an Exponential activation. | ||
* | ||
* @param tf the TensorFlow Ops | ||
*/ | ||
public Exponential(Ops tf) { | ||
super(tf); | ||
} | ||
|
||
/** | ||
* Calculates the Exponential activation. | ||
* | ||
* @param input the input tensor | ||
* @return an Operand for the exponential activation: <code>exp(x)</code>. | ||
*/ | ||
@Override | ||
public Operand<T> call(Operand<T> input) { | ||
return tf.math.exp(input); | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.