Skip to content

Commit daeb257

Browse files
authored
Native functions v2 (#233)
* Initial native function use Signed-off-by: Ryan Nett <[email protected]> * Allow body constants Signed-off-by: Ryan Nett <[email protected]> * Fix body forbids Signed-off-by: Ryan Nett <[email protected]> * Use default eager session for tensor calls Signed-off-by: Ryan Nett <[email protected]> * Use default eager for single tensor call too Signed-off-by: Ryan Nett <[email protected]> * Get functions from graph Signed-off-by: Ryan Nett <[email protected]> * Start of saver support Signed-off-by: Ryan Nett <[email protected]> * Update loading, detect statefulness, use PartitionedCall Signed-off-by: Ryan Nett <[email protected]> * Start of dependencies Signed-off-by: Ryan Nett <[email protected]> * Support dependencies Signed-off-by: Ryan Nett <[email protected]> * Remove unwrapping Signed-off-by: Ryan Nett <[email protected]> * Proper attribute setters Signed-off-by: Ryan Nett <[email protected]> * Add ignored gradient test Signed-off-by: Ryan Nett <[email protected]> * Rebase fix Signed-off-by: Ryan Nett <[email protected]> * Op generation for functions Signed-off-by: Ryan Nett <[email protected]> * Rebase fix Signed-off-by: Ryan Nett <[email protected]> * SavedFunction for running functions from SavedModelBundles Signed-off-by: Ryan Nett <[email protected]> * Review fixes Signed-off-by: Ryan Nett <[email protected]> * Generation and better javadoc Signed-off-by: Ryan Nett <[email protected]> * Rework pointer scopes Signed-off-by: Ryan Nett <[email protected]> * SessionFunction instead of SavedModelBundle specific class Signed-off-by: Ryan Nett <[email protected]> * Add CallableFunction javadoc Signed-off-by: Ryan Nett <[email protected]> * Remove obsolete test Signed-off-by: Ryan Nett <[email protected]> * Rebase fix Signed-off-by: Ryan Nett <[email protected]> * Formatting fixes and nits Signed-off-by: Ryan Nett <[email protected]> * Add session function test, Signature.builder with name Signed-off-by: Ryan Nett <[email protected]> * Remove extra synchronization Signed-off-by: Ryan Nett <[email protected]> * Formatting Signed-off-by: Ryan Nett <[email protected]> * New names Signed-off-by: Ryan Nett <[email protected]> * Note on SavedModel functions Signed-off-by: Ryan Nett <[email protected]> * Fix tests Signed-off-by: Ryan Nett <[email protected]> * Rename name method Signed-off-by: Ryan Nett <[email protected]> * Re-add tests w/ SessionFunction Signed-off-by: Ryan Nett <[email protected]> * Helper methods for saving Signed-off-by: Ryan Nett <[email protected]>
1 parent 3b4533c commit daeb257

25 files changed

+2588
-864
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

+28
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import java.nio.charset.Charset;
2121
import java.util.List;
22+
import java.util.Map;
23+
import org.tensorflow.ConcreteFunction;
2224
import org.tensorflow.DeviceSpec;
2325
import org.tensorflow.EagerSession;
2426
import org.tensorflow.ExecutionEnvironment;
@@ -87,6 +89,7 @@
8789
import org.tensorflow.op.core.ExtractVolumePatches;
8890
import org.tensorflow.op.core.Fill;
8991
import org.tensorflow.op.core.Fingerprint;
92+
import org.tensorflow.op.core.Function;
9093
import org.tensorflow.op.core.Gather;
9194
import org.tensorflow.op.core.GatherNd;
9295
import org.tensorflow.op.core.GetSessionHandle;
@@ -1116,6 +1119,31 @@ public Bucketize bucketize(Operand<? extends TNumber> input, List<Float> boundar
11161119
return Bucketize.create(scope, input, boundaries);
11171120
}
11181121

1122+
/**
1123+
* Calls the function in an execution environment, adding its graph as a function if it isn't
1124+
* already present. Only works for functions with a single input and output.
1125+
*
1126+
* @param argument the argument to the call
1127+
* @return the output of the function
1128+
* @see ConcreteFunction#call(Ops, Operand)
1129+
*/
1130+
public Operand<?> call(ConcreteFunction function, Operand<?> argument) {
1131+
return Function.call(scope, function, argument);
1132+
}
1133+
1134+
/**
1135+
* Calls the function in an execution environment, adding its graph as a function if it isn't
1136+
* already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
1137+
*
1138+
* @param arguments the arguments to the call
1139+
* @return the outputs of the function
1140+
* @see ConcreteFunction#call(Ops, Map)
1141+
*/
1142+
public Map<String, Operand<?>> call(ConcreteFunction function,
1143+
Map<String, Operand<?>> arguments) {
1144+
return Function.call(scope, function, arguments);
1145+
}
1146+
11191147
/**
11201148
* Clips tensor values to a specified min and max.
11211149
* Given a tensor {@code t}, this operation returns a tensor of the same type and

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// Once created and added to graphs, functions can be invoked by creating an
1414
// operation whose operation type matches the function name.
1515
@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
16-
public class TF_Function extends Pointer {
16+
public class TF_Function extends org.tensorflow.internal.c_api.AbstractTF_Function {
1717
/** Empty constructor. Calls {@code super((Pointer)null)}. */
1818
public TF_Function() { super((Pointer)null); }
1919
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */

0 commit comments

Comments
 (0)