Skip to content

Add Activations #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Oct 24, 2020
Merged
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
ef0ce67
Initial checkin of Keras Optimzers and helper classes.
JimClarke5 Jul 28, 2020
9c113a7
Added static final NAME to replace hardcoded String in the create met…
JimClarke5 Aug 20, 2020
824d487
Changed of method to use the DataType NAME attribute rather than hard…
JimClarke5 Aug 20, 2020
07a83a5
Added method WriteFieldWithInitializer to output a "final static Stri…
JimClarke5 Aug 20, 2020
3d26831
Added tf.nn.softmaxCrossEntropyWitLogits() and tf.nn.raw.softmaxCross…
JimClarke5 Aug 20, 2020
11cda5f
Moved SoftmaxCrossEntropyWithLogits and SparseSoftmaxCrossEntropyWit…
JimClarke5 Aug 20, 2020
9c7dfaa
Generated classes now have public static final String OP_NAME = "XXXX…
JimClarke5 Aug 20, 2020
84f49db
Generated classes now have public static final String OP_NAME = "XXXX…
JimClarke5 Aug 20, 2020
208b84a
fix dependencies for other Tensorflow Java modules
JimClarke5 Aug 20, 2020
3913161
formatting fix
JimClarke5 Aug 20, 2020
b5a7c0f
Fix ctors with name to properly pass the name to the the super ctor.
JimClarke5 Aug 20, 2020
fcba0a5
change asserts to IllegalArgumentException
JimClarke5 Aug 20, 2020
960cfc3
change asserts to IllegalArgumentException
JimClarke5 Aug 20, 2020
d37298a
Moved back to tests
JimClarke5 Aug 20, 2020
c68812c
Moved SoftmaxCrossEntropyWithLogits.java and SparseSoftmaxCrossEntrop…
JimClarke5 Aug 20, 2020
6b8eb26
Deleted files that are not necessary yet
JimClarke5 Aug 20, 2020
6515c24
Added nn.raw group for softmaxCrossEntropyWithLogits() and sparseSoft…
JimClarke5 Aug 20, 2020
76d0fe5
Added nn.raw group for softmaxCrossEntropyWithLogits() and sparseSoft…
JimClarke5 Aug 20, 2020
d2201df
Merge branch 'master' into master
JimClarke5 Aug 20, 2020
ab379d1
Refactor NN into individual operations under org.tensorflow.op.nn. Fi…
JimClarke5 Sep 3, 2020
889d67e
Refactor NN into individual operations under org.tensorflow.op.nn. Fi…
JimClarke5 Sep 3, 2020
515b799
Reformatted code
JimClarke5 Sep 3, 2020
5a9fe37
Added sub scope
JimClarke5 Sep 3, 2020
8d21dd7
Miscellaneous fixes based on review comments.
JimClarke5 Sep 3, 2020
4c3cc78
Fixed op_generator.cc to remove a spurious new line in the generated …
JimClarke5 Sep 3, 2020
44f530f
Changed back to non-generic Operand until we resolve how to handle ge…
JimClarke5 Sep 3, 2020
b8d3ac2
Regenerated due to creation of SoftmaxCrossEntropyWithLogits.java, S…
JimClarke5 Sep 3, 2020
c32fc5b
change snake case to camel case. format code
JimClarke5 Sep 7, 2020
171cd2f
clean upd warning, format code
JimClarke5 Sep 7, 2020
e9c3134
Added Adamax, Ftrl, and Nadam Optimizers. Added Optimizers enum for e…
JimClarke5 Sep 9, 2020
5c30a72
Removed optimize classes from tensorflow-keras, moved optimizer test …
JimClarke5 Sep 9, 2020
ebefc2e
Fixed generics
JimClarke5 Sep 9, 2020
7915e63
Fixed from Unit test results
JimClarke5 Sep 9, 2020
ec4f679
added @SuppressWarnings("unchecked") on Variable array
JimClarke5 Sep 9, 2020
c86d09b
Merge pull request #1 from tensorflow/master
JimClarke5 Sep 18, 2020
1a670ec
Added Support for evaluating TFloat16
JimClarke5 Sep 30, 2020
0cc9b9c
Add Activations
JimClarke5 Sep 30, 2020
ca77a0b
Remove no-arg CTORs
JimClarke5 Oct 1, 2020
73091be
Fix Unit Tests to include positive and negative numbers on input.
JimClarke5 Oct 1, 2020
946d1d5
Modify JavaDoc indicating Linear activation is also known as Identity…
JimClarke5 Oct 2, 2020
7c5cc4a
Changed DEFAULT values from private to public
JimClarke5 Oct 2, 2020
e32fe44
Fixed last sum to be over 'e' instead of 'input'
JimClarke5 Oct 2, 2020
0130914
Added tests for various parameter constructs.
JimClarke5 Oct 2, 2020
c7d0477
added tests for 1D and 3D input
JimClarke5 Oct 2, 2020
de0e610
Change snake case to camel case
JimClarke5 Oct 2, 2020
63c1f00
JavaDoc fixes
JimClarke5 Oct 4, 2020
2302cc5
Add TFloating family
JimClarke5 Oct 21, 2020
4c44c62
Add JavaDoc
JimClarke5 Oct 21, 2020
ef29af9
Changed to TFloating where appropriate.
JimClarke5 Oct 21, 2020
7519436
Remove the test of int arguments for those classes changed to TFloati…
JimClarke5 Oct 21, 2020
27c1126
Remove the test of int arguments for those classes changed to TFloati…
JimClarke5 Oct 21, 2020
b83f94f
Make LeakyRelu visible so that it is included in tf.nn.
JimClarke5 Oct 22, 2020
c59e905
Remove TNumber import
JimClarke5 Oct 22, 2020
ebbcc4f
Add tf.nn.leakyRelu operation
JimClarke5 Oct 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
op {
graph_op_name: "LeakyRelu"
visibility: VISIBLE
endpoint {
name: "nn.LeakyRelu"
}
Original file line number Diff line number Diff line change
@@ -59,6 +59,7 @@
import org.tensorflow.op.nn.FusedResizeAndPadConv2d;
import org.tensorflow.op.nn.InTopK;
import org.tensorflow.op.nn.L2Loss;
import org.tensorflow.op.nn.LeakyRelu;
import org.tensorflow.op.nn.LearnedUnigramCandidateSampler;
import org.tensorflow.op.nn.LocalResponseNormalization;
import org.tensorflow.op.nn.LogSoftmax;
@@ -1226,6 +1227,19 @@ public <T extends TNumber> L2Loss<T> l2Loss(Operand<T> t) {
return L2Loss.create(scope, t);
}

/**
* Computes rectified linear: `max(features, features * alpha)`.
*
* @param <T> data type for {@code activations()} output
* @param features
* @param options carries optional attributes values
* @return a new instance of LeakyRelu
*/
public <T extends TNumber> LeakyRelu<T> leakyRelu(Operand<T> features,
LeakyRelu.Options... options) {
return LeakyRelu.create(scope, features, options);
}

/**
* Generates labels for candidate sampling with a learned unigram distribution.
* <p>
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@
*
* @param <T> data type for {@code activations()} output
*/
@Operator(group = "nn")
public final class LeakyRelu<T extends TNumber> extends RawOp implements Operand<T> {

/**
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TFloating;

/**
* Brain 16-bit float tensor type.
@@ -48,7 +48,7 @@
* <p>Note that some CPUs support the bfloat16 format natively, which can result in faster
* computation compared to {@link TFloat16} when GPUs are not used.
*/
public interface TBfloat16 extends FloatNdArray, TNumber {
public interface TBfloat16 extends FloatNdArray, TFloating {
/** readable-name for the data type */
static final String NAME = "BFLOAT16";

Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TFloating;

/**
* IEEE-754 half-precision 16-bit float tensor type.
@@ -45,7 +45,7 @@
* most CPUs do not support this format natively. For CPU computation on 16-bit floats, the {@link
* TBfloat16} tensor type might be a better option.
*/
public interface TFloat16 extends FloatNdArray, TNumber {
public interface TFloat16 extends FloatNdArray, TFloating {

/** readable-name for the data type */
static final String NAME = "FLOAT16";
Original file line number Diff line number Diff line change
@@ -29,10 +29,10 @@
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TFloating;

/** IEEE-754 single-precision 32-bit float tensor type. */
public interface TFloat32 extends FloatNdArray, TNumber {
public interface TFloat32 extends FloatNdArray, TFloating {

/** readable-name for the data type */
static final String NAME = "FLOAT";
Original file line number Diff line number Diff line change
@@ -29,10 +29,11 @@
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TFloating;


/** IEEE-754 double-precision 64-bit float tensor type. */
public interface TFloat64 extends DoubleNdArray, TNumber {
public interface TFloat64 extends DoubleNdArray, TFloating {

/** readable-name for the data type */
static final String NAME = "DOUBLE";
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.tensorflow.types.family;

/**
* Marker interface for floating point tensor types.
*
* <p>Operations that only accepts floating point values as some of their operands enforce that the tensor
* types for these operands to be bound to this interface. For example:
*
* <pre>{@code
* TFloat32 tensor1 = TFloat32.vectorOf(1, 2, 3);
* TBool tensor2 = TBool.vectorOf(true, false, true);
*
* Ops tf = Ops.create();
* Exponential<TFloat32> exp = new Exponential<>(tf);
* exp.call(tf.constant(tensor1)); // OK
* exp.call(tf.constant(tensor2)); // Compilation failure
* }</pre>
*/
public interface TFloating extends TNumber {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.activations;

import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

/**
* Abstract base class for Activations
*
* <p><b>Note:</b> The {@link #tf} attribute must be set prior to invoking the call method. See
* {@link #setTF(Ops)} and the constructor {@link #Activation(Ops)}.
*
* @param <T> the data type of the activation
*/
public abstract class Activation<T extends TNumber> {

/** The TensorFlow Ops */
protected Ops tf;

/**
* Creates the abstract class for an Activation
*
* @param tf the TensorFlow Ops
*/
protected Activation(Ops tf) {
this.tf = tf;
}

/**
* Sets the TensorFlow Ops
*
* @param tf the TensorFlow Ops
*/
protected void setTF(Ops tf) {
this.tf = tf;
}

/**
* Gets the TensorFlow Ops
*
* @return the TensorFlow Ops
*/
protected Ops getTF() {
return this.tf;
}

/**
* Gets the calculation operation for the activation.
*
* @param input the input tensor
* @return The operand for the activation
*/
public abstract Operand<T> call(Operand<T> input);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.activations;

import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.types.TBool;
import org.tensorflow.types.family.TFloating;

/**
* Exponential linear unit.
*
* <p>The exponential linear unit (ELU) with <code>alpha &gt; 0</code> is:
*
* <p><code>x</code> if <code>x &gt; 0</code> and <code>alpha * (exp(x) -
* 1)</code> if <code>x &lt; 0</code>.
*
* <p>The ELU hyperparameter <code>alpha</code> controls the value to which an ELU saturates for
* negative net inputs. ELUs diminish the vanishing gradient effect.
*
* <p>ELUs have negative values which pushes the mean of the activations closer to zero. Mean
* activations that are closer to zero enable faster learning as they bring the gradient closer to
* the natural gradient. ELUs saturate to a negative value when the argument gets smaller.
* Saturation means a small derivative which decreases the variation and the information that is
* propagated to the next layer.
*
* <p>Example Usage:
*
* <pre>
* Operand&lt;TFloat32&gt; input = &#46;&#46;&#46;;
* ELU&lt;TFloat32&gt; elu = new ELU&lt;&gt;(tf, 2.0f);
* Operand&lt;TFloat32&gt; result = elu.call(input);
* </pre>
*
* @param <T> the data type of the activation
* @see <a href="https://arxiv.org/abs/1511.07289">Clevert et al, 2016, Fast and Accurate Deep
* Network Learning by Exponential Linear Units (ELUs)</a>
*/
public class ELU<T extends TFloating> extends Activation<T> {

private static final double ALPHA_DEFAULT = 1.0;

/** A scalar, slope of negative section. */
private final double alpha;

/**
* Creates a new ELU with alpha={@link #ALPHA_DEFAULT}.
*
* @param tf the TensorFlow Ops
*/
public ELU(Ops tf) {
this(tf, ALPHA_DEFAULT);
}

/**
* Creates a new ELU
*
* @param tf the TensorFlow Ops
* @param alpha A scalar, slope of negative section. It controls the value to which an ELU
* saturates for negative net inputs.
*/
public ELU(Ops tf, double alpha) {
super(tf);
this.alpha = alpha;
}

/**
* Gets the calculation operation for the activation.
*
* @param input the input tensor
* @return The operand for the activation
*/
@Override
public Operand<T> call(Operand<T> input) {

Operand<T> result = tf.nn.elu(input);
if (alpha == 1.0) return result;
else {
DataType<T> dataType = input.asOutput().dataType();
Operand<T> y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), dataType));
Operand<TBool> cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), dataType));
return tf.select(cond, result, y);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.activations;

import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TFloating;

/**
* Exponential activation function.
*
* <p>For example:
*
* <pre>
* Operand&lt;TFloat32&gt; input = tf.constant(
* new float[] {-3.0f,-1.0f, 0.0f,1.0f,3.0f});
* Exponential&lt;TFloat32&gt; exp = new Exponential&lt;&gt;(tf);
* Operand&lt;TFloat32&gt; result = exp.call(input);
* // result is [0.04978707f, 0.36787945f, 1.f, 2.7182817f, 20.085537f]
* </pre>
*
* @param <T> the data type of the activation
*/
public class Exponential<T extends TFloating> extends Activation<T> {

/**
* Creates an Exponential activation.
*
* @param tf the TensorFlow Ops
*/
public Exponential(Ops tf) {
super(tf);
}

/**
* Calculates the Exponential activation.
*
* @param input the input tensor
* @return an Operand for the exponential activation: <code>exp(x)</code>.
*/
@Override
public Operand<T> call(Operand<T> input) {
return tf.math.exp(input);
}
}
Loading