From 032db2ab006865744db8f8badfb8be0eb3969d96 Mon Sep 17 00:00:00 2001 From: Ryan Nett <JNett96@gmail.com> Date: Mon, 31 May 2021 15:54:47 -0700 Subject: [PATCH 01/14] Start of function ops Signed-off-by: Ryan Nett <JNett96@gmail.com> --- .../org/tensorflow/op/core/BatchFunction.java | 353 ++++++++++++++++++ .../gen/java/org/tensorflow/op/core/Case.java | 185 +++++++++ .../gen/java/org/tensorflow/op/core/For.java | 104 ++++++ .../op/core/GroupByReducerDataset.java | 119 ++++++ .../gen/java/org/tensorflow/op/core/If.java | 175 +++++++++ .../java/org/tensorflow/op/core/MapDefun.java | 158 ++++++++ .../tensorflow/op/core/PartitionedCall.java | 188 ++++++++++ .../org/tensorflow/op/core/ReduceDataset.java | 144 +++++++ .../org/tensorflow/op/core/RemoteCall.java | 95 +++++ .../op/core/StatefulPartitionedCall.java | 189 ++++++++++ .../org/tensorflow/op/core/StatelessCase.java | 186 +++++++++ .../org/tensorflow/op/core/StatelessIf.java | 178 +++++++++ .../tensorflow/op/core/StatelessWhile.java | 198 ++++++++++ .../java/org/tensorflow/op/core/While.java | 195 ++++++++++ .../op/data/ChooseFastestBranchDataset.java | 115 ++++++ .../org/tensorflow/op/data/FilterDataset.java | 107 ++++++ .../tensorflow/op/data/FlatMapDataset.java | 105 ++++++ .../tensorflow/op/data/GeneratorDataset.java | 103 +++++ .../op/data/GroupByWindowDataset.java | 108 ++++++ .../tensorflow/op/data/InterleaveDataset.java | 112 ++++++ .../org/tensorflow/op/data/LoadDataset.java | 136 +++++++ .../org/tensorflow/op/data/MapDataset.java | 165 ++++++++ .../tensorflow/op/data/OneShotIterator.java | 178 +++++++++ .../op/data/ParallelMapDataset.java | 193 ++++++++++ .../org/tensorflow/op/data/SaveDataset.java | 133 +++++++ .../org/tensorflow/op/data/ScanDataset.java | 163 ++++++++ .../tensorflow/op/data/SnapshotDataset.java | 200 ++++++++++ .../tensorflow/op/data/TakeWhileDataset.java | 103 +++++ .../experimental/GroupByReducerDataset.java | 119 ++++++ .../experimental/GroupByWindowDataset.java | 108 ++++++ .../LegacyParallelInterleaveDataset.java | 155 ++++++++ .../data/experimental/MapAndBatchDataset.java | 154 ++++++++ .../op/data/experimental/MapDataset.java | 161 ++++++++ .../ParallelInterleaveDataset.java | 177 +++++++++ .../op/data/experimental/ScanDataset.java | 137 +++++++ .../data/experimental/TakeWhileDataset.java | 103 +++++ .../java/org/tensorflow/op/tpu/Compile.java | 134 +++++++ .../tensorflow/op/tpu/PartitionedCall.java | 133 +++++++ .../tensorflow/op/train/SymbolicGradient.java | 107 ++++++ .../gen/java/org/tensorflow/op/xla/If.java | 102 +++++ .../java/org/tensorflow/op/xla/Reduce.java | 97 +++++ .../org/tensorflow/op/xla/ReduceWindow.java | 104 ++++++ .../java/org/tensorflow/op/xla/Scatter.java | 100 +++++ .../tensorflow/op/xla/SelectAndScatter.java | 104 ++++++ .../gen/java/org/tensorflow/op/xla/While.java | 101 +++++ .../org/tensorflow/op/xla/XlaHostCompute.java | 177 +++++++++ .../java/org/tensorflow/op/xla/XlaLaunch.java | 99 +++++ .../tensorflow/op/xla/XlaVariadicReduce.java | 105 ++++++ .../op/generator/ClassGenerator.java | 1 - 49 files changed, 6865 insertions(+), 1 deletion(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BatchFunction.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GroupByReducerDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapDefun.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteCall.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulPartitionedCall.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessCase.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessIf.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessWhile.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestBranchDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FlatMapDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GeneratorDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByWindowDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InterleaveDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LoadDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MapDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OneShotIterator.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelMapDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SaveDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ScanDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeWhileDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByReducerDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByWindowDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LegacyParallelInterleaveDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapAndBatchDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParallelInterleaveDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ScanDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/TakeWhileDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/Compile.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/PartitionedCall.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SymbolicGradient.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/If.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Reduce.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/ReduceWindow.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Scatter.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SelectAndScatter.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/While.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaHostCompute.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaLaunch.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicReduce.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BatchFunction.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BatchFunction.java new file mode 100644 index 00000000000..80841c9b28f --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BatchFunction.java @@ -0,0 +1,353 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * Batches all the inputs tensors to the computation done by the function. + * So, for example, in the following code + * <pre> + * + * # This input will be captured. + * y = tf.placeholder_with_default(1.0, shape=[]) + * + * {@literal @}tf.Defun(tf.float32) + * def computation(a): + * return tf.matmul(a, a) + y + * + * b = gen_batch_ops.batch_function( + * f=computation + * in_tensors=[a], + * captured_tensors=computation.captured_inputs, + * Tout=[o.type for o in computation.definition.signature.output_arg], + * num_batch_threads=1, + * max_batch_size=10, + * batch_timeout_micros=100000, # 100ms + * allowed_batch_sizes=[3, 10], + * batching_queue="") + * </pre> + * <p>If more than one session.run call is simultaneously trying to compute {@code b} + * the values of {@code a} will be gathered, non-deterministically concatenated + * along the first axis, and only one thread will run the computation. + * <p>Assumes that all arguments of the function are Tensors which will be batched + * along their first dimension. + * <p>Arguments that are captured, are not batched. The session.run call which does + * the concatenation, will use the values of the captured tensors available to it. + * Therefore, typical uses of captured tensors should involve values which remain + * unchanged across session.run calls. Inference is a good example of this. + * <p>SparseTensor is not supported. The return value of the decorated function + * must be a Tensor or a list/tuple of Tensors. + */ +@Operator +public final class BatchFunction extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "BatchFunction"; + + private List<Output<?>> outTensors; + + @SuppressWarnings("unchecked") + private BatchFunction(Operation operation) { + super(operation); + int outputIdx = 0; + int outTensorsLength = operation.outputListLength("out_tensors"); + outTensors = Arrays.asList(operation.outputList(outputIdx, outTensorsLength)); + outputIdx += outTensorsLength; + } + + /** + * Factory method to create a class wrapping a new BatchFunction operation. + * + * @param scope current scope + * @param inTensors The tensors to be batched. + * @param capturedTensors The tensors which are captured in the function, and don't need + * to be batched. + * @param f the value of the f property + * @param numBatchThreads Number of scheduling threads for processing batches of work. + * Determines the number of batches processed in parallel. + * @param maxBatchSize Batch sizes will never be bigger than this. + * @param batchTimeoutMicros Maximum number of microseconds to wait before outputting + * an incomplete batch. + * @param Tout the types of the output tensors. + * @param options carries optional attribute values + * @return a new instance of BatchFunction + */ + @Endpoint( + describeByClass = true + ) + public static BatchFunction create(Scope scope, Iterable<Operand<?>> inTensors, + Iterable<Operand<?>> capturedTensors, ConcreteFunction f, Long numBatchThreads, + Long maxBatchSize, Long batchTimeoutMicros, List<Class<? extends TType>> Tout, + Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("BatchFunction", scope.makeOpName("BatchFunction")); + opBuilder.addInputList(Operands.asOutputs(inTensors)); + opBuilder.addInputList(Operands.asOutputs(capturedTensors)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("num_batch_threads", numBatchThreads); + opBuilder.setAttr("max_batch_size", maxBatchSize); + opBuilder.setAttr("batch_timeout_micros", batchTimeoutMicros); + opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); + if (options != null) { + for (Options opts : options) { + if (opts.maxEnqueuedBatches != null) { + opBuilder.setAttr("max_enqueued_batches", opts.maxEnqueuedBatches); + } + if (opts.allowedBatchSizes != null) { + long[] allowedBatchSizesArray = new long[opts.allowedBatchSizes.size()]; + for (int i = 0 ; i < allowedBatchSizesArray.length ; i++) { + allowedBatchSizesArray[i] = opts.allowedBatchSizes.get(i); + } + opBuilder.setAttr("allowed_batch_sizes", allowedBatchSizesArray); + } + if (opts.container != null) { + opBuilder.setAttr("container", opts.container); + } + if (opts.sharedName != null) { + opBuilder.setAttr("shared_name", opts.sharedName); + } + if (opts.batchingQueue != null) { + opBuilder.setAttr("batching_queue", opts.batchingQueue); + } + if (opts.enableLargeBatchSplitting != null) { + opBuilder.setAttr("enable_large_batch_splitting", opts.enableLargeBatchSplitting); + } + } + } + return new BatchFunction(opBuilder.build()); + } + + /** + * Sets the maxEnqueuedBatches option. + * + * @param maxEnqueuedBatches Maximum number of batches enqueued. Default: 10. + * @return this Options instance. + */ + public static Options maxEnqueuedBatches(Long maxEnqueuedBatches) { + return new Options().maxEnqueuedBatches(maxEnqueuedBatches); + } + + /** + * Sets the allowedBatchSizes option. + * + * @param allowedBatchSizes Optional list of allowed batch sizes. If left empty, does + * nothing. Otherwise, supplies a list of batch sizes, causing the op to pad + * batches up to one of those sizes. The entries must increase monotonically. + * If enable_large_batch_splitting is false (i.e., large-input-split is not + * enabled) the final entry must equal max_batch_size. + * @return this Options instance. + */ + public static Options allowedBatchSizes(List<Long> allowedBatchSizes) { + return new Options().allowedBatchSizes(allowedBatchSizes); + } + + /** + * Sets the allowedBatchSizes option. + * + * @param allowedBatchSizes Optional list of allowed batch sizes. If left empty, does + * nothing. Otherwise, supplies a list of batch sizes, causing the op to pad + * batches up to one of those sizes. The entries must increase monotonically. + * If enable_large_batch_splitting is false (i.e., large-input-split is not + * enabled) the final entry must equal max_batch_size. + * @return this Options instance. + */ + public static Options allowedBatchSizes(Long[] allowedBatchSizes) { + return new Options().allowedBatchSizes(allowedBatchSizes); + } + + /** + * Sets the container option. + * + * @param container Controls the scope of sharing of this batch. + * @return this Options instance. + */ + public static Options container(String container) { + return new Options().container(container); + } + + /** + * Sets the sharedName option. + * + * @param sharedName Concurrently running instances of batch in the same device with the + * same container and shared_name will batch their elements together. If left + * empty, the op name will be used as the shared name. + * @return this Options instance. + */ + public static Options sharedName(String sharedName) { + return new Options().sharedName(sharedName); + } + + /** + * Sets the batchingQueue option. + * + * @param batchingQueue the batchingQueue option + * @return this Options instance. + */ + public static Options batchingQueue(String batchingQueue) { + return new Options().batchingQueue(batchingQueue); + } + + /** + * Sets the enableLargeBatchSplitting option. + * + * @param enableLargeBatchSplitting input with a large size (i.e., larger than the largest value of + * {@code allowed_batch_sizes}) will be splitted into multiple batches with batch size. + * @return this Options instance. + */ + public static Options enableLargeBatchSplitting(Boolean enableLargeBatchSplitting) { + return new Options().enableLargeBatchSplitting(enableLargeBatchSplitting); + } + + /** + * Gets outTensors. + * The output tensors. + * @return outTensors. + */ + public List<Output<?>> outTensors() { + return outTensors; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) outTensors.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.BatchFunction} + */ + public static class Options { + private Long maxEnqueuedBatches; + + private List<Long> allowedBatchSizes; + + private String container; + + private String sharedName; + + private String batchingQueue; + + private Boolean enableLargeBatchSplitting; + + private Options() { + } + + /** + * Sets the maxEnqueuedBatches option. + * + * @param maxEnqueuedBatches Maximum number of batches enqueued. Default: 10. + * @return this Options instance. + */ + public Options maxEnqueuedBatches(Long maxEnqueuedBatches) { + this.maxEnqueuedBatches = maxEnqueuedBatches; + return this; + } + + /** + * Sets the allowedBatchSizes option. + * + * @param allowedBatchSizes Optional list of allowed batch sizes. If left empty, does + * nothing. Otherwise, supplies a list of batch sizes, causing the op to pad + * batches up to one of those sizes. The entries must increase monotonically. + * If enable_large_batch_splitting is false (i.e., large-input-split is not + * enabled) the final entry must equal max_batch_size. + * @return this Options instance. + */ + public Options allowedBatchSizes(List<Long> allowedBatchSizes) { + this.allowedBatchSizes = allowedBatchSizes; + return this; + } + + /** + * Sets the allowedBatchSizes option. + * + * @param allowedBatchSizes Optional list of allowed batch sizes. If left empty, does + * nothing. Otherwise, supplies a list of batch sizes, causing the op to pad + * batches up to one of those sizes. The entries must increase monotonically. + * If enable_large_batch_splitting is false (i.e., large-input-split is not + * enabled) the final entry must equal max_batch_size. + * @return this Options instance. + */ + public Options allowedBatchSizes(Long... allowedBatchSizes) { + this.allowedBatchSizes = Arrays.asList(allowedBatchSizes); + return this; + } + + /** + * Sets the container option. + * + * @param container Controls the scope of sharing of this batch. + * @return this Options instance. + */ + public Options container(String container) { + this.container = container; + return this; + } + + /** + * Sets the sharedName option. + * + * @param sharedName Concurrently running instances of batch in the same device with the + * same container and shared_name will batch their elements together. If left + * empty, the op name will be used as the shared name. + * @return this Options instance. + */ + public Options sharedName(String sharedName) { + this.sharedName = sharedName; + return this; + } + + /** + * Sets the batchingQueue option. + * + * @param batchingQueue the batchingQueue option + * @return this Options instance. + */ + public Options batchingQueue(String batchingQueue) { + this.batchingQueue = batchingQueue; + return this; + } + + /** + * Sets the enableLargeBatchSplitting option. + * + * @param enableLargeBatchSplitting input with a large size (i.e., larger than the largest value of + * {@code allowed_batch_sizes}) will be splitted into multiple batches with batch size. + * @return this Options instance. + */ + public Options enableLargeBatchSplitting(Boolean enableLargeBatchSplitting) { + this.enableLargeBatchSplitting = enableLargeBatchSplitting; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java new file mode 100644 index 00000000000..8ea144ece5b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; + +/** + * An n-way switch statement which calls a single branch function. + * <pre> + * An n-way switch statement, implementing the following: + * ``` + * switch (branch_index) { + * case 0: + * output = branches[0](input); + * break; + * case 1: + * output = branches[1](input); + * break; + * ... + * case [[nbranches-1]]: + * default: + * output = branches[nbranches-1](input); + * break; + * } + * ``` + * </pre> + */ +@Operator +public final class Case extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "Case"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private Case(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new Case operation. + * + * @param scope current scope + * @param branchIndex The branch selector, an int32 Tensor. + * @param input A list of input tensors passed to the branch function. + * @param Tout A list of output types. + * @param branches <pre> + * A list of functions each of which takes 'inputs' and returns a list of + * tensors, whose types are the same as what every other branch returns. + * </pre> + * @param options carries optional attribute values + * @return a new instance of Case + */ + @Endpoint( + describeByClass = true + ) + public static Case create(Scope scope, Operand<TInt32> branchIndex, Iterable<Operand<?>> input, + List<Class<? extends TType>> Tout, List<ConcreteFunction> branches, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("Case", scope.makeOpName("Case")); + opBuilder.addInput(branchIndex.asOutput()); + opBuilder.addInputList(Operands.asOutputs(input)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); + ConcreteFunction[] branchesArray = new ConcreteFunction[branches.size()]; + for (int i = 0 ; i < branchesArray.length ; i++) { + branchesArray[i] = branches.get(i); + } + opBuilder.setAttr("branches", branchesArray); + if (options != null) { + for (Options opts : options) { + if (opts.outputShapes != null) { + Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = opts.outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + } + } + } + return new Case(opBuilder.build()); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(List<Shape> outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(Shape[] outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Gets output. + * A list of return values. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.Case} + */ + public static class Options { + private List<Shape> outputShapes; + + private Options() { + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(List<Shape> outputShapes) { + this.outputShapes = outputShapes; + return this; + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(Shape... outputShapes) { + this.outputShapes = Arrays.asList(outputShapes); + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java new file mode 100644 index 00000000000..4ce4fc1da35 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; + +/** + * <pre> + * output = input; + * for i in range(start, limit, delta) + * output = body(i, output); + * </pre> + */ +@Operator +public final class For extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "For"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private For(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new For operation. + * + * @param scope current scope + * @param start The lower bound. An int32 + * @param limit The upper bound. An int32 + * @param delta The increment. An int32 + * @param input A list of input tensors whose types are T. + * @param body <pre> + * A function that takes a list of tensors (int32, T) and returns another + * list of tensors (T). + * </pre> + * @return a new instance of For + */ + @Endpoint( + describeByClass = true + ) + public static For create(Scope scope, Operand<TInt32> start, Operand<TInt32> limit, + Operand<TInt32> delta, Iterable<Operand<?>> input, ConcreteFunction body) { + OperationBuilder opBuilder = scope.env().opBuilder("For", scope.makeOpName("For")); + opBuilder.addInput(start.asOutput()); + opBuilder.addInput(limit.asOutput()); + opBuilder.addInput(delta.asOutput()); + opBuilder.addInputList(Operands.asOutputs(input)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("body", body); + return new For(opBuilder.build()); + } + + /** + * Gets output. + * A list of output tensors whose types are T. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GroupByReducerDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GroupByReducerDataset.java new file mode 100644 index 00000000000..b0e9ba81d19 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GroupByReducerDataset.java @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that computes a group-by on {@code input_dataset}. + * Creates a dataset that computes a group-by on {@code input_dataset}. + */ +public final class GroupByReducerDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "GroupByReducerDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private GroupByReducerDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new GroupByReducerDataset operation. + * + * @param scope current scope + * @param inputDataset A variant tensor representing the input dataset. + * @param keyFuncOtherArguments A list of tensors, typically values that were captured when + * building a closure for {@code key_func}. + * @param initFuncOtherArguments A list of tensors, typically values that were captured when + * building a closure for {@code init_func}. + * @param reduceFuncOtherArguments A list of tensors, typically values that were captured when + * building a closure for {@code reduce_func}. + * @param finalizeFuncOtherArguments A list of tensors, typically values that were captured when + * building a closure for {@code finalize_func}. + * @param keyFunc A function mapping an element of {@code input_dataset}, concatenated + * with {@code key_func_other_arguments} to a scalar value of type DT_INT64. + * @param initFunc A function mapping a key of type DT_INT64, concatenated with + * {@code init_func_other_arguments} to the initial reducer state. + * @param reduceFunc A function mapping the current reducer state and an element of {@code input_dataset}, + * concatenated with {@code reduce_func_other_arguments} to a new reducer state. + * @param finalizeFunc A function mapping the final reducer state to an output element. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of GroupByReducerDataset + */ + @Endpoint( + describeByClass = true + ) + public static GroupByReducerDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> keyFuncOtherArguments, Iterable<Operand<?>> initFuncOtherArguments, + Iterable<Operand<?>> reduceFuncOtherArguments, + Iterable<Operand<?>> finalizeFuncOtherArguments, ConcreteFunction keyFunc, + ConcreteFunction initFunc, ConcreteFunction reduceFunc, ConcreteFunction finalizeFunc, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + OperationBuilder opBuilder = scope.env().opBuilder("GroupByReducerDataset", scope.makeOpName("GroupByReducerDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(keyFuncOtherArguments)); + opBuilder.addInputList(Operands.asOutputs(initFuncOtherArguments)); + opBuilder.addInputList(Operands.asOutputs(reduceFuncOtherArguments)); + opBuilder.addInputList(Operands.asOutputs(finalizeFuncOtherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("key_func", keyFunc); + opBuilder.setAttr("init_func", initFunc); + opBuilder.setAttr("reduce_func", reduceFunc); + opBuilder.setAttr("finalize_func", finalizeFunc); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + return new GroupByReducerDataset(opBuilder.build()); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java new file mode 100644 index 00000000000..c178111e27b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java @@ -0,0 +1,175 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * output = cond ? then_branch(input) : else_branch(input) + */ +@Operator +public final class If extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "If"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private If(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new If operation. + * + * @param scope current scope + * @param cond <pre> + * A Tensor. If the tensor is a scalar of non-boolean type, the + * scalar is converted to a boolean according to the + * following rule: if the scalar is a numerical value, non-zero means + * `True` and zero means False; if the scalar is a string, non-empty + * means `True` and empty means `False`. If the tensor is not a scalar, + * being empty means False and being non-empty means True. + * </pre> + * @param input A list of input tensors. + * @param Tout A list of output types. + * @param thenBranch <pre> + * A function that takes 'inputs' and returns a list of tensors, whose + * types are the same as what else_branch returns. + * </pre> + * @param elseBranch <pre> + * A function that takes 'inputs' and returns a list of tensors, whose + * types are the same as what then_branch returns. + * </pre> + * @param options carries optional attribute values + * @return a new instance of If + */ + @Endpoint( + describeByClass = true + ) + public static If create(Scope scope, Operand<? extends TType> cond, Iterable<Operand<?>> input, + List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, ConcreteFunction elseBranch, + Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("If", scope.makeOpName("If")); + opBuilder.addInput(cond.asOutput()); + opBuilder.addInputList(Operands.asOutputs(input)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); + opBuilder.setAttr("then_branch", thenBranch); + opBuilder.setAttr("else_branch", elseBranch); + if (options != null) { + for (Options opts : options) { + if (opts.outputShapes != null) { + Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = opts.outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + } + } + } + return new If(opBuilder.build()); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(List<Shape> outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(Shape[] outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Gets output. + * A list of return values. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.If} + */ + public static class Options { + private List<Shape> outputShapes; + + private Options() { + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(List<Shape> outputShapes) { + this.outputShapes = outputShapes; + return this; + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(Shape... outputShapes) { + this.outputShapes = Arrays.asList(outputShapes); + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapDefun.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapDefun.java new file mode 100644 index 00000000000..7a99f926eb2 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapDefun.java @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * Maps a function on the list of tensors unpacked from arguments on dimension 0. + * The function given by {@code f} is assumed to be stateless, and is executed + * concurrently on all the slices; up to batch_size (i.e. the size of the 0th + * dimension of each argument) functions will be scheduled at once. + * <p>The {@code max_intra_op_parallelism} attr, which defaults to 1, can be used to + * limit the intra op parallelism. To limit inter-op parallelism, a user can + * set a private threadpool on the dataset using {@code tf.data.Options}'s + * {@code ThreadingOptions}. + * <p>Note that this op is not exposed to users directly, but is invoked in tf.data + * rewrites. + */ +public final class MapDefun extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "MapDefun"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private MapDefun(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new MapDefun operation. + * + * @param scope current scope + * @param arguments <pre> + * A list of tensors whose types are `Targuments`, corresponding to the inputs + * the function should be mapped over. + * </pre> + * @param capturedInputs <pre> + * A list of tensors whose types are `Tcaptured`, corresponding to the captured + * inputs of the defun. + * </pre> + * @param outputTypes A list of types. + * @param outputShapes A list of shapes. + * @param f the value of the f property + * @param options carries optional attribute values + * @return a new instance of MapDefun + */ + @Endpoint( + describeByClass = true + ) + public static MapDefun create(Scope scope, Iterable<Operand<?>> arguments, + Iterable<Operand<?>> capturedInputs, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes, ConcreteFunction f, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("MapDefun", scope.makeOpName("MapDefun")); + opBuilder.addInputList(Operands.asOutputs(arguments)); + opBuilder.addInputList(Operands.asOutputs(capturedInputs)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + opBuilder.setAttr("f", f); + if (options != null) { + for (Options opts : options) { + if (opts.maxIntraOpParallelism != null) { + opBuilder.setAttr("max_intra_op_parallelism", opts.maxIntraOpParallelism); + } + } + } + return new MapDefun(opBuilder.build()); + } + + /** + * Sets the maxIntraOpParallelism option. + * + * @param maxIntraOpParallelism the maxIntraOpParallelism option + * @return this Options instance. + */ + public static Options maxIntraOpParallelism(Long maxIntraOpParallelism) { + return new Options().maxIntraOpParallelism(maxIntraOpParallelism); + } + + /** + * Gets output. + * <pre> + * A list of output tensors whose types are `output_types` and whose dimensions + * 0 are the same as the dimensions 0 of the tensors in `arguments`, and whose + * remaining dimensions correspond to those in `output_shapes`. + * </pre> + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.MapDefun} + */ + public static class Options { + private Long maxIntraOpParallelism; + + private Options() { + } + + /** + * Sets the maxIntraOpParallelism option. + * + * @param maxIntraOpParallelism the maxIntraOpParallelism option + * @return this Options instance. + */ + public Options maxIntraOpParallelism(Long maxIntraOpParallelism) { + this.maxIntraOpParallelism = maxIntraOpParallelism; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java new file mode 100644 index 00000000000..5b67a6129fb --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java @@ -0,0 +1,188 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. + */ +@Operator +public final class PartitionedCall extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "PartitionedCall"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private PartitionedCall(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new PartitionedCall operation. + * + * @param scope current scope + * @param args A list of input tensors. + * @param Tout A list of output types. + * @param f <pre> + * A function that takes 'args', a list of tensors, and returns 'output', + * another list of tensors. Input and output types are specified by 'Tin' + * and 'Tout'. The function body of f will be placed and partitioned across + * devices, setting this op apart from the regular Call op. + * </pre> + * @param options carries optional attribute values + * @return a new instance of PartitionedCall + */ + @Endpoint( + describeByClass = true + ) + public static PartitionedCall create(Scope scope, Iterable<Operand<?>> args, + List<Class<? extends TType>> Tout, ConcreteFunction f, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("PartitionedCall", scope.makeOpName("PartitionedCall")); + opBuilder.addInputList(Operands.asOutputs(args)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); + opBuilder.setAttr("f", f); + if (options != null) { + for (Options opts : options) { + if (opts.config != null) { + opBuilder.setAttr("config", opts.config); + } + if (opts.configProto != null) { + opBuilder.setAttr("config_proto", opts.configProto); + } + if (opts.executorType != null) { + opBuilder.setAttr("executor_type", opts.executorType); + } + } + } + return new PartitionedCall(opBuilder.build()); + } + + /** + * Sets the config option. + * + * @param config the config option + * @return this Options instance. + */ + public static Options config(String config) { + return new Options().config(config); + } + + /** + * Sets the configProto option. + * + * @param configProto the configProto option + * @return this Options instance. + */ + public static Options configProto(String configProto) { + return new Options().configProto(configProto); + } + + /** + * Sets the executorType option. + * + * @param executorType the executorType option + * @return this Options instance. + */ + public static Options executorType(String executorType) { + return new Options().executorType(executorType); + } + + /** + * Gets output. + * A list of return values. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.PartitionedCall} + */ + public static class Options { + private String config; + + private String configProto; + + private String executorType; + + private Options() { + } + + /** + * Sets the config option. + * + * @param config the config option + * @return this Options instance. + */ + public Options config(String config) { + this.config = config; + return this; + } + + /** + * Sets the configProto option. + * + * @param configProto the configProto option + * @return this Options instance. + */ + public Options configProto(String configProto) { + this.configProto = configProto; + return this; + } + + /** + * Sets the executorType option. + * + * @param executorType the executorType option + * @return this Options instance. + */ + public Options executorType(String executorType) { + this.executorType = executorType; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceDataset.java new file mode 100644 index 00000000000..3e15fc24817 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceDataset.java @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * Reduces the input dataset to a singleton using a reduce function. + */ +public final class ReduceDataset extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ReduceDataset"; + + private List<Output<?>> components; + + @SuppressWarnings("unchecked") + private ReduceDataset(Operation operation) { + super(operation); + int outputIdx = 0; + int componentsLength = operation.outputListLength("components"); + components = Arrays.asList(operation.outputList(outputIdx, componentsLength)); + outputIdx += componentsLength; + } + + /** + * Factory method to create a class wrapping a new ReduceDataset operation. + * + * @param scope current scope + * @param inputDataset A variant tensor representing the input dataset. + * @param initialState A nested structure of tensors, representing the initial state of the + * transformation. + * @param otherArguments the otherArguments value + * @param f A function that maps {@code (old_state, input_element)} to {@code new_state}. It must take + * two arguments and return a nested structures of tensors. The structure of + * {@code new_state} must match the structure of {@code initial_state}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ReduceDataset + */ + @Endpoint( + describeByClass = true + ) + public static ReduceDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> initialState, Iterable<Operand<?>> otherArguments, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("ReduceDataset", scope.makeOpName("ReduceDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(initialState)); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + if (options != null) { + for (Options opts : options) { + if (opts.useInterOpParallelism != null) { + opBuilder.setAttr("use_inter_op_parallelism", opts.useInterOpParallelism); + } + } + } + return new ReduceDataset(opBuilder.build()); + } + + /** + * Sets the useInterOpParallelism option. + * + * @param useInterOpParallelism the useInterOpParallelism option + * @return this Options instance. + */ + public static Options useInterOpParallelism(Boolean useInterOpParallelism) { + return new Options().useInterOpParallelism(useInterOpParallelism); + } + + /** + * Gets components. + * + * @return components. + */ + public List<Output<?>> components() { + return components; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) components.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.ReduceDataset} + */ + public static class Options { + private Boolean useInterOpParallelism; + + private Options() { + } + + /** + * Sets the useInterOpParallelism option. + * + * @param useInterOpParallelism the useInterOpParallelism option + * @return this Options instance. + */ + public Options useInterOpParallelism(Boolean useInterOpParallelism) { + this.useInterOpParallelism = useInterOpParallelism; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteCall.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteCall.java new file mode 100644 index 00000000000..416c7f222fc --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteCall.java @@ -0,0 +1,95 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; + +/** + * Runs function {@code f} on a remote device indicated by {@code target}. + */ +@Operator +public final class RemoteCall extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "RemoteCall"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private RemoteCall(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new RemoteCall operation. + * + * @param scope current scope + * @param target A fully specified device name where we want to run the function. + * @param args A list of arguments for the function. + * @param Tout The type list for the return values. + * @param f The function to run remotely. + * @return a new instance of RemoteCall + */ + @Endpoint( + describeByClass = true + ) + public static RemoteCall create(Scope scope, Operand<TString> target, Iterable<Operand<?>> args, + List<Class<? extends TType>> Tout, ConcreteFunction f) { + OperationBuilder opBuilder = scope.env().opBuilder("RemoteCall", scope.makeOpName("RemoteCall")); + opBuilder.addInput(target.asOutput()); + opBuilder.addInputList(Operands.asOutputs(args)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); + opBuilder.setAttr("f", f); + return new RemoteCall(opBuilder.build()); + } + + /** + * Gets output. + * A list of return values. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulPartitionedCall.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulPartitionedCall.java new file mode 100644 index 00000000000..70090176588 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulPartitionedCall.java @@ -0,0 +1,189 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. + */ +@Operator +public final class StatefulPartitionedCall extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "StatefulPartitionedCall"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private StatefulPartitionedCall(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new StatefulPartitionedCall operation. + * + * @param scope current scope + * @param args A list of input tensors. + * @param Tout A list of output types. + * @param f <pre> + * A function that takes 'args', a list of tensors, and returns 'output', + * another list of tensors. Input and output types are specified by 'Tin' + * and 'Tout'. The function body of f will be placed and partitioned across + * devices, setting this op apart from the regular Call op. This op is + * stateful. + * </pre> + * @param options carries optional attribute values + * @return a new instance of StatefulPartitionedCall + */ + @Endpoint( + describeByClass = true + ) + public static StatefulPartitionedCall create(Scope scope, Iterable<Operand<?>> args, + List<Class<? extends TType>> Tout, ConcreteFunction f, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("StatefulPartitionedCall", scope.makeOpName("StatefulPartitionedCall")); + opBuilder.addInputList(Operands.asOutputs(args)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); + opBuilder.setAttr("f", f); + if (options != null) { + for (Options opts : options) { + if (opts.config != null) { + opBuilder.setAttr("config", opts.config); + } + if (opts.configProto != null) { + opBuilder.setAttr("config_proto", opts.configProto); + } + if (opts.executorType != null) { + opBuilder.setAttr("executor_type", opts.executorType); + } + } + } + return new StatefulPartitionedCall(opBuilder.build()); + } + + /** + * Sets the config option. + * + * @param config the config option + * @return this Options instance. + */ + public static Options config(String config) { + return new Options().config(config); + } + + /** + * Sets the configProto option. + * + * @param configProto the configProto option + * @return this Options instance. + */ + public static Options configProto(String configProto) { + return new Options().configProto(configProto); + } + + /** + * Sets the executorType option. + * + * @param executorType the executorType option + * @return this Options instance. + */ + public static Options executorType(String executorType) { + return new Options().executorType(executorType); + } + + /** + * Gets output. + * A list of return values. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.StatefulPartitionedCall} + */ + public static class Options { + private String config; + + private String configProto; + + private String executorType; + + private Options() { + } + + /** + * Sets the config option. + * + * @param config the config option + * @return this Options instance. + */ + public Options config(String config) { + this.config = config; + return this; + } + + /** + * Sets the configProto option. + * + * @param configProto the configProto option + * @return this Options instance. + */ + public Options configProto(String configProto) { + this.configProto = configProto; + return this; + } + + /** + * Sets the executorType option. + * + * @param executorType the executorType option + * @return this Options instance. + */ + public Options executorType(String executorType) { + this.executorType = executorType; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessCase.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessCase.java new file mode 100644 index 00000000000..284f36c0db0 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessCase.java @@ -0,0 +1,186 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; + +/** + * An n-way switch statement which calls a single branch function. + * <pre> + * An n-way switch statement, implementing the following: + * ``` + * switch (branch_index) { + * case 0: + * output = branches[0](input); + * break; + * case 1: + * output = branches[1](input); + * break; + * ... + * case [[nbranches-1]]: + * default: + * output = branches[nbranches-1](input); + * break; + * } + * ``` + * + * This should only be used when the none of branches has stateful ops. + * </pre> + */ +public final class StatelessCase extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "StatelessCase"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private StatelessCase(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new StatelessCase operation. + * + * @param scope current scope + * @param branchIndex The branch selector, an int32 Tensor. + * @param input A list of input tensors passed to the branch function. + * @param Tout A list of output types. + * @param branches <pre> + * A list of functions each of which takes 'inputs' and returns a list of + * tensors, whose types are the same as what every other branch returns. + * </pre> + * @param options carries optional attribute values + * @return a new instance of StatelessCase + */ + @Endpoint( + describeByClass = true + ) + public static StatelessCase create(Scope scope, Operand<TInt32> branchIndex, + Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, + List<ConcreteFunction> branches, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("StatelessCase", scope.makeOpName("StatelessCase")); + opBuilder.addInput(branchIndex.asOutput()); + opBuilder.addInputList(Operands.asOutputs(input)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); + ConcreteFunction[] branchesArray = new ConcreteFunction[branches.size()]; + for (int i = 0 ; i < branchesArray.length ; i++) { + branchesArray[i] = branches.get(i); + } + opBuilder.setAttr("branches", branchesArray); + if (options != null) { + for (Options opts : options) { + if (opts.outputShapes != null) { + Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = opts.outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + } + } + } + return new StatelessCase(opBuilder.build()); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(List<Shape> outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(Shape[] outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Gets output. + * A list of return values. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.StatelessCase} + */ + public static class Options { + private List<Shape> outputShapes; + + private Options() { + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(List<Shape> outputShapes) { + this.outputShapes = outputShapes; + return this; + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(Shape... outputShapes) { + this.outputShapes = Arrays.asList(outputShapes); + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessIf.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessIf.java new file mode 100644 index 00000000000..8806806cdad --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessIf.java @@ -0,0 +1,178 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * output = cond ? then_branch(input) : else_branch(input) + */ +@Operator +public final class StatelessIf extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "StatelessIf"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private StatelessIf(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new StatelessIf operation. + * + * @param scope current scope + * @param cond <pre> + * A Tensor. If the tensor is a scalar of non-boolean type, the + * scalar is converted to a boolean according to the + * following rule: if the scalar is a numerical value, non-zero means + * `True` and zero means False; if the scalar is a string, non-empty + * means `True` and empty means `False`. If the tensor is not a scalar, + * being empty means False and being non-empty means True. + * + * This should only be used when the if then/else body functions do not + * have stateful ops. + * </pre> + * @param input A list of input tensors. + * @param Tout A list of output types. + * @param thenBranch <pre> + * A function that takes 'inputs' and returns a list of tensors, whose + * types are the same as what else_branch returns. + * </pre> + * @param elseBranch <pre> + * A function that takes 'inputs' and returns a list of tensors, whose + * types are the same as what then_branch returns. + * </pre> + * @param options carries optional attribute values + * @return a new instance of StatelessIf + */ + @Endpoint( + describeByClass = true + ) + public static StatelessIf create(Scope scope, Operand<? extends TType> cond, + Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, + ConcreteFunction elseBranch, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("StatelessIf", scope.makeOpName("StatelessIf")); + opBuilder.addInput(cond.asOutput()); + opBuilder.addInputList(Operands.asOutputs(input)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); + opBuilder.setAttr("then_branch", thenBranch); + opBuilder.setAttr("else_branch", elseBranch); + if (options != null) { + for (Options opts : options) { + if (opts.outputShapes != null) { + Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = opts.outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + } + } + } + return new StatelessIf(opBuilder.build()); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(List<Shape> outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(Shape[] outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Gets output. + * A list of return values. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.StatelessIf} + */ + public static class Options { + private List<Shape> outputShapes; + + private Options() { + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(List<Shape> outputShapes) { + this.outputShapes = outputShapes; + return this; + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(Shape... outputShapes) { + this.outputShapes = Arrays.asList(outputShapes); + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessWhile.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessWhile.java new file mode 100644 index 00000000000..8a806f1a9b1 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessWhile.java @@ -0,0 +1,198 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * output = input; While (Cond(output)) { output = Body(output) } + */ +@Operator +public final class StatelessWhile extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "StatelessWhile"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private StatelessWhile(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new StatelessWhile operation. + * + * @param scope current scope + * @param input A list of input tensors whose types are T. + * @param cond <pre> + * A function takes 'input' and returns a tensor. If the tensor is + * a scalar of non-boolean, the scalar is converted to a boolean + * according to the following rule: if the scalar is a numerical + * value, non-zero means True and zero means False; if the scalar is + * a string, non-empty means True and empty means False. If the + * tensor is not a scalar, non-emptiness means True and False + * otherwise. + * + * This should only be used when the while condition and body functions + * do not have stateful ops. + * </pre> + * @param body <pre> + * A function that takes a list of tensors and returns another + * list of tensors. Both lists have the same types as specified + * by T. + * </pre> + * @param options carries optional attribute values + * @return a new instance of StatelessWhile + */ + @Endpoint( + describeByClass = true + ) + public static StatelessWhile create(Scope scope, Iterable<Operand<?>> input, + ConcreteFunction cond, ConcreteFunction body, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("StatelessWhile", scope.makeOpName("StatelessWhile")); + opBuilder.addInputList(Operands.asOutputs(input)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("cond", cond); + opBuilder.setAttr("body", body); + if (options != null) { + for (Options opts : options) { + if (opts.outputShapes != null) { + Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = opts.outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + } + if (opts.parallelIterations != null) { + opBuilder.setAttr("parallel_iterations", opts.parallelIterations); + } + } + } + return new StatelessWhile(opBuilder.build()); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(List<Shape> outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(Shape[] outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the parallelIterations option. + * + * @param parallelIterations the parallelIterations option + * @return this Options instance. + */ + public static Options parallelIterations(Long parallelIterations) { + return new Options().parallelIterations(parallelIterations); + } + + /** + * Gets output. + * A list of output tensors whose types are T. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.StatelessWhile} + */ + public static class Options { + private List<Shape> outputShapes; + + private Long parallelIterations; + + private Options() { + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(List<Shape> outputShapes) { + this.outputShapes = outputShapes; + return this; + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(Shape... outputShapes) { + this.outputShapes = Arrays.asList(outputShapes); + return this; + } + + /** + * Sets the parallelIterations option. + * + * @param parallelIterations the parallelIterations option + * @return this Options instance. + */ + public Options parallelIterations(Long parallelIterations) { + this.parallelIterations = parallelIterations; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java new file mode 100644 index 00000000000..ac3b4e7a791 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java @@ -0,0 +1,195 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * output = input; While (Cond(output)) { output = Body(output) } + */ +@Operator +public final class While extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "While"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private While(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new While operation. + * + * @param scope current scope + * @param input A list of input tensors whose types are T. + * @param cond <pre> + * A function takes 'input' and returns a tensor. If the tensor is + * a scalar of non-boolean, the scalar is converted to a boolean + * according to the following rule: if the scalar is a numerical + * value, non-zero means True and zero means False; if the scalar is + * a string, non-empty means True and empty means False. If the + * tensor is not a scalar, non-emptiness means True and False + * otherwise. + * </pre> + * @param body <pre> + * A function that takes a list of tensors and returns another + * list of tensors. Both lists have the same types as specified + * by T. + * </pre> + * @param options carries optional attribute values + * @return a new instance of While + */ + @Endpoint( + describeByClass = true + ) + public static While create(Scope scope, Iterable<Operand<?>> input, ConcreteFunction cond, + ConcreteFunction body, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("While", scope.makeOpName("While")); + opBuilder.addInputList(Operands.asOutputs(input)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("cond", cond); + opBuilder.setAttr("body", body); + if (options != null) { + for (Options opts : options) { + if (opts.outputShapes != null) { + Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = opts.outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + } + if (opts.parallelIterations != null) { + opBuilder.setAttr("parallel_iterations", opts.parallelIterations); + } + } + } + return new While(opBuilder.build()); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(List<Shape> outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(Shape[] outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the parallelIterations option. + * + * @param parallelIterations the parallelIterations option + * @return this Options instance. + */ + public static Options parallelIterations(Long parallelIterations) { + return new Options().parallelIterations(parallelIterations); + } + + /** + * Gets output. + * A list of output tensors whose types are T. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.While} + */ + public static class Options { + private List<Shape> outputShapes; + + private Long parallelIterations; + + private Options() { + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(List<Shape> outputShapes) { + this.outputShapes = outputShapes; + return this; + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(Shape... outputShapes) { + this.outputShapes = Arrays.asList(outputShapes); + return this; + } + + /** + * Sets the parallelIterations option. + * + * @param parallelIterations the parallelIterations option + * @return this Options instance. + */ + public Options parallelIterations(Long parallelIterations) { + this.parallelIterations = parallelIterations; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestBranchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestBranchDataset.java new file mode 100644 index 00000000000..67f309f9d62 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestBranchDataset.java @@ -0,0 +1,115 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +/** + * The ChooseFastestBranchDataset operation + */ +public final class ChooseFastestBranchDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ChooseFastestBranchDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private ChooseFastestBranchDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ChooseFastestBranchDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param ratioNumerator the ratioNumerator value + * @param ratioDenominator the ratioDenominator value + * @param otherArguments the otherArguments value + * @param numElementsPerBranch the value of the numElementsPerBranch property + * @param branches the value of the branches property + * @param otherArgumentsLengths the value of the otherArgumentsLengths property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of ChooseFastestBranchDataset + */ + @Endpoint( + describeByClass = true + ) + public static ChooseFastestBranchDataset create(Scope scope, + Operand<? extends TType> inputDataset, Operand<TInt64> ratioNumerator, + Operand<TInt64> ratioDenominator, Iterable<Operand<?>> otherArguments, + Long numElementsPerBranch, List<ConcreteFunction> branches, List<Long> otherArgumentsLengths, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + OperationBuilder opBuilder = scope.env().opBuilder("ChooseFastestBranchDataset", scope.makeOpName("ChooseFastestBranchDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInput(ratioNumerator.asOutput()); + opBuilder.addInput(ratioDenominator.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("num_elements_per_branch", numElementsPerBranch); + ConcreteFunction[] branchesArray = new ConcreteFunction[branches.size()]; + for (int i = 0 ; i < branchesArray.length ; i++) { + branchesArray[i] = branches.get(i); + } + opBuilder.setAttr("branches", branchesArray); + long[] otherArgumentsLengthsArray = new long[otherArgumentsLengths.size()]; + for (int i = 0 ; i < otherArgumentsLengthsArray.length ; i++) { + otherArgumentsLengthsArray[i] = otherArgumentsLengths.get(i); + } + opBuilder.setAttr("other_arguments_lengths", otherArgumentsLengthsArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + return new ChooseFastestBranchDataset(opBuilder.build()); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterDataset.java new file mode 100644 index 00000000000..e0d8d172db5 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterDataset.java @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset containing elements of {@code input_dataset} matching {@code predicate}. + * The {@code predicate} function must return a scalar boolean and accept the + * following arguments: + * <ul> + * <li>One tensor for each component of an element of {@code input_dataset}.</li> + * <li>One tensor for each value in {@code other_arguments}.</li> + * </ul> + */ +@Operator( + group = "data" +) +public final class FilterDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "FilterDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private FilterDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new FilterDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param otherArguments A list of tensors, typically values that were captured when + * building a closure for {@code predicate}. + * @param predicate A function returning a scalar boolean. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of FilterDataset + */ + @Endpoint( + describeByClass = true + ) + public static FilterDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, ConcreteFunction predicate, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + OperationBuilder opBuilder = scope.env().opBuilder("FilterDataset", scope.makeOpName("FilterDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("predicate", predicate); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + return new FilterDataset(opBuilder.build()); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FlatMapDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FlatMapDataset.java new file mode 100644 index 00000000000..f9022a2e3f2 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FlatMapDataset.java @@ -0,0 +1,105 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + * Unlike MapDataset, the {@code f} in FlatMapDataset is expected to return a + * Dataset variant, and FlatMapDataset will flatten successive results + * into a single Dataset. + */ +@Operator( + group = "data" +) +public final class FlatMapDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "FlatMapDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private FlatMapDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new FlatMapDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param otherArguments the otherArguments value + * @param f A function mapping elements of {@code input_dataset}, concatenated with + * {@code other_arguments}, to a Dataset variant that contains elements matching + * {@code output_types} and {@code output_shapes}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of FlatMapDataset + */ + @Endpoint( + describeByClass = true + ) + public static FlatMapDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + OperationBuilder opBuilder = scope.env().opBuilder("FlatMapDataset", scope.makeOpName("FlatMapDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + return new FlatMapDataset(opBuilder.build()); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GeneratorDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GeneratorDataset.java new file mode 100644 index 00000000000..d19c70f72bd --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GeneratorDataset.java @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that invokes a function to generate elements. + */ +public final class GeneratorDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "GeneratorDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private GeneratorDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new GeneratorDataset operation. + * + * @param scope current scope + * @param initFuncOtherArgs the initFuncOtherArgs value + * @param nextFuncOtherArgs the nextFuncOtherArgs value + * @param finalizeFuncOtherArgs the finalizeFuncOtherArgs value + * @param initFunc the value of the initFunc property + * @param nextFunc the value of the nextFunc property + * @param finalizeFunc the value of the finalizeFunc property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of GeneratorDataset + */ + @Endpoint( + describeByClass = true + ) + public static GeneratorDataset create(Scope scope, Iterable<Operand<?>> initFuncOtherArgs, + Iterable<Operand<?>> nextFuncOtherArgs, Iterable<Operand<?>> finalizeFuncOtherArgs, + ConcreteFunction initFunc, ConcreteFunction nextFunc, ConcreteFunction finalizeFunc, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + OperationBuilder opBuilder = scope.env().opBuilder("GeneratorDataset", scope.makeOpName("GeneratorDataset")); + opBuilder.addInputList(Operands.asOutputs(initFuncOtherArgs)); + opBuilder.addInputList(Operands.asOutputs(nextFuncOtherArgs)); + opBuilder.addInputList(Operands.asOutputs(finalizeFuncOtherArgs)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("init_func", initFunc); + opBuilder.setAttr("next_func", nextFunc); + opBuilder.setAttr("finalize_func", finalizeFunc); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + return new GeneratorDataset(opBuilder.build()); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByWindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByWindowDataset.java new file mode 100644 index 00000000000..4e7811fd31d --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByWindowDataset.java @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that computes a windowed group-by on {@code input_dataset}. + * // TODO(mrry): Support non-int64 keys. + */ +public final class GroupByWindowDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "GroupByWindowDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private GroupByWindowDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new GroupByWindowDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param keyFuncOtherArguments the keyFuncOtherArguments value + * @param reduceFuncOtherArguments the reduceFuncOtherArguments value + * @param windowSizeFuncOtherArguments the windowSizeFuncOtherArguments value + * @param keyFunc A function mapping an element of {@code input_dataset}, concatenated + * with {@code key_func_other_arguments} to a scalar value of type DT_INT64. + * @param reduceFunc the value of the reduceFunc property + * @param windowSizeFunc the value of the windowSizeFunc property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of GroupByWindowDataset + */ + @Endpoint( + describeByClass = true + ) + public static GroupByWindowDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> keyFuncOtherArguments, Iterable<Operand<?>> reduceFuncOtherArguments, + Iterable<Operand<?>> windowSizeFuncOtherArguments, ConcreteFunction keyFunc, + ConcreteFunction reduceFunc, ConcreteFunction windowSizeFunc, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + OperationBuilder opBuilder = scope.env().opBuilder("GroupByWindowDataset", scope.makeOpName("GroupByWindowDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(keyFuncOtherArguments)); + opBuilder.addInputList(Operands.asOutputs(reduceFuncOtherArguments)); + opBuilder.addInputList(Operands.asOutputs(windowSizeFuncOtherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("key_func", keyFunc); + opBuilder.setAttr("reduce_func", reduceFunc); + opBuilder.setAttr("window_size_func", windowSizeFunc); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + return new GroupByWindowDataset(opBuilder.build()); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InterleaveDataset.java new file mode 100644 index 00000000000..b8e0b43187e --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InterleaveDataset.java @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + * Unlike MapDataset, the {@code f} in InterleaveDataset is expected to return + * a Dataset variant, and InterleaveDataset will flatten successive + * results into a single Dataset. Unlike FlatMapDataset, + * InterleaveDataset will interleave sequences of up to {@code block_length} + * consecutive elements from {@code cycle_length} input elements. + */ +@Operator( + group = "data" +) +public final class InterleaveDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "InterleaveDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private InterleaveDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new InterleaveDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param otherArguments the otherArguments value + * @param cycleLength the cycleLength value + * @param blockLength the blockLength value + * @param f A function mapping elements of {@code input_dataset}, concatenated with + * {@code other_arguments}, to a Dataset variant that contains elements matching + * {@code output_types} and {@code output_shapes}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of InterleaveDataset + */ + @Endpoint( + describeByClass = true + ) + public static InterleaveDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, Operand<TInt64> cycleLength, Operand<TInt64> blockLength, + ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + OperationBuilder opBuilder = scope.env().opBuilder("InterleaveDataset", scope.makeOpName("InterleaveDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder.addInput(cycleLength.asOutput()); + opBuilder.addInput(blockLength.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + return new InterleaveDataset(opBuilder.build()); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LoadDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LoadDataset.java new file mode 100644 index 00000000000..f46937b36f2 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LoadDataset.java @@ -0,0 +1,136 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; + +/** + * The LoadDataset operation + */ +public final class LoadDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "LoadDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private LoadDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new LoadDataset operation. + * + * @param scope current scope + * @param path the path value + * @param readerFuncOtherArgs the readerFuncOtherArgs value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param readerFunc the value of the readerFunc property + * @param options carries optional attribute values + * @return a new instance of LoadDataset + */ + @Endpoint( + describeByClass = true + ) + public static LoadDataset create(Scope scope, Operand<TString> path, + Iterable<Operand<?>> readerFuncOtherArgs, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes, ConcreteFunction readerFunc, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("LoadDataset", scope.makeOpName("LoadDataset")); + opBuilder.addInput(path.asOutput()); + opBuilder.addInputList(Operands.asOutputs(readerFuncOtherArgs)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + opBuilder.setAttr("reader_func", readerFunc); + if (options != null) { + for (Options opts : options) { + if (opts.compression != null) { + opBuilder.setAttr("compression", opts.compression); + } + } + } + return new LoadDataset(opBuilder.build()); + } + + /** + * Sets the compression option. + * + * @param compression the compression option + * @return this Options instance. + */ + public static Options compression(String compression) { + return new Options().compression(compression); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.LoadDataset} + */ + public static class Options { + private String compression; + + private Options() { + } + + /** + * Sets the compression option. + * + * @param compression the compression option + * @return this Options instance. + */ + public Options compression(String compression) { + this.compression = compression; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MapDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MapDataset.java new file mode 100644 index 00000000000..5f4ab8e6776 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MapDataset.java @@ -0,0 +1,165 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + */ +@Operator( + group = "data" +) +public final class MapDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "MapDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private MapDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new MapDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param otherArguments the otherArguments value + * @param f the value of the f property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of MapDataset + */ + @Endpoint( + describeByClass = true + ) + public static MapDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("MapDataset", scope.makeOpName("MapDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + if (options != null) { + for (Options opts : options) { + if (opts.useInterOpParallelism != null) { + opBuilder.setAttr("use_inter_op_parallelism", opts.useInterOpParallelism); + } + if (opts.preserveCardinality != null) { + opBuilder.setAttr("preserve_cardinality", opts.preserveCardinality); + } + } + } + return new MapDataset(opBuilder.build()); + } + + /** + * Sets the useInterOpParallelism option. + * + * @param useInterOpParallelism the useInterOpParallelism option + * @return this Options instance. + */ + public static Options useInterOpParallelism(Boolean useInterOpParallelism) { + return new Options().useInterOpParallelism(useInterOpParallelism); + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public static Options preserveCardinality(Boolean preserveCardinality) { + return new Options().preserveCardinality(preserveCardinality); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.MapDataset} + */ + public static class Options { + private Boolean useInterOpParallelism; + + private Boolean preserveCardinality; + + private Options() { + } + + /** + * Sets the useInterOpParallelism option. + * + * @param useInterOpParallelism the useInterOpParallelism option + * @return this Options instance. + */ + public Options useInterOpParallelism(Boolean useInterOpParallelism) { + this.useInterOpParallelism = useInterOpParallelism; + return this; + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public Options preserveCardinality(Boolean preserveCardinality) { + this.preserveCardinality = preserveCardinality; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OneShotIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OneShotIterator.java new file mode 100644 index 00000000000..01330924f62 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OneShotIterator.java @@ -0,0 +1,178 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * Makes a "one-shot" iterator that can be iterated only once. + * A one-shot iterator bundles the logic for defining the dataset and + * the state of the iterator in a single op, which allows simple input + * pipelines to be defined without an additional initialization + * ("MakeIterator") step. + * <p>One-shot iterators have the following limitations: + * <ul> + * <li>They do not support parameterization: all logic for creating the underlying + * dataset must be bundled in the {@code dataset_factory} function.</li> + * <li>They are not resettable. Once a one-shot iterator reaches the end of its + * underlying dataset, subsequent "IteratorGetNext" operations on that + * iterator will always produce an {@code OutOfRange} error.</li> + * </ul> + * <p>For greater flexibility, use "Iterator" and "MakeIterator" to define + * an iterator using an arbitrary subgraph, which may capture tensors + * (including fed values) as parameters, and which may be reset multiple + * times by rerunning "MakeIterator". + */ +@Operator( + group = "data" +) +public final class OneShotIterator extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "OneShotIterator"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private OneShotIterator(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new OneShotIterator operation. + * + * @param scope current scope + * @param datasetFactory A function of type {@code () -> DT_VARIANT}, where the returned + * DT_VARIANT is a dataset. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of OneShotIterator + */ + @Endpoint( + describeByClass = true + ) + public static OneShotIterator create(Scope scope, ConcreteFunction datasetFactory, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("OneShotIterator", scope.makeOpName("OneShotIterator")); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("dataset_factory", datasetFactory); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + if (options != null) { + for (Options opts : options) { + if (opts.container != null) { + opBuilder.setAttr("container", opts.container); + } + if (opts.sharedName != null) { + opBuilder.setAttr("shared_name", opts.sharedName); + } + } + } + return new OneShotIterator(opBuilder.build()); + } + + /** + * Sets the container option. + * + * @param container the container option + * @return this Options instance. + */ + public static Options container(String container) { + return new Options().container(container); + } + + /** + * Sets the sharedName option. + * + * @param sharedName the sharedName option + * @return this Options instance. + */ + public static Options sharedName(String sharedName) { + return new Options().sharedName(sharedName); + } + + /** + * Gets handle. + * A handle to the iterator that can be passed to an "IteratorGetNext" + * op. + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.OneShotIterator} + */ + public static class Options { + private String container; + + private String sharedName; + + private Options() { + } + + /** + * Sets the container option. + * + * @param container the container option + * @return this Options instance. + */ + public Options container(String container) { + this.container = container; + return this; + } + + /** + * Sets the sharedName option. + * + * @param sharedName the sharedName option + * @return this Options instance. + */ + public Options sharedName(String sharedName) { + this.sharedName = sharedName; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelMapDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelMapDataset.java new file mode 100644 index 00000000000..b8488d7a1ac --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelMapDataset.java @@ -0,0 +1,193 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + * Unlike a "MapDataset", which applies {@code f} sequentially, this dataset invokes up + * to {@code num_parallel_calls} copies of {@code f} in parallel. + */ +public final class ParallelMapDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ParallelMapDatasetV2"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private ParallelMapDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ParallelMapDatasetV2 operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param otherArguments the otherArguments value + * @param numParallelCalls The number of concurrent invocations of {@code f} that process + * elements from {@code input_dataset} in parallel. + * @param f the value of the f property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ParallelMapDataset + */ + @Endpoint( + describeByClass = true + ) + public static ParallelMapDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, Operand<TInt64> numParallelCalls, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("ParallelMapDatasetV2", scope.makeOpName("ParallelMapDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder.addInput(numParallelCalls.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + if (options != null) { + for (Options opts : options) { + if (opts.useInterOpParallelism != null) { + opBuilder.setAttr("use_inter_op_parallelism", opts.useInterOpParallelism); + } + if (opts.deterministic != null) { + opBuilder.setAttr("deterministic", opts.deterministic); + } + if (opts.preserveCardinality != null) { + opBuilder.setAttr("preserve_cardinality", opts.preserveCardinality); + } + } + } + return new ParallelMapDataset(opBuilder.build()); + } + + /** + * Sets the useInterOpParallelism option. + * + * @param useInterOpParallelism the useInterOpParallelism option + * @return this Options instance. + */ + public static Options useInterOpParallelism(Boolean useInterOpParallelism) { + return new Options().useInterOpParallelism(useInterOpParallelism); + } + + /** + * Sets the deterministic option. + * + * @param deterministic the deterministic option + * @return this Options instance. + */ + public static Options deterministic(String deterministic) { + return new Options().deterministic(deterministic); + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public static Options preserveCardinality(Boolean preserveCardinality) { + return new Options().preserveCardinality(preserveCardinality); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.ParallelMapDataset} + */ + public static class Options { + private Boolean useInterOpParallelism; + + private String deterministic; + + private Boolean preserveCardinality; + + private Options() { + } + + /** + * Sets the useInterOpParallelism option. + * + * @param useInterOpParallelism the useInterOpParallelism option + * @return this Options instance. + */ + public Options useInterOpParallelism(Boolean useInterOpParallelism) { + this.useInterOpParallelism = useInterOpParallelism; + return this; + } + + /** + * Sets the deterministic option. + * + * @param deterministic the deterministic option + * @return this Options instance. + */ + public Options deterministic(String deterministic) { + this.deterministic = deterministic; + return this; + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public Options preserveCardinality(Boolean preserveCardinality) { + this.preserveCardinality = preserveCardinality; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SaveDataset.java new file mode 100644 index 00000000000..9d2760b0424 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SaveDataset.java @@ -0,0 +1,133 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; + +/** + * The SaveDataset operation + */ +public final class SaveDataset extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "SaveDataset"; + + private SaveDataset(Operation operation) { + super(operation); + } + + /** + * Factory method to create a class wrapping a new SaveDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param path the path value + * @param shardFuncOtherArgs the shardFuncOtherArgs value + * @param shardFunc the value of the shardFunc property + * @param options carries optional attribute values + * @return a new instance of SaveDataset + */ + @Endpoint( + describeByClass = true + ) + public static SaveDataset create(Scope scope, Operand<? extends TType> inputDataset, + Operand<TString> path, Iterable<Operand<?>> shardFuncOtherArgs, ConcreteFunction shardFunc, + Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("SaveDataset", scope.makeOpName("SaveDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInput(path.asOutput()); + opBuilder.addInputList(Operands.asOutputs(shardFuncOtherArgs)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("shard_func", shardFunc); + if (options != null) { + for (Options opts : options) { + if (opts.compression != null) { + opBuilder.setAttr("compression", opts.compression); + } + if (opts.useShardFunc != null) { + opBuilder.setAttr("use_shard_func", opts.useShardFunc); + } + } + } + return new SaveDataset(opBuilder.build()); + } + + /** + * Sets the compression option. + * + * @param compression the compression option + * @return this Options instance. + */ + public static Options compression(String compression) { + return new Options().compression(compression); + } + + /** + * Sets the useShardFunc option. + * + * @param useShardFunc the useShardFunc option + * @return this Options instance. + */ + public static Options useShardFunc(Boolean useShardFunc) { + return new Options().useShardFunc(useShardFunc); + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.SaveDataset} + */ + public static class Options { + private String compression; + + private Boolean useShardFunc; + + private Options() { + } + + /** + * Sets the compression option. + * + * @param compression the compression option + * @return this Options instance. + */ + public Options compression(String compression) { + this.compression = compression; + return this; + } + + /** + * Sets the useShardFunc option. + * + * @param useShardFunc the useShardFunc option + * @return this Options instance. + */ + public Options useShardFunc(Boolean useShardFunc) { + this.useShardFunc = useShardFunc; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ScanDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ScanDataset.java new file mode 100644 index 00000000000..4ed398fb214 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ScanDataset.java @@ -0,0 +1,163 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset successively reduces {@code f} over the elements of {@code input_dataset}. + */ +public final class ScanDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ScanDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private ScanDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ScanDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param initialState the initialState value + * @param otherArguments the otherArguments value + * @param f the value of the f property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ScanDataset + */ + @Endpoint( + describeByClass = true + ) + public static ScanDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> initialState, Iterable<Operand<?>> otherArguments, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("ScanDataset", scope.makeOpName("ScanDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(initialState)); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + if (options != null) { + for (Options opts : options) { + if (opts.preserveCardinality != null) { + opBuilder.setAttr("preserve_cardinality", opts.preserveCardinality); + } + if (opts.useDefaultDevice != null) { + opBuilder.setAttr("use_default_device", opts.useDefaultDevice); + } + } + } + return new ScanDataset(opBuilder.build()); + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public static Options preserveCardinality(Boolean preserveCardinality) { + return new Options().preserveCardinality(preserveCardinality); + } + + /** + * Sets the useDefaultDevice option. + * + * @param useDefaultDevice the useDefaultDevice option + * @return this Options instance. + */ + public static Options useDefaultDevice(Boolean useDefaultDevice) { + return new Options().useDefaultDevice(useDefaultDevice); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.ScanDataset} + */ + public static class Options { + private Boolean preserveCardinality; + + private Boolean useDefaultDevice; + + private Options() { + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public Options preserveCardinality(Boolean preserveCardinality) { + this.preserveCardinality = preserveCardinality; + return this; + } + + /** + * Sets the useDefaultDevice option. + * + * @param useDefaultDevice the useDefaultDevice option + * @return this Options instance. + */ + public Options useDefaultDevice(Boolean useDefaultDevice) { + this.useDefaultDevice = useDefaultDevice; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java new file mode 100644 index 00000000000..426020c2f17 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java @@ -0,0 +1,200 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that will write to / read from a snapshot. + * This dataset attempts to determine whether a valid snapshot exists at the + * {@code snapshot_path}, and reads from the snapshot in lieu of using {@code input_dataset}. + * If not, it will run the preprocessing pipeline as usual, and write out a + * snapshot of the data processed for future use. + */ +public final class SnapshotDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "SnapshotDatasetV2"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private SnapshotDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new SnapshotDatasetV2 operation. + * + * @param scope current scope + * @param inputDataset A variant tensor representing the input dataset. + * @param path The path we should write snapshots to / read snapshots from. + * @param readerFuncOtherArgs the readerFuncOtherArgs value + * @param shardFuncOtherArgs the shardFuncOtherArgs value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param readerFunc Optional. A function to control how to read data from snapshot shards. + * @param shardFunc Optional. A function to control how to shard data when writing a snapshot. + * @param options carries optional attribute values + * @return a new instance of SnapshotDataset + */ + @Endpoint( + describeByClass = true + ) + public static SnapshotDataset create(Scope scope, Operand<? extends TType> inputDataset, + Operand<TString> path, Iterable<Operand<?>> readerFuncOtherArgs, + Iterable<Operand<?>> shardFuncOtherArgs, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes, ConcreteFunction readerFunc, ConcreteFunction shardFunc, + Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("SnapshotDatasetV2", scope.makeOpName("SnapshotDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInput(path.asOutput()); + opBuilder.addInputList(Operands.asOutputs(readerFuncOtherArgs)); + opBuilder.addInputList(Operands.asOutputs(shardFuncOtherArgs)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + opBuilder.setAttr("reader_func", readerFunc); + opBuilder.setAttr("shard_func", shardFunc); + if (options != null) { + for (Options opts : options) { + if (opts.compression != null) { + opBuilder.setAttr("compression", opts.compression); + } + if (opts.readerPrefix != null) { + opBuilder.setAttr("reader_prefix", opts.readerPrefix); + } + if (opts.writerPrefix != null) { + opBuilder.setAttr("writer_prefix", opts.writerPrefix); + } + } + } + return new SnapshotDataset(opBuilder.build()); + } + + /** + * Sets the compression option. + * + * @param compression The type of compression to be applied to the saved snapshot files. + * @return this Options instance. + */ + public static Options compression(String compression) { + return new Options().compression(compression); + } + + /** + * Sets the readerPrefix option. + * + * @param readerPrefix the readerPrefix option + * @return this Options instance. + */ + public static Options readerPrefix(String readerPrefix) { + return new Options().readerPrefix(readerPrefix); + } + + /** + * Sets the writerPrefix option. + * + * @param writerPrefix the writerPrefix option + * @return this Options instance. + */ + public static Options writerPrefix(String writerPrefix) { + return new Options().writerPrefix(writerPrefix); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.SnapshotDataset} + */ + public static class Options { + private String compression; + + private String readerPrefix; + + private String writerPrefix; + + private Options() { + } + + /** + * Sets the compression option. + * + * @param compression The type of compression to be applied to the saved snapshot files. + * @return this Options instance. + */ + public Options compression(String compression) { + this.compression = compression; + return this; + } + + /** + * Sets the readerPrefix option. + * + * @param readerPrefix the readerPrefix option + * @return this Options instance. + */ + public Options readerPrefix(String readerPrefix) { + this.readerPrefix = readerPrefix; + return this; + } + + /** + * Sets the writerPrefix option. + * + * @param writerPrefix the writerPrefix option + * @return this Options instance. + */ + public Options writerPrefix(String writerPrefix) { + this.writerPrefix = writerPrefix; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeWhileDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeWhileDataset.java new file mode 100644 index 00000000000..23bc41932d8 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeWhileDataset.java @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that stops iteration when predicate` is false. + * The {@code predicate} function must return a scalar boolean and accept the + * following arguments: + * <ul> + * <li>One tensor for each component of an element of {@code input_dataset}.</li> + * <li>One tensor for each value in {@code other_arguments}.</li> + * </ul> + */ +public final class TakeWhileDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "TakeWhileDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private TakeWhileDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new TakeWhileDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param otherArguments A list of tensors, typically values that were captured when + * building a closure for {@code predicate}. + * @param predicate A function returning a scalar boolean. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of TakeWhileDataset + */ + @Endpoint( + describeByClass = true + ) + public static TakeWhileDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, ConcreteFunction predicate, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + OperationBuilder opBuilder = scope.env().opBuilder("TakeWhileDataset", scope.makeOpName("TakeWhileDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("predicate", predicate); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + return new TakeWhileDataset(opBuilder.build()); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByReducerDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByReducerDataset.java new file mode 100644 index 00000000000..822d5e03c51 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByReducerDataset.java @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data.experimental; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that computes a group-by on {@code input_dataset}. + * Creates a dataset that computes a group-by on {@code input_dataset}. + */ +public final class GroupByReducerDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ExperimentalGroupByReducerDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private GroupByReducerDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ExperimentalGroupByReducerDataset operation. + * + * @param scope current scope + * @param inputDataset A variant tensor representing the input dataset. + * @param keyFuncOtherArguments A list of tensors, typically values that were captured when + * building a closure for {@code key_func}. + * @param initFuncOtherArguments A list of tensors, typically values that were captured when + * building a closure for {@code init_func}. + * @param reduceFuncOtherArguments A list of tensors, typically values that were captured when + * building a closure for {@code reduce_func}. + * @param finalizeFuncOtherArguments A list of tensors, typically values that were captured when + * building a closure for {@code finalize_func}. + * @param keyFunc A function mapping an element of {@code input_dataset}, concatenated + * with {@code key_func_other_arguments} to a scalar value of type DT_INT64. + * @param initFunc A function mapping a key of type DT_INT64, concatenated with + * {@code init_func_other_arguments} to the initial reducer state. + * @param reduceFunc A function mapping the current reducer state and an element of {@code input_dataset}, + * concatenated with {@code reduce_func_other_arguments} to a new reducer state. + * @param finalizeFunc A function mapping the final reducer state to an output element. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of GroupByReducerDataset + */ + @Endpoint( + describeByClass = true + ) + public static GroupByReducerDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> keyFuncOtherArguments, Iterable<Operand<?>> initFuncOtherArguments, + Iterable<Operand<?>> reduceFuncOtherArguments, + Iterable<Operand<?>> finalizeFuncOtherArguments, ConcreteFunction keyFunc, + ConcreteFunction initFunc, ConcreteFunction reduceFunc, ConcreteFunction finalizeFunc, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalGroupByReducerDataset", scope.makeOpName("GroupByReducerDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(keyFuncOtherArguments)); + opBuilder.addInputList(Operands.asOutputs(initFuncOtherArguments)); + opBuilder.addInputList(Operands.asOutputs(reduceFuncOtherArguments)); + opBuilder.addInputList(Operands.asOutputs(finalizeFuncOtherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("key_func", keyFunc); + opBuilder.setAttr("init_func", initFunc); + opBuilder.setAttr("reduce_func", reduceFunc); + opBuilder.setAttr("finalize_func", finalizeFunc); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + return new GroupByReducerDataset(opBuilder.build()); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByWindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByWindowDataset.java new file mode 100644 index 00000000000..fe598c19cad --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByWindowDataset.java @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data.experimental; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that computes a windowed group-by on {@code input_dataset}. + * // TODO(mrry): Support non-int64 keys. + */ +public final class GroupByWindowDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ExperimentalGroupByWindowDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private GroupByWindowDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ExperimentalGroupByWindowDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param keyFuncOtherArguments the keyFuncOtherArguments value + * @param reduceFuncOtherArguments the reduceFuncOtherArguments value + * @param windowSizeFuncOtherArguments the windowSizeFuncOtherArguments value + * @param keyFunc A function mapping an element of {@code input_dataset}, concatenated + * with {@code key_func_other_arguments} to a scalar value of type DT_INT64. + * @param reduceFunc the value of the reduceFunc property + * @param windowSizeFunc the value of the windowSizeFunc property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of GroupByWindowDataset + */ + @Endpoint( + describeByClass = true + ) + public static GroupByWindowDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> keyFuncOtherArguments, Iterable<Operand<?>> reduceFuncOtherArguments, + Iterable<Operand<?>> windowSizeFuncOtherArguments, ConcreteFunction keyFunc, + ConcreteFunction reduceFunc, ConcreteFunction windowSizeFunc, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalGroupByWindowDataset", scope.makeOpName("GroupByWindowDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(keyFuncOtherArguments)); + opBuilder.addInputList(Operands.asOutputs(reduceFuncOtherArguments)); + opBuilder.addInputList(Operands.asOutputs(windowSizeFuncOtherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("key_func", keyFunc); + opBuilder.setAttr("reduce_func", reduceFunc); + opBuilder.setAttr("window_size_func", windowSizeFunc); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + return new GroupByWindowDataset(opBuilder.build()); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LegacyParallelInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LegacyParallelInterleaveDataset.java new file mode 100644 index 00000000000..4a33fd1657c --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LegacyParallelInterleaveDataset.java @@ -0,0 +1,155 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data.experimental; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + * The resulting dataset is similar to the {@code InterleaveDataset}, with the exception + * that if retrieving the next value from a dataset would cause the requester to + * block, it will skip that input dataset. This dataset is especially useful + * when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it + * allows the training step to proceed so long as some data is available. + * <p>!! WARNING !! This dataset is not deterministic! + */ +public final class LegacyParallelInterleaveDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "LegacyParallelInterleaveDatasetV2"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private LegacyParallelInterleaveDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new LegacyParallelInterleaveDatasetV2 operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param otherArguments the otherArguments value + * @param cycleLength the cycleLength value + * @param blockLength the blockLength value + * @param bufferOutputElements the bufferOutputElements value + * @param prefetchInputElements the prefetchInputElements value + * @param f A function mapping elements of {@code input_dataset}, concatenated with + * {@code other_arguments}, to a Dataset variant that contains elements matching + * {@code output_types} and {@code output_shapes}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of LegacyParallelInterleaveDataset + */ + @Endpoint( + describeByClass = true + ) + public static LegacyParallelInterleaveDataset create(Scope scope, + Operand<? extends TType> inputDataset, Iterable<Operand<?>> otherArguments, + Operand<TInt64> cycleLength, Operand<TInt64> blockLength, + Operand<TInt64> bufferOutputElements, Operand<TInt64> prefetchInputElements, + ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("LegacyParallelInterleaveDatasetV2", scope.makeOpName("LegacyParallelInterleaveDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder.addInput(cycleLength.asOutput()); + opBuilder.addInput(blockLength.asOutput()); + opBuilder.addInput(bufferOutputElements.asOutput()); + opBuilder.addInput(prefetchInputElements.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + if (options != null) { + for (Options opts : options) { + if (opts.deterministic != null) { + opBuilder.setAttr("deterministic", opts.deterministic); + } + } + } + return new LegacyParallelInterleaveDataset(opBuilder.build()); + } + + /** + * Sets the deterministic option. + * + * @param deterministic the deterministic option + * @return this Options instance. + */ + public static Options deterministic(String deterministic) { + return new Options().deterministic(deterministic); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.experimental.LegacyParallelInterleaveDataset} + */ + public static class Options { + private String deterministic; + + private Options() { + } + + /** + * Sets the deterministic option. + * + * @param deterministic the deterministic option + * @return this Options instance. + */ + public Options deterministic(String deterministic) { + this.deterministic = deterministic; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapAndBatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapAndBatchDataset.java new file mode 100644 index 00000000000..07f5a18ac79 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapAndBatchDataset.java @@ -0,0 +1,154 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data.experimental; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that fuses mapping with batching. + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset} and then + * batches {@code batch_size} of them. + * <p>Unlike a "MapDataset", which applies {@code f} sequentially, this dataset invokes up + * to {@code batch_size * num_parallel_batches} copies of {@code f} in parallel. + */ +public final class MapAndBatchDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ExperimentalMapAndBatchDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private MapAndBatchDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ExperimentalMapAndBatchDataset operation. + * + * @param scope current scope + * @param inputDataset A variant tensor representing the input dataset. + * @param otherArguments A list of tensors, typically values that were captured when building a closure + * for {@code f}. + * @param batchSize A scalar representing the number of elements to accumulate in a + * batch. It determines the number of concurrent invocations of {@code f} that process + * elements from {@code input_dataset} in parallel. + * @param numParallelCalls A scalar representing the maximum number of parallel invocations of the {@code map_fn} + * function. Applying the {@code map_fn} on consecutive input elements in parallel has + * the potential to improve input pipeline throughput. + * @param dropRemainder A scalar representing whether the last batch should be dropped in case its size + * is smaller than desired. + * @param f A function to apply to the outputs of {@code input_dataset}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of MapAndBatchDataset + */ + @Endpoint( + describeByClass = true + ) + public static MapAndBatchDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, Operand<TInt64> batchSize, + Operand<TInt64> numParallelCalls, Operand<TBool> dropRemainder, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalMapAndBatchDataset", scope.makeOpName("MapAndBatchDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder.addInput(batchSize.asOutput()); + opBuilder.addInput(numParallelCalls.asOutput()); + opBuilder.addInput(dropRemainder.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + if (options != null) { + for (Options opts : options) { + if (opts.preserveCardinality != null) { + opBuilder.setAttr("preserve_cardinality", opts.preserveCardinality); + } + } + } + return new MapAndBatchDataset(opBuilder.build()); + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public static Options preserveCardinality(Boolean preserveCardinality) { + return new Options().preserveCardinality(preserveCardinality); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.experimental.MapAndBatchDataset} + */ + public static class Options { + private Boolean preserveCardinality; + + private Options() { + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public Options preserveCardinality(Boolean preserveCardinality) { + this.preserveCardinality = preserveCardinality; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapDataset.java new file mode 100644 index 00000000000..19667eb06bf --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapDataset.java @@ -0,0 +1,161 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data.experimental; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + */ +public final class MapDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ExperimentalMapDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private MapDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ExperimentalMapDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param otherArguments the otherArguments value + * @param f the value of the f property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of MapDataset + */ + @Endpoint( + describeByClass = true + ) + public static MapDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalMapDataset", scope.makeOpName("MapDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + if (options != null) { + for (Options opts : options) { + if (opts.useInterOpParallelism != null) { + opBuilder.setAttr("use_inter_op_parallelism", opts.useInterOpParallelism); + } + if (opts.preserveCardinality != null) { + opBuilder.setAttr("preserve_cardinality", opts.preserveCardinality); + } + } + } + return new MapDataset(opBuilder.build()); + } + + /** + * Sets the useInterOpParallelism option. + * + * @param useInterOpParallelism the useInterOpParallelism option + * @return this Options instance. + */ + public static Options useInterOpParallelism(Boolean useInterOpParallelism) { + return new Options().useInterOpParallelism(useInterOpParallelism); + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public static Options preserveCardinality(Boolean preserveCardinality) { + return new Options().preserveCardinality(preserveCardinality); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.experimental.MapDataset} + */ + public static class Options { + private Boolean useInterOpParallelism; + + private Boolean preserveCardinality; + + private Options() { + } + + /** + * Sets the useInterOpParallelism option. + * + * @param useInterOpParallelism the useInterOpParallelism option + * @return this Options instance. + */ + public Options useInterOpParallelism(Boolean useInterOpParallelism) { + this.useInterOpParallelism = useInterOpParallelism; + return this; + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public Options preserveCardinality(Boolean preserveCardinality) { + this.preserveCardinality = preserveCardinality; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParallelInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParallelInterleaveDataset.java new file mode 100644 index 00000000000..5ec8ad9819a --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParallelInterleaveDataset.java @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data.experimental; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + * The resulting dataset is similar to the {@code InterleaveDataset}, except that the + * dataset will fetch records from the interleaved datasets in parallel. + * <p>The {@code tf.data} Python API creates instances of this op from + * {@code Dataset.interleave()} when the {@code num_parallel_calls} parameter of that method + * is set to any value other than {@code None}. + * <p>By default, the output of this dataset will be deterministic, which may result + * in the dataset blocking if the next data item to be returned isn't available. + * In order to avoid head-of-line blocking, one can either set the {@code deterministic} + * attribute to "false", or leave it as "default" and set the + * {@code experimental_deterministic} parameter of {@code tf.data.Options} to {@code False}. + * This can improve performance at the expense of non-determinism. + */ +public final class ParallelInterleaveDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ParallelInterleaveDatasetV4"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private ParallelInterleaveDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ParallelInterleaveDatasetV4 operation. + * + * @param scope current scope + * @param inputDataset Dataset that produces a stream of arguments for the function {@code f}. + * @param otherArguments Additional arguments to pass to {@code f} beyond those produced by {@code input_dataset}. + * Evaluated once when the dataset is instantiated. + * @param cycleLength Number of datasets (each created by applying {@code f} to the elements of + * {@code input_dataset}) among which the {@code ParallelInterleaveDatasetV2} will cycle in a + * round-robin fashion. + * @param blockLength Number of elements at a time to produce from each interleaved invocation of a + * dataset returned by {@code f}. + * @param bufferOutputElements The number of elements each iterator being interleaved should buffer (similar + * to the {@code .prefetch()} transformation for each interleaved iterator). + * @param prefetchInputElements Determines the number of iterators to prefetch, allowing buffers to warm up and + * data to be pre-fetched without blocking the main thread. + * @param numParallelCalls Determines the number of threads that should be used for fetching data from + * input datasets in parallel. The Python API {@code tf.data.experimental.AUTOTUNE} + * constant can be used to indicate that the level of parallelism should be autotuned. + * @param f A function mapping elements of {@code input_dataset}, concatenated with + * {@code other_arguments}, to a Dataset variant that contains elements matching + * {@code output_types} and {@code output_shapes}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ParallelInterleaveDataset + */ + @Endpoint( + describeByClass = true + ) + public static ParallelInterleaveDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, Operand<TInt64> cycleLength, Operand<TInt64> blockLength, + Operand<TInt64> bufferOutputElements, Operand<TInt64> prefetchInputElements, + Operand<TInt64> numParallelCalls, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("ParallelInterleaveDatasetV4", scope.makeOpName("ParallelInterleaveDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder.addInput(cycleLength.asOutput()); + opBuilder.addInput(blockLength.asOutput()); + opBuilder.addInput(bufferOutputElements.asOutput()); + opBuilder.addInput(prefetchInputElements.asOutput()); + opBuilder.addInput(numParallelCalls.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + if (options != null) { + for (Options opts : options) { + if (opts.deterministic != null) { + opBuilder.setAttr("deterministic", opts.deterministic); + } + } + } + return new ParallelInterleaveDataset(opBuilder.build()); + } + + /** + * Sets the deterministic option. + * + * @param deterministic A string indicating the op-level determinism to use. Deterministic controls + * whether the interleave is allowed to return elements out of order if the next + * element to be returned isn't available, but a later element is. Options are + * "true", "false", and "default". "default" indicates that determinism should be + * decided by the {@code experimental_deterministic} parameter of {@code tf.data.Options}. + * @return this Options instance. + */ + public static Options deterministic(String deterministic) { + return new Options().deterministic(deterministic); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.experimental.ParallelInterleaveDataset} + */ + public static class Options { + private String deterministic; + + private Options() { + } + + /** + * Sets the deterministic option. + * + * @param deterministic A string indicating the op-level determinism to use. Deterministic controls + * whether the interleave is allowed to return elements out of order if the next + * element to be returned isn't available, but a later element is. Options are + * "true", "false", and "default". "default" indicates that determinism should be + * decided by the {@code experimental_deterministic} parameter of {@code tf.data.Options}. + * @return this Options instance. + */ + public Options deterministic(String deterministic) { + this.deterministic = deterministic; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ScanDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ScanDataset.java new file mode 100644 index 00000000000..bbe5d3e07af --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ScanDataset.java @@ -0,0 +1,137 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data.experimental; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset successively reduces {@code f} over the elements of {@code input_dataset}. + */ +public final class ScanDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ExperimentalScanDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private ScanDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ExperimentalScanDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param initialState the initialState value + * @param otherArguments the otherArguments value + * @param f the value of the f property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ScanDataset + */ + @Endpoint( + describeByClass = true + ) + public static ScanDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> initialState, Iterable<Operand<?>> otherArguments, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalScanDataset", scope.makeOpName("ScanDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(initialState)); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + if (options != null) { + for (Options opts : options) { + if (opts.preserveCardinality != null) { + opBuilder.setAttr("preserve_cardinality", opts.preserveCardinality); + } + } + } + return new ScanDataset(opBuilder.build()); + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public static Options preserveCardinality(Boolean preserveCardinality) { + return new Options().preserveCardinality(preserveCardinality); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.experimental.ScanDataset} + */ + public static class Options { + private Boolean preserveCardinality; + + private Options() { + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public Options preserveCardinality(Boolean preserveCardinality) { + this.preserveCardinality = preserveCardinality; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/TakeWhileDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/TakeWhileDataset.java new file mode 100644 index 00000000000..a03264f0f4c --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/TakeWhileDataset.java @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data.experimental; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that stops iteration when predicate` is false. + * The {@code predicate} function must return a scalar boolean and accept the + * following arguments: + * <ul> + * <li>One tensor for each component of an element of {@code input_dataset}.</li> + * <li>One tensor for each value in {@code other_arguments}.</li> + * </ul> + */ +public final class TakeWhileDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ExperimentalTakeWhileDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private TakeWhileDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ExperimentalTakeWhileDataset operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param otherArguments A list of tensors, typically values that were captured when + * building a closure for {@code predicate}. + * @param predicate A function returning a scalar boolean. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of TakeWhileDataset + */ + @Endpoint( + describeByClass = true + ) + public static TakeWhileDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, ConcreteFunction predicate, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalTakeWhileDataset", scope.makeOpName("TakeWhileDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("predicate", predicate); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + return new TakeWhileDataset(opBuilder.build()); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/Compile.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/Compile.java new file mode 100644 index 00000000000..6e01d8f37e8 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/Compile.java @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.tpu; + +import java.util.Arrays; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; + +/** + * Compiles a computations for execution on one or more TPU devices. + * For the internal use of the distributed TPU compiler. + * <p>'num_computations' is the number of computations to be compiled. + * 'function' is a function containing the computation to compile. + * 'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not + * known statically at TPUReplication rewrite time. + * 'guaranteed_constants' is a list of tensors which have been guaranteed to not + * change their values during the session lifetime. These contain tensors marked as + * constant using the GuaranteeConstOp. + * 'metadata' is a serialized TPUCompileMetadataProto describing + * the shapes and types of the inputs to the computation, as well as a mapping onto + * the TPU pod topology. + * Each 'program' output is a string key that is passed to the _TPUExecute op and + * used to look up the program in the compilation cache. + * 'may_modify_variables' indicates whether variables may be modified. + */ +@Operator( + group = "tpu" +) +public final class Compile extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "TPUCompile"; + + private Output<TString> compilationStatus; + + private List<Output<TString>> program; + + private List<Output<TBool>> mayModifyVariables; + + @SuppressWarnings("unchecked") + private Compile(Operation operation) { + super(operation); + int outputIdx = 0; + compilationStatus = operation.output(outputIdx++); + int programLength = operation.outputListLength("program"); + program = Arrays.asList((Output<TString>[]) operation.outputList(outputIdx, programLength)); + outputIdx += programLength; + int mayModifyVariablesLength = operation.outputListLength("may_modify_variables"); + mayModifyVariables = Arrays.asList((Output<TBool>[]) operation.outputList(outputIdx, mayModifyVariablesLength)); + outputIdx += mayModifyVariablesLength; + } + + /** + * Factory method to create a class wrapping a new TPUCompile operation. + * + * @param scope current scope + * @param dynamicShapes the dynamicShapes value + * @param guaranteedConstants the guaranteedConstants value + * @param numComputations the value of the numComputations property + * @param function the value of the function property + * @param metadata the value of the metadata property + * @return a new instance of Compile + */ + @Endpoint( + describeByClass = true + ) + public static Compile create(Scope scope, Iterable<Operand<TInt64>> dynamicShapes, + Iterable<Operand<?>> guaranteedConstants, Long numComputations, ConcreteFunction function, + String metadata) { + OperationBuilder opBuilder = scope.env().opBuilder("TPUCompile", scope.makeOpName("Compile")); + opBuilder.addInputList(Operands.asOutputs(dynamicShapes)); + opBuilder.addInputList(Operands.asOutputs(guaranteedConstants)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("num_computations", numComputations); + opBuilder.setAttr("function", function); + opBuilder.setAttr("metadata", metadata); + return new Compile(opBuilder.build()); + } + + /** + * Gets compilationStatus. + * + * @return compilationStatus. + */ + public Output<TString> compilationStatus() { + return compilationStatus; + } + + /** + * Gets program. + * + * @return program. + */ + public List<Output<TString>> program() { + return program; + } + + /** + * Gets mayModifyVariables. + * + * @return mayModifyVariables. + */ + public List<Output<TBool>> mayModifyVariables() { + return mayModifyVariables; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/PartitionedCall.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/PartitionedCall.java new file mode 100644 index 00000000000..90a536c99e2 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/PartitionedCall.java @@ -0,0 +1,133 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.tpu; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; + +/** + * Calls a function placed on a specified TPU device. + */ +public final class PartitionedCall extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "TPUPartitionedCall"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private PartitionedCall(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new TPUPartitionedCall operation. + * + * @param scope current scope + * @param args The arguments to the function. + * @param deviceOrdinal The TPU device ordinal to run the function on. + * @param Tout The types of the outputs of the function. + * @param f The function to call. + * @param options carries optional attribute values + * @return a new instance of PartitionedCall + */ + @Endpoint( + describeByClass = true + ) + public static PartitionedCall create(Scope scope, Iterable<Operand<?>> args, + Operand<TInt32> deviceOrdinal, List<Class<? extends TType>> Tout, ConcreteFunction f, + Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("TPUPartitionedCall", scope.makeOpName("PartitionedCall")); + opBuilder.addInputList(Operands.asOutputs(args)); + opBuilder.addInput(deviceOrdinal.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); + opBuilder.setAttr("f", f); + if (options != null) { + for (Options opts : options) { + if (opts.autotunerThresh != null) { + opBuilder.setAttr("autotuner_thresh", opts.autotunerThresh); + } + } + } + return new PartitionedCall(opBuilder.build()); + } + + /** + * Sets the autotunerThresh option. + * + * @param autotunerThresh the autotunerThresh option + * @return this Options instance. + */ + public static Options autotunerThresh(Long autotunerThresh) { + return new Options().autotunerThresh(autotunerThresh); + } + + /** + * Gets output. + * The output of the function call. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.tpu.PartitionedCall} + */ + public static class Options { + private Long autotunerThresh; + + private Options() { + } + + /** + * Sets the autotunerThresh option. + * + * @param autotunerThresh the autotunerThresh option + * @return this Options instance. + */ + public Options autotunerThresh(Long autotunerThresh) { + this.autotunerThresh = autotunerThresh; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SymbolicGradient.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SymbolicGradient.java new file mode 100644 index 00000000000..be3fccf6f90 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SymbolicGradient.java @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.train; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * Computes the gradient function for function f via backpropagation. + */ +@Operator( + group = "train" +) +public final class SymbolicGradient extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "SymbolicGradient"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private SymbolicGradient(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new SymbolicGradient operation. + * + * @param scope current scope + * @param input a list of input tensors of size N + M; + * @param Tout the type list for the input list. + * @param f The function we want to compute the gradient for. + * <p>The function 'f' must be a numerical function which takes N inputs and + * produces M outputs. Its gradient function 'g', which is computed by + * this SymbolicGradient op is a function taking N + M inputs and + * produces N outputs. + * <p>I.e. if we have + * (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), + * then, g is + * (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, + * dL/dy1, dL/dy2, ..., dL/dy_M), + * <p>where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the + * loss function). dL/dx_i is the partial derivative of L with respect + * to x_i. + * <p>(Needs some math expert to say the comment above better.) + * @return a new instance of SymbolicGradient + */ + @Endpoint( + describeByClass = true + ) + public static SymbolicGradient create(Scope scope, Iterable<Operand<?>> input, + List<Class<? extends TType>> Tout, ConcreteFunction f) { + OperationBuilder opBuilder = scope.env().opBuilder("SymbolicGradient", scope.makeOpName("SymbolicGradient")); + opBuilder.addInputList(Operands.asOutputs(input)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); + opBuilder.setAttr("f", f); + return new SymbolicGradient(opBuilder.build()); + } + + /** + * Gets output. + * a list of output tensors of size N; + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/If.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/If.java new file mode 100644 index 00000000000..a5d13be15ac --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/If.java @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * output = cond ? then_branch(inputs) : else_branch(inputs). + */ +@Operator( + group = "xla" +) +public final class If extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaIf"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private If(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new XlaIf operation. + * + * @param scope current scope + * @param cond A boolean scalar. + * @param inputs A list of input tensors. + * @param thenBranch A function takes 'inputs' and returns a list of tensors, + * whose types are the same as what else_branch returns. + * @param elseBranch A function takes 'inputs' and returns a list of tensors. + * whose types are the same as what then_branch returns. + * @param Tout the value of the Tout property + * @return a new instance of If + */ + @Endpoint( + describeByClass = true + ) + public static If create(Scope scope, Operand<? extends TType> cond, Iterable<Operand<?>> inputs, + ConcreteFunction thenBranch, ConcreteFunction elseBranch, List<Class<? extends TType>> Tout) { + OperationBuilder opBuilder = scope.env().opBuilder("XlaIf", scope.makeOpName("If")); + opBuilder.addInput(cond.asOutput()); + opBuilder.addInputList(Operands.asOutputs(inputs)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("then_branch", thenBranch); + opBuilder.setAttr("else_branch", elseBranch); + opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); + return new If(opBuilder.build()); + } + + /** + * Gets output. + * A list of tensors returned by either then_branch(inputs) or + * else_branch(inputs). The input shapes of the then_branch and + * else_branch must match. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Reduce.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Reduce.java new file mode 100644 index 00000000000..a0f70a32564 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Reduce.java @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * Wraps the XLA Reduce operator, documented at + * https://www.tensorflow.org/performance/xla/operation_semantics#reduce . + * + * @param <T> data type for {@code output} output + */ +@Operator( + group = "xla" +) +public final class Reduce<T extends TType> extends RawOp implements Operand<T> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaReduce"; + + private Output<T> output; + + private Reduce(Operation operation) { + super(operation); + int outputIdx = 0; + output = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaReduce operation. + * + * @param scope current scope + * @param input the input tensor + * @param initValue a scalar representing the initial value for the reduction + * @param dimensionsToReduce dimension numbers over which to reduce + * @param reducer a reducer function to apply + * @param <T> data type for {@code XlaReduce} output and operands + * @return a new instance of Reduce + */ + @Endpoint( + describeByClass = true + ) + public static <T extends TType> Reduce<T> create(Scope scope, Operand<T> input, + Operand<T> initValue, List<Long> dimensionsToReduce, ConcreteFunction reducer) { + OperationBuilder opBuilder = scope.env().opBuilder("XlaReduce", scope.makeOpName("Reduce")); + opBuilder.addInput(input.asOutput()); + opBuilder.addInput(initValue.asOutput()); + opBuilder = scope.apply(opBuilder); + long[] dimensionsToReduceArray = new long[dimensionsToReduce.size()]; + for (int i = 0 ; i < dimensionsToReduceArray.length ; i++) { + dimensionsToReduceArray[i] = dimensionsToReduce.get(i); + } + opBuilder.setAttr("dimensions_to_reduce", dimensionsToReduceArray); + opBuilder.setAttr("reducer", reducer); + return new Reduce<>(opBuilder.build()); + } + + /** + * Gets output. + * + * @return output. + */ + public Output<T> output() { + return output; + } + + @Override + public Output<T> asOutput() { + return output; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/ReduceWindow.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/ReduceWindow.java new file mode 100644 index 00000000000..40d94d31a06 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/ReduceWindow.java @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +/** + * Wraps the XLA ReduceWindow operator, documented at + * https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + * + * @param <T> data type for {@code output} output + */ +@Operator( + group = "xla" +) +public final class ReduceWindow<T extends TType> extends RawOp implements Operand<T> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaReduceWindow"; + + private Output<T> output; + + private ReduceWindow(Operation operation) { + super(operation); + int outputIdx = 0; + output = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaReduceWindow operation. + * + * @param scope current scope + * @param input the input tensor + * @param initValue a scalar representing the initial value for the reduction + * @param windowDimensions the shape of the window + * @param windowStrides the inter-window strides + * @param baseDilations the baseDilations value + * @param windowDilations the windowDilations value + * @param padding the padding to apply at the start and end of each input dimensions + * @param computation a reducer function to apply + * @param <T> data type for {@code XlaReduceWindow} output and operands + * @param <U> data type for {@code XlaReduceWindow} output and operands + * @return a new instance of ReduceWindow + */ + @Endpoint( + describeByClass = true + ) + public static <T extends TType, U extends TNumber> ReduceWindow<T> create(Scope scope, + Operand<T> input, Operand<T> initValue, Operand<U> windowDimensions, Operand<U> windowStrides, + Operand<U> baseDilations, Operand<U> windowDilations, Operand<U> padding, + ConcreteFunction computation) { + OperationBuilder opBuilder = scope.env().opBuilder("XlaReduceWindow", scope.makeOpName("ReduceWindow")); + opBuilder.addInput(input.asOutput()); + opBuilder.addInput(initValue.asOutput()); + opBuilder.addInput(windowDimensions.asOutput()); + opBuilder.addInput(windowStrides.asOutput()); + opBuilder.addInput(baseDilations.asOutput()); + opBuilder.addInput(windowDilations.asOutput()); + opBuilder.addInput(padding.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("computation", computation); + return new ReduceWindow<>(opBuilder.build()); + } + + /** + * Gets output. + * + * @return output. + */ + public Output<T> output() { + return output; + } + + @Override + public Output<T> asOutput() { + return output; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Scatter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Scatter.java new file mode 100644 index 00000000000..7148f19f805 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Scatter.java @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +/** + * Wraps the XLA Scatter operator documented at + * https://www.tensorflow.org/xla/operation_semantics#scatter. + * + * @param <T> data type for {@code output} output + */ +@Operator( + group = "xla" +) +public final class Scatter<T extends TType> extends RawOp implements Operand<T> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaScatter"; + + private Output<T> output; + + private Scatter(Operation operation) { + super(operation); + int outputIdx = 0; + output = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaScatter operation. + * + * @param scope current scope + * @param operand Array to be scattered into. + * @param scatterIndices Array containing the starting indices of the slices that must + * be scattered to. + * @param updates Array containing the values that must be used for scattering. + * @param updateComputation Computation to be used for combining the existing values in + * the input array and the updates during scatter. + * @param dimensionNumbers A serialized xla::ScatterDimensionNumbers proto. + * @param indicesAreSorted Boolean indicating if the indices are sorted. + * @param <T> data type for {@code XlaScatter} output and operands + * @return a new instance of Scatter + */ + @Endpoint( + describeByClass = true + ) + public static <T extends TType> Scatter<T> create(Scope scope, Operand<T> operand, + Operand<? extends TNumber> scatterIndices, Operand<T> updates, + ConcreteFunction updateComputation, String dimensionNumbers, Boolean indicesAreSorted) { + OperationBuilder opBuilder = scope.env().opBuilder("XlaScatter", scope.makeOpName("Scatter")); + opBuilder.addInput(operand.asOutput()); + opBuilder.addInput(scatterIndices.asOutput()); + opBuilder.addInput(updates.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("update_computation", updateComputation); + opBuilder.setAttr("dimension_numbers", dimensionNumbers); + opBuilder.setAttr("indices_are_sorted", indicesAreSorted); + return new Scatter<>(opBuilder.build()); + } + + /** + * Gets output. + * + * @return output. + */ + public Output<T> output() { + return output; + } + + @Override + public Output<T> asOutput() { + return output; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SelectAndScatter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SelectAndScatter.java new file mode 100644 index 00000000000..6912fd70677 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SelectAndScatter.java @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +/** + * Wraps the XLA SelectAndScatter operator, documented at + * https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter + * . + * + * @param <T> data type for {@code output} output + */ +@Operator( + group = "xla" +) +public final class SelectAndScatter<T extends TType> extends RawOp implements Operand<T> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSelectAndScatter"; + + private Output<T> output; + + private SelectAndScatter(Operation operation) { + super(operation); + int outputIdx = 0; + output = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSelectAndScatter operation. + * + * @param scope current scope + * @param operand the input tensor + * @param windowDimensions the shape of the window + * @param windowStrides the inter-window strides + * @param padding the padding to apply at the start and end of each input dimensions + * @param source a tensor of values to scatter + * @param initValue a scalar representing the initial value for the output tensor + * @param select a selection function to apply + * @param scatter a scatter function to apply + * @param <T> data type for {@code XlaSelectAndScatter} output and operands + * @param <U> data type for {@code XlaSelectAndScatter} output and operands + * @return a new instance of SelectAndScatter + */ + @Endpoint( + describeByClass = true + ) + public static <T extends TType, U extends TNumber> SelectAndScatter<T> create(Scope scope, + Operand<T> operand, Operand<U> windowDimensions, Operand<U> windowStrides, Operand<U> padding, + Operand<T> source, Operand<T> initValue, ConcreteFunction select, ConcreteFunction scatter) { + OperationBuilder opBuilder = scope.env().opBuilder("XlaSelectAndScatter", scope.makeOpName("SelectAndScatter")); + opBuilder.addInput(operand.asOutput()); + opBuilder.addInput(windowDimensions.asOutput()); + opBuilder.addInput(windowStrides.asOutput()); + opBuilder.addInput(padding.asOutput()); + opBuilder.addInput(source.asOutput()); + opBuilder.addInput(initValue.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("select", select); + opBuilder.setAttr("scatter", scatter); + return new SelectAndScatter<>(opBuilder.build()); + } + + /** + * Gets output. + * + * @return output. + */ + public Output<T> output() { + return output; + } + + @Override + public Output<T> asOutput() { + return output; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/While.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/While.java new file mode 100644 index 00000000000..e1a69eb1c6d --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/While.java @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * output = input; While (Cond(output)) { output = Body(output) } + */ +@Operator( + group = "xla" +) +public final class While extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaWhile"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private While(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new XlaWhile operation. + * + * @param scope current scope + * @param input A list of input tensors whose types are T. + * @param cond A function takes 'input' and returns a tensor. If the tensor is + * a scalar of non-boolean, the scalar is converted to a boolean + * according to the following rule: if the scalar is a numerical + * value, non-zero means True and zero means False; if the scalar is + * a string, non-empty means True and empty means False. If the + * tensor is not a scalar, non-emptiness means True and False + * otherwise. + * @param body A function that takes a list of tensors and returns another + * list of tensors. Both lists have the same types as specified by T. + * @return a new instance of While + */ + @Endpoint( + describeByClass = true + ) + public static While create(Scope scope, Iterable<Operand<?>> input, ConcreteFunction cond, + ConcreteFunction body) { + OperationBuilder opBuilder = scope.env().opBuilder("XlaWhile", scope.makeOpName("While")); + opBuilder.addInputList(Operands.asOutputs(input)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("cond", cond); + opBuilder.setAttr("body", body); + return new While(opBuilder.build()); + } + + /** + * Gets output. + * A list of output tensors whose types are T. + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaHostCompute.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaHostCompute.java new file mode 100644 index 00000000000..c45377d32e1 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaHostCompute.java @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * A pseudo-op to represent host-side computation in an XLA program. + */ +@Operator( + group = "xla" +) +public final class XlaHostCompute extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaHostCompute"; + + private List<Output<?>> outputs; + + @SuppressWarnings("unchecked") + private XlaHostCompute(Operation operation) { + super(operation); + int outputIdx = 0; + int outputsLength = operation.outputListLength("outputs"); + outputs = Arrays.asList(operation.outputList(outputIdx, outputsLength)); + outputIdx += outputsLength; + } + + /** + * Factory method to create a class wrapping a new XlaHostCompute operation. + * + * @param scope current scope + * @param inputs A list of tensors that will be sent to the host. + * @param Toutputs The element types of each element in {@code outputs}. + * @param ancestors A list of names of HostCompute computations that must be + * sequenced before this computation. + * @param shapes If shape_inference_graph is empty, a list of the shapes of {@code outputs}. + * @param shapeInferenceGraph If non-empty, a serialized GraphDef representing a graph + * that must be analyzed at compile time to determine the shapes of the outputs. + * @param key A unique identifier for this region used to match up host transfers. + * @param options carries optional attribute values + * @return a new instance of XlaHostCompute + */ + @Endpoint( + describeByClass = true + ) + public static XlaHostCompute create(Scope scope, Iterable<Operand<?>> inputs, + List<Class<? extends TType>> Toutputs, List<String> ancestors, List<Shape> shapes, + ConcreteFunction shapeInferenceGraph, String key, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("XlaHostCompute", scope.makeOpName("XlaHostCompute")); + opBuilder.addInputList(Operands.asOutputs(inputs)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("Toutputs", Operands.toDataTypes(Toutputs)); + String[] ancestorsArray = new String[ancestors.size()]; + for (int i = 0 ; i < ancestorsArray.length ; i++) { + ancestorsArray[i] = ancestors.get(i); + } + opBuilder.setAttr("ancestors", ancestorsArray); + Shape[] shapesArray = new Shape[shapes.size()]; + for (int i = 0 ; i < shapesArray.length ; i++) { + shapesArray[i] = shapes.get(i); + } + opBuilder.setAttr("shapes", shapesArray); + opBuilder.setAttr("shape_inference_graph", shapeInferenceGraph); + opBuilder.setAttr("key", key); + if (options != null) { + for (Options opts : options) { + if (opts.costEstimateNs != null) { + opBuilder.setAttr("cost_estimate_ns", opts.costEstimateNs); + } + if (opts.tpuCore != null) { + opBuilder.setAttr("tpu_core", opts.tpuCore); + } + } + } + return new XlaHostCompute(opBuilder.build()); + } + + /** + * Sets the costEstimateNs option. + * + * @param costEstimateNs Estimated duration of the host computation in nanoseconds. + * @return this Options instance. + */ + public static Options costEstimateNs(Long costEstimateNs) { + return new Options().costEstimateNs(costEstimateNs); + } + + /** + * Sets the tpuCore option. + * + * @param tpuCore Default core to use for host to device transfers. + * @return this Options instance. + */ + public static Options tpuCore(Long tpuCore) { + return new Options().tpuCore(tpuCore); + } + + /** + * Gets outputs. + * A list of tensors that will be returned to the device. + * @return outputs. + */ + public List<Output<?>> outputs() { + return outputs; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) outputs.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.xla.XlaHostCompute} + */ + public static class Options { + private Long costEstimateNs; + + private Long tpuCore; + + private Options() { + } + + /** + * Sets the costEstimateNs option. + * + * @param costEstimateNs Estimated duration of the host computation in nanoseconds. + * @return this Options instance. + */ + public Options costEstimateNs(Long costEstimateNs) { + this.costEstimateNs = costEstimateNs; + return this; + } + + /** + * Sets the tpuCore option. + * + * @param tpuCore Default core to use for host to device transfers. + * @return this Options instance. + */ + public Options tpuCore(Long tpuCore) { + this.tpuCore = tpuCore; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaLaunch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaLaunch.java new file mode 100644 index 00000000000..da119954460 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaLaunch.java @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * XLA Launch Op. For use by the XLA JIT only. + */ +@Operator( + group = "xla" +) +public final class XlaLaunch extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaLaunch"; + + private List<Output<?>> results; + + @SuppressWarnings("unchecked") + private XlaLaunch(Operation operation) { + super(operation); + int outputIdx = 0; + int resultsLength = operation.outputListLength("results"); + results = Arrays.asList(operation.outputList(outputIdx, resultsLength)); + outputIdx += resultsLength; + } + + /** + * Factory method to create a class wrapping a new XlaLaunch operation. + * + * @param scope current scope + * @param constants the constants value + * @param args the args value + * @param resources the resources value + * @param Tresults the value of the Tresults property + * @param function the value of the function property + * @return a new instance of XlaLaunch + */ + @Endpoint( + describeByClass = true + ) + public static XlaLaunch create(Scope scope, Iterable<Operand<?>> constants, + Iterable<Operand<?>> args, Iterable<Operand<? extends TType>> resources, + List<Class<? extends TType>> Tresults, ConcreteFunction function) { + OperationBuilder opBuilder = scope.env().opBuilder("XlaLaunch", scope.makeOpName("XlaLaunch")); + opBuilder.addInputList(Operands.asOutputs(constants)); + opBuilder.addInputList(Operands.asOutputs(args)); + opBuilder.addInputList(Operands.asOutputs(resources)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("Tresults", Operands.toDataTypes(Tresults)); + opBuilder.setAttr("function", function); + return new XlaLaunch(opBuilder.build()); + } + + /** + * Gets results. + * + * @return results. + */ + public List<Output<?>> results() { + return results; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) results.iterator(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicReduce.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicReduce.java new file mode 100644 index 00000000000..6ce2ac4ff3e --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicReduce.java @@ -0,0 +1,105 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * Wraps the variadic XLA Reduce operator, documented at + * https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce. + * + * @param <T> data type for {@code output} output + */ +@Operator( + group = "xla" +) +public final class XlaVariadicReduce<T extends TType> extends RawOp implements Iterable<Operand<T>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaVariadicReduce"; + + private List<Output<T>> output; + + @SuppressWarnings("unchecked") + private XlaVariadicReduce(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList((Output<T>[]) operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new XlaVariadicReduce operation. + * + * @param scope current scope + * @param input the input tensor(s) + * @param initValue scalar initial value(s) for the reduction + * @param dimensionsToReduce dimension numbers over which to reduce + * @param reducer a reducer function to apply + * @param <T> data type for {@code XlaVariadicReduce} output and operands + * @return a new instance of XlaVariadicReduce + */ + @Endpoint( + describeByClass = true + ) + public static <T extends TType> XlaVariadicReduce<T> create(Scope scope, + Iterable<Operand<T>> input, Iterable<Operand<T>> initValue, List<Long> dimensionsToReduce, + ConcreteFunction reducer) { + OperationBuilder opBuilder = scope.env().opBuilder("XlaVariadicReduce", scope.makeOpName("XlaVariadicReduce")); + opBuilder.addInputList(Operands.asOutputs(input)); + opBuilder.addInputList(Operands.asOutputs(initValue)); + opBuilder = scope.apply(opBuilder); + long[] dimensionsToReduceArray = new long[dimensionsToReduce.size()]; + for (int i = 0 ; i < dimensionsToReduceArray.length ; i++) { + dimensionsToReduceArray[i] = dimensionsToReduce.get(i); + } + opBuilder.setAttr("dimensions_to_reduce", dimensionsToReduceArray); + opBuilder.setAttr("reducer", reducer); + return new XlaVariadicReduce<>(opBuilder.build()); + } + + /** + * Gets output. + * + * @return output. + */ + public List<Output<T>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<T>> iterator() { + return (Iterator) output.iterator(); + } +} diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java index be6d0a32392..e5aab27bb69 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java @@ -57,7 +57,6 @@ final class ClassGenerator { /** Return true if we can generate the operation class for {@code op}. */ static boolean canGenerateOp(OpDef op, ApiDef apiDef) { return apiDef.getVisibility() != Visibility.SKIP - && !op.getAttrList().stream().anyMatch(x -> x.getType().contains("func")) && !op.getName() .startsWith("_"); // TODO do I want this? Some interesting ops like _XlaCompile } From 2d03c4771786c159aee6e21bc24bb29b235b4c7f Mon Sep 17 00:00:00 2001 From: Ryan Nett <JNett96@gmail.com> Date: Mon, 31 May 2021 16:45:22 -0700 Subject: [PATCH 02/14] Apply format profile Signed-off-by: Ryan Nett <JNett96@gmail.com> --- pom.xml | 39 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index f9b94e8164a..ee498135d2a 100644 --- a/pom.xml +++ b/pom.xml @@ -1,4 +1,6 @@ -<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns="http://maven.apache.org/POM/4.0.0" +<project + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" + xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"> <modelVersion>4.0.0</modelVersion> @@ -86,7 +88,9 @@ </snapshotRepository> <repository> <id>ossrh</id> - <url>https://oss.sonatype.org/service/local/staging/deployByRepositoryId/${stagingRepositoryId}/</url> + <url> + https://oss.sonatype.org/service/local/staging/deployByRepositoryId/${stagingRepositoryId}/ + </url> </repository> </distributionManagement> @@ -156,7 +160,9 @@ <repository> <id>ossrh-staging</id> <name>OSSRH Sonatype Staging</name> - <url>https://oss.sonatype.org/service/local/staging/deployByRepositoryId/${stagingRepositoryId}/</url> + <url> + https://oss.sonatype.org/service/local/staging/deployByRepositoryId/${stagingRepositoryId}/ + </url> <releases> <enabled>true</enabled> </releases> @@ -257,6 +263,33 @@ </plugins> </build> </profile> + + <!-- + Profile to run spotless:apply on builds. Will run before format's check. + --> + <profile> + <id>apply-format</id> + <build> + <plugins> + <plugin> + <groupId>com.diffplug.spotless</groupId> + <artifactId>spotless-maven-plugin</artifactId> + <version>${spotless.version}</version> + + <executions> + <execution> + <!-- Runs in initialize phase to fail fast in case of formatting issues (should be before codegen).--> + <id>spotless-check</id> + <phase>initialize</phase> + <goals> + <goal>apply</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> + </profile> </profiles> <!-- http://central.sonatype.org/pages/requirements.html#developer-information --> From ca4ee0103ef8418fd370629921fedc589f283971 Mon Sep 17 00:00:00 2001 From: Ryan Nett <JNett96@gmail.com> Date: Mon, 31 May 2021 16:46:12 -0700 Subject: [PATCH 03/14] Fix names Signed-off-by: Ryan Nett <JNett96@gmail.com> --- .../org/tensorflow/op/DataOps.java | 125 +++++++ .../annotations/org/tensorflow/op/Ops.java | 317 ++++++++++++++++++ .../annotations/org/tensorflow/op/TpuOps.java | 33 ++ .../org/tensorflow/op/TrainOps.java | 28 ++ .../annotations/org/tensorflow/op/XlaOps.java | 187 +++++++++++ .../gen/java/org/tensorflow/op/core/Case.java | 3 +- .../gen/java/org/tensorflow/op/core/For.java | 3 +- .../gen/java/org/tensorflow/op/core/If.java | 3 +- .../java/org/tensorflow/op/core/While.java | 3 +- .../gen/java/org/tensorflow/op/xla/If.java | 3 +- .../gen/java/org/tensorflow/op/xla/While.java | 3 +- .../op/generator/ClassGenerator.java | 38 ++- .../tensorflow/op/generator/FullOpDef.java | 119 +++++++ .../op/generator/GeneratorUtils.java | 56 ++-- .../tensorflow/op/generator/OpGenerator.java | 165 +++++---- .../tensorflow/op/generator/StatefulPair.java | 27 ++ 16 files changed, 982 insertions(+), 131 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java create mode 100644 tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java index 4197dac5fee..edf3d88f8ed 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java @@ -18,6 +18,7 @@ package org.tensorflow.op; import java.util.List; +import org.tensorflow.ConcreteFunction; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.data.AnonymousIterator; @@ -25,12 +26,17 @@ import org.tensorflow.op.data.ConcatenateDataset; import org.tensorflow.op.data.DeleteIterator; import org.tensorflow.op.data.DeserializeIterator; +import org.tensorflow.op.data.FilterDataset; +import org.tensorflow.op.data.FlatMapDataset; +import org.tensorflow.op.data.InterleaveDataset; import org.tensorflow.op.data.Iterator; import org.tensorflow.op.data.IteratorGetNext; import org.tensorflow.op.data.IteratorGetNextAsOptional; import org.tensorflow.op.data.IteratorGetNextSync; import org.tensorflow.op.data.IteratorToStringHandle; import org.tensorflow.op.data.MakeIterator; +import org.tensorflow.op.data.MapDataset; +import org.tensorflow.op.data.OneShotIterator; import org.tensorflow.op.data.OptionalFromValue; import org.tensorflow.op.data.OptionalGetValue; import org.tensorflow.op.data.OptionalHasValue; @@ -134,6 +140,75 @@ public DeserializeIterator deserializeIterator(Operand<? extends TType> resource return DeserializeIterator.create(scope, resourceHandle, serialized); } + /** + * Creates a dataset containing elements of {@code input_dataset} matching {@code predicate}. + * The {@code predicate} function must return a scalar boolean and accept the + * following arguments: + * <ul> + * <li>One tensor for each component of an element of {@code input_dataset}.</li> + * <li>One tensor for each value in {@code other_arguments}.</li> + * </ul> + * + * @param inputDataset the inputDataset value + * @param otherArguments A list of tensors, typically values that were captured when + * building a closure for {@code predicate}. + * @param predicate A function returning a scalar boolean. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of FilterDataset + */ + public FilterDataset filterDataset(Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, ConcreteFunction predicate, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return FilterDataset.create(scope, inputDataset, otherArguments, predicate, outputTypes, outputShapes); + } + + /** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + * Unlike MapDataset, the {@code f} in FlatMapDataset is expected to return a + * Dataset variant, and FlatMapDataset will flatten successive results + * into a single Dataset. + * + * @param inputDataset the inputDataset value + * @param otherArguments the otherArguments value + * @param f A function mapping elements of {@code input_dataset}, concatenated with + * {@code other_arguments}, to a Dataset variant that contains elements matching + * {@code output_types} and {@code output_shapes}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of FlatMapDataset + */ + public FlatMapDataset flatMapDataset(Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return FlatMapDataset.create(scope, inputDataset, otherArguments, f, outputTypes, outputShapes); + } + + /** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + * Unlike MapDataset, the {@code f} in InterleaveDataset is expected to return + * a Dataset variant, and InterleaveDataset will flatten successive + * results into a single Dataset. Unlike FlatMapDataset, + * InterleaveDataset will interleave sequences of up to {@code block_length} + * consecutive elements from {@code cycle_length} input elements. + * + * @param inputDataset the inputDataset value + * @param otherArguments the otherArguments value + * @param cycleLength the cycleLength value + * @param blockLength the blockLength value + * @param f A function mapping elements of {@code input_dataset}, concatenated with + * {@code other_arguments}, to a Dataset variant that contains elements matching + * {@code output_types} and {@code output_shapes}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of InterleaveDataset + */ + public InterleaveDataset interleaveDataset(Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, Operand<TInt64> cycleLength, Operand<TInt64> blockLength, + ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return InterleaveDataset.create(scope, inputDataset, otherArguments, cycleLength, blockLength, f, outputTypes, outputShapes); + } + /** * The IteratorV2 operation * @@ -215,6 +290,56 @@ public MakeIterator makeIterator(Operand<? extends TType> dataset, return MakeIterator.create(scope, dataset, iterator); } + /** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + * + * @param inputDataset the inputDataset value + * @param otherArguments the otherArguments value + * @param f the value of the f property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of MapDataset + */ + public MapDataset mapDataset(Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + MapDataset.Options... options) { + return MapDataset.create(scope, inputDataset, otherArguments, f, outputTypes, outputShapes, options); + } + + /** + * Makes a "one-shot" iterator that can be iterated only once. + * A one-shot iterator bundles the logic for defining the dataset and + * the state of the iterator in a single op, which allows simple input + * pipelines to be defined without an additional initialization + * ("MakeIterator") step. + * <p>One-shot iterators have the following limitations: + * <ul> + * <li>They do not support parameterization: all logic for creating the underlying + * dataset must be bundled in the {@code dataset_factory} function.</li> + * <li>They are not resettable. Once a one-shot iterator reaches the end of its + * underlying dataset, subsequent "IteratorGetNext" operations on that + * iterator will always produce an {@code OutOfRange} error.</li> + * </ul> + * <p>For greater flexibility, use "Iterator" and "MakeIterator" to define + * an iterator using an arbitrary subgraph, which may capture tensors + * (including fed values) as parameters, and which may be reset multiple + * times by rerunning "MakeIterator". + * + * @param datasetFactory A function of type {@code () -> DT_VARIANT}, where the returned + * DT_VARIANT is a dataset. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of OneShotIterator + */ + public OneShotIterator oneShotIterator(ConcreteFunction datasetFactory, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + OneShotIterator.Options... options) { + return OneShotIterator.create(scope, datasetFactory, outputTypes, outputShapes, options); + } + /** * Constructs an Optional variant from a tuple of tensors. * diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index a4a7f5d6dbc..d6402581da5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -58,6 +58,7 @@ import org.tensorflow.op.core.BarrierReadySize; import org.tensorflow.op.core.BarrierTakeMany; import org.tensorflow.op.core.Batch; +import org.tensorflow.op.core.BatchFunction; import org.tensorflow.op.core.BatchToSpace; import org.tensorflow.op.core.BatchToSpaceNd; import org.tensorflow.op.core.Bitcast; @@ -66,6 +67,7 @@ import org.tensorflow.op.core.BroadcastDynamicShape; import org.tensorflow.op.core.BroadcastTo; import org.tensorflow.op.core.Bucketize; +import org.tensorflow.op.core.Case; import org.tensorflow.op.core.ClipByValue; import org.tensorflow.op.core.Concat; import org.tensorflow.op.core.Constant; @@ -89,6 +91,7 @@ import org.tensorflow.op.core.ExtractVolumePatches; import org.tensorflow.op.core.Fill; import org.tensorflow.op.core.Fingerprint; +import org.tensorflow.op.core.For; import org.tensorflow.op.core.Function; import org.tensorflow.op.core.Gather; import org.tensorflow.op.core.GatherNd; @@ -101,6 +104,7 @@ import org.tensorflow.op.core.HistogramFixedWidth; import org.tensorflow.op.core.Identity; import org.tensorflow.op.core.IdentityN; +import org.tensorflow.op.core.If; import org.tensorflow.op.core.ImmutableConst; import org.tensorflow.op.core.Init; import org.tensorflow.op.core.InitializeTable; @@ -149,6 +153,7 @@ import org.tensorflow.op.core.Pad; import org.tensorflow.op.core.ParallelConcat; import org.tensorflow.op.core.ParallelDynamicStitch; +import org.tensorflow.op.core.PartitionedCall; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.PlaceholderWithDefault; import org.tensorflow.op.core.Print; @@ -166,6 +171,8 @@ import org.tensorflow.op.core.RefNextIteration; import org.tensorflow.op.core.RefSelect; import org.tensorflow.op.core.RefSwitch; +import org.tensorflow.op.core.RemoteCall; +import org.tensorflow.op.core.RemoteFusedGraphExecute; import org.tensorflow.op.core.Reshape; import org.tensorflow.op.core.ResourceCountUpTo; import org.tensorflow.op.core.ResourceGather; @@ -215,6 +222,9 @@ import org.tensorflow.op.core.StageClear; import org.tensorflow.op.core.StagePeek; import org.tensorflow.op.core.StageSize; +import org.tensorflow.op.core.StatefulPartitionedCall; +import org.tensorflow.op.core.StatelessIf; +import org.tensorflow.op.core.StatelessWhile; import org.tensorflow.op.core.StopGradient; import org.tensorflow.op.core.StridedSlice; import org.tensorflow.op.core.StridedSliceAssign; @@ -281,6 +291,7 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.op.core.VariableShape; import org.tensorflow.op.core.Where; +import org.tensorflow.op.core.While; import org.tensorflow.op.core.XlaConvV2; import org.tensorflow.op.core.XlaDotV2; import org.tensorflow.op.core.XlaSetDynamicDimensionSize; @@ -776,6 +787,61 @@ public Batch batch(Iterable<Operand<?>> inTensors, Long numBatchThreads, Long ma return Batch.create(scope, inTensors, numBatchThreads, maxBatchSize, batchTimeoutMicros, gradTimeoutMicros, options); } + /** + * Batches all the inputs tensors to the computation done by the function. + * So, for example, in the following code + * <pre> + * + * # This input will be captured. + * y = tf.placeholder_with_default(1.0, shape=[]) + * + * {@literal @}tf.Defun(tf.float32) + * def computation(a): + * return tf.matmul(a, a) + y + * + * b = gen_batch_ops.batch_function( + * f=computation + * in_tensors=[a], + * captured_tensors=computation.captured_inputs, + * Tout=[o.type for o in computation.definition.signature.output_arg], + * num_batch_threads=1, + * max_batch_size=10, + * batch_timeout_micros=100000, # 100ms + * allowed_batch_sizes=[3, 10], + * batching_queue="") + * </pre> + * <p>If more than one session.run call is simultaneously trying to compute {@code b} + * the values of {@code a} will be gathered, non-deterministically concatenated + * along the first axis, and only one thread will run the computation. + * <p>Assumes that all arguments of the function are Tensors which will be batched + * along their first dimension. + * <p>Arguments that are captured, are not batched. The session.run call which does + * the concatenation, will use the values of the captured tensors available to it. + * Therefore, typical uses of captured tensors should involve values which remain + * unchanged across session.run calls. Inference is a good example of this. + * <p>SparseTensor is not supported. The return value of the decorated function + * must be a Tensor or a list/tuple of Tensors. + * + * @param inTensors The tensors to be batched. + * @param capturedTensors The tensors which are captured in the function, and don't need + * to be batched. + * @param f the value of the f property + * @param numBatchThreads Number of scheduling threads for processing batches of work. + * Determines the number of batches processed in parallel. + * @param maxBatchSize Batch sizes will never be bigger than this. + * @param batchTimeoutMicros Maximum number of microseconds to wait before outputting + * an incomplete batch. + * @param Tout the types of the output tensors. + * @param options carries optional attribute values + * @return a new instance of BatchFunction + */ + public BatchFunction batchFunction(Iterable<Operand<?>> inTensors, + Iterable<Operand<?>> capturedTensors, ConcreteFunction f, Long numBatchThreads, + Long maxBatchSize, Long batchTimeoutMicros, List<Class<? extends TType>> Tout, + BatchFunction.Options... options) { + return BatchFunction.create(scope, inTensors, capturedTensors, f, numBatchThreads, maxBatchSize, batchTimeoutMicros, Tout, options); + } + /** * BatchToSpace for 4-D tensors of type T. * This is a legacy version of the more general BatchToSpaceND. @@ -1144,6 +1210,42 @@ public Map<String, Operand<?>> call(ConcreteFunction function, return Function.call(scope, function, arguments); } + /** + * An n-way switch statement which calls a single branch function. + * <pre> + * An n-way switch statement, implementing the following: + * ``` + * switch (branch_index) { + * case 0: + * output = branches[0](input); + * break; + * case 1: + * output = branches[1](input); + * break; + * ... + * case [[nbranches-1]]: + * default: + * output = branches[nbranches-1](input); + * break; + * } + * ``` + * </pre> + * + * @param branchIndex The branch selector, an int32 Tensor. + * @param input A list of input tensors passed to the branch function. + * @param Tout A list of output types. + * @param branches <pre> + * A list of functions each of which takes 'inputs' and returns a list of + * tensors, whose types are the same as what every other branch returns. + * </pre> + * @param options carries optional attribute values + * @return a new instance of Case + */ + public Case caseOp(Operand<TInt32> branchIndex, Iterable<Operand<?>> input, + List<Class<? extends TType>> Tout, List<ConcreteFunction> branches, Case.Options... options) { + return Case.create(scope, branchIndex, input, Tout, branches, options); + } + /** * Clips tensor values to a specified min and max. * Given a tensor {@code t}, this operation returns a tensor of the same type and @@ -2475,6 +2577,28 @@ public Fingerprint fingerprint(Operand<? extends TType> data, Operand<TString> m return Fingerprint.create(scope, data, method); } + /** + * <pre> + * output = input; + * for i in range(start, limit, delta) + * output = body(i, output); + * </pre> + * + * @param start The lower bound. An int32 + * @param limit The upper bound. An int32 + * @param delta The increment. An int32 + * @param input A list of input tensors whose types are T. + * @param body <pre> + * A function that takes a list of tensors (int32, T) and returns another + * list of tensors (T). + * </pre> + * @return a new instance of For + */ + public For forOp(Operand<TInt32> start, Operand<TInt32> limit, Operand<TInt32> delta, + Iterable<Operand<?>> input, ConcreteFunction body) { + return For.create(scope, start, limit, delta, input, body); + } + /** * Gather slices from {@code params} axis {@code axis} according to {@code indices}. * {@code indices} must be an integer tensor of any dimension (usually 0-D or 1-D). @@ -2820,6 +2944,36 @@ public IdentityN identityN(Iterable<Operand<?>> input) { return IdentityN.create(scope, input); } + /** + * output = cond ? then_branch(input) : else_branch(input) + * + * @param cond <pre> + * A Tensor. If the tensor is a scalar of non-boolean type, the + * scalar is converted to a boolean according to the + * following rule: if the scalar is a numerical value, non-zero means + * `True` and zero means False; if the scalar is a string, non-empty + * means `True` and empty means `False`. If the tensor is not a scalar, + * being empty means False and being non-empty means True. + * </pre> + * @param input A list of input tensors. + * @param Tout A list of output types. + * @param thenBranch <pre> + * A function that takes 'inputs' and returns a list of tensors, whose + * types are the same as what else_branch returns. + * </pre> + * @param elseBranch <pre> + * A function that takes 'inputs' and returns a list of tensors, whose + * types are the same as what then_branch returns. + * </pre> + * @param options carries optional attribute values + * @return a new instance of If + */ + public If ifOp(Operand<? extends TType> cond, Iterable<Operand<?>> input, + List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, ConcreteFunction elseBranch, + If.Options... options) { + return If.create(scope, cond, input, Tout, thenBranch, elseBranch, options); + } + /** * Returns immutable tensor from memory region. * The current implementation memmaps the tensor from a file. @@ -3859,6 +4013,25 @@ public <T extends TType> ParallelDynamicStitch<T> parallelDynamicStitch( return ParallelDynamicStitch.create(scope, indices, data); } + /** + * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. + * + * @param args A list of input tensors. + * @param Tout A list of output types. + * @param f <pre> + * A function that takes 'args', a list of tensors, and returns 'output', + * another list of tensors. Input and output types are specified by 'Tin' + * and 'Tout'. The function body of f will be placed and partitioned across + * devices, setting this op apart from the regular Call op. + * </pre> + * @param options carries optional attribute values + * @return a new instance of PartitionedCall + */ + public PartitionedCall partitionedCall(Iterable<Operand<?>> args, + List<Class<? extends TType>> Tout, ConcreteFunction f, PartitionedCall.Options... options) { + return PartitionedCall.create(scope, args, Tout, f, options); + } + /** * A placeholder op for a value that will be fed into the computation. * N.B. This operation will fail with an error if it is executed. It is @@ -4158,6 +4331,41 @@ public <T extends TType> RefSwitch<T> refSwitch(Operand<T> data, Operand<TBool> return RefSwitch.create(scope, data, pred); } + /** + * Runs function {@code f} on a remote device indicated by {@code target}. + * + * @param target A fully specified device name where we want to run the function. + * @param args A list of arguments for the function. + * @param Tout The type list for the return values. + * @param f The function to run remotely. + * @return a new instance of RemoteCall + */ + public RemoteCall remoteCall(Operand<TString> target, Iterable<Operand<?>> args, + List<Class<? extends TType>> Tout, ConcreteFunction f) { + return RemoteCall.create(scope, target, args, Tout, f); + } + + /** + * Execute a sub graph on a remote processor. + * The graph specifications(such as graph itself, input tensors and output names) + * are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo + * as serialized_remote_fused_graph_execute_info. + * The specifications will be passed to a dedicated registered + * remote fused graph executor. The executor will send the graph specifications + * to a remote processor and execute that graph. The execution results + * will be passed to consumer nodes as outputs of this node. + * + * @param inputs Arbitrary number of tensors with arbitrary data types + * @param Toutputs the value of the Toutputs property + * @param serializedRemoteFusedGraphExecuteInfo Serialized protocol buffer + * of RemoteFusedGraphExecuteInfo which contains graph specifications. + * @return a new instance of RemoteFusedGraphExecute + */ + public RemoteFusedGraphExecute remoteFusedGraphExecute(Iterable<Operand<?>> inputs, + List<Class<? extends TType>> Toutputs, String serializedRemoteFusedGraphExecuteInfo) { + return RemoteFusedGraphExecute.create(scope, inputs, Toutputs, serializedRemoteFusedGraphExecuteInfo); + } + /** * Reshapes a tensor. * Given {@code tensor}, this operation returns a tensor that has the same values @@ -5840,6 +6048,89 @@ public StageSize stageSize(List<Class<? extends TType>> dtypes, StageSize.Option return StageSize.create(scope, dtypes, options); } + /** + * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. + * + * @param args A list of input tensors. + * @param Tout A list of output types. + * @param f <pre> + * A function that takes 'args', a list of tensors, and returns 'output', + * another list of tensors. Input and output types are specified by 'Tin' + * and 'Tout'. The function body of f will be placed and partitioned across + * devices, setting this op apart from the regular Call op. This op is + * stateful. + * </pre> + * @param options carries optional attribute values + * @return a new instance of StatefulPartitionedCall + */ + public StatefulPartitionedCall statefulPartitionedCall(Iterable<Operand<?>> args, + List<Class<? extends TType>> Tout, ConcreteFunction f, + StatefulPartitionedCall.Options... options) { + return StatefulPartitionedCall.create(scope, args, Tout, f, options); + } + + /** + * output = cond ? then_branch(input) : else_branch(input) + * + * @param cond <pre> + * A Tensor. If the tensor is a scalar of non-boolean type, the + * scalar is converted to a boolean according to the + * following rule: if the scalar is a numerical value, non-zero means + * `True` and zero means False; if the scalar is a string, non-empty + * means `True` and empty means `False`. If the tensor is not a scalar, + * being empty means False and being non-empty means True. + * + * This should only be used when the if then/else body functions do not + * have stateful ops. + * </pre> + * @param input A list of input tensors. + * @param Tout A list of output types. + * @param thenBranch <pre> + * A function that takes 'inputs' and returns a list of tensors, whose + * types are the same as what else_branch returns. + * </pre> + * @param elseBranch <pre> + * A function that takes 'inputs' and returns a list of tensors, whose + * types are the same as what then_branch returns. + * </pre> + * @param options carries optional attribute values + * @return a new instance of StatelessIf + */ + public StatelessIf statelessIf(Operand<? extends TType> cond, Iterable<Operand<?>> input, + List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, ConcreteFunction elseBranch, + StatelessIf.Options... options) { + return StatelessIf.create(scope, cond, input, Tout, thenBranch, elseBranch, options); + } + + /** + * output = input; While (Cond(output)) { output = Body(output) } + * + * @param input A list of input tensors whose types are T. + * @param cond <pre> + * A function takes 'input' and returns a tensor. If the tensor is + * a scalar of non-boolean, the scalar is converted to a boolean + * according to the following rule: if the scalar is a numerical + * value, non-zero means True and zero means False; if the scalar is + * a string, non-empty means True and empty means False. If the + * tensor is not a scalar, non-emptiness means True and False + * otherwise. + * + * This should only be used when the while condition and body functions + * do not have stateful ops. + * </pre> + * @param body <pre> + * A function that takes a list of tensors and returns another + * list of tensors. Both lists have the same types as specified + * by T. + * </pre> + * @param options carries optional attribute values + * @return a new instance of StatelessWhile + */ + public StatelessWhile statelessWhile(Iterable<Operand<?>> input, ConcreteFunction cond, + ConcreteFunction body, StatelessWhile.Options... options) { + return StatelessWhile.create(scope, input, cond, body, options); + } + /** * Stops gradient computation. * When executed in a graph, this op outputs its input tensor as-is. @@ -7697,6 +7988,32 @@ public Where where(Operand<? extends TType> condition) { return Where.create(scope, condition); } + /** + * output = input; While (Cond(output)) { output = Body(output) } + * + * @param input A list of input tensors whose types are T. + * @param cond <pre> + * A function takes 'input' and returns a tensor. If the tensor is + * a scalar of non-boolean, the scalar is converted to a boolean + * according to the following rule: if the scalar is a numerical + * value, non-zero means True and zero means False; if the scalar is + * a string, non-empty means True and empty means False. If the + * tensor is not a scalar, non-emptiness means True and False + * otherwise. + * </pre> + * @param body <pre> + * A function that takes a list of tensors and returns another + * list of tensors. Both lists have the same types as specified + * by T. + * </pre> + * @param options carries optional attribute values + * @return a new instance of While + */ + public While whileOp(Iterable<Operand<?>> input, ConcreteFunction cond, ConcreteFunction body, + While.Options... options) { + return While.create(scope, input, cond, body, options); + } + /** * Wraps the XLA ConvGeneralDilated operator, documented at * https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TpuOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TpuOps.java index 1278494020f..6cd4a872851 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TpuOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TpuOps.java @@ -18,12 +18,15 @@ package org.tensorflow.op; import java.util.List; +import org.tensorflow.ConcreteFunction; import org.tensorflow.Operand; +import org.tensorflow.op.tpu.Compile; import org.tensorflow.op.tpu.CompileSucceededAssert; import org.tensorflow.op.tpu.Execute; import org.tensorflow.op.tpu.ExecuteAndUpdateVariables; import org.tensorflow.op.tpu.PartitionedInput; import org.tensorflow.op.tpu.PartitionedOutput; +import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; @@ -42,6 +45,36 @@ public final class TpuOps { this.ops = ops; } + /** + * Compiles a computations for execution on one or more TPU devices. + * For the internal use of the distributed TPU compiler. + * <p>'num_computations' is the number of computations to be compiled. + * 'function' is a function containing the computation to compile. + * 'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not + * known statically at TPUReplication rewrite time. + * 'guaranteed_constants' is a list of tensors which have been guaranteed to not + * change their values during the session lifetime. These contain tensors marked as + * constant using the GuaranteeConstOp. + * 'metadata' is a serialized TPUCompileMetadataProto describing + * the shapes and types of the inputs to the computation, as well as a mapping onto + * the TPU pod topology. + * Each 'program' output is a string key that is passed to the _TPUExecute op and + * used to look up the program in the compilation cache. + * 'may_modify_variables' indicates whether variables may be modified. + * + * @param dynamicShapes the dynamicShapes value + * @param guaranteedConstants the guaranteedConstants value + * @param numComputations the value of the numComputations property + * @param function the value of the function property + * @param metadata the value of the metadata property + * @return a new instance of Compile + */ + public Compile compile(Iterable<Operand<TInt64>> dynamicShapes, + Iterable<Operand<?>> guaranteedConstants, Long numComputations, ConcreteFunction function, + String metadata) { + return Compile.create(scope, dynamicShapes, guaranteedConstants, numComputations, function, metadata); + } + /** * Asserts that compilation succeeded. This op produces no output and closes the * device during failure to ensure all pending device interactions fail. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TrainOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TrainOps.java index 305b973c139..2661165e101 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TrainOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TrainOps.java @@ -18,6 +18,7 @@ package org.tensorflow.op; import java.util.List; +import org.tensorflow.ConcreteFunction; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.train.AccumulatorApplyGradient; @@ -81,6 +82,7 @@ import org.tensorflow.op.train.SparseApplyProximalAdagrad; import org.tensorflow.op.train.SparseApplyProximalGradientDescent; import org.tensorflow.op.train.SparseApplyRmsProp; +import org.tensorflow.op.train.SymbolicGradient; import org.tensorflow.op.train.TileGrad; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -1610,6 +1612,32 @@ public <T extends TType> SparseApplyRmsProp<T> sparseApplyRmsProp(Operand<T> var return SparseApplyRmsProp.create(scope, var, ms, mom, lr, rho, momentum, epsilon, grad, indices, options); } + /** + * Computes the gradient function for function f via backpropagation. + * + * @param input a list of input tensors of size N + M; + * @param Tout the type list for the input list. + * @param f The function we want to compute the gradient for. + * <p>The function 'f' must be a numerical function which takes N inputs and + * produces M outputs. Its gradient function 'g', which is computed by + * this SymbolicGradient op is a function taking N + M inputs and + * produces N outputs. + * <p>I.e. if we have + * (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), + * then, g is + * (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, + * dL/dy1, dL/dy2, ..., dL/dy_M), + * <p>where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the + * loss function). dL/dx_i is the partial derivative of L with respect + * to x_i. + * <p>(Needs some math expert to say the comment above better.) + * @return a new instance of SymbolicGradient + */ + public SymbolicGradient symbolicGradient(Iterable<Operand<?>> input, + List<Class<? extends TType>> Tout, ConcreteFunction f) { + return SymbolicGradient.create(scope, input, Tout, f); + } + /** * Returns the gradient of {@code Tile}. * Since {@code Tile} takes an input and repeats the input {@code multiples} times diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java index 99caae1fdc2..ba43401e0fe 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java @@ -17,6 +17,8 @@ // package org.tensorflow.op; +import java.util.List; +import org.tensorflow.ConcreteFunction; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.xla.BroadcastHelper; @@ -28,18 +30,27 @@ import org.tensorflow.op.xla.DynamicUpdateSlice; import org.tensorflow.op.xla.Einsum; import org.tensorflow.op.xla.Gather; +import org.tensorflow.op.xla.If; import org.tensorflow.op.xla.KeyValueSort; import org.tensorflow.op.xla.Pad; import org.tensorflow.op.xla.Recv; +import org.tensorflow.op.xla.Reduce; +import org.tensorflow.op.xla.ReduceWindow; import org.tensorflow.op.xla.ReplicaId; +import org.tensorflow.op.xla.Scatter; +import org.tensorflow.op.xla.SelectAndScatter; import org.tensorflow.op.xla.SelfAdjointEig; import org.tensorflow.op.xla.Send; import org.tensorflow.op.xla.Sharding; import org.tensorflow.op.xla.Sort; import org.tensorflow.op.xla.Svd; +import org.tensorflow.op.xla.While; +import org.tensorflow.op.xla.XlaHostCompute; +import org.tensorflow.op.xla.XlaLaunch; import org.tensorflow.op.xla.XlaRecvFromHost; import org.tensorflow.op.xla.XlaSendToHost; import org.tensorflow.op.xla.XlaSetBound; +import org.tensorflow.op.xla.XlaVariadicReduce; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -234,6 +245,23 @@ public <T extends TType, U extends TNumber> Gather<T> gather(Operand<T> operand, return Gather.create(scope, operand, startIndices, sliceSizes, dimensionNumbers, indicesAreSorted); } + /** + * output = cond ? then_branch(inputs) : else_branch(inputs). + * + * @param cond A boolean scalar. + * @param inputs A list of input tensors. + * @param thenBranch A function takes 'inputs' and returns a list of tensors, + * whose types are the same as what else_branch returns. + * @param elseBranch A function takes 'inputs' and returns a list of tensors. + * whose types are the same as what then_branch returns. + * @param Tout the value of the Tout property + * @return a new instance of If + */ + public If ifOp(Operand<? extends TType> cond, Iterable<Operand<?>> inputs, + ConcreteFunction thenBranch, ConcreteFunction elseBranch, List<Class<? extends TType>> Tout) { + return If.create(scope, cond, inputs, thenBranch, elseBranch, Tout); + } + /** * Wraps the XLA Sort operator, documented at * https://www.tensorflow.org/performance/xla/operation_semantics#sort @@ -293,6 +321,47 @@ public <T extends TType> Recv<T> recv(Class<T> dtype, String tensorName, Shape s return Recv.create(scope, dtype, tensorName, shape); } + /** + * Wraps the XLA Reduce operator, documented at + * https://www.tensorflow.org/performance/xla/operation_semantics#reduce . + * + * @param <T> data type for {@code output} output + * @param input the input tensor + * @param initValue a scalar representing the initial value for the reduction + * @param dimensionsToReduce dimension numbers over which to reduce + * @param reducer a reducer function to apply + * @param <T> data type for {@code XlaReduce} output and operands + * @return a new instance of Reduce + */ + public <T extends TType> Reduce<T> reduce(Operand<T> input, Operand<T> initValue, + List<Long> dimensionsToReduce, ConcreteFunction reducer) { + return Reduce.create(scope, input, initValue, dimensionsToReduce, reducer); + } + + /** + * Wraps the XLA ReduceWindow operator, documented at + * https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + * + * @param <T> data type for {@code output} output + * @param input the input tensor + * @param initValue a scalar representing the initial value for the reduction + * @param windowDimensions the shape of the window + * @param windowStrides the inter-window strides + * @param baseDilations the baseDilations value + * @param windowDilations the windowDilations value + * @param padding the padding to apply at the start and end of each input dimensions + * @param computation a reducer function to apply + * @param <T> data type for {@code XlaReduceWindow} output and operands + * @param <U> data type for {@code XlaReduceWindow} output and operands + * @return a new instance of ReduceWindow + */ + public <T extends TType, U extends TNumber> ReduceWindow<T> reduceWindow(Operand<T> input, + Operand<T> initValue, Operand<U> windowDimensions, Operand<U> windowStrides, + Operand<U> baseDilations, Operand<U> windowDilations, Operand<U> padding, + ConcreteFunction computation) { + return ReduceWindow.create(scope, input, initValue, windowDimensions, windowStrides, baseDilations, windowDilations, padding, computation); + } + /** * Replica ID. * @@ -302,6 +371,52 @@ public ReplicaId replicaId() { return ReplicaId.create(scope); } + /** + * Wraps the XLA Scatter operator documented at + * https://www.tensorflow.org/xla/operation_semantics#scatter. + * + * @param <T> data type for {@code output} output + * @param operand Array to be scattered into. + * @param scatterIndices Array containing the starting indices of the slices that must + * be scattered to. + * @param updates Array containing the values that must be used for scattering. + * @param updateComputation Computation to be used for combining the existing values in + * the input array and the updates during scatter. + * @param dimensionNumbers A serialized xla::ScatterDimensionNumbers proto. + * @param indicesAreSorted Boolean indicating if the indices are sorted. + * @param <T> data type for {@code XlaScatter} output and operands + * @return a new instance of Scatter + */ + public <T extends TType> Scatter<T> scatter(Operand<T> operand, + Operand<? extends TNumber> scatterIndices, Operand<T> updates, + ConcreteFunction updateComputation, String dimensionNumbers, Boolean indicesAreSorted) { + return Scatter.create(scope, operand, scatterIndices, updates, updateComputation, dimensionNumbers, indicesAreSorted); + } + + /** + * Wraps the XLA SelectAndScatter operator, documented at + * https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter + * . + * + * @param <T> data type for {@code output} output + * @param operand the input tensor + * @param windowDimensions the shape of the window + * @param windowStrides the inter-window strides + * @param padding the padding to apply at the start and end of each input dimensions + * @param source a tensor of values to scatter + * @param initValue a scalar representing the initial value for the output tensor + * @param select a selection function to apply + * @param scatter a scatter function to apply + * @param <T> data type for {@code XlaSelectAndScatter} output and operands + * @param <U> data type for {@code XlaSelectAndScatter} output and operands + * @return a new instance of SelectAndScatter + */ + public <T extends TType, U extends TNumber> SelectAndScatter<T> selectAndScatter( + Operand<T> operand, Operand<U> windowDimensions, Operand<U> windowStrides, Operand<U> padding, + Operand<T> source, Operand<T> initValue, ConcreteFunction select, ConcreteFunction scatter) { + return SelectAndScatter.create(scope, operand, windowDimensions, windowStrides, padding, source, initValue, select, scatter); + } + /** * Computes the eigen decomposition of a batch of self-adjoint matrices * (Note: Only real inputs are supported). @@ -389,6 +504,61 @@ public <T extends TType> Svd<T> svd(Operand<T> a, Long maxIter, Float epsilon, return Svd.create(scope, a, maxIter, epsilon, precisionConfig); } + /** + * output = input; While (Cond(output)) { output = Body(output) } + * + * @param input A list of input tensors whose types are T. + * @param cond A function takes 'input' and returns a tensor. If the tensor is + * a scalar of non-boolean, the scalar is converted to a boolean + * according to the following rule: if the scalar is a numerical + * value, non-zero means True and zero means False; if the scalar is + * a string, non-empty means True and empty means False. If the + * tensor is not a scalar, non-emptiness means True and False + * otherwise. + * @param body A function that takes a list of tensors and returns another + * list of tensors. Both lists have the same types as specified by T. + * @return a new instance of While + */ + public While whileOp(Iterable<Operand<?>> input, ConcreteFunction cond, ConcreteFunction body) { + return While.create(scope, input, cond, body); + } + + /** + * A pseudo-op to represent host-side computation in an XLA program. + * + * @param inputs A list of tensors that will be sent to the host. + * @param Toutputs The element types of each element in {@code outputs}. + * @param ancestors A list of names of HostCompute computations that must be + * sequenced before this computation. + * @param shapes If shape_inference_graph is empty, a list of the shapes of {@code outputs}. + * @param shapeInferenceGraph If non-empty, a serialized GraphDef representing a graph + * that must be analyzed at compile time to determine the shapes of the outputs. + * @param key A unique identifier for this region used to match up host transfers. + * @param options carries optional attribute values + * @return a new instance of XlaHostCompute + */ + public XlaHostCompute xlaHostCompute(Iterable<Operand<?>> inputs, + List<Class<? extends TType>> Toutputs, List<String> ancestors, List<Shape> shapes, + ConcreteFunction shapeInferenceGraph, String key, XlaHostCompute.Options... options) { + return XlaHostCompute.create(scope, inputs, Toutputs, ancestors, shapes, shapeInferenceGraph, key, options); + } + + /** + * XLA Launch Op. For use by the XLA JIT only. + * + * @param constants the constants value + * @param args the args value + * @param resources the resources value + * @param Tresults the value of the Tresults property + * @param function the value of the function property + * @return a new instance of XlaLaunch + */ + public XlaLaunch xlaLaunch(Iterable<Operand<?>> constants, Iterable<Operand<?>> args, + Iterable<Operand<? extends TType>> resources, List<Class<? extends TType>> Tresults, + ConcreteFunction function) { + return XlaLaunch.create(scope, constants, args, resources, Tresults, function); + } + /** * An op to receive a tensor from the host. * output: the tensor that will be received from the host. @@ -436,6 +606,23 @@ public XlaSetBound xlaSetBound(Operand<TInt32> input, Operand<TInt32> bound) { return XlaSetBound.create(scope, input, bound); } + /** + * Wraps the variadic XLA Reduce operator, documented at + * https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce. + * + * @param <T> data type for {@code output} output + * @param input the input tensor(s) + * @param initValue scalar initial value(s) for the reduction + * @param dimensionsToReduce dimension numbers over which to reduce + * @param reducer a reducer function to apply + * @param <T> data type for {@code XlaVariadicReduce} output and operands + * @return a new instance of XlaVariadicReduce + */ + public <T extends TType> XlaVariadicReduce<T> xlaVariadicReduce(Iterable<Operand<T>> input, + Iterable<Operand<T>> initValue, List<Long> dimensionsToReduce, ConcreteFunction reducer) { + return XlaVariadicReduce.create(scope, input, initValue, dimensionsToReduce, reducer); + } + /** * Get the parent {@link Ops} object. */ diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java index 8ea144ece5b..73cd64b11bd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java @@ -88,7 +88,8 @@ private Case(Operation operation) { * @return a new instance of Case */ @Endpoint( - describeByClass = true + describeByClass = true, + name = "caseOp" ) public static Case create(Scope scope, Operand<TInt32> branchIndex, Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, List<ConcreteFunction> branches, Options... options) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java index 4ce4fc1da35..7d714de0879 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java @@ -73,7 +73,8 @@ private For(Operation operation) { * @return a new instance of For */ @Endpoint( - describeByClass = true + describeByClass = true, + name = "forOp" ) public static For create(Scope scope, Operand<TInt32> start, Operand<TInt32> limit, Operand<TInt32> delta, Iterable<Operand<?>> input, ConcreteFunction body) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java index c178111e27b..7e1acb6691b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java @@ -80,7 +80,8 @@ private If(Operation operation) { * @return a new instance of If */ @Endpoint( - describeByClass = true + describeByClass = true, + name = "ifOp" ) public static If create(Scope scope, Operand<? extends TType> cond, Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, ConcreteFunction elseBranch, diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java index ac3b4e7a791..1f5a0a273f4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java @@ -77,7 +77,8 @@ private While(Operation operation) { * @return a new instance of While */ @Endpoint( - describeByClass = true + describeByClass = true, + name = "whileOp" ) public static While create(Scope scope, Iterable<Operand<?>> input, ConcreteFunction cond, ConcreteFunction body, Options... options) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/If.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/If.java index a5d13be15ac..0f35595d405 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/If.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/If.java @@ -69,7 +69,8 @@ private If(Operation operation) { * @return a new instance of If */ @Endpoint( - describeByClass = true + describeByClass = true, + name = "ifOp" ) public static If create(Scope scope, Operand<? extends TType> cond, Iterable<Operand<?>> inputs, ConcreteFunction thenBranch, ConcreteFunction elseBranch, List<Class<? extends TType>> Tout) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/While.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/While.java index e1a69eb1c6d..985aabed588 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/While.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/While.java @@ -72,7 +72,8 @@ private While(Operation operation) { * @return a new instance of While */ @Endpoint( - describeByClass = true + describeByClass = true, + name = "whileOp" ) public static While create(Scope scope, Iterable<Operand<?>> input, ConcreteFunction cond, ConcreteFunction body) { diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java index e5aab27bb69..7327afab81c 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -======================================================================= -*/ + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.op.generator; import static org.tensorflow.op.generator.GeneratorUtils.javaizeMemberName; @@ -443,8 +443,16 @@ private void buildFactoryMethods() { } factoryBuilder.returns(returnType); - factoryBuilder.addAnnotation( - AnnotationSpec.builder(Names.Endpoint).addMember("describeByClass", "true").build()); + AnnotationSpec.Builder endpointAnnotation = + AnnotationSpec.builder(Names.Endpoint).addMember("describeByClass", "true"); + + String methodName = GeneratorUtils.getOpMethodName(className); + + if (methodName != null) { + endpointAnnotation.addMember("name", "$S", methodName); + } + + factoryBuilder.addAnnotation(endpointAnnotation.build()); factoryBuilder.addJavadoc( "Factory method to create a class wrapping a new $L operation.\n", op.getName()); diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java new file mode 100644 index 00000000000..2f39e466719 --- /dev/null +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java @@ -0,0 +1,119 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.op.generator; + +import com.squareup.javapoet.TypeSpec; +import org.tensorflow.proto.framework.ApiDef; +import org.tensorflow.proto.framework.ApiDef.Endpoint; +import org.tensorflow.proto.framework.OpDef; + +public final class FullOpDef { + public final OpDef opDef; + public final ApiDef apiDef; + public final String basePackage; + public final String packageName; + public final String group; + public final String className; + public final Endpoint endpoint; + + public FullOpDef( + OpDef opDef, + ApiDef apiDef, + String basePackage, + String packageName, + String group, + String className, + Endpoint endpoint) { + this.group = group; + this.endpoint = endpoint; + if (opDef == null) { + throw new IllegalArgumentException("Can't have a null OpDef"); + } + if (apiDef == null) { + throw new IllegalArgumentException("Can't have a null ApiDef"); + } + this.opDef = opDef; + this.apiDef = apiDef; + this.basePackage = basePackage; + this.packageName = packageName; + this.className = className; + } + + public boolean isStateful() { + return opDef.getIsStateful(); + } + + public boolean equalOtherThanState(FullOpDef other) { + OpDef copy = + opDef.toBuilder().setName(other.opDef.getName()).setIsStateful(other.isStateful()).build(); + return copy.equals(other.opDef); + } + + public TypeSpec buildOpClass() { + TypeSpec.Builder cls = TypeSpec.classBuilder(className); + try { + new ClassGenerator(cls, opDef, apiDef, basePackage, packageName, group, className, endpoint) + .buildClass(); + } catch (Exception e) { + throw new IllegalStateException("Failed to generate class for op " + opDef.getName(), e); + } + return cls.build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FullOpDef fullOpDef = (FullOpDef) o; + + if (!opDef.equals(fullOpDef.opDef)) { + return false; + } + if (!apiDef.equals(fullOpDef.apiDef)) { + return false; + } + if (!basePackage.equals(fullOpDef.basePackage)) { + return false; + } + if (!packageName.equals(fullOpDef.packageName)) { + return false; + } + if (!group.equals(fullOpDef.group)) { + return false; + } + if (!className.equals(fullOpDef.className)) { + return false; + } + return endpoint.equals(fullOpDef.endpoint); + } + + @Override + public int hashCode() { + int result = opDef.hashCode(); + result = 31 * result + apiDef.hashCode(); + result = 31 * result + basePackage.hashCode(); + result = 31 * result + packageName.hashCode(); + result = 31 * result + group.hashCode(); + result = 31 * result + className.hashCode(); + result = 31 * result + endpoint.hashCode(); + return result; + } +} diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java index 80d6698fe36..dfe993c96ca 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java @@ -1,19 +1,18 @@ -/* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ============================================================================== - */ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.op.generator; import org.commonmark.node.Node; @@ -21,9 +20,7 @@ import org.tensorflow.op.generator.javadoc.JavaDocRenderer; import org.tensorflow.proto.framework.OpDef.ArgDef; -/** - * Utilities for op generation - */ +/** Utilities for op generation */ final class GeneratorUtils { private static final Parser parser = Parser.builder().build(); @@ -31,11 +28,11 @@ final class GeneratorUtils { /** * Convert a Python style name to a Java style name. * - * Does snake_case -> camelCase and handles keywords. + * <p>Does snake_case -> camelCase and handles keywords. * - * Not valid for class names, meant for fields and methods. + * <p>Not valid for class names, meant for fields and methods. * - * Generally you should use {@link ClassGenerator#getJavaName(ArgDef)}. + * <p>Generally you should use {@link ClassGenerator#getJavaName(ArgDef)}. */ static String javaizeMemberName(String name) { StringBuilder result = new StringBuilder(); @@ -68,13 +65,28 @@ static String javaizeMemberName(String name) { } /** - * Convert markdown descriptions to JavaDocs. + * Get the name of the Ops method, or null to not specify one (the decapitalized class name will + * be used). */ + static String getOpMethodName(String className) { + switch (className) { + case "If": + return "ifOp"; + case "While": + return "whileOp"; + case "For": + return "forOp"; + case "Case": + return "caseOp"; + default: + return null; + } + } + + /** Convert markdown descriptions to JavaDocs. */ static String parseDocumentation(String docs) { Node document = parser.parse(docs); JavaDocRenderer renderer = JavaDocRenderer.builder().build(); return renderer.render(document).trim(); } - - } diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java index da4f63405cc..5e7b0aa324a 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java @@ -1,28 +1,24 @@ -/* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -============================================================================== -*/ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.op.generator; import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.UnknownFieldSet; import com.squareup.javapoet.JavaFile; import com.squareup.javapoet.TypeSpec; -import org.tensorflow.proto.framework.ApiDef; -import org.tensorflow.proto.framework.OpDef; -import org.tensorflow.proto.framework.OpList; - import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; @@ -35,7 +31,12 @@ import java.nio.file.StandardOpenOption; import java.nio.file.attribute.BasicFileAttributes; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import org.tensorflow.proto.framework.ApiDef; +import org.tensorflow.proto.framework.OpDef; +import org.tensorflow.proto.framework.OpList; public final class OpGenerator { @@ -169,7 +170,8 @@ private static void generate(File outputDir, String packageName, File opDefs) { .getField(API_DEF_FIELD_NUMBER) .getLengthDelimitedList() .get(0)); - defs.put(op, api); + defs.put( + op.toBuilder().setUnknownFields(UnknownFieldSet.newBuilder().build()).build(), api); } catch (InvalidProtocolBufferException e) { throw new RuntimeException( "Could not parse attached ApiDef for op " @@ -186,78 +188,65 @@ private static void generate(File outputDir, String packageName, File opDefs) { /** Generate all the ops that pass {@link ClassGenerator#canGenerateOp(OpDef, ApiDef)}. */ private static void generate(File outputDir, String basePackage, Map<OpDef, ApiDef> ops) { - ops.entrySet().stream() - .filter(e -> ClassGenerator.canGenerateOp(e.getKey(), e.getValue())) - .forEach( - (entry) -> { - entry - .getValue() - .getEndpointList() - .forEach( - (endpoint) -> { - String name; - String pack; - - int pos = endpoint.getName().lastIndexOf('.'); - if (pos > -1) { - pack = endpoint.getName().substring(0, pos); - name = endpoint.getName().substring(pos + 1); - } else { - pack = "core"; - name = endpoint.getName(); - } - - TypeSpec.Builder cls = TypeSpec.classBuilder(name); - try { - new ClassGenerator( - cls, + List<FullOpDef> fullOps = + ops.entrySet().stream() + .filter(e -> ClassGenerator.canGenerateOp(e.getKey(), e.getValue())) + .flatMap( + (entry) -> + entry.getValue().getEndpointList().stream() + .map( + (endpoint) -> { + String name; + String pack; + + int pos = endpoint.getName().lastIndexOf('.'); + if (pos > -1) { + pack = endpoint.getName().substring(0, pos); + name = endpoint.getName().substring(pos + 1); + } else { + pack = "core"; + name = endpoint.getName(); + } + + return new FullOpDef( entry.getKey(), entry.getValue(), basePackage, basePackage + "." + pack, pack, name, - endpoint) - .buildClass(); - } catch (Exception e) { - throw new IllegalStateException( - "Failed to generate class for op " + entry.getKey().getName(), e); - } - TypeSpec spec = cls.build(); - - JavaFile file = - JavaFile.builder(basePackage + "." + pack, spec) - .indent(" ") - .skipJavaLangImports(true) - .build(); - - File outputFile = - new File( - outputDir, - basePackage.replace('.', '/') - + '/' - + pack.replace('.', '/') - + '/' - + spec.name - + ".java"); - outputFile.getParentFile().mkdirs(); - try { - StringBuilder builder = new StringBuilder(); - builder.append(LICENSE + '\n'); - builder.append("// This class has been generated, DO NOT EDIT!\n\n"); - file.writeTo(builder); - - Files.write( - outputFile.toPath(), - builder.toString().getBytes(StandardCharsets.UTF_8), - StandardOpenOption.WRITE, - StandardOpenOption.CREATE, - StandardOpenOption.TRUNCATE_EXISTING); - } catch (IOException ioException) { - throw new IllegalStateException( - "Failed to write file " + outputFile, ioException); - } - }); - }); + endpoint); + })) + .collect(Collectors.toList()); + + fullOps.forEach( + (def) -> { + TypeSpec spec = def.buildOpClass(); + + JavaFile file = + JavaFile.builder(def.packageName, spec) + .indent(" ") + .skipJavaLangImports(true) + .build(); + + File outputFile = + new File(outputDir, def.packageName.replace('.', '/') + '/' + spec.name + ".java"); + outputFile.getParentFile().mkdirs(); + try { + StringBuilder builder = new StringBuilder(); + builder.append(LICENSE + '\n'); + builder.append("// This class has been generated, DO NOT EDIT!\n\n"); + file.writeTo(builder); + + Files.write( + outputFile.toPath(), + builder.toString().getBytes(StandardCharsets.UTF_8), + StandardOpenOption.WRITE, + StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING); + } catch (IOException ioException) { + throw new IllegalStateException("Failed to write file " + outputFile, ioException); + } + }); } } diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java new file mode 100644 index 00000000000..194ab51937d --- /dev/null +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java @@ -0,0 +1,27 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.op.generator; + +public final class StatefulPair { + public final FullOpDef statefulOp; + + public final FullOpDef statelessOp; + + public StatefulPair(FullOpDef statefulOp, FullOpDef statelessOp) { + this.statefulOp = statefulOp; + this.statelessOp = statelessOp; + } +} From db51573307b84535aba9f574e7b09233f8fa3a2b Mon Sep 17 00:00:00 2001 From: Ryan Nett <JNett96@gmail.com> Date: Mon, 31 May 2021 16:50:48 -0700 Subject: [PATCH 04/14] Fix copyright formatting Signed-off-by: Ryan Nett <JNett96@gmail.com> --- .../op/generator/ClassGenerator.java | 26 +++++++++---------- .../tensorflow/op/generator/FullOpDef.java | 22 ++++++++-------- .../op/generator/GeneratorUtils.java | 22 ++++++++-------- .../tensorflow/op/generator/OpGenerator.java | 26 +++++++++---------- .../tensorflow/op/generator/StatefulPair.java | 26 +++++++++++-------- 5 files changed, 63 insertions(+), 59 deletions(-) diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java index 7327afab81c..303e1dcc296 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.op.generator; import static org.tensorflow.op.generator.GeneratorUtils.javaizeMemberName; diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java index 2f39e466719..1fd9d862a7c 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.op.generator; import com.squareup.javapoet.TypeSpec; diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java index dfe993c96ca..f19d88416f5 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.op.generator; import org.commonmark.node.Node; diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java index 5e7b0aa324a..640b79e8e39 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.op.generator; import com.google.protobuf.InvalidProtocolBufferException; diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java index 194ab51937d..f02c0e24068 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.op.generator; public final class StatefulPair { @@ -24,4 +24,8 @@ public StatefulPair(FullOpDef statefulOp, FullOpDef statelessOp) { this.statefulOp = statefulOp; this.statelessOp = statelessOp; } + + // public static List<StatefulPair> extractStatefulPairs(List<FullOpDef> ops) { + // // List<> + // } } From bdde8284d12e4de3ee21aa69d76997248c83ef9b Mon Sep 17 00:00:00 2001 From: Ryan Nett <JNett96@gmail.com> Date: Mon, 31 May 2021 17:20:13 -0700 Subject: [PATCH 05/14] Start of stateful/stateless processing Signed-off-by: Ryan Nett <JNett96@gmail.com> --- pom.xml | 2 +- .../op/core/{Case.java => StatefulCase.java} | 20 ++-- .../op/core/{If.java => StatefulIf.java} | 21 ++-- .../core/{While.java => StatefulWhile.java} | 17 ++- ...all.java => StatelessPartitionedCall.java} | 14 +-- .../tensorflow/op/generator/FullOpDef.java | 26 ++++- .../tensorflow/op/generator/OpGenerator.java | 57 ++++++---- .../tensorflow/op/generator/StatefulPair.java | 107 +++++++++++++++++- 8 files changed, 196 insertions(+), 68 deletions(-) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/{Case.java => StatefulCase.java} (90%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/{If.java => StatefulIf.java} (89%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/{While.java => StatefulWhile.java} (92%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/{PartitionedCall.java => StatelessPartitionedCall.java} (91%) diff --git a/pom.xml b/pom.xml index ee498135d2a..faf005732dd 100644 --- a/pom.xml +++ b/pom.xml @@ -252,7 +252,7 @@ <executions> <execution> <!-- Runs in initialize phase to fail fast in case of formatting issues (should be before codegen).--> - <id>spotless-check</id> + <id>spotless-apply</id> <phase>initialize</phase> <goals> <goal>check</goal> diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulCase.java similarity index 90% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulCase.java index 73cd64b11bd..beac85f315d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulCase.java @@ -56,7 +56,7 @@ * </pre> */ @Operator -public final class Case extends RawOp implements Iterable<Operand<TType>> { +public final class StatefulCase extends RawOp implements Iterable<Operand<TType>> { /** * The name of this op, as known by TensorFlow core engine */ @@ -65,7 +65,7 @@ public final class Case extends RawOp implements Iterable<Operand<TType>> { private List<Output<?>> output; @SuppressWarnings("unchecked") - private Case(Operation operation) { + private StatefulCase(Operation operation) { super(operation); int outputIdx = 0; int outputLength = operation.outputListLength("output"); @@ -85,15 +85,15 @@ private Case(Operation operation) { * tensors, whose types are the same as what every other branch returns. * </pre> * @param options carries optional attribute values - * @return a new instance of Case + * @return a new instance of StatefulCase */ @Endpoint( - describeByClass = true, - name = "caseOp" + describeByClass = true ) - public static Case create(Scope scope, Operand<TInt32> branchIndex, Iterable<Operand<?>> input, - List<Class<? extends TType>> Tout, List<ConcreteFunction> branches, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("Case", scope.makeOpName("Case")); + public static StatefulCase create(Scope scope, Operand<TInt32> branchIndex, + Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, + List<ConcreteFunction> branches, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("Case", scope.makeOpName("StatefulCase")); opBuilder.addInput(branchIndex.asOutput()); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.apply(opBuilder); @@ -114,7 +114,7 @@ public static Case create(Scope scope, Operand<TInt32> branchIndex, Iterable<Ope } } } - return new Case(opBuilder.build()); + return new StatefulCase(opBuilder.build()); } /** @@ -153,7 +153,7 @@ public Iterator<Operand<TType>> iterator() { } /** - * Optional attributes for {@link org.tensorflow.op.core.Case} + * Optional attributes for {@link org.tensorflow.op.core.StatefulCase} */ public static class Options { private List<Shape> outputShapes; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulIf.java similarity index 89% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulIf.java index 7e1acb6691b..7f773e18c4b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulIf.java @@ -37,7 +37,7 @@ * output = cond ? then_branch(input) : else_branch(input) */ @Operator -public final class If extends RawOp implements Iterable<Operand<TType>> { +public final class StatefulIf extends RawOp implements Iterable<Operand<TType>> { /** * The name of this op, as known by TensorFlow core engine */ @@ -46,7 +46,7 @@ public final class If extends RawOp implements Iterable<Operand<TType>> { private List<Output<?>> output; @SuppressWarnings("unchecked") - private If(Operation operation) { + private StatefulIf(Operation operation) { super(operation); int outputIdx = 0; int outputLength = operation.outputListLength("output"); @@ -77,16 +77,15 @@ private If(Operation operation) { * types are the same as what then_branch returns. * </pre> * @param options carries optional attribute values - * @return a new instance of If + * @return a new instance of StatefulIf */ @Endpoint( - describeByClass = true, - name = "ifOp" + describeByClass = true ) - public static If create(Scope scope, Operand<? extends TType> cond, Iterable<Operand<?>> input, - List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, ConcreteFunction elseBranch, - Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("If", scope.makeOpName("If")); + public static StatefulIf create(Scope scope, Operand<? extends TType> cond, + Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, + ConcreteFunction elseBranch, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder("If", scope.makeOpName("StatefulIf")); opBuilder.addInput(cond.asOutput()); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.apply(opBuilder); @@ -104,7 +103,7 @@ public static If create(Scope scope, Operand<? extends TType> cond, Iterable<Ope } } } - return new If(opBuilder.build()); + return new StatefulIf(opBuilder.build()); } /** @@ -143,7 +142,7 @@ public Iterator<Operand<TType>> iterator() { } /** - * Optional attributes for {@link org.tensorflow.op.core.If} + * Optional attributes for {@link org.tensorflow.op.core.StatefulIf} */ public static class Options { private List<Shape> outputShapes; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulWhile.java similarity index 92% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulWhile.java index 1f5a0a273f4..10edf7df77d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulWhile.java @@ -37,7 +37,7 @@ * output = input; While (Cond(output)) { output = Body(output) } */ @Operator -public final class While extends RawOp implements Iterable<Operand<TType>> { +public final class StatefulWhile extends RawOp implements Iterable<Operand<TType>> { /** * The name of this op, as known by TensorFlow core engine */ @@ -46,7 +46,7 @@ public final class While extends RawOp implements Iterable<Operand<TType>> { private List<Output<?>> output; @SuppressWarnings("unchecked") - private While(Operation operation) { + private StatefulWhile(Operation operation) { super(operation); int outputIdx = 0; int outputLength = operation.outputListLength("output"); @@ -74,15 +74,14 @@ private While(Operation operation) { * by T. * </pre> * @param options carries optional attribute values - * @return a new instance of While + * @return a new instance of StatefulWhile */ @Endpoint( - describeByClass = true, - name = "whileOp" + describeByClass = true ) - public static While create(Scope scope, Iterable<Operand<?>> input, ConcreteFunction cond, + public static StatefulWhile create(Scope scope, Iterable<Operand<?>> input, ConcreteFunction cond, ConcreteFunction body, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("While", scope.makeOpName("While")); + OperationBuilder opBuilder = scope.env().opBuilder("While", scope.makeOpName("StatefulWhile")); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("cond", cond); @@ -101,7 +100,7 @@ public static While create(Scope scope, Iterable<Operand<?>> input, ConcreteFunc } } } - return new While(opBuilder.build()); + return new StatefulWhile(opBuilder.build()); } /** @@ -150,7 +149,7 @@ public Iterator<Operand<TType>> iterator() { } /** - * Optional attributes for {@link org.tensorflow.op.core.While} + * Optional attributes for {@link org.tensorflow.op.core.StatefulWhile} */ public static class Options { private List<Shape> outputShapes; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessPartitionedCall.java similarity index 91% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessPartitionedCall.java index 5b67a6129fb..2fe53c2d398 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessPartitionedCall.java @@ -36,7 +36,7 @@ * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. */ @Operator -public final class PartitionedCall extends RawOp implements Iterable<Operand<TType>> { +public final class StatelessPartitionedCall extends RawOp implements Iterable<Operand<TType>> { /** * The name of this op, as known by TensorFlow core engine */ @@ -45,7 +45,7 @@ public final class PartitionedCall extends RawOp implements Iterable<Operand<TTy private List<Output<?>> output; @SuppressWarnings("unchecked") - private PartitionedCall(Operation operation) { + private StatelessPartitionedCall(Operation operation) { super(operation); int outputIdx = 0; int outputLength = operation.outputListLength("output"); @@ -66,14 +66,14 @@ private PartitionedCall(Operation operation) { * devices, setting this op apart from the regular Call op. * </pre> * @param options carries optional attribute values - * @return a new instance of PartitionedCall + * @return a new instance of StatelessPartitionedCall */ @Endpoint( describeByClass = true ) - public static PartitionedCall create(Scope scope, Iterable<Operand<?>> args, + public static StatelessPartitionedCall create(Scope scope, Iterable<Operand<?>> args, List<Class<? extends TType>> Tout, ConcreteFunction f, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("PartitionedCall", scope.makeOpName("PartitionedCall")); + OperationBuilder opBuilder = scope.env().opBuilder("PartitionedCall", scope.makeOpName("StatelessPartitionedCall")); opBuilder.addInputList(Operands.asOutputs(args)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); @@ -91,7 +91,7 @@ public static PartitionedCall create(Scope scope, Iterable<Operand<?>> args, } } } - return new PartitionedCall(opBuilder.build()); + return new StatelessPartitionedCall(opBuilder.build()); } /** @@ -140,7 +140,7 @@ public Iterator<Operand<TType>> iterator() { } /** - * Optional attributes for {@link org.tensorflow.op.core.PartitionedCall} + * Optional attributes for {@link org.tensorflow.op.core.StatelessPartitionedCall} */ public static class Options { private String config; diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java index 1fd9d862a7c..bea598ca37a 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java @@ -16,6 +16,7 @@ package org.tensorflow.op.generator; import com.squareup.javapoet.TypeSpec; +import java.util.StringJoiner; import org.tensorflow.proto.framework.ApiDef; import org.tensorflow.proto.framework.ApiDef.Endpoint; import org.tensorflow.proto.framework.OpDef; @@ -56,13 +57,21 @@ public boolean isStateful() { return opDef.getIsStateful(); } - public boolean equalOtherThanState(FullOpDef other) { + public boolean isStateVariant(FullOpDef other) { + if (this.equals(other)) return false; + + if (this.isStateful() == other.isStateful()) return false; + OpDef copy = opDef.toBuilder().setName(other.opDef.getName()).setIsStateful(other.isStateful()).build(); - return copy.equals(other.opDef); + return copy.equals(other.opDef) && packageName.equals(other.packageName); } public TypeSpec buildOpClass() { + return buildOpClass(className); + } + + public TypeSpec buildOpClass(String className) { TypeSpec.Builder cls = TypeSpec.classBuilder(className); try { new ClassGenerator(cls, opDef, apiDef, basePackage, packageName, group, className, endpoint) @@ -116,4 +125,17 @@ public int hashCode() { result = 31 * result + endpoint.hashCode(); return result; } + + @Override + public String toString() { + return new StringJoiner(", ", FullOpDef.class.getSimpleName() + "(", ")") + .add("opDef=" + opDef) + .add("apiDef=" + apiDef) + .add("basePackage='" + basePackage + "'") + .add("packageName='" + packageName + "'") + .add("group='" + group + "'") + .add("className='" + className + "'") + .add("endpoint=" + endpoint) + .toString(); + } } diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java index 640b79e8e39..b72cf664279 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java @@ -186,6 +186,30 @@ private static void generate(File outputDir, String packageName, File opDefs) { generate(outputDir, packageName, defs); } + private static void writeToFile(TypeSpec spec, File outputDir, String packageName) { + JavaFile file = + JavaFile.builder(packageName, spec).indent(" ").skipJavaLangImports(true).build(); + + File outputFile = + new File(outputDir, packageName.replace('.', '/') + '/' + spec.name + ".java"); + outputFile.getParentFile().mkdirs(); + try { + StringBuilder builder = new StringBuilder(); + builder.append(LICENSE + '\n'); + builder.append("// This class has been generated, DO NOT EDIT!\n\n"); + file.writeTo(builder); + + Files.write( + outputFile.toPath(), + builder.toString().getBytes(StandardCharsets.UTF_8), + StandardOpenOption.WRITE, + StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING); + } catch (IOException ioException) { + throw new IllegalStateException("Failed to write file " + outputFile, ioException); + } + } + /** Generate all the ops that pass {@link ClassGenerator#canGenerateOp(OpDef, ApiDef)}. */ private static void generate(File outputDir, String basePackage, Map<OpDef, ApiDef> ops) { List<FullOpDef> fullOps = @@ -219,34 +243,19 @@ private static void generate(File outputDir, String basePackage, Map<OpDef, ApiD })) .collect(Collectors.toList()); + List<StatefulPair> statefulPairs = StatefulPair.extractStatefulPairs(fullOps); + fullOps.forEach( (def) -> { TypeSpec spec = def.buildOpClass(); - JavaFile file = - JavaFile.builder(def.packageName, spec) - .indent(" ") - .skipJavaLangImports(true) - .build(); - - File outputFile = - new File(outputDir, def.packageName.replace('.', '/') + '/' + spec.name + ".java"); - outputFile.getParentFile().mkdirs(); - try { - StringBuilder builder = new StringBuilder(); - builder.append(LICENSE + '\n'); - builder.append("// This class has been generated, DO NOT EDIT!\n\n"); - file.writeTo(builder); - - Files.write( - outputFile.toPath(), - builder.toString().getBytes(StandardCharsets.UTF_8), - StandardOpenOption.WRITE, - StandardOpenOption.CREATE, - StandardOpenOption.TRUNCATE_EXISTING); - } catch (IOException ioException) { - throw new IllegalStateException("Failed to write file " + outputFile, ioException); - } + writeToFile(spec, outputDir, def.packageName); + }); + + statefulPairs.forEach( + (pair) -> { + pair.buildOpClasses() + .forEach((spec) -> writeToFile(spec, outputDir, pair.getPackageName())); }); } } diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java index f02c0e24068..da82dd237e9 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java @@ -15,17 +15,116 @@ */ package org.tensorflow.op.generator; +import com.squareup.javapoet.TypeSpec; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.StringJoiner; + public final class StatefulPair { public final FullOpDef statefulOp; - public final FullOpDef statelessOp; + public final String selectorClassName; + public final String statefulClassName; + public final String statelessClassName; public StatefulPair(FullOpDef statefulOp, FullOpDef statelessOp) { this.statefulOp = statefulOp; this.statelessOp = statelessOp; + + this.selectorClassName = statefulOp.className.replace("stateful", "").replace("Stateful", ""); + + if (statefulOp.className.toLowerCase().contains("stateful")) { + statefulClassName = statefulOp.className; + } else { + statefulClassName = "Stateful" + statefulOp.className; + } + + if (statelessOp.className.toLowerCase().contains("stateless")) { + statelessClassName = statelessOp.className; + } else { + statelessClassName = "Stateless" + statelessOp.className; + } + } + + public static List<StatefulPair> extractStatefulPairs(List<FullOpDef> ops) { + List<StatefulPair> pairs = new ArrayList<>(10); + for (FullOpDef stateful : ops) { + if (!stateful.isStateful()) { + continue; + } + + for (FullOpDef stateless : ops) { + if (stateful.isStateVariant(stateless) && !stateful.equals(stateless)) { + if (stateful.opDef.getName().toLowerCase().contains("stateful") + || stateless.opDef.getName().toLowerCase().contains("stateless")) { + pairs.add(new StatefulPair(stateful, stateless)); + } + } + } + } + for (StatefulPair pair : pairs) { + ops.remove(pair.statefulOp); + ops.remove(pair.statelessOp); + } + return pairs; + } + + public String getPackageName() { + return statefulOp.packageName; } - // public static List<StatefulPair> extractStatefulPairs(List<FullOpDef> ops) { - // // List<> - // } + public List<TypeSpec> buildOpClasses() { + TypeSpec stateful = statefulOp.buildOpClass(statefulClassName); + TypeSpec stateless = statelessOp.buildOpClass(statelessClassName); + + return Arrays.asList(stateful, stateless); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + StatefulPair that = (StatefulPair) o; + + if (!statefulOp.equals(that.statefulOp)) { + return false; + } + if (!statelessOp.equals(that.statelessOp)) { + return false; + } + if (!selectorClassName.equals(that.selectorClassName)) { + return false; + } + if (!statefulClassName.equals(that.statefulClassName)) { + return false; + } + return statelessClassName.equals(that.statelessClassName); + } + + @Override + public int hashCode() { + int result = statefulOp.hashCode(); + result = 31 * result + statelessOp.hashCode(); + result = 31 * result + selectorClassName.hashCode(); + result = 31 * result + statefulClassName.hashCode(); + result = 31 * result + statelessClassName.hashCode(); + return result; + } + + @Override + public String toString() { + return new StringJoiner(", ", StatefulPair.class.getSimpleName() + "(", ")") + .add("statefulOp=" + statefulOp) + .add("statelessOp=" + statelessOp) + .add("selectorClassName='" + selectorClassName + "'") + .add("statefulClassName='" + statefulClassName + "'") + .add("statelessClassName='" + statelessClassName + "'") + .toString(); + } } From 3bee40b15b4d8c28ca5ab764a62a0cb3c23e53d6 Mon Sep 17 00:00:00 2001 From: Ryan Nett <JNett96@gmail.com> Date: Mon, 31 May 2021 18:40:25 -0700 Subject: [PATCH 06/14] Stateful/stateless processing, selector op wrappers Signed-off-by: Ryan Nett <JNett96@gmail.com> --- .../annotations/org/tensorflow/op/Ops.java | 129 ++++++++++- .../org/tensorflow/op/core/BatchFunction.java | 2 +- .../gen/java/org/tensorflow/op/core/Case.java | 150 ++++++++++++ .../gen/java/org/tensorflow/op/core/For.java | 2 +- .../op/core/GroupByReducerDataset.java | 2 +- .../gen/java/org/tensorflow/op/core/If.java | 146 ++++++++++++ .../java/org/tensorflow/op/core/MapDefun.java | 2 +- .../tensorflow/op/core/PartitionedCall.java | 155 +++++++++++++ .../org/tensorflow/op/core/ReduceDataset.java | 2 +- .../org/tensorflow/op/core/RemoteCall.java | 2 +- .../org/tensorflow/op/core/StatefulCase.java | 61 +---- .../org/tensorflow/op/core/StatefulIf.java | 61 +---- .../op/core/StatefulPartitionedCall.java | 86 +------ .../org/tensorflow/op/core/StatefulWhile.java | 84 +------ .../org/tensorflow/op/core/StatelessCase.java | 61 +---- .../org/tensorflow/op/core/StatelessIf.java | 61 +---- .../op/core/StatelessPartitionedCall.java | 86 +------ .../tensorflow/op/core/StatelessWhile.java | 84 +------ .../java/org/tensorflow/op/core/While.java | 165 +++++++++++++ .../op/data/ChooseFastestBranchDataset.java | 2 +- .../org/tensorflow/op/data/FilterDataset.java | 2 +- .../tensorflow/op/data/FlatMapDataset.java | 2 +- .../tensorflow/op/data/GeneratorDataset.java | 2 +- .../op/data/GroupByWindowDataset.java | 2 +- .../tensorflow/op/data/InterleaveDataset.java | 2 +- .../org/tensorflow/op/data/LoadDataset.java | 2 +- .../org/tensorflow/op/data/MapDataset.java | 2 +- .../tensorflow/op/data/OneShotIterator.java | 2 +- .../op/data/ParallelMapDataset.java | 2 +- .../org/tensorflow/op/data/SaveDataset.java | 2 +- .../org/tensorflow/op/data/ScanDataset.java | 2 +- .../tensorflow/op/data/SnapshotDataset.java | 2 +- .../tensorflow/op/data/TakeWhileDataset.java | 2 +- .../experimental/GroupByReducerDataset.java | 2 +- .../experimental/GroupByWindowDataset.java | 2 +- .../LegacyParallelInterleaveDataset.java | 2 +- .../data/experimental/MapAndBatchDataset.java | 2 +- .../op/data/experimental/MapDataset.java | 2 +- .../ParallelInterleaveDataset.java | 2 +- .../op/data/experimental/ScanDataset.java | 2 +- .../data/experimental/TakeWhileDataset.java | 2 +- .../java/org/tensorflow/op/tpu/Compile.java | 2 +- .../tensorflow/op/tpu/PartitionedCall.java | 2 +- .../tensorflow/op/train/SymbolicGradient.java | 2 +- .../gen/java/org/tensorflow/op/xla/If.java | 2 +- .../java/org/tensorflow/op/xla/Reduce.java | 2 +- .../org/tensorflow/op/xla/ReduceWindow.java | 2 +- .../java/org/tensorflow/op/xla/Scatter.java | 2 +- .../tensorflow/op/xla/SelectAndScatter.java | 2 +- .../gen/java/org/tensorflow/op/xla/While.java | 2 +- .../org/tensorflow/op/xla/XlaHostCompute.java | 2 +- .../java/org/tensorflow/op/xla/XlaLaunch.java | 2 +- .../tensorflow/op/xla/XlaVariadicReduce.java | 2 +- .../op/generator/ClassGenerator.java | 217 ++++++++++++++---- .../tensorflow/op/generator/FullOpDef.java | 11 +- .../tensorflow/op/generator/StatefulPair.java | 40 +++- 56 files changed, 1032 insertions(+), 645 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index d6402581da5..7b6aa4a1679 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -222,8 +222,12 @@ import org.tensorflow.op.core.StageClear; import org.tensorflow.op.core.StagePeek; import org.tensorflow.op.core.StageSize; +import org.tensorflow.op.core.StatefulCase; +import org.tensorflow.op.core.StatefulIf; import org.tensorflow.op.core.StatefulPartitionedCall; +import org.tensorflow.op.core.StatefulWhile; import org.tensorflow.op.core.StatelessIf; +import org.tensorflow.op.core.StatelessPartitionedCall; import org.tensorflow.op.core.StatelessWhile; import org.tensorflow.op.core.StopGradient; import org.tensorflow.op.core.StridedSlice; @@ -1230,6 +1234,7 @@ public Map<String, Operand<?>> call(ConcreteFunction function, * } * ``` * </pre> + * Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments. * * @param branchIndex The branch selector, an int32 Tensor. * @param input A list of input tensors passed to the branch function. @@ -2946,6 +2951,7 @@ public IdentityN identityN(Iterable<Operand<?>> input) { /** * output = cond ? then_branch(input) : else_branch(input) + * Selects between {@link StatefulIf} and {@link StatelessIf} based on the statefulness of the function arguments. * * @param cond <pre> * A Tensor. If the tensor is a scalar of non-boolean type, the @@ -4015,6 +4021,7 @@ public <T extends TType> ParallelDynamicStitch<T> parallelDynamicStitch( /** * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. + * Selects between {@link StatefulPartitionedCall} and {@link StatelessPartitionedCall} based on the statefulness of the function arguments. * * @param args A list of input tensors. * @param Tout A list of output types. @@ -4022,7 +4029,8 @@ public <T extends TType> ParallelDynamicStitch<T> parallelDynamicStitch( * A function that takes 'args', a list of tensors, and returns 'output', * another list of tensors. Input and output types are specified by 'Tin' * and 'Tout'. The function body of f will be placed and partitioned across - * devices, setting this op apart from the regular Call op. + * devices, setting this op apart from the regular Call op. This op is + * stateful. * </pre> * @param options carries optional attribute values * @return a new instance of PartitionedCall @@ -6048,6 +6056,72 @@ public StageSize stageSize(List<Class<? extends TType>> dtypes, StageSize.Option return StageSize.create(scope, dtypes, options); } + /** + * An n-way switch statement which calls a single branch function. + * <pre> + * An n-way switch statement, implementing the following: + * ``` + * switch (branch_index) { + * case 0: + * output = branches[0](input); + * break; + * case 1: + * output = branches[1](input); + * break; + * ... + * case [[nbranches-1]]: + * default: + * output = branches[nbranches-1](input); + * break; + * } + * ``` + * </pre> + * + * @param branchIndex The branch selector, an int32 Tensor. + * @param input A list of input tensors passed to the branch function. + * @param Tout A list of output types. + * @param branches <pre> + * A list of functions each of which takes 'inputs' and returns a list of + * tensors, whose types are the same as what every other branch returns. + * </pre> + * @param options carries optional attribute values + * @return a new instance of StatefulCase + */ + public StatefulCase statefulCase(Operand<TInt32> branchIndex, Iterable<Operand<?>> input, + List<Class<? extends TType>> Tout, List<ConcreteFunction> branches, Case.Options... options) { + return StatefulCase.create(scope, branchIndex, input, Tout, branches, options); + } + + /** + * output = cond ? then_branch(input) : else_branch(input) + * + * @param cond <pre> + * A Tensor. If the tensor is a scalar of non-boolean type, the + * scalar is converted to a boolean according to the + * following rule: if the scalar is a numerical value, non-zero means + * `True` and zero means False; if the scalar is a string, non-empty + * means `True` and empty means `False`. If the tensor is not a scalar, + * being empty means False and being non-empty means True. + * </pre> + * @param input A list of input tensors. + * @param Tout A list of output types. + * @param thenBranch <pre> + * A function that takes 'inputs' and returns a list of tensors, whose + * types are the same as what else_branch returns. + * </pre> + * @param elseBranch <pre> + * A function that takes 'inputs' and returns a list of tensors, whose + * types are the same as what then_branch returns. + * </pre> + * @param options carries optional attribute values + * @return a new instance of StatefulIf + */ + public StatefulIf statefulIf(Operand<? extends TType> cond, Iterable<Operand<?>> input, + List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, ConcreteFunction elseBranch, + If.Options... options) { + return StatefulIf.create(scope, cond, input, Tout, thenBranch, elseBranch, options); + } + /** * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. * @@ -6064,11 +6138,36 @@ public StageSize stageSize(List<Class<? extends TType>> dtypes, StageSize.Option * @return a new instance of StatefulPartitionedCall */ public StatefulPartitionedCall statefulPartitionedCall(Iterable<Operand<?>> args, - List<Class<? extends TType>> Tout, ConcreteFunction f, - StatefulPartitionedCall.Options... options) { + List<Class<? extends TType>> Tout, ConcreteFunction f, PartitionedCall.Options... options) { return StatefulPartitionedCall.create(scope, args, Tout, f, options); } + /** + * output = input; While (Cond(output)) { output = Body(output) } + * + * @param input A list of input tensors whose types are T. + * @param cond <pre> + * A function takes 'input' and returns a tensor. If the tensor is + * a scalar of non-boolean, the scalar is converted to a boolean + * according to the following rule: if the scalar is a numerical + * value, non-zero means True and zero means False; if the scalar is + * a string, non-empty means True and empty means False. If the + * tensor is not a scalar, non-emptiness means True and False + * otherwise. + * </pre> + * @param body <pre> + * A function that takes a list of tensors and returns another + * list of tensors. Both lists have the same types as specified + * by T. + * </pre> + * @param options carries optional attribute values + * @return a new instance of StatefulWhile + */ + public StatefulWhile statefulWhile(Iterable<Operand<?>> input, ConcreteFunction cond, + ConcreteFunction body, While.Options... options) { + return StatefulWhile.create(scope, input, cond, body, options); + } + /** * output = cond ? then_branch(input) : else_branch(input) * @@ -6098,10 +6197,29 @@ public StatefulPartitionedCall statefulPartitionedCall(Iterable<Operand<?>> args */ public StatelessIf statelessIf(Operand<? extends TType> cond, Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, ConcreteFunction elseBranch, - StatelessIf.Options... options) { + If.Options... options) { return StatelessIf.create(scope, cond, input, Tout, thenBranch, elseBranch, options); } + /** + * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. + * + * @param args A list of input tensors. + * @param Tout A list of output types. + * @param f <pre> + * A function that takes 'args', a list of tensors, and returns 'output', + * another list of tensors. Input and output types are specified by 'Tin' + * and 'Tout'. The function body of f will be placed and partitioned across + * devices, setting this op apart from the regular Call op. + * </pre> + * @param options carries optional attribute values + * @return a new instance of StatelessPartitionedCall + */ + public StatelessPartitionedCall statelessPartitionedCall(Iterable<Operand<?>> args, + List<Class<? extends TType>> Tout, ConcreteFunction f, PartitionedCall.Options... options) { + return StatelessPartitionedCall.create(scope, args, Tout, f, options); + } + /** * output = input; While (Cond(output)) { output = Body(output) } * @@ -6127,7 +6245,7 @@ public StatelessIf statelessIf(Operand<? extends TType> cond, Iterable<Operand<? * @return a new instance of StatelessWhile */ public StatelessWhile statelessWhile(Iterable<Operand<?>> input, ConcreteFunction cond, - ConcreteFunction body, StatelessWhile.Options... options) { + ConcreteFunction body, While.Options... options) { return StatelessWhile.create(scope, input, cond, body, options); } @@ -7990,6 +8108,7 @@ public Where where(Operand<? extends TType> condition) { /** * output = input; While (Cond(output)) { output = Body(output) } + * Selects between {@link StatefulWhile} and {@link StatelessWhile} based on the statefulness of the function arguments. * * @param input A list of input tensors whose types are T. * @param cond <pre> diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BatchFunction.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BatchFunction.java index 80841c9b28f..db64e937645 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BatchFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BatchFunction.java @@ -109,7 +109,7 @@ public static BatchFunction create(Scope scope, Iterable<Operand<?>> inTensors, Iterable<Operand<?>> capturedTensors, ConcreteFunction f, Long numBatchThreads, Long maxBatchSize, Long batchTimeoutMicros, List<Class<? extends TType>> Tout, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("BatchFunction", scope.makeOpName("BatchFunction")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("BatchFunction")); opBuilder.addInputList(Operands.asOutputs(inTensors)); opBuilder.addInputList(Operands.asOutputs(capturedTensors)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java new file mode 100644 index 00000000000..46e2406bcdf --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java @@ -0,0 +1,150 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; + +/** + * An n-way switch statement which calls a single branch function. + * <pre> + * An n-way switch statement, implementing the following: + * ``` + * switch (branch_index) { + * case 0: + * output = branches[0](input); + * break; + * case 1: + * output = branches[1](input); + * break; + * ... + * case [[nbranches-1]]: + * default: + * output = branches[nbranches-1](input); + * break; + * } + * ``` + * </pre> + * Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments. + */ +@Operator +public interface Case extends Iterable<Operand<TType>> { + /** + * Factory method to create a class wrapping a new Case operation. + * + * @param scope current scope + * @param branchIndex The branch selector, an int32 Tensor. + * @param input A list of input tensors passed to the branch function. + * @param Tout A list of output types. + * @param branches <pre> + * A list of functions each of which takes 'inputs' and returns a list of + * tensors, whose types are the same as what every other branch returns. + * </pre> + * @param options carries optional attribute values + * @return a new instance of Case + */ + @Endpoint( + describeByClass = true, + name = "caseOp" + ) + static Case create(Scope scope, Operand<TInt32> branchIndex, Iterable<Operand<?>> input, + List<Class<? extends TType>> Tout, List<ConcreteFunction> branches, Options... options) { + boolean isStateful = false; + if (branches.stream().anyMatch(x -> x.isStateful())) { + isStateful = true; + } + if (isStateful) { + return StatefulCase.create(scope, branchIndex, input, Tout, branches, options); + } else { + return StatelessCase.create(scope, branchIndex, input, Tout, branches, options); + } + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + static Options outputShapes(List<Shape> outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + static Options outputShapes(Shape[] outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Gets output. + * A list of return values. + * @return output. + */ + List<Output<?>> output(); + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + Iterator<Operand<TType>> iterator(); + + /** + * Optional attributes for {@link org.tensorflow.op.core.Case} + */ + class Options { + List<Shape> outputShapes; + + private Options() { + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(List<Shape> outputShapes) { + this.outputShapes = outputShapes; + return this; + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(Shape... outputShapes) { + this.outputShapes = Arrays.asList(outputShapes); + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java index 7d714de0879..96392a1a3cc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/For.java @@ -78,7 +78,7 @@ private For(Operation operation) { ) public static For create(Scope scope, Operand<TInt32> start, Operand<TInt32> limit, Operand<TInt32> delta, Iterable<Operand<?>> input, ConcreteFunction body) { - OperationBuilder opBuilder = scope.env().opBuilder("For", scope.makeOpName("For")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("For")); opBuilder.addInput(start.asOutput()); opBuilder.addInput(limit.asOutput()); opBuilder.addInput(delta.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GroupByReducerDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GroupByReducerDataset.java index b0e9ba81d19..b5c042cd733 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GroupByReducerDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GroupByReducerDataset.java @@ -82,7 +82,7 @@ public static GroupByReducerDataset create(Scope scope, Operand<? extends TType> Iterable<Operand<?>> finalizeFuncOtherArguments, ConcreteFunction keyFunc, ConcreteFunction initFunc, ConcreteFunction reduceFunc, ConcreteFunction finalizeFunc, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder("GroupByReducerDataset", scope.makeOpName("GroupByReducerDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("GroupByReducerDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(keyFuncOtherArguments)); opBuilder.addInputList(Operands.asOutputs(initFuncOtherArguments)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java new file mode 100644 index 00000000000..82134d19559 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java @@ -0,0 +1,146 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * output = cond ? then_branch(input) : else_branch(input) + * Selects between {@link StatefulIf} and {@link StatelessIf} based on the statefulness of the function arguments. + */ +@Operator +public interface If extends Iterable<Operand<TType>> { + /** + * Factory method to create a class wrapping a new If operation. + * + * @param scope current scope + * @param cond <pre> + * A Tensor. If the tensor is a scalar of non-boolean type, the + * scalar is converted to a boolean according to the + * following rule: if the scalar is a numerical value, non-zero means + * `True` and zero means False; if the scalar is a string, non-empty + * means `True` and empty means `False`. If the tensor is not a scalar, + * being empty means False and being non-empty means True. + * </pre> + * @param input A list of input tensors. + * @param Tout A list of output types. + * @param thenBranch <pre> + * A function that takes 'inputs' and returns a list of tensors, whose + * types are the same as what else_branch returns. + * </pre> + * @param elseBranch <pre> + * A function that takes 'inputs' and returns a list of tensors, whose + * types are the same as what then_branch returns. + * </pre> + * @param options carries optional attribute values + * @return a new instance of If + */ + @Endpoint( + describeByClass = true, + name = "ifOp" + ) + static If create(Scope scope, Operand<? extends TType> cond, Iterable<Operand<?>> input, + List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, ConcreteFunction elseBranch, + Options... options) { + boolean isStateful = false; + if (thenBranch.isStateful()) { + isStateful = true; + } + if (elseBranch.isStateful()) { + isStateful = true; + } + if (isStateful) { + return StatefulIf.create(scope, cond, input, Tout, thenBranch, elseBranch, options); + } else { + return StatelessIf.create(scope, cond, input, Tout, thenBranch, elseBranch, options); + } + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + static Options outputShapes(List<Shape> outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + static Options outputShapes(Shape[] outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Gets output. + * A list of return values. + * @return output. + */ + List<Output<?>> output(); + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + Iterator<Operand<TType>> iterator(); + + /** + * Optional attributes for {@link org.tensorflow.op.core.If} + */ + class Options { + List<Shape> outputShapes; + + private Options() { + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(List<Shape> outputShapes) { + this.outputShapes = outputShapes; + return this; + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(Shape... outputShapes) { + this.outputShapes = Arrays.asList(outputShapes); + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapDefun.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapDefun.java index 7a99f926eb2..60fb426abd6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapDefun.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapDefun.java @@ -85,7 +85,7 @@ private MapDefun(Operation operation) { public static MapDefun create(Scope scope, Iterable<Operand<?>> arguments, Iterable<Operand<?>> capturedInputs, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, ConcreteFunction f, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("MapDefun", scope.makeOpName("MapDefun")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("MapDefun")); opBuilder.addInputList(Operands.asOutputs(arguments)); opBuilder.addInputList(Operands.asOutputs(capturedInputs)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java new file mode 100644 index 00000000000..42ccc7168b0 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java @@ -0,0 +1,155 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. + * Selects between {@link StatefulPartitionedCall} and {@link StatelessPartitionedCall} based on the statefulness of the function arguments. + */ +@Operator +public interface PartitionedCall extends Iterable<Operand<TType>> { + /** + * Factory method to create a class wrapping a new StatefulPartitionedCall operation. + * + * @param scope current scope + * @param args A list of input tensors. + * @param Tout A list of output types. + * @param f <pre> + * A function that takes 'args', a list of tensors, and returns 'output', + * another list of tensors. Input and output types are specified by 'Tin' + * and 'Tout'. The function body of f will be placed and partitioned across + * devices, setting this op apart from the regular Call op. This op is + * stateful. + * </pre> + * @param options carries optional attribute values + * @return a new instance of PartitionedCall + */ + @Endpoint( + describeByClass = true + ) + static PartitionedCall create(Scope scope, Iterable<Operand<?>> args, + List<Class<? extends TType>> Tout, ConcreteFunction f, Options... options) { + boolean isStateful = false; + if (f.isStateful()) { + isStateful = true; + } + if (isStateful) { + return StatefulPartitionedCall.create(scope, args, Tout, f, options); + } else { + return StatelessPartitionedCall.create(scope, args, Tout, f, options); + } + } + + /** + * Sets the config option. + * + * @param config the config option + * @return this Options instance. + */ + static Options config(String config) { + return new Options().config(config); + } + + /** + * Sets the configProto option. + * + * @param configProto the configProto option + * @return this Options instance. + */ + static Options configProto(String configProto) { + return new Options().configProto(configProto); + } + + /** + * Sets the executorType option. + * + * @param executorType the executorType option + * @return this Options instance. + */ + static Options executorType(String executorType) { + return new Options().executorType(executorType); + } + + /** + * Gets output. + * A list of return values. + * @return output. + */ + List<Output<?>> output(); + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + Iterator<Operand<TType>> iterator(); + + /** + * Optional attributes for {@link org.tensorflow.op.core.PartitionedCall} + */ + class Options { + String config; + + String configProto; + + String executorType; + + private Options() { + } + + /** + * Sets the config option. + * + * @param config the config option + * @return this Options instance. + */ + public Options config(String config) { + this.config = config; + return this; + } + + /** + * Sets the configProto option. + * + * @param configProto the configProto option + * @return this Options instance. + */ + public Options configProto(String configProto) { + this.configProto = configProto; + return this; + } + + /** + * Sets the executorType option. + * + * @param executorType the executorType option + * @return this Options instance. + */ + public Options executorType(String executorType) { + this.executorType = executorType; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceDataset.java index 3e15fc24817..c82507fe5f7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceDataset.java @@ -74,7 +74,7 @@ private ReduceDataset(Operation operation) { public static ReduceDataset create(Scope scope, Operand<? extends TType> inputDataset, Iterable<Operand<?>> initialState, Iterable<Operand<?>> otherArguments, ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("ReduceDataset", scope.makeOpName("ReduceDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("ReduceDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(initialState)); opBuilder.addInputList(Operands.asOutputs(otherArguments)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteCall.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteCall.java index 416c7f222fc..3848a6ff0e7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteCall.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteCall.java @@ -69,7 +69,7 @@ private RemoteCall(Operation operation) { ) public static RemoteCall create(Scope scope, Operand<TString> target, Iterable<Operand<?>> args, List<Class<? extends TType>> Tout, ConcreteFunction f) { - OperationBuilder opBuilder = scope.env().opBuilder("RemoteCall", scope.makeOpName("RemoteCall")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("RemoteCall")); opBuilder.addInput(target.asOutput()); opBuilder.addInputList(Operands.asOutputs(args)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulCase.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulCase.java index beac85f315d..817539cb187 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulCase.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulCase.java @@ -56,7 +56,7 @@ * </pre> */ @Operator -public final class StatefulCase extends RawOp implements Iterable<Operand<TType>> { +public final class StatefulCase extends RawOp implements Case { /** * The name of this op, as known by TensorFlow core engine */ @@ -92,8 +92,8 @@ private StatefulCase(Operation operation) { ) public static StatefulCase create(Scope scope, Operand<TInt32> branchIndex, Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, - List<ConcreteFunction> branches, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("Case", scope.makeOpName("StatefulCase")); + List<ConcreteFunction> branches, Case.Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("StatefulCase")); opBuilder.addInput(branchIndex.asOutput()); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.apply(opBuilder); @@ -104,7 +104,7 @@ public static StatefulCase create(Scope scope, Operand<TInt32> branchIndex, } opBuilder.setAttr("branches", branchesArray); if (options != null) { - for (Options opts : options) { + for (Case.Options opts : options) { if (opts.outputShapes != null) { Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; for (int i = 0 ; i < outputShapesArray.length ; i++) { @@ -117,31 +117,12 @@ public static StatefulCase create(Scope scope, Operand<TInt32> branchIndex, return new StatefulCase(opBuilder.build()); } - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public static Options outputShapes(List<Shape> outputShapes) { - return new Options().outputShapes(outputShapes); - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public static Options outputShapes(Shape[] outputShapes) { - return new Options().outputShapes(outputShapes); - } - /** * Gets output. * A list of return values. * @return output. */ + @Override public List<Output<?>> output() { return output; } @@ -151,36 +132,4 @@ public List<Output<?>> output() { public Iterator<Operand<TType>> iterator() { return (Iterator) output.iterator(); } - - /** - * Optional attributes for {@link org.tensorflow.op.core.StatefulCase} - */ - public static class Options { - private List<Shape> outputShapes; - - private Options() { - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public Options outputShapes(List<Shape> outputShapes) { - this.outputShapes = outputShapes; - return this; - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public Options outputShapes(Shape... outputShapes) { - this.outputShapes = Arrays.asList(outputShapes); - return this; - } - } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulIf.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulIf.java index 7f773e18c4b..e1f13750f3e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulIf.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulIf.java @@ -37,7 +37,7 @@ * output = cond ? then_branch(input) : else_branch(input) */ @Operator -public final class StatefulIf extends RawOp implements Iterable<Operand<TType>> { +public final class StatefulIf extends RawOp implements If { /** * The name of this op, as known by TensorFlow core engine */ @@ -84,8 +84,8 @@ private StatefulIf(Operation operation) { ) public static StatefulIf create(Scope scope, Operand<? extends TType> cond, Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, - ConcreteFunction elseBranch, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("If", scope.makeOpName("StatefulIf")); + ConcreteFunction elseBranch, If.Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("StatefulIf")); opBuilder.addInput(cond.asOutput()); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.apply(opBuilder); @@ -93,7 +93,7 @@ public static StatefulIf create(Scope scope, Operand<? extends TType> cond, opBuilder.setAttr("then_branch", thenBranch); opBuilder.setAttr("else_branch", elseBranch); if (options != null) { - for (Options opts : options) { + for (If.Options opts : options) { if (opts.outputShapes != null) { Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; for (int i = 0 ; i < outputShapesArray.length ; i++) { @@ -106,31 +106,12 @@ public static StatefulIf create(Scope scope, Operand<? extends TType> cond, return new StatefulIf(opBuilder.build()); } - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public static Options outputShapes(List<Shape> outputShapes) { - return new Options().outputShapes(outputShapes); - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public static Options outputShapes(Shape[] outputShapes) { - return new Options().outputShapes(outputShapes); - } - /** * Gets output. * A list of return values. * @return output. */ + @Override public List<Output<?>> output() { return output; } @@ -140,36 +121,4 @@ public List<Output<?>> output() { public Iterator<Operand<TType>> iterator() { return (Iterator) output.iterator(); } - - /** - * Optional attributes for {@link org.tensorflow.op.core.StatefulIf} - */ - public static class Options { - private List<Shape> outputShapes; - - private Options() { - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public Options outputShapes(List<Shape> outputShapes) { - this.outputShapes = outputShapes; - return this; - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public Options outputShapes(Shape... outputShapes) { - this.outputShapes = Arrays.asList(outputShapes); - return this; - } - } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulPartitionedCall.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulPartitionedCall.java index 70090176588..ce789a235f9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulPartitionedCall.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulPartitionedCall.java @@ -36,7 +36,7 @@ * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. */ @Operator -public final class StatefulPartitionedCall extends RawOp implements Iterable<Operand<TType>> { +public final class StatefulPartitionedCall extends RawOp implements PartitionedCall { /** * The name of this op, as known by TensorFlow core engine */ @@ -73,14 +73,14 @@ private StatefulPartitionedCall(Operation operation) { describeByClass = true ) public static StatefulPartitionedCall create(Scope scope, Iterable<Operand<?>> args, - List<Class<? extends TType>> Tout, ConcreteFunction f, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("StatefulPartitionedCall", scope.makeOpName("StatefulPartitionedCall")); + List<Class<? extends TType>> Tout, ConcreteFunction f, PartitionedCall.Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("StatefulPartitionedCall")); opBuilder.addInputList(Operands.asOutputs(args)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); opBuilder.setAttr("f", f); if (options != null) { - for (Options opts : options) { + for (PartitionedCall.Options opts : options) { if (opts.config != null) { opBuilder.setAttr("config", opts.config); } @@ -95,41 +95,12 @@ public static StatefulPartitionedCall create(Scope scope, Iterable<Operand<?>> a return new StatefulPartitionedCall(opBuilder.build()); } - /** - * Sets the config option. - * - * @param config the config option - * @return this Options instance. - */ - public static Options config(String config) { - return new Options().config(config); - } - - /** - * Sets the configProto option. - * - * @param configProto the configProto option - * @return this Options instance. - */ - public static Options configProto(String configProto) { - return new Options().configProto(configProto); - } - - /** - * Sets the executorType option. - * - * @param executorType the executorType option - * @return this Options instance. - */ - public static Options executorType(String executorType) { - return new Options().executorType(executorType); - } - /** * Gets output. * A list of return values. * @return output. */ + @Override public List<Output<?>> output() { return output; } @@ -139,51 +110,4 @@ public List<Output<?>> output() { public Iterator<Operand<TType>> iterator() { return (Iterator) output.iterator(); } - - /** - * Optional attributes for {@link org.tensorflow.op.core.StatefulPartitionedCall} - */ - public static class Options { - private String config; - - private String configProto; - - private String executorType; - - private Options() { - } - - /** - * Sets the config option. - * - * @param config the config option - * @return this Options instance. - */ - public Options config(String config) { - this.config = config; - return this; - } - - /** - * Sets the configProto option. - * - * @param configProto the configProto option - * @return this Options instance. - */ - public Options configProto(String configProto) { - this.configProto = configProto; - return this; - } - - /** - * Sets the executorType option. - * - * @param executorType the executorType option - * @return this Options instance. - */ - public Options executorType(String executorType) { - this.executorType = executorType; - return this; - } - } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulWhile.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulWhile.java index 10edf7df77d..5d8fdcff241 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulWhile.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatefulWhile.java @@ -37,7 +37,7 @@ * output = input; While (Cond(output)) { output = Body(output) } */ @Operator -public final class StatefulWhile extends RawOp implements Iterable<Operand<TType>> { +public final class StatefulWhile extends RawOp implements While { /** * The name of this op, as known by TensorFlow core engine */ @@ -80,14 +80,14 @@ private StatefulWhile(Operation operation) { describeByClass = true ) public static StatefulWhile create(Scope scope, Iterable<Operand<?>> input, ConcreteFunction cond, - ConcreteFunction body, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("While", scope.makeOpName("StatefulWhile")); + ConcreteFunction body, While.Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("StatefulWhile")); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("cond", cond); opBuilder.setAttr("body", body); if (options != null) { - for (Options opts : options) { + for (While.Options opts : options) { if (opts.outputShapes != null) { Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; for (int i = 0 ; i < outputShapesArray.length ; i++) { @@ -103,41 +103,12 @@ public static StatefulWhile create(Scope scope, Iterable<Operand<?>> input, Conc return new StatefulWhile(opBuilder.build()); } - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public static Options outputShapes(List<Shape> outputShapes) { - return new Options().outputShapes(outputShapes); - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public static Options outputShapes(Shape[] outputShapes) { - return new Options().outputShapes(outputShapes); - } - - /** - * Sets the parallelIterations option. - * - * @param parallelIterations the parallelIterations option - * @return this Options instance. - */ - public static Options parallelIterations(Long parallelIterations) { - return new Options().parallelIterations(parallelIterations); - } - /** * Gets output. * A list of output tensors whose types are T. * @return output. */ + @Override public List<Output<?>> output() { return output; } @@ -147,49 +118,4 @@ public List<Output<?>> output() { public Iterator<Operand<TType>> iterator() { return (Iterator) output.iterator(); } - - /** - * Optional attributes for {@link org.tensorflow.op.core.StatefulWhile} - */ - public static class Options { - private List<Shape> outputShapes; - - private Long parallelIterations; - - private Options() { - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public Options outputShapes(List<Shape> outputShapes) { - this.outputShapes = outputShapes; - return this; - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public Options outputShapes(Shape... outputShapes) { - this.outputShapes = Arrays.asList(outputShapes); - return this; - } - - /** - * Sets the parallelIterations option. - * - * @param parallelIterations the parallelIterations option - * @return this Options instance. - */ - public Options parallelIterations(Long parallelIterations) { - this.parallelIterations = parallelIterations; - return this; - } - } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessCase.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessCase.java index 284f36c0db0..31326021300 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessCase.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessCase.java @@ -56,7 +56,7 @@ * This should only be used when the none of branches has stateful ops. * </pre> */ -public final class StatelessCase extends RawOp implements Iterable<Operand<TType>> { +public final class StatelessCase extends RawOp implements Case { /** * The name of this op, as known by TensorFlow core engine */ @@ -92,8 +92,8 @@ private StatelessCase(Operation operation) { ) public static StatelessCase create(Scope scope, Operand<TInt32> branchIndex, Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, - List<ConcreteFunction> branches, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("StatelessCase", scope.makeOpName("StatelessCase")); + List<ConcreteFunction> branches, Case.Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("StatelessCase")); opBuilder.addInput(branchIndex.asOutput()); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.apply(opBuilder); @@ -104,7 +104,7 @@ public static StatelessCase create(Scope scope, Operand<TInt32> branchIndex, } opBuilder.setAttr("branches", branchesArray); if (options != null) { - for (Options opts : options) { + for (Case.Options opts : options) { if (opts.outputShapes != null) { Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; for (int i = 0 ; i < outputShapesArray.length ; i++) { @@ -117,31 +117,12 @@ public static StatelessCase create(Scope scope, Operand<TInt32> branchIndex, return new StatelessCase(opBuilder.build()); } - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public static Options outputShapes(List<Shape> outputShapes) { - return new Options().outputShapes(outputShapes); - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public static Options outputShapes(Shape[] outputShapes) { - return new Options().outputShapes(outputShapes); - } - /** * Gets output. * A list of return values. * @return output. */ + @Override public List<Output<?>> output() { return output; } @@ -151,36 +132,4 @@ public List<Output<?>> output() { public Iterator<Operand<TType>> iterator() { return (Iterator) output.iterator(); } - - /** - * Optional attributes for {@link org.tensorflow.op.core.StatelessCase} - */ - public static class Options { - private List<Shape> outputShapes; - - private Options() { - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public Options outputShapes(List<Shape> outputShapes) { - this.outputShapes = outputShapes; - return this; - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public Options outputShapes(Shape... outputShapes) { - this.outputShapes = Arrays.asList(outputShapes); - return this; - } - } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessIf.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessIf.java index 8806806cdad..83589e73a34 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessIf.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessIf.java @@ -37,7 +37,7 @@ * output = cond ? then_branch(input) : else_branch(input) */ @Operator -public final class StatelessIf extends RawOp implements Iterable<Operand<TType>> { +public final class StatelessIf extends RawOp implements If { /** * The name of this op, as known by TensorFlow core engine */ @@ -87,8 +87,8 @@ private StatelessIf(Operation operation) { ) public static StatelessIf create(Scope scope, Operand<? extends TType> cond, Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, ConcreteFunction thenBranch, - ConcreteFunction elseBranch, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("StatelessIf", scope.makeOpName("StatelessIf")); + ConcreteFunction elseBranch, If.Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("StatelessIf")); opBuilder.addInput(cond.asOutput()); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.apply(opBuilder); @@ -96,7 +96,7 @@ public static StatelessIf create(Scope scope, Operand<? extends TType> cond, opBuilder.setAttr("then_branch", thenBranch); opBuilder.setAttr("else_branch", elseBranch); if (options != null) { - for (Options opts : options) { + for (If.Options opts : options) { if (opts.outputShapes != null) { Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; for (int i = 0 ; i < outputShapesArray.length ; i++) { @@ -109,31 +109,12 @@ public static StatelessIf create(Scope scope, Operand<? extends TType> cond, return new StatelessIf(opBuilder.build()); } - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public static Options outputShapes(List<Shape> outputShapes) { - return new Options().outputShapes(outputShapes); - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public static Options outputShapes(Shape[] outputShapes) { - return new Options().outputShapes(outputShapes); - } - /** * Gets output. * A list of return values. * @return output. */ + @Override public List<Output<?>> output() { return output; } @@ -143,36 +124,4 @@ public List<Output<?>> output() { public Iterator<Operand<TType>> iterator() { return (Iterator) output.iterator(); } - - /** - * Optional attributes for {@link org.tensorflow.op.core.StatelessIf} - */ - public static class Options { - private List<Shape> outputShapes; - - private Options() { - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public Options outputShapes(List<Shape> outputShapes) { - this.outputShapes = outputShapes; - return this; - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public Options outputShapes(Shape... outputShapes) { - this.outputShapes = Arrays.asList(outputShapes); - return this; - } - } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessPartitionedCall.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessPartitionedCall.java index 2fe53c2d398..f497454692b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessPartitionedCall.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessPartitionedCall.java @@ -36,7 +36,7 @@ * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. */ @Operator -public final class StatelessPartitionedCall extends RawOp implements Iterable<Operand<TType>> { +public final class StatelessPartitionedCall extends RawOp implements PartitionedCall { /** * The name of this op, as known by TensorFlow core engine */ @@ -72,14 +72,14 @@ private StatelessPartitionedCall(Operation operation) { describeByClass = true ) public static StatelessPartitionedCall create(Scope scope, Iterable<Operand<?>> args, - List<Class<? extends TType>> Tout, ConcreteFunction f, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("PartitionedCall", scope.makeOpName("StatelessPartitionedCall")); + List<Class<? extends TType>> Tout, ConcreteFunction f, PartitionedCall.Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("StatelessPartitionedCall")); opBuilder.addInputList(Operands.asOutputs(args)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); opBuilder.setAttr("f", f); if (options != null) { - for (Options opts : options) { + for (PartitionedCall.Options opts : options) { if (opts.config != null) { opBuilder.setAttr("config", opts.config); } @@ -94,41 +94,12 @@ public static StatelessPartitionedCall create(Scope scope, Iterable<Operand<?>> return new StatelessPartitionedCall(opBuilder.build()); } - /** - * Sets the config option. - * - * @param config the config option - * @return this Options instance. - */ - public static Options config(String config) { - return new Options().config(config); - } - - /** - * Sets the configProto option. - * - * @param configProto the configProto option - * @return this Options instance. - */ - public static Options configProto(String configProto) { - return new Options().configProto(configProto); - } - - /** - * Sets the executorType option. - * - * @param executorType the executorType option - * @return this Options instance. - */ - public static Options executorType(String executorType) { - return new Options().executorType(executorType); - } - /** * Gets output. * A list of return values. * @return output. */ + @Override public List<Output<?>> output() { return output; } @@ -138,51 +109,4 @@ public List<Output<?>> output() { public Iterator<Operand<TType>> iterator() { return (Iterator) output.iterator(); } - - /** - * Optional attributes for {@link org.tensorflow.op.core.StatelessPartitionedCall} - */ - public static class Options { - private String config; - - private String configProto; - - private String executorType; - - private Options() { - } - - /** - * Sets the config option. - * - * @param config the config option - * @return this Options instance. - */ - public Options config(String config) { - this.config = config; - return this; - } - - /** - * Sets the configProto option. - * - * @param configProto the configProto option - * @return this Options instance. - */ - public Options configProto(String configProto) { - this.configProto = configProto; - return this; - } - - /** - * Sets the executorType option. - * - * @param executorType the executorType option - * @return this Options instance. - */ - public Options executorType(String executorType) { - this.executorType = executorType; - return this; - } - } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessWhile.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessWhile.java index 8a806f1a9b1..77c067dfa42 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessWhile.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StatelessWhile.java @@ -37,7 +37,7 @@ * output = input; While (Cond(output)) { output = Body(output) } */ @Operator -public final class StatelessWhile extends RawOp implements Iterable<Operand<TType>> { +public final class StatelessWhile extends RawOp implements While { /** * The name of this op, as known by TensorFlow core engine */ @@ -83,14 +83,14 @@ private StatelessWhile(Operation operation) { describeByClass = true ) public static StatelessWhile create(Scope scope, Iterable<Operand<?>> input, - ConcreteFunction cond, ConcreteFunction body, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("StatelessWhile", scope.makeOpName("StatelessWhile")); + ConcreteFunction cond, ConcreteFunction body, While.Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("StatelessWhile")); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("cond", cond); opBuilder.setAttr("body", body); if (options != null) { - for (Options opts : options) { + for (While.Options opts : options) { if (opts.outputShapes != null) { Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; for (int i = 0 ; i < outputShapesArray.length ; i++) { @@ -106,41 +106,12 @@ public static StatelessWhile create(Scope scope, Iterable<Operand<?>> input, return new StatelessWhile(opBuilder.build()); } - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public static Options outputShapes(List<Shape> outputShapes) { - return new Options().outputShapes(outputShapes); - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public static Options outputShapes(Shape[] outputShapes) { - return new Options().outputShapes(outputShapes); - } - - /** - * Sets the parallelIterations option. - * - * @param parallelIterations the parallelIterations option - * @return this Options instance. - */ - public static Options parallelIterations(Long parallelIterations) { - return new Options().parallelIterations(parallelIterations); - } - /** * Gets output. * A list of output tensors whose types are T. * @return output. */ + @Override public List<Output<?>> output() { return output; } @@ -150,49 +121,4 @@ public List<Output<?>> output() { public Iterator<Operand<TType>> iterator() { return (Iterator) output.iterator(); } - - /** - * Optional attributes for {@link org.tensorflow.op.core.StatelessWhile} - */ - public static class Options { - private List<Shape> outputShapes; - - private Long parallelIterations; - - private Options() { - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public Options outputShapes(List<Shape> outputShapes) { - this.outputShapes = outputShapes; - return this; - } - - /** - * Sets the outputShapes option. - * - * @param outputShapes the outputShapes option - * @return this Options instance. - */ - public Options outputShapes(Shape... outputShapes) { - this.outputShapes = Arrays.asList(outputShapes); - return this; - } - - /** - * Sets the parallelIterations option. - * - * @param parallelIterations the parallelIterations option - * @return this Options instance. - */ - public Options parallelIterations(Long parallelIterations) { - this.parallelIterations = parallelIterations; - return this; - } - } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java new file mode 100644 index 00000000000..279683e3806 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java @@ -0,0 +1,165 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * output = input; While (Cond(output)) { output = Body(output) } + * Selects between {@link StatefulWhile} and {@link StatelessWhile} based on the statefulness of the function arguments. + */ +@Operator +public interface While extends Iterable<Operand<TType>> { + /** + * Factory method to create a class wrapping a new While operation. + * + * @param scope current scope + * @param input A list of input tensors whose types are T. + * @param cond <pre> + * A function takes 'input' and returns a tensor. If the tensor is + * a scalar of non-boolean, the scalar is converted to a boolean + * according to the following rule: if the scalar is a numerical + * value, non-zero means True and zero means False; if the scalar is + * a string, non-empty means True and empty means False. If the + * tensor is not a scalar, non-emptiness means True and False + * otherwise. + * </pre> + * @param body <pre> + * A function that takes a list of tensors and returns another + * list of tensors. Both lists have the same types as specified + * by T. + * </pre> + * @param options carries optional attribute values + * @return a new instance of While + */ + @Endpoint( + describeByClass = true, + name = "whileOp" + ) + static While create(Scope scope, Iterable<Operand<?>> input, ConcreteFunction cond, + ConcreteFunction body, Options... options) { + boolean isStateful = false; + if (cond.isStateful()) { + isStateful = true; + } + if (body.isStateful()) { + isStateful = true; + } + if (isStateful) { + return StatefulWhile.create(scope, input, cond, body, options); + } else { + return StatelessWhile.create(scope, input, cond, body, options); + } + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + static Options outputShapes(List<Shape> outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + static Options outputShapes(Shape[] outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the parallelIterations option. + * + * @param parallelIterations the parallelIterations option + * @return this Options instance. + */ + static Options parallelIterations(Long parallelIterations) { + return new Options().parallelIterations(parallelIterations); + } + + /** + * Gets output. + * A list of output tensors whose types are T. + * @return output. + */ + List<Output<?>> output(); + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + Iterator<Operand<TType>> iterator(); + + /** + * Optional attributes for {@link org.tensorflow.op.core.While} + */ + class Options { + List<Shape> outputShapes; + + Long parallelIterations; + + private Options() { + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(List<Shape> outputShapes) { + this.outputShapes = outputShapes; + return this; + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(Shape... outputShapes) { + this.outputShapes = Arrays.asList(outputShapes); + return this; + } + + /** + * Sets the parallelIterations option. + * + * @param parallelIterations the parallelIterations option + * @return this Options instance. + */ + public Options parallelIterations(Long parallelIterations) { + this.parallelIterations = parallelIterations; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestBranchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestBranchDataset.java index 67f309f9d62..5fe596e3538 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestBranchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestBranchDataset.java @@ -72,7 +72,7 @@ public static ChooseFastestBranchDataset create(Scope scope, Operand<TInt64> ratioDenominator, Iterable<Operand<?>> otherArguments, Long numElementsPerBranch, List<ConcreteFunction> branches, List<Long> otherArgumentsLengths, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder("ChooseFastestBranchDataset", scope.makeOpName("ChooseFastestBranchDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("ChooseFastestBranchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(ratioNumerator.asOutput()); opBuilder.addInput(ratioDenominator.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterDataset.java index e0d8d172db5..75f9b98aa8a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterDataset.java @@ -76,7 +76,7 @@ private FilterDataset(Operation operation) { public static FilterDataset create(Scope scope, Operand<? extends TType> inputDataset, Iterable<Operand<?>> otherArguments, ConcreteFunction predicate, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder("FilterDataset", scope.makeOpName("FilterDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("FilterDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(otherArguments)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FlatMapDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FlatMapDataset.java index f9022a2e3f2..6e4a48642ce 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FlatMapDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FlatMapDataset.java @@ -74,7 +74,7 @@ private FlatMapDataset(Operation operation) { public static FlatMapDataset create(Scope scope, Operand<? extends TType> inputDataset, Iterable<Operand<?>> otherArguments, ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder("FlatMapDataset", scope.makeOpName("FlatMapDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("FlatMapDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(otherArguments)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GeneratorDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GeneratorDataset.java index d19c70f72bd..d1f47611ba2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GeneratorDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GeneratorDataset.java @@ -69,7 +69,7 @@ public static GeneratorDataset create(Scope scope, Iterable<Operand<?>> initFunc Iterable<Operand<?>> nextFuncOtherArgs, Iterable<Operand<?>> finalizeFuncOtherArgs, ConcreteFunction initFunc, ConcreteFunction nextFunc, ConcreteFunction finalizeFunc, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder("GeneratorDataset", scope.makeOpName("GeneratorDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("GeneratorDataset")); opBuilder.addInputList(Operands.asOutputs(initFuncOtherArgs)); opBuilder.addInputList(Operands.asOutputs(nextFuncOtherArgs)); opBuilder.addInputList(Operands.asOutputs(finalizeFuncOtherArgs)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByWindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByWindowDataset.java index 4e7811fd31d..11478287d00 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByWindowDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByWindowDataset.java @@ -73,7 +73,7 @@ public static GroupByWindowDataset create(Scope scope, Operand<? extends TType> Iterable<Operand<?>> windowSizeFuncOtherArguments, ConcreteFunction keyFunc, ConcreteFunction reduceFunc, ConcreteFunction windowSizeFunc, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder("GroupByWindowDataset", scope.makeOpName("GroupByWindowDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("GroupByWindowDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(keyFuncOtherArguments)); opBuilder.addInputList(Operands.asOutputs(reduceFuncOtherArguments)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InterleaveDataset.java index b8e0b43187e..64f177ec539 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InterleaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InterleaveDataset.java @@ -79,7 +79,7 @@ private InterleaveDataset(Operation operation) { public static InterleaveDataset create(Scope scope, Operand<? extends TType> inputDataset, Iterable<Operand<?>> otherArguments, Operand<TInt64> cycleLength, Operand<TInt64> blockLength, ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder("InterleaveDataset", scope.makeOpName("InterleaveDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("InterleaveDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(otherArguments)); opBuilder.addInput(cycleLength.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LoadDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LoadDataset.java index f46937b36f2..95ad802e5ac 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LoadDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LoadDataset.java @@ -67,7 +67,7 @@ private LoadDataset(Operation operation) { public static LoadDataset create(Scope scope, Operand<TString> path, Iterable<Operand<?>> readerFuncOtherArgs, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, ConcreteFunction readerFunc, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("LoadDataset", scope.makeOpName("LoadDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("LoadDataset")); opBuilder.addInput(path.asOutput()); opBuilder.addInputList(Operands.asOutputs(readerFuncOtherArgs)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MapDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MapDataset.java index 5f4ab8e6776..5232a58cd1d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MapDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MapDataset.java @@ -70,7 +70,7 @@ private MapDataset(Operation operation) { public static MapDataset create(Scope scope, Operand<? extends TType> inputDataset, Iterable<Operand<?>> otherArguments, ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("MapDataset", scope.makeOpName("MapDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("MapDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(otherArguments)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OneShotIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OneShotIterator.java index 01330924f62..59421497931 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OneShotIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OneShotIterator.java @@ -84,7 +84,7 @@ private OneShotIterator(Operation operation) { ) public static OneShotIterator create(Scope scope, ConcreteFunction datasetFactory, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("OneShotIterator", scope.makeOpName("OneShotIterator")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("OneShotIterator")); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("dataset_factory", datasetFactory); opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelMapDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelMapDataset.java index b8488d7a1ac..ee9d6ec2c85 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelMapDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelMapDataset.java @@ -71,7 +71,7 @@ private ParallelMapDataset(Operation operation) { public static ParallelMapDataset create(Scope scope, Operand<? extends TType> inputDataset, Iterable<Operand<?>> otherArguments, Operand<TInt64> numParallelCalls, ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("ParallelMapDatasetV2", scope.makeOpName("ParallelMapDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("ParallelMapDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(otherArguments)); opBuilder.addInput(numParallelCalls.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SaveDataset.java index 9d2760b0424..4e5944ac7d9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SaveDataset.java @@ -58,7 +58,7 @@ private SaveDataset(Operation operation) { public static SaveDataset create(Scope scope, Operand<? extends TType> inputDataset, Operand<TString> path, Iterable<Operand<?>> shardFuncOtherArgs, ConcreteFunction shardFunc, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("SaveDataset", scope.makeOpName("SaveDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("SaveDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(path.asOutput()); opBuilder.addInputList(Operands.asOutputs(shardFuncOtherArgs)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ScanDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ScanDataset.java index 4ed398fb214..dde6aea2ca5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ScanDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ScanDataset.java @@ -67,7 +67,7 @@ private ScanDataset(Operation operation) { public static ScanDataset create(Scope scope, Operand<? extends TType> inputDataset, Iterable<Operand<?>> initialState, Iterable<Operand<?>> otherArguments, ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("ScanDataset", scope.makeOpName("ScanDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("ScanDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(initialState)); opBuilder.addInputList(Operands.asOutputs(otherArguments)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java index 426020c2f17..5ed0cfa969c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java @@ -76,7 +76,7 @@ public static SnapshotDataset create(Scope scope, Operand<? extends TType> input Iterable<Operand<?>> shardFuncOtherArgs, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, ConcreteFunction readerFunc, ConcreteFunction shardFunc, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("SnapshotDatasetV2", scope.makeOpName("SnapshotDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("SnapshotDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(path.asOutput()); opBuilder.addInputList(Operands.asOutputs(readerFuncOtherArgs)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeWhileDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeWhileDataset.java index 23bc41932d8..3bcf539bd80 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeWhileDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeWhileDataset.java @@ -72,7 +72,7 @@ private TakeWhileDataset(Operation operation) { public static TakeWhileDataset create(Scope scope, Operand<? extends TType> inputDataset, Iterable<Operand<?>> otherArguments, ConcreteFunction predicate, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder("TakeWhileDataset", scope.makeOpName("TakeWhileDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("TakeWhileDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(otherArguments)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByReducerDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByReducerDataset.java index 822d5e03c51..1b2c46e2de1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByReducerDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByReducerDataset.java @@ -82,7 +82,7 @@ public static GroupByReducerDataset create(Scope scope, Operand<? extends TType> Iterable<Operand<?>> finalizeFuncOtherArguments, ConcreteFunction keyFunc, ConcreteFunction initFunc, ConcreteFunction reduceFunc, ConcreteFunction finalizeFunc, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalGroupByReducerDataset", scope.makeOpName("GroupByReducerDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("GroupByReducerDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(keyFuncOtherArguments)); opBuilder.addInputList(Operands.asOutputs(initFuncOtherArguments)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByWindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByWindowDataset.java index fe598c19cad..043b87f399b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByWindowDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/GroupByWindowDataset.java @@ -73,7 +73,7 @@ public static GroupByWindowDataset create(Scope scope, Operand<? extends TType> Iterable<Operand<?>> windowSizeFuncOtherArguments, ConcreteFunction keyFunc, ConcreteFunction reduceFunc, ConcreteFunction windowSizeFunc, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalGroupByWindowDataset", scope.makeOpName("GroupByWindowDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("GroupByWindowDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(keyFuncOtherArguments)); opBuilder.addInputList(Operands.asOutputs(reduceFuncOtherArguments)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LegacyParallelInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LegacyParallelInterleaveDataset.java index 4a33fd1657c..36dd881dc20 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LegacyParallelInterleaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LegacyParallelInterleaveDataset.java @@ -82,7 +82,7 @@ public static LegacyParallelInterleaveDataset create(Scope scope, Operand<TInt64> bufferOutputElements, Operand<TInt64> prefetchInputElements, ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("LegacyParallelInterleaveDatasetV2", scope.makeOpName("LegacyParallelInterleaveDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("LegacyParallelInterleaveDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(otherArguments)); opBuilder.addInput(cycleLength.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapAndBatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapAndBatchDataset.java index 07f5a18ac79..3a86ebd3ed3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapAndBatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapAndBatchDataset.java @@ -82,7 +82,7 @@ public static MapAndBatchDataset create(Scope scope, Operand<? extends TType> in Iterable<Operand<?>> otherArguments, Operand<TInt64> batchSize, Operand<TInt64> numParallelCalls, Operand<TBool> dropRemainder, ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalMapAndBatchDataset", scope.makeOpName("MapAndBatchDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("MapAndBatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(otherArguments)); opBuilder.addInput(batchSize.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapDataset.java index 19667eb06bf..efab3bb9ef3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MapDataset.java @@ -66,7 +66,7 @@ private MapDataset(Operation operation) { public static MapDataset create(Scope scope, Operand<? extends TType> inputDataset, Iterable<Operand<?>> otherArguments, ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalMapDataset", scope.makeOpName("MapDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("MapDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(otherArguments)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParallelInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParallelInterleaveDataset.java index 5ec8ad9819a..b0475a6d457 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParallelInterleaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParallelInterleaveDataset.java @@ -95,7 +95,7 @@ public static ParallelInterleaveDataset create(Scope scope, Operand<? extends TT Operand<TInt64> bufferOutputElements, Operand<TInt64> prefetchInputElements, Operand<TInt64> numParallelCalls, ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("ParallelInterleaveDatasetV4", scope.makeOpName("ParallelInterleaveDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("ParallelInterleaveDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(otherArguments)); opBuilder.addInput(cycleLength.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ScanDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ScanDataset.java index bbe5d3e07af..e60d6024026 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ScanDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ScanDataset.java @@ -67,7 +67,7 @@ private ScanDataset(Operation operation) { public static ScanDataset create(Scope scope, Operand<? extends TType> inputDataset, Iterable<Operand<?>> initialState, Iterable<Operand<?>> otherArguments, ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalScanDataset", scope.makeOpName("ScanDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("ScanDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(initialState)); opBuilder.addInputList(Operands.asOutputs(otherArguments)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/TakeWhileDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/TakeWhileDataset.java index a03264f0f4c..9f283a7eb1f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/TakeWhileDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/TakeWhileDataset.java @@ -72,7 +72,7 @@ private TakeWhileDataset(Operation operation) { public static TakeWhileDataset create(Scope scope, Operand<? extends TType> inputDataset, Iterable<Operand<?>> otherArguments, ConcreteFunction predicate, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalTakeWhileDataset", scope.makeOpName("TakeWhileDataset")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("TakeWhileDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(otherArguments)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/Compile.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/Compile.java index 6e01d8f37e8..43725ba9b0b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/Compile.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/Compile.java @@ -95,7 +95,7 @@ private Compile(Operation operation) { public static Compile create(Scope scope, Iterable<Operand<TInt64>> dynamicShapes, Iterable<Operand<?>> guaranteedConstants, Long numComputations, ConcreteFunction function, String metadata) { - OperationBuilder opBuilder = scope.env().opBuilder("TPUCompile", scope.makeOpName("Compile")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("Compile")); opBuilder.addInputList(Operands.asOutputs(dynamicShapes)); opBuilder.addInputList(Operands.asOutputs(guaranteedConstants)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/PartitionedCall.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/PartitionedCall.java index 90a536c99e2..a4086db1728 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/PartitionedCall.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/PartitionedCall.java @@ -69,7 +69,7 @@ private PartitionedCall(Operation operation) { public static PartitionedCall create(Scope scope, Iterable<Operand<?>> args, Operand<TInt32> deviceOrdinal, List<Class<? extends TType>> Tout, ConcreteFunction f, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("TPUPartitionedCall", scope.makeOpName("PartitionedCall")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("PartitionedCall")); opBuilder.addInputList(Operands.asOutputs(args)); opBuilder.addInput(deviceOrdinal.asOutput()); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SymbolicGradient.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SymbolicGradient.java index be3fccf6f90..e58ce0b2afb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SymbolicGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SymbolicGradient.java @@ -82,7 +82,7 @@ private SymbolicGradient(Operation operation) { ) public static SymbolicGradient create(Scope scope, Iterable<Operand<?>> input, List<Class<? extends TType>> Tout, ConcreteFunction f) { - OperationBuilder opBuilder = scope.env().opBuilder("SymbolicGradient", scope.makeOpName("SymbolicGradient")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("SymbolicGradient")); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("Tout", Operands.toDataTypes(Tout)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/If.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/If.java index 0f35595d405..c0bc7d5001b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/If.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/If.java @@ -74,7 +74,7 @@ private If(Operation operation) { ) public static If create(Scope scope, Operand<? extends TType> cond, Iterable<Operand<?>> inputs, ConcreteFunction thenBranch, ConcreteFunction elseBranch, List<Class<? extends TType>> Tout) { - OperationBuilder opBuilder = scope.env().opBuilder("XlaIf", scope.makeOpName("If")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("If")); opBuilder.addInput(cond.asOutput()); opBuilder.addInputList(Operands.asOutputs(inputs)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Reduce.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Reduce.java index a0f70a32564..ec39d1def2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Reduce.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Reduce.java @@ -68,7 +68,7 @@ private Reduce(Operation operation) { ) public static <T extends TType> Reduce<T> create(Scope scope, Operand<T> input, Operand<T> initValue, List<Long> dimensionsToReduce, ConcreteFunction reducer) { - OperationBuilder opBuilder = scope.env().opBuilder("XlaReduce", scope.makeOpName("Reduce")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("Reduce")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(initValue.asOutput()); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/ReduceWindow.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/ReduceWindow.java index 40d94d31a06..7e2a4eaf100 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/ReduceWindow.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/ReduceWindow.java @@ -75,7 +75,7 @@ public static <T extends TType, U extends TNumber> ReduceWindow<T> create(Scope Operand<T> input, Operand<T> initValue, Operand<U> windowDimensions, Operand<U> windowStrides, Operand<U> baseDilations, Operand<U> windowDilations, Operand<U> padding, ConcreteFunction computation) { - OperationBuilder opBuilder = scope.env().opBuilder("XlaReduceWindow", scope.makeOpName("ReduceWindow")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("ReduceWindow")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(initValue.asOutput()); opBuilder.addInput(windowDimensions.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Scatter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Scatter.java index 7148f19f805..9a249b0e862 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Scatter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Scatter.java @@ -73,7 +73,7 @@ private Scatter(Operation operation) { public static <T extends TType> Scatter<T> create(Scope scope, Operand<T> operand, Operand<? extends TNumber> scatterIndices, Operand<T> updates, ConcreteFunction updateComputation, String dimensionNumbers, Boolean indicesAreSorted) { - OperationBuilder opBuilder = scope.env().opBuilder("XlaScatter", scope.makeOpName("Scatter")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("Scatter")); opBuilder.addInput(operand.asOutput()); opBuilder.addInput(scatterIndices.asOutput()); opBuilder.addInput(updates.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SelectAndScatter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SelectAndScatter.java index 6912fd70677..4e3b9753f96 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SelectAndScatter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SelectAndScatter.java @@ -75,7 +75,7 @@ private SelectAndScatter(Operation operation) { public static <T extends TType, U extends TNumber> SelectAndScatter<T> create(Scope scope, Operand<T> operand, Operand<U> windowDimensions, Operand<U> windowStrides, Operand<U> padding, Operand<T> source, Operand<T> initValue, ConcreteFunction select, ConcreteFunction scatter) { - OperationBuilder opBuilder = scope.env().opBuilder("XlaSelectAndScatter", scope.makeOpName("SelectAndScatter")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("SelectAndScatter")); opBuilder.addInput(operand.asOutput()); opBuilder.addInput(windowDimensions.asOutput()); opBuilder.addInput(windowStrides.asOutput()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/While.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/While.java index 985aabed588..61e1e3a1ec6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/While.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/While.java @@ -77,7 +77,7 @@ private While(Operation operation) { ) public static While create(Scope scope, Iterable<Operand<?>> input, ConcreteFunction cond, ConcreteFunction body) { - OperationBuilder opBuilder = scope.env().opBuilder("XlaWhile", scope.makeOpName("While")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("While")); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("cond", cond); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaHostCompute.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaHostCompute.java index c45377d32e1..f442219f338 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaHostCompute.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaHostCompute.java @@ -77,7 +77,7 @@ private XlaHostCompute(Operation operation) { public static XlaHostCompute create(Scope scope, Iterable<Operand<?>> inputs, List<Class<? extends TType>> Toutputs, List<String> ancestors, List<Shape> shapes, ConcreteFunction shapeInferenceGraph, String key, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("XlaHostCompute", scope.makeOpName("XlaHostCompute")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("XlaHostCompute")); opBuilder.addInputList(Operands.asOutputs(inputs)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("Toutputs", Operands.toDataTypes(Toutputs)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaLaunch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaLaunch.java index da119954460..110b723def4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaLaunch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaLaunch.java @@ -72,7 +72,7 @@ private XlaLaunch(Operation operation) { public static XlaLaunch create(Scope scope, Iterable<Operand<?>> constants, Iterable<Operand<?>> args, Iterable<Operand<? extends TType>> resources, List<Class<? extends TType>> Tresults, ConcreteFunction function) { - OperationBuilder opBuilder = scope.env().opBuilder("XlaLaunch", scope.makeOpName("XlaLaunch")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("XlaLaunch")); opBuilder.addInputList(Operands.asOutputs(constants)); opBuilder.addInputList(Operands.asOutputs(args)); opBuilder.addInputList(Operands.asOutputs(resources)); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicReduce.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicReduce.java index 6ce2ac4ff3e..16c8aeda370 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicReduce.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicReduce.java @@ -75,7 +75,7 @@ private XlaVariadicReduce(Operation operation) { public static <T extends TType> XlaVariadicReduce<T> create(Scope scope, Iterable<Operand<T>> input, Iterable<Operand<T>> initValue, List<Long> dimensionsToReduce, ConcreteFunction reducer) { - OperationBuilder opBuilder = scope.env().opBuilder("XlaVariadicReduce", scope.makeOpName("XlaVariadicReduce")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("XlaVariadicReduce")); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder.addInputList(Operands.asOutputs(initValue)); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java index 303e1dcc296..a752f4baeea 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java @@ -41,6 +41,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.StringJoiner; import javax.lang.model.element.Modifier; import org.tensorflow.Names; import org.tensorflow.proto.framework.ApiDef; @@ -69,6 +70,8 @@ enum RenderMode { OPERAND; } + private static final String OP_NAME_FIELD = "OP_NAME"; + /** The in-progress class builder for the top level op class. */ private final TypeSpec.Builder builder; @@ -96,6 +99,10 @@ enum RenderMode { /** The endpoint being generated in this class. */ private final Endpoint endpoint; + private final StatefulPair statefulPair; + private final boolean isStateSelector; + private final boolean isStateSubclass; + /** * The generated options class, or null if it doesn't have one or {@link #buildOptionsClass()} has * not been ran. @@ -128,7 +135,8 @@ enum RenderMode { String fullPackage, String group, String className, - Endpoint endpoint) { + Endpoint endpoint, + StatefulPair statefulPair) { this.builder = builder; this.op = op; @@ -139,6 +147,9 @@ enum RenderMode { this.group = group; this.className = className; this.endpoint = endpoint; + this.statefulPair = statefulPair; + this.isStateSelector = statefulPair != null && className.equals(statefulPair.selectorClassName); + this.isStateSubclass = statefulPair != null && !isStateSelector; op.getAttrList() .forEach( @@ -206,8 +217,15 @@ private String fullClassName() { /** Build the class. */ void buildClass() { - builder.addModifiers(Modifier.PUBLIC, Modifier.FINAL); - builder.superclass(Names.RawOp); + builder.addModifiers(Modifier.PUBLIC); + if (!isStateSelector) { + builder.addModifiers(Modifier.FINAL); + builder.superclass(Names.RawOp); + } + + if (isStateSubclass) { + builder.addSuperinterface(ClassName.get(fullPackage, statefulPair.selectorClassName)); + } // add class javadocs String summary = parseDocumentation(apiDef.getSummary()); @@ -221,6 +239,15 @@ void buildClass() { builder.addJavadoc("$L", desc + "\n"); } + if (isStateSelector) { + builder.addJavadoc( + "Selects between {@link " + + statefulPair.statefulClassName + + "} and {@link " + + statefulPair.statelessClassName + + "} based on the statefulness of the function arguments."); + } + // add superinterface and set mode if (op.getOutputArgCount() == 1) { ArgDef output = op.getOutputArg(0); @@ -232,11 +259,15 @@ void buildClass() { if (iterable) { mode = RenderMode.LIST_OPERAND; - builder.addSuperinterface( - ParameterizedTypeName.get(ClassName.get(Iterable.class), operandType)); + if (!isStateSubclass) { + builder.addSuperinterface( + ParameterizedTypeName.get(ClassName.get(Iterable.class), operandType)); + } } else { mode = RenderMode.OPERAND; - builder.addSuperinterface(operandType); + if (!isStateSubclass) { + builder.addSuperinterface(operandType); + } } } @@ -276,7 +307,7 @@ void buildClass() { builder.addAnnotation(annotation.build()); } - if (!optionalAttributes.isEmpty()) { + if (!optionalAttributes.isEmpty() && !isStateSubclass) { buildOptionsClass(); } @@ -286,29 +317,31 @@ void buildClass() { buildInterfaceImpl(); } - // add op name field - builder.addField( - FieldSpec.builder( - TypeResolver.STRING, - OP_NAME_FIELD_NAME, - Modifier.PUBLIC, - Modifier.STATIC, - Modifier.FINAL) - .addJavadoc("$L", "The name of this op, as known by TensorFlow core engine") - .initializer("$S", op.getName()) - .build()); + if (!isStateSelector) { + // add op name field + builder.addField( + FieldSpec.builder( + TypeResolver.STRING, + OP_NAME_FIELD, + Modifier.PUBLIC, + Modifier.STATIC, + Modifier.FINAL) + .addJavadoc("$L", "The name of this op, as known by TensorFlow core engine") + .initializer("$S", op.getName()) + .build()); - // add output fields - if (op.getOutputArgCount() > 0) { - for (ArgDef output : op.getOutputArgList()) { - builder.addField( - resolver.typeOf(output).listIfIterable().javaType, - getJavaName(output), - Modifier.PRIVATE); + // add output fields + if (op.getOutputArgCount() > 0) { + for (ArgDef output : op.getOutputArgList()) { + builder.addField( + resolver.typeOf(output).listIfIterable().javaType, + getJavaName(output), + Modifier.PRIVATE); + } } - } - buildConstructor(); + buildConstructor(); + } } /** Add a nested class for Options */ @@ -375,9 +408,14 @@ private void buildOptionsClass() { .build()); } + FieldSpec.Builder field = + FieldSpec.builder(type.classIfGeneric().listIfIterable().javaType, name); + if (!isStateSelector) { + field.addModifiers(Modifier.PRIVATE); + } + // add the field - optionsBuilder.addField( - type.classIfGeneric().listIfIterable().javaType, name, Modifier.PRIVATE); + optionsBuilder.addField(field.build()); } // add a private constructor @@ -470,15 +508,26 @@ private void buildFactoryMethods() { body.addStatement( "$T opBuilder = scope.env().opBuilder($L, scope.makeOpName($S))", Names.OperationBuilder, - OP_NAME_FIELD_NAME, + OP_NAME_FIELD, className); + List<String> functionArgs = new ArrayList<>(); + List<String> iterableFunctionArgs = new ArrayList<>(); + // add the inputs as parameters, and add them to the op builder for (ArgDef input : op.getInputArgList()) { ApiDef.Arg argDef = argApis.get(input); ResolvedType type = resolver.typeOf(input); String name = getJavaName(input); + if (type.javaType.equals(Names.ConcreteFunction)) { + if (type.iterable) { + iterableFunctionArgs.add(name); + } else { + functionArgs.add(name); + } + } + ParameterSpec.Builder param = ParameterSpec.builder(type.iterableIfIterable().javaType, name); String description = argDef.getDescription().isEmpty() @@ -509,11 +558,19 @@ private void buildFactoryMethods() { ResolvedType type = resolver.typeOf(attr); ApiDef.Attr apiAttr = attrApis.get(attr); + String javaName = getJavaName(attr); + + if (type.javaType.equals(Names.ConcreteFunction)) { + if (type.iterable) { + iterableFunctionArgs.add(javaName); + } else { + functionArgs.add(javaName); + } + } ParameterSpec.Builder builder = ParameterSpec.builder(type.classIfGeneric().listIfIterable().javaType, getJavaName(attr)); - String javaName = getJavaName(attr); String description = apiAttr.getDescription().isEmpty() ? String.format("the value of the %s property", javaName) @@ -536,18 +593,26 @@ private void buildFactoryMethods() { writeSetAttr(body, attr, type, false); } + // TODO optional function attrs (there currently aren't any) + // add optional attributes - if (optionsClass != null) { + if (optionsClass != null || (isStateSubclass && statefulPair.hasOptionalAttrs())) { + + ClassName optionsClassName; + if (isStateSubclass) { + optionsClassName = ClassName.get(fullPackage, statefulPair.selectorClassName, "Options"); + } else { + optionsClassName = ClassName.get(fullPackage, className, "Options"); + } + factoryBuilder.addParameter( - ParameterSpec.builder( - ArrayTypeName.of(ClassName.get(fullPackage, className, "Options")), "options") - .build()); + ParameterSpec.builder(ArrayTypeName.of(optionsClassName), "options").build()); paramTags.put("options", CodeBlock.of("$L", "carries optional attribute values")); factoryBuilder.varargs(); body.beginControlFlow("if (options != null)"); - body.beginControlFlow("for (Options opts : options)"); + body.beginControlFlow("for ($T opts : options)", optionsClassName); for (AttrDef attr : optionalAttributes) { String name = getJavaName(attr); body.beginControlFlow("if (opts.$L != null)", name); @@ -564,7 +629,40 @@ private void buildFactoryMethods() { body.addStatement( "return new $L(opBuilder.build())", typeParams.isEmpty() ? className : (className + "<>")); + if (isStateSelector) { + body.clear(); + + body.addStatement("boolean isStateful = false"); + functionArgs.forEach( + arg -> { + body.beginControlFlow("if ($L.isStateful())", arg) + .addStatement("isStateful = true") + .endControlFlow(); + }); + iterableFunctionArgs.forEach( + arg -> { + body.beginControlFlow("if ($L.stream().anyMatch(x -> x.isStateful()))", arg) + .addStatement("isStateful = true") + .endControlFlow(); + }); + + StringJoiner argList = new StringJoiner(", "); + factoryBuilder.parameters.forEach(x -> argList.add(x.name)); + + body.beginControlFlow("if (isStateful)") + .addStatement( + "return $T.create($L)", + ClassName.get(fullPackage, statefulPair.statefulClassName), + argList.toString()) + .nextControlFlow("else") + .addStatement( + "return $T.create($L)", + ClassName.get(fullPackage, statefulPair.statelessClassName), + argList.toString()) + .endControlFlow(); + } factoryBuilder.addCode(body.build()); + paramTags.forEach( (param, doc) -> { String description = doc.toString(); @@ -658,7 +756,10 @@ private void buildSecondaryFactory( factoryBuilder.addJavadoc( "\n@return a new instance of $L, with default output types", className); - factoryBuilder.addCode(body.build()); + + if (!isStateSelector) { + factoryBuilder.addCode(body.build()); + } factoryBuilder.addTypeVariables(typeVars); builder.addMethod(factoryBuilder.build()); @@ -690,15 +791,25 @@ private void buildGettersAndSetters() { for (ArgDef output : op.getOutputArgList()) { String name = getJavaName(output); ApiDef.Arg argDef = argApis.get(output); - builder.addMethod( + + MethodSpec.Builder method = MethodSpec.methodBuilder(name) .addModifiers(Modifier.PUBLIC) .returns(resolver.typeOf(output).listIfIterable().javaType) .addJavadoc("Gets $L.\n", name) .addJavadoc("$L", parseDocumentation(argDef.getDescription())) - .addJavadoc("\n@return $L.", name) - .addCode("return $L;", name) - .build()); + .addJavadoc("\n@return $L.", name); + + if (isStateSelector) { + method.addModifiers(Modifier.ABSTRACT); + } else { + if (isStateSubclass) { + method.addAnnotation(Override.class); + } + method.addCode("return $L;", name); + } + + builder.addMethod(method.build()); } } @@ -718,14 +829,18 @@ private void buildInterfaceImpl() { .returns(outputType) .addAnnotation(Override.class); - if (uncheckedCast) { - asOutput.addAnnotation( - AnnotationSpec.builder(SuppressWarnings.class) - .addMember("value", "$S", "unchecked") - .build()); - asOutput.addCode("return ($T) $L;", outputType, getJavaName(output)); + if (isStateSelector) { + asOutput.addModifiers(Modifier.ABSTRACT); } else { - asOutput.addCode("return $L;", getJavaName(output)); + if (uncheckedCast) { + asOutput.addAnnotation( + AnnotationSpec.builder(SuppressWarnings.class) + .addMember("value", "$S", "unchecked") + .build()); + asOutput.addCode("return ($T) $L;", outputType, getJavaName(output)); + } else { + asOutput.addCode("return $L;", getJavaName(output)); + } } builder.addMethod(asOutput.build()); @@ -743,7 +858,11 @@ private void buildInterfaceImpl() { .addMember("value", "{$S, $S}", "rawtypes", "unchecked") .build()); - iterator.addCode("return ($T) $L.iterator();", Iterator.class, getJavaName(output)); + if (isStateSelector) { + iterator.addModifiers(Modifier.ABSTRACT); + } else { + iterator.addCode("return ($T) $L.iterator();", Iterator.class, getJavaName(output)); + } builder.addMethod(iterator.build()); } diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java index bea598ca37a..6611eab61e7 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java @@ -64,17 +64,20 @@ public boolean isStateVariant(FullOpDef other) { OpDef copy = opDef.toBuilder().setName(other.opDef.getName()).setIsStateful(other.isStateful()).build(); - return copy.equals(other.opDef) && packageName.equals(other.packageName); + return copy.equals(other.opDef) + && packageName.equals(other.packageName) + && group.equals(other.group); } public TypeSpec buildOpClass() { - return buildOpClass(className); + return buildOpClass(className, null); } - public TypeSpec buildOpClass(String className) { + TypeSpec buildOpClass(String className, StatefulPair pair) { TypeSpec.Builder cls = TypeSpec.classBuilder(className); try { - new ClassGenerator(cls, opDef, apiDef, basePackage, packageName, group, className, endpoint) + new ClassGenerator( + cls, opDef, apiDef, basePackage, packageName, group, className, endpoint, pair) .buildClass(); } catch (Exception e) { throw new IllegalStateException("Failed to generate class for op " + opDef.getName(), e); diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java index da82dd237e9..f833d8001cb 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java @@ -71,14 +71,48 @@ public static List<StatefulPair> extractStatefulPairs(List<FullOpDef> ops) { } public String getPackageName() { + // guaranteed to be the same in FullOpDef.isStateVariant return statefulOp.packageName; } + public String getGroup() { + // guaranteed to be the same in FullOpDef.isStateVariant + return statefulOp.group; + } + + public boolean hasOptionalAttrs() { + return statefulOp.opDef.getAttrList().stream() + .anyMatch(attr -> attr.hasDefaultValue() && !attr.getType().contains("type")); + } + public List<TypeSpec> buildOpClasses() { - TypeSpec stateful = statefulOp.buildOpClass(statefulClassName); - TypeSpec stateless = statelessOp.buildOpClass(statelessClassName); - return Arrays.asList(stateful, stateless); + TypeSpec.Builder selector = TypeSpec.interfaceBuilder(selectorClassName); + try { + new ClassGenerator( + selector, + statefulOp.opDef, + statefulOp.apiDef, + statefulOp.basePackage, + getPackageName(), + getGroup(), + selectorClassName, + statefulOp.endpoint, + this) + .buildClass(); + } catch (Exception e) { + throw new IllegalStateException( + "Failed to generate statefulness selector class for ops " + + statefulOp.opDef.getName() + + " and " + + statelessOp.opDef.getName(), + e); + } + + TypeSpec stateful = statefulOp.buildOpClass(statefulClassName, this); + TypeSpec stateless = statelessOp.buildOpClass(statelessClassName, this); + + return Arrays.asList(stateful, stateless, selector.build()); } @Override From ec92d15d4f8587903f21f75cb19c3b3a3f9fe6fb Mon Sep 17 00:00:00 2001 From: Ryan Nett <JNett96@gmail.com> Date: Mon, 31 May 2021 19:11:03 -0700 Subject: [PATCH 07/14] Test for wrappers, using If Signed-off-by: Ryan Nett <JNett96@gmail.com> --- .../annotations/org/tensorflow/op/Ops.java | 12 +- .../gen/java/org/tensorflow/op/core/Case.java | 3 +- .../gen/java/org/tensorflow/op/core/If.java | 3 +- .../tensorflow/op/core/PartitionedCall.java | 3 +- .../java/org/tensorflow/op/core/While.java | 3 +- .../java/org/tensorflow/op/core/IfTest.java | 173 ++++++++++++++++++ .../op/generator/ClassGenerator.java | 1 + 7 files changed, 190 insertions(+), 8 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 7b6aa4a1679..253e935ab48 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -1234,7 +1234,8 @@ public Map<String, Operand<?>> call(ConcreteFunction function, * } * ``` * </pre> - * Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments. + * + * <p>Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments. * * @param branchIndex The branch selector, an int32 Tensor. * @param input A list of input tensors passed to the branch function. @@ -2951,7 +2952,8 @@ public IdentityN identityN(Iterable<Operand<?>> input) { /** * output = cond ? then_branch(input) : else_branch(input) - * Selects between {@link StatefulIf} and {@link StatelessIf} based on the statefulness of the function arguments. + * + * <p>Selects between {@link StatefulIf} and {@link StatelessIf} based on the statefulness of the function arguments. * * @param cond <pre> * A Tensor. If the tensor is a scalar of non-boolean type, the @@ -4021,7 +4023,8 @@ public <T extends TType> ParallelDynamicStitch<T> parallelDynamicStitch( /** * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. - * Selects between {@link StatefulPartitionedCall} and {@link StatelessPartitionedCall} based on the statefulness of the function arguments. + * + * <p>Selects between {@link StatefulPartitionedCall} and {@link StatelessPartitionedCall} based on the statefulness of the function arguments. * * @param args A list of input tensors. * @param Tout A list of output types. @@ -8108,7 +8111,8 @@ public Where where(Operand<? extends TType> condition) { /** * output = input; While (Cond(output)) { output = Body(output) } - * Selects between {@link StatefulWhile} and {@link StatelessWhile} based on the statefulness of the function arguments. + * + * <p>Selects between {@link StatefulWhile} and {@link StatelessWhile} based on the statefulness of the function arguments. * * @param input A list of input tensors whose types are T. * @param cond <pre> diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java index 46e2406bcdf..b137e28f283 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Case.java @@ -50,7 +50,8 @@ * } * ``` * </pre> - * Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments. + * + * <p>Selects between {@link StatefulCase} and {@link StatelessCase} based on the statefulness of the function arguments. */ @Operator public interface Case extends Iterable<Operand<TType>> { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java index 82134d19559..2ba19c25f25 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/If.java @@ -31,7 +31,8 @@ /** * output = cond ? then_branch(input) : else_branch(input) - * Selects between {@link StatefulIf} and {@link StatelessIf} based on the statefulness of the function arguments. + * + * <p>Selects between {@link StatefulIf} and {@link StatelessIf} based on the statefulness of the function arguments. */ @Operator public interface If extends Iterable<Operand<TType>> { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java index 42ccc7168b0..7deb23131c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/PartitionedCall.java @@ -29,7 +29,8 @@ /** * returns {@code f(inputs)}, where {@code f}'s body is placed and partitioned. - * Selects between {@link StatefulPartitionedCall} and {@link StatelessPartitionedCall} based on the statefulness of the function arguments. + * + * <p>Selects between {@link StatefulPartitionedCall} and {@link StatelessPartitionedCall} based on the statefulness of the function arguments. */ @Operator public interface PartitionedCall extends Iterable<Operand<TType>> { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java index 279683e3806..714929f426c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/While.java @@ -31,7 +31,8 @@ /** * output = input; While (Cond(output)) { output = Body(output) } - * Selects between {@link StatefulWhile} and {@link StatelessWhile} based on the statefulness of the function arguments. + * + * <p>Selects between {@link StatefulWhile} and {@link StatelessWhile} based on the statefulness of the function arguments. */ @Operator public interface While extends Iterable<Operand<TType>> { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java new file mode 100644 index 00000000000..b2911d2eec0 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java @@ -0,0 +1,173 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.op.core; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.Collections; +import org.junit.jupiter.api.Test; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.EagerSession; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Signature; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; + +public class IfTest { + + private static Operand<TInt32> basicIf(Ops tf, Operand<TInt32> a, Operand<TInt32> b) { + ConcreteFunction thenBranch = + ConcreteFunction.create( + (ops) -> { + Operand<TInt32> a1 = ops.placeholder(TInt32.class); + Operand<TInt32> b1 = ops.placeholder(TInt32.class); + return Signature.builder().input("a", a1).input("b", b1).output("y", a1).build(); + }); + + ConcreteFunction elseBranch = + ConcreteFunction.create( + (ops) -> { + Operand<TInt32> a1 = ops.placeholder(TInt32.class); + Operand<TInt32> b1 = ops.placeholder(TInt32.class); + Operand<TInt32> y = ops.math.neg(b1); + return Signature.builder().input("a", a1).input("b", b1).output("y", y).build(); + }); + + return (Operand<TInt32>) + tf.ifOp( + tf.math.greater(a, b), + Arrays.asList(a, b), + Arrays.asList(TInt32.class), + thenBranch, + elseBranch) + .output() + .get(0); + } + + @Test + public void testGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + Ops tf = Ops.create(g); + Operand<TInt32> a = tf.placeholder(TInt32.class); + Operand<TInt32> b = tf.placeholder(TInt32.class); + Operand<TInt32> c = basicIf(tf, a, b); + + assertEquals(StatelessIf.OP_NAME, c.op().type()); + + try (TInt32 out = + (TInt32) + s.runner() + .feed(a, TInt32.scalarOf(2)) + .feed(b, TInt32.scalarOf(1)) + .fetch(c) + .run() + .get(0)) { + assertEquals(2, out.getInt()); + } + + try (TInt32 out = + (TInt32) + s.runner() + .feed(a, TInt32.scalarOf(2)) + .feed(b, TInt32.scalarOf(3)) + .fetch(c) + .run() + .get(0)) { + assertEquals(-3, out.getInt()); + } + } + } + + @Test + public void testStatefullness() { + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Operand<TInt32> a = tf.placeholder(TInt32.class); + Operand<TInt32> b = tf.placeholder(TInt32.class); + + ConcreteFunction thenBranch = + ConcreteFunction.create( + (ops) -> { + Operand<TInt32> a1 = ops.placeholder(TInt32.class); + Operand<TInt32> b1 = ops.placeholder(TInt32.class); + Operand<TInt32> result = + (Operand<TInt32>) + ops.statefulIf( + ops.constant(false), + Collections.emptyList(), + Arrays.asList(TInt32.class), + ConcreteFunction.create( + (ops1) -> + Signature.builder().output("y", ops1.constant(1)).build()), + ConcreteFunction.create( + (ops1) -> + Signature.builder().output("y", ops1.constant(1)).build())) + .output() + .get(0); + return Signature.builder() + .input("a", a1) + .input("b", b1) + .output("y", result) + .build(); + }); + + ConcreteFunction elseBranch = + ConcreteFunction.create( + (ops) -> { + Operand<TInt32> a1 = ops.placeholder(TInt32.class); + Operand<TInt32> b1 = ops.placeholder(TInt32.class); + Operand<TInt32> y = ops.math.neg(b1); + return Signature.builder().input("a", a1).input("b", b1).output("y", y).build(); + }); + + Operand<TInt32> output = + (Operand<TInt32>) + tf.ifOp( + tf.math.greater(a, b), + Arrays.asList(a, b), + Arrays.asList(TInt32.class), + thenBranch, + elseBranch) + .output() + .get(0); + + assertEquals(StatefulIf.OP_NAME, output.op().type()); + } + } + + @Test + public void testEager() { + try (EagerSession e = EagerSession.create()) { + Ops tf = Ops.create(e); + + Operand<TInt32> out1 = basicIf(tf, tf.constant(2), tf.constant(1)); + + assertEquals(StatelessIf.OP_NAME, out1.op().type()); + + try (TInt32 out = out1.asTensor()) { + assertEquals(2, out.getInt()); + } + + try (TInt32 out = basicIf(tf, tf.constant(2), tf.constant(3)).asTensor()) { + assertEquals(-3, out.getInt()); + } + } + } +} diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java index a752f4baeea..0d3546655a3 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java @@ -240,6 +240,7 @@ void buildClass() { } if (isStateSelector) { + builder.addJavadoc("\n<p>"); builder.addJavadoc( "Selects between {@link " + statefulPair.statefulClassName From d2b8be0c4ff057155e7a58796c74a474be81480b Mon Sep 17 00:00:00 2001 From: Ryan Nett <JNett96@gmail.com> Date: Mon, 31 May 2021 19:19:31 -0700 Subject: [PATCH 08/14] Use generated op names in ConcreteFunction Signed-off-by: Ryan Nett <JNett96@gmail.com> --- .../main/java/org/tensorflow/ConcreteFunction.java | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 3e264e0e25d..73932a8f995 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -42,6 +42,8 @@ import org.tensorflow.op.Scope; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.PlaceholderWithDefault; +import org.tensorflow.op.core.StatefulPartitionedCall; +import org.tensorflow.op.core.StatelessPartitionedCall; import org.tensorflow.proto.framework.AttrValue; import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.FunctionDef; @@ -207,11 +209,6 @@ public String toString() { return signature.toString(); } - // TODO migrate to the actual ops once they are generated - public static final String CALL_OP = "PartitionedCall"; - // TODO migrate to the actual ops once they are generated - public static final String STATEFUL_CALL_OP = "StatefulPartitionedCall"; - /** * Calls the function in an execution environment, adding its graph as a function if it isn't * already present. The inputs and outputs are keyed by the names set in the {@code Signature}. @@ -255,7 +252,9 @@ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> argumen OperationBuilder opBuilder = scope .env() - .opBuilder(isStateful() ? STATEFUL_CALL_OP : CALL_OP, scope.makeOpName(displayName)); + .opBuilder( + isStateful() ? StatefulPartitionedCall.OP_NAME : StatelessPartitionedCall.OP_NAME, + scope.makeOpName(displayName)); opBuilder.addInputList(inputs); From add3c8036a69fb4aef05db45e4c1bb4cc0c7faeb Mon Sep 17 00:00:00 2001 From: Ryan Nett <JNett96@gmail.com> Date: Mon, 31 May 2021 19:37:28 -0700 Subject: [PATCH 09/14] Better dependency resolution Signed-off-by: Ryan Nett <JNett96@gmail.com> --- .../java/org/tensorflow/NativeFunction.java | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java index faab6dbca7b..967fec33d88 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java @@ -76,10 +76,21 @@ public synchronized List<String> getDependencies() { Set<String> deps = new LinkedHashSet<>(); for (NodeDef node : getFunctionDef().getNodeDefList()) { - if (node.getOp().equals(ConcreteFunction.CALL_OP) - || node.getOp().equals(ConcreteFunction.STATEFUL_CALL_OP)) { - deps.add(node.getAttrMap().get("f").getFunc().getName()); - } + node.getAttrMap() + .values() + .forEach( + (attr) -> { + if (attr.hasFunc()) { + deps.add(attr.getFunc().getName()); + } else if (attr.hasList()) { + attr.getList() + .getFuncList() + .forEach( + funcs -> { + deps.add(funcs.getName()); + }); + } + }); } dependencies = Collections.unmodifiableList(new ArrayList<>(deps)); } From 272af6a594bf99d6c162fa38e3012e46311e2755 Mon Sep 17 00:00:00 2001 From: Ryan Nett <JNett96@gmail.com> Date: Mon, 31 May 2021 19:46:02 -0700 Subject: [PATCH 10/14] PartitionedCall test Signed-off-by: Ryan Nett <JNett96@gmail.com> --- .../op/core/PartitionedCallTest.java | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/PartitionedCallTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/PartitionedCallTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/PartitionedCallTest.java new file mode 100644 index 00000000000..865a55f4ac9 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/PartitionedCallTest.java @@ -0,0 +1,68 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.op.core; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import org.junit.jupiter.api.Test; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.EagerSession; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Signature; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; + +public class PartitionedCallTest { + + public static Signature plusTwo(Ops tf) { + Operand<TInt32> x = tf.placeholder(TInt32.class); + Operand<TInt32> y = tf.math.add(x, tf.constant(2)); + return Signature.builder().input("x", x).output("y", y).build(); + } + + @Test + public void testEager() { + try (EagerSession e = EagerSession.create(); + ConcreteFunction f = ConcreteFunction.create(PartitionedCallTest::plusTwo)) { + Ops tf = Ops.create(e); + Operand<TInt32> x = tf.constant(3); + Operand<TInt32> y = + (Operand<TInt32>) + tf.partitionedCall(Arrays.asList(x), Arrays.asList(TInt32.class), f).output().get(0); + assertEquals(5, y.asTensor().getInt()); + } + } + + @Test + public void testGraph() { + try (Graph g = new Graph(); + ConcreteFunction f = ConcreteFunction.create(PartitionedCallTest::plusTwo)) { + Ops tf = Ops.create(g); + Operand<TInt32> x = tf.placeholder(TInt32.class); + Operand<TInt32> y = + (Operand<TInt32>) + tf.partitionedCall(Arrays.asList(x), Arrays.asList(TInt32.class), f).output().get(0); + + try (Session s = new Session(g); + TInt32 out = (TInt32) s.runner().feed(x, TInt32.scalarOf(3)).fetch(y).run().get(0)) { + assertEquals(5, out.getInt()); + } + } + } +} From 6845d9fd83a09140f9d29f47a3075d60890d8f6c Mon Sep 17 00:00:00 2001 From: Ryan Nett <JNett96@gmail.com> Date: Mon, 31 May 2021 20:02:37 -0700 Subject: [PATCH 11/14] PartitionedCall to do calls Signed-off-by: Ryan Nett <JNett96@gmail.com> --- .../java/org/tensorflow/ConcreteFunction.java | 50 ++++++------------- 1 file changed, 14 insertions(+), 36 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 73932a8f995..f59604d7528 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -38,12 +38,12 @@ import org.tensorflow.internal.c_api.TF_Operation; import org.tensorflow.internal.c_api.TF_Output; import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.internal.types.registry.TensorTypeRegistry; import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; +import org.tensorflow.op.core.PartitionedCall; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.PlaceholderWithDefault; -import org.tensorflow.op.core.StatefulPartitionedCall; -import org.tensorflow.op.core.StatelessPartitionedCall; import org.tensorflow.proto.framework.AttrValue; import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.FunctionDef; @@ -218,11 +218,8 @@ public String toString() { * @return the outputs of the function */ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> arguments) { - List<Operand<?>> inputList = new ArrayList<>(); + List<Operand<?>> inputList = new ArrayList<>(signature.inputNames().size()); - Output<?>[] inputs = new Output<?>[signature().inputNames().size()]; - - int i = 0; for (String inputName : signature().inputNames()) { if (!arguments.containsKey(inputName)) { throw new IllegalArgumentException( @@ -240,42 +237,23 @@ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> argumen + inputName + "\" was null."); } - inputs[i] = input.asOutput(); - i++; + inputList.add(input); } - scope.env().attachFunction(this); - String name = getDefinedName(); - - String displayName = Scope.isValidOpName(name) ? name : "FunctionCall"; - - OperationBuilder opBuilder = - scope - .env() - .opBuilder( - isStateful() ? StatefulPartitionedCall.OP_NAME : StatelessPartitionedCall.OP_NAME, - scope.makeOpName(displayName)); - - opBuilder.addInputList(inputs); - - opBuilder.setAttr("f", this); - opBuilder.setAttr("Tin", inputDtypes); - opBuilder.setAttr("Tout", outputDtypes); - - opBuilder = scope.apply(opBuilder); - Operation op = opBuilder.build(); - - int numOutputs1 = op.numOutputs(); - List<Operand<?>> outputList = new ArrayList<>(signature().outputNames().size()); - - for (i = 0; i < numOutputs1; i++) { - outputList.add(op.output(i)); - } + List<Output<?>> outputList = + PartitionedCall.create( + scope, + inputList, + Arrays.stream(inputDtypes) + .map(x -> TensorTypeRegistry.find(x).type()) + .collect(Collectors.toList()), + this) + .output(); Map<String, Operand<?>> namedOutputs = new LinkedHashMap<>(signature().outputNames().size()); List<String> outputNames = new ArrayList<>(signature().outputNames()); - for (i = 0; i < outputNames.size(); i++) { + for (int i = 0; i < outputNames.size(); i++) { String outputName = outputNames.get(i); if (i > outputList.size()) { From 69d46003192cfa90834cd024d04609a2fa8e2d22 Mon Sep 17 00:00:00 2001 From: Ryan Nett <JNett96@gmail.com> Date: Tue, 15 Jun 2021 19:28:41 -0700 Subject: [PATCH 12/14] Rebase fixes Signed-off-by: Ryan Nett <JNett96@gmail.com> --- pom.xml | 14 +- .../annotations/org/tensorflow/op/Ops.java | 43 ++--- .../annotations/org/tensorflow/op/XlaOps.java | 3 +- .../tensorflow/op/core/XlaVariadicSort.java | 101 ++++++++++ .../tensorflow/op/data/SnapshotDataset.java | 52 +++++ .../org/tensorflow/op/risc/RiscCondition.java | 95 +++++++++ .../org/tensorflow/op/risc/RiscWhile.java | 181 ++++++++++++++++++ .../org/tensorflow/op/xla/XlaHostCompute.java | 52 +++++ .../tensorflow/op/xla/XlaVariadicReduce.java | 3 +- .../java/org/tensorflow/ConcreteFunction.java | 26 +-- .../java/org/tensorflow/NativeFunction.java | 26 +-- .../java/org/tensorflow/op/core/IfTest.java | 26 +-- .../op/core/PartitionedCallTest.java | 22 +-- .../op/generator/ClassGenerator.java | 26 +-- .../tensorflow/op/generator/FullOpDef.java | 26 +-- .../op/generator/GeneratorUtils.java | 22 +-- .../tensorflow/op/generator/OpGenerator.java | 26 +-- .../tensorflow/op/generator/StatefulPair.java | 26 +-- 18 files changed, 626 insertions(+), 144 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaVariadicSort.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/risc/RiscCondition.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/risc/RiscWhile.java diff --git a/pom.xml b/pom.xml index faf005732dd..c76ba81cf3f 100644 --- a/pom.xml +++ b/pom.xml @@ -237,11 +237,10 @@ </profile> <!-- - Profile that enables format checks on compilation - Can auto-format using spotless:apply. + Profile to run spotless:apply on builds. Will run before format's check. --> <profile> - <id>check-format</id> + <id>apply-format</id> <build> <plugins> <plugin> @@ -255,7 +254,7 @@ <id>spotless-apply</id> <phase>initialize</phase> <goals> - <goal>check</goal> + <goal>apply</goal> </goals> </execution> </executions> @@ -265,10 +264,11 @@ </profile> <!-- - Profile to run spotless:apply on builds. Will run before format's check. + Profile that enables format checks on compilation + Can auto-format using spotless:apply. --> <profile> - <id>apply-format</id> + <id>check-format</id> <build> <plugins> <plugin> @@ -282,7 +282,7 @@ <id>spotless-check</id> <phase>initialize</phase> <goals> - <goal>apply</goal> + <goal>check</goal> </goals> </execution> </executions> diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 253e935ab48..dcf86da09a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -172,7 +172,6 @@ import org.tensorflow.op.core.RefSelect; import org.tensorflow.op.core.RefSwitch; import org.tensorflow.op.core.RemoteCall; -import org.tensorflow.op.core.RemoteFusedGraphExecute; import org.tensorflow.op.core.Reshape; import org.tensorflow.op.core.ResourceCountUpTo; import org.tensorflow.op.core.ResourceGather; @@ -301,6 +300,7 @@ import org.tensorflow.op.core.XlaSetDynamicDimensionSize; import org.tensorflow.op.core.XlaSpmdFullToShardShape; import org.tensorflow.op.core.XlaSpmdShardToFullShape; +import org.tensorflow.op.core.XlaVariadicSort; import org.tensorflow.op.core.Zeros; import org.tensorflow.op.core.ZerosLike; import org.tensorflow.types.TBool; @@ -4356,27 +4356,6 @@ public RemoteCall remoteCall(Operand<TString> target, Iterable<Operand<?>> args, return RemoteCall.create(scope, target, args, Tout, f); } - /** - * Execute a sub graph on a remote processor. - * The graph specifications(such as graph itself, input tensors and output names) - * are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo - * as serialized_remote_fused_graph_execute_info. - * The specifications will be passed to a dedicated registered - * remote fused graph executor. The executor will send the graph specifications - * to a remote processor and execute that graph. The execution results - * will be passed to consumer nodes as outputs of this node. - * - * @param inputs Arbitrary number of tensors with arbitrary data types - * @param Toutputs the value of the Toutputs property - * @param serializedRemoteFusedGraphExecuteInfo Serialized protocol buffer - * of RemoteFusedGraphExecuteInfo which contains graph specifications. - * @return a new instance of RemoteFusedGraphExecute - */ - public RemoteFusedGraphExecute remoteFusedGraphExecute(Iterable<Operand<?>> inputs, - List<Class<? extends TType>> Toutputs, String serializedRemoteFusedGraphExecuteInfo) { - return RemoteFusedGraphExecute.create(scope, inputs, Toutputs, serializedRemoteFusedGraphExecuteInfo); - } - /** * Reshapes a tensor. * Given {@code tensor}, this operation returns a tensor that has the same values @@ -8239,6 +8218,26 @@ public <T extends TType> XlaSpmdShardToFullShape<T> xlaSpmdShardToFullShape(Oper return XlaSpmdShardToFullShape.create(scope, input, manualSharding, fullShape); } + /** + * Wraps the XLA Sort operator, documented at + * https://www.tensorflow.org/performance/xla/operation_semantics#sort + * . + * <p>Sorts one or more tensors, with support for custom comparator, dimension, and + * is_stable attributes. + * + * @param inputs A list of {@code Tensor} of identical shape but possibly different types. + * @param dimension The dimension along which to sort. Must be a compile-time constant. + * @param comparator A comparator function to apply to 2*N scalars and returning a + * boolean. N is the number of sort inputs. If you want to sort in ascending + * order then the comparator should perform a less-than comparison. + * @param isStable Whether to use stable sort. + * @return a new instance of XlaVariadicSort + */ + public XlaVariadicSort xlaVariadicSort(Iterable<Operand<?>> inputs, Operand<TInt32> dimension, + ConcreteFunction comparator, Boolean isStable) { + return XlaVariadicSort.create(scope, inputs, dimension, comparator, isStable); + } + /** * Creates a zeroed tensor given its type and shape. * diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java index ba43401e0fe..16298aebaca 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java @@ -607,7 +607,8 @@ public XlaSetBound xlaSetBound(Operand<TInt32> input, Operand<TInt32> bound) { } /** - * Wraps the variadic XLA Reduce operator, documented at + * Wraps the variadic XLA Reduce operator. + * Semantics are documented at * https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce. * * @param <T> data type for {@code output} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaVariadicSort.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaVariadicSort.java new file mode 100644 index 00000000000..5a4eafda539 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaVariadicSort.java @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; + +/** + * Wraps the XLA Sort operator, documented at + * https://www.tensorflow.org/performance/xla/operation_semantics#sort + * . + * <p>Sorts one or more tensors, with support for custom comparator, dimension, and + * is_stable attributes. + */ +@Operator +public final class XlaVariadicSort extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaVariadicSort"; + + private List<Output<?>> outputs; + + @SuppressWarnings("unchecked") + private XlaVariadicSort(Operation operation) { + super(operation); + int outputIdx = 0; + int outputsLength = operation.outputListLength("outputs"); + outputs = Arrays.asList(operation.outputList(outputIdx, outputsLength)); + outputIdx += outputsLength; + } + + /** + * Factory method to create a class wrapping a new XlaVariadicSort operation. + * + * @param scope current scope + * @param inputs A list of {@code Tensor} of identical shape but possibly different types. + * @param dimension The dimension along which to sort. Must be a compile-time constant. + * @param comparator A comparator function to apply to 2*N scalars and returning a + * boolean. N is the number of sort inputs. If you want to sort in ascending + * order then the comparator should perform a less-than comparison. + * @param isStable Whether to use stable sort. + * @return a new instance of XlaVariadicSort + */ + @Endpoint( + describeByClass = true + ) + public static XlaVariadicSort create(Scope scope, Iterable<Operand<?>> inputs, + Operand<TInt32> dimension, ConcreteFunction comparator, Boolean isStable) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("XlaVariadicSort")); + opBuilder.addInputList(Operands.asOutputs(inputs)); + opBuilder.addInput(dimension.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("comparator", comparator); + opBuilder.setAttr("is_stable", isStable); + return new XlaVariadicSort(opBuilder.build()); + } + + /** + * Gets outputs. + * A list of {@code Tensor} of same shape and types as the {@code input}. + * @return outputs. + */ + public List<Output<?>> outputs() { + return outputs; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) outputs.iterator(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java index 5ed0cfa969c..3fc7bd3ec93 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java @@ -101,6 +101,12 @@ public static SnapshotDataset create(Scope scope, Operand<? extends TType> input if (opts.writerPrefix != null) { opBuilder.setAttr("writer_prefix", opts.writerPrefix); } + if (opts.hashValid != null) { + opBuilder.setAttr("hash_valid", opts.hashValid); + } + if (opts.hash != null) { + opBuilder.setAttr("hash", opts.hash); + } } } return new SnapshotDataset(opBuilder.build()); @@ -136,6 +142,26 @@ public static Options writerPrefix(String writerPrefix) { return new Options().writerPrefix(writerPrefix); } + /** + * Sets the hashValid option. + * + * @param hashValid the hashValid option + * @return this Options instance. + */ + public static Options hashValid(Boolean hashValid) { + return new Options().hashValid(hashValid); + } + + /** + * Sets the hash option. + * + * @param hash the hash option + * @return this Options instance. + */ + public static Options hash(Long hash) { + return new Options().hash(hash); + } + /** * Gets handle. * @@ -161,6 +187,10 @@ public static class Options { private String writerPrefix; + private Boolean hashValid; + + private Long hash; + private Options() { } @@ -196,5 +226,27 @@ public Options writerPrefix(String writerPrefix) { this.writerPrefix = writerPrefix; return this; } + + /** + * Sets the hashValid option. + * + * @param hashValid the hashValid option + * @return this Options instance. + */ + public Options hashValid(Boolean hashValid) { + this.hashValid = hashValid; + return this; + } + + /** + * Sets the hash option. + * + * @param hash the hash option + * @return this Options instance. + */ + public Options hash(Long hash) { + this.hash = hash; + return this; + } } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/risc/RiscCondition.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/risc/RiscCondition.java new file mode 100644 index 00000000000..f6c37680a15 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/risc/RiscCondition.java @@ -0,0 +1,95 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.risc; + +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TNumber; + +/** + * The RiscCondition operation + * + * @param <U> data type for {@code output} output + */ +public final class RiscCondition<U extends TNumber> extends RawOp implements Operand<U> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "RiscCondition"; + + private Output<U> output; + + private RiscCondition(Operation operation) { + super(operation); + int outputIdx = 0; + output = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new RiscCondition operation. + * + * @param scope current scope + * @param pred the pred value + * @param inputTrue the inputTrue value + * @param inputFalse the inputFalse value + * @param funcTrue the value of the funcTrue property + * @param funcFalse the value of the funcFalse property + * @param DstT the value of the DstT property + * @param <U> data type for {@code RiscCondition} output and operands + * @param <T> data type for {@code RiscCondition} output and operands + * @return a new instance of RiscCondition + */ + @Endpoint( + describeByClass = true + ) + public static <U extends TNumber, T extends TNumber> RiscCondition<U> create(Scope scope, + Operand<TBool> pred, Operand<T> inputTrue, Operand<T> inputFalse, ConcreteFunction funcTrue, + ConcreteFunction funcFalse, Class<U> DstT) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("RiscCondition")); + opBuilder.addInput(pred.asOutput()); + opBuilder.addInput(inputTrue.asOutput()); + opBuilder.addInput(inputFalse.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("func_true", funcTrue); + opBuilder.setAttr("func_false", funcFalse); + opBuilder.setAttr("DstT", Operands.toDataType(DstT)); + return new RiscCondition<>(opBuilder.build()); + } + + /** + * Gets output. + * + * @return output. + */ + public Output<U> output() { + return output; + } + + @Override + public Output<U> asOutput() { + return output; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/risc/RiscWhile.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/risc/RiscWhile.java new file mode 100644 index 00000000000..328f098c70d --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/risc/RiscWhile.java @@ -0,0 +1,181 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.risc; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.family.TType; + +/** + * The RiscWhile operation + */ +public final class RiscWhile extends RawOp implements Iterable<Operand<TType>> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "RiscWhile"; + + private List<Output<?>> output; + + @SuppressWarnings("unchecked") + private RiscWhile(Operation operation) { + super(operation); + int outputIdx = 0; + int outputLength = operation.outputListLength("output"); + output = Arrays.asList(operation.outputList(outputIdx, outputLength)); + outputIdx += outputLength; + } + + /** + * Factory method to create a class wrapping a new RiscWhile operation. + * + * @param scope current scope + * @param input the input value + * @param cond the value of the cond property + * @param body the value of the body property + * @param options carries optional attribute values + * @return a new instance of RiscWhile + */ + @Endpoint( + describeByClass = true + ) + public static RiscWhile create(Scope scope, Iterable<Operand<?>> input, ConcreteFunction cond, + ConcreteFunction body, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("RiscWhile")); + opBuilder.addInputList(Operands.asOutputs(input)); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("cond", cond); + opBuilder.setAttr("body", body); + if (options != null) { + for (Options opts : options) { + if (opts.outputShapes != null) { + Shape[] outputShapesArray = new Shape[opts.outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = opts.outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + } + if (opts.parallelIterations != null) { + opBuilder.setAttr("parallel_iterations", opts.parallelIterations); + } + } + } + return new RiscWhile(opBuilder.build()); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(List<Shape> outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public static Options outputShapes(Shape[] outputShapes) { + return new Options().outputShapes(outputShapes); + } + + /** + * Sets the parallelIterations option. + * + * @param parallelIterations the parallelIterations option + * @return this Options instance. + */ + public static Options parallelIterations(Long parallelIterations) { + return new Options().parallelIterations(parallelIterations); + } + + /** + * Gets output. + * + * @return output. + */ + public List<Output<?>> output() { + return output; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<TType>> iterator() { + return (Iterator) output.iterator(); + } + + /** + * Optional attributes for {@link org.tensorflow.op.risc.RiscWhile} + */ + public static class Options { + private List<Shape> outputShapes; + + private Long parallelIterations; + + private Options() { + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(List<Shape> outputShapes) { + this.outputShapes = outputShapes; + return this; + } + + /** + * Sets the outputShapes option. + * + * @param outputShapes the outputShapes option + * @return this Options instance. + */ + public Options outputShapes(Shape... outputShapes) { + this.outputShapes = Arrays.asList(outputShapes); + return this; + } + + /** + * Sets the parallelIterations option. + * + * @param parallelIterations the parallelIterations option + * @return this Options instance. + */ + public Options parallelIterations(Long parallelIterations) { + this.parallelIterations = parallelIterations; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaHostCompute.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaHostCompute.java index f442219f338..ff628c183c5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaHostCompute.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaHostCompute.java @@ -95,6 +95,12 @@ public static XlaHostCompute create(Scope scope, Iterable<Operand<?>> inputs, opBuilder.setAttr("key", key); if (options != null) { for (Options opts : options) { + if (opts.sendKey != null) { + opBuilder.setAttr("send_key", opts.sendKey); + } + if (opts.recvKey != null) { + opBuilder.setAttr("recv_key", opts.recvKey); + } if (opts.costEstimateNs != null) { opBuilder.setAttr("cost_estimate_ns", opts.costEstimateNs); } @@ -106,6 +112,26 @@ public static XlaHostCompute create(Scope scope, Iterable<Operand<?>> inputs, return new XlaHostCompute(opBuilder.build()); } + /** + * Sets the sendKey option. + * + * @param sendKey the sendKey option + * @return this Options instance. + */ + public static Options sendKey(String sendKey) { + return new Options().sendKey(sendKey); + } + + /** + * Sets the recvKey option. + * + * @param recvKey the recvKey option + * @return this Options instance. + */ + public static Options recvKey(String recvKey) { + return new Options().recvKey(recvKey); + } + /** * Sets the costEstimateNs option. * @@ -145,6 +171,10 @@ public Iterator<Operand<TType>> iterator() { * Optional attributes for {@link org.tensorflow.op.xla.XlaHostCompute} */ public static class Options { + private String sendKey; + + private String recvKey; + private Long costEstimateNs; private Long tpuCore; @@ -152,6 +182,28 @@ public static class Options { private Options() { } + /** + * Sets the sendKey option. + * + * @param sendKey the sendKey option + * @return this Options instance. + */ + public Options sendKey(String sendKey) { + this.sendKey = sendKey; + return this; + } + + /** + * Sets the recvKey option. + * + * @param recvKey the recvKey option + * @return this Options instance. + */ + public Options recvKey(String recvKey) { + this.recvKey = recvKey; + return this; + } + /** * Sets the costEstimateNs option. * diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicReduce.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicReduce.java index 16c8aeda370..317df94f72e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicReduce.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicReduce.java @@ -33,7 +33,8 @@ import org.tensorflow.types.family.TType; /** - * Wraps the variadic XLA Reduce operator, documented at + * Wraps the variadic XLA Reduce operator. + * Semantics are documented at * https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce. * * @param <T> data type for {@code output} output diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index f59604d7528..f5dcd7c2ce3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -1,18 +1,18 @@ /* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionSetAttrValueProto; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java index 967fec33d88..05e89a39722 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionName; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java index b2911d2eec0..57bc0bc9ffb 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op.core; import static org.junit.jupiter.api.Assertions.assertEquals; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/PartitionedCallTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/PartitionedCallTest.java index 865a55f4ac9..f075934d2cc 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/PartitionedCallTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/PartitionedCallTest.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op.core; import static org.junit.jupiter.api.Assertions.assertEquals; diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java index 0d3546655a3..d5d617158c7 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op.generator; import static org.tensorflow.op.generator.GeneratorUtils.javaizeMemberName; diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java index 6611eab61e7..5e7775194d7 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/FullOpDef.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op.generator; import com.squareup.javapoet.TypeSpec; diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java index f19d88416f5..0aa1638a861 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op.generator; import org.commonmark.node.Node; diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java index b72cf664279..9dc597c7f74 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op.generator; import com.google.protobuf.InvalidProtocolBufferException; diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java index f833d8001cb..246eb3878f6 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/StatefulPair.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op.generator; import com.squareup.javapoet.TypeSpec; From cc3be7fd28887882cb7fdb092d8af2ae8153f9a5 Mon Sep 17 00:00:00 2001 From: klessard <klessard@expedia.com> Date: Sat, 26 Jun 2021 11:15:00 -0400 Subject: [PATCH 13/14] Reclassify some ops --- .../api_def_AssertCardinalityDataset.pbtxt | 2 +- .../api_def/api_def_CompressElement.pbtxt | 2 +- .../api_def/api_def_DataServiceDataset.pbtxt | 2 +- .../api_def_DummyIterationCounter.pbtxt | 2 +- .../api_def/api_def_DummySeedGenerator.pbtxt | 2 +- .../api_def_GroupByReducerDataset.pbtxt | 3 + ...ef_LegacyParallelInterleaveDatasetV2.pbtxt | 2 +- .../api_def/api_def_MapAndBatchDataset.pbtxt | 5 +- .../api_def_MapAndBatchDatasetV2.pbtxt | 6 - .../api_def_ParallelInterleaveDatasetV4.pbtxt | 2 +- .../api_def_ParseExampleDatasetV2.pbtxt | 2 +- .../bazel/api_def/api_def_ReduceDataset.pbtxt | 3 + .../api_def_StatsAggregatorHandleV2.pbtxt | 2 +- ..._def_StatsAggregatorSetSummaryWriter.pbtxt | 2 +- .../api_def/api_def_UncompressElement.pbtxt | 2 +- .../src/bazel/api_def/api_def_XlaConv.pbtxt | 4 +- .../src/bazel/api_def/api_def_XlaConvV2.pbtxt | 3 + .../src/bazel/api_def/api_def_XlaDot.pbtxt | 4 +- .../src/bazel/api_def/api_def_XlaDotV2.pbtxt | 3 + .../api_def_XlaSetDynamicDimensionSize.pbtxt | 3 + .../api_def_XlaSpmdFullToShardShape.pbtxt | 3 + .../api_def_XlaSpmdShardToFullShape.pbtxt | 3 + .../api_def/api_def_XlaVariadicSort.pbtxt | 3 + .../org/tensorflow/op/DataOps.java | 33 +++ .../annotations/org/tensorflow/op/Ops.java | 140 +---------- .../annotations/org/tensorflow/op/XlaOps.java | 106 +++++++- .../internal/c_api/global/tensorflow.java | 2 +- .../org/tensorflow/op/core/XlaConvV2.java | 108 -------- .../java/org/tensorflow/op/core/XlaDotV2.java | 94 ------- .../AssertCardinalityDataset.java | 2 +- .../{experimental => }/CompressElement.java | 2 +- .../DataServiceDataset.java | 4 +- .../DummyIterationCounter.java | 2 +- .../{core => data}/GroupByReducerDataset.java | 2 +- .../LegacyParallelInterleaveDataset.java | 4 +- .../op/data/MapAndBatchDataset.java | 158 ++++++++++++ .../op/data/ParallelInterleaveDataset.java | 177 +++++++++++++ .../op/data/ParseExampleDataset.java | 236 ++++++++++++++++++ .../op/{core => data}/ReduceDataset.java | 4 +- .../op/data/StatsAggregatorHandle.java | 6 +- .../StatsAggregatorSetSummaryWriter.java | 2 +- .../{experimental => }/UncompressElement.java | 2 +- .../ParallelInterleaveDataset.java | 99 ++------ .../experimental/ParseExampleDataset.java | 95 ++----- .../experimental/StatsAggregatorHandle.java | 6 +- .../DummySeedGenerator.java | 2 +- .../gen/java/org/tensorflow/op/xla/Conv.java | 30 ++- .../gen/java/org/tensorflow/op/xla/Dot.java | 24 +- .../SetDynamicDimensionSize.java} | 20 +- .../SpmdFullToShardShape.java} | 18 +- .../SpmdShardToFullShape.java} | 18 +- .../op/{core => xla}/XlaVariadicSort.java | 6 +- .../src/gen/resources/ops.pb | Bin 1480980 -> 1480855 bytes 53 files changed, 865 insertions(+), 602 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapAndBatchDatasetV2.pbtxt delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaConvV2.java delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaDotV2.java rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/{experimental => }/AssertCardinalityDataset.java (98%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/{experimental => }/CompressElement.java (98%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/{experimental => }/DataServiceDataset.java (97%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/{experimental => }/DummyIterationCounter.java (98%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/{core => data}/GroupByReducerDataset.java (99%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/{experimental => }/LegacyParallelInterleaveDataset.java (97%) create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MapAndBatchDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelInterleaveDataset.java create mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParseExampleDataset.java rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/{core => data}/ReduceDataset.java (98%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/{experimental => }/StatsAggregatorSetSummaryWriter.java (97%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/{experimental => }/UncompressElement.java (98%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/{experimental => }/DummySeedGenerator.java (97%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/{core/XlaSetDynamicDimensionSize.java => xla/SetDynamicDimensionSize.java} (81%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/{core/XlaSpmdFullToShardShape.java => xla/SpmdFullToShardShape.java} (84%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/{core/XlaSpmdShardToFullShape.java => xla/SpmdShardToFullShape.java} (84%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/{core => xla}/XlaVariadicSort.java (98%) diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AssertCardinalityDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AssertCardinalityDataset.pbtxt index 7f8ad1c4491..1701d76bf52 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AssertCardinalityDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AssertCardinalityDataset.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "AssertCardinalityDataset" endpoint { - name: "data.experimental.AssertCardinalityDataset" + name: "data.AssertCardinalityDataset" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CompressElement.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CompressElement.pbtxt index f84051aa746..08158de13a5 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CompressElement.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CompressElement.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "CompressElement" endpoint { - name: "data.experimental.CompressElement" + name: "data.CompressElement" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DataServiceDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DataServiceDataset.pbtxt index de2818df181..1dd6077f8db 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DataServiceDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DataServiceDataset.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "DataServiceDataset" endpoint { - name: "data.experimental.DataServiceDataset" + name: "data.DataServiceDataset" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DummyIterationCounter.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DummyIterationCounter.pbtxt index 4f770897717..12b6ffbf5fa 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DummyIterationCounter.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DummyIterationCounter.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "DummyIterationCounter" endpoint { - name: "data.experimental.DummyIterationCounter" + name: "data.DummyIterationCounter" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DummySeedGenerator.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DummySeedGenerator.pbtxt index 24e723ea75c..4d550787d03 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DummySeedGenerator.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DummySeedGenerator.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "DummySeedGenerator" endpoint { - name: "random.experimental.DummySeedGenerator" + name: "random.DummySeedGenerator" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GroupByReducerDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GroupByReducerDataset.pbtxt index 1bd2c8f531b..8f8ae87314e 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GroupByReducerDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GroupByReducerDataset.pbtxt @@ -1,3 +1,6 @@ op { graph_op_name: "GroupByReducerDataset" + endpoint { + name: "data.GroupByReducerDataset" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LegacyParallelInterleaveDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LegacyParallelInterleaveDatasetV2.pbtxt index a07ccaeaa1f..71012d0b557 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LegacyParallelInterleaveDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LegacyParallelInterleaveDatasetV2.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "LegacyParallelInterleaveDatasetV2" endpoint { - name: "data.experimental.LegacyParallelInterleaveDataset" + name: "data.LegacyParallelInterleaveDataset" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapAndBatchDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapAndBatchDataset.pbtxt index cb96bf63d8f..bbf61d0f31e 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapAndBatchDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapAndBatchDataset.pbtxt @@ -1,4 +1,7 @@ op { graph_op_name: "MapAndBatchDataset" - visibility: SKIP + visibility: VISIBLE + endpoint { + name: "data.MapAndBatchDataset" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapAndBatchDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapAndBatchDatasetV2.pbtxt deleted file mode 100644 index b29c21888fa..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapAndBatchDatasetV2.pbtxt +++ /dev/null @@ -1,6 +0,0 @@ -op { - graph_op_name: "MapAndBatchDatasetV2" - endpoint { - name: "data.MapAndBatchDataset" - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelInterleaveDatasetV4.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelInterleaveDatasetV4.pbtxt index 56cfd4f9429..5ed14b3695c 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelInterleaveDatasetV4.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelInterleaveDatasetV4.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "ParallelInterleaveDatasetV4" endpoint { - name: "data.experimental.ParallelInterleaveDataset" + name: "data.ParallelInterleaveDataset" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParseExampleDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParseExampleDatasetV2.pbtxt index e0d99d55539..3c88a87c8ab 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParseExampleDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParseExampleDatasetV2.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "ParseExampleDatasetV2" endpoint { - name: "data.experimental.ParseExampleDataset" + name: "data.ParseExampleDataset" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ReduceDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ReduceDataset.pbtxt index b16c5dbb96c..4417f01ef90 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ReduceDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ReduceDataset.pbtxt @@ -1,3 +1,6 @@ op { graph_op_name: "ReduceDataset" + endpoint { + name: "data.ReduceDataset" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_StatsAggregatorHandleV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_StatsAggregatorHandleV2.pbtxt index 3985b578962..358f535edb9 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_StatsAggregatorHandleV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_StatsAggregatorHandleV2.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "StatsAggregatorHandleV2" endpoint { - name: "data.experimental.StatsAggregatorHandle" + name: "data.StatsAggregatorHandle" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_StatsAggregatorSetSummaryWriter.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_StatsAggregatorSetSummaryWriter.pbtxt index 71f71d8d877..a0c6b1f833d 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_StatsAggregatorSetSummaryWriter.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_StatsAggregatorSetSummaryWriter.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "StatsAggregatorSetSummaryWriter" endpoint { - name: "data.experimental.StatsAggregatorSetSummaryWriter" + name: "data.StatsAggregatorSetSummaryWriter" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UncompressElement.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UncompressElement.pbtxt index 137fb6d75be..46225e53cdd 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UncompressElement.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UncompressElement.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "UncompressElement" endpoint { - name: "data.experimental.UncompressElement" + name: "data.UncompressElement" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaConv.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaConv.pbtxt index 3f36b3e7442..0f5f916d30f 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaConv.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaConv.pbtxt @@ -1,6 +1,4 @@ op { graph_op_name: "XlaConv" - endpoint { - name: "xla.Conv" - } + visibility: SKIP } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaConvV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaConvV2.pbtxt index d2c9637c0ba..bd999901f47 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaConvV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaConvV2.pbtxt @@ -1,3 +1,6 @@ op { graph_op_name: "XlaConvV2" + endpoint { + name: "xla.Conv" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaDot.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaDot.pbtxt index 9489678d9f6..9af5cd4caaa 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaDot.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaDot.pbtxt @@ -1,6 +1,4 @@ op { graph_op_name: "XlaDot" - endpoint { - name: "xla.Dot" - } + visibility: SKIP } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaDotV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaDotV2.pbtxt index 357866b27ac..77e09ad7491 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaDotV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaDotV2.pbtxt @@ -1,3 +1,6 @@ op { graph_op_name: "XlaDotV2" + endpoint { + name: "xla.Dot" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaSetDynamicDimensionSize.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaSetDynamicDimensionSize.pbtxt index aeaeb87d701..724d91eab72 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaSetDynamicDimensionSize.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaSetDynamicDimensionSize.pbtxt @@ -1,3 +1,6 @@ op { graph_op_name: "XlaSetDynamicDimensionSize" + endpoint { + name: "xla.SetDynamicDimensionSize" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaSpmdFullToShardShape.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaSpmdFullToShardShape.pbtxt index ea7bf6f9458..560b18cebcb 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaSpmdFullToShardShape.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaSpmdFullToShardShape.pbtxt @@ -1,3 +1,6 @@ op { graph_op_name: "XlaSpmdFullToShardShape" + endpoint { + name: "xla.SpmdFullToShardShape" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaSpmdShardToFullShape.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaSpmdShardToFullShape.pbtxt index 0e5842e82d0..e960c39d326 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaSpmdShardToFullShape.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaSpmdShardToFullShape.pbtxt @@ -1,3 +1,6 @@ op { graph_op_name: "XlaSpmdShardToFullShape" + endpoint { + name: "xla.SpmdShardToFullShape" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaVariadicSort.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaVariadicSort.pbtxt index 5ae24c7686a..3fa661cd3c6 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaVariadicSort.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_XlaVariadicSort.pbtxt @@ -1,3 +1,6 @@ op { graph_op_name: "XlaVariadicSort" + endpoint { + name: "xla.XlaVariadicSort" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java index edf3d88f8ed..4f933b2807f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java @@ -35,6 +35,7 @@ import org.tensorflow.op.data.IteratorGetNextSync; import org.tensorflow.op.data.IteratorToStringHandle; import org.tensorflow.op.data.MakeIterator; +import org.tensorflow.op.data.MapAndBatchDataset; import org.tensorflow.op.data.MapDataset; import org.tensorflow.op.data.OneShotIterator; import org.tensorflow.op.data.OptionalFromValue; @@ -290,6 +291,38 @@ public MakeIterator makeIterator(Operand<? extends TType> dataset, return MakeIterator.create(scope, dataset, iterator); } + /** + * Creates a dataset that fuses mapping with batching. + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset} and then + * batches {@code batch_size} of them. + * <p>Unlike a "MapDataset", which applies {@code f} sequentially, this dataset invokes up + * to {@code batch_size * num_parallel_batches} copies of {@code f} in parallel. + * + * @param inputDataset A variant tensor representing the input dataset. + * @param otherArguments A list of tensors, typically values that were captured when building a closure + * for {@code f}. + * @param batchSize A scalar representing the number of elements to accumulate in a + * batch. It determines the number of concurrent invocations of {@code f} that process + * elements from {@code input_dataset} in parallel. + * @param numParallelCalls A scalar representing the maximum number of parallel invocations of the {@code map_fn} + * function. Applying the {@code map_fn} on consecutive input elements in parallel has + * the potential to improve input pipeline throughput. + * @param dropRemainder A scalar representing whether the last batch should be dropped in case its size + * is smaller than desired. + * @param f A function to apply to the outputs of {@code input_dataset}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of MapAndBatchDataset + */ + public MapAndBatchDataset mapAndBatchDataset(Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, Operand<TInt64> batchSize, + Operand<TInt64> numParallelCalls, Operand<TBool> dropRemainder, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + MapAndBatchDataset.Options... options) { + return MapAndBatchDataset.create(scope, inputDataset, otherArguments, batchSize, numParallelCalls, dropRemainder, f, outputTypes, outputShapes, options); + } + /** * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. * diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index dcf86da09a8..7abb451be2d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -295,12 +295,6 @@ import org.tensorflow.op.core.VariableShape; import org.tensorflow.op.core.Where; import org.tensorflow.op.core.While; -import org.tensorflow.op.core.XlaConvV2; -import org.tensorflow.op.core.XlaDotV2; -import org.tensorflow.op.core.XlaSetDynamicDimensionSize; -import org.tensorflow.op.core.XlaSpmdFullToShardShape; -import org.tensorflow.op.core.XlaSpmdShardToFullShape; -import org.tensorflow.op.core.XlaVariadicSort; import org.tensorflow.op.core.Zeros; import org.tensorflow.op.core.ZerosLike; import org.tensorflow.types.TBool; @@ -372,20 +366,20 @@ public final class Ops { public final SparseOps sparse; - public final BitwiseOps bitwise; - public final TpuOps tpu; + public final BitwiseOps bitwise; + public final MathOps math; public final AudioOps audio; public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -403,13 +397,13 @@ private Ops(Scope scope) { random = new RandomOps(this); strings = new StringsOps(this); sparse = new SparseOps(this); - bitwise = new BitwiseOps(this); tpu = new TpuOps(this); + bitwise = new BitwiseOps(this); math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** @@ -8116,128 +8110,6 @@ public While whileOp(Iterable<Operand<?>> input, ConcreteFunction cond, Concrete return While.create(scope, input, cond, body, options); } - /** - * Wraps the XLA ConvGeneralDilated operator, documented at - * https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution - * . - * - * @param <W> data type for {@code output} output - * @param lhs the input tensor - * @param rhs the kernel tensor - * @param windowStrides the inter-window strides - * @param padding the padding to apply at the start and end of each input dimensions - * @param lhsDilation dilation to apply between input elements - * @param rhsDilation dilation to apply between kernel elements - * @param featureGroupCount number of feature groups for grouped convolution. - * @param dimensionNumbers a serialized xla::ConvolutionDimensionNumbers proto. - * @param precisionConfig a serialized xla::PrecisionConfig proto. - * @param preferredElementType The type of the tensor. - * @param <W> data type for {@code XlaConvV2} output and operands - * @param <V> data type for {@code XlaConvV2} output and operands - * @return a new instance of XlaConvV2 - */ - public <W extends TType, V extends TNumber> XlaConvV2<W> xlaConvV2(Operand<? extends TType> lhs, - Operand<? extends TType> rhs, Operand<V> windowStrides, Operand<V> padding, - Operand<V> lhsDilation, Operand<V> rhsDilation, Operand<V> featureGroupCount, - String dimensionNumbers, String precisionConfig, Class<W> preferredElementType) { - return XlaConvV2.create(scope, lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, featureGroupCount, dimensionNumbers, precisionConfig, preferredElementType); - } - - /** - * Wraps the XLA DotGeneral operator, documented at - * https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral - * . - * - * @param <V> data type for {@code output} output - * @param lhs the LHS tensor - * @param rhs the RHS tensor - * @param dimensionNumbers a serialized xla::DotDimensionNumbers proto. - * @param precisionConfig a serialized xla::PrecisionConfig proto. - * @param preferredElementType The type of the tensor. - * @param <V> data type for {@code XlaDotV2} output and operands - * @return a new instance of XlaDotV2 - */ - public <V extends TType> XlaDotV2<V> xlaDotV2(Operand<? extends TType> lhs, - Operand<? extends TType> rhs, String dimensionNumbers, String precisionConfig, - Class<V> preferredElementType) { - return XlaDotV2.create(scope, lhs, rhs, dimensionNumbers, precisionConfig, preferredElementType); - } - - /** - * Make a static dimension into a xla bounded dynamic dimension. - * <pre> - * The current static dimension size will become the bound and the second - * operand becomes the dynamic size of the dimension. - * </pre> - * - * @param <T> data type for {@code output} output - * @param input the input value - * @param dimIndex the dimIndex value - * @param sizeOutput the sizeOutput value - * @param <T> data type for {@code XlaSetDynamicDimensionSize} output and operands - * @return a new instance of XlaSetDynamicDimensionSize - */ - public <T extends TType> XlaSetDynamicDimensionSize<T> xlaSetDynamicDimensionSize( - Operand<T> input, Operand<TInt32> dimIndex, Operand<TInt32> sizeOutput) { - return XlaSetDynamicDimensionSize.create(scope, input, dimIndex, sizeOutput); - } - - /** - * An op used by XLA SPMD partitioner to switch from automatic partitioning to - * manual partitioning. It annotates the input (full-shape, to be automatically - * partitioned) with the same sharding used by manual partitioning, and outputs a - * shard-shaped tensor to be consumed by later manually-partitioned ops. If the - * shape is not evenly partitionable, the padding region will be masked with 0s. - * - * @param <T> data type for {@code output} output - * @param input the input value - * @param manualSharding the value of the manualSharding property - * @param <T> data type for {@code XlaSpmdFullToShardShape} output and operands - * @return a new instance of XlaSpmdFullToShardShape - */ - public <T extends TType> XlaSpmdFullToShardShape<T> xlaSpmdFullToShardShape(Operand<T> input, - String manualSharding) { - return XlaSpmdFullToShardShape.create(scope, input, manualSharding); - } - - /** - * An op used by XLA SPMD partitioner to switch from manual partitioning to - * automatic partitioning. It converts the shard-shaped, manually partitioned input - * into full-shaped tensor to be partitioned automatically with the same sharding - * used by manual partitioning. - * - * @param <T> data type for {@code output} output - * @param input the input value - * @param manualSharding the value of the manualSharding property - * @param fullShape the value of the fullShape property - * @param <T> data type for {@code XlaSpmdShardToFullShape} output and operands - * @return a new instance of XlaSpmdShardToFullShape - */ - public <T extends TType> XlaSpmdShardToFullShape<T> xlaSpmdShardToFullShape(Operand<T> input, - String manualSharding, Shape fullShape) { - return XlaSpmdShardToFullShape.create(scope, input, manualSharding, fullShape); - } - - /** - * Wraps the XLA Sort operator, documented at - * https://www.tensorflow.org/performance/xla/operation_semantics#sort - * . - * <p>Sorts one or more tensors, with support for custom comparator, dimension, and - * is_stable attributes. - * - * @param inputs A list of {@code Tensor} of identical shape but possibly different types. - * @param dimension The dimension along which to sort. Must be a compile-time constant. - * @param comparator A comparator function to apply to 2*N scalars and returning a - * boolean. N is the number of sort inputs. If you want to sort in ascending - * order then the comparator should perform a less-than comparison. - * @param isStable Whether to use stable sort. - * @return a new instance of XlaVariadicSort - */ - public XlaVariadicSort xlaVariadicSort(Iterable<Operand<?>> inputs, Operand<TInt32> dimension, - ConcreteFunction comparator, Boolean isStable) { - return XlaVariadicSort.create(scope, inputs, dimension, comparator, isStable); - } - /** * Creates a zeroed tensor given its type and shape. * diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java index 16298aebaca..07dd99967d1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java @@ -41,8 +41,11 @@ import org.tensorflow.op.xla.SelectAndScatter; import org.tensorflow.op.xla.SelfAdjointEig; import org.tensorflow.op.xla.Send; +import org.tensorflow.op.xla.SetDynamicDimensionSize; import org.tensorflow.op.xla.Sharding; import org.tensorflow.op.xla.Sort; +import org.tensorflow.op.xla.SpmdFullToShardShape; +import org.tensorflow.op.xla.SpmdShardToFullShape; import org.tensorflow.op.xla.Svd; import org.tensorflow.op.xla.While; import org.tensorflow.op.xla.XlaHostCompute; @@ -51,6 +54,7 @@ import org.tensorflow.op.xla.XlaSendToHost; import org.tensorflow.op.xla.XlaSetBound; import org.tensorflow.op.xla.XlaVariadicReduce; +import org.tensorflow.op.xla.XlaVariadicSort; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -105,7 +109,7 @@ public <T extends TType> ClusterOutput<T> clusterOutput(Operand<T> input) { * https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution * . * - * @param <T> data type for {@code output} output + * @param <W> data type for {@code output} output * @param lhs the input tensor * @param rhs the kernel tensor * @param windowStrides the inter-window strides @@ -115,14 +119,16 @@ public <T extends TType> ClusterOutput<T> clusterOutput(Operand<T> input) { * @param featureGroupCount number of feature groups for grouped convolution. * @param dimensionNumbers a serialized xla::ConvolutionDimensionNumbers proto. * @param precisionConfig a serialized xla::PrecisionConfig proto. - * @param <T> data type for {@code XlaConv} output and operands - * @param <U> data type for {@code XlaConv} output and operands + * @param preferredElementType The type of the tensor. + * @param <W> data type for {@code XlaConvV2} output and operands + * @param <V> data type for {@code XlaConvV2} output and operands * @return a new instance of Conv */ - public <T extends TType, U extends TNumber> Conv<T> conv(Operand<T> lhs, Operand<T> rhs, - Operand<U> windowStrides, Operand<U> padding, Operand<U> lhsDilation, Operand<U> rhsDilation, - Operand<U> featureGroupCount, String dimensionNumbers, String precisionConfig) { - return Conv.create(scope, lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, featureGroupCount, dimensionNumbers, precisionConfig); + public <W extends TType, V extends TNumber> Conv<W> conv(Operand<? extends TType> lhs, + Operand<? extends TType> rhs, Operand<V> windowStrides, Operand<V> padding, + Operand<V> lhsDilation, Operand<V> rhsDilation, Operand<V> featureGroupCount, + String dimensionNumbers, String precisionConfig, Class<W> preferredElementType) { + return Conv.create(scope, lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, featureGroupCount, dimensionNumbers, precisionConfig, preferredElementType); } /** @@ -147,17 +153,18 @@ public Dequantize dequantize(Operand<? extends TType> input, Float minRange, Flo * https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral * . * - * @param <T> data type for {@code output} output + * @param <V> data type for {@code output} output * @param lhs the LHS tensor * @param rhs the RHS tensor * @param dimensionNumbers a serialized xla::DotDimensionNumbers proto. * @param precisionConfig a serialized xla::PrecisionConfig proto. - * @param <T> data type for {@code XlaDot} output and operands + * @param preferredElementType The type of the tensor. + * @param <V> data type for {@code XlaDotV2} output and operands * @return a new instance of Dot */ - public <T extends TType> Dot<T> dot(Operand<T> lhs, Operand<T> rhs, String dimensionNumbers, - String precisionConfig) { - return Dot.create(scope, lhs, rhs, dimensionNumbers, precisionConfig); + public <V extends TType> Dot<V> dot(Operand<? extends TType> lhs, Operand<? extends TType> rhs, + String dimensionNumbers, String precisionConfig, Class<V> preferredElementType) { + return Dot.create(scope, lhs, rhs, dimensionNumbers, precisionConfig, preferredElementType); } /** @@ -454,6 +461,25 @@ public Send send(Operand<? extends TType> tensor, String tensorName) { return Send.create(scope, tensor, tensorName); } + /** + * Make a static dimension into a xla bounded dynamic dimension. + * <pre> + * The current static dimension size will become the bound and the second + * operand becomes the dynamic size of the dimension. + * </pre> + * + * @param <T> data type for {@code output} output + * @param input the input value + * @param dimIndex the dimIndex value + * @param sizeOutput the sizeOutput value + * @param <T> data type for {@code XlaSetDynamicDimensionSize} output and operands + * @return a new instance of SetDynamicDimensionSize + */ + public <T extends TType> SetDynamicDimensionSize<T> setDynamicDimensionSize(Operand<T> input, + Operand<TInt32> dimIndex, Operand<TInt32> sizeOutput) { + return SetDynamicDimensionSize.create(scope, input, dimIndex, sizeOutput); + } + /** * An op which shards the input based on the given sharding attribute. * @@ -482,6 +508,42 @@ public <T extends TType> Sort<T> sort(Operand<T> input) { return Sort.create(scope, input); } + /** + * An op used by XLA SPMD partitioner to switch from automatic partitioning to + * manual partitioning. It annotates the input (full-shape, to be automatically + * partitioned) with the same sharding used by manual partitioning, and outputs a + * shard-shaped tensor to be consumed by later manually-partitioned ops. If the + * shape is not evenly partitionable, the padding region will be masked with 0s. + * + * @param <T> data type for {@code output} output + * @param input the input value + * @param manualSharding the value of the manualSharding property + * @param <T> data type for {@code XlaSpmdFullToShardShape} output and operands + * @return a new instance of SpmdFullToShardShape + */ + public <T extends TType> SpmdFullToShardShape<T> spmdFullToShardShape(Operand<T> input, + String manualSharding) { + return SpmdFullToShardShape.create(scope, input, manualSharding); + } + + /** + * An op used by XLA SPMD partitioner to switch from manual partitioning to + * automatic partitioning. It converts the shard-shaped, manually partitioned input + * into full-shaped tensor to be partitioned automatically with the same sharding + * used by manual partitioning. + * + * @param <T> data type for {@code output} output + * @param input the input value + * @param manualSharding the value of the manualSharding property + * @param fullShape the value of the fullShape property + * @param <T> data type for {@code XlaSpmdShardToFullShape} output and operands + * @return a new instance of SpmdShardToFullShape + */ + public <T extends TType> SpmdShardToFullShape<T> spmdShardToFullShape(Operand<T> input, + String manualSharding, Shape fullShape) { + return SpmdShardToFullShape.create(scope, input, manualSharding, fullShape); + } + /** * Computes the eigen decomposition of a batch of self-adjoint matrices * (Note: Only real inputs are supported). @@ -624,6 +686,26 @@ public <T extends TType> XlaVariadicReduce<T> xlaVariadicReduce(Iterable<Operand return XlaVariadicReduce.create(scope, input, initValue, dimensionsToReduce, reducer); } + /** + * Wraps the XLA Sort operator, documented at + * https://www.tensorflow.org/performance/xla/operation_semantics#sort + * . + * <p>Sorts one or more tensors, with support for custom comparator, dimension, and + * is_stable attributes. + * + * @param inputs A list of {@code Tensor} of identical shape but possibly different types. + * @param dimension The dimension along which to sort. Must be a compile-time constant. + * @param comparator A comparator function to apply to 2*N scalars and returning a + * boolean. N is the number of sort inputs. If you want to sort in ascending + * order then the comparator should perform a less-than comparison. + * @param isStable Whether to use stable sort. + * @return a new instance of XlaVariadicSort + */ + public XlaVariadicSort xlaVariadicSort(Iterable<Operand<?>> inputs, Operand<TInt32> dimension, + ConcreteFunction comparator, Boolean isStable) { + return XlaVariadicSort.create(scope, inputs, dimension, comparator, isStable); + } + /** * Get the parent {@link Ops} object. */ diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java index 2441bc1af65..b345ab4dad2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java @@ -3150,7 +3150,7 @@ public static native void TF_RegisterFilesystemPlugin( // TF_InitKernel to do op/kernel registration. // Plugin should implement TF_InitKernel to register kernels. This function // should register all kernels in a plugin. -public static native void TF_InitKernel(); + // Targeting ../Create_func_TF_OpKernelConstruction.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaConvV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaConvV2.java deleted file mode 100644 index fe95ace9832..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaConvV2.java +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.core; - -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.op.Operands; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; - -/** - * Wraps the XLA ConvGeneralDilated operator, documented at - * https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution - * . - * - * @param <W> data type for {@code output} output - */ -@Operator -public final class XlaConvV2<W extends TType> extends RawOp implements Operand<W> { - /** - * The name of this op, as known by TensorFlow core engine - */ - public static final String OP_NAME = "XlaConvV2"; - - private Output<W> output; - - private XlaConvV2(Operation operation) { - super(operation); - int outputIdx = 0; - output = operation.output(outputIdx++); - } - - /** - * Factory method to create a class wrapping a new XlaConvV2 operation. - * - * @param scope current scope - * @param lhs the input tensor - * @param rhs the kernel tensor - * @param windowStrides the inter-window strides - * @param padding the padding to apply at the start and end of each input dimensions - * @param lhsDilation dilation to apply between input elements - * @param rhsDilation dilation to apply between kernel elements - * @param featureGroupCount number of feature groups for grouped convolution. - * @param dimensionNumbers a serialized xla::ConvolutionDimensionNumbers proto. - * @param precisionConfig a serialized xla::PrecisionConfig proto. - * @param preferredElementType The type of the tensor. - * @param <W> data type for {@code XlaConvV2} output and operands - * @param <V> data type for {@code XlaConvV2} output and operands - * @return a new instance of XlaConvV2 - */ - @Endpoint( - describeByClass = true - ) - public static <W extends TType, V extends TNumber> XlaConvV2<W> create(Scope scope, - Operand<? extends TType> lhs, Operand<? extends TType> rhs, Operand<V> windowStrides, - Operand<V> padding, Operand<V> lhsDilation, Operand<V> rhsDilation, - Operand<V> featureGroupCount, String dimensionNumbers, String precisionConfig, - Class<W> preferredElementType) { - OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("XlaConvV2")); - opBuilder.addInput(lhs.asOutput()); - opBuilder.addInput(rhs.asOutput()); - opBuilder.addInput(windowStrides.asOutput()); - opBuilder.addInput(padding.asOutput()); - opBuilder.addInput(lhsDilation.asOutput()); - opBuilder.addInput(rhsDilation.asOutput()); - opBuilder.addInput(featureGroupCount.asOutput()); - opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dimension_numbers", dimensionNumbers); - opBuilder.setAttr("precision_config", precisionConfig); - opBuilder.setAttr("preferred_element_type", Operands.toDataType(preferredElementType)); - return new XlaConvV2<>(opBuilder.build()); - } - - /** - * Gets output. - * - * @return output. - */ - public Output<W> output() { - return output; - } - - @Override - public Output<W> asOutput() { - return output; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaDotV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaDotV2.java deleted file mode 100644 index 6d94739df68..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaDotV2.java +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.core; - -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.op.Operands; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.types.family.TType; - -/** - * Wraps the XLA DotGeneral operator, documented at - * https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral - * . - * - * @param <V> data type for {@code output} output - */ -@Operator -public final class XlaDotV2<V extends TType> extends RawOp implements Operand<V> { - /** - * The name of this op, as known by TensorFlow core engine - */ - public static final String OP_NAME = "XlaDotV2"; - - private Output<V> output; - - private XlaDotV2(Operation operation) { - super(operation); - int outputIdx = 0; - output = operation.output(outputIdx++); - } - - /** - * Factory method to create a class wrapping a new XlaDotV2 operation. - * - * @param scope current scope - * @param lhs the LHS tensor - * @param rhs the RHS tensor - * @param dimensionNumbers a serialized xla::DotDimensionNumbers proto. - * @param precisionConfig a serialized xla::PrecisionConfig proto. - * @param preferredElementType The type of the tensor. - * @param <V> data type for {@code XlaDotV2} output and operands - * @return a new instance of XlaDotV2 - */ - @Endpoint( - describeByClass = true - ) - public static <V extends TType> XlaDotV2<V> create(Scope scope, Operand<? extends TType> lhs, - Operand<? extends TType> rhs, String dimensionNumbers, String precisionConfig, - Class<V> preferredElementType) { - OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("XlaDotV2")); - opBuilder.addInput(lhs.asOutput()); - opBuilder.addInput(rhs.asOutput()); - opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dimension_numbers", dimensionNumbers); - opBuilder.setAttr("precision_config", precisionConfig); - opBuilder.setAttr("preferred_element_type", Operands.toDataType(preferredElementType)); - return new XlaDotV2<>(opBuilder.build()); - } - - /** - * Gets output. - * - * @return output. - */ - public Output<V> output() { - return output; - } - - @Override - public Output<V> asOutput() { - return output; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertCardinalityDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertCardinalityDataset.java similarity index 98% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertCardinalityDataset.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertCardinalityDataset.java index 90ff976d9e5..e79b4f5ece5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertCardinalityDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertCardinalityDataset.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.data.experimental; +package org.tensorflow.op.data; import java.util.List; import org.tensorflow.Operand; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/CompressElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CompressElement.java similarity index 98% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/CompressElement.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CompressElement.java index 30caf1434f3..10be8060ede 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/CompressElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CompressElement.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.data.experimental; +package org.tensorflow.op.data; import org.tensorflow.Operand; import org.tensorflow.Operation; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DataServiceDataset.java similarity index 97% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DataServiceDataset.java index dd4a3e3f6d9..cb7ed81bcd1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DataServiceDataset.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.data.experimental; +package org.tensorflow.op.data; import java.util.List; import org.tensorflow.Operand; @@ -137,7 +137,7 @@ public Output<TType> asOutput() { } /** - * Optional attributes for {@link org.tensorflow.op.data.experimental.DataServiceDataset} + * Optional attributes for {@link org.tensorflow.op.data.DataServiceDataset} */ public static class Options { private Long taskRefreshIntervalHintMs; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DummyIterationCounter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DummyIterationCounter.java similarity index 98% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DummyIterationCounter.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DummyIterationCounter.java index d7e6e4ce05e..3b6fcd763b7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DummyIterationCounter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DummyIterationCounter.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.data.experimental; +package org.tensorflow.op.data; import org.tensorflow.Operand; import org.tensorflow.Operation; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GroupByReducerDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByReducerDataset.java similarity index 99% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GroupByReducerDataset.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByReducerDataset.java index b5c042cd733..cff25af74b0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GroupByReducerDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByReducerDataset.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.core; +package org.tensorflow.op.data; import java.util.List; import org.tensorflow.ConcreteFunction; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LegacyParallelInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LegacyParallelInterleaveDataset.java similarity index 97% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LegacyParallelInterleaveDataset.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LegacyParallelInterleaveDataset.java index 36dd881dc20..adeb0e4f634 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LegacyParallelInterleaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LegacyParallelInterleaveDataset.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.data.experimental; +package org.tensorflow.op.data; import java.util.List; import org.tensorflow.ConcreteFunction; @@ -133,7 +133,7 @@ public Output<TType> asOutput() { } /** - * Optional attributes for {@link org.tensorflow.op.data.experimental.LegacyParallelInterleaveDataset} + * Optional attributes for {@link org.tensorflow.op.data.LegacyParallelInterleaveDataset} */ public static class Options { private String deterministic; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MapAndBatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MapAndBatchDataset.java new file mode 100644 index 00000000000..3526f689b30 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MapAndBatchDataset.java @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that fuses mapping with batching. + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset} and then + * batches {@code batch_size} of them. + * <p>Unlike a "MapDataset", which applies {@code f} sequentially, this dataset invokes up + * to {@code batch_size * num_parallel_batches} copies of {@code f} in parallel. + */ +@Operator( + group = "data" +) +public final class MapAndBatchDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "MapAndBatchDataset"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private MapAndBatchDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new MapAndBatchDataset operation. + * + * @param scope current scope + * @param inputDataset A variant tensor representing the input dataset. + * @param otherArguments A list of tensors, typically values that were captured when building a closure + * for {@code f}. + * @param batchSize A scalar representing the number of elements to accumulate in a + * batch. It determines the number of concurrent invocations of {@code f} that process + * elements from {@code input_dataset} in parallel. + * @param numParallelCalls A scalar representing the maximum number of parallel invocations of the {@code map_fn} + * function. Applying the {@code map_fn} on consecutive input elements in parallel has + * the potential to improve input pipeline throughput. + * @param dropRemainder A scalar representing whether the last batch should be dropped in case its size + * is smaller than desired. + * @param f A function to apply to the outputs of {@code input_dataset}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of MapAndBatchDataset + */ + @Endpoint( + describeByClass = true + ) + public static MapAndBatchDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, Operand<TInt64> batchSize, + Operand<TInt64> numParallelCalls, Operand<TBool> dropRemainder, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("MapAndBatchDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder.addInput(batchSize.asOutput()); + opBuilder.addInput(numParallelCalls.asOutput()); + opBuilder.addInput(dropRemainder.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + if (options != null) { + for (Options opts : options) { + if (opts.preserveCardinality != null) { + opBuilder.setAttr("preserve_cardinality", opts.preserveCardinality); + } + } + } + return new MapAndBatchDataset(opBuilder.build()); + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public static Options preserveCardinality(Boolean preserveCardinality) { + return new Options().preserveCardinality(preserveCardinality); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.MapAndBatchDataset} + */ + public static class Options { + private Boolean preserveCardinality; + + private Options() { + } + + /** + * Sets the preserveCardinality option. + * + * @param preserveCardinality the preserveCardinality option + * @return this Options instance. + */ + public Options preserveCardinality(Boolean preserveCardinality) { + this.preserveCardinality = preserveCardinality; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelInterleaveDataset.java new file mode 100644 index 00000000000..59571258049 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelInterleaveDataset.java @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.List; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +/** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + * The resulting dataset is similar to the {@code InterleaveDataset}, except that the + * dataset will fetch records from the interleaved datasets in parallel. + * <p>The {@code tf.data} Python API creates instances of this op from + * {@code Dataset.interleave()} when the {@code num_parallel_calls} parameter of that method + * is set to any value other than {@code None}. + * <p>By default, the output of this dataset will be deterministic, which may result + * in the dataset blocking if the next data item to be returned isn't available. + * In order to avoid head-of-line blocking, one can either set the {@code deterministic} + * attribute to "false", or leave it as "default" and set the + * {@code experimental_deterministic} parameter of {@code tf.data.Options} to {@code False}. + * This can improve performance at the expense of non-determinism. + */ +public final class ParallelInterleaveDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ParallelInterleaveDatasetV4"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private ParallelInterleaveDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ParallelInterleaveDatasetV4 operation. + * + * @param scope current scope + * @param inputDataset Dataset that produces a stream of arguments for the function {@code f}. + * @param otherArguments Additional arguments to pass to {@code f} beyond those produced by {@code input_dataset}. + * Evaluated once when the dataset is instantiated. + * @param cycleLength Number of datasets (each created by applying {@code f} to the elements of + * {@code input_dataset}) among which the {@code ParallelInterleaveDatasetV2} will cycle in a + * round-robin fashion. + * @param blockLength Number of elements at a time to produce from each interleaved invocation of a + * dataset returned by {@code f}. + * @param bufferOutputElements The number of elements each iterator being interleaved should buffer (similar + * to the {@code .prefetch()} transformation for each interleaved iterator). + * @param prefetchInputElements Determines the number of iterators to prefetch, allowing buffers to warm up and + * data to be pre-fetched without blocking the main thread. + * @param numParallelCalls Determines the number of threads that should be used for fetching data from + * input datasets in parallel. The Python API {@code tf.data.experimental.AUTOTUNE} + * constant can be used to indicate that the level of parallelism should be autotuned. + * @param f A function mapping elements of {@code input_dataset}, concatenated with + * {@code other_arguments}, to a Dataset variant that contains elements matching + * {@code output_types} and {@code output_shapes}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ParallelInterleaveDataset + */ + @Endpoint( + describeByClass = true + ) + public static ParallelInterleaveDataset create(Scope scope, Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, Operand<TInt64> cycleLength, Operand<TInt64> blockLength, + Operand<TInt64> bufferOutputElements, Operand<TInt64> prefetchInputElements, + Operand<TInt64> numParallelCalls, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("ParallelInterleaveDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInputList(Operands.asOutputs(otherArguments)); + opBuilder.addInput(cycleLength.asOutput()); + opBuilder.addInput(blockLength.asOutput()); + opBuilder.addInput(bufferOutputElements.asOutput()); + opBuilder.addInput(prefetchInputElements.asOutput()); + opBuilder.addInput(numParallelCalls.asOutput()); + opBuilder = scope.apply(opBuilder); + opBuilder.setAttr("f", f); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + if (options != null) { + for (Options opts : options) { + if (opts.deterministic != null) { + opBuilder.setAttr("deterministic", opts.deterministic); + } + } + } + return new ParallelInterleaveDataset(opBuilder.build()); + } + + /** + * Sets the deterministic option. + * + * @param deterministic A string indicating the op-level determinism to use. Deterministic controls + * whether the interleave is allowed to return elements out of order if the next + * element to be returned isn't available, but a later element is. Options are + * "true", "false", and "default". "default" indicates that determinism should be + * decided by the {@code experimental_deterministic} parameter of {@code tf.data.Options}. + * @return this Options instance. + */ + public static Options deterministic(String deterministic) { + return new Options().deterministic(deterministic); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.ParallelInterleaveDataset} + */ + public static class Options { + private String deterministic; + + private Options() { + } + + /** + * Sets the deterministic option. + * + * @param deterministic A string indicating the op-level determinism to use. Deterministic controls + * whether the interleave is allowed to return elements out of order if the next + * element to be returned isn't available, but a later element is. Options are + * "true", "false", and "default". "default" indicates that determinism should be + * decided by the {@code experimental_deterministic} parameter of {@code tf.data.Options}. + * @return this Options instance. + */ + public Options deterministic(String deterministic) { + this.deterministic = deterministic; + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParseExampleDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParseExampleDataset.java new file mode 100644 index 00000000000..35c277f46a9 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParseExampleDataset.java @@ -0,0 +1,236 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.Arrays; +import java.util.List; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +/** + * Transforms {@code input_dataset} containing {@code Example} protos as vectors of DT_STRING into a dataset of {@code Tensor} or {@code SparseTensor} objects representing the parsed features. + */ +public final class ParseExampleDataset extends RawOp implements Operand<TType> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ParseExampleDatasetV2"; + + private Output<? extends TType> handle; + + @SuppressWarnings("unchecked") + private ParseExampleDataset(Operation operation) { + super(operation); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ParseExampleDatasetV2 operation. + * + * @param scope current scope + * @param inputDataset the inputDataset value + * @param numParallelCalls the numParallelCalls value + * @param denseDefaults A dict mapping string keys to {@code Tensor}s. + * The keys of the dict must match the dense_keys of the feature. + * @param sparseKeys A list of string keys in the examples features. + * The results for these keys will be returned as {@code SparseTensor} objects. + * @param denseKeys A list of Ndense string Tensors (scalars). + * The keys expected in the Examples features associated with dense values. + * @param sparseTypes A list of {@code DTypes} of the same length as {@code sparse_keys}. + * Only {@code tf.float32} ({@code FloatList}), {@code tf.int64} ({@code Int64List}), + * and {@code tf.string} ({@code BytesList}) are supported. + * @param denseShapes List of tuples with the same length as {@code dense_keys}. + * The shape of the data for each dense feature referenced by {@code dense_keys}. + * Required for any input tensors identified by {@code dense_keys}. Must be + * either fully defined, or may contain an unknown first dimension. + * An unknown first dimension means the feature is treated as having + * a variable number of blocks, and the output shape along this dimension + * is considered unknown at graph build time. Padding is applied for + * minibatch elements smaller than the maximum number of blocks for the + * given feature along this dimension. + * @param outputTypes The type list for the return values. + * @param outputShapes The list of shapes being produced. + * @param raggedValueTypes the value of the raggedValueTypes property + * @param raggedSplitTypes the value of the raggedSplitTypes property + * @param options carries optional attribute values + * @return a new instance of ParseExampleDataset + */ + @Endpoint( + describeByClass = true + ) + public static ParseExampleDataset create(Scope scope, Operand<? extends TType> inputDataset, + Operand<TInt64> numParallelCalls, Iterable<Operand<?>> denseDefaults, List<String> sparseKeys, + List<String> denseKeys, List<Class<? extends TType>> sparseTypes, List<Shape> denseShapes, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + List<Class<? extends TType>> raggedValueTypes, + List<Class<? extends TNumber>> raggedSplitTypes, Options... options) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("ParseExampleDataset")); + opBuilder.addInput(inputDataset.asOutput()); + opBuilder.addInput(numParallelCalls.asOutput()); + opBuilder.addInputList(Operands.asOutputs(denseDefaults)); + opBuilder = scope.apply(opBuilder); + String[] sparseKeysArray = new String[sparseKeys.size()]; + for (int i = 0 ; i < sparseKeysArray.length ; i++) { + sparseKeysArray[i] = sparseKeys.get(i); + } + opBuilder.setAttr("sparse_keys", sparseKeysArray); + String[] denseKeysArray = new String[denseKeys.size()]; + for (int i = 0 ; i < denseKeysArray.length ; i++) { + denseKeysArray[i] = denseKeys.get(i); + } + opBuilder.setAttr("dense_keys", denseKeysArray); + opBuilder.setAttr("sparse_types", Operands.toDataTypes(sparseTypes)); + Shape[] denseShapesArray = new Shape[denseShapes.size()]; + for (int i = 0 ; i < denseShapesArray.length ; i++) { + denseShapesArray[i] = denseShapes.get(i); + } + opBuilder.setAttr("dense_shapes", denseShapesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + opBuilder.setAttr("ragged_value_types", Operands.toDataTypes(raggedValueTypes)); + opBuilder.setAttr("ragged_split_types", Operands.toDataTypes(raggedSplitTypes)); + if (options != null) { + for (Options opts : options) { + if (opts.deterministic != null) { + opBuilder.setAttr("deterministic", opts.deterministic); + } + if (opts.raggedKeys != null) { + String[] raggedKeysArray = new String[opts.raggedKeys.size()]; + for (int i = 0 ; i < raggedKeysArray.length ; i++) { + raggedKeysArray[i] = opts.raggedKeys.get(i); + } + opBuilder.setAttr("ragged_keys", raggedKeysArray); + } + } + } + return new ParseExampleDataset(opBuilder.build()); + } + + /** + * Sets the deterministic option. + * + * @param deterministic A string indicating the op-level determinism to use. Deterministic controls + * whether the dataset is allowed to return elements out of order if the next + * element to be returned isn't available, but a later element is. Options are + * "true", "false", and "default". "default" indicates that determinism should be + * decided by the {@code experimental_deterministic} parameter of {@code tf.data.Options}. + * @return this Options instance. + */ + public static Options deterministic(String deterministic) { + return new Options().deterministic(deterministic); + } + + /** + * Sets the raggedKeys option. + * + * @param raggedKeys the raggedKeys option + * @return this Options instance. + */ + public static Options raggedKeys(List<String> raggedKeys) { + return new Options().raggedKeys(raggedKeys); + } + + /** + * Sets the raggedKeys option. + * + * @param raggedKeys the raggedKeys option + * @return this Options instance. + */ + public static Options raggedKeys(String[] raggedKeys) { + return new Options().raggedKeys(raggedKeys); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output<? extends TType> handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output<TType> asOutput() { + return (Output<TType>) handle; + } + + /** + * Optional attributes for {@link org.tensorflow.op.data.ParseExampleDataset} + */ + public static class Options { + private String deterministic; + + private List<String> raggedKeys; + + private Options() { + } + + /** + * Sets the deterministic option. + * + * @param deterministic A string indicating the op-level determinism to use. Deterministic controls + * whether the dataset is allowed to return elements out of order if the next + * element to be returned isn't available, but a later element is. Options are + * "true", "false", and "default". "default" indicates that determinism should be + * decided by the {@code experimental_deterministic} parameter of {@code tf.data.Options}. + * @return this Options instance. + */ + public Options deterministic(String deterministic) { + this.deterministic = deterministic; + return this; + } + + /** + * Sets the raggedKeys option. + * + * @param raggedKeys the raggedKeys option + * @return this Options instance. + */ + public Options raggedKeys(List<String> raggedKeys) { + this.raggedKeys = raggedKeys; + return this; + } + + /** + * Sets the raggedKeys option. + * + * @param raggedKeys the raggedKeys option + * @return this Options instance. + */ + public Options raggedKeys(String... raggedKeys) { + this.raggedKeys = Arrays.asList(raggedKeys); + return this; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ReduceDataset.java similarity index 98% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceDataset.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ReduceDataset.java index c82507fe5f7..d2c4ffd3aae 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ReduceDataset.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.core; +package org.tensorflow.op.data; import java.util.Arrays; import java.util.Iterator; @@ -122,7 +122,7 @@ public Iterator<Operand<TType>> iterator() { } /** - * Optional attributes for {@link org.tensorflow.op.core.ReduceDataset} + * Optional attributes for {@link org.tensorflow.op.data.ReduceDataset} */ public static class Options { private Boolean useInterOpParallelism; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/StatsAggregatorHandle.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/StatsAggregatorHandle.java index a4cd180c217..99c227869fc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/StatsAggregatorHandle.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/StatsAggregatorHandle.java @@ -27,13 +27,13 @@ import org.tensorflow.types.family.TType; /** - * Creates a statistics manager resource. + * The StatsAggregatorHandleV2 operation */ public final class StatsAggregatorHandle extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "StatsAggregatorHandle"; + public static final String OP_NAME = "StatsAggregatorHandleV2"; private Output<? extends TType> handle; @@ -45,7 +45,7 @@ private StatsAggregatorHandle(Operation operation) { } /** - * Factory method to create a class wrapping a new StatsAggregatorHandle operation. + * Factory method to create a class wrapping a new StatsAggregatorHandleV2 operation. * * @param scope current scope * @param options carries optional attribute values diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorSetSummaryWriter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/StatsAggregatorSetSummaryWriter.java similarity index 97% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorSetSummaryWriter.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/StatsAggregatorSetSummaryWriter.java index a0630991efe..aa3e56da61b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorSetSummaryWriter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/StatsAggregatorSetSummaryWriter.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.data.experimental; +package org.tensorflow.op.data; import org.tensorflow.Operand; import org.tensorflow.Operation; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UncompressElement.java similarity index 98% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UncompressElement.java index 0f498626609..344915555c9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UncompressElement.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.data.experimental; +package org.tensorflow.op.data; import java.util.Arrays; import java.util.Iterator; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParallelInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParallelInterleaveDataset.java index b0475a6d457..0aa546f2f9c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParallelInterleaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParallelInterleaveDataset.java @@ -28,28 +28,24 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. - * The resulting dataset is similar to the {@code InterleaveDataset}, except that the - * dataset will fetch records from the interleaved datasets in parallel. - * <p>The {@code tf.data} Python API creates instances of this op from - * {@code Dataset.interleave()} when the {@code num_parallel_calls} parameter of that method - * is set to any value other than {@code None}. - * <p>By default, the output of this dataset will be deterministic, which may result - * in the dataset blocking if the next data item to be returned isn't available. - * In order to avoid head-of-line blocking, one can either set the {@code deterministic} - * attribute to "false", or leave it as "default" and set the - * {@code experimental_deterministic} parameter of {@code tf.data.Options} to {@code False}. - * This can improve performance at the expense of non-determinism. + * The resulting dataset is similar to the {@code InterleaveDataset}, with the exception + * that if retrieving the next value from a dataset would cause the requester to + * block, it will skip that input dataset. This dataset is especially useful + * when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it + * allows the training step to proceed so long as some data is available. + * <p>!! WARNING !! This dataset is not deterministic! */ public final class ParallelInterleaveDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "ParallelInterleaveDatasetV4"; + public static final String OP_NAME = "ExperimentalParallelInterleaveDataset"; private Output<? extends TType> handle; @@ -61,30 +57,21 @@ private ParallelInterleaveDataset(Operation operation) { } /** - * Factory method to create a class wrapping a new ParallelInterleaveDatasetV4 operation. + * Factory method to create a class wrapping a new ExperimentalParallelInterleaveDataset operation. * * @param scope current scope - * @param inputDataset Dataset that produces a stream of arguments for the function {@code f}. - * @param otherArguments Additional arguments to pass to {@code f} beyond those produced by {@code input_dataset}. - * Evaluated once when the dataset is instantiated. - * @param cycleLength Number of datasets (each created by applying {@code f} to the elements of - * {@code input_dataset}) among which the {@code ParallelInterleaveDatasetV2} will cycle in a - * round-robin fashion. - * @param blockLength Number of elements at a time to produce from each interleaved invocation of a - * dataset returned by {@code f}. - * @param bufferOutputElements The number of elements each iterator being interleaved should buffer (similar - * to the {@code .prefetch()} transformation for each interleaved iterator). - * @param prefetchInputElements Determines the number of iterators to prefetch, allowing buffers to warm up and - * data to be pre-fetched without blocking the main thread. - * @param numParallelCalls Determines the number of threads that should be used for fetching data from - * input datasets in parallel. The Python API {@code tf.data.experimental.AUTOTUNE} - * constant can be used to indicate that the level of parallelism should be autotuned. + * @param inputDataset the inputDataset value + * @param otherArguments the otherArguments value + * @param cycleLength the cycleLength value + * @param blockLength the blockLength value + * @param sloppy the sloppy value + * @param bufferOutputElements the bufferOutputElements value + * @param prefetchInputElements the prefetchInputElements value * @param f A function mapping elements of {@code input_dataset}, concatenated with * {@code other_arguments}, to a Dataset variant that contains elements matching * {@code output_types} and {@code output_shapes}. * @param outputTypes the value of the outputTypes property * @param outputShapes the value of the outputShapes property - * @param options carries optional attribute values * @return a new instance of ParallelInterleaveDataset */ @Endpoint( @@ -92,17 +79,17 @@ private ParallelInterleaveDataset(Operation operation) { ) public static ParallelInterleaveDataset create(Scope scope, Operand<? extends TType> inputDataset, Iterable<Operand<?>> otherArguments, Operand<TInt64> cycleLength, Operand<TInt64> blockLength, - Operand<TInt64> bufferOutputElements, Operand<TInt64> prefetchInputElements, - Operand<TInt64> numParallelCalls, ConcreteFunction f, - List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { + Operand<TBool> sloppy, Operand<TInt64> bufferOutputElements, + Operand<TInt64> prefetchInputElements, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("ParallelInterleaveDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(otherArguments)); opBuilder.addInput(cycleLength.asOutput()); opBuilder.addInput(blockLength.asOutput()); + opBuilder.addInput(sloppy.asOutput()); opBuilder.addInput(bufferOutputElements.asOutput()); opBuilder.addInput(prefetchInputElements.asOutput()); - opBuilder.addInput(numParallelCalls.asOutput()); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("f", f); opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); @@ -111,30 +98,9 @@ public static ParallelInterleaveDataset create(Scope scope, Operand<? extends TT outputShapesArray[i] = outputShapes.get(i); } opBuilder.setAttr("output_shapes", outputShapesArray); - if (options != null) { - for (Options opts : options) { - if (opts.deterministic != null) { - opBuilder.setAttr("deterministic", opts.deterministic); - } - } - } return new ParallelInterleaveDataset(opBuilder.build()); } - /** - * Sets the deterministic option. - * - * @param deterministic A string indicating the op-level determinism to use. Deterministic controls - * whether the interleave is allowed to return elements out of order if the next - * element to be returned isn't available, but a later element is. Options are - * "true", "false", and "default". "default" indicates that determinism should be - * decided by the {@code experimental_deterministic} parameter of {@code tf.data.Options}. - * @return this Options instance. - */ - public static Options deterministic(String deterministic) { - return new Options().deterministic(deterministic); - } - /** * Gets handle. * @@ -149,29 +115,4 @@ public Output<? extends TType> handle() { public Output<TType> asOutput() { return (Output<TType>) handle; } - - /** - * Optional attributes for {@link org.tensorflow.op.data.experimental.ParallelInterleaveDataset} - */ - public static class Options { - private String deterministic; - - private Options() { - } - - /** - * Sets the deterministic option. - * - * @param deterministic A string indicating the op-level determinism to use. Deterministic controls - * whether the interleave is allowed to return elements out of order if the next - * element to be returned isn't available, but a later element is. Options are - * "true", "false", and "default". "default" indicates that determinism should be - * decided by the {@code experimental_deterministic} parameter of {@code tf.data.Options}. - * @return this Options instance. - */ - public Options deterministic(String deterministic) { - this.deterministic = deterministic; - return this; - } - } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParseExampleDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParseExampleDataset.java index c3e179c29db..6a569cb0c68 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParseExampleDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParseExampleDataset.java @@ -17,7 +17,6 @@ package org.tensorflow.op.data.experimental; -import java.util.Arrays; import java.util.List; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -29,7 +28,6 @@ import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; /** @@ -39,7 +37,7 @@ public final class ParseExampleDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "ParseExampleDatasetV2"; + public static final String OP_NAME = "ExperimentalParseExampleDataset"; private Output<? extends TType> handle; @@ -51,7 +49,7 @@ private ParseExampleDataset(Operation operation) { } /** - * Factory method to create a class wrapping a new ParseExampleDatasetV2 operation. + * Factory method to create a class wrapping a new ExperimentalParseExampleDataset operation. * * @param scope current scope * @param inputDataset the inputDataset value @@ -76,8 +74,6 @@ private ParseExampleDataset(Operation operation) { * given feature along this dimension. * @param outputTypes The type list for the return values. * @param outputShapes The list of shapes being produced. - * @param raggedValueTypes the value of the raggedValueTypes property - * @param raggedSplitTypes the value of the raggedSplitTypes property * @param options carries optional attribute values * @return a new instance of ParseExampleDataset */ @@ -87,9 +83,7 @@ private ParseExampleDataset(Operation operation) { public static ParseExampleDataset create(Scope scope, Operand<? extends TType> inputDataset, Operand<TInt64> numParallelCalls, Iterable<Operand<?>> denseDefaults, List<String> sparseKeys, List<String> denseKeys, List<Class<? extends TType>> sparseTypes, List<Shape> denseShapes, - List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, - List<Class<? extends TType>> raggedValueTypes, - List<Class<? extends TNumber>> raggedSplitTypes, Options... options) { + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("ParseExampleDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numParallelCalls.asOutput()); @@ -117,19 +111,10 @@ public static ParseExampleDataset create(Scope scope, Operand<? extends TType> i outputShapesArray[i] = outputShapes.get(i); } opBuilder.setAttr("output_shapes", outputShapesArray); - opBuilder.setAttr("ragged_value_types", Operands.toDataTypes(raggedValueTypes)); - opBuilder.setAttr("ragged_split_types", Operands.toDataTypes(raggedSplitTypes)); if (options != null) { for (Options opts : options) { - if (opts.deterministic != null) { - opBuilder.setAttr("deterministic", opts.deterministic); - } - if (opts.raggedKeys != null) { - String[] raggedKeysArray = new String[opts.raggedKeys.size()]; - for (int i = 0 ; i < raggedKeysArray.length ; i++) { - raggedKeysArray[i] = opts.raggedKeys.get(i); - } - opBuilder.setAttr("ragged_keys", raggedKeysArray); + if (opts.sloppy != null) { + opBuilder.setAttr("sloppy", opts.sloppy); } } } @@ -137,37 +122,13 @@ public static ParseExampleDataset create(Scope scope, Operand<? extends TType> i } /** - * Sets the deterministic option. - * - * @param deterministic A string indicating the op-level determinism to use. Deterministic controls - * whether the dataset is allowed to return elements out of order if the next - * element to be returned isn't available, but a later element is. Options are - * "true", "false", and "default". "default" indicates that determinism should be - * decided by the {@code experimental_deterministic} parameter of {@code tf.data.Options}. - * @return this Options instance. - */ - public static Options deterministic(String deterministic) { - return new Options().deterministic(deterministic); - } - - /** - * Sets the raggedKeys option. + * Sets the sloppy option. * - * @param raggedKeys the raggedKeys option + * @param sloppy the sloppy option * @return this Options instance. */ - public static Options raggedKeys(List<String> raggedKeys) { - return new Options().raggedKeys(raggedKeys); - } - - /** - * Sets the raggedKeys option. - * - * @param raggedKeys the raggedKeys option - * @return this Options instance. - */ - public static Options raggedKeys(String[] raggedKeys) { - return new Options().raggedKeys(raggedKeys); + public static Options sloppy(Boolean sloppy) { + return new Options().sloppy(sloppy); } /** @@ -189,47 +150,19 @@ public Output<TType> asOutput() { * Optional attributes for {@link org.tensorflow.op.data.experimental.ParseExampleDataset} */ public static class Options { - private String deterministic; - - private List<String> raggedKeys; + private Boolean sloppy; private Options() { } /** - * Sets the deterministic option. - * - * @param deterministic A string indicating the op-level determinism to use. Deterministic controls - * whether the dataset is allowed to return elements out of order if the next - * element to be returned isn't available, but a later element is. Options are - * "true", "false", and "default". "default" indicates that determinism should be - * decided by the {@code experimental_deterministic} parameter of {@code tf.data.Options}. - * @return this Options instance. - */ - public Options deterministic(String deterministic) { - this.deterministic = deterministic; - return this; - } - - /** - * Sets the raggedKeys option. - * - * @param raggedKeys the raggedKeys option - * @return this Options instance. - */ - public Options raggedKeys(List<String> raggedKeys) { - this.raggedKeys = raggedKeys; - return this; - } - - /** - * Sets the raggedKeys option. + * Sets the sloppy option. * - * @param raggedKeys the raggedKeys option + * @param sloppy the sloppy option * @return this Options instance. */ - public Options raggedKeys(String... raggedKeys) { - this.raggedKeys = Arrays.asList(raggedKeys); + public Options sloppy(Boolean sloppy) { + this.sloppy = sloppy; return this; } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorHandle.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorHandle.java index e4183305b0f..4f1ea92e22d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorHandle.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorHandle.java @@ -27,13 +27,13 @@ import org.tensorflow.types.family.TType; /** - * The StatsAggregatorHandleV2 operation + * Creates a statistics manager resource. */ public final class StatsAggregatorHandle extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "StatsAggregatorHandleV2"; + public static final String OP_NAME = "ExperimentalStatsAggregatorHandle"; private Output<? extends TType> handle; @@ -45,7 +45,7 @@ private StatsAggregatorHandle(Operation operation) { } /** - * Factory method to create a class wrapping a new StatsAggregatorHandleV2 operation. + * Factory method to create a class wrapping a new ExperimentalStatsAggregatorHandle operation. * * @param scope current scope * @param options carries optional attribute values diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DummySeedGenerator.java similarity index 97% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DummySeedGenerator.java index b5f8baccd92..068f85a76df 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DummySeedGenerator.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.random.experimental; +package org.tensorflow.op.random; import org.tensorflow.Operand; import org.tensorflow.Operation; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Conv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Conv.java index 1fe493c4223..fc238e7ebc5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Conv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Conv.java @@ -21,6 +21,7 @@ import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -33,18 +34,18 @@ * https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution * . * - * @param <T> data type for {@code output} output + * @param <W> data type for {@code output} output */ @Operator( group = "xla" ) -public final class Conv<T extends TType> extends RawOp implements Operand<T> { +public final class Conv<W extends TType> extends RawOp implements Operand<W> { /** * The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "XlaConv"; + public static final String OP_NAME = "XlaConvV2"; - private Output<T> output; + private Output<W> output; private Conv(Operation operation) { super(operation); @@ -53,7 +54,7 @@ private Conv(Operation operation) { } /** - * Factory method to create a class wrapping a new XlaConv operation. + * Factory method to create a class wrapping a new XlaConvV2 operation. * * @param scope current scope * @param lhs the input tensor @@ -65,17 +66,19 @@ private Conv(Operation operation) { * @param featureGroupCount number of feature groups for grouped convolution. * @param dimensionNumbers a serialized xla::ConvolutionDimensionNumbers proto. * @param precisionConfig a serialized xla::PrecisionConfig proto. - * @param <T> data type for {@code XlaConv} output and operands - * @param <U> data type for {@code XlaConv} output and operands + * @param preferredElementType The type of the tensor. + * @param <W> data type for {@code XlaConvV2} output and operands + * @param <V> data type for {@code XlaConvV2} output and operands * @return a new instance of Conv */ @Endpoint( describeByClass = true ) - public static <T extends TType, U extends TNumber> Conv<T> create(Scope scope, Operand<T> lhs, - Operand<T> rhs, Operand<U> windowStrides, Operand<U> padding, Operand<U> lhsDilation, - Operand<U> rhsDilation, Operand<U> featureGroupCount, String dimensionNumbers, - String precisionConfig) { + public static <W extends TType, V extends TNumber> Conv<W> create(Scope scope, + Operand<? extends TType> lhs, Operand<? extends TType> rhs, Operand<V> windowStrides, + Operand<V> padding, Operand<V> lhsDilation, Operand<V> rhsDilation, + Operand<V> featureGroupCount, String dimensionNumbers, String precisionConfig, + Class<W> preferredElementType) { OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("Conv")); opBuilder.addInput(lhs.asOutput()); opBuilder.addInput(rhs.asOutput()); @@ -87,6 +90,7 @@ public static <T extends TType, U extends TNumber> Conv<T> create(Scope scope, O opBuilder = scope.apply(opBuilder); opBuilder.setAttr("dimension_numbers", dimensionNumbers); opBuilder.setAttr("precision_config", precisionConfig); + opBuilder.setAttr("preferred_element_type", Operands.toDataType(preferredElementType)); return new Conv<>(opBuilder.build()); } @@ -95,12 +99,12 @@ public static <T extends TType, U extends TNumber> Conv<T> create(Scope scope, O * * @return output. */ - public Output<T> output() { + public Output<W> output() { return output; } @Override - public Output<T> asOutput() { + public Output<W> asOutput() { return output; } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Dot.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Dot.java index 3de0491e718..420ad933a2b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Dot.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Dot.java @@ -21,6 +21,7 @@ import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -32,18 +33,18 @@ * https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral * . * - * @param <T> data type for {@code output} output + * @param <V> data type for {@code output} output */ @Operator( group = "xla" ) -public final class Dot<T extends TType> extends RawOp implements Operand<T> { +public final class Dot<V extends TType> extends RawOp implements Operand<V> { /** * The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "XlaDot"; + public static final String OP_NAME = "XlaDotV2"; - private Output<T> output; + private Output<V> output; private Dot(Operation operation) { super(operation); @@ -52,27 +53,30 @@ private Dot(Operation operation) { } /** - * Factory method to create a class wrapping a new XlaDot operation. + * Factory method to create a class wrapping a new XlaDotV2 operation. * * @param scope current scope * @param lhs the LHS tensor * @param rhs the RHS tensor * @param dimensionNumbers a serialized xla::DotDimensionNumbers proto. * @param precisionConfig a serialized xla::PrecisionConfig proto. - * @param <T> data type for {@code XlaDot} output and operands + * @param preferredElementType The type of the tensor. + * @param <V> data type for {@code XlaDotV2} output and operands * @return a new instance of Dot */ @Endpoint( describeByClass = true ) - public static <T extends TType> Dot<T> create(Scope scope, Operand<T> lhs, Operand<T> rhs, - String dimensionNumbers, String precisionConfig) { + public static <V extends TType> Dot<V> create(Scope scope, Operand<? extends TType> lhs, + Operand<? extends TType> rhs, String dimensionNumbers, String precisionConfig, + Class<V> preferredElementType) { OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("Dot")); opBuilder.addInput(lhs.asOutput()); opBuilder.addInput(rhs.asOutput()); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("dimension_numbers", dimensionNumbers); opBuilder.setAttr("precision_config", precisionConfig); + opBuilder.setAttr("preferred_element_type", Operands.toDataType(preferredElementType)); return new Dot<>(opBuilder.build()); } @@ -81,12 +85,12 @@ public static <T extends TType> Dot<T> create(Scope scope, Operand<T> lhs, Opera * * @return output. */ - public Output<T> output() { + public Output<V> output() { return output; } @Override - public Output<T> asOutput() { + public Output<V> asOutput() { return output; } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSetDynamicDimensionSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SetDynamicDimensionSize.java similarity index 81% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSetDynamicDimensionSize.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SetDynamicDimensionSize.java index 2ab7f6605ab..7388d3d9048 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSetDynamicDimensionSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SetDynamicDimensionSize.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.core; +package org.tensorflow.op.xla; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -37,8 +37,10 @@ * * @param <T> data type for {@code output} output */ -@Operator -public final class XlaSetDynamicDimensionSize<T extends TType> extends RawOp implements Operand<T> { +@Operator( + group = "xla" +) +public final class SetDynamicDimensionSize<T extends TType> extends RawOp implements Operand<T> { /** * The name of this op, as known by TensorFlow core engine */ @@ -46,7 +48,7 @@ public final class XlaSetDynamicDimensionSize<T extends TType> extends RawOp imp private Output<T> output; - private XlaSetDynamicDimensionSize(Operation operation) { + private SetDynamicDimensionSize(Operation operation) { super(operation); int outputIdx = 0; output = operation.output(outputIdx++); @@ -60,19 +62,19 @@ private XlaSetDynamicDimensionSize(Operation operation) { * @param dimIndex the dimIndex value * @param sizeOutput the sizeOutput value * @param <T> data type for {@code XlaSetDynamicDimensionSize} output and operands - * @return a new instance of XlaSetDynamicDimensionSize + * @return a new instance of SetDynamicDimensionSize */ @Endpoint( describeByClass = true ) - public static <T extends TType> XlaSetDynamicDimensionSize<T> create(Scope scope, - Operand<T> input, Operand<TInt32> dimIndex, Operand<TInt32> sizeOutput) { - OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("XlaSetDynamicDimensionSize")); + public static <T extends TType> SetDynamicDimensionSize<T> create(Scope scope, Operand<T> input, + Operand<TInt32> dimIndex, Operand<TInt32> sizeOutput) { + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("SetDynamicDimensionSize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(dimIndex.asOutput()); opBuilder.addInput(sizeOutput.asOutput()); opBuilder = scope.apply(opBuilder); - return new XlaSetDynamicDimensionSize<>(opBuilder.build()); + return new SetDynamicDimensionSize<>(opBuilder.build()); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdFullToShardShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SpmdFullToShardShape.java similarity index 84% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdFullToShardShape.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SpmdFullToShardShape.java index cedee394b63..2cb2377382a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdFullToShardShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SpmdFullToShardShape.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.core; +package org.tensorflow.op.xla; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -36,8 +36,10 @@ * * @param <T> data type for {@code output} output */ -@Operator -public final class XlaSpmdFullToShardShape<T extends TType> extends RawOp implements Operand<T> { +@Operator( + group = "xla" +) +public final class SpmdFullToShardShape<T extends TType> extends RawOp implements Operand<T> { /** * The name of this op, as known by TensorFlow core engine */ @@ -45,7 +47,7 @@ public final class XlaSpmdFullToShardShape<T extends TType> extends RawOp implem private Output<T> output; - private XlaSpmdFullToShardShape(Operation operation) { + private SpmdFullToShardShape(Operation operation) { super(operation); int outputIdx = 0; output = operation.output(outputIdx++); @@ -58,18 +60,18 @@ private XlaSpmdFullToShardShape(Operation operation) { * @param input the input value * @param manualSharding the value of the manualSharding property * @param <T> data type for {@code XlaSpmdFullToShardShape} output and operands - * @return a new instance of XlaSpmdFullToShardShape + * @return a new instance of SpmdFullToShardShape */ @Endpoint( describeByClass = true ) - public static <T extends TType> XlaSpmdFullToShardShape<T> create(Scope scope, Operand<T> input, + public static <T extends TType> SpmdFullToShardShape<T> create(Scope scope, Operand<T> input, String manualSharding) { - OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("XlaSpmdFullToShardShape")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("SpmdFullToShardShape")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("manual_sharding", manualSharding); - return new XlaSpmdFullToShardShape<>(opBuilder.build()); + return new SpmdFullToShardShape<>(opBuilder.build()); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdShardToFullShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SpmdShardToFullShape.java similarity index 84% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdShardToFullShape.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SpmdShardToFullShape.java index dc5e3de5834..c4e49461a22 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdShardToFullShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/SpmdShardToFullShape.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.core; +package org.tensorflow.op.xla; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -36,8 +36,10 @@ * * @param <T> data type for {@code output} output */ -@Operator -public final class XlaSpmdShardToFullShape<T extends TType> extends RawOp implements Operand<T> { +@Operator( + group = "xla" +) +public final class SpmdShardToFullShape<T extends TType> extends RawOp implements Operand<T> { /** * The name of this op, as known by TensorFlow core engine */ @@ -45,7 +47,7 @@ public final class XlaSpmdShardToFullShape<T extends TType> extends RawOp implem private Output<T> output; - private XlaSpmdShardToFullShape(Operation operation) { + private SpmdShardToFullShape(Operation operation) { super(operation); int outputIdx = 0; output = operation.output(outputIdx++); @@ -59,19 +61,19 @@ private XlaSpmdShardToFullShape(Operation operation) { * @param manualSharding the value of the manualSharding property * @param fullShape the value of the fullShape property * @param <T> data type for {@code XlaSpmdShardToFullShape} output and operands - * @return a new instance of XlaSpmdShardToFullShape + * @return a new instance of SpmdShardToFullShape */ @Endpoint( describeByClass = true ) - public static <T extends TType> XlaSpmdShardToFullShape<T> create(Scope scope, Operand<T> input, + public static <T extends TType> SpmdShardToFullShape<T> create(Scope scope, Operand<T> input, String manualSharding, Shape fullShape) { - OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("XlaSpmdShardToFullShape")); + OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("SpmdShardToFullShape")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("manual_sharding", manualSharding); opBuilder.setAttr("full_shape", fullShape); - return new XlaSpmdShardToFullShape<>(opBuilder.build()); + return new SpmdShardToFullShape<>(opBuilder.build()); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaVariadicSort.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicSort.java similarity index 98% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaVariadicSort.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicSort.java index 5a4eafda539..fdb7d680cc6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaVariadicSort.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaVariadicSort.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.core; +package org.tensorflow.op.xla; import java.util.Arrays; import java.util.Iterator; @@ -40,7 +40,9 @@ * <p>Sorts one or more tensors, with support for custom comparator, dimension, and * is_stable attributes. */ -@Operator +@Operator( + group = "xla" +) public final class XlaVariadicSort extends RawOp implements Iterable<Operand<TType>> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/resources/ops.pb b/tensorflow-core/tensorflow-core-api/src/gen/resources/ops.pb index 4c3e6bef038286ce7eb88eeb01798070a47682f6..14aed40998768ed31aaf82bf24803c9751f126bf 100644 GIT binary patch delta 1048 zcmZXSZD?Cn7{__e&CScX$&&QmoSUS{y4e>CYCnwOr-6ORC?a-to!ux*xF#1Y%S%me z$}-oLnyptqSjCMVnXuI_+Jay!(g$HwHoA#2nEOzsScE8ZewiqpR%ubsO<S4x;m<h_ z{Ql4LKQ9+P9^wl_^WC%LH*!8TJJog<Xtb+TGK*G^QOpeIjqI>hNg0+=GOdslZ=%VJ zD5~<DEzbeUbHscc&hZl~<b~w>K!ZJn+(^+Zl@4ajoSAnVjB1X<URxo!wj!Nwg)eYq zLrMWTEAy8(B#jCvM)f*W+mx`j>AD|-BEQAtc6i!ZbVVQu8h4Nmew)b=en6K?aDvb4 zaw{Bk7EKpSBaIb}e5R1=NR@NB%7AHRdii);IBvH`Wg$1VM?J^7VFM?IJVioY3-Zu$ z&wty`Tovtqs?oz(IqsQw0H5GSzqb#_T_4}<_cl9G+GxvGl)w7DO?>t%-(LRKi%JXs z@Sc1C+m3iYTXfX|v;p6|BgbIRvO*i4HAcGfnQp^Mzm(d(OV~Ly*1ncMknQELPl)1j zSoxug+==s0r;@<y14<8&-}QRs-!F#mp}V(+|N4{?!?FE}5Aew@<vGBGqsr|k$OWCB zu@xQPJ*U(FF7h#3Y4pZvv~I%7W6G%8Zk|$v{MbAdI6VnBa7NqJ3W*sv4IDNz<+Qnz zQHpD{p>DKSBfErxLLs5BPy?<;)bXtA9#Eb4PpA=Yy{h)Q?42?&xbNX-s}sAayN}Q% zeyOTH!pl{)ovZ=BTvc1V|K90)Rkg`pud29S)ec>O?|5%#<ZRP7K*P`H)BpQjQOcZr zxZWALv;?2x^xu)alEyvMb1ZB06!N1X68F^$yh+%k1oJrHW{XnbPY?BNBZK?!8e!i9 z_s3ZRBYu|l!ZcPiHV$lBqHcku3YKW(-P!Z2G4>1(^)ZSW#ySAr!1FrONbC=Zdbe<K zcyK14%o-iSjlvN-0`K6rIrb=!UkG2#u>=)Nk~TUFujQDIodx!#i#+h~pUW(Utrq(n z;B9`t%#t4Qgmo@#F^&Hi4nKy6M%e<u9KuOf0r-f&Qh5lQsBB|gV9krC*q9%#;h9M$ iNy#NQ4T_Y5Mse8445tSQMN0&UP}LcL7hh+mg8u;Hh)^8> delta 1017 zcmZWnZ)jUp6z9B`mzUhPEqP7eeNCFub)zz5c5B@S6^i{69SCl$(NKaE9?1i%$xBII zwpKfmHJxumHkh3rK_+Y;T>PMH4t8Yd;=nmJ8KSJ>V4zsp1g*m6l)4RlX;bF#!;kwr zzx%uAo^#K|lc~cWrA`diglTa$UOTAW1lBrGESd#-$S7pSEF(K+PvY*|Qb%>M<}B8L z7HdLgBfNp5bHeDur+@{9^11PXSu8%4HFKtA2gT@m)|oMEquZPqHw$A8(a1Vi=LEVs zCr)+2x0Jji#(^J_SWDdXCCtLkv9#HsMbpO1r=^%vzbjJxuIJ|nl<{jW=~(BI3dshK z;wisZ!&_Y12NO7@OI?o-yNOD<++@NuGh0o|EEsma;1;rhJybBPOg{HODc$$IxT!;p zIyZF5g_myV^2i|kN&P9gAPCg}4y0t=<L1$q7_I$W8kVUqBkx&&b5!5q+Xli-KYq5u zx4zYNJ<59i*LF9m310`jknr^rT%#k*AA?w3VJ*)Z;{#S^(6G}lHX9EL(J*VFqu)vI zNlr!eyI@pNmH7eBPf^x}J$*`)_9T=c@KjryF>Jxx*QH*1VvFJf3MQ4#O~O$fA8`~N z+q%>+jqX)0102VVj?zZ)1?hFKbN+zhVilb~uzKn>|2ZF9J#1!5X;goeV|eH3;A8l1 zc+JK3cf8GwmBF|j2392eA+DyJd2Nl0fQz6D)kO=<YoU&;=MP{y_B6)SHXYjPIlK}I z$zD9+_r5ByFL7Zy^bAxQ%R2A!>){v7yY8)ODB#RgLo`zjZ~p{-!7EQ{Gum}v>ho6m z!LNp2eJr)oM89_AESw?!kJc$FZ_7*`$r?j>YbU+w=Rd-11$}sLm~(tz;UC~{9zG8z z=^Kfk7Gr<Q%-=L8dxH>F{u9!6fnNdK7vV9QRe0J5NPmTR85oMp>z<0|ZQLB;BaI8_ zdyQ`fI7mUAhlS{($b8F=rL^pRM12t+c}_qHE0&ws&T$6tJC-fJ7GlfVrxw5D5pK!& za*0RiSdm`<ID{{icsoRzT8Uo+s8M+rp9JAU5i7fR#N*Z<Sz>ss#J5p-lJ8yxXX)Nv UE{bplVL#VIl?AAMKYu0gFUpNp!2kdN From ebf2817a128afb9a530d5dd8ecb70c2304187bb4 Mon Sep 17 00:00:00 2001 From: klessard <klessard@expedia.com> Date: Sat, 26 Jun 2021 11:58:28 -0400 Subject: [PATCH 14/14] Cleanup op dataset ops and make them visible --- .../api_def_AssertCardinalityDataset.pbtxt | 1 + .../api_def/api_def_AssertNextDataset.pbtxt | 1 + .../api_def/api_def_AutoShardDataset.pbtxt | 1 + .../api_def_BytesProducedStatsDataset.pbtxt | 1 + .../bazel/api_def/api_def_CSVDataset.pbtxt | 4 +- .../bazel/api_def/api_def_CSVDatasetV2.pbtxt | 3 +- .../bazel/api_def/api_def_CacheDataset.pbtxt | 4 +- .../api_def/api_def_CacheDatasetV2.pbtxt | 3 +- .../api_def_ChooseFastestBranchDataset.pbtxt | 1 + .../api_def_ChooseFastestDataset.pbtxt | 1 + .../api_def/api_def_DataServiceDataset.pbtxt | 4 +- .../api_def_DataServiceDatasetV2.pbtxt | 3 +- .../api_def/api_def_DatasetCardinality.pbtxt | 1 + .../api_def/api_def_DatasetFromGraph.pbtxt | 1 + .../api_def/api_def_DatasetToGraphV2.pbtxt | 1 + .../api_def_DatasetToSingleElement.pbtxt | 1 + .../api_def/api_def_DatasetToTFRecord.pbtxt | 1 + .../api_def_DenseToSparseBatchDataset.pbtxt | 1 + .../api_def_DirectedInterleaveDataset.pbtxt | 1 + .../api_def_EnqueueInQueueDataset.pbtxt | 1 + ...api_def_FilterByLastComponentDataset.pbtxt | 1 + .../api_def/api_def_FinalizeDataset.pbtxt | 3 +- .../api_def_FixedLengthRecordDatasetV2.pbtxt | 1 + .../api_def/api_def_GeneratorDataset.pbtxt | 1 + .../api_def_GroupByReducerDataset.pbtxt | 1 + .../api_def_GroupByWindowDataset.pbtxt | 1 + .../api_def/api_def_IgnoreErrorsDataset.pbtxt | 1 + .../api_def_InitializeTableFromDataset.pbtxt | 1 + .../bazel/api_def/api_def_KafkaDataset.pbtxt | 1 + .../bazel/api_def/api_def_LMDBDataset.pbtxt | 1 + .../api_def/api_def_LatencyStatsDataset.pbtxt | 1 + ...ef_LegacyParallelInterleaveDatasetV2.pbtxt | 1 + .../bazel/api_def/api_def_LoadDataset.pbtxt | 1 + .../api_def_MatchingFilesDataset.pbtxt | 1 + ...api_def_MaxIntraOpParallelismDataset.pbtxt | 1 + .../bazel/api_def/api_def_ModelDataset.pbtxt | 1 + .../api_def_NonSerializableDataset.pbtxt | 1 + .../api_def/api_def_OptimizeDataset.pbtxt | 4 +- .../api_def/api_def_OptimizeDatasetV2.pbtxt | 3 +- .../api_def/api_def_OptionsDataset.pbtxt | 3 +- .../api_def_PaddedBatchDatasetV2.pbtxt | 1 + .../api_def_ParallelBatchDataset.pbtxt | 3 +- .../api_def_ParallelInterleaveDatasetV4.pbtxt | 1 + .../api_def_ParallelMapDatasetV2.pbtxt | 1 + .../api_def_ParseExampleDatasetV2.pbtxt | 1 + .../api_def/api_def_PrefetchDataset.pbtxt | 1 + ...rependFromQueueAndPaddedBatchDataset.pbtxt | 1 + .../api_def_PrivateThreadPoolDataset.pbtxt | 1 + .../bazel/api_def/api_def_RandomDataset.pbtxt | 1 + .../api_def/api_def_RebatchDataset.pbtxt | 4 +- .../api_def/api_def_RebatchDatasetV2.pbtxt | 1 + .../bazel/api_def/api_def_ReduceDataset.pbtxt | 1 + .../api_def/api_def_RegisterDataset.pbtxt | 1 + .../api_def/api_def_SamplingDataset.pbtxt | 1 + .../bazel/api_def/api_def_SaveDataset.pbtxt | 1 + .../bazel/api_def/api_def_ScanDataset.pbtxt | 1 + .../api_def_SetStatsAggregatorDataset.pbtxt | 1 + .../bazel/api_def/api_def_ShardDataset.pbtxt | 1 + .../api_def_ShuffleAndRepeatDatasetV2.pbtxt | 1 + .../api_def/api_def_ShuffleDataset.pbtxt | 5 +- .../api_def/api_def_ShuffleDatasetV3.pbtxt | 1 + .../bazel/api_def/api_def_SleepDataset.pbtxt | 1 + .../api_def_SlidingWindowDataset.pbtxt | 1 + .../api_def/api_def_SnapshotDatasetV2.pbtxt | 1 + .../api_def_SparseTensorSliceDataset.pbtxt | 1 + .../bazel/api_def/api_def_SqlDataset.pbtxt | 1 + .../api_def/api_def_TakeWhileDataset.pbtxt | 1 + .../bazel/api_def/api_def_TensorDataset.pbtxt | 1 + .../api_def/api_def_ThreadPoolDataset.pbtxt | 1 + .../api_def/api_def_UnbatchDataset.pbtxt | 1 + .../bazel/api_def/api_def_UniqueDataset.pbtxt | 1 + .../api_def_UnwrapDatasetVariant.pbtxt | 1 + .../bazel/api_def/api_def_WindowDataset.pbtxt | 1 + .../api_def/api_def_WrapDatasetVariant.pbtxt | 1 + .../org/tensorflow/op/DataOps.java | 1346 ++++++++++++++++- .../op/data/AssertCardinalityDataset.java | 4 + .../tensorflow/op/data/AssertNextDataset.java | 4 + .../tensorflow/op/data/AutoShardDataset.java | 4 + .../op/data/BytesProducedStatsDataset.java | 4 + .../org/tensorflow/op/data/CSVDataset.java | 15 +- .../org/tensorflow/op/data/CSVDatasetV2.java | 112 -- .../org/tensorflow/op/data/CacheDataset.java | 23 +- .../tensorflow/op/data/CacheDatasetV2.java | 96 -- .../op/data/ChooseFastestBranchDataset.java | 4 + .../op/data/ChooseFastestDataset.java | 4 + .../op/data/DataServiceDataset.java | 172 --- .../DataServiceDatasetV2.java | 8 +- .../op/data/DatasetCardinality.java | 4 + .../tensorflow/op/data/DatasetFromGraph.java | 4 + .../tensorflow/op/data/DatasetToGraph.java | 4 + .../op/data/DatasetToSingleElement.java | 4 + .../tensorflow/op/data/DatasetToTfRecord.java | 4 + .../op/data/DenseToSparseBatchDataset.java | 4 + .../op/data/DirectedInterleaveDataset.java | 4 + .../op/data/FilterByLastComponentDataset.java | 4 + .../op/{rawops => data}/FinalizeDataset.java | 8 +- .../op/data/FixedLengthRecordDataset.java | 4 + .../tensorflow/op/data/GeneratorDataset.java | 4 + .../op/data/GroupByReducerDataset.java | 4 + .../op/data/GroupByWindowDataset.java | 4 + .../op/data/IgnoreErrorsDataset.java | 4 + .../op/data/InitializeTableFromDataset.java | 4 + .../org/tensorflow/op/data/LMDBDataset.java | 4 + .../op/data/LatencyStatsDataset.java | 4 + .../data/LegacyParallelInterleaveDataset.java | 4 + .../org/tensorflow/op/data/LoadDataset.java | 4 + .../op/data/MatchingFilesDataset.java | 4 + .../op/data/MaxIntraOpParallelismDataset.java | 4 + .../org/tensorflow/op/data/ModelDataset.java | 4 + .../op/data/NonSerializableDataset.java | 4 + .../tensorflow/op/data/OptimizeDataset.java | 23 +- .../tensorflow/op/data/OptimizeDatasetV2.java | 165 -- .../op/{rawops => data}/OptionsDataset.java | 6 +- .../op/data/PaddedBatchDataset.java | 4 + .../ParallelBatchDataset.java | 8 +- .../op/data/ParallelInterleaveDataset.java | 4 + .../op/data/ParallelMapDataset.java | 4 + .../op/data/ParseExampleDataset.java | 4 + .../tensorflow/op/data/PrefetchDataset.java | 4 + .../op/data/PrivateThreadPoolDataset.java | 4 + .../org/tensorflow/op/data/RandomDataset.java | 4 + .../tensorflow/op/data/RebatchDataset.java | 137 -- .../tensorflow/op/data/RebatchDatasetV2.java | 4 + .../org/tensorflow/op/data/ReduceDataset.java | 4 + .../tensorflow/op/data/RegisterDataset.java | 4 + .../tensorflow/op/data/SamplingDataset.java | 4 + .../org/tensorflow/op/data/SaveDataset.java | 4 + .../org/tensorflow/op/data/ScanDataset.java | 4 + .../op/data/SetStatsAggregatorDataset.java | 4 + .../org/tensorflow/op/data/ShardDataset.java | 4 + .../op/data/ShuffleAndRepeatDataset.java | 4 + .../tensorflow/op/data/ShuffleDataset.java | 4 + .../org/tensorflow/op/data/SleepDataset.java | 4 + .../op/data/SlidingWindowDataset.java | 4 + .../tensorflow/op/data/SnapshotDataset.java | 4 + .../op/data/SparseTensorSliceDataset.java | 4 + .../org/tensorflow/op/data/SqlDataset.java | 4 + .../tensorflow/op/data/TakeWhileDataset.java | 4 + .../org/tensorflow/op/data/TensorDataset.java | 4 + .../tensorflow/op/data/ThreadPoolDataset.java | 4 + .../tensorflow/op/data/UnbatchDataset.java | 4 + .../org/tensorflow/op/data/UniqueDataset.java | 4 + .../op/data/UnwrapDatasetVariant.java | 4 + .../org/tensorflow/op/data/WindowDataset.java | 4 + .../op/data/WrapDatasetVariant.java | 4 + .../src/gen/resources/ops.pb | Bin 1480855 -> 1480811 bytes 146 files changed, 1702 insertions(+), 756 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CSVDatasetV2.java delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDatasetV2.java delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DataServiceDataset.java rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/{rawops => data}/DataServiceDatasetV2.java (96%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/{rawops => data}/FinalizeDataset.java (95%) delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDatasetV2.java rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/{rawops => data}/OptionsDataset.java (96%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/{rawops => data}/ParallelBatchDataset.java (95%) delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDataset.java diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AssertCardinalityDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AssertCardinalityDataset.pbtxt index 1701d76bf52..fc93c13b627 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AssertCardinalityDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AssertCardinalityDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "AssertCardinalityDataset" + visibility: VISIBLE endpoint { name: "data.AssertCardinalityDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AssertNextDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AssertNextDataset.pbtxt index 5991221a549..d85694ae56e 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AssertNextDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AssertNextDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "AssertNextDataset" + visibility: VISIBLE endpoint { name: "data.AssertNextDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AutoShardDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AutoShardDataset.pbtxt index 7a5039c6744..75488eb84eb 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AutoShardDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_AutoShardDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "AutoShardDataset" + visibility: VISIBLE endpoint { name: "data.AutoShardDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_BytesProducedStatsDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_BytesProducedStatsDataset.pbtxt index cd7f24d9614..b9d81b54105 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_BytesProducedStatsDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_BytesProducedStatsDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "BytesProducedStatsDataset" + visibility: VISIBLE endpoint { name: "data.BytesProducedStatsDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CSVDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CSVDataset.pbtxt index 5b2950c34fb..b4637f34b40 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CSVDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CSVDataset.pbtxt @@ -1,6 +1,4 @@ op { graph_op_name: "CSVDataset" - endpoint { - name: "data.CSVDataset" - } + visibility: SKIP } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CSVDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CSVDatasetV2.pbtxt index bba2cafbdb7..2c2848e9a74 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CSVDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CSVDatasetV2.pbtxt @@ -1,6 +1,7 @@ op { graph_op_name: "CSVDatasetV2" + visibility: VISIBLE endpoint { - name: "data.CSVDatasetV2" + name: "data.CSVDataset" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CacheDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CacheDataset.pbtxt index 11c26c1dfc5..4b4b9bb3b53 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CacheDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CacheDataset.pbtxt @@ -1,6 +1,4 @@ op { graph_op_name: "CacheDataset" - endpoint { - name: "data.CacheDataset" - } + visibility: SKIP } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CacheDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CacheDatasetV2.pbtxt index 2992d9ae109..8c5b58383c3 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CacheDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_CacheDatasetV2.pbtxt @@ -1,6 +1,7 @@ op { graph_op_name: "CacheDatasetV2" + visibility: VISIBLE endpoint { - name: "data.CacheDatasetV2" + name: "data.CacheDataset" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ChooseFastestBranchDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ChooseFastestBranchDataset.pbtxt index ee776c9b105..b3be1a9cd3b 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ChooseFastestBranchDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ChooseFastestBranchDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "ChooseFastestBranchDataset" + visibility: VISIBLE endpoint { name: "data.ChooseFastestBranchDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ChooseFastestDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ChooseFastestDataset.pbtxt index 7f6aadeca93..a25508efebc 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ChooseFastestDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ChooseFastestDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "ChooseFastestDataset" + visibility: VISIBLE endpoint { name: "data.ChooseFastestDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DataServiceDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DataServiceDataset.pbtxt index 1dd6077f8db..cfe0fc978fc 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DataServiceDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DataServiceDataset.pbtxt @@ -1,6 +1,4 @@ op { graph_op_name: "DataServiceDataset" - endpoint { - name: "data.DataServiceDataset" - } + visibility: SKIP } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DataServiceDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DataServiceDatasetV2.pbtxt index da39be5c1c1..083caab7391 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DataServiceDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DataServiceDatasetV2.pbtxt @@ -1,6 +1,7 @@ op { graph_op_name: "DataServiceDatasetV2" + visibility: VISIBLE endpoint { - name: "rawops.DataServiceDatasetV2" + name: "data.DataServiceDatasetV2" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetCardinality.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetCardinality.pbtxt index efb67649b45..6a699fc214e 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetCardinality.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetCardinality.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "DatasetCardinality" + visibility: VISIBLE endpoint { name: "data.DatasetCardinality" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetFromGraph.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetFromGraph.pbtxt index 5cf72797f85..80ea2a57196 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetFromGraph.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetFromGraph.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "DatasetFromGraph" + visibility: VISIBLE endpoint { name: "data.DatasetFromGraph" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetToGraphV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetToGraphV2.pbtxt index 8e5eceabbfa..99be66c5e54 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetToGraphV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetToGraphV2.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "DatasetToGraphV2" + visibility: VISIBLE endpoint { name: "data.DatasetToGraph" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetToSingleElement.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetToSingleElement.pbtxt index 0ac42e0e936..0f0407914d4 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetToSingleElement.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetToSingleElement.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "DatasetToSingleElement" + visibility: VISIBLE endpoint { name: "data.DatasetToSingleElement" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetToTFRecord.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetToTFRecord.pbtxt index 3d388570630..c4d256969f7 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetToTFRecord.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DatasetToTFRecord.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "DatasetToTFRecord" + visibility: VISIBLE endpoint { name: "data.DatasetToTfRecord" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DenseToSparseBatchDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DenseToSparseBatchDataset.pbtxt index 76f6ba0b8ac..106059e48f0 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DenseToSparseBatchDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DenseToSparseBatchDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "DenseToSparseBatchDataset" + visibility: VISIBLE endpoint { name: "data.DenseToSparseBatchDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DirectedInterleaveDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DirectedInterleaveDataset.pbtxt index 24ada998b23..60c9729704e 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DirectedInterleaveDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_DirectedInterleaveDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "DirectedInterleaveDataset" + visibility: VISIBLE endpoint { name: "data.DirectedInterleaveDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_EnqueueInQueueDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_EnqueueInQueueDataset.pbtxt index 26051ab446f..804c2fc317e 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_EnqueueInQueueDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_EnqueueInQueueDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "EnqueueInQueueDataset" + visibility: VISIBLE endpoint { name: "data.EnqueueInQueueDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_FilterByLastComponentDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_FilterByLastComponentDataset.pbtxt index b7111f48fa9..4ad74385bce 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_FilterByLastComponentDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_FilterByLastComponentDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "FilterByLastComponentDataset" + visibility: VISIBLE endpoint { name: "data.FilterByLastComponentDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_FinalizeDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_FinalizeDataset.pbtxt index ab2a5fa846a..78b37455e28 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_FinalizeDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_FinalizeDataset.pbtxt @@ -1,6 +1,7 @@ op { graph_op_name: "FinalizeDataset" + visibility: VISIBLE endpoint { - name: "rawops.FinalizeDataset" + name: "data.FinalizeDataset" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_FixedLengthRecordDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_FixedLengthRecordDatasetV2.pbtxt index b8012bbe168..b0f66ca1642 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_FixedLengthRecordDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_FixedLengthRecordDatasetV2.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "FixedLengthRecordDatasetV2" + visibility: VISIBLE endpoint { name: "data.FixedLengthRecordDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GeneratorDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GeneratorDataset.pbtxt index b1719005e99..9ff9c28330a 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GeneratorDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GeneratorDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "GeneratorDataset" + visibility: VISIBLE endpoint { name: "data.GeneratorDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GroupByReducerDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GroupByReducerDataset.pbtxt index 8f8ae87314e..57b781c5a0a 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GroupByReducerDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GroupByReducerDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "GroupByReducerDataset" + visibility: VISIBLE endpoint { name: "data.GroupByReducerDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GroupByWindowDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GroupByWindowDataset.pbtxt index 9e4c4cd4ff2..529ca47d86e 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GroupByWindowDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_GroupByWindowDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "GroupByWindowDataset" + visibility: VISIBLE endpoint { name: "data.GroupByWindowDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_IgnoreErrorsDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_IgnoreErrorsDataset.pbtxt index 86ff4bd8b87..a8812f70e1a 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_IgnoreErrorsDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_IgnoreErrorsDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "IgnoreErrorsDataset" + visibility: VISIBLE endpoint { name: "data.IgnoreErrorsDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_InitializeTableFromDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_InitializeTableFromDataset.pbtxt index dd2c8233b20..bfd98dd2d1c 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_InitializeTableFromDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_InitializeTableFromDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "InitializeTableFromDataset" + visibility: VISIBLE endpoint { name: "data.InitializeTableFromDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_KafkaDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_KafkaDataset.pbtxt index 5f0da216cbb..6055a53c894 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_KafkaDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_KafkaDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "KafkaDataset" + visibility: VISIBLE endpoint { name: "data.KafkaDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LMDBDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LMDBDataset.pbtxt index 9c4d0119d88..68de33fde0b 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LMDBDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LMDBDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "LMDBDataset" + visibility: VISIBLE endpoint { name: "data.LMDBDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LatencyStatsDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LatencyStatsDataset.pbtxt index bf0bf2a5ed7..8447fb7e287 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LatencyStatsDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LatencyStatsDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "LatencyStatsDataset" + visibility: VISIBLE endpoint { name: "data.LatencyStatsDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LegacyParallelInterleaveDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LegacyParallelInterleaveDatasetV2.pbtxt index 71012d0b557..4046256a580 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LegacyParallelInterleaveDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LegacyParallelInterleaveDatasetV2.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "LegacyParallelInterleaveDatasetV2" + visibility: VISIBLE endpoint { name: "data.LegacyParallelInterleaveDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LoadDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LoadDataset.pbtxt index 88dce6062a3..7bbd800c889 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LoadDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LoadDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "LoadDataset" + visibility: VISIBLE endpoint { name: "data.LoadDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MatchingFilesDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MatchingFilesDataset.pbtxt index 749257c37b5..2bee85539f8 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MatchingFilesDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MatchingFilesDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "MatchingFilesDataset" + visibility: VISIBLE endpoint { name: "data.MatchingFilesDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MaxIntraOpParallelismDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MaxIntraOpParallelismDataset.pbtxt index 2e90a1cbbd4..20a2e55ba80 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MaxIntraOpParallelismDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MaxIntraOpParallelismDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "MaxIntraOpParallelismDataset" + visibility: VISIBLE endpoint { name: "data.MaxIntraOpParallelismDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ModelDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ModelDataset.pbtxt index 143c7afd720..8549b4582cc 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ModelDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ModelDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "ModelDataset" + visibility: VISIBLE endpoint { name: "data.ModelDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_NonSerializableDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_NonSerializableDataset.pbtxt index a88d14e65c9..4efa2ec9978 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_NonSerializableDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_NonSerializableDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "NonSerializableDataset" + visibility: VISIBLE endpoint { name: "data.NonSerializableDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OptimizeDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OptimizeDataset.pbtxt index e7ddf97d1ab..4220b306071 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OptimizeDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OptimizeDataset.pbtxt @@ -1,6 +1,4 @@ op { graph_op_name: "OptimizeDataset" - endpoint { - name: "data.OptimizeDataset" - } + visibility: SKIP } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OptimizeDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OptimizeDatasetV2.pbtxt index 7f04d4d27a0..4b0214740a9 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OptimizeDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OptimizeDatasetV2.pbtxt @@ -1,6 +1,7 @@ op { graph_op_name: "OptimizeDatasetV2" + visibility: VISIBLE endpoint { - name: "data.OptimizeDatasetV2" + name: "data.OptimizeDataset" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OptionsDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OptionsDataset.pbtxt index e90dfd7bd04..0ccf9d12546 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OptionsDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OptionsDataset.pbtxt @@ -1,6 +1,7 @@ op { graph_op_name: "OptionsDataset" + visibility: VISIBLE endpoint { - name: "rawops.OptionsDataset" + name: "data.OptionsDataset" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PaddedBatchDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PaddedBatchDatasetV2.pbtxt index 22dfe84f0ca..0999120d361 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PaddedBatchDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PaddedBatchDatasetV2.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "PaddedBatchDatasetV2" + visibility: VISIBLE endpoint { name: "data.PaddedBatchDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelBatchDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelBatchDataset.pbtxt index f05138a1bd8..3fd9b1a423b 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelBatchDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelBatchDataset.pbtxt @@ -1,6 +1,7 @@ op { graph_op_name: "ParallelBatchDataset" + visibility: VISIBLE endpoint { - name: "rawops.ParallelBatchDataset" + name: "data.ParallelBatchDataset" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelInterleaveDatasetV4.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelInterleaveDatasetV4.pbtxt index 5ed14b3695c..7dddfc15a66 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelInterleaveDatasetV4.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelInterleaveDatasetV4.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "ParallelInterleaveDatasetV4" + visibility: VISIBLE endpoint { name: "data.ParallelInterleaveDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelMapDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelMapDatasetV2.pbtxt index 8439c4ebc0e..9135d8f5c99 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelMapDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParallelMapDatasetV2.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "ParallelMapDatasetV2" + visibility: VISIBLE endpoint { name: "data.ParallelMapDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParseExampleDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParseExampleDatasetV2.pbtxt index 3c88a87c8ab..4ce175177f6 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParseExampleDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ParseExampleDatasetV2.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "ParseExampleDatasetV2" + visibility: VISIBLE endpoint { name: "data.ParseExampleDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PrefetchDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PrefetchDataset.pbtxt index beaad84d153..a8a3423123d 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PrefetchDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PrefetchDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "PrefetchDataset" + visibility: VISIBLE endpoint { name: "data.PrefetchDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PrependFromQueueAndPaddedBatchDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PrependFromQueueAndPaddedBatchDataset.pbtxt index 7c9d509b163..a5e8f99de23 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PrependFromQueueAndPaddedBatchDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PrependFromQueueAndPaddedBatchDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "PrependFromQueueAndPaddedBatchDataset" + visibility: VISIBLE endpoint { name: "data.PrependFromQueueAndPaddedBatchDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PrivateThreadPoolDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PrivateThreadPoolDataset.pbtxt index 13490a99de5..4c0cfbe6698 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PrivateThreadPoolDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_PrivateThreadPoolDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "PrivateThreadPoolDataset" + visibility: VISIBLE endpoint { name: "data.PrivateThreadPoolDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RandomDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RandomDataset.pbtxt index 43921e6eafe..19975d0a927 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RandomDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RandomDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "RandomDataset" + visibility: VISIBLE endpoint { name: "data.RandomDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RebatchDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RebatchDataset.pbtxt index 7fac4c9c8d0..61d010ff033 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RebatchDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RebatchDataset.pbtxt @@ -1,6 +1,4 @@ op { graph_op_name: "RebatchDataset" - endpoint { - name: "data.RebatchDataset" - } + visibility: SKIP } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RebatchDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RebatchDatasetV2.pbtxt index f287e867d2d..42ab3bc1a6e 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RebatchDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RebatchDatasetV2.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "RebatchDatasetV2" + visibility: VISIBLE endpoint { name: "data.RebatchDatasetV2" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ReduceDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ReduceDataset.pbtxt index 4417f01ef90..4493b8a20ac 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ReduceDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ReduceDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "ReduceDataset" + visibility: VISIBLE endpoint { name: "data.ReduceDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RegisterDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RegisterDataset.pbtxt index b628dfdcf10..dca405191e5 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RegisterDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_RegisterDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "RegisterDataset" + visibility: VISIBLE endpoint { name: "data.RegisterDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SamplingDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SamplingDataset.pbtxt index 23a5447e523..6b33a96d9e8 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SamplingDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SamplingDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SamplingDataset" + visibility: VISIBLE endpoint { name: "data.SamplingDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SaveDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SaveDataset.pbtxt index d670168929d..20b5562a385 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SaveDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SaveDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SaveDataset" + visibility: VISIBLE endpoint { name: "data.SaveDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ScanDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ScanDataset.pbtxt index 89b63c53f70..838de863044 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ScanDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ScanDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "ScanDataset" + visibility: VISIBLE endpoint { name: "data.ScanDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SetStatsAggregatorDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SetStatsAggregatorDataset.pbtxt index f57abe5a667..6b2b18d0b0b 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SetStatsAggregatorDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SetStatsAggregatorDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SetStatsAggregatorDataset" + visibility: VISIBLE endpoint { name: "data.SetStatsAggregatorDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShardDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShardDataset.pbtxt index 6bd05dae2fb..530d56293ca 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShardDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShardDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "ShardDataset" + visibility: VISIBLE endpoint { name: "data.ShardDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShuffleAndRepeatDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShuffleAndRepeatDatasetV2.pbtxt index 2c38ba7607a..b04b2ba824d 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShuffleAndRepeatDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShuffleAndRepeatDatasetV2.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "ShuffleAndRepeatDatasetV2" + visibility: VISIBLE endpoint { name: "data.ShuffleAndRepeatDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShuffleDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShuffleDataset.pbtxt index 40b2a3f53d0..612510fd386 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShuffleDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShuffleDataset.pbtxt @@ -1,7 +1,4 @@ op { graph_op_name: "ShuffleDataset" - visibility: VISIBLE - endpoint { - name: "data.ShuffleDataset" - } + visibility: SKIP } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShuffleDatasetV3.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShuffleDatasetV3.pbtxt index 677f0d2c1e9..a57f540ef0a 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShuffleDatasetV3.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ShuffleDatasetV3.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "ShuffleDatasetV3" + visibility: VISIBLE endpoint { name: "data.ShuffleDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SleepDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SleepDataset.pbtxt index f8abf22d64c..d43d54dca60 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SleepDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SleepDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SleepDataset" + visibility: VISIBLE endpoint { name: "data.SleepDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SlidingWindowDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SlidingWindowDataset.pbtxt index bada002ca92..332518c4cf1 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SlidingWindowDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SlidingWindowDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SlidingWindowDataset" + visibility: VISIBLE endpoint { name: "data.SlidingWindowDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SnapshotDatasetV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SnapshotDatasetV2.pbtxt index d6f7188631c..d275c27201c 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SnapshotDatasetV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SnapshotDatasetV2.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SnapshotDatasetV2" + visibility: VISIBLE endpoint { name: "data.SnapshotDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseTensorSliceDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseTensorSliceDataset.pbtxt index bb0d1d7a949..7cb0d05fba2 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseTensorSliceDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseTensorSliceDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SparseTensorSliceDataset" + visibility: VISIBLE endpoint { name: "data.SparseTensorSliceDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SqlDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SqlDataset.pbtxt index 8764e81af25..c4723512e3b 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SqlDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SqlDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "SqlDataset" + visibility: VISIBLE endpoint { name: "data.SqlDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_TakeWhileDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_TakeWhileDataset.pbtxt index 5571c0a1d00..9d6bf0f5a7b 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_TakeWhileDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_TakeWhileDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "TakeWhileDataset" + visibility: VISIBLE endpoint { name: "data.TakeWhileDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_TensorDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_TensorDataset.pbtxt index ed0ead6e7ab..30b95c2d9c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_TensorDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_TensorDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "TensorDataset" + visibility: VISIBLE endpoint { name: "data.TensorDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ThreadPoolDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ThreadPoolDataset.pbtxt index 1f8f25a9b57..93a1a33ceeb 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ThreadPoolDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_ThreadPoolDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "ThreadPoolDataset" + visibility: VISIBLE endpoint { name: "data.ThreadPoolDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UnbatchDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UnbatchDataset.pbtxt index 24907c804b0..a084a2db4ab 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UnbatchDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UnbatchDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "UnbatchDataset" + visibility: VISIBLE endpoint { name: "data.UnbatchDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UniqueDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UniqueDataset.pbtxt index e761f775b51..c6313c6fea8 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UniqueDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UniqueDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "UniqueDataset" + visibility: VISIBLE endpoint { name: "data.UniqueDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UnwrapDatasetVariant.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UnwrapDatasetVariant.pbtxt index 10e80a520dc..7520cc17855 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UnwrapDatasetVariant.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_UnwrapDatasetVariant.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "UnwrapDatasetVariant" + visibility: VISIBLE endpoint { name: "data.UnwrapDatasetVariant" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_WindowDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_WindowDataset.pbtxt index 69f12c55e1d..285d8387eb8 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_WindowDataset.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_WindowDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "WindowDataset" + visibility: VISIBLE endpoint { name: "data.WindowDataset" } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_WrapDatasetVariant.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_WrapDatasetVariant.pbtxt index 49b2f1c4409..7fe9c7ebad3 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_WrapDatasetVariant.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_WrapDatasetVariant.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "WrapDatasetVariant" + visibility: VISIBLE endpoint { name: "data.WrapDatasetVariant" } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java index 4f933b2807f..5f1765d1f59 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java @@ -22,38 +22,105 @@ import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.data.AnonymousIterator; +import org.tensorflow.op.data.AssertCardinalityDataset; +import org.tensorflow.op.data.AssertNextDataset; +import org.tensorflow.op.data.AutoShardDataset; import org.tensorflow.op.data.BatchDataset; +import org.tensorflow.op.data.BytesProducedStatsDataset; +import org.tensorflow.op.data.CSVDataset; +import org.tensorflow.op.data.CacheDataset; +import org.tensorflow.op.data.ChooseFastestBranchDataset; +import org.tensorflow.op.data.ChooseFastestDataset; import org.tensorflow.op.data.ConcatenateDataset; +import org.tensorflow.op.data.DataServiceDatasetV2; +import org.tensorflow.op.data.DatasetCardinality; +import org.tensorflow.op.data.DatasetFromGraph; +import org.tensorflow.op.data.DatasetToGraph; +import org.tensorflow.op.data.DatasetToSingleElement; +import org.tensorflow.op.data.DatasetToTfRecord; import org.tensorflow.op.data.DeleteIterator; +import org.tensorflow.op.data.DenseToSparseBatchDataset; import org.tensorflow.op.data.DeserializeIterator; +import org.tensorflow.op.data.DirectedInterleaveDataset; +import org.tensorflow.op.data.FilterByLastComponentDataset; import org.tensorflow.op.data.FilterDataset; +import org.tensorflow.op.data.FinalizeDataset; +import org.tensorflow.op.data.FixedLengthRecordDataset; import org.tensorflow.op.data.FlatMapDataset; +import org.tensorflow.op.data.GeneratorDataset; +import org.tensorflow.op.data.GroupByReducerDataset; +import org.tensorflow.op.data.GroupByWindowDataset; +import org.tensorflow.op.data.IgnoreErrorsDataset; +import org.tensorflow.op.data.InitializeTableFromDataset; import org.tensorflow.op.data.InterleaveDataset; import org.tensorflow.op.data.Iterator; import org.tensorflow.op.data.IteratorGetNext; import org.tensorflow.op.data.IteratorGetNextAsOptional; import org.tensorflow.op.data.IteratorGetNextSync; import org.tensorflow.op.data.IteratorToStringHandle; +import org.tensorflow.op.data.LMDBDataset; +import org.tensorflow.op.data.LatencyStatsDataset; +import org.tensorflow.op.data.LegacyParallelInterleaveDataset; +import org.tensorflow.op.data.LoadDataset; import org.tensorflow.op.data.MakeIterator; import org.tensorflow.op.data.MapAndBatchDataset; import org.tensorflow.op.data.MapDataset; +import org.tensorflow.op.data.MatchingFilesDataset; +import org.tensorflow.op.data.MaxIntraOpParallelismDataset; +import org.tensorflow.op.data.ModelDataset; +import org.tensorflow.op.data.NonSerializableDataset; import org.tensorflow.op.data.OneShotIterator; +import org.tensorflow.op.data.OptimizeDataset; import org.tensorflow.op.data.OptionalFromValue; import org.tensorflow.op.data.OptionalGetValue; import org.tensorflow.op.data.OptionalHasValue; import org.tensorflow.op.data.OptionalNone; +import org.tensorflow.op.data.OptionsDataset; +import org.tensorflow.op.data.PaddedBatchDataset; +import org.tensorflow.op.data.ParallelBatchDataset; +import org.tensorflow.op.data.ParallelInterleaveDataset; +import org.tensorflow.op.data.ParallelMapDataset; +import org.tensorflow.op.data.ParseExampleDataset; +import org.tensorflow.op.data.PrefetchDataset; +import org.tensorflow.op.data.PrivateThreadPoolDataset; +import org.tensorflow.op.data.RandomDataset; import org.tensorflow.op.data.RangeDataset; +import org.tensorflow.op.data.RebatchDatasetV2; +import org.tensorflow.op.data.ReduceDataset; +import org.tensorflow.op.data.RegisterDataset; import org.tensorflow.op.data.RepeatDataset; +import org.tensorflow.op.data.SamplingDataset; +import org.tensorflow.op.data.SaveDataset; +import org.tensorflow.op.data.ScanDataset; import org.tensorflow.op.data.SerializeIterator; +import org.tensorflow.op.data.SetStatsAggregatorDataset; +import org.tensorflow.op.data.ShardDataset; +import org.tensorflow.op.data.ShuffleAndRepeatDataset; +import org.tensorflow.op.data.ShuffleDataset; import org.tensorflow.op.data.SkipDataset; +import org.tensorflow.op.data.SleepDataset; +import org.tensorflow.op.data.SlidingWindowDataset; +import org.tensorflow.op.data.SnapshotDataset; +import org.tensorflow.op.data.SparseTensorSliceDataset; +import org.tensorflow.op.data.SqlDataset; import org.tensorflow.op.data.TakeDataset; +import org.tensorflow.op.data.TakeWhileDataset; +import org.tensorflow.op.data.TensorDataset; import org.tensorflow.op.data.TensorSliceDataset; import org.tensorflow.op.data.TextLineDataset; import org.tensorflow.op.data.TfRecordDataset; +import org.tensorflow.op.data.ThreadPoolDataset; +import org.tensorflow.op.data.UnbatchDataset; +import org.tensorflow.op.data.UniqueDataset; +import org.tensorflow.op.data.UnwrapDatasetVariant; +import org.tensorflow.op.data.WindowDataset; +import org.tensorflow.op.data.WrapDatasetVariant; import org.tensorflow.op.data.ZipDataset; import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; /** @@ -83,6 +150,68 @@ public AnonymousIterator anonymousIterator(List<Class<? extends TType>> outputTy return AnonymousIterator.create(scope, outputTypes, outputShapes); } + /** + * The AssertCardinalityDataset operation + * + * @param inputDataset the inputDataset value + * @param cardinality the cardinality value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of AssertCardinalityDataset + */ + public AssertCardinalityDataset assertCardinalityDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> cardinality, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes) { + return AssertCardinalityDataset.create(scope, inputDataset, cardinality, outputTypes, outputShapes); + } + + /** + * A transformation that asserts which transformations happen next. + * This transformation checks whether the camel-case names (i.e. "FlatMap", not + * "flat_map") of the transformations following this transformation match the list + * of names in the {@code transformations} argument. If there is a mismatch, the + * transformation raises an exception. + * <p>The check occurs when iterating over the contents of the dataset, which + * means that the check happens <em>after</em> any static optimizations are applied + * to the dataset graph. + * + * @param inputDataset A variant tensor representing the input dataset. + * {@code data.AssertNextDataset} passes through the outputs of its input dataset. + * @param transformations A {@code tf.string} vector {@code tf.Tensor} identifying the transformations that are + * expected to happen next. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of AssertNextDataset + */ + public AssertNextDataset assertNextDataset(Operand<? extends TType> inputDataset, + Operand<TString> transformations, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes) { + return AssertNextDataset.create(scope, inputDataset, transformations, outputTypes, outputShapes); + } + + /** + * Creates a dataset that shards the input dataset. + * Creates a dataset that shards the input dataset by num_workers, returning a + * sharded dataset for the index-th worker. This attempts to automatically shard + * a dataset by examining the Dataset graph and inserting a shard op before the + * inputs to a reader Dataset (e.g. CSVDataset, TFRecordDataset). + * <p>This dataset will throw a NotFound error if we cannot shard the dataset + * automatically. + * + * @param inputDataset A variant tensor representing the input dataset. + * @param numWorkers A scalar representing the number of workers to distribute this dataset across. + * @param index A scalar representing the index of the current worker out of num_workers. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of AutoShardDataset + */ + public AutoShardDataset autoShardDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> numWorkers, Operand<TInt64> index, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes, AutoShardDataset.Options... options) { + return AutoShardDataset.create(scope, inputDataset, numWorkers, index, outputTypes, outputShapes, options); + } + /** * Creates a dataset that batches {@code batch_size} elements from {@code input_dataset}. * @@ -101,6 +230,95 @@ public BatchDataset batchDataset(Operand<? extends TType> inputDataset, Operand< return BatchDataset.create(scope, inputDataset, batchSize, dropRemainder, outputTypes, outputShapes, options); } + /** + * Records the bytes size of each element of {@code input_dataset} in a StatsAggregator. + * + * @param inputDataset the inputDataset value + * @param tag the tag value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of BytesProducedStatsDataset + */ + public BytesProducedStatsDataset bytesProducedStatsDataset(Operand<? extends TType> inputDataset, + Operand<TString> tag, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return BytesProducedStatsDataset.create(scope, inputDataset, tag, outputTypes, outputShapes); + } + + /** + * The CSVDatasetV2 operation + * + * @param filenames the filenames value + * @param compressionType the compressionType value + * @param bufferSize the bufferSize value + * @param header the header value + * @param fieldDelim the fieldDelim value + * @param useQuoteDelim the useQuoteDelim value + * @param naValue the naValue value + * @param selectCols the selectCols value + * @param recordDefaults the recordDefaults value + * @param excludeCols the excludeCols value + * @param outputShapes the value of the outputShapes property + * @return a new instance of CSVDataset + */ + public CSVDataset cSVDataset(Operand<TString> filenames, Operand<TString> compressionType, + Operand<TInt64> bufferSize, Operand<TBool> header, Operand<TString> fieldDelim, + Operand<TBool> useQuoteDelim, Operand<TString> naValue, Operand<TInt64> selectCols, + Iterable<Operand<?>> recordDefaults, Operand<TInt64> excludeCols, List<Shape> outputShapes) { + return CSVDataset.create(scope, filenames, compressionType, bufferSize, header, fieldDelim, useQuoteDelim, naValue, selectCols, recordDefaults, excludeCols, outputShapes); + } + + /** + * The CacheDatasetV2 operation + * + * @param inputDataset the inputDataset value + * @param filename the filename value + * @param cache the cache value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of CacheDataset + */ + public CacheDataset cacheDataset(Operand<? extends TType> inputDataset, Operand<TString> filename, + Operand<? extends TType> cache, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes) { + return CacheDataset.create(scope, inputDataset, filename, cache, outputTypes, outputShapes); + } + + /** + * The ChooseFastestBranchDataset operation + * + * @param inputDataset the inputDataset value + * @param ratioNumerator the ratioNumerator value + * @param ratioDenominator the ratioDenominator value + * @param otherArguments the otherArguments value + * @param numElementsPerBranch the value of the numElementsPerBranch property + * @param branches the value of the branches property + * @param otherArgumentsLengths the value of the otherArgumentsLengths property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of ChooseFastestBranchDataset + */ + public ChooseFastestBranchDataset chooseFastestBranchDataset( + Operand<? extends TType> inputDataset, Operand<TInt64> ratioNumerator, + Operand<TInt64> ratioDenominator, Iterable<Operand<?>> otherArguments, + Long numElementsPerBranch, List<ConcreteFunction> branches, List<Long> otherArgumentsLengths, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return ChooseFastestBranchDataset.create(scope, inputDataset, ratioNumerator, ratioDenominator, otherArguments, numElementsPerBranch, branches, otherArgumentsLengths, outputTypes, outputShapes); + } + + /** + * The ChooseFastestDataset operation + * + * @param inputDatasets the inputDatasets value + * @param numExperiments the value of the numExperiments property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of ChooseFastestDataset + */ + public ChooseFastestDataset chooseFastestDataset(Iterable<Operand<? extends TType>> inputDatasets, + Long numExperiments, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return ChooseFastestDataset.create(scope, inputDatasets, numExperiments, outputTypes, outputShapes); + } + /** * Creates a dataset that concatenates {@code input_dataset} with {@code another_dataset}. * @@ -116,6 +334,94 @@ public ConcatenateDataset concatenateDataset(Operand<? extends TType> inputDatas return ConcatenateDataset.create(scope, inputDataset, anotherDataset, outputTypes, outputShapes); } + /** + * Creates a dataset that reads data from the tf.data service. + * + * @param datasetId the datasetId value + * @param processingMode the processingMode value + * @param address the address value + * @param protocol the protocol value + * @param jobName the jobName value + * @param consumerIndex the consumerIndex value + * @param numConsumers the numConsumers value + * @param maxOutstandingRequests the maxOutstandingRequests value + * @param iterationCounter the iterationCounter value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of DataServiceDatasetV2 + */ + public DataServiceDatasetV2 dataServiceDatasetV2(Operand<TInt64> datasetId, + Operand<TString> processingMode, Operand<TString> address, Operand<TString> protocol, + Operand<TString> jobName, Operand<TInt64> consumerIndex, Operand<TInt64> numConsumers, + Operand<TInt64> maxOutstandingRequests, Operand<? extends TType> iterationCounter, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + DataServiceDatasetV2.Options... options) { + return DataServiceDatasetV2.create(scope, datasetId, processingMode, address, protocol, jobName, consumerIndex, numConsumers, maxOutstandingRequests, iterationCounter, outputTypes, outputShapes, options); + } + + /** + * Returns the cardinality of {@code input_dataset}. + * Returns the cardinality of {@code input_dataset}. + * + * @param inputDataset A variant tensor representing the dataset to return cardinality for. + * @return a new instance of DatasetCardinality + */ + public DatasetCardinality datasetCardinality(Operand<? extends TType> inputDataset) { + return DatasetCardinality.create(scope, inputDataset); + } + + /** + * Creates a dataset from the given {@code graph_def}. + * Creates a dataset from the provided {@code graph_def}. + * + * @param graphDef The graph representation of the dataset (as serialized GraphDef). + * @return a new instance of DatasetFromGraph + */ + public DatasetFromGraph datasetFromGraph(Operand<TString> graphDef) { + return DatasetFromGraph.create(scope, graphDef); + } + + /** + * Returns a serialized GraphDef representing {@code input_dataset}. + * Returns a graph representation for {@code input_dataset}. + * + * @param inputDataset A variant tensor representing the dataset to return the graph representation for. + * @param options carries optional attribute values + * @return a new instance of DatasetToGraph + */ + public DatasetToGraph datasetToGraph(Operand<? extends TType> inputDataset, + DatasetToGraph.Options... options) { + return DatasetToGraph.create(scope, inputDataset, options); + } + + /** + * Outputs the single element from the given dataset. + * + * @param dataset A handle to a dataset that contains a single element. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of DatasetToSingleElement + */ + public DatasetToSingleElement datasetToSingleElement(Operand<? extends TType> dataset, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return DatasetToSingleElement.create(scope, dataset, outputTypes, outputShapes); + } + + /** + * Writes the given dataset to the given file using the TFRecord format. + * + * @param inputDataset A variant tensor representing the dataset to write. + * @param filename A scalar string tensor representing the filename to use. + * @param compressionType A scalar string tensor containing either (i) the empty string (no + * compression), (ii) "ZLIB", or (iii) "GZIP". + * @return a new instance of DatasetToTfRecord + */ + public DatasetToTfRecord datasetToTfRecord(Operand<? extends TType> inputDataset, + Operand<TString> filename, Operand<TString> compressionType) { + return DatasetToTfRecord.create(scope, inputDataset, filename, compressionType); + } + /** * A container for an iterator resource. * @@ -128,6 +434,25 @@ public DeleteIterator deleteIterator(Operand<? extends TType> handle, return DeleteIterator.create(scope, handle, deleter); } + /** + * Creates a dataset that batches input elements into a SparseTensor. + * + * @param inputDataset A handle to an input dataset. Must have a single component. + * @param batchSize A scalar representing the number of elements to accumulate in a + * batch. + * @param rowShape A vector representing the dense shape of each row in the produced + * SparseTensor. The shape may be partially specified, using {@code -1} to indicate + * that a particular dimension should use the maximum size of all batch elements. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of DenseToSparseBatchDataset + */ + public DenseToSparseBatchDataset denseToSparseBatchDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> batchSize, Operand<TInt64> rowShape, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes) { + return DenseToSparseBatchDataset.create(scope, inputDataset, batchSize, rowShape, outputTypes, outputShapes); + } + /** * Converts the given variant tensor to an iterator and stores it in the given resource. * @@ -141,6 +466,38 @@ public DeserializeIterator deserializeIterator(Operand<? extends TType> resource return DeserializeIterator.create(scope, resourceHandle, serialized); } + /** + * A substitute for {@code InterleaveDataset} on a fixed list of {@code N} datasets. + * + * @param selectorInputDataset A dataset of scalar {@code DT_INT64} elements that determines which of the + * {@code N} data inputs should produce the next output element. + * @param dataInputDatasets {@code N} datasets with the same type that will be interleaved according to + * the values of {@code selector_input_dataset}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of DirectedInterleaveDataset + */ + public DirectedInterleaveDataset directedInterleaveDataset( + Operand<? extends TType> selectorInputDataset, + Iterable<Operand<? extends TType>> dataInputDatasets, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return DirectedInterleaveDataset.create(scope, selectorInputDataset, dataInputDatasets, outputTypes, outputShapes); + } + + /** + * Creates a dataset containing elements of first component of {@code input_dataset} having true in the last component. + * + * @param inputDataset the inputDataset value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of FilterByLastComponentDataset + */ + public FilterByLastComponentDataset filterByLastComponentDataset( + Operand<? extends TType> inputDataset, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes) { + return FilterByLastComponentDataset.create(scope, inputDataset, outputTypes, outputShapes); + } + /** * Creates a dataset containing elements of {@code input_dataset} matching {@code predicate}. * The {@code predicate} function must return a scalar boolean and accept the @@ -164,6 +521,38 @@ public FilterDataset filterDataset(Operand<? extends TType> inputDataset, return FilterDataset.create(scope, inputDataset, otherArguments, predicate, outputTypes, outputShapes); } + /** + * Creates a dataset by applying {@code tf.data.Options} to {@code input_dataset}. + * + * @param inputDataset A variant tensor representing the input dataset. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of FinalizeDataset + */ + public FinalizeDataset finalizeDataset(Operand<? extends TType> inputDataset, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + FinalizeDataset.Options... options) { + return FinalizeDataset.create(scope, inputDataset, outputTypes, outputShapes, options); + } + + /** + * The FixedLengthRecordDatasetV2 operation + * + * @param filenames the filenames value + * @param headerBytes the headerBytes value + * @param recordBytes the recordBytes value + * @param footerBytes the footerBytes value + * @param bufferSize the bufferSize value + * @param compressionType the compressionType value + * @return a new instance of FixedLengthRecordDataset + */ + public FixedLengthRecordDataset fixedLengthRecordDataset(Operand<TString> filenames, + Operand<TInt64> headerBytes, Operand<TInt64> recordBytes, Operand<TInt64> footerBytes, + Operand<TInt64> bufferSize, Operand<TString> compressionType) { + return FixedLengthRecordDataset.create(scope, filenames, headerBytes, recordBytes, footerBytes, bufferSize, compressionType); + } + /** * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. * Unlike MapDataset, the {@code f} in FlatMapDataset is expected to return a @@ -185,6 +574,110 @@ public FlatMapDataset flatMapDataset(Operand<? extends TType> inputDataset, return FlatMapDataset.create(scope, inputDataset, otherArguments, f, outputTypes, outputShapes); } + /** + * Creates a dataset that invokes a function to generate elements. + * + * @param initFuncOtherArgs the initFuncOtherArgs value + * @param nextFuncOtherArgs the nextFuncOtherArgs value + * @param finalizeFuncOtherArgs the finalizeFuncOtherArgs value + * @param initFunc the value of the initFunc property + * @param nextFunc the value of the nextFunc property + * @param finalizeFunc the value of the finalizeFunc property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of GeneratorDataset + */ + public GeneratorDataset generatorDataset(Iterable<Operand<?>> initFuncOtherArgs, + Iterable<Operand<?>> nextFuncOtherArgs, Iterable<Operand<?>> finalizeFuncOtherArgs, + ConcreteFunction initFunc, ConcreteFunction nextFunc, ConcreteFunction finalizeFunc, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return GeneratorDataset.create(scope, initFuncOtherArgs, nextFuncOtherArgs, finalizeFuncOtherArgs, initFunc, nextFunc, finalizeFunc, outputTypes, outputShapes); + } + + /** + * Creates a dataset that computes a group-by on {@code input_dataset}. + * Creates a dataset that computes a group-by on {@code input_dataset}. + * + * @param inputDataset A variant tensor representing the input dataset. + * @param keyFuncOtherArguments A list of tensors, typically values that were captured when + * building a closure for {@code key_func}. + * @param initFuncOtherArguments A list of tensors, typically values that were captured when + * building a closure for {@code init_func}. + * @param reduceFuncOtherArguments A list of tensors, typically values that were captured when + * building a closure for {@code reduce_func}. + * @param finalizeFuncOtherArguments A list of tensors, typically values that were captured when + * building a closure for {@code finalize_func}. + * @param keyFunc A function mapping an element of {@code input_dataset}, concatenated + * with {@code key_func_other_arguments} to a scalar value of type DT_INT64. + * @param initFunc A function mapping a key of type DT_INT64, concatenated with + * {@code init_func_other_arguments} to the initial reducer state. + * @param reduceFunc A function mapping the current reducer state and an element of {@code input_dataset}, + * concatenated with {@code reduce_func_other_arguments} to a new reducer state. + * @param finalizeFunc A function mapping the final reducer state to an output element. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of GroupByReducerDataset + */ + public GroupByReducerDataset groupByReducerDataset(Operand<? extends TType> inputDataset, + Iterable<Operand<?>> keyFuncOtherArguments, Iterable<Operand<?>> initFuncOtherArguments, + Iterable<Operand<?>> reduceFuncOtherArguments, + Iterable<Operand<?>> finalizeFuncOtherArguments, ConcreteFunction keyFunc, + ConcreteFunction initFunc, ConcreteFunction reduceFunc, ConcreteFunction finalizeFunc, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return GroupByReducerDataset.create(scope, inputDataset, keyFuncOtherArguments, initFuncOtherArguments, reduceFuncOtherArguments, finalizeFuncOtherArguments, keyFunc, initFunc, reduceFunc, finalizeFunc, outputTypes, outputShapes); + } + + /** + * Creates a dataset that computes a windowed group-by on {@code input_dataset}. + * // TODO(mrry): Support non-int64 keys. + * + * @param inputDataset the inputDataset value + * @param keyFuncOtherArguments the keyFuncOtherArguments value + * @param reduceFuncOtherArguments the reduceFuncOtherArguments value + * @param windowSizeFuncOtherArguments the windowSizeFuncOtherArguments value + * @param keyFunc A function mapping an element of {@code input_dataset}, concatenated + * with {@code key_func_other_arguments} to a scalar value of type DT_INT64. + * @param reduceFunc the value of the reduceFunc property + * @param windowSizeFunc the value of the windowSizeFunc property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of GroupByWindowDataset + */ + public GroupByWindowDataset groupByWindowDataset(Operand<? extends TType> inputDataset, + Iterable<Operand<?>> keyFuncOtherArguments, Iterable<Operand<?>> reduceFuncOtherArguments, + Iterable<Operand<?>> windowSizeFuncOtherArguments, ConcreteFunction keyFunc, + ConcreteFunction reduceFunc, ConcreteFunction windowSizeFunc, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return GroupByWindowDataset.create(scope, inputDataset, keyFuncOtherArguments, reduceFuncOtherArguments, windowSizeFuncOtherArguments, keyFunc, reduceFunc, windowSizeFunc, outputTypes, outputShapes); + } + + /** + * Creates a dataset that contains the elements of {@code input_dataset} ignoring errors. + * + * @param inputDataset the inputDataset value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of IgnoreErrorsDataset + */ + public IgnoreErrorsDataset ignoreErrorsDataset(Operand<? extends TType> inputDataset, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + IgnoreErrorsDataset.Options... options) { + return IgnoreErrorsDataset.create(scope, inputDataset, outputTypes, outputShapes, options); + } + + /** + * The InitializeTableFromDataset operation + * + * @param tableHandle the tableHandle value + * @param dataset the dataset value + * @return a new instance of InitializeTableFromDataset + */ + public InitializeTableFromDataset initializeTableFromDataset(Operand<? extends TType> tableHandle, + Operand<? extends TType> dataset) { + return InitializeTableFromDataset.create(scope, tableHandle, dataset); + } + /** * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. * Unlike MapDataset, the {@code f} in InterleaveDataset is expected to return @@ -277,6 +770,91 @@ public IteratorToStringHandle iteratorToStringHandle(Operand<? extends TType> re return IteratorToStringHandle.create(scope, resourceHandle); } + /** + * Creates a dataset that emits the key-value pairs in one or more LMDB files. + * The Lightning Memory-Mapped Database Manager, or LMDB, is an embedded binary + * key-value database. This dataset can read the contents of LMDB database files, + * the names of which generally have the {@code .mdb} suffix. + * <p>Each output element consists of a key-value pair represented as a pair of + * scalar string {@code Tensor}s, where the first {@code Tensor} contains the key and the + * second {@code Tensor} contains the value. + * <p>LMDB uses different file formats on big- and little-endian machines. + * {@code data.LMDBDataset} can only read files in the format of the host machine. + * + * @param filenames A scalar or a vector containing the name(s) of the binary file(s) to be + * read. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of LMDBDataset + */ + public LMDBDataset lMDBDataset(Operand<TString> filenames, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return LMDBDataset.create(scope, filenames, outputTypes, outputShapes); + } + + /** + * Records the latency of producing {@code input_dataset} elements in a StatsAggregator. + * + * @param inputDataset the inputDataset value + * @param tag the tag value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of LatencyStatsDataset + */ + public LatencyStatsDataset latencyStatsDataset(Operand<? extends TType> inputDataset, + Operand<TString> tag, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return LatencyStatsDataset.create(scope, inputDataset, tag, outputTypes, outputShapes); + } + + /** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + * The resulting dataset is similar to the {@code InterleaveDataset}, with the exception + * that if retrieving the next value from a dataset would cause the requester to + * block, it will skip that input dataset. This dataset is especially useful + * when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it + * allows the training step to proceed so long as some data is available. + * <p>!! WARNING !! This dataset is not deterministic! + * + * @param inputDataset the inputDataset value + * @param otherArguments the otherArguments value + * @param cycleLength the cycleLength value + * @param blockLength the blockLength value + * @param bufferOutputElements the bufferOutputElements value + * @param prefetchInputElements the prefetchInputElements value + * @param f A function mapping elements of {@code input_dataset}, concatenated with + * {@code other_arguments}, to a Dataset variant that contains elements matching + * {@code output_types} and {@code output_shapes}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of LegacyParallelInterleaveDataset + */ + public LegacyParallelInterleaveDataset legacyParallelInterleaveDataset( + Operand<? extends TType> inputDataset, Iterable<Operand<?>> otherArguments, + Operand<TInt64> cycleLength, Operand<TInt64> blockLength, + Operand<TInt64> bufferOutputElements, Operand<TInt64> prefetchInputElements, + ConcreteFunction f, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + LegacyParallelInterleaveDataset.Options... options) { + return LegacyParallelInterleaveDataset.create(scope, inputDataset, otherArguments, cycleLength, blockLength, bufferOutputElements, prefetchInputElements, f, outputTypes, outputShapes, options); + } + + /** + * The LoadDataset operation + * + * @param path the path value + * @param readerFuncOtherArgs the readerFuncOtherArgs value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param readerFunc the value of the readerFunc property + * @param options carries optional attribute values + * @return a new instance of LoadDataset + */ + public LoadDataset loadDataset(Operand<TString> path, Iterable<Operand<?>> readerFuncOtherArgs, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + ConcreteFunction readerFunc, LoadDataset.Options... options) { + return LoadDataset.create(scope, path, readerFuncOtherArgs, outputTypes, outputShapes, readerFunc, options); + } + /** * Makes a new iterator from the given {@code dataset} and stores it in {@code iterator}. * This operation may be executed multiple times. Each execution will reset the @@ -342,26 +920,80 @@ public MapDataset mapDataset(Operand<? extends TType> inputDataset, } /** - * Makes a "one-shot" iterator that can be iterated only once. - * A one-shot iterator bundles the logic for defining the dataset and - * the state of the iterator in a single op, which allows simple input - * pipelines to be defined without an additional initialization - * ("MakeIterator") step. - * <p>One-shot iterators have the following limitations: - * <ul> - * <li>They do not support parameterization: all logic for creating the underlying - * dataset must be bundled in the {@code dataset_factory} function.</li> - * <li>They are not resettable. Once a one-shot iterator reaches the end of its - * underlying dataset, subsequent "IteratorGetNext" operations on that - * iterator will always produce an {@code OutOfRange} error.</li> - * </ul> - * <p>For greater flexibility, use "Iterator" and "MakeIterator" to define - * an iterator using an arbitrary subgraph, which may capture tensors - * (including fed values) as parameters, and which may be reset multiple - * times by rerunning "MakeIterator". + * The MatchingFilesDataset operation * - * @param datasetFactory A function of type {@code () -> DT_VARIANT}, where the returned - * DT_VARIANT is a dataset. + * @param patterns the patterns value + * @return a new instance of MatchingFilesDataset + */ + public MatchingFilesDataset matchingFilesDataset(Operand<TString> patterns) { + return MatchingFilesDataset.create(scope, patterns); + } + + /** + * Creates a dataset that overrides the maximum intra-op parallelism. + * + * @param inputDataset the inputDataset value + * @param maxIntraOpParallelism Identifies the maximum intra-op parallelism to use. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of MaxIntraOpParallelismDataset + */ + public MaxIntraOpParallelismDataset maxIntraOpParallelismDataset( + Operand<? extends TType> inputDataset, Operand<TInt64> maxIntraOpParallelism, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return MaxIntraOpParallelismDataset.create(scope, inputDataset, maxIntraOpParallelism, outputTypes, outputShapes); + } + + /** + * Identity transformation that models performance. + * Identity transformation that models performance. + * + * @param inputDataset A variant tensor representing the input dataset. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ModelDataset + */ + public ModelDataset modelDataset(Operand<? extends TType> inputDataset, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + ModelDataset.Options... options) { + return ModelDataset.create(scope, inputDataset, outputTypes, outputShapes, options); + } + + /** + * The NonSerializableDataset operation + * + * @param inputDataset the inputDataset value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of NonSerializableDataset + */ + public NonSerializableDataset nonSerializableDataset(Operand<? extends TType> inputDataset, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return NonSerializableDataset.create(scope, inputDataset, outputTypes, outputShapes); + } + + /** + * Makes a "one-shot" iterator that can be iterated only once. + * A one-shot iterator bundles the logic for defining the dataset and + * the state of the iterator in a single op, which allows simple input + * pipelines to be defined without an additional initialization + * ("MakeIterator") step. + * <p>One-shot iterators have the following limitations: + * <ul> + * <li>They do not support parameterization: all logic for creating the underlying + * dataset must be bundled in the {@code dataset_factory} function.</li> + * <li>They are not resettable. Once a one-shot iterator reaches the end of its + * underlying dataset, subsequent "IteratorGetNext" operations on that + * iterator will always produce an {@code OutOfRange} error.</li> + * </ul> + * <p>For greater flexibility, use "Iterator" and "MakeIterator" to define + * an iterator using an arbitrary subgraph, which may capture tensors + * (including fed values) as parameters, and which may be reset multiple + * times by rerunning "MakeIterator". + * + * @param datasetFactory A function of type {@code () -> DT_VARIANT}, where the returned + * DT_VARIANT is a dataset. * @param outputTypes the value of the outputTypes property * @param outputShapes the value of the outputShapes property * @param options carries optional attribute values @@ -373,6 +1005,26 @@ public OneShotIterator oneShotIterator(ConcreteFunction datasetFactory, return OneShotIterator.create(scope, datasetFactory, outputTypes, outputShapes, options); } + /** + * Creates a dataset by applying related optimizations to {@code input_dataset}. + * Creates a dataset by applying related optimizations to {@code input_dataset}. + * + * @param inputDataset A variant tensor representing the input dataset. + * @param optimizationsEnabled A {@code tf.string} vector {@code tf.Tensor} identifying user enabled optimizations. + * @param optimizationsDisabled A {@code tf.string} vector {@code tf.Tensor} identifying user disabled optimizations. + * @param optimizationsDefault A {@code tf.string} vector {@code tf.Tensor} identifying optimizations by default. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of OptimizeDataset + */ + public OptimizeDataset optimizeDataset(Operand<? extends TType> inputDataset, + Operand<TString> optimizationsEnabled, Operand<TString> optimizationsDisabled, + Operand<TString> optimizationsDefault, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes, OptimizeDataset.Options... options) { + return OptimizeDataset.create(scope, inputDataset, optimizationsEnabled, optimizationsDisabled, optimizationsDefault, outputTypes, outputShapes, options); + } + /** * Constructs an Optional variant from a tuple of tensors. * @@ -415,6 +1067,228 @@ public OptionalNone optionalNone() { return OptionalNone.create(scope); } + /** + * Creates a dataset by attaching tf.data.Options to {@code input_dataset}. + * + * @param inputDataset A variant tensor representing the input dataset. + * @param serializedOptions A {@code tf.string} scalar {@code tf.Tensor} of serialized {@code tf.data.Options} protocol buffer. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of OptionsDataset + */ + public OptionsDataset optionsDataset(Operand<? extends TType> inputDataset, + String serializedOptions, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes) { + return OptionsDataset.create(scope, inputDataset, serializedOptions, outputTypes, outputShapes); + } + + /** + * Creates a dataset that batches and pads {@code batch_size} elements from the input. + * + * @param inputDataset the inputDataset value + * @param batchSize A scalar representing the number of elements to accumulate in a + * batch. + * @param paddedShapes A list of int64 tensors representing the desired padded shapes + * of the corresponding output components. These shapes may be partially + * specified, using {@code -1} to indicate that a particular dimension should be + * padded to the maximum size of all batch elements. + * @param paddingValues A list of scalars containing the padding value to use for + * each of the outputs. + * @param dropRemainder A scalar representing whether the last batch should be dropped in case its size + * is smaller than desired. + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of PaddedBatchDataset + */ + public PaddedBatchDataset paddedBatchDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> batchSize, Iterable<Operand<TInt64>> paddedShapes, + Iterable<Operand<?>> paddingValues, Operand<TBool> dropRemainder, List<Shape> outputShapes, + PaddedBatchDataset.Options... options) { + return PaddedBatchDataset.create(scope, inputDataset, batchSize, paddedShapes, paddingValues, dropRemainder, outputShapes, options); + } + + /** + * The ParallelBatchDataset operation + * + * @param inputDataset the inputDataset value + * @param batchSize the batchSize value + * @param numParallelCalls the numParallelCalls value + * @param dropRemainder the dropRemainder value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ParallelBatchDataset + */ + public ParallelBatchDataset parallelBatchDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> batchSize, Operand<TInt64> numParallelCalls, Operand<TBool> dropRemainder, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + ParallelBatchDataset.Options... options) { + return ParallelBatchDataset.create(scope, inputDataset, batchSize, numParallelCalls, dropRemainder, outputTypes, outputShapes, options); + } + + /** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + * The resulting dataset is similar to the {@code InterleaveDataset}, except that the + * dataset will fetch records from the interleaved datasets in parallel. + * <p>The {@code tf.data} Python API creates instances of this op from + * {@code Dataset.interleave()} when the {@code num_parallel_calls} parameter of that method + * is set to any value other than {@code None}. + * <p>By default, the output of this dataset will be deterministic, which may result + * in the dataset blocking if the next data item to be returned isn't available. + * In order to avoid head-of-line blocking, one can either set the {@code deterministic} + * attribute to "false", or leave it as "default" and set the + * {@code experimental_deterministic} parameter of {@code tf.data.Options} to {@code False}. + * This can improve performance at the expense of non-determinism. + * + * @param inputDataset Dataset that produces a stream of arguments for the function {@code f}. + * @param otherArguments Additional arguments to pass to {@code f} beyond those produced by {@code input_dataset}. + * Evaluated once when the dataset is instantiated. + * @param cycleLength Number of datasets (each created by applying {@code f} to the elements of + * {@code input_dataset}) among which the {@code ParallelInterleaveDatasetV2} will cycle in a + * round-robin fashion. + * @param blockLength Number of elements at a time to produce from each interleaved invocation of a + * dataset returned by {@code f}. + * @param bufferOutputElements The number of elements each iterator being interleaved should buffer (similar + * to the {@code .prefetch()} transformation for each interleaved iterator). + * @param prefetchInputElements Determines the number of iterators to prefetch, allowing buffers to warm up and + * data to be pre-fetched without blocking the main thread. + * @param numParallelCalls Determines the number of threads that should be used for fetching data from + * input datasets in parallel. The Python API {@code tf.data.experimental.AUTOTUNE} + * constant can be used to indicate that the level of parallelism should be autotuned. + * @param f A function mapping elements of {@code input_dataset}, concatenated with + * {@code other_arguments}, to a Dataset variant that contains elements matching + * {@code output_types} and {@code output_shapes}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ParallelInterleaveDataset + */ + public ParallelInterleaveDataset parallelInterleaveDataset(Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, Operand<TInt64> cycleLength, Operand<TInt64> blockLength, + Operand<TInt64> bufferOutputElements, Operand<TInt64> prefetchInputElements, + Operand<TInt64> numParallelCalls, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + ParallelInterleaveDataset.Options... options) { + return ParallelInterleaveDataset.create(scope, inputDataset, otherArguments, cycleLength, blockLength, bufferOutputElements, prefetchInputElements, numParallelCalls, f, outputTypes, outputShapes, options); + } + + /** + * Creates a dataset that applies {@code f} to the outputs of {@code input_dataset}. + * Unlike a "MapDataset", which applies {@code f} sequentially, this dataset invokes up + * to {@code num_parallel_calls} copies of {@code f} in parallel. + * + * @param inputDataset the inputDataset value + * @param otherArguments the otherArguments value + * @param numParallelCalls The number of concurrent invocations of {@code f} that process + * elements from {@code input_dataset} in parallel. + * @param f the value of the f property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ParallelMapDataset + */ + public ParallelMapDataset parallelMapDataset(Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, Operand<TInt64> numParallelCalls, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + ParallelMapDataset.Options... options) { + return ParallelMapDataset.create(scope, inputDataset, otherArguments, numParallelCalls, f, outputTypes, outputShapes, options); + } + + /** + * Transforms {@code input_dataset} containing {@code Example} protos as vectors of DT_STRING into a dataset of {@code Tensor} or {@code SparseTensor} objects representing the parsed features. + * + * @param inputDataset the inputDataset value + * @param numParallelCalls the numParallelCalls value + * @param denseDefaults A dict mapping string keys to {@code Tensor}s. + * The keys of the dict must match the dense_keys of the feature. + * @param sparseKeys A list of string keys in the examples features. + * The results for these keys will be returned as {@code SparseTensor} objects. + * @param denseKeys A list of Ndense string Tensors (scalars). + * The keys expected in the Examples features associated with dense values. + * @param sparseTypes A list of {@code DTypes} of the same length as {@code sparse_keys}. + * Only {@code tf.float32} ({@code FloatList}), {@code tf.int64} ({@code Int64List}), + * and {@code tf.string} ({@code BytesList}) are supported. + * @param denseShapes List of tuples with the same length as {@code dense_keys}. + * The shape of the data for each dense feature referenced by {@code dense_keys}. + * Required for any input tensors identified by {@code dense_keys}. Must be + * either fully defined, or may contain an unknown first dimension. + * An unknown first dimension means the feature is treated as having + * a variable number of blocks, and the output shape along this dimension + * is considered unknown at graph build time. Padding is applied for + * minibatch elements smaller than the maximum number of blocks for the + * given feature along this dimension. + * @param outputTypes The type list for the return values. + * @param outputShapes The list of shapes being produced. + * @param raggedValueTypes the value of the raggedValueTypes property + * @param raggedSplitTypes the value of the raggedSplitTypes property + * @param options carries optional attribute values + * @return a new instance of ParseExampleDataset + */ + public ParseExampleDataset parseExampleDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> numParallelCalls, Iterable<Operand<?>> denseDefaults, List<String> sparseKeys, + List<String> denseKeys, List<Class<? extends TType>> sparseTypes, List<Shape> denseShapes, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + List<Class<? extends TType>> raggedValueTypes, + List<Class<? extends TNumber>> raggedSplitTypes, ParseExampleDataset.Options... options) { + return ParseExampleDataset.create(scope, inputDataset, numParallelCalls, denseDefaults, sparseKeys, denseKeys, sparseTypes, denseShapes, outputTypes, outputShapes, raggedValueTypes, raggedSplitTypes, options); + } + + /** + * Creates a dataset that asynchronously prefetches elements from {@code input_dataset}. + * + * @param inputDataset the inputDataset value + * @param bufferSize The maximum number of elements to buffer in an iterator over + * this dataset. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of PrefetchDataset + */ + public PrefetchDataset prefetchDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> bufferSize, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes, PrefetchDataset.Options... options) { + return PrefetchDataset.create(scope, inputDataset, bufferSize, outputTypes, outputShapes, options); + } + + /** + * Creates a dataset that uses a custom thread pool to compute {@code input_dataset}. + * + * @param inputDataset the inputDataset value + * @param numThreads Identifies the number of threads to use for the private threadpool. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of PrivateThreadPoolDataset + */ + public PrivateThreadPoolDataset privateThreadPoolDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> numThreads, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes) { + return PrivateThreadPoolDataset.create(scope, inputDataset, numThreads, outputTypes, outputShapes); + } + + /** + * Creates a Dataset that returns pseudorandom numbers. + * Creates a Dataset that returns a stream of uniformly distributed + * pseudorandom 64-bit signed integers. + * <p>In the TensorFlow Python API, you can instantiate this dataset via the + * class {@code tf.data.experimental.RandomDataset}. + * <p>Instances of this dataset are also created as a result of the + * {@code hoist_random_uniform} static optimization. Whether this optimization is + * performed is determined by the {@code experimental_optimization.hoist_random_uniform} + * option of {@code tf.data.Options}. + * + * @param seed A scalar seed for the random number generator. If either seed or + * seed2 is set to be non-zero, the random number generator is seeded + * by the given seed. Otherwise, a random seed is used. + * @param seed2 A second scalar seed to avoid seed collision. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of RandomDataset + */ + public RandomDataset randomDataset(Operand<TInt64> seed, Operand<TInt64> seed2, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return RandomDataset.create(scope, seed, seed2, outputTypes, outputShapes); + } + /** * Creates a dataset with a range of values. Corresponds to python's xrange. * @@ -430,6 +1304,61 @@ public RangeDataset rangeDataset(Operand<TInt64> start, Operand<TInt64> stop, return RangeDataset.create(scope, start, stop, step, outputTypes, outputShapes); } + /** + * Creates a dataset that changes the batch size. + * Creates a dataset that rebatches elements from {@code input_dataset} into new batch + * sizes. + * + * @param inputDataset A variant tensor representing the input dataset. + * @param batchSizes A vector of integers representing the size of batches to produce. These values + * are cycled through in order. + * @param dropRemainder the dropRemainder value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of RebatchDatasetV2 + */ + public RebatchDatasetV2 rebatchDatasetV2(Operand<? extends TType> inputDataset, + Operand<TInt64> batchSizes, Operand<TBool> dropRemainder, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return RebatchDatasetV2.create(scope, inputDataset, batchSizes, dropRemainder, outputTypes, outputShapes); + } + + /** + * Reduces the input dataset to a singleton using a reduce function. + * + * @param inputDataset A variant tensor representing the input dataset. + * @param initialState A nested structure of tensors, representing the initial state of the + * transformation. + * @param otherArguments the otherArguments value + * @param f A function that maps {@code (old_state, input_element)} to {@code new_state}. It must take + * two arguments and return a nested structures of tensors. The structure of + * {@code new_state} must match the structure of {@code initial_state}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ReduceDataset + */ + public ReduceDataset reduceDataset(Operand<? extends TType> inputDataset, + Iterable<Operand<?>> initialState, Iterable<Operand<?>> otherArguments, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + ReduceDataset.Options... options) { + return ReduceDataset.create(scope, inputDataset, initialState, otherArguments, f, outputTypes, outputShapes, options); + } + + /** + * Registers a dataset with the tf.data service. + * + * @param dataset the dataset value + * @param address the address value + * @param protocol the protocol value + * @param externalStatePolicy the value of the externalStatePolicy property + * @return a new instance of RegisterDataset + */ + public RegisterDataset registerDataset(Operand<? extends TType> dataset, Operand<TString> address, + Operand<TString> protocol, Long externalStatePolicy) { + return RegisterDataset.create(scope, dataset, address, protocol, externalStatePolicy); + } + /** * Creates a dataset that emits the outputs of {@code input_dataset} {@code count} times. * @@ -445,6 +1374,64 @@ public RepeatDataset repeatDataset(Operand<? extends TType> inputDataset, Operan return RepeatDataset.create(scope, inputDataset, count, outputTypes, outputShapes); } + /** + * Creates a dataset that takes a Bernoulli sample of the contents of another dataset. + * There is no transformation in the {@code tf.data} Python API for creating this dataset. + * Instead, it is created as a result of the {@code filter_with_random_uniform_fusion} + * static optimization. Whether this optimization is performed is determined by the + * {@code experimental_optimization.filter_with_random_uniform_fusion} option of + * {@code tf.data.Options}. + * + * @param inputDataset the inputDataset value + * @param rate A scalar representing the sample rate. Each element of {@code input_dataset} is + * retained with this probability, independent of all other elements. + * @param seed A scalar representing seed of random number generator. + * @param seed2 A scalar representing seed2 of random number generator. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of SamplingDataset + */ + public SamplingDataset samplingDataset(Operand<? extends TType> inputDataset, + Operand<TFloat32> rate, Operand<TInt64> seed, Operand<TInt64> seed2, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return SamplingDataset.create(scope, inputDataset, rate, seed, seed2, outputTypes, outputShapes); + } + + /** + * The SaveDataset operation + * + * @param inputDataset the inputDataset value + * @param path the path value + * @param shardFuncOtherArgs the shardFuncOtherArgs value + * @param shardFunc the value of the shardFunc property + * @param options carries optional attribute values + * @return a new instance of SaveDataset + */ + public SaveDataset saveDataset(Operand<? extends TType> inputDataset, Operand<TString> path, + Iterable<Operand<?>> shardFuncOtherArgs, ConcreteFunction shardFunc, + SaveDataset.Options... options) { + return SaveDataset.create(scope, inputDataset, path, shardFuncOtherArgs, shardFunc, options); + } + + /** + * Creates a dataset successively reduces {@code f} over the elements of {@code input_dataset}. + * + * @param inputDataset the inputDataset value + * @param initialState the initialState value + * @param otherArguments the otherArguments value + * @param f the value of the f property + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ScanDataset + */ + public ScanDataset scanDataset(Operand<? extends TType> inputDataset, + Iterable<Operand<?>> initialState, Iterable<Operand<?>> otherArguments, ConcreteFunction f, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + ScanDataset.Options... options) { + return ScanDataset.create(scope, inputDataset, initialState, otherArguments, f, outputTypes, outputShapes, options); + } + /** * Converts the given {@code resource_handle} representing an iterator to a variant tensor. * @@ -457,6 +1444,83 @@ public SerializeIterator serializeIterator(Operand<? extends TType> resourceHand return SerializeIterator.create(scope, resourceHandle, options); } + /** + * The SetStatsAggregatorDataset operation + * + * @param inputDataset the inputDataset value + * @param statsAggregator the statsAggregator value + * @param tag the tag value + * @param counterPrefix the counterPrefix value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of SetStatsAggregatorDataset + */ + public SetStatsAggregatorDataset setStatsAggregatorDataset(Operand<? extends TType> inputDataset, + Operand<? extends TType> statsAggregator, Operand<TString> tag, + Operand<TString> counterPrefix, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes) { + return SetStatsAggregatorDataset.create(scope, inputDataset, statsAggregator, tag, counterPrefix, outputTypes, outputShapes); + } + + /** + * Creates a {@code Dataset} that includes only 1/{@code num_shards} of this dataset. + * + * @param inputDataset the inputDataset value + * @param numShards An integer representing the number of shards operating in parallel. + * @param index An integer representing the current worker index. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ShardDataset + */ + public ShardDataset shardDataset(Operand<? extends TType> inputDataset, Operand<TInt64> numShards, + Operand<TInt64> index, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + ShardDataset.Options... options) { + return ShardDataset.create(scope, inputDataset, numShards, index, outputTypes, outputShapes, options); + } + + /** + * The ShuffleAndRepeatDatasetV2 operation + * + * @param inputDataset the inputDataset value + * @param bufferSize the bufferSize value + * @param seed the seed value + * @param seed2 the seed2 value + * @param count the count value + * @param seedGenerator the seedGenerator value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ShuffleAndRepeatDataset + */ + public ShuffleAndRepeatDataset shuffleAndRepeatDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> bufferSize, Operand<TInt64> seed, Operand<TInt64> seed2, + Operand<TInt64> count, Operand<? extends TType> seedGenerator, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, + ShuffleAndRepeatDataset.Options... options) { + return ShuffleAndRepeatDataset.create(scope, inputDataset, bufferSize, seed, seed2, count, seedGenerator, outputTypes, outputShapes, options); + } + + /** + * The ShuffleDatasetV3 operation + * + * @param inputDataset the inputDataset value + * @param bufferSize the bufferSize value + * @param seed the seed value + * @param seed2 the seed2 value + * @param seedGenerator the seedGenerator value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param options carries optional attribute values + * @return a new instance of ShuffleDataset + */ + public ShuffleDataset shuffleDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> bufferSize, Operand<TInt64> seed, Operand<TInt64> seed2, + Operand<? extends TType> seedGenerator, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes, ShuffleDataset.Options... options) { + return ShuffleDataset.create(scope, inputDataset, bufferSize, seed, seed2, seedGenerator, outputTypes, outputShapes, options); + } + /** * Creates a dataset that skips {@code count} elements from the {@code input_dataset}. * @@ -472,6 +1536,95 @@ public SkipDataset skipDataset(Operand<? extends TType> inputDataset, Operand<TI return SkipDataset.create(scope, inputDataset, count, outputTypes, outputShapes); } + /** + * The SleepDataset operation + * + * @param inputDataset the inputDataset value + * @param sleepMicroseconds the sleepMicroseconds value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of SleepDataset + */ + public SleepDataset sleepDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> sleepMicroseconds, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes) { + return SleepDataset.create(scope, inputDataset, sleepMicroseconds, outputTypes, outputShapes); + } + + /** + * Creates a dataset that passes a sliding window over {@code input_dataset}. + * + * @param inputDataset the inputDataset value + * @param windowSize A scalar representing the number of elements in the + * sliding window. + * @param windowShift A scalar representing the steps moving the sliding window + * forward in one iteration. It must be positive. + * @param windowStride A scalar representing the stride of the input elements of the sliding window. + * It must be positive. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of SlidingWindowDataset + */ + public SlidingWindowDataset slidingWindowDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> windowSize, Operand<TInt64> windowShift, Operand<TInt64> windowStride, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return SlidingWindowDataset.create(scope, inputDataset, windowSize, windowShift, windowStride, outputTypes, outputShapes); + } + + /** + * Creates a dataset that will write to / read from a snapshot. + * This dataset attempts to determine whether a valid snapshot exists at the + * {@code snapshot_path}, and reads from the snapshot in lieu of using {@code input_dataset}. + * If not, it will run the preprocessing pipeline as usual, and write out a + * snapshot of the data processed for future use. + * + * @param inputDataset A variant tensor representing the input dataset. + * @param path The path we should write snapshots to / read snapshots from. + * @param readerFuncOtherArgs the readerFuncOtherArgs value + * @param shardFuncOtherArgs the shardFuncOtherArgs value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @param readerFunc Optional. A function to control how to read data from snapshot shards. + * @param shardFunc Optional. A function to control how to shard data when writing a snapshot. + * @param options carries optional attribute values + * @return a new instance of SnapshotDataset + */ + public SnapshotDataset snapshotDataset(Operand<? extends TType> inputDataset, + Operand<TString> path, Iterable<Operand<?>> readerFuncOtherArgs, + Iterable<Operand<?>> shardFuncOtherArgs, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes, ConcreteFunction readerFunc, ConcreteFunction shardFunc, + SnapshotDataset.Options... options) { + return SnapshotDataset.create(scope, inputDataset, path, readerFuncOtherArgs, shardFuncOtherArgs, outputTypes, outputShapes, readerFunc, shardFunc, options); + } + + /** + * Creates a dataset that splits a SparseTensor into elements row-wise. + * + * @param indices the indices value + * @param values the values value + * @param denseShape the denseShape value + * @return a new instance of SparseTensorSliceDataset + */ + public SparseTensorSliceDataset sparseTensorSliceDataset(Operand<TInt64> indices, + Operand<? extends TType> values, Operand<TInt64> denseShape) { + return SparseTensorSliceDataset.create(scope, indices, values, denseShape); + } + + /** + * Creates a dataset that executes a SQL query and emits rows of the result set. + * + * @param driverName The database type. Currently, the only supported type is 'sqlite'. + * @param dataSourceName A connection string to connect to the database. + * @param query A SQL query to execute. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of SqlDataset + */ + public SqlDataset sqlDataset(Operand<TString> driverName, Operand<TString> dataSourceName, + Operand<TString> query, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return SqlDataset.create(scope, driverName, dataSourceName, query, outputTypes, outputShapes); + } + /** * Creates a dataset that contains {@code count} elements from the {@code input_dataset}. * @@ -488,6 +1641,40 @@ public TakeDataset takeDataset(Operand<? extends TType> inputDataset, Operand<TI return TakeDataset.create(scope, inputDataset, count, outputTypes, outputShapes); } + /** + * Creates a dataset that stops iteration when predicate` is false. + * The {@code predicate} function must return a scalar boolean and accept the + * following arguments: + * <ul> + * <li>One tensor for each component of an element of {@code input_dataset}.</li> + * <li>One tensor for each value in {@code other_arguments}.</li> + * </ul> + * + * @param inputDataset the inputDataset value + * @param otherArguments A list of tensors, typically values that were captured when + * building a closure for {@code predicate}. + * @param predicate A function returning a scalar boolean. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of TakeWhileDataset + */ + public TakeWhileDataset takeWhileDataset(Operand<? extends TType> inputDataset, + Iterable<Operand<?>> otherArguments, ConcreteFunction predicate, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return TakeWhileDataset.create(scope, inputDataset, otherArguments, predicate, outputTypes, outputShapes); + } + + /** + * Creates a dataset that emits {@code components} as a tuple of tensors once. + * + * @param components the components value + * @param outputShapes the value of the outputShapes property + * @return a new instance of TensorDataset + */ + public TensorDataset tensorDataset(Iterable<Operand<?>> components, List<Shape> outputShapes) { + return TensorDataset.create(scope, components, outputShapes); + } + /** * Creates a dataset that emits each dim-0 slice of {@code components} once. * @@ -531,6 +1718,127 @@ public TfRecordDataset tfRecordDataset(Operand<TString> filenames, return TfRecordDataset.create(scope, filenames, compressionType, bufferSize); } + /** + * Creates a dataset that uses a custom thread pool to compute {@code input_dataset}. + * + * @param inputDataset the inputDataset value + * @param threadPool A resource produced by the ThreadPoolHandle op. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of ThreadPoolDataset + */ + public ThreadPoolDataset threadPoolDataset(Operand<? extends TType> inputDataset, + Operand<? extends TType> threadPool, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes) { + return ThreadPoolDataset.create(scope, inputDataset, threadPool, outputTypes, outputShapes); + } + + /** + * A dataset that splits the elements of its input into multiple elements. + * + * @param inputDataset the inputDataset value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of UnbatchDataset + */ + public UnbatchDataset unbatchDataset(Operand<? extends TType> inputDataset, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return UnbatchDataset.create(scope, inputDataset, outputTypes, outputShapes); + } + + /** + * Creates a dataset that contains the unique elements of {@code input_dataset}. + * + * @param inputDataset the inputDataset value + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of UniqueDataset + */ + public UniqueDataset uniqueDataset(Operand<? extends TType> inputDataset, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { + return UniqueDataset.create(scope, inputDataset, outputTypes, outputShapes); + } + + /** + * The UnwrapDatasetVariant operation + * + * @param inputHandle the inputHandle value + * @return a new instance of UnwrapDatasetVariant + */ + public UnwrapDatasetVariant unwrapDatasetVariant(Operand<? extends TType> inputHandle) { + return UnwrapDatasetVariant.create(scope, inputHandle); + } + + /** + * Combines (nests of) input elements into a dataset of (nests of) windows. + * <p>A "window" is a finite dataset of flat elements of size {@code size} (or possibly + * fewer if there are not enough input elements to fill the window and + * {@code drop_remainder} evaluates to false). + * <p>The {@code shift} argument determines the number of input elements by which + * the window moves on each iteration. The first element in the {@code k}th window + * will be element + * <pre> + * 1 + (k-1) * shift + * </pre> + * <p>of the input dataset. In particular, the first element of the first window + * will always be the first element of the input dataset. + * <p>If the {@code stride} parameter is greater than 1, then each window will skip + * {@code (stride - 1)} input elements between each element that appears in the + * window. Output windows will still contain {@code size} elements regardless of + * the value of {@code stride}. + * <p>The {@code stride} argument determines the stride of the input elements, and the + * {@code shift} argument determines the shift of the window. + * <p>For example, letting {@code {...}} to represent a Dataset: + * <ul> + * <li>{@code tf.data.Dataset.range(7).window(2)} produces + * {@code {{0, 1}, {2, 3}, {4, 5}, {6}}}</li> + * <li>{@code tf.data.Dataset.range(7).window(3, 2, 1, True)} produces + * {@code {{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}}</li> + * <li>{@code tf.data.Dataset.range(7).window(3, 1, 2, True)} produces + * {@code {{0, 2, 4}, {1, 3, 5}, {2, 4, 6}}}</li> + * </ul> + * <p>Note that when the {@code window} transformation is applied to a dataset of + * nested elements, it produces a dataset of nested windows. + * <p>For example: + * <ul> + * <li>{@code tf.data.Dataset.from_tensor_slices((range(4), range(4))).window(2)} + * produces {@code {({0, 1}, {0, 1}), ({2, 3}, {2, 3})}}</li> + * <li>{@code tf.data.Dataset.from_tensor_slices({"a": range(4)}).window(2)} + * produces {@code {{"a": {0, 1}}, {"a": {2, 3}}}}</li> + * </ul> + * + * @param inputDataset the inputDataset value + * @param sizeOutput An integer scalar, representing the number of elements + * of the input dataset to combine into a window. Must be positive. + * @param shift An integer scalar, representing the number of input elements + * by which the window moves in each iteration. Defaults to {@code size}. + * Must be positive. + * @param stride An integer scalar, representing the stride of the input elements + * in the sliding window. Must be positive. The default value of 1 means + * "retain every input element". + * @param dropRemainder A Boolean scalar, representing whether the last window should be + * dropped if its size is smaller than {@code window_size}. + * @param outputTypes the value of the outputTypes property + * @param outputShapes the value of the outputShapes property + * @return a new instance of WindowDataset + */ + public WindowDataset windowDataset(Operand<? extends TType> inputDataset, + Operand<TInt64> sizeOutput, Operand<TInt64> shift, Operand<TInt64> stride, + Operand<TBool> dropRemainder, List<Class<? extends TType>> outputTypes, + List<Shape> outputShapes) { + return WindowDataset.create(scope, inputDataset, sizeOutput, shift, stride, dropRemainder, outputTypes, outputShapes); + } + + /** + * The WrapDatasetVariant operation + * + * @param inputHandle the inputHandle value + * @return a new instance of WrapDatasetVariant + */ + public WrapDatasetVariant wrapDatasetVariant(Operand<? extends TType> inputHandle) { + return WrapDatasetVariant.create(scope, inputHandle); + } + /** * Creates a dataset that zips together {@code input_datasets}. * The elements of the resulting dataset are created by zipping corresponding diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertCardinalityDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertCardinalityDataset.java index e79b4f5ece5..f23c263c973 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertCardinalityDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertCardinalityDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * The AssertCardinalityDataset operation */ +@Operator( + group = "data" +) public final class AssertCardinalityDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertNextDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertNextDataset.java index 55d04a508aa..193a9c7e21f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertNextDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertNextDataset.java @@ -27,6 +27,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; @@ -40,6 +41,9 @@ * means that the check happens <em>after</em> any static optimizations are applied * to the dataset graph. */ +@Operator( + group = "data" +) public final class AssertNextDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AutoShardDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AutoShardDataset.java index 901aa906739..fd68bffe522 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AutoShardDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AutoShardDataset.java @@ -27,6 +27,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -39,6 +40,9 @@ * <p>This dataset will throw a NotFound error if we cannot shard the dataset * automatically. */ +@Operator( + group = "data" +) public final class AutoShardDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BytesProducedStatsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BytesProducedStatsDataset.java index d2965b028d7..2bea64b5ee9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BytesProducedStatsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BytesProducedStatsDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; /** * Records the bytes size of each element of {@code input_dataset} in a StatsAggregator. */ +@Operator( + group = "data" +) public final class BytesProducedStatsDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CSVDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CSVDataset.java index 6eaaf8f621c..321cfb95669 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CSVDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CSVDataset.java @@ -27,19 +27,23 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; /** - * The CSVDataset operation + * The CSVDatasetV2 operation */ +@Operator( + group = "data" +) public final class CSVDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "CSVDataset"; + public static final String OP_NAME = "CSVDatasetV2"; private Output<? extends TType> handle; @@ -51,7 +55,7 @@ private CSVDataset(Operation operation) { } /** - * Factory method to create a class wrapping a new CSVDataset operation. + * Factory method to create a class wrapping a new CSVDatasetV2 operation. * * @param scope current scope * @param filenames the filenames value @@ -63,6 +67,7 @@ private CSVDataset(Operation operation) { * @param naValue the naValue value * @param selectCols the selectCols value * @param recordDefaults the recordDefaults value + * @param excludeCols the excludeCols value * @param outputShapes the value of the outputShapes property * @return a new instance of CSVDataset */ @@ -72,7 +77,8 @@ private CSVDataset(Operation operation) { public static CSVDataset create(Scope scope, Operand<TString> filenames, Operand<TString> compressionType, Operand<TInt64> bufferSize, Operand<TBool> header, Operand<TString> fieldDelim, Operand<TBool> useQuoteDelim, Operand<TString> naValue, - Operand<TInt64> selectCols, Iterable<Operand<?>> recordDefaults, List<Shape> outputShapes) { + Operand<TInt64> selectCols, Iterable<Operand<?>> recordDefaults, Operand<TInt64> excludeCols, + List<Shape> outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("CSVDataset")); opBuilder.addInput(filenames.asOutput()); opBuilder.addInput(compressionType.asOutput()); @@ -83,6 +89,7 @@ public static CSVDataset create(Scope scope, Operand<TString> filenames, opBuilder.addInput(naValue.asOutput()); opBuilder.addInput(selectCols.asOutput()); opBuilder.addInputList(Operands.asOutputs(recordDefaults)); + opBuilder.addInput(excludeCols.asOutput()); opBuilder = scope.apply(opBuilder); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0 ; i < outputShapesArray.length ; i++) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CSVDatasetV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CSVDatasetV2.java deleted file mode 100644 index 47db26f6dd6..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CSVDatasetV2.java +++ /dev/null @@ -1,112 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.data; - -import java.util.List; -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Operands; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.types.TBool; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.TString; -import org.tensorflow.types.family.TType; - -/** - * The CSVDatasetV2 operation - */ -public final class CSVDatasetV2 extends RawOp implements Operand<TType> { - /** - * The name of this op, as known by TensorFlow core engine - */ - public static final String OP_NAME = "CSVDatasetV2"; - - private Output<? extends TType> handle; - - @SuppressWarnings("unchecked") - private CSVDatasetV2(Operation operation) { - super(operation); - int outputIdx = 0; - handle = operation.output(outputIdx++); - } - - /** - * Factory method to create a class wrapping a new CSVDatasetV2 operation. - * - * @param scope current scope - * @param filenames the filenames value - * @param compressionType the compressionType value - * @param bufferSize the bufferSize value - * @param header the header value - * @param fieldDelim the fieldDelim value - * @param useQuoteDelim the useQuoteDelim value - * @param naValue the naValue value - * @param selectCols the selectCols value - * @param recordDefaults the recordDefaults value - * @param excludeCols the excludeCols value - * @param outputShapes the value of the outputShapes property - * @return a new instance of CSVDatasetV2 - */ - @Endpoint( - describeByClass = true - ) - public static CSVDatasetV2 create(Scope scope, Operand<TString> filenames, - Operand<TString> compressionType, Operand<TInt64> bufferSize, Operand<TBool> header, - Operand<TString> fieldDelim, Operand<TBool> useQuoteDelim, Operand<TString> naValue, - Operand<TInt64> selectCols, Iterable<Operand<?>> recordDefaults, Operand<TInt64> excludeCols, - List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("CSVDatasetV2")); - opBuilder.addInput(filenames.asOutput()); - opBuilder.addInput(compressionType.asOutput()); - opBuilder.addInput(bufferSize.asOutput()); - opBuilder.addInput(header.asOutput()); - opBuilder.addInput(fieldDelim.asOutput()); - opBuilder.addInput(useQuoteDelim.asOutput()); - opBuilder.addInput(naValue.asOutput()); - opBuilder.addInput(selectCols.asOutput()); - opBuilder.addInputList(Operands.asOutputs(recordDefaults)); - opBuilder.addInput(excludeCols.asOutput()); - opBuilder = scope.apply(opBuilder); - Shape[] outputShapesArray = new Shape[outputShapes.size()]; - for (int i = 0 ; i < outputShapesArray.length ; i++) { - outputShapesArray[i] = outputShapes.get(i); - } - opBuilder.setAttr("output_shapes", outputShapesArray); - return new CSVDatasetV2(opBuilder.build()); - } - - /** - * Gets handle. - * - * @return handle. - */ - public Output<? extends TType> handle() { - return handle; - } - - @Override - @SuppressWarnings("unchecked") - public Output<TType> asOutput() { - return (Output<TType>) handle; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDataset.java index b610b7af19a..b9903b62d21 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDataset.java @@ -27,21 +27,21 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; /** - * Creates a dataset that caches elements from {@code input_dataset}. - * A CacheDataset will iterate over the input_dataset, and store tensors. If the - * cache already exists, the cache will be used. If the cache is inappropriate - * (e.g. cannot be opened, contains tensors of the wrong shape / size), an error - * will the returned when used. + * The CacheDatasetV2 operation */ +@Operator( + group = "data" +) public final class CacheDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "CacheDataset"; + public static final String OP_NAME = "CacheDatasetV2"; private Output<? extends TType> handle; @@ -53,12 +53,12 @@ private CacheDataset(Operation operation) { } /** - * Factory method to create a class wrapping a new CacheDataset operation. + * Factory method to create a class wrapping a new CacheDatasetV2 operation. * * @param scope current scope * @param inputDataset the inputDataset value - * @param filename A path on the filesystem where we should cache the dataset. Note: this - * will be a directory. + * @param filename the filename value + * @param cache the cache value * @param outputTypes the value of the outputTypes property * @param outputShapes the value of the outputShapes property * @return a new instance of CacheDataset @@ -67,11 +67,12 @@ private CacheDataset(Operation operation) { describeByClass = true ) public static CacheDataset create(Scope scope, Operand<? extends TType> inputDataset, - Operand<TString> filename, List<Class<? extends TType>> outputTypes, - List<Shape> outputShapes) { + Operand<TString> filename, Operand<? extends TType> cache, + List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("CacheDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(filename.asOutput()); + opBuilder.addInput(cache.asOutput()); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDatasetV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDatasetV2.java deleted file mode 100644 index 60b371e0f4c..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDatasetV2.java +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.data; - -import java.util.List; -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Operands; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.types.TString; -import org.tensorflow.types.family.TType; - -/** - * The CacheDatasetV2 operation - */ -public final class CacheDatasetV2 extends RawOp implements Operand<TType> { - /** - * The name of this op, as known by TensorFlow core engine - */ - public static final String OP_NAME = "CacheDatasetV2"; - - private Output<? extends TType> handle; - - @SuppressWarnings("unchecked") - private CacheDatasetV2(Operation operation) { - super(operation); - int outputIdx = 0; - handle = operation.output(outputIdx++); - } - - /** - * Factory method to create a class wrapping a new CacheDatasetV2 operation. - * - * @param scope current scope - * @param inputDataset the inputDataset value - * @param filename the filename value - * @param cache the cache value - * @param outputTypes the value of the outputTypes property - * @param outputShapes the value of the outputShapes property - * @return a new instance of CacheDatasetV2 - */ - @Endpoint( - describeByClass = true - ) - public static CacheDatasetV2 create(Scope scope, Operand<? extends TType> inputDataset, - Operand<TString> filename, Operand<? extends TType> cache, - List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { - OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("CacheDatasetV2")); - opBuilder.addInput(inputDataset.asOutput()); - opBuilder.addInput(filename.asOutput()); - opBuilder.addInput(cache.asOutput()); - opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); - Shape[] outputShapesArray = new Shape[outputShapes.size()]; - for (int i = 0 ; i < outputShapesArray.length ; i++) { - outputShapesArray[i] = outputShapes.get(i); - } - opBuilder.setAttr("output_shapes", outputShapesArray); - return new CacheDatasetV2(opBuilder.build()); - } - - /** - * Gets handle. - * - * @return handle. - */ - public Output<? extends TType> handle() { - return handle; - } - - @Override - @SuppressWarnings("unchecked") - public Output<TType> asOutput() { - return (Output<TType>) handle; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestBranchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestBranchDataset.java index 5fe596e3538..5623bcf21b5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestBranchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestBranchDataset.java @@ -28,12 +28,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * The ChooseFastestBranchDataset operation */ +@Operator( + group = "data" +) public final class ChooseFastestBranchDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestDataset.java index d91ed316fe5..e5c26edcba8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestDataset.java @@ -27,11 +27,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * The ChooseFastestDataset operation */ +@Operator( + group = "data" +) public final class ChooseFastestDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DataServiceDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DataServiceDataset.java deleted file mode 100644 index cb7ed81bcd1..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DataServiceDataset.java +++ /dev/null @@ -1,172 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.data; - -import java.util.List; -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Operands; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.TString; -import org.tensorflow.types.family.TType; - -/** - * Creates a dataset that reads data from the tf.data service. - */ -public final class DataServiceDataset extends RawOp implements Operand<TType> { - /** - * The name of this op, as known by TensorFlow core engine - */ - public static final String OP_NAME = "DataServiceDataset"; - - private Output<? extends TType> handle; - - @SuppressWarnings("unchecked") - private DataServiceDataset(Operation operation) { - super(operation); - int outputIdx = 0; - handle = operation.output(outputIdx++); - } - - /** - * Factory method to create a class wrapping a new DataServiceDataset operation. - * - * @param scope current scope - * @param datasetId the datasetId value - * @param processingMode the processingMode value - * @param address the address value - * @param protocol the protocol value - * @param jobName the jobName value - * @param maxOutstandingRequests the maxOutstandingRequests value - * @param iterationCounter the iterationCounter value - * @param outputTypes the value of the outputTypes property - * @param outputShapes the value of the outputShapes property - * @param options carries optional attribute values - * @return a new instance of DataServiceDataset - */ - @Endpoint( - describeByClass = true - ) - public static DataServiceDataset create(Scope scope, Operand<TInt64> datasetId, - Operand<TString> processingMode, Operand<TString> address, Operand<TString> protocol, - Operand<TString> jobName, Operand<TInt64> maxOutstandingRequests, - Operand<? extends TType> iterationCounter, List<Class<? extends TType>> outputTypes, - List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("DataServiceDataset")); - opBuilder.addInput(datasetId.asOutput()); - opBuilder.addInput(processingMode.asOutput()); - opBuilder.addInput(address.asOutput()); - opBuilder.addInput(protocol.asOutput()); - opBuilder.addInput(jobName.asOutput()); - opBuilder.addInput(maxOutstandingRequests.asOutput()); - opBuilder.addInput(iterationCounter.asOutput()); - opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); - Shape[] outputShapesArray = new Shape[outputShapes.size()]; - for (int i = 0 ; i < outputShapesArray.length ; i++) { - outputShapesArray[i] = outputShapes.get(i); - } - opBuilder.setAttr("output_shapes", outputShapesArray); - if (options != null) { - for (Options opts : options) { - if (opts.taskRefreshIntervalHintMs != null) { - opBuilder.setAttr("task_refresh_interval_hint_ms", opts.taskRefreshIntervalHintMs); - } - if (opts.dataTransferProtocol != null) { - opBuilder.setAttr("data_transfer_protocol", opts.dataTransferProtocol); - } - } - } - return new DataServiceDataset(opBuilder.build()); - } - - /** - * Sets the taskRefreshIntervalHintMs option. - * - * @param taskRefreshIntervalHintMs the taskRefreshIntervalHintMs option - * @return this Options instance. - */ - public static Options taskRefreshIntervalHintMs(Long taskRefreshIntervalHintMs) { - return new Options().taskRefreshIntervalHintMs(taskRefreshIntervalHintMs); - } - - /** - * Sets the dataTransferProtocol option. - * - * @param dataTransferProtocol the dataTransferProtocol option - * @return this Options instance. - */ - public static Options dataTransferProtocol(String dataTransferProtocol) { - return new Options().dataTransferProtocol(dataTransferProtocol); - } - - /** - * Gets handle. - * - * @return handle. - */ - public Output<? extends TType> handle() { - return handle; - } - - @Override - @SuppressWarnings("unchecked") - public Output<TType> asOutput() { - return (Output<TType>) handle; - } - - /** - * Optional attributes for {@link org.tensorflow.op.data.DataServiceDataset} - */ - public static class Options { - private Long taskRefreshIntervalHintMs; - - private String dataTransferProtocol; - - private Options() { - } - - /** - * Sets the taskRefreshIntervalHintMs option. - * - * @param taskRefreshIntervalHintMs the taskRefreshIntervalHintMs option - * @return this Options instance. - */ - public Options taskRefreshIntervalHintMs(Long taskRefreshIntervalHintMs) { - this.taskRefreshIntervalHintMs = taskRefreshIntervalHintMs; - return this; - } - - /** - * Sets the dataTransferProtocol option. - * - * @param dataTransferProtocol the dataTransferProtocol option - * @return this Options instance. - */ - public Options dataTransferProtocol(String dataTransferProtocol) { - this.dataTransferProtocol = dataTransferProtocol; - return this; - } - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/rawops/DataServiceDatasetV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DataServiceDatasetV2.java similarity index 96% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/rawops/DataServiceDatasetV2.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DataServiceDatasetV2.java index 96059a264e4..69c97855be0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/rawops/DataServiceDatasetV2.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DataServiceDatasetV2.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.rawops; +package org.tensorflow.op.data; import java.util.List; import org.tensorflow.Operand; @@ -27,6 +27,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; @@ -34,6 +35,9 @@ /** * Creates a dataset that reads data from the tf.data service. */ +@Operator( + group = "data" +) public final class DataServiceDatasetV2 extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine @@ -141,7 +145,7 @@ public Output<TType> asOutput() { } /** - * Optional attributes for {@link org.tensorflow.op.rawops.DataServiceDatasetV2} + * Optional attributes for {@link org.tensorflow.op.data.DataServiceDatasetV2} */ public static class Options { private Long taskRefreshIntervalHintMs; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetCardinality.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetCardinality.java index 2d1b8a24ba5..8ec90a071c5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetCardinality.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetCardinality.java @@ -24,6 +24,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -31,6 +32,9 @@ * Returns the cardinality of {@code input_dataset}. * Returns the cardinality of {@code input_dataset}. */ +@Operator( + group = "data" +) public final class DatasetCardinality extends RawOp implements Operand<TInt64> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetFromGraph.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetFromGraph.java index 0064b4a1efa..d54a9ce08de 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetFromGraph.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetFromGraph.java @@ -24,6 +24,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; @@ -31,6 +32,9 @@ * Creates a dataset from the given {@code graph_def}. * Creates a dataset from the provided {@code graph_def}. */ +@Operator( + group = "data" +) public final class DatasetFromGraph extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToGraph.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToGraph.java index 1f550abaf7d..d799766128b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToGraph.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToGraph.java @@ -24,6 +24,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; @@ -31,6 +32,9 @@ * Returns a serialized GraphDef representing {@code input_dataset}. * Returns a graph representation for {@code input_dataset}. */ +@Operator( + group = "data" +) public final class DatasetToGraph extends RawOp implements Operand<TString> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToSingleElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToSingleElement.java index dffd7bebf02..9b302486b3b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToSingleElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToSingleElement.java @@ -29,11 +29,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Outputs the single element from the given dataset. */ +@Operator( + group = "data" +) public final class DatasetToSingleElement extends RawOp implements Iterable<Operand<TType>> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToTfRecord.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToTfRecord.java index 4ccdc21bafa..9e8b386200f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToTfRecord.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToTfRecord.java @@ -23,12 +23,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; /** * Writes the given dataset to the given file using the TFRecord format. */ +@Operator( + group = "data" +) public final class DatasetToTfRecord extends RawOp { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DenseToSparseBatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DenseToSparseBatchDataset.java index 07d490b1008..ecea765507d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DenseToSparseBatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DenseToSparseBatchDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * Creates a dataset that batches input elements into a SparseTensor. */ +@Operator( + group = "data" +) public final class DenseToSparseBatchDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DirectedInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DirectedInterleaveDataset.java index 5f4f262a18c..58637eb9f7f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DirectedInterleaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DirectedInterleaveDataset.java @@ -27,11 +27,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * A substitute for {@code InterleaveDataset} on a fixed list of {@code N} datasets. */ +@Operator( + group = "data" +) public final class DirectedInterleaveDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterByLastComponentDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterByLastComponentDataset.java index 22ef6836c25..0dfdd3eb57d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterByLastComponentDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterByLastComponentDataset.java @@ -27,11 +27,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Creates a dataset containing elements of first component of {@code input_dataset} having true in the last component. */ +@Operator( + group = "data" +) public final class FilterByLastComponentDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/rawops/FinalizeDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FinalizeDataset.java similarity index 95% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/rawops/FinalizeDataset.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FinalizeDataset.java index 6b37412d717..a56b7d175bf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/rawops/FinalizeDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FinalizeDataset.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.rawops; +package org.tensorflow.op.data; import java.util.List; import org.tensorflow.Operand; @@ -27,11 +27,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Creates a dataset by applying {@code tf.data.Options} to {@code input_dataset}. */ +@Operator( + group = "data" +) public final class FinalizeDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine @@ -107,7 +111,7 @@ public Output<TType> asOutput() { } /** - * Optional attributes for {@link org.tensorflow.op.rawops.FinalizeDataset} + * Optional attributes for {@link org.tensorflow.op.data.FinalizeDataset} */ public static class Options { private Boolean hasCapturedRef; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FixedLengthRecordDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FixedLengthRecordDataset.java index 6c5f4ec9fec..adf6f2a31d4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FixedLengthRecordDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FixedLengthRecordDataset.java @@ -24,6 +24,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; @@ -31,6 +32,9 @@ /** * The FixedLengthRecordDatasetV2 operation */ +@Operator( + group = "data" +) public final class FixedLengthRecordDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GeneratorDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GeneratorDataset.java index d1f47611ba2..0887f089adb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GeneratorDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GeneratorDataset.java @@ -28,11 +28,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Creates a dataset that invokes a function to generate elements. */ +@Operator( + group = "data" +) public final class GeneratorDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByReducerDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByReducerDataset.java index cff25af74b0..e9c2828179e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByReducerDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByReducerDataset.java @@ -28,12 +28,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Creates a dataset that computes a group-by on {@code input_dataset}. * Creates a dataset that computes a group-by on {@code input_dataset}. */ +@Operator( + group = "data" +) public final class GroupByReducerDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByWindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByWindowDataset.java index 11478287d00..f84ec83d20c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByWindowDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/GroupByWindowDataset.java @@ -28,12 +28,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Creates a dataset that computes a windowed group-by on {@code input_dataset}. * // TODO(mrry): Support non-int64 keys. */ +@Operator( + group = "data" +) public final class GroupByWindowDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IgnoreErrorsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IgnoreErrorsDataset.java index 152d5696a91..5eb18474f4f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IgnoreErrorsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IgnoreErrorsDataset.java @@ -27,11 +27,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Creates a dataset that contains the elements of {@code input_dataset} ignoring errors. */ +@Operator( + group = "data" +) public final class IgnoreErrorsDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InitializeTableFromDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InitializeTableFromDataset.java index b57533279cd..7afd248da3d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InitializeTableFromDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InitializeTableFromDataset.java @@ -23,11 +23,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * The InitializeTableFromDataset operation */ +@Operator( + group = "data" +) public final class InitializeTableFromDataset extends RawOp { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LMDBDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LMDBDataset.java index 77d37448de1..9aef6cf746b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LMDBDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LMDBDataset.java @@ -27,6 +27,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; @@ -41,6 +42,9 @@ * <p>LMDB uses different file formats on big- and little-endian machines. * {@code data.LMDBDataset} can only read files in the format of the host machine. */ +@Operator( + group = "data" +) public final class LMDBDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LatencyStatsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LatencyStatsDataset.java index 29256f43c80..2d30dae30b2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LatencyStatsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LatencyStatsDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; /** * Records the latency of producing {@code input_dataset} elements in a StatsAggregator. */ +@Operator( + group = "data" +) public final class LatencyStatsDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LegacyParallelInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LegacyParallelInterleaveDataset.java index adeb0e4f634..6f769b45b97 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LegacyParallelInterleaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LegacyParallelInterleaveDataset.java @@ -28,6 +28,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -40,6 +41,9 @@ * allows the training step to proceed so long as some data is available. * <p>!! WARNING !! This dataset is not deterministic! */ +@Operator( + group = "data" +) public final class LegacyParallelInterleaveDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LoadDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LoadDataset.java index 95ad802e5ac..5897e41b71e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LoadDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LoadDataset.java @@ -28,12 +28,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; /** * The LoadDataset operation */ +@Operator( + group = "data" +) public final class LoadDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MatchingFilesDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MatchingFilesDataset.java index a4fefc83572..d90d48a6011 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MatchingFilesDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MatchingFilesDataset.java @@ -24,12 +24,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; /** * The MatchingFilesDataset operation */ +@Operator( + group = "data" +) public final class MatchingFilesDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MaxIntraOpParallelismDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MaxIntraOpParallelismDataset.java index b63d4d577e3..406c90a988c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MaxIntraOpParallelismDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MaxIntraOpParallelismDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * Creates a dataset that overrides the maximum intra-op parallelism. */ +@Operator( + group = "data" +) public final class MaxIntraOpParallelismDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ModelDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ModelDataset.java index 4133f57bddf..a88c103708f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ModelDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ModelDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Identity transformation that models performance. * Identity transformation that models performance. */ +@Operator( + group = "data" +) public final class ModelDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/NonSerializableDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/NonSerializableDataset.java index 8acb8a1a9c7..7db69b45f8d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/NonSerializableDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/NonSerializableDataset.java @@ -27,11 +27,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * The NonSerializableDataset operation */ +@Operator( + group = "data" +) public final class NonSerializableDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDataset.java index e4c68e1585b..7c1352a9e00 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDataset.java @@ -28,18 +28,22 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; /** - * Creates a dataset by applying optimizations to {@code input_dataset}. - * Creates a dataset by applying optimizations to {@code input_dataset}. + * Creates a dataset by applying related optimizations to {@code input_dataset}. + * Creates a dataset by applying related optimizations to {@code input_dataset}. */ +@Operator( + group = "data" +) public final class OptimizeDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "OptimizeDataset"; + public static final String OP_NAME = "OptimizeDatasetV2"; private Output<? extends TType> handle; @@ -51,11 +55,13 @@ private OptimizeDataset(Operation operation) { } /** - * Factory method to create a class wrapping a new OptimizeDataset operation. + * Factory method to create a class wrapping a new OptimizeDatasetV2 operation. * * @param scope current scope * @param inputDataset A variant tensor representing the input dataset. - * @param optimizations A {@code tf.string} vector {@code tf.Tensor} identifying optimizations to use. + * @param optimizationsEnabled A {@code tf.string} vector {@code tf.Tensor} identifying user enabled optimizations. + * @param optimizationsDisabled A {@code tf.string} vector {@code tf.Tensor} identifying user disabled optimizations. + * @param optimizationsDefault A {@code tf.string} vector {@code tf.Tensor} identifying optimizations by default. * @param outputTypes the value of the outputTypes property * @param outputShapes the value of the outputShapes property * @param options carries optional attribute values @@ -65,11 +71,14 @@ private OptimizeDataset(Operation operation) { describeByClass = true ) public static OptimizeDataset create(Scope scope, Operand<? extends TType> inputDataset, - Operand<TString> optimizations, List<Class<? extends TType>> outputTypes, + Operand<TString> optimizationsEnabled, Operand<TString> optimizationsDisabled, + Operand<TString> optimizationsDefault, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("OptimizeDataset")); opBuilder.addInput(inputDataset.asOutput()); - opBuilder.addInput(optimizations.asOutput()); + opBuilder.addInput(optimizationsEnabled.asOutput()); + opBuilder.addInput(optimizationsDisabled.asOutput()); + opBuilder.addInput(optimizationsDefault.asOutput()); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDatasetV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDatasetV2.java deleted file mode 100644 index f2a7659cd76..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDatasetV2.java +++ /dev/null @@ -1,165 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.data; - -import java.util.Arrays; -import java.util.List; -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Operands; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.types.TString; -import org.tensorflow.types.family.TType; - -/** - * Creates a dataset by applying related optimizations to {@code input_dataset}. - * Creates a dataset by applying related optimizations to {@code input_dataset}. - */ -public final class OptimizeDatasetV2 extends RawOp implements Operand<TType> { - /** - * The name of this op, as known by TensorFlow core engine - */ - public static final String OP_NAME = "OptimizeDatasetV2"; - - private Output<? extends TType> handle; - - @SuppressWarnings("unchecked") - private OptimizeDatasetV2(Operation operation) { - super(operation); - int outputIdx = 0; - handle = operation.output(outputIdx++); - } - - /** - * Factory method to create a class wrapping a new OptimizeDatasetV2 operation. - * - * @param scope current scope - * @param inputDataset A variant tensor representing the input dataset. - * @param optimizationsEnabled A {@code tf.string} vector {@code tf.Tensor} identifying user enabled optimizations. - * @param optimizationsDisabled A {@code tf.string} vector {@code tf.Tensor} identifying user disabled optimizations. - * @param optimizationsDefault A {@code tf.string} vector {@code tf.Tensor} identifying optimizations by default. - * @param outputTypes the value of the outputTypes property - * @param outputShapes the value of the outputShapes property - * @param options carries optional attribute values - * @return a new instance of OptimizeDatasetV2 - */ - @Endpoint( - describeByClass = true - ) - public static OptimizeDatasetV2 create(Scope scope, Operand<? extends TType> inputDataset, - Operand<TString> optimizationsEnabled, Operand<TString> optimizationsDisabled, - Operand<TString> optimizationsDefault, List<Class<? extends TType>> outputTypes, - List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("OptimizeDatasetV2")); - opBuilder.addInput(inputDataset.asOutput()); - opBuilder.addInput(optimizationsEnabled.asOutput()); - opBuilder.addInput(optimizationsDisabled.asOutput()); - opBuilder.addInput(optimizationsDefault.asOutput()); - opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); - Shape[] outputShapesArray = new Shape[outputShapes.size()]; - for (int i = 0 ; i < outputShapesArray.length ; i++) { - outputShapesArray[i] = outputShapes.get(i); - } - opBuilder.setAttr("output_shapes", outputShapesArray); - if (options != null) { - for (Options opts : options) { - if (opts.optimizationConfigs != null) { - String[] optimizationConfigsArray = new String[opts.optimizationConfigs.size()]; - for (int i = 0 ; i < optimizationConfigsArray.length ; i++) { - optimizationConfigsArray[i] = opts.optimizationConfigs.get(i); - } - opBuilder.setAttr("optimization_configs", optimizationConfigsArray); - } - } - } - return new OptimizeDatasetV2(opBuilder.build()); - } - - /** - * Sets the optimizationConfigs option. - * - * @param optimizationConfigs the optimizationConfigs option - * @return this Options instance. - */ - public static Options optimizationConfigs(List<String> optimizationConfigs) { - return new Options().optimizationConfigs(optimizationConfigs); - } - - /** - * Sets the optimizationConfigs option. - * - * @param optimizationConfigs the optimizationConfigs option - * @return this Options instance. - */ - public static Options optimizationConfigs(String[] optimizationConfigs) { - return new Options().optimizationConfigs(optimizationConfigs); - } - - /** - * Gets handle. - * - * @return handle. - */ - public Output<? extends TType> handle() { - return handle; - } - - @Override - @SuppressWarnings("unchecked") - public Output<TType> asOutput() { - return (Output<TType>) handle; - } - - /** - * Optional attributes for {@link org.tensorflow.op.data.OptimizeDatasetV2} - */ - public static class Options { - private List<String> optimizationConfigs; - - private Options() { - } - - /** - * Sets the optimizationConfigs option. - * - * @param optimizationConfigs the optimizationConfigs option - * @return this Options instance. - */ - public Options optimizationConfigs(List<String> optimizationConfigs) { - this.optimizationConfigs = optimizationConfigs; - return this; - } - - /** - * Sets the optimizationConfigs option. - * - * @param optimizationConfigs the optimizationConfigs option - * @return this Options instance. - */ - public Options optimizationConfigs(String... optimizationConfigs) { - this.optimizationConfigs = Arrays.asList(optimizationConfigs); - return this; - } - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/rawops/OptionsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptionsDataset.java similarity index 96% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/rawops/OptionsDataset.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptionsDataset.java index ec3bf2560c2..2ff6923fc71 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/rawops/OptionsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptionsDataset.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.rawops; +package org.tensorflow.op.data; import java.util.List; import org.tensorflow.Operand; @@ -27,11 +27,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Creates a dataset by attaching tf.data.Options to {@code input_dataset}. */ +@Operator( + group = "data" +) public final class OptionsDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PaddedBatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PaddedBatchDataset.java index 5e74c59c2dc..058015155ea 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PaddedBatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PaddedBatchDataset.java @@ -27,6 +27,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -34,6 +35,9 @@ /** * Creates a dataset that batches and pads {@code batch_size} elements from the input. */ +@Operator( + group = "data" +) public final class PaddedBatchDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/rawops/ParallelBatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelBatchDataset.java similarity index 95% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/rawops/ParallelBatchDataset.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelBatchDataset.java index 23fe867b5ae..f313afbf0b8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/rawops/ParallelBatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelBatchDataset.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.rawops; +package org.tensorflow.op.data; import java.util.List; import org.tensorflow.Operand; @@ -27,6 +27,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -34,6 +35,9 @@ /** * The ParallelBatchDataset operation */ +@Operator( + group = "data" +) public final class ParallelBatchDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine @@ -116,7 +120,7 @@ public Output<TType> asOutput() { } /** - * Optional attributes for {@link org.tensorflow.op.rawops.ParallelBatchDataset} + * Optional attributes for {@link org.tensorflow.op.data.ParallelBatchDataset} */ public static class Options { private String deterministic; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelInterleaveDataset.java index 59571258049..bf87079b34c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelInterleaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelInterleaveDataset.java @@ -28,6 +28,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -45,6 +46,9 @@ * {@code experimental_deterministic} parameter of {@code tf.data.Options} to {@code False}. * This can improve performance at the expense of non-determinism. */ +@Operator( + group = "data" +) public final class ParallelInterleaveDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelMapDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelMapDataset.java index ee9d6ec2c85..8d3c1686b5f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelMapDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParallelMapDataset.java @@ -28,6 +28,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -36,6 +37,9 @@ * Unlike a "MapDataset", which applies {@code f} sequentially, this dataset invokes up * to {@code num_parallel_calls} copies of {@code f} in parallel. */ +@Operator( + group = "data" +) public final class ParallelMapDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParseExampleDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParseExampleDataset.java index 35c277f46a9..6389e91a952 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParseExampleDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ParseExampleDataset.java @@ -28,6 +28,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -35,6 +36,9 @@ /** * Transforms {@code input_dataset} containing {@code Example} protos as vectors of DT_STRING into a dataset of {@code Tensor} or {@code SparseTensor} objects representing the parsed features. */ +@Operator( + group = "data" +) public final class ParseExampleDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrefetchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrefetchDataset.java index 8ef26922652..92d74b4742a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrefetchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrefetchDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * Creates a dataset that asynchronously prefetches elements from {@code input_dataset}. */ +@Operator( + group = "data" +) public final class PrefetchDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrivateThreadPoolDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrivateThreadPoolDataset.java index 3290840da75..e5e387da993 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrivateThreadPoolDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrivateThreadPoolDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * Creates a dataset that uses a custom thread pool to compute {@code input_dataset}. */ +@Operator( + group = "data" +) public final class PrivateThreadPoolDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RandomDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RandomDataset.java index 788ff579222..ef1b7a72431 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RandomDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RandomDataset.java @@ -27,6 +27,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -41,6 +42,9 @@ * performed is determined by the {@code experimental_optimization.hoist_random_uniform} * option of {@code tf.data.Options}. */ +@Operator( + group = "data" +) public final class RandomDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDataset.java deleted file mode 100644 index 874587c68e7..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDataset.java +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.data; - -import java.util.List; -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Operands; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TType; - -/** - * Creates a dataset that changes the batch size. - * Creates a dataset that changes the batch size of the dataset to current batch - * size // num_workers. - */ -public final class RebatchDataset extends RawOp implements Operand<TType> { - /** - * The name of this op, as known by TensorFlow core engine - */ - public static final String OP_NAME = "RebatchDataset"; - - private Output<? extends TType> handle; - - @SuppressWarnings("unchecked") - private RebatchDataset(Operation operation) { - super(operation); - int outputIdx = 0; - handle = operation.output(outputIdx++); - } - - /** - * Factory method to create a class wrapping a new RebatchDataset operation. - * - * @param scope current scope - * @param inputDataset A variant tensor representing the input dataset. - * @param numReplicas A scalar representing the number of replicas to distribute this batch across. As - * a result of this transformation the current batch size would end up being - * divided by this parameter. - * @param outputTypes the value of the outputTypes property - * @param outputShapes the value of the outputShapes property - * @param options carries optional attribute values - * @return a new instance of RebatchDataset - */ - @Endpoint( - describeByClass = true - ) - public static RebatchDataset create(Scope scope, Operand<? extends TType> inputDataset, - Operand<TInt64> numReplicas, List<Class<? extends TType>> outputTypes, - List<Shape> outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("RebatchDataset")); - opBuilder.addInput(inputDataset.asOutput()); - opBuilder.addInput(numReplicas.asOutput()); - opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); - Shape[] outputShapesArray = new Shape[outputShapes.size()]; - for (int i = 0 ; i < outputShapesArray.length ; i++) { - outputShapesArray[i] = outputShapes.get(i); - } - opBuilder.setAttr("output_shapes", outputShapesArray); - if (options != null) { - for (Options opts : options) { - if (opts.useFallback != null) { - opBuilder.setAttr("use_fallback", opts.useFallback); - } - } - } - return new RebatchDataset(opBuilder.build()); - } - - /** - * Sets the useFallback option. - * - * @param useFallback the useFallback option - * @return this Options instance. - */ - public static Options useFallback(Boolean useFallback) { - return new Options().useFallback(useFallback); - } - - /** - * Gets handle. - * - * @return handle. - */ - public Output<? extends TType> handle() { - return handle; - } - - @Override - @SuppressWarnings("unchecked") - public Output<TType> asOutput() { - return (Output<TType>) handle; - } - - /** - * Optional attributes for {@link org.tensorflow.op.data.RebatchDataset} - */ - public static class Options { - private Boolean useFallback; - - private Options() { - } - - /** - * Sets the useFallback option. - * - * @param useFallback the useFallback option - * @return this Options instance. - */ - public Options useFallback(Boolean useFallback) { - this.useFallback = useFallback; - return this; - } - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDatasetV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDatasetV2.java index b98013d55bc..2c50b6cf519 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDatasetV2.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDatasetV2.java @@ -27,6 +27,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -36,6 +37,9 @@ * Creates a dataset that rebatches elements from {@code input_dataset} into new batch * sizes. */ +@Operator( + group = "data" +) public final class RebatchDatasetV2 extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ReduceDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ReduceDataset.java index d2c4ffd3aae..143ec183fee 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ReduceDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ReduceDataset.java @@ -30,11 +30,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Reduces the input dataset to a singleton using a reduce function. */ +@Operator( + group = "data" +) public final class ReduceDataset extends RawOp implements Iterable<Operand<TType>> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RegisterDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RegisterDataset.java index 5809a919ebd..e389557f99f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RegisterDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RegisterDataset.java @@ -24,6 +24,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; @@ -31,6 +32,9 @@ /** * Registers a dataset with the tf.data service. */ +@Operator( + group = "data" +) public final class RegisterDataset extends RawOp implements Operand<TInt64> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SamplingDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SamplingDataset.java index 27778a1dd49..550df76210e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SamplingDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SamplingDataset.java @@ -27,6 +27,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -39,6 +40,9 @@ * {@code experimental_optimization.filter_with_random_uniform_fusion} option of * {@code tf.data.Options}. */ +@Operator( + group = "data" +) public final class SamplingDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SaveDataset.java index 4e5944ac7d9..74a20665508 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SaveDataset.java @@ -25,12 +25,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; /** * The SaveDataset operation */ +@Operator( + group = "data" +) public final class SaveDataset extends RawOp { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ScanDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ScanDataset.java index dde6aea2ca5..ab71a2a7da3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ScanDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ScanDataset.java @@ -28,11 +28,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Creates a dataset successively reduces {@code f} over the elements of {@code input_dataset}. */ +@Operator( + group = "data" +) public final class ScanDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SetStatsAggregatorDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SetStatsAggregatorDataset.java index 63bce56359c..fd935d5862f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SetStatsAggregatorDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SetStatsAggregatorDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; /** * The SetStatsAggregatorDataset operation */ +@Operator( + group = "data" +) public final class SetStatsAggregatorDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShardDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShardDataset.java index ea9b7059eea..1609721dbde 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShardDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShardDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * Creates a {@code Dataset} that includes only 1/{@code num_shards} of this dataset. */ +@Operator( + group = "data" +) public final class ShardDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java index 1a7d3667632..c57edd1bb10 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * The ShuffleAndRepeatDatasetV2 operation */ +@Operator( + group = "data" +) public final class ShuffleAndRepeatDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java index 4ae9f5a57cf..4bdbd0abc2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * The ShuffleDatasetV3 operation */ +@Operator( + group = "data" +) public final class ShuffleDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SleepDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SleepDataset.java index bc3b16e01a2..74992fd71fa 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SleepDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SleepDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * The SleepDataset operation */ +@Operator( + group = "data" +) public final class SleepDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SlidingWindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SlidingWindowDataset.java index 123a5780271..4a5d1629816 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SlidingWindowDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SlidingWindowDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * Creates a dataset that passes a sliding window over {@code input_dataset}. */ +@Operator( + group = "data" +) public final class SlidingWindowDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java index 3fc7bd3ec93..a1cbcd27635 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java @@ -28,6 +28,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; @@ -38,6 +39,9 @@ * If not, it will run the preprocessing pipeline as usual, and write out a * snapshot of the data processed for future use. */ +@Operator( + group = "data" +) public final class SnapshotDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SparseTensorSliceDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SparseTensorSliceDataset.java index 188bf6f6a8b..6de2eb40539 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SparseTensorSliceDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SparseTensorSliceDataset.java @@ -24,12 +24,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** * Creates a dataset that splits a SparseTensor into elements row-wise. */ +@Operator( + group = "data" +) public final class SparseTensorSliceDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SqlDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SqlDataset.java index 7240a21d3b4..5fd6a3d5784 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SqlDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SqlDataset.java @@ -27,12 +27,16 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; /** * Creates a dataset that executes a SQL query and emits rows of the result set. */ +@Operator( + group = "data" +) public final class SqlDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeWhileDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeWhileDataset.java index 3bcf539bd80..5f916997113 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeWhileDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeWhileDataset.java @@ -28,6 +28,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** @@ -39,6 +40,9 @@ * <li>One tensor for each value in {@code other_arguments}.</li> * </ul> */ +@Operator( + group = "data" +) public final class TakeWhileDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TensorDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TensorDataset.java index 1c93d86e475..a089e8e1be5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TensorDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TensorDataset.java @@ -27,11 +27,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Creates a dataset that emits {@code components} as a tuple of tensors once. */ +@Operator( + group = "data" +) public final class TensorDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ThreadPoolDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ThreadPoolDataset.java index 9f67fdf0a66..17a7c0c9065 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ThreadPoolDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ThreadPoolDataset.java @@ -27,11 +27,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Creates a dataset that uses a custom thread pool to compute {@code input_dataset}. */ +@Operator( + group = "data" +) public final class ThreadPoolDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnbatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnbatchDataset.java index a827341e997..2fd64618e8e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnbatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnbatchDataset.java @@ -27,11 +27,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * A dataset that splits the elements of its input into multiple elements. */ +@Operator( + group = "data" +) public final class UnbatchDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UniqueDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UniqueDataset.java index 79e21583277..6938c786469 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UniqueDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UniqueDataset.java @@ -27,11 +27,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * Creates a dataset that contains the unique elements of {@code input_dataset}. */ +@Operator( + group = "data" +) public final class UniqueDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnwrapDatasetVariant.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnwrapDatasetVariant.java index df30e1749d9..2bbc602b5fc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnwrapDatasetVariant.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnwrapDatasetVariant.java @@ -24,11 +24,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * The UnwrapDatasetVariant operation */ +@Operator( + group = "data" +) public final class UnwrapDatasetVariant extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WindowDataset.java index 3e2069a2a93..3988ddcd99d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WindowDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WindowDataset.java @@ -27,6 +27,7 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -69,6 +70,9 @@ * produces {@code {{"a": {0, 1}}, {"a": {2, 3}}}}</li> * </ul> */ +@Operator( + group = "data" +) public final class WindowDataset extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WrapDatasetVariant.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WrapDatasetVariant.java index 8f6de33fda7..aec61c37b21 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WrapDatasetVariant.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WrapDatasetVariant.java @@ -24,11 +24,15 @@ import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** * The WrapDatasetVariant operation */ +@Operator( + group = "data" +) public final class WrapDatasetVariant extends RawOp implements Operand<TType> { /** * The name of this op, as known by TensorFlow core engine diff --git a/tensorflow-core/tensorflow-core-api/src/gen/resources/ops.pb b/tensorflow-core/tensorflow-core-api/src/gen/resources/ops.pb index 14aed40998768ed31aaf82bf24803c9751f126bf..40322dc87805dc5e0fa65f0e0cfc1cbd1fb53a92 100644 GIT binary patch delta 1317 zcmYk6drVtZ7{GH*&%M1px4pf$y_ZLMi~)NjI#wqmf)b_{jG1g?O-Lkjq(VFD0L`cn z5p`%x&CtPE*a;ySV}Hn^Xn;jelo?$#i;7FwA024sfF&w(=v<a1%0v8aEswu`$@l)g z^L>YnRNWq~8oqNLp$mVVZ$N-U9|Rw?Ax|<jY6Fn~FBuZvN9?ObfxOi`#ha&4_-%yv z%yT&3YF!X0ebQgaE-kx@+4T`&LjTngm-xW5DHu_)$4)7KJcMHiUE;}gcpGx{5^4vh zgh}Nt74b(x*g*X73;0I5bW;Fzpr(##^VW2B)<GPsIf^TiW=%Jc2iIdO)Xrz4cppZ8 zGVgKRrT>hmveCv2&ZeTcfZIE8yBq30aR#5k2t{`cV7_C(vb_p@!z&jpK9s&fsLKrP z4<2m}*KEk5Qx_jwvV<_6lzL8*Bwn~;sYK{jPZvpnIi=CmT5$X10?DvIf~yuX22<qk z38air<T0q*ROhjPmBjccCEcdB)J}R$PCJOd6M|&XOw)>=91}V)zuZrhFccY+1Rj&b z>=0P~-6(bwG|f7i#8VmLm^#bCje+`b>sq16lSK7+km86qWCxeukBM(y0XFQjRx7|O z2d(egfK?&u{XM{nLF<GUxOvKY>>1#>TuDy@@)D`d4W)0Tx&N20`K3+#pkF$d1=^J+ zDR~UEiN5m~+G*3-hqJ=N{P0fO<XJE&nXvWL0%Q5~-f`e}Us9BWM%Z|aj;8)=q0ap; zMZygmYR4XpI(gqUI)~(OiE0Tw)3#`tY-oy(Cr^rN($$XRkK28X9^s2W(p#eZF`<g- z_+n$zsk<=`8*UPvEM=Bf=k{6FG(MQJxtQw@lW<`(KJ8i79}(T$c9Xv9g$`3`a=d?D zt;pLk|5TTk9e9GUt{gduT`x4;>_)acjPA3}9680kK9P#R(LjBDu>J{FU1U7Kf6I}V z&2Wh~&SL|=2&ud<U$!BxR?7J*LW89tc{2>EEGmz_2K<NEu?H4lDbvoEG99;JXZYNR z!@FPX7O8R|$R-MnO{^FxZE1I*>jS~(1Cg4#)!wj2jmy}qY&*nz7lc;CM=Z)NB*d?p zT}6tUFB0W-$jn;}WoZ+nYT2#umfg-j7C;vCs;`K^=>oMtgY9%t)f@z$Sr)b;?t5DO z`Z8!GRU0J0VO{&lT&s>utp(YZZB+5sy&MX$haS~4*h-4F<+N9Iu&OLAoN`f0DC<`r z5q=0Lmd|8qpPL!`wrLsWnSNHJeQRFN`jQb*=lvzxM>5n}zE|VrdtC$VV0q^U*ElR~ zbZic;bk1Gs2N`-y&oY_asRxe$-&xSVF9G_CjC(%dQ*Rh2iZ@1Yzi8;JZQdy7d%iH7 SIbc#TW2oj01?G){q5lAj#iSts delta 1323 zcmYk5e@s(X6vuP#yYIEVk3RbP3re9-Kw+5!Mu2QA6Be0sV^LfaTpa2uwu6mdf$oRf zj8&K5@CSy3-RuWy|AH8jfo9w#;)G~a*vOcRC?rnE$cAH!E@Zkn96PTqmA^iD=iK}K zo^$Sb^L<-c|JJ@cJxDlnx2GNfwx9B?LBfHH6a3-?a<(FYPjABoI^B{;={ip(Q#18; zcCuf*Z<-gvkuq<<+u{qD1YTZ?ZHR4n1CQt~NHk76{X|f5u-=E3)$COtZe}AjxaqOa zNjMk1SD9}Tk}Oo+1u{8sB^FM%vykSs-H(6cufK_ZN=Co3l2#lv|Dp+at808QYj&j# z7jb+LyRqRerb;-=Dqd7m*dZfP*p@cD&3y|+6EO>LwoyyO8T8m^tJKH&e{Fc51=<cC z$Hy@eT2n)qr-p>=t!RXA_*W=H?vF%j)-*i6=67nVA61d$pb23y-m@t9F`DG@OF})u z^CB}X33emcrBG+H_ua;(mYmKYv3_-i=moGE6v$baJ%3Lm>k06@fi!_N&x<3UYVlDc z@w5Y7T_p4tlr*Fl`H)^p_e1RG2C0L{GXjlYU)$iVt9_ph=gUrZ><7`vdRip!sxcf( zttd5-+M|Vpud(1Ht>Om*(o1om`D0MpKL)JasN10c2dZ?34ZwAN-Cw1^r@M9Ic|dwY z7kn9bC{u5B0#BCcKh#E-;?a+$0q0uu31^WokZZ^a!Dsfkp>r4Tn*zGfM#IAhkBjBb zigva*D+c-IA-dBh4~R5=Q&XU}KGIDxo4Q61yAF_8bVVa}YK_c%?*yGgXpGg|pbOcP z5>=xL%V)!j25pB=FXz}8QL6aONqSpCeeCiy?O5F<P(@2OHnfy2uXa)bwMEAjO|jg7 zRL5`6&=2yUli0(TC~`LDpIGIGMtqg9rVQE6_IyDy+4t%433Q(|X2^;1iZ$TJGUWUx zK_QL0BD3h<tatE)0@;B0c&S{0_}VS<EIdMYhF^BW@G@KF(|dr|NgPgsU6^vk@hO*a zz5`;H`i*&;2$QJ7D*h3>67LbJyyDyG4OG`WY6Ui)4HLz}&mm<=+p=>+=`;a7Hf1pt z6e|i9UQrnT#{-D1%`@$mfL|Axid5j87fgR%2F}E*-P)#pY*nWWP`-;MN6J6pP@iJm zS!yx6<x&^%e#&m9s^`tada8J8-uC<JmY;*w!y$g%p<Y6)_ixd{2h-Klnq*^^nxZYy zkz(}<0cj_7M5IegR4ZRYq#c-FeNNp73)$)}SGl{~9NGuxFMMbI7M8HJea_6==Pb{A zAa?B;YnsM;N3A~iad_1O>u3outJt<s26XPV9eoyfYS6Yj1H7!%9P#>cOw8f0P1`1A L$QJ%?E9(6pPkyWj