Skip to content

Fetch resource variable fix #276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tensorflow-core/tensorflow-core-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,10 @@
</execution>
</executions>
<configuration>
<!-- Activate the use of TCP to transmit events to the plugin -->
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
<!-- Activate the use of TCP to transmit events to the plugin -->
<!-- disabled as it appears to cause intermittent test failures in GitHub Actions
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
-->
<additionalClasspathElements>
<additionalClasspathElement>${project.build.directory}/${project.artifactId}-${project.version}-${native.classifier}.jar</additionalClasspathElement>
<!-- Note: the following path is not accessible in deploying profile, so other libraries like
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
* <p>Usage of this class is reserved for internal purposes only.
*
* @param <T> tensor type mapped by this object
* @see {@link TType}
* @see TType
*/
public abstract class TensorMapper<T extends TType> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
* An operator creating a constant initialized with zeros of the shape given by `dims`.
*
* <p>For example, the following expression
* <pre>{@code tf.zeros(tf.constant(shape), TFloat32.class)</pre>
* <pre>{@code tf.zeros(tf.constant(shape), TFloat32.class)}</pre>
* is the equivalent of
* <pre>{@code tf.fill(tf.constant(shape), tf.constant(0.0f))</pre>
* <pre>{@code tf.fill(tf.constant(shape), tf.constant(0.0f))}</pre>
*
* @param <T> constant type
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
*
* <p>Subinterfaces of {@code TType} are propagated as a generic parameter to various entities of
* TensorFlow to identify the type of the tensor they carry. For example, a
* {@link org.tensorflow.Operand Operand<TFloat32>} is an operand which outputs a 32-bit floating
* {@link org.tensorflow.Operand Operand&lt;TFloat32&gt;} is an operand which outputs a 32-bit floating
* point tensor. This parameter ensure type-compatibility between operands of a computation at
* compile-time. For example:
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ public void pythonTfFunction() {
* Signature name used for saving 'add', argument names 'a' and 'b'
*/
ConcreteFunction add = bundle.function("add");
Map<String, Tensor> args = new HashMap();
Map<String, Tensor> args = new HashMap<>();
try (TFloat32 a = TFloat32.scalarOf(10.0f);
TFloat32 b = TFloat32.scalarOf(15.5f)) {
args.put("a", a);
Expand All @@ -301,14 +301,19 @@ public void pythonTfFunction() {
assertEquals(25.5f, c.getFloat());
}
}
args.clear();

// variable unwrapping happens in Session, which is used by ConcreteFunction.call
ConcreteFunction getVariable = bundle.function("get_variable");
try (TFloat32 v = (TFloat32) getVariable.call(new HashMap<>())
.get(getVariable.signature().outputNames().iterator().next())) {
assertEquals(2f, v.getFloat());
try (TFloat32 dummy = TFloat32.scalarOf(1.0f)) {
args.put("dummy",dummy);
// TF functions always require an input, so we supply a dummy one here
// This test actually checks that resource variables can be loaded correctly.
try (TFloat32 v = (TFloat32) getVariable.call(args)
.get(getVariable.signature().outputNames().iterator().next())) {
assertEquals(2f, v.getFloat());
}
}

}
}

Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import tensorflow as tf


class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
Expand All @@ -42,8 +43,7 @@ def get_scalar(self, x):
def get_vector(self, x):
return self.const_vector + x

@tf.function(input_signature=[
tf.TensorSpec(shape=None, dtype=tf.float32, name='input')])
@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')])
def get_matrix(self, x):
return self.const_matrix + x

Expand All @@ -53,10 +53,14 @@ def get_matrix(self, x):
def add(self, a, b):
return a + b

@tf.function(input_signature=[])
def get_variable(self):
#TF functions always require an input
@tf.function(input_signature=[
tf.TensorSpec(shape=None, dtype=tf.float32, name='dummy')
])
def get_variable(self, dummy):
return self.variable


model = MyModel()

signatures = {
Expand Down