diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/java_defs.h b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/java_defs.h index e41dc2dd9df..6028c4ea71d 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/java_defs.h +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/java_defs.h @@ -97,9 +97,6 @@ class Type { static Type IterableOf(const Type& type) { return Interface("Iterable").add_parameter(type); } - static Type DataTypeOf(const Type& type) { - return Class("DataType", "org.tensorflow").add_parameter(type); - } static Type ForDataType(DataType data_type) { switch (data_type) { case DataType::DT_BOOL: diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc index 14d33c58fc7..16744db3799 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc @@ -103,13 +103,22 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, } for (const AttributeSpec& attribute : op.attributes()) { out->push_back(attribute.var().type()); - out->push_back(attribute.jni_type()); + if (attribute.jni_type().name() == "DataType") { + out->push_back(Type::Class("Operands", "org.tensorflow.op")); + } else { + out->push_back(attribute.jni_type()); + } if (attribute.has_default_value() && attribute.type().kind() == Type::GENERIC) { out->push_back(Type::ForDataType(attribute.default_value()->type())); } } for (const AttributeSpec& optional_attribute : op.optional_attributes()) { + if (optional_attribute.jni_type().name() == "DataType") { + out->push_back(Type::Class("Operands", "org.tensorflow.op")); + } else { + out->push_back(optional_attribute.jni_type()); + } out->push_back(optional_attribute.var().type()); } } @@ -117,25 +126,32 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, void WriteSetAttrDirective(const AttributeSpec& attr, bool optional, SourceWriter* writer) { string var_name = optional ? "opts." + attr.var().name() : attr.var().name(); - if (attr.iterable()) { - string array_name = attr.var().name() + "Array"; - writer->AppendType(attr.jni_type()) - .Append("[] " + array_name + " = new ") - .AppendType(attr.jni_type()) - .Append("[" + var_name + ".size()];") - .EndLine() - .BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)") - .Append(array_name + "[i] = "); - writer->Append(var_name + ".get(i);"); - writer->EndLine() - .EndBlock() - .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ") - .Append(array_name + ");") - .EndLine(); - } else { + if (attr.jni_type().name() == "DataType") { writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ") - .Append(var_name + ");") - .EndLine(); + .Append(attr.iterable() ? "Operands.toDataTypes(" : "Operands.toDataType(") + .Append(attr.var().name() + "));") + .EndLine(); + } else { + if (attr.iterable()) { + string array_name = attr.var().name() + "Array"; + writer->AppendType(attr.jni_type()) + .Append("[] " + array_name + " = new ") + .AppendType(attr.jni_type()) + .Append("[" + var_name + ".size()];") + .EndLine() + .BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)") + .Append(array_name + "[i] = "); + writer->Append(var_name + ".get(i);"); + writer->EndLine() + .EndBlock() + .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ") + .Append(array_name + ");") + .EndLine(); + } else { + writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ") + .Append(var_name + ");") + .EndLine(); + } } } @@ -177,7 +193,7 @@ void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class, if (attr.type().kind() == Type::GENERIC && default_types.find(attr.type().name()) != default_types.end()) { factory_statement << default_types.at(attr.type().name()).name() - << ".DTYPE"; + << ".class"; } else { AddArgument(attr.var(), attr.description(), &factory, &factory_doc); factory_statement << attr.var().name(); diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc index 7d184bf2a46..56de15d9ab7 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc @@ -81,13 +81,19 @@ class TypeResolver { std::pair MakeTypePair(const Type& type) { return std::make_pair(type, type); } - Type NextGeneric() { + Type NextGeneric(const OpDef_AttrDef& attr_def) { char generic_letter = next_generic_letter_++; if (next_generic_letter_ > 'Z') { next_generic_letter_ = 'A'; } - return Type::Generic(string(1, generic_letter)) - .add_supertype(Type::Class("TType", "org.tensorflow.types.family")); + return Type::Generic(string(1, generic_letter)); + } + Type TypeFamilyOf(const OpDef_AttrDef& attr_def) { + // TODO(karllessard) support more type families + if (IsRealNumbers(attr_def.allowed_values())) { + return Type::Interface("TNumber", "org.tensorflow.types.family"); + } + return Type::Interface("TType", "org.tensorflow.types.family"); } }; @@ -155,11 +161,9 @@ std::pair TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")); } else if (attr_type == "type") { - Type type = *iterable_out ? Type::Wildcard() : NextGeneric(); - if (IsRealNumbers(attr_def.allowed_values())) { - type.add_supertype(Type::Class("TNumber", "org.tensorflow.types.family")); - } - types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow")); + Type type = *iterable_out ? Type::Wildcard() : NextGeneric(attr_def); + type.add_supertype(TypeFamilyOf(attr_def)); + types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow.proto.framework")); } else { LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type @@ -305,7 +309,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, bool iterable = false; std::pair types = type_resolver->TypesOf(attr_def, &iterable); Type var_type = types.first.kind() == Type::GENERIC - ? Type::DataTypeOf(types.first) + ? Type::ClassOf(types.first) : types.first; if (iterable) { var_type = Type::ListOf(var_type); diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.cc b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.cc index 8598b1d945d..37315f0dff3 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.cc +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.cc @@ -85,6 +85,7 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) { SourceWriter& SourceWriter::AppendType(const Type& type) { if (type.wildcard()) { Append("?"); + WriteTypeBounds(type.supertypes()); } else { Append(type.name()); if (!type.parameters().empty()) { @@ -321,14 +322,27 @@ SourceWriter& SourceWriter::WriteGenerics( Append(", "); } Append(pt->name()); - if (!pt->supertypes().empty()) { - Append(" extends ").AppendType(pt->supertypes().front()); - } + WriteTypeBounds(pt->supertypes()); first = false; } return Append(">"); } +SourceWriter& SourceWriter::WriteTypeBounds( + const std::list& bounds) { + bool first = true; + for (const Type& bound : bounds) { + if (first) { + Append(" extends "); + first = false; + } else { + Append(" & "); + } + AppendType(bound); + } + return *this; +} + SourceWriter::GenericNamespace* SourceWriter::PushGenericNamespace( int modifiers) { GenericNamespace* generic_namespace; diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.h b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.h index 097887083e7..26b97f7a9c4 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.h +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.h @@ -213,6 +213,7 @@ class SourceWriter { SourceWriter& WriteJavadoc(const Javadoc& javadoc); SourceWriter& WriteAnnotations(const std::list& annotations); SourceWriter& WriteGenerics(const std::list& generics); + SourceWriter& WriteTypeBounds(const std::list& bounds); GenericNamespace* PushGenericNamespace(int modifiers); void PopGenericNamespace(); }; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataExperimentalOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataExperimentalOps.java index f78720827b0..ab4089a045d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataExperimentalOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataExperimentalOps.java @@ -18,12 +18,12 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.data.experimental.DataServiceDataset; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * An API for building {@code data.experimental} operations as {@link Op Op}s @@ -57,7 +57,7 @@ public final class DataExperimentalOps { public DataServiceDataset dataServiceDataset(Operand datasetId, Operand processingMode, Operand address, Operand protocol, Operand jobName, Operand maxOutstandingRequests, Operand iterationCounter, - List> outputTypes, List outputShapes, + List> outputTypes, List outputShapes, DataServiceDataset.Options... options) { return DataServiceDataset.create(scope, datasetId, processingMode, address, protocol, jobName, maxOutstandingRequests, iterationCounter, outputTypes, outputShapes, options); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java index b3207a5d4e7..f5f3e7ebf86 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java @@ -18,7 +18,6 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.data.AnonymousIterator; @@ -49,6 +48,7 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * An API for building {@code data} operations as {@link Op Op}s @@ -75,7 +75,7 @@ public final class DataOps { * @param outputShapes * @return a new instance of AnonymousIterator */ - public AnonymousIterator anonymousIterator(List> outputTypes, + public AnonymousIterator anonymousIterator(List> outputTypes, List outputShapes) { return AnonymousIterator.create(scope, outputTypes, outputShapes); } @@ -93,8 +93,8 @@ public AnonymousIterator anonymousIterator(List> outputTypes, * @return a new instance of BatchDataset */ public BatchDataset batchDataset(Operand inputDataset, Operand batchSize, - Operand dropRemainder, List> outputTypes, List outputShapes, - BatchDataset.Options... options) { + Operand dropRemainder, List> outputTypes, + List outputShapes, BatchDataset.Options... options) { return BatchDataset.create(scope, inputDataset, batchSize, dropRemainder, outputTypes, outputShapes, options); } @@ -129,7 +129,7 @@ public CSVDataset cSVDataset(Operand filenames, Operand compre * @return a new instance of ConcatenateDataset */ public ConcatenateDataset concatenateDataset(Operand inputDataset, Operand anotherDataset, - List> outputTypes, List outputShapes) { + List> outputTypes, List outputShapes) { return ConcatenateDataset.create(scope, inputDataset, anotherDataset, outputTypes, outputShapes); } @@ -164,8 +164,8 @@ public DeserializeIterator deserializeIterator(Operand resourceHandle, Operan * @param outputShapes * @return a new instance of Iterator */ - public Iterator iterator(String sharedName, String container, List> outputTypes, - List outputShapes) { + public Iterator iterator(String sharedName, String container, + List> outputTypes, List outputShapes) { return Iterator.create(scope, sharedName, container, outputTypes, outputShapes); } @@ -177,8 +177,8 @@ public Iterator iterator(String sharedName, String container, List> * @param outputShapes * @return a new instance of IteratorGetNext */ - public IteratorGetNext iteratorGetNext(Operand iterator, List> outputTypes, - List outputShapes) { + public IteratorGetNext iteratorGetNext(Operand iterator, + List> outputTypes, List outputShapes) { return IteratorGetNext.create(scope, iterator, outputTypes, outputShapes); } @@ -191,7 +191,7 @@ public IteratorGetNext iteratorGetNext(Operand iterator, List> ou * @return a new instance of IteratorGetNextAsOptional */ public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand iterator, - List> outputTypes, List outputShapes) { + List> outputTypes, List outputShapes) { return IteratorGetNextAsOptional.create(scope, iterator, outputTypes, outputShapes); } @@ -208,8 +208,8 @@ public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand iterator, * @param outputShapes * @return a new instance of IteratorGetNextSync */ - public IteratorGetNextSync iteratorGetNextSync(Operand iterator, List> outputTypes, - List outputShapes) { + public IteratorGetNextSync iteratorGetNextSync(Operand iterator, + List> outputTypes, List outputShapes) { return IteratorGetNextSync.create(scope, iterator, outputTypes, outputShapes); } @@ -255,8 +255,8 @@ public OptionalFromValue optionalFromValue(Iterable> components) { * @param outputShapes * @return a new instance of OptionalGetValue */ - public OptionalGetValue optionalGetValue(Operand optional, List> outputTypes, - List outputShapes) { + public OptionalGetValue optionalGetValue(Operand optional, + List> outputTypes, List outputShapes) { return OptionalGetValue.create(scope, optional, outputTypes, outputShapes); } @@ -290,7 +290,7 @@ public OptionalNone optionalNone() { * @return a new instance of RangeDataset */ public RangeDataset rangeDataset(Operand start, Operand stop, - Operand step, List> outputTypes, List outputShapes) { + Operand step, List> outputTypes, List outputShapes) { return RangeDataset.create(scope, start, stop, step, outputTypes, outputShapes); } @@ -305,7 +305,7 @@ public RangeDataset rangeDataset(Operand start, Operand stop, * @return a new instance of RepeatDataset */ public RepeatDataset repeatDataset(Operand inputDataset, Operand count, - List> outputTypes, List outputShapes) { + List> outputTypes, List outputShapes) { return RepeatDataset.create(scope, inputDataset, count, outputTypes, outputShapes); } @@ -332,7 +332,7 @@ public SerializeIterator serializeIterator(Operand resourceHandle, * @return a new instance of SkipDataset */ public SkipDataset skipDataset(Operand inputDataset, Operand count, - List> outputTypes, List outputShapes) { + List> outputTypes, List outputShapes) { return SkipDataset.create(scope, inputDataset, count, outputTypes, outputShapes); } @@ -348,7 +348,7 @@ public SkipDataset skipDataset(Operand inputDataset, Operand count, * @return a new instance of TakeDataset */ public TakeDataset takeDataset(Operand inputDataset, Operand count, - List> outputTypes, List outputShapes) { + List> outputTypes, List outputShapes) { return TakeDataset.create(scope, inputDataset, count, outputTypes, outputShapes); } @@ -409,8 +409,8 @@ public TfRecordDataset tfRecordDataset(Operand filenames, * @param outputShapes * @return a new instance of ZipDataset */ - public ZipDataset zipDataset(Iterable> inputDatasets, List> outputTypes, - List outputShapes) { + public ZipDataset zipDataset(Iterable> inputDatasets, + List> outputTypes, List outputShapes) { return ZipDataset.create(scope, inputDatasets, outputTypes, outputShapes); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DtypesOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DtypesOps.java index 1fb69c6637d..acf6a748b70 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DtypesOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DtypesOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.dtypes.AsString; import org.tensorflow.op.dtypes.Cast; @@ -73,7 +72,7 @@ public AsString asString(Operand input, AsString.Options... * @param options carries optional attributes values * @return a new instance of Cast */ - public Cast cast(Operand x, DataType DstT, + public Cast cast(Operand x, Class DstT, Cast.Options... options) { return Cast.create(scope, x, DstT, options); } @@ -102,7 +101,7 @@ public Cast cast(Operand x, DataType * @return a new instance of Complex */ public Complex complex(Operand real, Operand imag, - DataType Tout) { + Class Tout) { return Complex.create(scope, real, imag, Tout); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ImageOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ImageOps.java index 35791f74ec8..13db1243c8a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ImageOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ImageOps.java @@ -18,7 +18,6 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.image.AdjustContrast; import org.tensorflow.op.image.AdjustHue; @@ -272,7 +271,7 @@ public CropAndResizeGradBoxes cropAndResizeGradBoxes(Operand */ public CropAndResizeGradImage cropAndResizeGradImage( Operand grads, Operand boxes, Operand boxInd, - Operand imageSize, DataType T, CropAndResizeGradImage.Options... options) { + Operand imageSize, Class T, CropAndResizeGradImage.Options... options) { return CropAndResizeGradImage.create(scope, grads, boxes, boxInd, imageSize, T, options); } @@ -463,7 +462,7 @@ public DecodePng decodePng(Operand contents, DecodePng.Options. * @param options carries optional attributes values * @return a new instance of DecodePng */ - public DecodePng decodePng(Operand contents, DataType dtype, + public DecodePng decodePng(Operand contents, Class dtype, DecodePng.Options... options) { return DecodePng.create(scope, contents, dtype, options); } @@ -625,7 +624,7 @@ public ExtractJpegShape extractJpegShape(Operand contents) { * @return a new instance of ExtractJpegShape */ public ExtractJpegShape extractJpegShape(Operand contents, - DataType outputType) { + Class outputType) { return ExtractJpegShape.create(scope, contents, outputType); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/IoOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/IoOps.java index 3cb7569ab38..f8d48de3690 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/IoOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/IoOps.java @@ -18,7 +18,6 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.io.DecodeBase64; @@ -171,7 +170,7 @@ public DecodeJsonExample decodeJsonExample(Operand jsonExamples) { * @return a new instance of DecodePaddedRaw */ public DecodePaddedRaw decodePaddedRaw(Operand inputBytes, - Operand fixedLength, DataType outType, DecodePaddedRaw.Options... options) { + Operand fixedLength, Class outType, DecodePaddedRaw.Options... options) { return DecodePaddedRaw.create(scope, inputBytes, fixedLength, outType, options); } @@ -184,7 +183,7 @@ public DecodePaddedRaw decodePaddedRaw(Operand i * @param options carries optional attributes values * @return a new instance of DecodeRaw */ - public DecodeRaw decodeRaw(Operand bytes, DataType outType, + public DecodeRaw decodeRaw(Operand bytes, Class outType, DecodeRaw.Options... options) { return DecodeRaw.create(scope, bytes, outType, options); } @@ -241,7 +240,7 @@ public DecodeRaw decodeRaw(Operand bytes, DataType * @return a new instance of DeserializeManySparse */ public DeserializeManySparse deserializeManySparse( - Operand serializedSparse, DataType dtype) { + Operand serializedSparse, Class dtype) { return DeserializeManySparse.create(scope, serializedSparse, dtype); } @@ -270,7 +269,8 @@ public EncodeBase64 encodeBase64(Operand input, EncodeBase64.Options... * @param options carries optional attributes values * @return a new instance of FifoQueue */ - public FifoQueue fifoQueue(List> componentTypes, FifoQueue.Options... options) { + public FifoQueue fifoQueue(List> componentTypes, + FifoQueue.Options... options) { return FifoQueue.create(scope, componentTypes, options); } @@ -334,7 +334,7 @@ public MatchingFiles matchingFiles(Operand pattern) { * @param options carries optional attributes values * @return a new instance of PaddingFifoQueue */ - public PaddingFifoQueue paddingFifoQueue(List> componentTypes, + public PaddingFifoQueue paddingFifoQueue(List> componentTypes, PaddingFifoQueue.Options... options) { return PaddingFifoQueue.create(scope, componentTypes, options); } @@ -396,9 +396,9 @@ public PaddingFifoQueue paddingFifoQueue(List> componentTypes, */ public ParseExample parseExample(Operand serialized, Operand names, Operand sparseKeys, Operand denseKeys, Operand raggedKeys, - Iterable> denseDefaults, Long numSparse, List> sparseTypes, - List> raggedValueTypes, List> raggedSplitTypes, - List denseShapes) { + Iterable> denseDefaults, Long numSparse, List> sparseTypes, + List> raggedValueTypes, + List> raggedSplitTypes, List denseShapes) { return ParseExample.create(scope, serialized, names, sparseKeys, denseKeys, raggedKeys, denseDefaults, numSparse, sparseTypes, raggedValueTypes, raggedSplitTypes, denseShapes); } @@ -454,10 +454,13 @@ public ParseSequenceExample parseSequenceExample(Operand serialized, Operand contextDenseKeys, Operand contextRaggedKeys, Operand featureListSparseKeys, Operand featureListDenseKeys, Operand featureListRaggedKeys, Operand featureListDenseMissingAssumedEmpty, - Iterable> contextDenseDefaults, List> contextSparseTypes, - List> contextRaggedValueTypes, List> contextRaggedSplitTypes, - List> featureListDenseTypes, List> featureListSparseTypes, - List> featureListRaggedValueTypes, List> featureListRaggedSplitTypes, + Iterable> contextDenseDefaults, List> contextSparseTypes, + List> contextRaggedValueTypes, + List> contextRaggedSplitTypes, + List> featureListDenseTypes, + List> featureListSparseTypes, + List> featureListRaggedValueTypes, + List> featureListRaggedSplitTypes, ParseSequenceExample.Options... options) { return ParseSequenceExample.create(scope, serialized, debugName, contextSparseKeys, contextDenseKeys, contextRaggedKeys, featureListSparseKeys, featureListDenseKeys, featureListRaggedKeys, featureListDenseMissingAssumedEmpty, contextDenseDefaults, contextSparseTypes, contextRaggedValueTypes, contextRaggedSplitTypes, featureListDenseTypes, featureListSparseTypes, featureListRaggedValueTypes, featureListRaggedSplitTypes, options); } @@ -499,7 +502,7 @@ public ParseSequenceExample parseSequenceExample(Operand serialized, */ public ParseSingleExample parseSingleExample(Operand serialized, Iterable> denseDefaults, Long numSparse, List sparseKeys, - List denseKeys, List> sparseTypes, List denseShapes) { + List denseKeys, List> sparseTypes, List denseShapes) { return ParseSingleExample.create(scope, serialized, denseDefaults, numSparse, sparseKeys, denseKeys, sparseTypes, denseShapes); } @@ -553,8 +556,9 @@ public ParseSingleSequenceExample parseSingleSequenceExample(Operand se Iterable> contextSparseKeys, Iterable> contextDenseKeys, Iterable> featureListSparseKeys, Iterable> featureListDenseKeys, Iterable> contextDenseDefaults, - Operand debugName, List> contextSparseTypes, - List> featureListDenseTypes, List> featureListSparseTypes, + Operand debugName, List> contextSparseTypes, + List> featureListDenseTypes, + List> featureListSparseTypes, ParseSingleSequenceExample.Options... options) { return ParseSingleSequenceExample.create(scope, serialized, featureListDenseMissingAssumedEmpty, contextSparseKeys, contextDenseKeys, featureListSparseKeys, featureListDenseKeys, contextDenseDefaults, debugName, contextSparseTypes, featureListDenseTypes, featureListSparseTypes, options); } @@ -569,7 +573,7 @@ public ParseSingleSequenceExample parseSingleSequenceExample(Operand se * @return a new instance of ParseTensor */ public ParseTensor parseTensor(Operand serialized, - DataType outType) { + Class outType) { return ParseTensor.create(scope, serialized, outType); } @@ -590,8 +594,8 @@ public ParseTensor parseTensor(Operand serialized, * @param options carries optional attributes values * @return a new instance of PriorityQueue */ - public PriorityQueue priorityQueue(List> componentTypes, List shapes, - PriorityQueue.Options... options) { + public PriorityQueue priorityQueue(List> componentTypes, + List shapes, PriorityQueue.Options... options) { return PriorityQueue.create(scope, componentTypes, shapes, options); } @@ -627,7 +631,7 @@ public QueueClose queueClose(Operand handle, QueueClose.Options... options) { * @param options carries optional attributes values * @return a new instance of QueueDequeue */ - public QueueDequeue queueDequeue(Operand handle, List> componentTypes, + public QueueDequeue queueDequeue(Operand handle, List> componentTypes, QueueDequeue.Options... options) { return QueueDequeue.create(scope, handle, componentTypes, options); } @@ -656,7 +660,7 @@ public QueueDequeue queueDequeue(Operand handle, List> componentT * @return a new instance of QueueDequeueMany */ public QueueDequeueMany queueDequeueMany(Operand handle, Operand n, - List> componentTypes, QueueDequeueMany.Options... options) { + List> componentTypes, QueueDequeueMany.Options... options) { return QueueDequeueMany.create(scope, handle, n, componentTypes, options); } @@ -688,7 +692,7 @@ public QueueDequeueMany queueDequeueMany(Operand handle, Operand n, * @return a new instance of QueueDequeueUpTo */ public QueueDequeueUpTo queueDequeueUpTo(Operand handle, Operand n, - List> componentTypes, QueueDequeueUpTo.Options... options) { + List> componentTypes, QueueDequeueUpTo.Options... options) { return QueueDequeueUpTo.create(scope, handle, n, componentTypes, options); } @@ -765,7 +769,7 @@ public QueueSize queueSize(Operand handle) { * @param options carries optional attributes values * @return a new instance of RandomShuffleQueue */ - public RandomShuffleQueue randomShuffleQueue(List> componentTypes, + public RandomShuffleQueue randomShuffleQueue(List> componentTypes, RandomShuffleQueue.Options... options) { return RandomShuffleQueue.create(scope, componentTypes, options); } @@ -917,7 +921,7 @@ public SerializeManySparse serializeManySparse( */ public SerializeManySparse serializeManySparse( Operand sparseIndices, Operand sparseValues, Operand sparseShape, - DataType outType) { + Class outType) { return SerializeManySparse.create(scope, sparseIndices, sparseValues, sparseShape, outType); } @@ -948,7 +952,7 @@ public SerializeSparse serializeSparse(Operand SerializeSparse serializeSparse( Operand sparseIndices, Operand sparseValues, Operand sparseShape, - DataType outType) { + Class outType) { return SerializeSparse.create(scope, sparseIndices, sparseValues, sparseShape, outType); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/LinalgOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/LinalgOps.java index 6d63d9ddd89..f15c50fe691 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/LinalgOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/LinalgOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.linalg.BandPart; import org.tensorflow.op.linalg.BatchCholesky; @@ -399,7 +398,7 @@ public Det det(Operand input) { * @param options carries optional attributes values * @return a new instance of Eig */ - public Eig eig(Operand input, DataType Tout, + public Eig eig(Operand input, Class Tout, Eig.Options... options) { return Eig.create(scope, input, Tout, options); } @@ -685,7 +684,7 @@ public Lu lu(Operand input) { * @return a new instance of Lu */ public Lu lu(Operand input, - DataType outputIdxType) { + Class outputIdxType) { return Lu.create(scope, input, outputIdxType); } @@ -1376,7 +1375,7 @@ public Qr qr(Operand input, Qr.Options... options) { */ public QuantizedMatMul quantizedMatMul( Operand a, Operand b, Operand minA, Operand maxA, - Operand minB, Operand maxB, DataType Toutput, DataType Tactivation, + Operand minB, Operand maxB, Class Toutput, Class Tactivation, QuantizedMatMul.Options... options) { return QuantizedMatMul.create(scope, a, b, minA, maxA, minB, maxB, Toutput, Tactivation, options); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/MathOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/MathOps.java index b193c2e404d..252a84fd745 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/MathOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/MathOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.math.Abs; @@ -292,7 +291,7 @@ public Angle angle(Operand input) { * @param Tout * @return a new instance of Angle */ - public Angle angle(Operand input, DataType Tout) { + public Angle angle(Operand input, Class Tout) { return Angle.create(scope, input, Tout); } @@ -360,7 +359,7 @@ public ArgMax argMax(Operand inp * @return a new instance of ArgMax */ public ArgMax argMax(Operand input, - Operand dimension, DataType outputType) { + Operand dimension, Class outputType) { return ArgMax.create(scope, input, dimension, outputType); } @@ -415,7 +414,7 @@ public ArgMin argMin(Operand inp * @return a new instance of ArgMin */ public ArgMin argMin(Operand input, - Operand dimension, DataType outputType) { + Operand dimension, Class outputType) { return ArgMin.create(scope, input, dimension, outputType); } @@ -654,7 +653,7 @@ public ComplexAbs complexAbs(Operand x) { * @return a new instance of ComplexAbs */ public ComplexAbs complexAbs(Operand x, - DataType Tout) { + Class Tout) { return ComplexAbs.create(scope, x, Tout); } @@ -1181,7 +1180,7 @@ public Imag imag(Operand input) { * @param Tout * @return a new instance of Imag */ - public Imag imag(Operand input, DataType Tout) { + public Imag imag(Operand input, Class Tout) { return Imag.create(scope, input, Tout); } @@ -1638,7 +1637,7 @@ public Pow pow(Operand x, Operand y) { */ public QuantizedAdd quantizedAdd( Operand x, Operand y, Operand minX, Operand maxX, - Operand minY, Operand maxY, DataType Toutput) { + Operand minY, Operand maxY, Class Toutput) { return QuantizedAdd.create(scope, x, y, minX, maxX, minY, maxY, Toutput); } @@ -1657,7 +1656,7 @@ public QuantizedAdd quant */ public QuantizedMul quantizedMul( Operand x, Operand y, Operand minX, Operand maxX, - Operand minY, Operand maxY, DataType Toutput) { + Operand minY, Operand maxY, Class Toutput) { return QuantizedMul.create(scope, x, y, minX, maxX, minY, maxY, Toutput); } @@ -1702,7 +1701,7 @@ public Real real(Operand input) { * @param Tout * @return a new instance of Real */ - public Real real(Operand input, DataType Tout) { + public Real real(Operand input, Class Tout) { return Real.create(scope, input, Tout); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java index 0992cc606ba..8b5b01fac32 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java @@ -18,7 +18,6 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.nn.AvgPool; import org.tensorflow.op.nn.AvgPool3d; @@ -645,8 +644,8 @@ public CudnnRNNParamsToCanonical cudnnRNNParamsToCanonica * @return a new instance of CudnnRnnParamsSize */ public CudnnRnnParamsSize cudnnRnnParamsSize( - Operand numLayers, Operand numUnits, Operand inputSize, DataType T, - DataType S, CudnnRnnParamsSize.Options... options) { + Operand numLayers, Operand numUnits, Operand inputSize, Class T, + Class S, CudnnRnnParamsSize.Options... options) { return CudnnRnnParamsSize.create(scope, numLayers, numUnits, inputSize, T, S, options); } @@ -1504,7 +1503,7 @@ public MaxPoolWithArgmax maxPoolWithArgmax(Operan * @return a new instance of MaxPoolWithArgmax */ public MaxPoolWithArgmax maxPoolWithArgmax( - Operand input, List ksize, List strides, DataType Targmax, String padding, + Operand input, List ksize, List strides, Class Targmax, String padding, MaxPoolWithArgmax.Options... options) { return MaxPoolWithArgmax.create(scope, input, ksize, strides, Targmax, padding, options); } @@ -1591,7 +1590,7 @@ public QuantizedBatchNormWithGlobalNormalizat Operand t, Operand tMin, Operand tMax, Operand m, Operand mMin, Operand mMax, Operand v, Operand vMin, Operand vMax, Operand beta, Operand betaMin, Operand betaMax, - Operand gamma, Operand gammaMin, Operand gammaMax, DataType outType, + Operand gamma, Operand gammaMin, Operand gammaMax, Class outType, Float varianceEpsilon, Boolean scaleAfterNormalization) { return QuantizedBatchNormWithGlobalNormalization.create(scope, t, tMin, tMax, m, mMin, mMax, v, vMin, vMax, beta, betaMin, betaMax, gamma, gammaMin, gammaMax, outType, varianceEpsilon, scaleAfterNormalization); } @@ -1613,7 +1612,7 @@ public QuantizedBatchNormWithGlobalNormalizat */ public QuantizedBiasAdd quantizedBiasAdd( Operand input, Operand bias, Operand minInput, Operand maxInput, - Operand minBias, Operand maxBias, DataType outType) { + Operand minBias, Operand maxBias, Class outType) { return QuantizedBiasAdd.create(scope, input, bias, minInput, maxInput, minBias, maxBias, outType); } @@ -1641,7 +1640,7 @@ public QuantizedBiasAdd q */ public QuantizedConv2d quantizedConv2d( Operand input, Operand filter, Operand minInput, Operand maxInput, - Operand minFilter, Operand maxFilter, DataType outType, + Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, QuantizedConv2d.Options... options) { return QuantizedConv2d.create(scope, input, filter, minInput, maxInput, minFilter, maxFilter, outType, strides, padding, options); } @@ -1692,7 +1691,7 @@ public QuantizedMaxPool quantizedMaxPool(Operand input, * @return a new instance of QuantizedRelu */ public QuantizedRelu quantizedRelu(Operand features, - Operand minFeatures, Operand maxFeatures, DataType outType) { + Operand minFeatures, Operand maxFeatures, Class outType) { return QuantizedRelu.create(scope, features, minFeatures, maxFeatures, outType); } @@ -1707,7 +1706,7 @@ public QuantizedRelu quantizedRelu(Operand * @return a new instance of QuantizedRelu6 */ public QuantizedRelu6 quantizedRelu6(Operand features, - Operand minFeatures, Operand maxFeatures, DataType outType) { + Operand minFeatures, Operand maxFeatures, Class outType) { return QuantizedRelu6.create(scope, features, minFeatures, maxFeatures, outType); } @@ -1724,7 +1723,7 @@ public QuantizedRelu6 quantizedRelu6(Opera */ public QuantizedReluX quantizedReluX(Operand features, Operand maxValue, Operand minFeatures, Operand maxFeatures, - DataType outType) { + Class outType) { return QuantizedReluX.create(scope, features, maxValue, minFeatures, maxFeatures, outType); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 06908c41d6a..d6e69085324 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -19,7 +19,6 @@ import java.nio.charset.Charset; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.DeviceSpec; import org.tensorflow.EagerSession; import org.tensorflow.ExecutionEnvironment; @@ -648,7 +647,7 @@ public AssignVariableOp assignVariableOp(Operand resource, * @param options carries optional attributes values * @return a new instance of Barrier */ - public Barrier barrier(List> componentTypes, Barrier.Options... options) { + public Barrier barrier(List> componentTypes, Barrier.Options... options) { return Barrier.create(scope, componentTypes, options); } @@ -729,7 +728,7 @@ public BarrierReadySize barrierReadySize(Operand handle) { * @return a new instance of BarrierTakeMany */ public BarrierTakeMany barrierTakeMany(Operand handle, Operand numElements, - List> componentTypes, BarrierTakeMany.Options... options) { + List> componentTypes, BarrierTakeMany.Options... options) { return BarrierTakeMany.create(scope, handle, numElements, componentTypes, options); } @@ -989,7 +988,7 @@ public BatchToSpaceNd * @param type * @return a new instance of Bitcast */ - public Bitcast bitcast(Operand input, DataType type) { + public Bitcast bitcast(Operand input, Class type) { return Bitcast.create(scope, input, type); } @@ -1854,16 +1853,16 @@ public Constant constant(Charset charset, Shape shape, DataBuffer the tensor type * @param scope is a scope used to add the underlying operation. - * @param type the tensor datatype. + * @param type the tensor type class * @param shape the tensor shape. * @param data a buffer containing the tensor data. * @return a constant of type `type` * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ - public Constant constant(DataType type, Shape shape, - ByteDataBuffer data) { + public Constant constant(Class type, Shape shape, ByteDataBuffer data) { return Constant.tensorOf(scope, type, shape, data); } @@ -2138,7 +2137,7 @@ public EditDistance editDistance(Operand hypothesisInd * @param options carries optional attributes values * @return a new instance of Empty */ - public Empty empty(Operand shape, DataType dtype, + public Empty empty(Operand shape, Class dtype, Empty.Options... options) { return Empty.create(scope, shape, dtype, options); } @@ -2159,7 +2158,7 @@ public Empty empty(Operand shape, DataType dtype * @return a new instance of EmptyTensorList */ public EmptyTensorList emptyTensorList( - Operand elementShape, Operand maxNumElements, DataType elementDtype) { + Operand elementShape, Operand maxNumElements, Class elementDtype) { return EmptyTensorList.create(scope, elementShape, maxNumElements, elementDtype); } @@ -2491,7 +2490,7 @@ public GetSessionHandle getSessionHandle(Operand value) { * @return a new instance of GetSessionTensor */ public GetSessionTensor getSessionTensor(Operand handle, - DataType dtype) { + Class dtype) { return GetSessionTensor.create(scope, handle, dtype); } @@ -2571,8 +2570,8 @@ public GuaranteeConst guaranteeConst(Operand input) { * @param options carries optional attributes values * @return a new instance of HashTable */ - public HashTable hashTable(DataType keyDtype, - DataType valueDtype, HashTable.Options... options) { + public HashTable hashTable(Class keyDtype, + Class valueDtype, HashTable.Options... options) { return HashTable.create(scope, keyDtype, valueDtype, options); } @@ -2635,7 +2634,7 @@ public HistogramFixedWidth histogramFixedWidth(Opera * @return a new instance of HistogramFixedWidth */ public HistogramFixedWidth histogramFixedWidth( - Operand values, Operand valueRange, Operand nbins, DataType dtype) { + Operand values, Operand valueRange, Operand nbins, Class dtype) { return HistogramFixedWidth.create(scope, values, valueRange, nbins, dtype); } @@ -2685,7 +2684,7 @@ public IdentityN identityN(Iterable> input) { * NewReadOnlyMemoryRegionFromFile in tensorflow::Env. * @return a new instance of ImmutableConst */ - public ImmutableConst immutableConst(DataType dtype, Shape shape, + public ImmutableConst immutableConst(Class dtype, Shape shape, String memoryRegionName) { return ImmutableConst.create(scope, dtype, shape, memoryRegionName); } @@ -2710,7 +2709,7 @@ public ImmutableConst immutableConst(DataType dtype, Sha * try (Session s = new Session(g)) { * s.run(tf.init()); // initialize all variables * - * try (Tensor t = s.runner().fetch(z).run().get(0).expect(TInt32.DTYPE)) { + * try (TInt32 t = (TInt32)s.runner().fetch(z).run().get(0)) { * assertEquals(30, t.data().getInt()); * } * } @@ -2737,7 +2736,7 @@ public ImmutableConst immutableConst(DataType dtype, Sha * try (SavedModelBundle model = SavedModelBundle.load("/path/to/model", "train")) { * model.session().run(Init.DEFAULT_NAME); * - * try (Tensor t = s.runner().fetch("z").run().get(0).expect(TInt32.DTYPE)) { + * try (TInt32 t = (TInt32)s.runner().fetch("z").run().get(0)) { * assertEquals(30, t.data().getInt()); * } * } @@ -2880,7 +2879,7 @@ public IsVariableInitialized isVariableInitialized(Operand * @return a new instance of LookupTableExport */ public LookupTableExport lookupTableExport( - Operand tableHandle, DataType Tkeys, DataType Tvalues) { + Operand tableHandle, Class Tkeys, Class Tvalues) { return LookupTableExport.create(scope, tableHandle, Tkeys, Tvalues); } @@ -2966,7 +2965,7 @@ public LoopCond loopCond(Operand input) { * @param options carries optional attributes values * @return a new instance of MapClear */ - public MapClear mapClear(List> dtypes, MapClear.Options... options) { + public MapClear mapClear(List> dtypes, MapClear.Options... options) { return MapClear.create(scope, dtypes, options); } @@ -2977,7 +2976,7 @@ public MapClear mapClear(List> dtypes, MapClear.Options... options) * @param options carries optional attributes values * @return a new instance of MapIncompleteSize */ - public MapIncompleteSize mapIncompleteSize(List> dtypes, + public MapIncompleteSize mapIncompleteSize(List> dtypes, MapIncompleteSize.Options... options) { return MapIncompleteSize.create(scope, dtypes, options); } @@ -2994,8 +2993,8 @@ public MapIncompleteSize mapIncompleteSize(List> dtypes, * @param options carries optional attributes values * @return a new instance of MapPeek */ - public MapPeek mapPeek(Operand key, Operand indices, List> dtypes, - MapPeek.Options... options) { + public MapPeek mapPeek(Operand key, Operand indices, + List> dtypes, MapPeek.Options... options) { return MapPeek.create(scope, key, indices, dtypes, options); } @@ -3006,7 +3005,7 @@ public MapPeek mapPeek(Operand key, Operand indices, List> dtypes, MapSize.Options... options) { + public MapSize mapSize(List> dtypes, MapSize.Options... options) { return MapSize.create(scope, dtypes, options); } @@ -3022,7 +3021,8 @@ public MapSize mapSize(List> dtypes, MapSize.Options... options) { * @return a new instance of MapStage */ public MapStage mapStage(Operand key, Operand indices, - Iterable> values, List> dtypes, MapStage.Options... options) { + Iterable> values, List> dtypes, + MapStage.Options... options) { return MapStage.create(scope, key, indices, values, dtypes, options); } @@ -3039,7 +3039,7 @@ public MapStage mapStage(Operand key, Operand indices, * @return a new instance of MapUnstage */ public MapUnstage mapUnstage(Operand key, Operand indices, - List> dtypes, MapUnstage.Options... options) { + List> dtypes, MapUnstage.Options... options) { return MapUnstage.create(scope, key, indices, dtypes, options); } @@ -3054,8 +3054,8 @@ public MapUnstage mapUnstage(Operand key, Operand indices, * @param options carries optional attributes values * @return a new instance of MapUnstageNoKey */ - public MapUnstageNoKey mapUnstageNoKey(Operand indices, List> dtypes, - MapUnstageNoKey.Options... options) { + public MapUnstageNoKey mapUnstageNoKey(Operand indices, + List> dtypes, MapUnstageNoKey.Options... options) { return MapUnstageNoKey.create(scope, indices, dtypes, options); } @@ -3196,7 +3196,7 @@ public MirrorPad mirrorPad(Operand in * @return a new instance of MlirPassthroughOp */ public MlirPassthroughOp mlirPassthroughOp(Iterable> inputs, String mlirModule, - List> Toutputs) { + List> Toutputs) { return MlirPassthroughOp.create(scope, inputs, mlirModule, Toutputs); } @@ -3218,7 +3218,7 @@ public MlirPassthroughOp mlirPassthroughOp(Iterable> inputs, String m * @return a new instance of MutableDenseHashTable */ public MutableDenseHashTable mutableDenseHashTable( - Operand emptyKey, Operand deletedKey, DataType valueDtype, + Operand emptyKey, Operand deletedKey, Class valueDtype, MutableDenseHashTable.Options... options) { return MutableDenseHashTable.create(scope, emptyKey, deletedKey, valueDtype, options); } @@ -3235,8 +3235,8 @@ public MutableDenseHashTable mutableDenseHash * @param options carries optional attributes values * @return a new instance of MutableHashTable */ - public MutableHashTable mutableHashTable(DataType keyDtype, - DataType valueDtype, MutableHashTable.Options... options) { + public MutableHashTable mutableHashTable(Class keyDtype, + Class valueDtype, MutableHashTable.Options... options) { return MutableHashTable.create(scope, keyDtype, valueDtype, options); } @@ -3253,7 +3253,7 @@ public MutableHashTable mutableHashTable(Data * @return a new instance of MutableHashTableOfTensors */ public MutableHashTableOfTensors mutableHashTableOfTensors( - DataType keyDtype, DataType valueDtype, MutableHashTableOfTensors.Options... options) { + Class keyDtype, Class valueDtype, MutableHashTableOfTensors.Options... options) { return MutableHashTableOfTensors.create(scope, keyDtype, valueDtype, options); } @@ -3434,11 +3434,11 @@ public OneHot oneHot(Operand indices, * * @param scope is a scope used to add the underlying operation * @param dims a 1-D operand that represents the shape of the output tensor - * @param type the output tensor datatype. Can not be TString. + * @param type the output tensor type class. Can not be TString. * @return a constant tensor initialized with ones * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with ones. */ - public Ones ones(Operand dims, DataType type) { + public Ones ones(Operand dims, Class type) { return Ones.create(scope, dims, type); } @@ -3460,7 +3460,7 @@ public OnesLike onesLike(Operand x) { * @param options carries optional attributes values * @return a new instance of OrderedMapClear */ - public OrderedMapClear orderedMapClear(List> dtypes, + public OrderedMapClear orderedMapClear(List> dtypes, OrderedMapClear.Options... options) { return OrderedMapClear.create(scope, dtypes, options); } @@ -3472,7 +3472,7 @@ public OrderedMapClear orderedMapClear(List> dtypes, * @param options carries optional attributes values * @return a new instance of OrderedMapIncompleteSize */ - public OrderedMapIncompleteSize orderedMapIncompleteSize(List> dtypes, + public OrderedMapIncompleteSize orderedMapIncompleteSize(List> dtypes, OrderedMapIncompleteSize.Options... options) { return OrderedMapIncompleteSize.create(scope, dtypes, options); } @@ -3491,7 +3491,7 @@ public OrderedMapIncompleteSize orderedMapIncompleteSize(List> dtype * @return a new instance of OrderedMapPeek */ public OrderedMapPeek orderedMapPeek(Operand key, Operand indices, - List> dtypes, OrderedMapPeek.Options... options) { + List> dtypes, OrderedMapPeek.Options... options) { return OrderedMapPeek.create(scope, key, indices, dtypes, options); } @@ -3502,7 +3502,7 @@ public OrderedMapPeek orderedMapPeek(Operand key, Operand indice * @param options carries optional attributes values * @return a new instance of OrderedMapSize */ - public OrderedMapSize orderedMapSize(List> dtypes, + public OrderedMapSize orderedMapSize(List> dtypes, OrderedMapSize.Options... options) { return OrderedMapSize.create(scope, dtypes, options); } @@ -3521,7 +3521,8 @@ public OrderedMapSize orderedMapSize(List> dtypes, * @return a new instance of OrderedMapStage */ public OrderedMapStage orderedMapStage(Operand key, Operand indices, - Iterable> values, List> dtypes, OrderedMapStage.Options... options) { + Iterable> values, List> dtypes, + OrderedMapStage.Options... options) { return OrderedMapStage.create(scope, key, indices, values, dtypes, options); } @@ -3538,7 +3539,7 @@ public OrderedMapStage orderedMapStage(Operand key, Operand indi * @return a new instance of OrderedMapUnstage */ public OrderedMapUnstage orderedMapUnstage(Operand key, Operand indices, - List> dtypes, OrderedMapUnstage.Options... options) { + List> dtypes, OrderedMapUnstage.Options... options) { return OrderedMapUnstage.create(scope, key, indices, dtypes, options); } @@ -3554,7 +3555,7 @@ public OrderedMapUnstage orderedMapUnstage(Operand key, Operand * @return a new instance of OrderedMapUnstageNoKey */ public OrderedMapUnstageNoKey orderedMapUnstageNoKey(Operand indices, - List> dtypes, OrderedMapUnstageNoKey.Options... options) { + List> dtypes, OrderedMapUnstageNoKey.Options... options) { return OrderedMapUnstageNoKey.create(scope, indices, dtypes, options); } @@ -3705,7 +3706,7 @@ public ParallelDynamicStitch parallelDynamicStitch( * @param options carries optional attributes values * @return a new instance of Placeholder */ - public Placeholder placeholder(DataType dtype, + public Placeholder placeholder(Class dtype, Placeholder.Options... options) { return Placeholder.create(scope, dtype, options); } @@ -3834,8 +3835,7 @@ public Rank rank(Operand input) { * @param dtype the dtype of the value. * @return a new instance of ReadVariableOp */ - public ReadVariableOp readVariableOp(Operand resource, - DataType dtype) { + public ReadVariableOp readVariableOp(Operand resource, Class dtype) { return ReadVariableOp.create(scope, resource, dtype); } @@ -4016,7 +4016,7 @@ public RefSwitch refSwitch(Operand data, Operand * @return a new instance of RemoteFusedGraphExecute */ public RemoteFusedGraphExecute remoteFusedGraphExecute(Iterable> inputs, - List> Toutputs, String serializedRemoteFusedGraphExecuteInfo) { + List> Toutputs, String serializedRemoteFusedGraphExecuteInfo) { return RemoteFusedGraphExecute.create(scope, inputs, Toutputs, serializedRemoteFusedGraphExecuteInfo); } @@ -4103,7 +4103,7 @@ public Reshape reshape(Operand tensor * @return a new instance of ResourceCountUpTo */ public ResourceCountUpTo resourceCountUpTo(Operand resource, Long limit, - DataType T) { + Class T) { return ResourceCountUpTo.create(scope, resource, limit, T); } @@ -4131,7 +4131,7 @@ public ResourceCountUpTo resourceCountUpTo(Operand res * @return a new instance of ResourceGather */ public ResourceGather resourceGather(Operand resource, - Operand indices, DataType dtype, ResourceGather.Options... options) { + Operand indices, Class dtype, ResourceGather.Options... options) { return ResourceGather.create(scope, resource, indices, dtype, options); } @@ -4144,7 +4144,7 @@ public ResourceGather resourceGather(Ope * @return a new instance of ResourceGatherNd */ public ResourceGatherNd resourceGatherNd( - Operand resource, Operand indices, DataType dtype) { + Operand resource, Operand indices, Class dtype) { return ResourceGatherNd.create(scope, resource, indices, dtype); } @@ -5423,7 +5423,7 @@ public SetDiff1d setDiff1d(Operand x, Operand * @return a new instance of SetDiff1d */ public SetDiff1d setDiff1d(Operand x, Operand y, - DataType outIdx) { + Class outIdx) { return SetDiff1d.create(scope, x, y, outIdx); } @@ -5484,7 +5484,7 @@ public org.tensorflow.op.core.Shape shape(Operand i * @return a new instance of Shape */ public org.tensorflow.op.core.Shape shape( - Operand input, DataType outType) { + Operand input, Class outType) { return org.tensorflow.op.core.Shape.create(scope, input, outType); } @@ -5512,7 +5512,7 @@ public ShapeN shapeN(Iterable> input) { * @return a new instance of ShapeN */ public ShapeN shapeN(Iterable> input, - DataType outType) { + Class outType) { return ShapeN.create(scope, input, outType); } @@ -5553,7 +5553,7 @@ public Size size(Operand input) { * @param outType * @return a new instance of Size */ - public Size size(Operand input, DataType outType) { + public Size size(Operand input, Class outType) { return Size.create(scope, input, outType); } @@ -5833,7 +5833,7 @@ public Stage stage(Iterable> values, Stage.Options... options) { * @param options carries optional attributes values * @return a new instance of StageClear */ - public StageClear stageClear(List> dtypes, StageClear.Options... options) { + public StageClear stageClear(List> dtypes, StageClear.Options... options) { return StageClear.create(scope, dtypes, options); } @@ -5849,7 +5849,7 @@ public StageClear stageClear(List> dtypes, StageClear.Options... opt * @param options carries optional attributes values * @return a new instance of StagePeek */ - public StagePeek stagePeek(Operand index, List> dtypes, + public StagePeek stagePeek(Operand index, List> dtypes, StagePeek.Options... options) { return StagePeek.create(scope, index, dtypes, options); } @@ -5861,7 +5861,7 @@ public StagePeek stagePeek(Operand index, List> dtypes, * @param options carries optional attributes values * @return a new instance of StageSize */ - public StageSize stageSize(List> dtypes, StageSize.Options... options) { + public StageSize stageSize(List> dtypes, StageSize.Options... options) { return StageSize.create(scope, dtypes, options); } @@ -6126,7 +6126,7 @@ public SwitchCond switchCond(Operand data, Operand TemporaryVariable temporaryVariable(Shape shape, DataType dtype, + public TemporaryVariable temporaryVariable(Shape shape, Class dtype, TemporaryVariable.Options... options) { return TemporaryVariable.create(scope, shape, dtype, options); } @@ -6141,7 +6141,7 @@ public TemporaryVariable temporaryVariable(Shape shape, Dat * @param options carries optional attributes values * @return a new instance of TensorArray */ - public TensorArray tensorArray(Operand size, DataType dtype, + public TensorArray tensorArray(Operand size, Class dtype, TensorArray.Options... options) { return TensorArray.create(scope, size, dtype, options); } @@ -6181,7 +6181,7 @@ public TensorArrayClose tensorArrayClose(Operand handle) { * @return a new instance of TensorArrayConcat */ public TensorArrayConcat tensorArrayConcat(Operand handle, - Operand flowIn, DataType dtype, TensorArrayConcat.Options... options) { + Operand flowIn, Class dtype, TensorArrayConcat.Options... options) { return TensorArrayConcat.create(scope, handle, flowIn, dtype, options); } @@ -6199,7 +6199,7 @@ public TensorArrayConcat tensorArrayConcat(Operand handl * @return a new instance of TensorArrayGather */ public TensorArrayGather tensorArrayGather(Operand handle, - Operand indices, Operand flowIn, DataType dtype, + Operand indices, Operand flowIn, Class dtype, TensorArrayGather.Options... options) { return TensorArrayGather.create(scope, handle, indices, flowIn, dtype, options); } @@ -6287,7 +6287,7 @@ public TensorArrayGradWithShape tensorArrayGradWithShape(Operand handle, * @return a new instance of TensorArrayPack */ public TensorArrayPack tensorArrayPack(Operand handle, - Operand flowIn, DataType dtype, TensorArrayPack.Options... options) { + Operand flowIn, Class dtype, TensorArrayPack.Options... options) { return TensorArrayPack.create(scope, handle, flowIn, dtype, options); } @@ -6302,7 +6302,7 @@ public TensorArrayPack tensorArrayPack(Operand han * @return a new instance of TensorArrayRead */ public TensorArrayRead tensorArrayRead(Operand handle, - Operand index, Operand flowIn, DataType dtype) { + Operand index, Operand flowIn, Class dtype) { return TensorArrayRead.create(scope, handle, index, flowIn, dtype); } @@ -6419,7 +6419,7 @@ public TensorArrayWrite tensorArrayWrite(Operand handle, */ public TensorListConcat tensorListConcat( Operand inputHandle, Operand elementShape, Operand leadingDims, - DataType elementDtype) { + Class elementDtype) { return TensorListConcat.create(scope, inputHandle, elementShape, leadingDims, elementDtype); } @@ -6431,7 +6431,7 @@ public TensorListConcat tensorListConcat * @return a new instance of TensorListConcatLists */ public TensorListConcatLists tensorListConcatLists(Operand inputA, - Operand inputB, DataType elementDtype) { + Operand inputB, Class elementDtype) { return TensorListConcatLists.create(scope, inputA, inputB, elementDtype); } @@ -6447,7 +6447,7 @@ public TensorListConcatLists tensorListConcatLists(Operand * @return a new instance of TensorListElementShape */ public TensorListElementShape tensorListElementShape( - Operand inputHandle, DataType shapeType) { + Operand inputHandle, Class shapeType) { return TensorListElementShape.create(scope, inputHandle, shapeType); } @@ -6486,7 +6486,7 @@ public TensorListFromTensor tensorListFromT * @return a new instance of TensorListGather */ public TensorListGather tensorListGather(Operand inputHandle, - Operand indices, Operand elementShape, DataType elementDtype) { + Operand indices, Operand elementShape, Class elementDtype) { return TensorListGather.create(scope, inputHandle, indices, elementShape, elementDtype); } @@ -6500,7 +6500,7 @@ public TensorListGather tensorListGather(Operand inputHa * @return a new instance of TensorListGetItem */ public TensorListGetItem tensorListGetItem(Operand inputHandle, - Operand index, Operand elementShape, DataType elementDtype) { + Operand index, Operand elementShape, Class elementDtype) { return TensorListGetItem.create(scope, inputHandle, index, elementShape, elementDtype); } @@ -6534,7 +6534,7 @@ public TensorListLength tensorListLength(Operand inputHandle) { * @return a new instance of TensorListPopBack */ public TensorListPopBack tensorListPopBack(Operand inputHandle, - Operand elementShape, DataType elementDtype) { + Operand elementShape, Class elementDtype) { return TensorListPopBack.create(scope, inputHandle, elementShape, elementDtype); } @@ -6581,7 +6581,7 @@ public TensorListPushBackBatch tensorListPushBackBatch(Operand * @return a new instance of TensorListReserve */ public TensorListReserve tensorListReserve( - Operand elementShape, Operand numElements, DataType elementDtype) { + Operand elementShape, Operand numElements, Class elementDtype) { return TensorListReserve.create(scope, elementShape, numElements, elementDtype); } @@ -6697,7 +6697,7 @@ public TensorListSplit tensorListSplit(Oper * @return a new instance of TensorListStack */ public TensorListStack tensorListStack(Operand inputHandle, - Operand elementShape, DataType elementDtype, TensorListStack.Options... options) { + Operand elementShape, Class elementDtype, TensorListStack.Options... options) { return TensorListStack.create(scope, inputHandle, elementShape, elementDtype, options); } @@ -7304,7 +7304,7 @@ public Unique unique(Operand * @return a new instance of Unique */ public Unique unique(Operand x, - Operand axis, DataType outIdx) { + Operand axis, Class outIdx) { return Unique.create(scope, x, axis, outIdx); } @@ -7421,7 +7421,7 @@ public UniqueWithCounts uniqueWi * @return a new instance of UniqueWithCounts */ public UniqueWithCounts uniqueWithCounts( - Operand x, Operand axis, DataType outIdx) { + Operand x, Operand axis, Class outIdx) { return UniqueWithCounts.create(scope, x, axis, outIdx); } @@ -7494,7 +7494,7 @@ public Unstack unstack(Operand value, Long num, * @param options carries optional attributes values * @return a new instance of Unstage */ - public Unstage unstage(List> dtypes, Unstage.Options... options) { + public Unstage unstage(List> dtypes, Unstage.Options... options) { return Unstage.create(scope, dtypes, options); } @@ -7507,7 +7507,7 @@ public Unstage unstage(List> dtypes, Unstage.Options... options) { * @param options carries optional attributes values * @return a new instance of VarHandleOp */ - public VarHandleOp varHandleOp(DataType dtype, Shape shape, + public VarHandleOp varHandleOp(Class dtype, Shape shape, VarHandleOp.Options... options) { return VarHandleOp.create(scope, dtype, shape, options); } @@ -7550,7 +7550,7 @@ public Variable variable(Operand init, Variable.Options. * @param options carries optional attributes values * @return a new instance of Variable */ - public Variable variable(Shape shape, DataType dtype, + public Variable variable(Shape shape, Class dtype, Variable.Options... options) { return Variable.create(scope, shape, dtype, options); } @@ -7590,7 +7590,7 @@ public VariableShape variableShape(Operand input) { * @param outType * @return a new instance of VariableShape */ - public VariableShape variableShape(Operand input, DataType outType) { + public VariableShape variableShape(Operand input, Class outType) { return VariableShape.create(scope, input, outType); } @@ -7708,7 +7708,7 @@ public XlaSpmdShardToFullShape xlaSpmdShardToFullShape(Oper * @return a constant tensor initialized with zeros * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros. */ - public Zeros zeros(Operand dims, DataType type) { + public Zeros zeros(Operand dims, Class type) { return Zeros.create(scope, dims, type); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/QuantizationOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/QuantizationOps.java index 0a0703a8f5b..bcca8f36505 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/QuantizationOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/QuantizationOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.quantization.Dequantize; import org.tensorflow.op.quantization.FakeQuantWithMinMaxArgs; @@ -176,7 +175,7 @@ public Dequantize dequantize(Operand input, * @return a new instance of Dequantize */ public Dequantize dequantize(Operand input, - Operand minRange, Operand maxRange, DataType dtype, + Operand minRange, Operand maxRange, Class dtype, Dequantize.Options... options) { return Dequantize.create(scope, input, minRange, maxRange, dtype, options); } @@ -507,7 +506,7 @@ public FakeQuantWithMinMaxVarsPerChannelGradient fakeQuantWithMinMaxVarsPerChann * @return a new instance of Quantize */ public Quantize quantize(Operand input, Operand minRange, - Operand maxRange, DataType T, Quantize.Options... options) { + Operand maxRange, Class T, Quantize.Options... options) { return Quantize.create(scope, input, minRange, maxRange, T, options); } @@ -565,8 +564,7 @@ public QuantizeAndDequantize quantizeAndDequantize(Operan * @return a new instance of QuantizeDownAndShrinkRange */ public QuantizeDownAndShrinkRange quantizeDownAndShrinkRange( - Operand input, Operand inputMin, Operand inputMax, - DataType outType) { + Operand input, Operand inputMin, Operand inputMax, Class outType) { return QuantizeDownAndShrinkRange.create(scope, input, inputMin, inputMax, outType); } @@ -628,7 +626,7 @@ public RequantizationRange requantizationRange(Operand inpu */ public Requantize requantize(Operand input, Operand inputMin, Operand inputMax, Operand requestedOutputMin, - Operand requestedOutputMax, DataType outType) { + Operand requestedOutputMax, Class outType) { return Requantize.create(scope, input, inputMin, inputMax, requestedOutputMin, requestedOutputMax, outType); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/RandomOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/RandomOps.java index 1dde01b96b1..f0c3b8c660c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/RandomOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/RandomOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.random.AllCandidateSampler; import org.tensorflow.op.random.LogUniformCandidateSampler; @@ -144,7 +143,7 @@ public Multinomial multinomial(Operand logits, * @return a new instance of Multinomial */ public Multinomial multinomial(Operand logits, - Operand numSamples, DataType outputDtype, Multinomial.Options... options) { + Operand numSamples, Class outputDtype, Multinomial.Options... options) { return Multinomial.create(scope, logits, numSamples, outputDtype, options); } @@ -239,7 +238,7 @@ public RandomPoisson randomPoisso * @return a new instance of RandomPoisson */ public RandomPoisson randomPoisson( - Operand shape, Operand rate, DataType dtype, RandomPoisson.Options... options) { + Operand shape, Operand rate, Class dtype, RandomPoisson.Options... options) { return RandomPoisson.create(scope, shape, rate, dtype, options); } @@ -277,7 +276,7 @@ public RandomShuffle randomShuffle(Operand value, * @return a new instance of RandomStandardNormal */ public RandomStandardNormal randomStandardNormal( - Operand shape, DataType dtype, RandomStandardNormal.Options... options) { + Operand shape, Class dtype, RandomStandardNormal.Options... options) { return RandomStandardNormal.create(scope, shape, dtype, options); } @@ -294,7 +293,7 @@ public RandomStandardNormal randomStan * @return a new instance of RandomUniform */ public RandomUniform randomUniform(Operand shape, - DataType dtype, RandomUniform.Options... options) { + Class dtype, RandomUniform.Options... options) { return RandomUniform.create(scope, shape, dtype, options); } @@ -361,7 +360,7 @@ public StatefulRandomBinomial sta */ public StatefulRandomBinomial statefulRandomBinomial( Operand resource, Operand algorithm, Operand shape, Operand counts, - Operand probs, DataType dtype) { + Operand probs, Class dtype) { return StatefulRandomBinomial.create(scope, resource, algorithm, shape, counts, probs, dtype); } @@ -394,7 +393,7 @@ public StatefulStandardNormal statefulStandardNormal * @return a new instance of StatefulStandardNormal */ public StatefulStandardNormal statefulStandardNormal( - Operand resource, Operand algorithm, Operand shape, DataType dtype) { + Operand resource, Operand algorithm, Operand shape, Class dtype) { return StatefulStandardNormal.create(scope, resource, algorithm, shape, dtype); } @@ -425,7 +424,7 @@ public StatelessMultinomial state * @return a new instance of StatelessMultinomial */ public StatelessMultinomial statelessMultinomial( - Operand logits, Operand numSamples, Operand seed, DataType outputDtype) { + Operand logits, Operand numSamples, Operand seed, Class outputDtype) { return StatelessMultinomial.create(scope, logits, numSamples, seed, outputDtype); } @@ -460,7 +459,7 @@ public StatelessRandomNormal st * @return a new instance of StatelessRandomNormal */ public StatelessRandomNormal statelessRandomNormal( - Operand shape, Operand seed, DataType dtype) { + Operand shape, Operand seed, Class dtype) { return StatelessRandomNormal.create(scope, shape, seed, dtype); } @@ -497,7 +496,7 @@ public StatelessRandomUniform s * @return a new instance of StatelessRandomUniform */ public StatelessRandomUniform statelessRandomUniform( - Operand shape, Operand seed, DataType dtype) { + Operand shape, Operand seed, Class dtype) { return StatelessRandomUniform.create(scope, shape, seed, dtype); } @@ -536,7 +535,7 @@ public StatelessTruncatedNormal * @return a new instance of StatelessTruncatedNormal */ public StatelessTruncatedNormal statelessTruncatedNormal( - Operand shape, Operand seed, DataType dtype) { + Operand shape, Operand seed, Class dtype) { return StatelessTruncatedNormal.create(scope, shape, seed, dtype); } @@ -554,7 +553,7 @@ public StatelessTrunca * @return a new instance of TruncatedNormal */ public TruncatedNormal truncatedNormal(Operand shape, - DataType dtype, TruncatedNormal.Options... options) { + Class dtype, TruncatedNormal.Options... options) { return TruncatedNormal.create(scope, shape, dtype, options); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ShapeOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ShapeOps.java index 5ed7ab60999..ac5ec77a7fb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ShapeOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/ShapeOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.core.Shape; import org.tensorflow.op.core.Shapes; @@ -115,12 +114,12 @@ public Operand flatten(Shape shape) { * @param the shape datatype * @param scope current scope * @param operand the operand to flatten - * @param dType the shape datatype + * @param type the shape datatype * @return the reshaped operand */ public Operand flatten(Operand operand, - DataType dType) { - return Shapes.flatten(scope, operand, dType); + Class type) { + return Shapes.flatten(scope, operand, type); } /** @@ -129,11 +128,11 @@ public Operand flatten(Operand operan * @param the shape datatype * @param scope current scope * @param shape the TensorFlow shape - * @param dType the shape datatype + * @param type the shape datatype * @return the flattened shape */ - public Operand flatten(Shape shape, DataType dType) { - return Shapes.flatten(scope, shape, dType); + public Operand flatten(Shape shape, Class type) { + return Shapes.flatten(scope, shape, type); } /** @@ -152,12 +151,12 @@ public Operand head(Shape shape) { * * @param scope current scope * @param shape the TensorFlow shape - * @param dType the shape datatype. + * @param type the shape datatype. * @param the shape datatype. * @return a 1-dimensional Operand containing the Shape's first dimension */ - public Operand head(Shape shape, DataType dType) { - return Shapes.head(scope, shape, dType); + public Operand head(Shape shape, Class type) { + return Shapes.head(scope, shape, type); } /** @@ -177,11 +176,11 @@ public Operand numDimensions(Shape shape) { * @param the shape datatype * @param scope the curren scope * @param shape the shape - * @param dType the shape datatype + * @param type the shape datatype * @return the number of dimensions */ - public Operand numDimensions(Shape shape, DataType dType) { - return Shapes.numDimensions(scope, shape, dType); + public Operand numDimensions(Shape shape, Class type) { + return Shapes.numDimensions(scope, shape, type); } /** @@ -261,12 +260,12 @@ public Operand reduceDims(Shape shape, Operand axis) { * @param scope current scope * @param operand the operand * @param axis the axis - * @param dType the shape datatype + * @param type the shape datatype * @return the reshaped operand */ public Operand reduceDims(Operand operand, - Operand axis, DataType dType) { - return Shapes.reduceDims(scope, operand, axis, dType); + Operand axis, Class type) { + return Shapes.reduceDims(scope, operand, axis, type); } /** @@ -276,12 +275,11 @@ public Operand reduceDims(Operand ope * @param scope current scope * @param shape the TensorFlow shape * @param axis the axis - * @param dType the shape datatype + * @param type the shape datatype * @return the reduced shape */ - public Operand reduceDims(Shape shape, Operand axis, - DataType dType) { - return Shapes.reduceDims(scope, shape, axis, dType); + public Operand reduceDims(Shape shape, Operand axis, Class type) { + return Shapes.reduceDims(scope, shape, axis, type); } /** @@ -308,28 +306,28 @@ public Operand size(Operand input, Operand } /** - * Get the size represented by the TensorFlow shape. + * Get the size of the specified dimension in the shape. * - * @param the type of the shape * @param scope current scope * @param shape the TensorFlow shape - * @param dType the shape datatype - * @return the size + * @param dim the dimension + * @return the size of the specified dimension */ - public Operand size(Shape shape, DataType dType) { - return Shapes.size(scope, shape, dType); + public Operand size(Shape shape, Operand dim) { + return Shapes.size(scope, shape, dim); } /** - * Get the size of the specified dimension in the shape. + * Get the size represented by the TensorFlow shape. * + * @param the type of the shape * @param scope current scope * @param shape the TensorFlow shape - * @param dim the dimension - * @return the size of the specified dimension + * @param type the shape datatype + * @return the size */ - public Operand size(Shape shape, Operand dim) { - return Shapes.size(scope, shape, dim); + public Operand size(Shape shape, Class type) { + return Shapes.size(scope, shape, type); } /** @@ -339,12 +337,12 @@ public Operand size(Shape shape, Operand dim) { * @param scope current scope * @param input the operand * @param dim the dimension - * @param dType the shape datatype + * @param type the shape datatype * @return the size of the specified dimension */ public Operand size(Operand input, Operand dim, - DataType dType) { - return Shapes.size(scope, input, dim, dType); + Class type) { + return Shapes.size(scope, input, dim, type); } /** @@ -354,11 +352,11 @@ public Operand size(Operand input, Op * @param scope current scope * @param shape the TensorFlow shape * @param dim the dimension - * @param dType the shape datatype + * @param type the shape datatype * @return the size of the specified dimension */ - public Operand size(Shape shape, Operand dim, DataType dType) { - return Shapes.size(scope, shape, dim, dType); + public Operand size(Shape shape, Operand dim, Class type) { + return Shapes.size(scope, shape, dim, type); } /** @@ -378,11 +376,11 @@ public Operand squeeze(Shape shape) { * @param the shape datatype. * @param scope current scope * @param shape the TensorFlow shape - * @param dType the shape datatype. + * @param type the shape datatype. * @return the squeezed shape */ - public Operand squeeze(Shape shape, DataType dType) { - return Shapes.squeeze(scope, shape, dType); + public Operand squeeze(Shape shape, Class type) { + return Shapes.squeeze(scope, shape, type); } /** @@ -404,13 +402,13 @@ public Operand tail(Shape shape) { * * @param scope current scope * @param shape the TensorFlow shape - * @param dType the shape datatype. + * @param type the shape datatype. * @param the shape datatype. * @return a 1-dimensional Operand that contains the dimension matching the last dimension of the * Shape */ - public Operand tail(Shape shape, DataType dType) { - return Shapes.tail(scope, shape, dType); + public Operand tail(Shape shape, Class type) { + return Shapes.tail(scope, shape, type); } /** @@ -434,13 +432,13 @@ public Operand take(Shape shape, Operand n) { * @param scope current scope * @param shape the TensorFlow shape * @param n the number of leading dimensions to get, must be <= than the shape's numDimensions() - * @param dType the shape datatype. + * @param type the shape datatype. * @param the shape datatype. * @return a 1-dimensional operand with the dimensions matching * the first n dimensions of the * shape */ - public Operand take(Shape shape, Operand n, DataType dType) { - return Shapes.take(scope, shape, n, dType); + public Operand take(Shape shape, Operand n, Class type) { + return Shapes.take(scope, shape, n, type); } /** @@ -464,13 +462,13 @@ public Operand takeLast(Shape shape, Operand * @param scope current scope * @param shape the TensorFlow shape * @param n the number of leading dimensions to get, must be <= than the shape's numDimensions() - * @param dType the shape datatype. + * @param type the shape datatype. * @param the shape datatype. * @return a 1-dimensional operand containing the dimensions matching the last n dimensions of the * shape */ - public Operand takeLast(Shape shape, Operand n, DataType dType) { - return Shapes.takeLast(scope, shape, n, dType); + public Operand takeLast(Shape shape, Operand n, Class type) { + return Shapes.takeLast(scope, shape, n, type); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SignalOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SignalOps.java index b71fb5cce13..e8ac9a0e53b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SignalOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SignalOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.signal.BatchFft; import org.tensorflow.op.signal.BatchFft2d; @@ -245,7 +244,7 @@ public Irfft irfft(Operand input, Operand * @return a new instance of Irfft */ public Irfft irfft(Operand input, - Operand fftLength, DataType Treal) { + Operand fftLength, Class Treal) { return Irfft.create(scope, input, fftLength, Treal); } @@ -301,7 +300,7 @@ public Irfft2d irfft2d(Operand input, Operand Irfft2d irfft2d(Operand input, - Operand fftLength, DataType Treal) { + Operand fftLength, Class Treal) { return Irfft2d.create(scope, input, fftLength, Treal); } @@ -357,7 +356,7 @@ public Irfft3d irfft3d(Operand input, Operand Irfft3d irfft3d(Operand input, - Operand fftLength, DataType Treal) { + Operand fftLength, Class Treal) { return Irfft3d.create(scope, input, fftLength, Treal); } @@ -382,7 +381,7 @@ public Irfft3d irfft3d(Operand input, * @return a new instance of Rfft */ public Rfft rfft(Operand input, - Operand fftLength, DataType Tcomplex) { + Operand fftLength, Class Tcomplex) { return Rfft.create(scope, input, fftLength, Tcomplex); } @@ -408,7 +407,7 @@ public Rfft rfft(Operand input, * @return a new instance of Rfft2d */ public Rfft2d rfft2d(Operand input, - Operand fftLength, DataType Tcomplex) { + Operand fftLength, Class Tcomplex) { return Rfft2d.create(scope, input, fftLength, Tcomplex); } @@ -434,7 +433,7 @@ public Rfft2d rfft2d(Operand input, * @return a new instance of Rfft3d */ public Rfft3d rfft3d(Operand input, - Operand fftLength, DataType Tcomplex) { + Operand fftLength, Class Tcomplex) { return Rfft3d.create(scope, input, fftLength, Tcomplex); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SparseOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SparseOps.java index 4bd61718ecc..3971fc6fc06 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SparseOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/SparseOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.sparse.AddManySparseToTensorsMap; @@ -272,7 +271,7 @@ public DenseToSparseSetOperation denseToSparseSetOperation( * @return a new instance of DeserializeSparse */ public DeserializeSparse deserializeSparse( - Operand serializedSparse, DataType dtype) { + Operand serializedSparse, Class dtype) { return DeserializeSparse.create(scope, serializedSparse, dtype); } @@ -318,7 +317,7 @@ public SparseAccumulatorApplyGradient sparseAccumulatorApplyGr * @return a new instance of SparseAccumulatorTakeGradient */ public SparseAccumulatorTakeGradient sparseAccumulatorTakeGradient( - Operand handle, Operand numRequired, DataType dtype) { + Operand handle, Operand numRequired, Class dtype) { return SparseAccumulatorTakeGradient.create(scope, handle, numRequired, dtype); } @@ -479,8 +478,8 @@ public SparseConcat sparseConcat(Iterable> * @param options carries optional attributes values * @return a new instance of SparseConditionalAccumulator */ - public SparseConditionalAccumulator sparseConditionalAccumulator( - DataType dtype, Shape shape, SparseConditionalAccumulator.Options... options) { + public SparseConditionalAccumulator sparseConditionalAccumulator(Class dtype, + Shape shape, SparseConditionalAccumulator.Options... options) { return SparseConditionalAccumulator.create(scope, dtype, shape, options); } @@ -1496,7 +1495,7 @@ public SparseToSparseSetOperation sparseToSparseSetOperatio * @return a new instance of TakeManySparseFromTensorsMap */ public TakeManySparseFromTensorsMap takeManySparseFromTensorsMap( - Operand sparseHandles, DataType dtype, + Operand sparseHandles, Class dtype, TakeManySparseFromTensorsMap.Options... options) { return TakeManySparseFromTensorsMap.create(scope, sparseHandles, dtype, options); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/StringsOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/StringsOps.java index d1f2940bfa2..6d380c31bc6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/StringsOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/StringsOps.java @@ -18,7 +18,6 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.strings.Join; import org.tensorflow.op.strings.Lower; @@ -486,8 +485,7 @@ public ToNumber toNumber(Operand stringTensor) { * @param outType The numeric type to interpret each string in `string_tensor` as. * @return a new instance of ToNumber */ - public ToNumber toNumber(Operand stringTensor, - DataType outType) { + public ToNumber toNumber(Operand stringTensor, Class outType) { return ToNumber.create(scope, stringTensor, outType); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TrainOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TrainOps.java index d45e2c85425..d21b5e037d3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TrainOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TrainOps.java @@ -18,7 +18,6 @@ package org.tensorflow.op; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.train.AccumulatorApplyGradient; @@ -162,7 +161,7 @@ public AccumulatorSetGlobalStep accumulatorSetGlobalStep(Operand handle * @return a new instance of AccumulatorTakeGradient */ public AccumulatorTakeGradient accumulatorTakeGradient( - Operand handle, Operand numRequired, DataType dtype) { + Operand handle, Operand numRequired, Class dtype) { return AccumulatorTakeGradient.create(scope, handle, numRequired, dtype); } @@ -545,7 +544,7 @@ public BatchMatMul batchMatMul(Operand x, Operand y, * @param options carries optional attributes values * @return a new instance of ConditionalAccumulator */ - public ConditionalAccumulator conditionalAccumulator(DataType dtype, + public ConditionalAccumulator conditionalAccumulator(Class dtype, Shape shape, ConditionalAccumulator.Options... options) { return ConditionalAccumulator.create(scope, dtype, shape, options); } @@ -1298,7 +1297,7 @@ public ResourceSparseApplyRmsProp resourceS * @return a new instance of Restore */ public Restore restore(Operand prefix, Operand tensorNames, - Operand shapeAndSlices, List> dtypes) { + Operand shapeAndSlices, List> dtypes) { return Restore.create(scope, prefix, tensorNames, shapeAndSlices, dtypes); } @@ -1324,7 +1323,7 @@ public Restore restore(Operand prefix, Operand tensorNames, * @return a new instance of RestoreSlice */ public RestoreSlice restoreSlice(Operand filePattern, - Operand tensorName, Operand shapeAndSlice, DataType dt, + Operand tensorName, Operand shapeAndSlice, Class dt, RestoreSlice.Options... options) { return RestoreSlice.create(scope, filePattern, tensorName, shapeAndSlice, dt, options); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java index 585e721733a..4c8df665739 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java @@ -17,7 +17,6 @@ // package org.tensorflow.op; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.xla.BroadcastHelper; @@ -281,7 +280,7 @@ public Pad pad(Operand input, Operand * @param shape The shape of the tensor. * @return a new instance of Recv */ - public Recv recv(DataType dtype, String tensorName, Shape shape) { + public Recv recv(Class dtype, String tensorName, Shape shape) { return Recv.create(scope, dtype, tensorName, shape); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseAnd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseAnd.java index 692eaa591aa..86fad697878 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseAnd.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseAnd.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Elementwise computes the bitwise AND of `x` and `y`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseOr.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseOr.java index 3d7837ec98a..cea0b766cfe 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseOr.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseOr.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Elementwise computes the bitwise OR of `x` and `y`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseXor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseXor.java index af7de8fc140..f209732bb5c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseXor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseXor.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Elementwise computes the bitwise XOR of `x` and `y`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/Invert.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/Invert.java index fe92038487b..4f4e063ed50 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/Invert.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/Invert.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Invert (flip) each bit of supported types; for example, type `uint8` value 01010101 becomes 10101010. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/LeftShift.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/LeftShift.java index c85fa62863a..98ca43ae889 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/LeftShift.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/LeftShift.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Elementwise computes the bitwise left-shift of `x` and `y`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/RightShift.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/RightShift.java index 80bb7418680..4d08a05bf1a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/RightShift.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/RightShift.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Elementwise computes the bitwise right-shift of `x` and `y`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/AllReduce.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/AllReduce.java index 3301d27fcd8..adaafc5fc90 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/AllReduce.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/AllReduce.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Mutually reduces multiple tensors of identical type and shape. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/BroadcastRecv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/BroadcastRecv.java index 4b22670cb9d..bce981a1f6f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/BroadcastRecv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/BroadcastRecv.java @@ -17,12 +17,12 @@ package org.tensorflow.op.collective; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -77,10 +77,10 @@ private Options() { * @return a new instance of BroadcastRecv */ @Endpoint(describeByClass = true) - public static BroadcastRecv create(Scope scope, DataType T, Long groupSize, Long groupKey, Long instanceKey, Shape shape, Options... options) { + public static BroadcastRecv create(Scope scope, Class T, Long groupSize, Long groupKey, Long instanceKey, Shape shape, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("CollectiveBcastRecv", scope.makeOpName("BroadcastRecv")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("T", T); + opBuilder.setAttr("T", Operands.toDataType(T)); opBuilder.setAttr("group_size", groupSize); opBuilder.setAttr("group_key", groupKey); opBuilder.setAttr("instance_key", instanceKey); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/All.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/All.java index 231336110a5..17b3a3bf0c3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/All.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/All.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the "logical and" of elements across dimensions of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Any.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Any.java index 031cfdb03b3..85a60f6b2e0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Any.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Any.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the "logical or" of elements across dimensions of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Barrier.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Barrier.java index 6bdda78f84d..b429cce3084 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Barrier.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Barrier.java @@ -18,17 +18,18 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * Defines a barrier that persists across different graph executions. @@ -105,14 +106,10 @@ private Options() { * @return a new instance of Barrier */ @Endpoint(describeByClass = true) - public static Barrier create(Scope scope, List> componentTypes, Options... options) { + public static Barrier create(Scope scope, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Barrier", scope.makeOpName("Barrier")); opBuilder = scope.apply(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; - for (int i = 0; i < componentTypesArray.length; ++i) { - componentTypesArray[i] = componentTypes.get(i); - } - opBuilder.setAttr("component_types", componentTypesArray); + opBuilder.setAttr("component_types", Operands.toDataTypes(componentTypes)); if (options != null) { for (Options opts : options) { if (opts.shapes != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierTakeMany.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierTakeMany.java index 1473dce90b6..4bf16ef62e0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierTakeMany.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierTakeMany.java @@ -19,11 +19,11 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -31,6 +31,7 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * Takes the given number of completed elements from a barrier. @@ -98,16 +99,12 @@ private Options() { * @return a new instance of BarrierTakeMany */ @Endpoint(describeByClass = true) - public static BarrierTakeMany create(Scope scope, Operand handle, Operand numElements, List> componentTypes, Options... options) { + public static BarrierTakeMany create(Scope scope, Operand handle, Operand numElements, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("BarrierTakeMany", scope.makeOpName("BarrierTakeMany")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(numElements.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; - for (int i = 0; i < componentTypesArray.length; ++i) { - componentTypesArray[i] = componentTypes.get(i); - } - opBuilder.setAttr("component_types", componentTypesArray); + opBuilder.setAttr("component_types", Operands.toDataTypes(componentTypes)); if (options != null) { for (Options opts : options) { if (opts.allowSmallBatch != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bitcast.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bitcast.java index f4f74845369..04ea83cf3be 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bitcast.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bitcast.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -96,11 +96,11 @@ public final class Bitcast extends RawOp implements Operand * @return a new instance of Bitcast */ @Endpoint(describeByClass = true) - public static Bitcast create(Scope scope, Operand input, DataType type) { + public static Bitcast create(Scope scope, Operand input, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("Bitcast", scope.makeOpName("Bitcast")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("type", type); + opBuilder.setAttr("type", Operands.toDataType(type)); return new Bitcast(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastDynamicShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastDynamicShape.java index 346ff14c7ec..91611ab6888 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastDynamicShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastDynamicShape.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Return the shape of s0 op s1 with broadcast. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastGradientArgs.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastGradientArgs.java index ca291872f8d..8aa19f97259 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastGradientArgs.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BroadcastGradientArgs.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Return the reduction indices for computing gradients of s0 op s1 with broadcast. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bucketize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bucketize.java index ce79347d467..87197f19c3d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bucketize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Bucketize.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Bucketizes 'input' based on 'boundaries'. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CollectiveGather.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CollectiveGather.java index b84c8833780..2e1de3898fa 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CollectiveGather.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CollectiveGather.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Mutually accumulates multiple tensors of identical type and shape. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CountUpTo.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CountUpTo.java index c3f2dd9870f..5d518629b78 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CountUpTo.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/CountUpTo.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Increments 'ref' until it reaches 'limit'. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DecodeProto.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DecodeProto.java index 127a5bffe8c..1b8373c8e6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DecodeProto.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DecodeProto.java @@ -19,17 +19,18 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * The op extracts fields from a serialized protocol buffers message into tensors. @@ -135,7 +136,7 @@ private Options() { * @return a new instance of DecodeProto */ @Endpoint(describeByClass = true) - public static DecodeProto create(Scope scope, Operand bytes, String messageType, List fieldNames, List> outputTypes, Options... options) { + public static DecodeProto create(Scope scope, Operand bytes, String messageType, List fieldNames, List> outputTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("DecodeProtoV2", scope.makeOpName("DecodeProto")); opBuilder.addInput(bytes.asOutput()); opBuilder = scope.apply(opBuilder); @@ -145,11 +146,7 @@ public static DecodeProto create(Scope scope, Operand bytes, String mes fieldNamesArray[i] = fieldNames.get(i); } opBuilder.setAttr("field_names", fieldNamesArray); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); if (options != null) { for (Options opts : options) { if (opts.descriptorSource != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Empty.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Empty.java index 64e18b3e566..7e7140b9015 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Empty.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Empty.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -68,11 +68,11 @@ private Options() { * @return a new instance of Empty */ @Endpoint(describeByClass = true) - public static Empty create(Scope scope, Operand shape, DataType dtype, Options... options) { + public static Empty create(Scope scope, Operand shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Empty", scope.makeOpName("Empty")); opBuilder.addInput(shape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.init != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/EmptyTensorList.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/EmptyTensorList.java index 37fcfd5cd93..e5ec261e5b8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/EmptyTensorList.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/EmptyTensorList.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -53,12 +53,12 @@ public final class EmptyTensorList extends RawOp implements Operand { * @return a new instance of EmptyTensorList */ @Endpoint(describeByClass = true) - public static EmptyTensorList create(Scope scope, Operand elementShape, Operand maxNumElements, DataType elementDtype) { + public static EmptyTensorList create(Scope scope, Operand elementShape, Operand maxNumElements, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("EmptyTensorList", scope.makeOpName("EmptyTensorList")); opBuilder.addInput(elementShape.asOutput()); opBuilder.addInput(maxNumElements.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("element_dtype", elementDtype); + opBuilder.setAttr("element_dtype", Operands.toDataType(elementDtype)); return new EmptyTensorList(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ExtractVolumePatches.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ExtractVolumePatches.java index f90febe4c1d..b52cc4b3a77 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ExtractVolumePatches.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ExtractVolumePatches.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Extract `patches` from `input` and put them in the "depth" output dimension. 3D extension of `extract_image_patches`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GetSessionTensor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GetSessionTensor.java index ba2e22c337d..8a04f173ef8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GetSessionTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GetSessionTensor.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -46,11 +46,11 @@ public final class GetSessionTensor extends RawOp implements Op * @return a new instance of GetSessionTensor */ @Endpoint(describeByClass = true) - public static GetSessionTensor create(Scope scope, Operand handle, DataType dtype) { + public static GetSessionTensor create(Scope scope, Operand handle, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("GetSessionTensor", scope.makeOpName("GetSessionTensor")); opBuilder.addInput(handle.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new GetSessionTensor(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HashTable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HashTable.java index 7ecb8439cd2..75f53539f84 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HashTable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HashTable.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -88,11 +88,11 @@ private Options() { * @return a new instance of HashTable */ @Endpoint(describeByClass = true) - public static HashTable create(Scope scope, DataType keyDtype, DataType valueDtype, Options... options) { + public static HashTable create(Scope scope, Class keyDtype, Class valueDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("HashTableV2", scope.makeOpName("HashTable")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("key_dtype", keyDtype); - opBuilder.setAttr("value_dtype", valueDtype); + opBuilder.setAttr("key_dtype", Operands.toDataType(keyDtype)); + opBuilder.setAttr("value_dtype", Operands.toDataType(valueDtype)); if (options != null) { for (Options opts : options) { if (opts.container != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HistogramFixedWidth.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HistogramFixedWidth.java index fbdedfb8d5c..86e6cce541a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HistogramFixedWidth.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/HistogramFixedWidth.java @@ -17,18 +17,17 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Return histogram of values. @@ -67,13 +66,13 @@ public final class HistogramFixedWidth extends RawOp implemen * @return a new instance of HistogramFixedWidth */ @Endpoint(describeByClass = true) - public static HistogramFixedWidth create(Scope scope, Operand values, Operand valueRange, Operand nbins, DataType dtype) { + public static HistogramFixedWidth create(Scope scope, Operand values, Operand valueRange, Operand nbins, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("HistogramFixedWidth", scope.makeOpName("HistogramFixedWidth")); opBuilder.addInput(values.asOutput()); opBuilder.addInput(valueRange.asOutput()); opBuilder.addInput(nbins.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new HistogramFixedWidth(opBuilder.build()); } @@ -90,7 +89,7 @@ public static HistogramFixedWidth crea */ @Endpoint(describeByClass = true) public static HistogramFixedWidth create(Scope scope, Operand values, Operand valueRange, Operand nbins) { - return create(scope, values, valueRange, nbins, TInt32.DTYPE); + return create(scope, values, valueRange, nbins, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ImmutableConst.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ImmutableConst.java index 404c4672776..319304aff8b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ImmutableConst.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ImmutableConst.java @@ -17,12 +17,12 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -50,10 +50,10 @@ public final class ImmutableConst extends RawOp implements Oper * @return a new instance of ImmutableConst */ @Endpoint(describeByClass = true) - public static ImmutableConst create(Scope scope, DataType dtype, Shape shape, String memoryRegionName) { + public static ImmutableConst create(Scope scope, Class dtype, Shape shape, String memoryRegionName) { OperationBuilder opBuilder = scope.env().opBuilder("ImmutableConst", scope.makeOpName("ImmutableConst")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); opBuilder.setAttr("shape", shape); opBuilder.setAttr("memory_region_name", memoryRegionName); return new ImmutableConst(opBuilder.build()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LinSpace.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LinSpace.java index c0694e4165a..29662e643f8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LinSpace.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LinSpace.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Generates values in an interval. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableExport.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableExport.java index 952bf8e9994..7685fa5e7b4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableExport.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableExport.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,12 +47,12 @@ public final class LookupTableExport extends R * @return a new instance of LookupTableExport */ @Endpoint(describeByClass = true) - public static LookupTableExport create(Scope scope, Operand tableHandle, DataType Tkeys, DataType Tvalues) { + public static LookupTableExport create(Scope scope, Operand tableHandle, Class Tkeys, Class Tvalues) { OperationBuilder opBuilder = scope.env().opBuilder("LookupTableExportV2", scope.makeOpName("LookupTableExport")); opBuilder.addInput(tableHandle.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Tkeys", Tkeys); - opBuilder.setAttr("Tvalues", Tvalues); + opBuilder.setAttr("Tkeys", Operands.toDataType(Tkeys)); + opBuilder.setAttr("Tvalues", Operands.toDataType(Tvalues)); return new LookupTableExport(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LowerBound.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LowerBound.java index f2671d23aae..1aa8badb5a1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LowerBound.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LowerBound.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -66,12 +66,12 @@ public final class LowerBound extends RawOp implements Operan * @return a new instance of LowerBound */ @Endpoint(describeByClass = true) - public static LowerBound create(Scope scope, Operand sortedInputs, Operand values, DataType outType) { + public static LowerBound create(Scope scope, Operand sortedInputs, Operand values, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("LowerBound", scope.makeOpName("LowerBound")); opBuilder.addInput(sortedInputs.asOutput()); opBuilder.addInput(values.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new LowerBound(opBuilder.build()); } @@ -86,7 +86,7 @@ public static LowerBound create(Scope sc */ @Endpoint(describeByClass = true) public static LowerBound create(Scope scope, Operand sortedInputs, Operand values) { - return create(scope, sortedInputs, values, TInt32.DTYPE); + return create(scope, sortedInputs, values, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java index 5b3c43ee0d8..e680e7db08a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java @@ -18,13 +18,14 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; /** * Op removes all elements in the underlying container. @@ -87,14 +88,10 @@ private Options() { * @return a new instance of MapClear */ @Endpoint(describeByClass = true) - public static MapClear create(Scope scope, List> dtypes, Options... options) { + public static MapClear create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapClear", scope.makeOpName("MapClear")); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapIncompleteSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapIncompleteSize.java index 016ec0f023c..6fa921bb408 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapIncompleteSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapIncompleteSize.java @@ -18,16 +18,17 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; /** * Op returns the number of incomplete elements in the underlying container. @@ -90,14 +91,10 @@ private Options() { * @return a new instance of MapIncompleteSize */ @Endpoint(describeByClass = true) - public static MapIncompleteSize create(Scope scope, List> dtypes, Options... options) { + public static MapIncompleteSize create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapIncompleteSize", scope.makeOpName("MapIncompleteSize")); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapPeek.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapPeek.java index 15e941ec891..316bc08c64c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapPeek.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapPeek.java @@ -20,11 +20,11 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -99,16 +99,12 @@ private Options() { * @return a new instance of MapPeek */ @Endpoint(describeByClass = true) - public static MapPeek create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { + public static MapPeek create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapPeek", scope.makeOpName("MapPeek")); opBuilder.addInput(key.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapSize.java index f6e99dc5ee4..0cd9510a8f7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapSize.java @@ -18,16 +18,17 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; /** * Op returns the number of elements in the underlying container. @@ -90,14 +91,10 @@ private Options() { * @return a new instance of MapSize */ @Endpoint(describeByClass = true) - public static MapSize create(Scope scope, List> dtypes, Options... options) { + public static MapSize create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapSize", scope.makeOpName("MapSize")); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java index 2e5c7d0ea4f..76f9086f46e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,6 +28,7 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; /** * Stage (key, values) in the underlying container which behaves like a hashtable. @@ -97,17 +97,13 @@ private Options() { * @return a new instance of MapStage */ @Endpoint(describeByClass = true) - public static MapStage create(Scope scope, Operand key, Operand indices, Iterable> values, List> dtypes, Options... options) { + public static MapStage create(Scope scope, Operand key, Operand indices, Iterable> values, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapStage", scope.makeOpName("MapStage")); opBuilder.addInput(key.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder.addInputList(Operands.asOutputs(values)); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstage.java index e6cab844ab0..6d189a50d7b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstage.java @@ -20,11 +20,11 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -99,16 +99,12 @@ private Options() { * @return a new instance of MapUnstage */ @Endpoint(describeByClass = true) - public static MapUnstage create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { + public static MapUnstage create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapUnstage", scope.makeOpName("MapUnstage")); opBuilder.addInput(key.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstageNoKey.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstageNoKey.java index a523400c34e..9848ab2d845 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstageNoKey.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapUnstageNoKey.java @@ -19,17 +19,18 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; /** * Op removes and returns a random (key, value) @@ -96,15 +97,11 @@ private Options() { * @return a new instance of MapUnstageNoKey */ @Endpoint(describeByClass = true) - public static MapUnstageNoKey create(Scope scope, Operand indices, List> dtypes, Options... options) { + public static MapUnstageNoKey create(Scope scope, Operand indices, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MapUnstageNoKey", scope.makeOpName("MapUnstageNoKey")); opBuilder.addInput(indices.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MlirPassthroughOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MlirPassthroughOp.java index 1596cf10e84..e80af4ecee8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MlirPassthroughOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MlirPassthroughOp.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -78,16 +77,12 @@ public final class MlirPassthroughOp extends RawOp implements Iterable> inputs, String mlirModule, List> Toutputs) { + public static MlirPassthroughOp create(Scope scope, Iterable> inputs, String mlirModule, List> Toutputs) { OperationBuilder opBuilder = scope.env().opBuilder("MlirPassthroughOp", scope.makeOpName("MlirPassthroughOp")); opBuilder.addInputList(Operands.asOutputs(inputs)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("mlir_module", mlirModule); - DataType[] ToutputsArray = new DataType[Toutputs.size()]; - for (int i = 0; i < ToutputsArray.length; ++i) { - ToutputsArray[i] = Toutputs.get(i); - } - opBuilder.setAttr("Toutputs", ToutputsArray); + opBuilder.setAttr("Toutputs", Operands.toDataTypes(Toutputs)); return new MlirPassthroughOp(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableDenseHashTable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableDenseHashTable.java index dc73ae16585..28a85dc0082 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableDenseHashTable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableDenseHashTable.java @@ -17,12 +17,12 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -122,12 +122,12 @@ private Options() { * @return a new instance of MutableDenseHashTable */ @Endpoint(describeByClass = true) - public static MutableDenseHashTable create(Scope scope, Operand emptyKey, Operand deletedKey, DataType valueDtype, Options... options) { + public static MutableDenseHashTable create(Scope scope, Operand emptyKey, Operand deletedKey, Class valueDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MutableDenseHashTableV2", scope.makeOpName("MutableDenseHashTable")); opBuilder.addInput(emptyKey.asOutput()); opBuilder.addInput(deletedKey.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("value_dtype", valueDtype); + opBuilder.setAttr("value_dtype", Operands.toDataType(valueDtype)); if (options != null) { for (Options opts : options) { if (opts.container != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTable.java index 469742f61fa..1a551a3d77f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTable.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -88,11 +88,11 @@ private Options() { * @return a new instance of MutableHashTable */ @Endpoint(describeByClass = true) - public static MutableHashTable create(Scope scope, DataType keyDtype, DataType valueDtype, Options... options) { + public static MutableHashTable create(Scope scope, Class keyDtype, Class valueDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MutableHashTableV2", scope.makeOpName("MutableHashTable")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("key_dtype", keyDtype); - opBuilder.setAttr("value_dtype", valueDtype); + opBuilder.setAttr("key_dtype", Operands.toDataType(keyDtype)); + opBuilder.setAttr("value_dtype", Operands.toDataType(valueDtype)); if (options != null) { for (Options opts : options) { if (opts.container != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTableOfTensors.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTableOfTensors.java index d15b924f034..c054b60ad0b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTableOfTensors.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MutableHashTableOfTensors.java @@ -17,12 +17,12 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -97,11 +97,11 @@ private Options() { * @return a new instance of MutableHashTableOfTensors */ @Endpoint(describeByClass = true) - public static MutableHashTableOfTensors create(Scope scope, DataType keyDtype, DataType valueDtype, Options... options) { + public static MutableHashTableOfTensors create(Scope scope, Class keyDtype, Class valueDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MutableHashTableOfTensorsV2", scope.makeOpName("MutableHashTableOfTensors")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("key_dtype", keyDtype); - opBuilder.setAttr("value_dtype", valueDtype); + opBuilder.setAttr("key_dtype", Operands.toDataType(keyDtype)); + opBuilder.setAttr("value_dtype", Operands.toDataType(valueDtype)); if (options != null) { for (Options opts : options) { if (opts.container != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclAllReduce.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclAllReduce.java index 1e2e0b440ff..8d891145adb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclAllReduce.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclAllReduce.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs a tensor containing the reduction across all input tensors. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclBroadcast.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclBroadcast.java index 7b5dd9d98b7..8febd22137b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclBroadcast.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclBroadcast.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Sends `input` to all devices that are connected to the output. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclReduce.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclReduce.java index a88f18ebe5f..ceb8480274c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclReduce.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NcclReduce.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Reduces `input` from `num_devices` using `reduction` to a single device. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java index 14c7aef5aac..dd789bcec82 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java @@ -18,13 +18,14 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; /** * Op removes all elements in the underlying container. @@ -87,14 +88,10 @@ private Options() { * @return a new instance of OrderedMapClear */ @Endpoint(describeByClass = true) - public static OrderedMapClear create(Scope scope, List> dtypes, Options... options) { + public static OrderedMapClear create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapClear", scope.makeOpName("OrderedMapClear")); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapIncompleteSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapIncompleteSize.java index 06c0e20af69..56d520b02ff 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapIncompleteSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapIncompleteSize.java @@ -18,16 +18,17 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; /** * Op returns the number of incomplete elements in the underlying container. @@ -90,14 +91,10 @@ private Options() { * @return a new instance of OrderedMapIncompleteSize */ @Endpoint(describeByClass = true) - public static OrderedMapIncompleteSize create(Scope scope, List> dtypes, Options... options) { + public static OrderedMapIncompleteSize create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapIncompleteSize", scope.makeOpName("OrderedMapIncompleteSize")); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapPeek.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapPeek.java index 91af29049bf..893be796c79 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapPeek.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapPeek.java @@ -20,11 +20,11 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -100,16 +100,12 @@ private Options() { * @return a new instance of OrderedMapPeek */ @Endpoint(describeByClass = true) - public static OrderedMapPeek create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { + public static OrderedMapPeek create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapPeek", scope.makeOpName("OrderedMapPeek")); opBuilder.addInput(key.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapSize.java index 1a995fbd947..3c561660900 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapSize.java @@ -18,16 +18,17 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; /** * Op returns the number of elements in the underlying container. @@ -90,14 +91,10 @@ private Options() { * @return a new instance of OrderedMapSize */ @Endpoint(describeByClass = true) - public static OrderedMapSize create(Scope scope, List> dtypes, Options... options) { + public static OrderedMapSize create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapSize", scope.makeOpName("OrderedMapSize")); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java index aee7309dbd4..a78c20a9623 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -29,6 +28,7 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; /** * Stage (key, values) in the underlying container which behaves like a ordered @@ -99,17 +99,13 @@ private Options() { * @return a new instance of OrderedMapStage */ @Endpoint(describeByClass = true) - public static OrderedMapStage create(Scope scope, Operand key, Operand indices, Iterable> values, List> dtypes, Options... options) { + public static OrderedMapStage create(Scope scope, Operand key, Operand indices, Iterable> values, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapStage", scope.makeOpName("OrderedMapStage")); opBuilder.addInput(key.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder.addInputList(Operands.asOutputs(values)); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstage.java index 5666b321065..667f0f198fb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstage.java @@ -20,11 +20,11 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -99,16 +99,12 @@ private Options() { * @return a new instance of OrderedMapUnstage */ @Endpoint(describeByClass = true) - public static OrderedMapUnstage create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { + public static OrderedMapUnstage create(Scope scope, Operand key, Operand indices, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapUnstage", scope.makeOpName("OrderedMapUnstage")); opBuilder.addInput(key.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstageNoKey.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstageNoKey.java index 4f28b783451..fb0d239d6a2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstageNoKey.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapUnstageNoKey.java @@ -19,17 +19,18 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; /** * Op removes and returns the (key, value) element with the smallest @@ -96,15 +97,11 @@ private Options() { * @return a new instance of OrderedMapUnstageNoKey */ @Endpoint(describeByClass = true) - public static OrderedMapUnstageNoKey create(Scope scope, Operand indices, List> dtypes, Options... options) { + public static OrderedMapUnstageNoKey create(Scope scope, Operand indices, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OrderedMapUnstageNoKey", scope.makeOpName("OrderedMapUnstageNoKey")); opBuilder.addInput(indices.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Placeholder.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Placeholder.java index f173c27a36c..a510d2d5cb7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Placeholder.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Placeholder.java @@ -17,12 +17,12 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -70,10 +70,10 @@ private Options() { * @return a new instance of Placeholder */ @Endpoint(describeByClass = true) - public static Placeholder create(Scope scope, DataType dtype, Options... options) { + public static Placeholder create(Scope scope, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Placeholder", scope.makeOpName("Placeholder")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.shape != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Range.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Range.java index 05e58a15da0..7f30e607c86 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Range.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Range.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Creates a sequence of numbers. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReadVariableOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReadVariableOp.java index ee6624f4fc8..1e73225cc89 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReadVariableOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReadVariableOp.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -52,11 +52,11 @@ public final class ReadVariableOp extends RawOp implements Oper * @return a new instance of ReadVariableOp */ @Endpoint(describeByClass = true) - public static ReadVariableOp create(Scope scope, Operand resource, DataType dtype) { + public static ReadVariableOp create(Scope scope, Operand resource, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("ReadVariableOp", scope.makeOpName("ReadVariableOp")); opBuilder.addInput(resource.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new ReadVariableOp(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Recv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Recv.java index 148c8a5d1ce..ccae187f1ec 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Recv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Recv.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -70,10 +70,10 @@ private Options() { * @return a new instance of Recv */ @Endpoint(describeByClass = true) - public static Recv create(Scope scope, DataType tensorType, String tensorName, String sendDevice, Long sendDeviceIncarnation, String recvDevice, Options... options) { + public static Recv create(Scope scope, Class tensorType, String tensorName, String sendDevice, Long sendDeviceIncarnation, String recvDevice, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Recv", scope.makeOpName("Recv")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("tensor_type", tensorType); + opBuilder.setAttr("tensor_type", Operands.toDataType(tensorType)); opBuilder.setAttr("tensor_name", tensorName); opBuilder.setAttr("send_device", sendDevice); opBuilder.setAttr("send_device_incarnation", sendDeviceIncarnation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAll.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAll.java index 0766d3a0cf6..9c731309c44 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAll.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAll.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the "logical and" of elements across dimensions of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAny.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAny.java index b853c9c3df4..ce57bd4911b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAny.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ReduceAny.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the "logical or" of elements across dimensions of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteFusedGraphExecute.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteFusedGraphExecute.java index dd7cbdad550..b0dcdfc5398 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteFusedGraphExecute.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RemoteFusedGraphExecute.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -57,15 +56,11 @@ public final class RemoteFusedGraphExecute extends RawOp implements Iterable> inputs, List> Toutputs, String serializedRemoteFusedGraphExecuteInfo) { + public static RemoteFusedGraphExecute create(Scope scope, Iterable> inputs, List> Toutputs, String serializedRemoteFusedGraphExecuteInfo) { OperationBuilder opBuilder = scope.env().opBuilder("RemoteFusedGraphExecute", scope.makeOpName("RemoteFusedGraphExecute")); opBuilder.addInputList(Operands.asOutputs(inputs)); opBuilder = scope.apply(opBuilder); - DataType[] ToutputsArray = new DataType[Toutputs.size()]; - for (int i = 0; i < ToutputsArray.length; ++i) { - ToutputsArray[i] = Toutputs.get(i); - } - opBuilder.setAttr("Toutputs", ToutputsArray); + opBuilder.setAttr("Toutputs", Operands.toDataTypes(Toutputs)); opBuilder.setAttr("serialized_remote_fused_graph_execute_info", serializedRemoteFusedGraphExecuteInfo); return new RemoteFusedGraphExecute(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceCountUpTo.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceCountUpTo.java index 7600321885f..cbfb7ea15d2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceCountUpTo.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceCountUpTo.java @@ -17,17 +17,16 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Increments variable pointed to by 'resource' until it reaches 'limit'. @@ -48,12 +47,12 @@ public final class ResourceCountUpTo extends RawOp implements * @return a new instance of ResourceCountUpTo */ @Endpoint(describeByClass = true) - public static ResourceCountUpTo create(Scope scope, Operand resource, Long limit, DataType T) { + public static ResourceCountUpTo create(Scope scope, Operand resource, Long limit, Class T) { OperationBuilder opBuilder = scope.env().opBuilder("ResourceCountUpTo", scope.makeOpName("ResourceCountUpTo")); opBuilder.addInput(resource.asOutput()); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("limit", limit); - opBuilder.setAttr("T", T); + opBuilder.setAttr("T", Operands.toDataType(T)); return new ResourceCountUpTo(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGather.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGather.java index 8d29d6f9a6f..4830d44d3ba 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGather.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGather.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -90,12 +90,12 @@ private Options() { * @return a new instance of ResourceGather */ @Endpoint(describeByClass = true) - public static ResourceGather create(Scope scope, Operand resource, Operand indices, DataType dtype, Options... options) { + public static ResourceGather create(Scope scope, Operand resource, Operand indices, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ResourceGather", scope.makeOpName("ResourceGather")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.batchDims != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGatherNd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGatherNd.java index ed90a270f83..d78b649a997 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGatherNd.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceGatherNd.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -45,12 +45,12 @@ public final class ResourceGatherNd extends RawOp implements Op * @return a new instance of ResourceGatherNd */ @Endpoint(describeByClass = true) - public static ResourceGatherNd create(Scope scope, Operand resource, Operand indices, DataType dtype) { + public static ResourceGatherNd create(Scope scope, Operand resource, Operand indices, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("ResourceGatherNd", scope.makeOpName("ResourceGatherNd")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new ResourceGatherNd(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMax.java index e96cd9db5fd..746a6d30a35 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMax.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Reduces sparse updates into a variable reference using the `max` operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMin.java index 596ae2b1882..4f854496bf3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterMin.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Reduces sparse updates into a variable reference using the `min` operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/SetDiff1d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/SetDiff1d.java index 4c5a51af3f4..515f413fca9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/SetDiff1d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/SetDiff1d.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -69,12 +69,12 @@ public final class SetDiff1d extends RawOp { * @return a new instance of SetDiff1d */ @Endpoint(describeByClass = true) - public static SetDiff1d create(Scope scope, Operand x, Operand y, DataType outIdx) { + public static SetDiff1d create(Scope scope, Operand x, Operand y, Class outIdx) { OperationBuilder opBuilder = scope.env().opBuilder("ListDiff", scope.makeOpName("SetDiff1d")); opBuilder.addInput(x.asOutput()); opBuilder.addInput(y.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_idx", outIdx); + opBuilder.setAttr("out_idx", Operands.toDataType(outIdx)); return new SetDiff1d(opBuilder.build()); } @@ -88,7 +88,7 @@ public static SetDiff1d create(Scope */ @Endpoint(describeByClass = true) public static SetDiff1d create(Scope scope, Operand x, Operand y) { - return create(scope, x, y, TInt32.DTYPE); + return create(scope, x, y, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Shape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Shape.java index 4fdefdae3a7..a2d8a43a496 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Shape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Shape.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -56,11 +56,11 @@ public final class Shape extends RawOp implements Operand * @return a new instance of Shape */ @Endpoint(describeByClass = true) - public static Shape create(Scope scope, Operand input, DataType outType) { + public static Shape create(Scope scope, Operand input, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("Shape", scope.makeOpName("Shape")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new Shape(opBuilder.build()); } @@ -73,7 +73,7 @@ public static Shape create(Scope scope, */ @Endpoint(describeByClass = true) public static Shape create(Scope scope, Operand input) { - return create(scope, input, TInt32.DTYPE); + return create(scope, input, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ShapeN.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ShapeN.java index c295cf810e2..b196e321962 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ShapeN.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ShapeN.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -53,11 +52,11 @@ public final class ShapeN extends RawOp implements Iterable ShapeN create(Scope scope, Iterable> input, DataType outType) { + public static ShapeN create(Scope scope, Iterable> input, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("ShapeN", scope.makeOpName("ShapeN")); opBuilder.addInputList(Operands.asOutputs(input)); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new ShapeN(opBuilder.build()); } @@ -70,7 +69,7 @@ public static ShapeN create(Scope scope, */ @Endpoint(describeByClass = true) public static ShapeN create(Scope scope, Iterable> input) { - return create(scope, input, TInt32.DTYPE); + return create(scope, input, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Size.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Size.java index fbc3f2928a8..127608e1fac 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Size.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Size.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -57,11 +57,11 @@ public final class Size extends RawOp implements Operand { * @return a new instance of Size */ @Endpoint(describeByClass = true) - public static Size create(Scope scope, Operand input, DataType outType) { + public static Size create(Scope scope, Operand input, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("Size", scope.makeOpName("Size")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new Size(opBuilder.build()); } @@ -74,7 +74,7 @@ public static Size create(Scope scope, O */ @Endpoint(describeByClass = true) public static Size create(Scope scope, Operand input) { - return create(scope, input, TInt32.DTYPE); + return create(scope, input, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java index 5bcf715d7a6..7be90432f6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java @@ -18,13 +18,14 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; /** * Op removes all elements in the underlying container. @@ -87,14 +88,10 @@ private Options() { * @return a new instance of StageClear */ @Endpoint(describeByClass = true) - public static StageClear create(Scope scope, List> dtypes, Options... options) { + public static StageClear create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("StageClear", scope.makeOpName("StageClear")); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StagePeek.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StagePeek.java index f200a074545..eb6e6dde91c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StagePeek.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StagePeek.java @@ -20,11 +20,11 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -98,15 +98,11 @@ private Options() { * @return a new instance of StagePeek */ @Endpoint(describeByClass = true) - public static StagePeek create(Scope scope, Operand index, List> dtypes, Options... options) { + public static StagePeek create(Scope scope, Operand index, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("StagePeek", scope.makeOpName("StagePeek")); opBuilder.addInput(index.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageSize.java index 09791a6332f..c70660c2071 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageSize.java @@ -18,16 +18,17 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; /** * Op returns the number of elements in the underlying container. @@ -90,14 +91,10 @@ private Options() { * @return a new instance of StageSize */ @Endpoint(describeByClass = true) - public static StageSize create(Scope scope, List> dtypes, Options... options) { + public static StageSize create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("StageSize", scope.makeOpName("StageSize")); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TemporaryVariable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TemporaryVariable.java index fad59c7168a..15f54863824 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TemporaryVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TemporaryVariable.java @@ -17,12 +17,12 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -82,11 +82,11 @@ private Options() { * @return a new instance of TemporaryVariable */ @Endpoint(describeByClass = true) - public static TemporaryVariable create(Scope scope, Shape shape, DataType dtype, Options... options) { + public static TemporaryVariable create(Scope scope, Shape shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TemporaryVariable", scope.makeOpName("TemporaryVariable")); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("shape", shape); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.varName != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArray.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArray.java index f8ab8d68d64..c162f3110f6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArray.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArray.java @@ -17,12 +17,12 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -116,11 +116,11 @@ private Options() { * @return a new instance of TensorArray */ @Endpoint(describeByClass = true) - public static TensorArray create(Scope scope, Operand size, DataType dtype, Options... options) { + public static TensorArray create(Scope scope, Operand size, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TensorArrayV3", scope.makeOpName("TensorArray")); opBuilder.addInput(size.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.elementShape != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayConcat.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayConcat.java index 64c24bd69d4..a861b19400e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayConcat.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayConcat.java @@ -17,12 +17,12 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -83,12 +83,12 @@ private Options() { * @return a new instance of TensorArrayConcat */ @Endpoint(describeByClass = true) - public static TensorArrayConcat create(Scope scope, Operand handle, Operand flowIn, DataType dtype, Options... options) { + public static TensorArrayConcat create(Scope scope, Operand handle, Operand flowIn, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TensorArrayConcatV3", scope.makeOpName("TensorArrayConcat")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(flowIn.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.elementShapeExcept0 != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayGather.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayGather.java index 76f577f4e01..8f4bcea319e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayGather.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayGather.java @@ -17,12 +17,12 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -74,13 +74,13 @@ private Options() { * @return a new instance of TensorArrayGather */ @Endpoint(describeByClass = true) - public static TensorArrayGather create(Scope scope, Operand handle, Operand indices, Operand flowIn, DataType dtype, Options... options) { + public static TensorArrayGather create(Scope scope, Operand handle, Operand indices, Operand flowIn, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TensorArrayGatherV3", scope.makeOpName("TensorArrayGather")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder.addInput(flowIn.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.elementShape != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayPack.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayPack.java index d9505d5c4a2..aa45df4b937 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayPack.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayPack.java @@ -17,12 +17,12 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -67,12 +67,12 @@ private Options() { * @return a new instance of TensorArrayPack */ @Endpoint(describeByClass = true) - public static TensorArrayPack create(Scope scope, Operand handle, Operand flowIn, DataType dtype, Options... options) { + public static TensorArrayPack create(Scope scope, Operand handle, Operand flowIn, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TensorArrayPack", scope.makeOpName("TensorArrayPack")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(flowIn.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.elementShape != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayRead.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayRead.java index 3e299a0c5d5..c076510e393 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayRead.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayRead.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -49,13 +49,13 @@ public final class TensorArrayRead extends RawOp implements Ope * @return a new instance of TensorArrayRead */ @Endpoint(describeByClass = true) - public static TensorArrayRead create(Scope scope, Operand handle, Operand index, Operand flowIn, DataType dtype) { + public static TensorArrayRead create(Scope scope, Operand handle, Operand index, Operand flowIn, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorArrayReadV3", scope.makeOpName("TensorArrayRead")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(index.asOutput()); opBuilder.addInput(flowIn.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new TensorArrayRead(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcat.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcat.java index 1e1941053ac..685c9c1d900 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcat.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcat.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -62,13 +62,13 @@ public final class TensorListConcat extends RawOp { * @return a new instance of TensorListConcat */ @Endpoint(describeByClass = true) - public static TensorListConcat create(Scope scope, Operand inputHandle, Operand elementShape, Operand leadingDims, DataType elementDtype) { + public static TensorListConcat create(Scope scope, Operand inputHandle, Operand elementShape, Operand leadingDims, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListConcatV2", scope.makeOpName("TensorListConcat")); opBuilder.addInput(inputHandle.asOutput()); opBuilder.addInput(elementShape.asOutput()); opBuilder.addInput(leadingDims.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("element_dtype", elementDtype); + opBuilder.setAttr("element_dtype", Operands.toDataType(elementDtype)); return new TensorListConcat(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcatLists.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcatLists.java index 7554a301d16..25b2e06df84 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcatLists.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListConcatLists.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -43,12 +43,12 @@ public final class TensorListConcatLists extends RawOp implements Operand * @return a new instance of TensorListConcatLists */ @Endpoint(describeByClass = true) - public static TensorListConcatLists create(Scope scope, Operand inputA, Operand inputB, DataType elementDtype) { + public static TensorListConcatLists create(Scope scope, Operand inputA, Operand inputB, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListConcatLists", scope.makeOpName("TensorListConcatLists")); opBuilder.addInput(inputA.asOutput()); opBuilder.addInput(inputB.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("element_dtype", elementDtype); + opBuilder.setAttr("element_dtype", Operands.toDataType(elementDtype)); return new TensorListConcatLists(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListElementShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListElementShape.java index d0511f3b96a..a985213de4e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListElementShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListElementShape.java @@ -17,17 +17,16 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * The shape of the elements of the given list, as a tensor. @@ -49,11 +48,11 @@ public final class TensorListElementShape extends RawOp imple * @return a new instance of TensorListElementShape */ @Endpoint(describeByClass = true) - public static TensorListElementShape create(Scope scope, Operand inputHandle, DataType shapeType) { + public static TensorListElementShape create(Scope scope, Operand inputHandle, Class shapeType) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListElementShape", scope.makeOpName("TensorListElementShape")); opBuilder.addInput(inputHandle.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("shape_type", shapeType); + opBuilder.setAttr("shape_type", Operands.toDataType(shapeType)); return new TensorListElementShape(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGather.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGather.java index 2acb7aea4ae..90cacdd0435 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGather.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGather.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -55,13 +55,13 @@ public final class TensorListGather extends RawOp implements Op * @return a new instance of TensorListGather */ @Endpoint(describeByClass = true) - public static TensorListGather create(Scope scope, Operand inputHandle, Operand indices, Operand elementShape, DataType elementDtype) { + public static TensorListGather create(Scope scope, Operand inputHandle, Operand indices, Operand elementShape, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListGather", scope.makeOpName("TensorListGather")); opBuilder.addInput(inputHandle.asOutput()); opBuilder.addInput(indices.asOutput()); opBuilder.addInput(elementShape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("element_dtype", elementDtype); + opBuilder.setAttr("element_dtype", Operands.toDataType(elementDtype)); return new TensorListGather(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGetItem.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGetItem.java index af86e14fa92..7f96f897242 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGetItem.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListGetItem.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -46,13 +46,13 @@ public final class TensorListGetItem extends RawOp implements O * @return a new instance of TensorListGetItem */ @Endpoint(describeByClass = true) - public static TensorListGetItem create(Scope scope, Operand inputHandle, Operand index, Operand elementShape, DataType elementDtype) { + public static TensorListGetItem create(Scope scope, Operand inputHandle, Operand index, Operand elementShape, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListGetItem", scope.makeOpName("TensorListGetItem")); opBuilder.addInput(inputHandle.asOutput()); opBuilder.addInput(index.asOutput()); opBuilder.addInput(elementShape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("element_dtype", elementDtype); + opBuilder.setAttr("element_dtype", Operands.toDataType(elementDtype)); return new TensorListGetItem(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListPopBack.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListPopBack.java index 9e601b5733b..5600bf69778 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListPopBack.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListPopBack.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -54,12 +54,12 @@ public final class TensorListPopBack extends RawOp { * @return a new instance of TensorListPopBack */ @Endpoint(describeByClass = true) - public static TensorListPopBack create(Scope scope, Operand inputHandle, Operand elementShape, DataType elementDtype) { + public static TensorListPopBack create(Scope scope, Operand inputHandle, Operand elementShape, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListPopBack", scope.makeOpName("TensorListPopBack")); opBuilder.addInput(inputHandle.asOutput()); opBuilder.addInput(elementShape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("element_dtype", elementDtype); + opBuilder.setAttr("element_dtype", Operands.toDataType(elementDtype)); return new TensorListPopBack(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListReserve.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListReserve.java index 439bfe69886..71e4c2b2f06 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListReserve.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListReserve.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -51,12 +51,12 @@ public final class TensorListReserve extends RawOp implements Operand { * @return a new instance of TensorListReserve */ @Endpoint(describeByClass = true) - public static TensorListReserve create(Scope scope, Operand elementShape, Operand numElements, DataType elementDtype) { + public static TensorListReserve create(Scope scope, Operand elementShape, Operand numElements, Class elementDtype) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListReserve", scope.makeOpName("TensorListReserve")); opBuilder.addInput(elementShape.asOutput()); opBuilder.addInput(numElements.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("element_dtype", elementDtype); + opBuilder.setAttr("element_dtype", Operands.toDataType(elementDtype)); return new TensorListReserve(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListStack.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListStack.java index 5ce4e4242ea..20284f1e76e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListStack.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorListStack.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -74,12 +74,12 @@ private Options() { * @return a new instance of TensorListStack */ @Endpoint(describeByClass = true) - public static TensorListStack create(Scope scope, Operand inputHandle, Operand elementShape, DataType elementDtype, Options... options) { + public static TensorListStack create(Scope scope, Operand inputHandle, Operand elementShape, Class elementDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TensorListStack", scope.makeOpName("TensorListStack")); opBuilder.addInput(inputHandle.asOutput()); opBuilder.addInput(elementShape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("element_dtype", elementDtype); + opBuilder.setAttr("element_dtype", Operands.toDataType(elementDtype)); if (options != null) { for (Options opts : options) { if (opts.numElements != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unique.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unique.java index 29703554467..40b303e2a05 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unique.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unique.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -90,12 +90,12 @@ public final class Unique extends RawOp { * @return a new instance of Unique */ @Endpoint(describeByClass = true) - public static Unique create(Scope scope, Operand x, Operand axis, DataType outIdx) { + public static Unique create(Scope scope, Operand x, Operand axis, Class outIdx) { OperationBuilder opBuilder = scope.env().opBuilder("UniqueV2", scope.makeOpName("Unique")); opBuilder.addInput(x.asOutput()); opBuilder.addInput(axis.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_idx", outIdx); + opBuilder.setAttr("out_idx", Operands.toDataType(outIdx)); return new Unique(opBuilder.build()); } @@ -110,7 +110,7 @@ public static Unique Unique create(Scope scope, Operand x, Operand axis) { - return create(scope, x, axis, TInt32.DTYPE); + return create(scope, x, axis, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UniqueWithCounts.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UniqueWithCounts.java index 1c0b63c6444..cf2ae7566cc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UniqueWithCounts.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UniqueWithCounts.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -94,12 +94,12 @@ public final class UniqueWithCounts extends * @return a new instance of UniqueWithCounts */ @Endpoint(describeByClass = true) - public static UniqueWithCounts create(Scope scope, Operand x, Operand axis, DataType outIdx) { + public static UniqueWithCounts create(Scope scope, Operand x, Operand axis, Class outIdx) { OperationBuilder opBuilder = scope.env().opBuilder("UniqueWithCountsV2", scope.makeOpName("UniqueWithCounts")); opBuilder.addInput(x.asOutput()); opBuilder.addInput(axis.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_idx", outIdx); + opBuilder.setAttr("out_idx", Operands.toDataType(outIdx)); return new UniqueWithCounts(opBuilder.build()); } @@ -114,7 +114,7 @@ public static UniqueWith */ @Endpoint(describeByClass = true) public static UniqueWithCounts create(Scope scope, Operand x, Operand axis) { - return create(scope, x, axis, TInt32.DTYPE); + return create(scope, x, axis, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UnravelIndex.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UnravelIndex.java index 18ff27b22f4..0ce98a63884 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UnravelIndex.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UnravelIndex.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Converts an array of flat indices into a tuple of coordinate arrays. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unstage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unstage.java index e765d849117..2997b5eaa57 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unstage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Unstage.java @@ -20,11 +20,11 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -95,14 +95,10 @@ private Options() { * @return a new instance of Unstage */ @Endpoint(describeByClass = true) - public static Unstage create(Scope scope, List> dtypes, Options... options) { + public static Unstage create(Scope scope, List> dtypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Unstage", scope.makeOpName("Unstage")); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); if (options != null) { for (Options opts : options) { if (opts.capacity != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UpperBound.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UpperBound.java index c8f83e5bbb0..6bc9d8034fd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UpperBound.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/UpperBound.java @@ -17,11 +17,11 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -66,12 +66,12 @@ public final class UpperBound extends RawOp implements Operan * @return a new instance of UpperBound */ @Endpoint(describeByClass = true) - public static UpperBound create(Scope scope, Operand sortedInputs, Operand values, DataType outType) { + public static UpperBound create(Scope scope, Operand sortedInputs, Operand values, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("UpperBound", scope.makeOpName("UpperBound")); opBuilder.addInput(sortedInputs.asOutput()); opBuilder.addInput(values.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new UpperBound(opBuilder.build()); } @@ -86,7 +86,7 @@ public static UpperBound create(Scope sc */ @Endpoint(describeByClass = true) public static UpperBound create(Scope scope, Operand sortedInputs, Operand values) { - return create(scope, sortedInputs, values, TInt32.DTYPE); + return create(scope, sortedInputs, values, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VarHandleOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VarHandleOp.java index 1d5c519c417..41d962d4381 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VarHandleOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VarHandleOp.java @@ -18,12 +18,12 @@ package org.tensorflow.op.core; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -85,10 +85,10 @@ private Options() { * @return a new instance of VarHandleOp */ @Endpoint(describeByClass = true) - public static VarHandleOp create(Scope scope, DataType dtype, Shape shape, Options... options) { + public static VarHandleOp create(Scope scope, Class dtype, Shape shape, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("VarHandleOp", scope.makeOpName("VarHandleOp")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); opBuilder.setAttr("shape", shape); if (options != null) { for (Options opts : options) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java index 38b66d3deca..98e545d7b76 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java @@ -17,12 +17,12 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -81,11 +81,11 @@ private Options() { * @return a new instance of Variable */ @Endpoint(describeByClass = true) - public static Variable create(Scope scope, Shape shape, DataType dtype, Options... options) { + public static Variable create(Scope scope, Shape shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("VariableV2", scope.makeOpName("Variable")); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("shape", shape); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.container != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VariableShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VariableShape.java index 3f7a30bc2f4..ab37a6c4d08 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VariableShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/VariableShape.java @@ -17,18 +17,17 @@ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the shape of the variable pointed to by `resource`. @@ -56,11 +55,11 @@ public final class VariableShape extends RawOp implements Ope * @return a new instance of VariableShape */ @Endpoint(describeByClass = true) - public static VariableShape create(Scope scope, Operand input, DataType outType) { + public static VariableShape create(Scope scope, Operand input, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("VariableShape", scope.makeOpName("VariableShape")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new VariableShape(opBuilder.build()); } @@ -73,7 +72,7 @@ public static VariableShape create(Scope scope, Operand create(Scope scope, Operand input) { - return create(scope, input, TInt32.DTYPE); + return create(scope, input, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousIterator.java index 3e9325f8ffc..3be3e0be3eb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousIterator.java @@ -18,15 +18,16 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; /** * A container for an iterator resource. @@ -43,14 +44,10 @@ public final class AnonymousIterator extends RawOp { * @return a new instance of AnonymousIterator */ @Endpoint(describeByClass = true) - public static AnonymousIterator create(Scope scope, List> outputTypes, List outputShapes) { + public static AnonymousIterator create(Scope scope, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("AnonymousIteratorV2", scope.makeOpName("AnonymousIterator")); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousMultiDeviceIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousMultiDeviceIterator.java index c88b52dac89..d02a1f37cf3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousMultiDeviceIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AnonymousMultiDeviceIterator.java @@ -18,15 +18,16 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; /** * A container for a multi device iterator resource. @@ -43,7 +44,7 @@ public final class AnonymousMultiDeviceIterator extends RawOp { * @return a new instance of AnonymousMultiDeviceIterator */ @Endpoint(describeByClass = true) - public static AnonymousMultiDeviceIterator create(Scope scope, List devices, List> outputTypes, List outputShapes) { + public static AnonymousMultiDeviceIterator create(Scope scope, List devices, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("AnonymousMultiDeviceIterator", scope.makeOpName("AnonymousMultiDeviceIterator")); opBuilder = scope.apply(opBuilder); String[] devicesArray = new String[devices.size()]; @@ -51,11 +52,7 @@ public static AnonymousMultiDeviceIterator create(Scope scope, List devi devicesArray[i] = devices.get(i); } opBuilder.setAttr("devices", devicesArray); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertNextDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertNextDataset.java index 56c48d9093c..ee9da65be4d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertNextDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AssertNextDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -58,16 +58,12 @@ public final class AssertNextDataset extends RawOp implements Operand { * @return a new instance of AssertNextDataset */ @Endpoint(describeByClass = true) - public static AssertNextDataset create(Scope scope, Operand inputDataset, Operand transformations, List> outputTypes, List outputShapes) { + public static AssertNextDataset create(Scope scope, Operand inputDataset, Operand transformations, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("AssertNextDataset", scope.makeOpName("AssertNextDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(transformations.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AutoShardDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AutoShardDataset.java index 92b73b6603e..1e18bf6be11 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AutoShardDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/AutoShardDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -76,17 +76,13 @@ private Options() { * @return a new instance of AutoShardDataset */ @Endpoint(describeByClass = true) - public static AutoShardDataset create(Scope scope, Operand inputDataset, Operand numWorkers, Operand index, List> outputTypes, List outputShapes, Options... options) { + public static AutoShardDataset create(Scope scope, Operand inputDataset, Operand numWorkers, Operand index, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("AutoShardDataset", scope.makeOpName("AutoShardDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numWorkers.asOutput()); opBuilder.addInput(index.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BatchDataset.java index a37b4eb3169..ee03bac2b95 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BatchDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -71,17 +71,13 @@ private Options() { * @return a new instance of BatchDataset */ @Endpoint(describeByClass = true) - public static BatchDataset create(Scope scope, Operand inputDataset, Operand batchSize, Operand dropRemainder, List> outputTypes, List outputShapes, Options... options) { + public static BatchDataset create(Scope scope, Operand inputDataset, Operand batchSize, Operand dropRemainder, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("BatchDatasetV2", scope.makeOpName("BatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(batchSize.asOutput()); opBuilder.addInput(dropRemainder.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BytesProducedStatsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BytesProducedStatsDataset.java index 4e5dbab1143..b9cea475727 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BytesProducedStatsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/BytesProducedStatsDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,16 +47,12 @@ public final class BytesProducedStatsDataset extends RawOp implements Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { + public static BytesProducedStatsDataset create(Scope scope, Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("BytesProducedStatsDataset", scope.makeOpName("BytesProducedStatsDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(tag.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDataset.java index 4df9c69bc61..25431c48ec6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -53,16 +53,12 @@ public final class CacheDataset extends RawOp implements Operand { * @return a new instance of CacheDataset */ @Endpoint(describeByClass = true) - public static CacheDataset create(Scope scope, Operand inputDataset, Operand filename, List> outputTypes, List outputShapes) { + public static CacheDataset create(Scope scope, Operand inputDataset, Operand filename, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("CacheDataset", scope.makeOpName("CacheDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(filename.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDatasetV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDatasetV2.java index ed5f087c039..c5cd457be55 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDatasetV2.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/CacheDatasetV2.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,17 +47,13 @@ public final class CacheDatasetV2 extends RawOp implements Operand { * @return a new instance of CacheDatasetV2 */ @Endpoint(describeByClass = true) - public static CacheDatasetV2 create(Scope scope, Operand inputDataset, Operand filename, Operand cache, List> outputTypes, List outputShapes) { + public static CacheDatasetV2 create(Scope scope, Operand inputDataset, Operand filename, Operand cache, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("CacheDatasetV2", scope.makeOpName("CacheDatasetV2")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(filename.asOutput()); opBuilder.addInput(cache.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestDataset.java index c10cf4d1660..e29e27f73f4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ChooseFastestDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,16 +45,12 @@ public final class ChooseFastestDataset extends RawOp implements Operand * @return a new instance of ChooseFastestDataset */ @Endpoint(describeByClass = true) - public static ChooseFastestDataset create(Scope scope, Iterable> inputDatasets, Long numExperiments, List> outputTypes, List outputShapes) { + public static ChooseFastestDataset create(Scope scope, Iterable> inputDatasets, Long numExperiments, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ChooseFastestDataset", scope.makeOpName("ChooseFastestDataset")); opBuilder.addInputList(Operands.asOutputs(inputDatasets)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("num_experiments", numExperiments); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ConcatenateDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ConcatenateDataset.java index 9806530cd4b..bd32060b138 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ConcatenateDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ConcatenateDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,16 +47,12 @@ public final class ConcatenateDataset extends RawOp implements Operand { * @return a new instance of ConcatenateDataset */ @Endpoint(describeByClass = true) - public static ConcatenateDataset create(Scope scope, Operand inputDataset, Operand anotherDataset, List> outputTypes, List outputShapes) { + public static ConcatenateDataset create(Scope scope, Operand inputDataset, Operand anotherDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ConcatenateDataset", scope.makeOpName("ConcatenateDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(anotherDataset.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToSingleElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToSingleElement.java index 5df75adc3f4..a74413e1935 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToSingleElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToSingleElement.java @@ -20,12 +20,12 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,15 +47,11 @@ public final class DatasetToSingleElement extends RawOp implements Iterable dataset, List> outputTypes, List outputShapes) { + public static DatasetToSingleElement create(Scope scope, Operand dataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("DatasetToSingleElement", scope.makeOpName("DatasetToSingleElement")); opBuilder.addInput(dataset.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DenseToSparseBatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DenseToSparseBatchDataset.java index 87c843cb06c..1b5c9ec6ceb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DenseToSparseBatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DenseToSparseBatchDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -51,17 +51,13 @@ public final class DenseToSparseBatchDataset extends RawOp implements Operand inputDataset, Operand batchSize, Operand rowShape, List> outputTypes, List outputShapes) { + public static DenseToSparseBatchDataset create(Scope scope, Operand inputDataset, Operand batchSize, Operand rowShape, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("DenseToSparseBatchDataset", scope.makeOpName("DenseToSparseBatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(batchSize.asOutput()); opBuilder.addInput(rowShape.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DirectedInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DirectedInterleaveDataset.java index 654fec6185c..0e543dc351a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DirectedInterleaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DirectedInterleaveDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,16 +48,12 @@ public final class DirectedInterleaveDataset extends RawOp implements Operand selectorInputDataset, Iterable> dataInputDatasets, List> outputTypes, List outputShapes) { + public static DirectedInterleaveDataset create(Scope scope, Operand selectorInputDataset, Iterable> dataInputDatasets, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("DirectedInterleaveDataset", scope.makeOpName("DirectedInterleaveDataset")); opBuilder.addInput(selectorInputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(dataInputDatasets)); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterByLastComponentDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterByLastComponentDataset.java index 9119d082346..d1d1c55e08f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterByLastComponentDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/FilterByLastComponentDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -45,15 +45,11 @@ public final class FilterByLastComponentDataset extends RawOp implements Operand * @return a new instance of FilterByLastComponentDataset */ @Endpoint(describeByClass = true) - public static FilterByLastComponentDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static FilterByLastComponentDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("FilterByLastComponentDataset", scope.makeOpName("FilterByLastComponentDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IgnoreErrorsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IgnoreErrorsDataset.java index 2d0822d06a2..0b4bec924b2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IgnoreErrorsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IgnoreErrorsDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -45,15 +45,11 @@ public final class IgnoreErrorsDataset extends RawOp implements Operand { * @return a new instance of IgnoreErrorsDataset */ @Endpoint(describeByClass = true) - public static IgnoreErrorsDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static IgnoreErrorsDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("IgnoreErrorsDataset", scope.makeOpName("IgnoreErrorsDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/Iterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/Iterator.java index c1c3804bdbe..d20e98d5c9b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/Iterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/Iterator.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -46,16 +46,12 @@ public final class Iterator extends RawOp implements Operand { * @return a new instance of Iterator */ @Endpoint(describeByClass = true) - public static Iterator create(Scope scope, String sharedName, String container, List> outputTypes, List outputShapes) { + public static Iterator create(Scope scope, String sharedName, String container, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("IteratorV2", scope.makeOpName("Iterator")); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("shared_name", sharedName); opBuilder.setAttr("container", container); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorFromStringHandle.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorFromStringHandle.java index b65e2c57225..4de1a9f1e9e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorFromStringHandle.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorFromStringHandle.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -64,15 +64,11 @@ private Options() { * @return a new instance of IteratorFromStringHandle */ @Endpoint(describeByClass = true) - public static IteratorFromStringHandle create(Scope scope, Operand stringHandle, List> outputTypes, Options... options) { + public static IteratorFromStringHandle create(Scope scope, Operand stringHandle, List> outputTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("IteratorFromStringHandleV2", scope.makeOpName("IteratorFromStringHandle")); opBuilder.addInput(stringHandle.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); if (options != null) { for (Options opts : options) { if (opts.outputShapes != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNext.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNext.java index 33c7d435baa..a41f441f650 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNext.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNext.java @@ -20,12 +20,12 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -48,15 +48,11 @@ public final class IteratorGetNext extends RawOp implements Iterable iterator, List> outputTypes, List outputShapes) { + public static IteratorGetNext create(Scope scope, Operand iterator, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("IteratorGetNext", scope.makeOpName("IteratorGetNext")); opBuilder.addInput(iterator.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextAsOptional.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextAsOptional.java index 2ab688f371e..7d4f441eb83 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextAsOptional.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextAsOptional.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -46,15 +46,11 @@ public final class IteratorGetNextAsOptional extends RawOp implements Operand iterator, List> outputTypes, List outputShapes) { + public static IteratorGetNextAsOptional create(Scope scope, Operand iterator, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("IteratorGetNextAsOptional", scope.makeOpName("IteratorGetNextAsOptional")); opBuilder.addInput(iterator.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextSync.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextSync.java index 4b7114fb0c7..bdce447dd0a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextSync.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/IteratorGetNextSync.java @@ -20,12 +20,12 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -53,15 +53,11 @@ public final class IteratorGetNextSync extends RawOp implements Iterable iterator, List> outputTypes, List outputShapes) { + public static IteratorGetNextSync create(Scope scope, Operand iterator, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("IteratorGetNextSync", scope.makeOpName("IteratorGetNextSync")); opBuilder.addInput(iterator.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LMDBDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LMDBDataset.java index bca9d0b50cf..c67912ec018 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LMDBDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LMDBDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -58,15 +58,11 @@ public final class LMDBDataset extends RawOp implements Operand { * @return a new instance of LMDBDataset */ @Endpoint(describeByClass = true) - public static LMDBDataset create(Scope scope, Operand filenames, List> outputTypes, List outputShapes) { + public static LMDBDataset create(Scope scope, Operand filenames, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("LMDBDataset", scope.makeOpName("LMDBDataset")); opBuilder.addInput(filenames.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LatencyStatsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LatencyStatsDataset.java index e9a9425f891..047057a1c25 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LatencyStatsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LatencyStatsDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,16 +47,12 @@ public final class LatencyStatsDataset extends RawOp implements Operand { * @return a new instance of LatencyStatsDataset */ @Endpoint(describeByClass = true) - public static LatencyStatsDataset create(Scope scope, Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { + public static LatencyStatsDataset create(Scope scope, Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("LatencyStatsDataset", scope.makeOpName("LatencyStatsDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(tag.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LeakyReluGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LeakyReluGrad.java index 8ac38a713ea..3b41dd2e918 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LeakyReluGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/LeakyReluGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes rectified linear gradients for a LeakyRelu operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MaxIntraOpParallelismDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MaxIntraOpParallelismDataset.java index 3f3b43038c2..14a35d14d07 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MaxIntraOpParallelismDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MaxIntraOpParallelismDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,16 +47,12 @@ public final class MaxIntraOpParallelismDataset extends RawOp implements Operand * @return a new instance of MaxIntraOpParallelismDataset */ @Endpoint(describeByClass = true) - public static MaxIntraOpParallelismDataset create(Scope scope, Operand inputDataset, Operand maxIntraOpParallelism, List> outputTypes, List outputShapes) { + public static MaxIntraOpParallelismDataset create(Scope scope, Operand inputDataset, Operand maxIntraOpParallelism, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("MaxIntraOpParallelismDataset", scope.makeOpName("MaxIntraOpParallelismDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(maxIntraOpParallelism.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ModelDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ModelDataset.java index 9cb6913b9d6..f62be66854b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ModelDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ModelDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -76,15 +76,11 @@ private Options() { * @return a new instance of ModelDataset */ @Endpoint(describeByClass = true) - public static ModelDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes, Options... options) { + public static ModelDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ModelDataset", scope.makeOpName("ModelDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIterator.java index 5db1b83475f..e6ee9d12c02 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIterator.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -49,7 +49,7 @@ public final class MultiDeviceIterator extends RawOp implements Operand { * @return a new instance of MultiDeviceIterator */ @Endpoint(describeByClass = true) - public static MultiDeviceIterator create(Scope scope, List devices, String sharedName, String container, List> outputTypes, List outputShapes) { + public static MultiDeviceIterator create(Scope scope, List devices, String sharedName, String container, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("MultiDeviceIterator", scope.makeOpName("MultiDeviceIterator")); opBuilder = scope.apply(opBuilder); String[] devicesArray = new String[devices.size()]; @@ -59,11 +59,7 @@ public static MultiDeviceIterator create(Scope scope, List devices, Stri opBuilder.setAttr("devices", devicesArray); opBuilder.setAttr("shared_name", sharedName); opBuilder.setAttr("container", container); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorFromStringHandle.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorFromStringHandle.java index 3a4113e10b5..5e6fdb1c640 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorFromStringHandle.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorFromStringHandle.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -65,15 +65,11 @@ private Options() { * @return a new instance of MultiDeviceIteratorFromStringHandle */ @Endpoint(describeByClass = true) - public static MultiDeviceIteratorFromStringHandle create(Scope scope, Operand stringHandle, List> outputTypes, Options... options) { + public static MultiDeviceIteratorFromStringHandle create(Scope scope, Operand stringHandle, List> outputTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MultiDeviceIteratorFromStringHandle", scope.makeOpName("MultiDeviceIteratorFromStringHandle")); opBuilder.addInput(stringHandle.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); if (options != null) { for (Options opts : options) { if (opts.outputShapes != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorGetNextFromShard.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorGetNextFromShard.java index c7313076e26..578bf00fe78 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorGetNextFromShard.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MultiDeviceIteratorGetNextFromShard.java @@ -20,12 +20,12 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -51,17 +51,13 @@ public final class MultiDeviceIteratorGetNextFromShard extends RawOp implements * @return a new instance of MultiDeviceIteratorGetNextFromShard */ @Endpoint(describeByClass = true) - public static MultiDeviceIteratorGetNextFromShard create(Scope scope, Operand multiDeviceIterator, Operand shardNum, Operand incarnationId, List> outputTypes, List outputShapes) { + public static MultiDeviceIteratorGetNextFromShard create(Scope scope, Operand multiDeviceIterator, Operand shardNum, Operand incarnationId, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("MultiDeviceIteratorGetNextFromShard", scope.makeOpName("MultiDeviceIteratorGetNextFromShard")); opBuilder.addInput(multiDeviceIterator.asOutput()); opBuilder.addInput(shardNum.asOutput()); opBuilder.addInput(incarnationId.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/NonSerializableDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/NonSerializableDataset.java index c5b887812ef..49fece5850f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/NonSerializableDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/NonSerializableDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -44,15 +44,11 @@ public final class NonSerializableDataset extends RawOp implements Operand inputDataset, List> outputTypes, List outputShapes) { + public static NonSerializableDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("NonSerializableDataset", scope.makeOpName("NonSerializableDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDataset.java index d08521d3d54..c9bd55a0320 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptimizeDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -69,16 +69,12 @@ private Options() { * @return a new instance of OptimizeDataset */ @Endpoint(describeByClass = true) - public static OptimizeDataset create(Scope scope, Operand inputDataset, Operand optimizations, List> outputTypes, List outputShapes, Options... options) { + public static OptimizeDataset create(Scope scope, Operand inputDataset, Operand optimizations, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OptimizeDataset", scope.makeOpName("OptimizeDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(optimizations.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptionalGetValue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptionalGetValue.java index 5218e9e64e7..eeaa8d61642 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptionalGetValue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/OptionalGetValue.java @@ -20,12 +20,12 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -48,15 +48,11 @@ public final class OptionalGetValue extends RawOp implements Iterable optional, List> outputTypes, List outputShapes) { + public static OptionalGetValue create(Scope scope, Operand optional, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("OptionalGetValue", scope.makeOpName("OptionalGetValue")); opBuilder.addInput(optional.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrefetchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrefetchDataset.java index 65dcbf4806d..7e86b47bfe1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrefetchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrefetchDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -77,16 +77,12 @@ private Options() { * @return a new instance of PrefetchDataset */ @Endpoint(describeByClass = true) - public static PrefetchDataset create(Scope scope, Operand inputDataset, Operand bufferSize, List> outputTypes, List outputShapes, Options... options) { + public static PrefetchDataset create(Scope scope, Operand inputDataset, Operand bufferSize, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("PrefetchDataset", scope.makeOpName("PrefetchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(bufferSize.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrivateThreadPoolDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrivateThreadPoolDataset.java index 63f8045f2e3..92d7753b54a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrivateThreadPoolDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/PrivateThreadPoolDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,16 +47,12 @@ public final class PrivateThreadPoolDataset extends RawOp implements Operand inputDataset, Operand numThreads, List> outputTypes, List outputShapes) { + public static PrivateThreadPoolDataset create(Scope scope, Operand inputDataset, Operand numThreads, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("PrivateThreadPoolDataset", scope.makeOpName("PrivateThreadPoolDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numThreads.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RandomDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RandomDataset.java index 431e2244118..5108bec6d0d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RandomDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RandomDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -60,16 +60,12 @@ public final class RandomDataset extends RawOp implements Operand { * @return a new instance of RandomDataset */ @Endpoint(describeByClass = true) - public static RandomDataset create(Scope scope, Operand seed, Operand seed2, List> outputTypes, List outputShapes) { + public static RandomDataset create(Scope scope, Operand seed, Operand seed2, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("RandomDataset", scope.makeOpName("RandomDataset")); opBuilder.addInput(seed.asOutput()); opBuilder.addInput(seed2.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RangeDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RangeDataset.java index 8fdcdfad0e6..0acede3a72b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RangeDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RangeDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -49,17 +49,13 @@ public final class RangeDataset extends RawOp implements Operand { * @return a new instance of RangeDataset */ @Endpoint(describeByClass = true) - public static RangeDataset create(Scope scope, Operand start, Operand stop, Operand step, List> outputTypes, List outputShapes) { + public static RangeDataset create(Scope scope, Operand start, Operand stop, Operand step, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("RangeDataset", scope.makeOpName("RangeDataset")); opBuilder.addInput(start.asOutput()); opBuilder.addInput(stop.asOutput()); opBuilder.addInput(step.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDataset.java index 050e765ecf6..4d73b0dbbcf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RebatchDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -72,16 +72,12 @@ private Options() { * @return a new instance of RebatchDataset */ @Endpoint(describeByClass = true) - public static RebatchDataset create(Scope scope, Operand inputDataset, Operand numReplicas, List> outputTypes, List outputShapes, Options... options) { + public static RebatchDataset create(Scope scope, Operand inputDataset, Operand numReplicas, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("RebatchDataset", scope.makeOpName("RebatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numReplicas.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RepeatDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RepeatDataset.java index 33a598eb33c..ad327527c06 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RepeatDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RepeatDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -49,16 +49,12 @@ public final class RepeatDataset extends RawOp implements Operand { * @return a new instance of RepeatDataset */ @Endpoint(describeByClass = true) - public static RepeatDataset create(Scope scope, Operand inputDataset, Operand count, List> outputTypes, List outputShapes) { + public static RepeatDataset create(Scope scope, Operand inputDataset, Operand count, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("RepeatDataset", scope.makeOpName("RepeatDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(count.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SamplingDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SamplingDataset.java index ec75a203366..7f67f238136 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SamplingDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SamplingDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -57,18 +57,14 @@ public final class SamplingDataset extends RawOp implements Operand { * @return a new instance of SamplingDataset */ @Endpoint(describeByClass = true) - public static SamplingDataset create(Scope scope, Operand inputDataset, Operand rate, Operand seed, Operand seed2, List> outputTypes, List outputShapes) { + public static SamplingDataset create(Scope scope, Operand inputDataset, Operand rate, Operand seed, Operand seed2, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("SamplingDataset", scope.makeOpName("SamplingDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(rate.asOutput()); opBuilder.addInput(seed.asOutput()); opBuilder.addInput(seed2.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SetStatsAggregatorDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SetStatsAggregatorDataset.java index 247856790af..7e774e3a103 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SetStatsAggregatorDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SetStatsAggregatorDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -48,18 +48,14 @@ public final class SetStatsAggregatorDataset extends RawOp implements Operand inputDataset, Operand statsAggregator, Operand tag, Operand counterPrefix, List> outputTypes, List outputShapes) { + public static SetStatsAggregatorDataset create(Scope scope, Operand inputDataset, Operand statsAggregator, Operand tag, Operand counterPrefix, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("SetStatsAggregatorDataset", scope.makeOpName("SetStatsAggregatorDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(statsAggregator.asOutput()); opBuilder.addInput(tag.asOutput()); opBuilder.addInput(counterPrefix.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShardDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShardDataset.java index 1255fe5e4b4..53b184949e7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShardDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShardDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -68,17 +68,13 @@ private Options() { * @return a new instance of ShardDataset */ @Endpoint(describeByClass = true) - public static ShardDataset create(Scope scope, Operand inputDataset, Operand numShards, Operand index, List> outputTypes, List outputShapes, Options... options) { + public static ShardDataset create(Scope scope, Operand inputDataset, Operand numShards, Operand index, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ShardDataset", scope.makeOpName("ShardDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numShards.asOutput()); opBuilder.addInput(index.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java index 063b9d7352a..bc486a3777c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -70,7 +70,7 @@ private Options() { * @return a new instance of ShuffleAndRepeatDataset */ @Endpoint(describeByClass = true) - public static ShuffleAndRepeatDataset create(Scope scope, Operand inputDataset, Operand bufferSize, Operand seed, Operand seed2, Operand count, Operand seedGenerator, List> outputTypes, List outputShapes, Options... options) { + public static ShuffleAndRepeatDataset create(Scope scope, Operand inputDataset, Operand bufferSize, Operand seed, Operand seed2, Operand count, Operand seedGenerator, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ShuffleAndRepeatDatasetV2", scope.makeOpName("ShuffleAndRepeatDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(bufferSize.asOutput()); @@ -79,11 +79,7 @@ public static ShuffleAndRepeatDataset create(Scope scope, Operand inputDatase opBuilder.addInput(count.asOutput()); opBuilder.addInput(seedGenerator.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java index 27e4982773f..d458c12b68c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -69,7 +69,7 @@ private Options() { * @return a new instance of ShuffleDataset */ @Endpoint(describeByClass = true) - public static ShuffleDataset create(Scope scope, Operand inputDataset, Operand bufferSize, Operand seed, Operand seed2, Operand seedGenerator, List> outputTypes, List outputShapes, Options... options) { + public static ShuffleDataset create(Scope scope, Operand inputDataset, Operand bufferSize, Operand seed, Operand seed2, Operand seedGenerator, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ShuffleDatasetV3", scope.makeOpName("ShuffleDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(bufferSize.asOutput()); @@ -77,11 +77,7 @@ public static ShuffleDataset create(Scope scope, Operand inputDataset, Operan opBuilder.addInput(seed2.asOutput()); opBuilder.addInput(seedGenerator.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SkipDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SkipDataset.java index b4f77efda63..dea92ac046b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SkipDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SkipDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -49,16 +49,12 @@ public final class SkipDataset extends RawOp implements Operand { * @return a new instance of SkipDataset */ @Endpoint(describeByClass = true) - public static SkipDataset create(Scope scope, Operand inputDataset, Operand count, List> outputTypes, List outputShapes) { + public static SkipDataset create(Scope scope, Operand inputDataset, Operand count, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("SkipDataset", scope.makeOpName("SkipDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(count.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SleepDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SleepDataset.java index 4d338ac51b6..926cc93fe89 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SleepDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SleepDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -46,16 +46,12 @@ public final class SleepDataset extends RawOp implements Operand { * @return a new instance of SleepDataset */ @Endpoint(describeByClass = true) - public static SleepDataset create(Scope scope, Operand inputDataset, Operand sleepMicroseconds, List> outputTypes, List outputShapes) { + public static SleepDataset create(Scope scope, Operand inputDataset, Operand sleepMicroseconds, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("SleepDataset", scope.makeOpName("SleepDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(sleepMicroseconds.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SlidingWindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SlidingWindowDataset.java index 2ce91cfa3ec..2c10ea43cd9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SlidingWindowDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SlidingWindowDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -52,18 +52,14 @@ public final class SlidingWindowDataset extends RawOp implements Operand * @return a new instance of SlidingWindowDataset */ @Endpoint(describeByClass = true) - public static SlidingWindowDataset create(Scope scope, Operand inputDataset, Operand windowSize, Operand windowShift, Operand windowStride, List> outputTypes, List outputShapes) { + public static SlidingWindowDataset create(Scope scope, Operand inputDataset, Operand windowSize, Operand windowShift, Operand windowStride, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("SlidingWindowDataset", scope.makeOpName("SlidingWindowDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(windowSize.asOutput()); opBuilder.addInput(windowShift.asOutput()); opBuilder.addInput(windowStride.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java deleted file mode 100644 index dd9afe060dc..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SnapshotDataset.java +++ /dev/null @@ -1,375 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.data; - -import java.util.List; -import org.tensorflow.DataType; -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.types.TString; -import org.tensorflow.types.family.TType; - -/** - * Creates a dataset that will write to / read from a snapshot. - *

- * This dataset attempts to determine whether a valid snapshot exists at the - * `snapshot_path`, and reads from the snapshot in lieu of using `input_dataset`. - * If not, it will run the preprocessing pipeline as usual, and write out a - * snapshot of the data processed for future use. - */ -public final class SnapshotDataset extends RawOp implements Operand { - - /** - * Optional attributes for {@link org.tensorflow.op.data.SnapshotDataset} - */ - public static class Options { - - /** - * @param compression - */ - public Options compression(String compression) { - this.compression = compression; - return this; - } - - /** - * @param readerPathPrefix - */ - public Options readerPathPrefix(String readerPathPrefix) { - this.readerPathPrefix = readerPathPrefix; - return this; - } - - /** - * @param writerPathPrefix - */ - public Options writerPathPrefix(String writerPathPrefix) { - this.writerPathPrefix = writerPathPrefix; - return this; - } - - /** - * @param shardSizeBytes - */ - public Options shardSizeBytes(Long shardSizeBytes) { - this.shardSizeBytes = shardSizeBytes; - return this; - } - - /** - * @param pendingSnapshotExpirySeconds - */ - public Options pendingSnapshotExpirySeconds(Long pendingSnapshotExpirySeconds) { - this.pendingSnapshotExpirySeconds = pendingSnapshotExpirySeconds; - return this; - } - - /** - * @param numReaderThreads - */ - public Options numReaderThreads(Long numReaderThreads) { - this.numReaderThreads = numReaderThreads; - return this; - } - - /** - * @param readerBufferSize - */ - public Options readerBufferSize(Long readerBufferSize) { - this.readerBufferSize = readerBufferSize; - return this; - } - - /** - * @param numWriterThreads - */ - public Options numWriterThreads(Long numWriterThreads) { - this.numWriterThreads = numWriterThreads; - return this; - } - - /** - * @param writerBufferSize - */ - public Options writerBufferSize(Long writerBufferSize) { - this.writerBufferSize = writerBufferSize; - return this; - } - - /** - * @param shuffleOnRead - */ - public Options shuffleOnRead(Boolean shuffleOnRead) { - this.shuffleOnRead = shuffleOnRead; - return this; - } - - /** - * @param seed - */ - public Options seed(Long seed) { - this.seed = seed; - return this; - } - - /** - * @param seed2 - */ - public Options seed2(Long seed2) { - this.seed2 = seed2; - return this; - } - - /** - * @param mode - */ - public Options mode(String mode) { - this.mode = mode; - return this; - } - - /** - * @param snapshotName - */ - public Options snapshotName(String snapshotName) { - this.snapshotName = snapshotName; - return this; - } - - private String compression; - private String readerPathPrefix; - private String writerPathPrefix; - private Long shardSizeBytes; - private Long pendingSnapshotExpirySeconds; - private Long numReaderThreads; - private Long readerBufferSize; - private Long numWriterThreads; - private Long writerBufferSize; - private Boolean shuffleOnRead; - private Long seed; - private Long seed2; - private String mode; - private String snapshotName; - - private Options() { - } - } - - /** - * Factory method to create a class wrapping a new SnapshotDataset operation. - * - * @param scope current scope - * @param inputDataset A variant tensor representing the input dataset. - * @param path The path we should write snapshots to / read snapshots from. - * @param outputTypes - * @param outputShapes - * @param options carries optional attributes values - * @return a new instance of SnapshotDataset - */ - @Endpoint(describeByClass = true) - public static SnapshotDataset create(Scope scope, Operand inputDataset, Operand path, List> outputTypes, List outputShapes, Options... options) { - OperationBuilder opBuilder = scope.env().opBuilder("SnapshotDataset", scope.makeOpName("SnapshotDataset")); - opBuilder.addInput(inputDataset.asOutput()); - opBuilder.addInput(path.asOutput()); - opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); - Shape[] outputShapesArray = new Shape[outputShapes.size()]; - for (int i = 0; i < outputShapesArray.length; ++i) { - outputShapesArray[i] = outputShapes.get(i); - } - opBuilder.setAttr("output_shapes", outputShapesArray); - if (options != null) { - for (Options opts : options) { - if (opts.compression != null) { - opBuilder.setAttr("compression", opts.compression); - } - if (opts.readerPathPrefix != null) { - opBuilder.setAttr("reader_path_prefix", opts.readerPathPrefix); - } - if (opts.writerPathPrefix != null) { - opBuilder.setAttr("writer_path_prefix", opts.writerPathPrefix); - } - if (opts.shardSizeBytes != null) { - opBuilder.setAttr("shard_size_bytes", opts.shardSizeBytes); - } - if (opts.pendingSnapshotExpirySeconds != null) { - opBuilder.setAttr("pending_snapshot_expiry_seconds", opts.pendingSnapshotExpirySeconds); - } - if (opts.numReaderThreads != null) { - opBuilder.setAttr("num_reader_threads", opts.numReaderThreads); - } - if (opts.readerBufferSize != null) { - opBuilder.setAttr("reader_buffer_size", opts.readerBufferSize); - } - if (opts.numWriterThreads != null) { - opBuilder.setAttr("num_writer_threads", opts.numWriterThreads); - } - if (opts.writerBufferSize != null) { - opBuilder.setAttr("writer_buffer_size", opts.writerBufferSize); - } - if (opts.shuffleOnRead != null) { - opBuilder.setAttr("shuffle_on_read", opts.shuffleOnRead); - } - if (opts.seed != null) { - opBuilder.setAttr("seed", opts.seed); - } - if (opts.seed2 != null) { - opBuilder.setAttr("seed2", opts.seed2); - } - if (opts.mode != null) { - opBuilder.setAttr("mode", opts.mode); - } - if (opts.snapshotName != null) { - opBuilder.setAttr("snapshot_name", opts.snapshotName); - } - } - } - return new SnapshotDataset(opBuilder.build()); - } - - /** - * @param compression - */ - public static Options compression(String compression) { - return new Options().compression(compression); - } - - /** - * @param readerPathPrefix - */ - public static Options readerPathPrefix(String readerPathPrefix) { - return new Options().readerPathPrefix(readerPathPrefix); - } - - /** - * @param writerPathPrefix - */ - public static Options writerPathPrefix(String writerPathPrefix) { - return new Options().writerPathPrefix(writerPathPrefix); - } - - /** - * @param shardSizeBytes - */ - public static Options shardSizeBytes(Long shardSizeBytes) { - return new Options().shardSizeBytes(shardSizeBytes); - } - - /** - * @param pendingSnapshotExpirySeconds - */ - public static Options pendingSnapshotExpirySeconds(Long pendingSnapshotExpirySeconds) { - return new Options().pendingSnapshotExpirySeconds(pendingSnapshotExpirySeconds); - } - - /** - * @param numReaderThreads - */ - public static Options numReaderThreads(Long numReaderThreads) { - return new Options().numReaderThreads(numReaderThreads); - } - - /** - * @param readerBufferSize - */ - public static Options readerBufferSize(Long readerBufferSize) { - return new Options().readerBufferSize(readerBufferSize); - } - - /** - * @param numWriterThreads - */ - public static Options numWriterThreads(Long numWriterThreads) { - return new Options().numWriterThreads(numWriterThreads); - } - - /** - * @param writerBufferSize - */ - public static Options writerBufferSize(Long writerBufferSize) { - return new Options().writerBufferSize(writerBufferSize); - } - - /** - * @param shuffleOnRead - */ - public static Options shuffleOnRead(Boolean shuffleOnRead) { - return new Options().shuffleOnRead(shuffleOnRead); - } - - /** - * @param seed - */ - public static Options seed(Long seed) { - return new Options().seed(seed); - } - - /** - * @param seed2 - */ - public static Options seed2(Long seed2) { - return new Options().seed2(seed2); - } - - /** - * @param mode - */ - public static Options mode(String mode) { - return new Options().mode(mode); - } - - /** - * @param snapshotName - */ - public static Options snapshotName(String snapshotName) { - return new Options().snapshotName(snapshotName); - } - - /** - */ - public Output handle() { - return handle; - } - - @Override - @SuppressWarnings("unchecked") - public Output asOutput() { - return (Output) handle; - } - - /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "SnapshotDataset"; - - private Output handle; - - private SnapshotDataset(Operation operation) { - super(operation); - int outputIdx = 0; - handle = operation.output(outputIdx++); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SqlDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SqlDataset.java index 5fa54ddc199..e22c4f13bb1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SqlDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/SqlDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -48,17 +48,13 @@ public final class SqlDataset extends RawOp implements Operand { * @return a new instance of SqlDataset */ @Endpoint(describeByClass = true) - public static SqlDataset create(Scope scope, Operand driverName, Operand dataSourceName, Operand query, List> outputTypes, List outputShapes) { + public static SqlDataset create(Scope scope, Operand driverName, Operand dataSourceName, Operand query, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("SqlDataset", scope.makeOpName("SqlDataset")); opBuilder.addInput(driverName.asOutput()); opBuilder.addInput(dataSourceName.asOutput()); opBuilder.addInput(query.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeDataset.java index ea531147a0e..61d3ebc7916 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/TakeDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -50,16 +50,12 @@ public final class TakeDataset extends RawOp implements Operand { * @return a new instance of TakeDataset */ @Endpoint(describeByClass = true) - public static TakeDataset create(Scope scope, Operand inputDataset, Operand count, List> outputTypes, List outputShapes) { + public static TakeDataset create(Scope scope, Operand inputDataset, Operand count, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("TakeDataset", scope.makeOpName("TakeDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(count.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ThreadPoolDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ThreadPoolDataset.java index 1c8886b81f5..fe1f7d9f859 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ThreadPoolDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ThreadPoolDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -46,16 +46,12 @@ public final class ThreadPoolDataset extends RawOp implements Operand { * @return a new instance of ThreadPoolDataset */ @Endpoint(describeByClass = true) - public static ThreadPoolDataset create(Scope scope, Operand inputDataset, Operand threadPool, List> outputTypes, List outputShapes) { + public static ThreadPoolDataset create(Scope scope, Operand inputDataset, Operand threadPool, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ThreadPoolDataset", scope.makeOpName("ThreadPoolDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(threadPool.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnbatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnbatchDataset.java index e573884c4c4..1e8cc5cd305 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnbatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UnbatchDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -45,15 +45,11 @@ public final class UnbatchDataset extends RawOp implements Operand { * @return a new instance of UnbatchDataset */ @Endpoint(describeByClass = true) - public static UnbatchDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static UnbatchDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("UnbatchDataset", scope.makeOpName("UnbatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UniqueDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UniqueDataset.java index 375a80da417..6bd918a3db0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UniqueDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/UniqueDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -45,15 +45,11 @@ public final class UniqueDataset extends RawOp implements Operand { * @return a new instance of UniqueDataset */ @Endpoint(describeByClass = true) - public static UniqueDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static UniqueDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("UniqueDataset", scope.makeOpName("UniqueDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WindowDataset.java index 89ada3cfb40..a44807564c1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WindowDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/WindowDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -98,7 +98,7 @@ public final class WindowDataset extends RawOp implements Operand { * @return a new instance of WindowDataset */ @Endpoint(describeByClass = true) - public static WindowDataset create(Scope scope, Operand inputDataset, Operand size, Operand shift, Operand stride, Operand dropRemainder, List> outputTypes, List outputShapes) { + public static WindowDataset create(Scope scope, Operand inputDataset, Operand size, Operand shift, Operand stride, Operand dropRemainder, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("WindowDataset", scope.makeOpName("WindowDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(size.asOutput()); @@ -106,11 +106,7 @@ public static WindowDataset create(Scope scope, Operand inputDataset, Operand opBuilder.addInput(stride.asOutput()); opBuilder.addInput(dropRemainder.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ZipDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ZipDataset.java index 133ef39430c..47dd0be0775 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ZipDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ZipDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -53,15 +52,11 @@ public final class ZipDataset extends RawOp implements Operand { * @return a new instance of ZipDataset */ @Endpoint(describeByClass = true) - public static ZipDataset create(Scope scope, Iterable> inputDatasets, List> outputTypes, List outputShapes) { + public static ZipDataset create(Scope scope, Iterable> inputDatasets, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ZipDataset", scope.makeOpName("ZipDataset")); opBuilder.addInputList(Operands.asOutputs(inputDatasets)); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertCardinalityDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertCardinalityDataset.java index 09ade444b05..8ab5989b44b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertCardinalityDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertCardinalityDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -46,16 +46,12 @@ public final class AssertCardinalityDataset extends RawOp implements Operand inputDataset, Operand cardinality, List> outputTypes, List outputShapes) { + public static AssertCardinalityDataset create(Scope scope, Operand inputDataset, Operand cardinality, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("AssertCardinalityDataset", scope.makeOpName("AssertCardinalityDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(cardinality.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertNextDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertNextDataset.java index ac3227bda10..2483e971723 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertNextDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AssertNextDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -46,16 +46,12 @@ public final class AssertNextDataset extends RawOp implements Operand { * @return a new instance of AssertNextDataset */ @Endpoint(describeByClass = true) - public static AssertNextDataset create(Scope scope, Operand inputDataset, Operand transformations, List> outputTypes, List outputShapes) { + public static AssertNextDataset create(Scope scope, Operand inputDataset, Operand transformations, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalAssertNextDataset", scope.makeOpName("AssertNextDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(transformations.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AutoShardDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AutoShardDataset.java index 556b7acd89a..fc630d53919 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AutoShardDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/AutoShardDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -76,17 +76,13 @@ private Options() { * @return a new instance of AutoShardDataset */ @Endpoint(describeByClass = true) - public static AutoShardDataset create(Scope scope, Operand inputDataset, Operand numWorkers, Operand index, List> outputTypes, List outputShapes, Options... options) { + public static AutoShardDataset create(Scope scope, Operand inputDataset, Operand numWorkers, Operand index, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalAutoShardDataset", scope.makeOpName("AutoShardDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numWorkers.asOutput()); opBuilder.addInput(index.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/BytesProducedStatsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/BytesProducedStatsDataset.java index ca545d5e89b..76a5a96f895 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/BytesProducedStatsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/BytesProducedStatsDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,16 +47,12 @@ public final class BytesProducedStatsDataset extends RawOp implements Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { + public static BytesProducedStatsDataset create(Scope scope, Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalBytesProducedStatsDataset", scope.makeOpName("BytesProducedStatsDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(tag.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ChooseFastestDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ChooseFastestDataset.java index b3551d4f38d..b735a535c48 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ChooseFastestDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ChooseFastestDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -46,16 +45,12 @@ public final class ChooseFastestDataset extends RawOp implements Operand * @return a new instance of ChooseFastestDataset */ @Endpoint(describeByClass = true) - public static ChooseFastestDataset create(Scope scope, Iterable> inputDatasets, Long numExperiments, List> outputTypes, List outputShapes) { + public static ChooseFastestDataset create(Scope scope, Iterable> inputDatasets, Long numExperiments, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalChooseFastestDataset", scope.makeOpName("ChooseFastestDataset")); opBuilder.addInputList(Operands.asOutputs(inputDatasets)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("num_experiments", numExperiments); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java index 56037ab9963..6307dddf9dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -73,7 +73,7 @@ private Options() { * @return a new instance of DataServiceDataset */ @Endpoint(describeByClass = true) - public static DataServiceDataset create(Scope scope, Operand datasetId, Operand processingMode, Operand address, Operand protocol, Operand jobName, Operand maxOutstandingRequests, Operand iterationCounter, List> outputTypes, List outputShapes, Options... options) { + public static DataServiceDataset create(Scope scope, Operand datasetId, Operand processingMode, Operand address, Operand protocol, Operand jobName, Operand maxOutstandingRequests, Operand iterationCounter, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("DataServiceDataset", scope.makeOpName("DataServiceDataset")); opBuilder.addInput(datasetId.asOutput()); opBuilder.addInput(processingMode.asOutput()); @@ -83,11 +83,7 @@ public static DataServiceDataset create(Scope scope, Operand datasetId, opBuilder.addInput(maxOutstandingRequests.asOutput()); opBuilder.addInput(iterationCounter.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DenseToSparseBatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DenseToSparseBatchDataset.java index 1edc739575f..fd19c2fc45c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DenseToSparseBatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DenseToSparseBatchDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -51,17 +51,13 @@ public final class DenseToSparseBatchDataset extends RawOp implements Operand inputDataset, Operand batchSize, Operand rowShape, List> outputTypes, List outputShapes) { + public static DenseToSparseBatchDataset create(Scope scope, Operand inputDataset, Operand batchSize, Operand rowShape, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalDenseToSparseBatchDataset", scope.makeOpName("DenseToSparseBatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(batchSize.asOutput()); opBuilder.addInput(rowShape.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DirectedInterleaveDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DirectedInterleaveDataset.java index 076f7220fc0..c1bdefa37dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DirectedInterleaveDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DirectedInterleaveDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -49,16 +48,12 @@ public final class DirectedInterleaveDataset extends RawOp implements Operand selectorInputDataset, Iterable> dataInputDatasets, List> outputTypes, List outputShapes) { + public static DirectedInterleaveDataset create(Scope scope, Operand selectorInputDataset, Iterable> dataInputDatasets, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalDirectedInterleaveDataset", scope.makeOpName("DirectedInterleaveDataset")); opBuilder.addInput(selectorInputDataset.asOutput()); opBuilder.addInputList(Operands.asOutputs(dataInputDatasets)); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/IgnoreErrorsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/IgnoreErrorsDataset.java index 8390d3d672f..419baa98adf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/IgnoreErrorsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/IgnoreErrorsDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -45,15 +45,11 @@ public final class IgnoreErrorsDataset extends RawOp implements Operand { * @return a new instance of IgnoreErrorsDataset */ @Endpoint(describeByClass = true) - public static IgnoreErrorsDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static IgnoreErrorsDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalIgnoreErrorsDataset", scope.makeOpName("IgnoreErrorsDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LatencyStatsDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LatencyStatsDataset.java index f20c55f6c82..1c333d77ac5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LatencyStatsDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LatencyStatsDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,16 +47,12 @@ public final class LatencyStatsDataset extends RawOp implements Operand { * @return a new instance of LatencyStatsDataset */ @Endpoint(describeByClass = true) - public static LatencyStatsDataset create(Scope scope, Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { + public static LatencyStatsDataset create(Scope scope, Operand inputDataset, Operand tag, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalLatencyStatsDataset", scope.makeOpName("LatencyStatsDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(tag.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LmdbDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LmdbDataset.java index 339154d5c31..4cdb41cc8e3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LmdbDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/LmdbDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -45,15 +45,11 @@ public final class LmdbDataset extends RawOp implements Operand { * @return a new instance of LmdbDataset */ @Endpoint(describeByClass = true) - public static LmdbDataset create(Scope scope, Operand filenames, List> outputTypes, List outputShapes) { + public static LmdbDataset create(Scope scope, Operand filenames, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalLMDBDataset", scope.makeOpName("LmdbDataset")); opBuilder.addInput(filenames.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MaxIntraOpParallelismDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MaxIntraOpParallelismDataset.java index cb22b4e910d..37d6f3915cc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MaxIntraOpParallelismDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/MaxIntraOpParallelismDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,16 +47,12 @@ public final class MaxIntraOpParallelismDataset extends RawOp implements Operand * @return a new instance of MaxIntraOpParallelismDataset */ @Endpoint(describeByClass = true) - public static MaxIntraOpParallelismDataset create(Scope scope, Operand inputDataset, Operand maxIntraOpParallelism, List> outputTypes, List outputShapes) { + public static MaxIntraOpParallelismDataset create(Scope scope, Operand inputDataset, Operand maxIntraOpParallelism, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalMaxIntraOpParallelismDataset", scope.makeOpName("MaxIntraOpParallelismDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(maxIntraOpParallelism.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/NonSerializableDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/NonSerializableDataset.java index 06543121f3a..ef096a1b19f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/NonSerializableDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/NonSerializableDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -44,15 +44,11 @@ public final class NonSerializableDataset extends RawOp implements Operand inputDataset, List> outputTypes, List outputShapes) { + public static NonSerializableDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalNonSerializableDataset", scope.makeOpName("NonSerializableDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParseExampleDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParseExampleDataset.java index c1465ff7082..eb9aceba455 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParseExampleDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ParseExampleDataset.java @@ -18,7 +18,6 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -102,7 +101,7 @@ private Options() { * @return a new instance of ParseExampleDataset */ @Endpoint(describeByClass = true) - public static ParseExampleDataset create(Scope scope, Operand inputDataset, Operand numParallelCalls, Iterable> denseDefaults, List sparseKeys, List denseKeys, List> sparseTypes, List denseShapes, List> outputTypes, List outputShapes, List> raggedValueTypes, List> raggedSplitTypes, Options... options) { + public static ParseExampleDataset create(Scope scope, Operand inputDataset, Operand numParallelCalls, Iterable> denseDefaults, List sparseKeys, List denseKeys, List> sparseTypes, List denseShapes, List> outputTypes, List outputShapes, List> raggedValueTypes, List> raggedSplitTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ParseExampleDatasetV2", scope.makeOpName("ParseExampleDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numParallelCalls.asOutput()); @@ -118,36 +117,20 @@ public static ParseExampleDataset create(Scope scope, Operand inputDataset, O denseKeysArray[i] = denseKeys.get(i); } opBuilder.setAttr("dense_keys", denseKeysArray); - DataType[] sparseTypesArray = new DataType[sparseTypes.size()]; - for (int i = 0; i < sparseTypesArray.length; ++i) { - sparseTypesArray[i] = sparseTypes.get(i); - } - opBuilder.setAttr("sparse_types", sparseTypesArray); + opBuilder.setAttr("sparse_types", Operands.toDataTypes(sparseTypes)); Shape[] denseShapesArray = new Shape[denseShapes.size()]; for (int i = 0; i < denseShapesArray.length; ++i) { denseShapesArray[i] = denseShapes.get(i); } opBuilder.setAttr("dense_shapes", denseShapesArray); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); } opBuilder.setAttr("output_shapes", outputShapesArray); - DataType[] raggedValueTypesArray = new DataType[raggedValueTypes.size()]; - for (int i = 0; i < raggedValueTypesArray.length; ++i) { - raggedValueTypesArray[i] = raggedValueTypes.get(i); - } - opBuilder.setAttr("ragged_value_types", raggedValueTypesArray); - DataType[] raggedSplitTypesArray = new DataType[raggedSplitTypes.size()]; - for (int i = 0; i < raggedSplitTypesArray.length; ++i) { - raggedSplitTypesArray[i] = raggedSplitTypes.get(i); - } - opBuilder.setAttr("ragged_split_types", raggedSplitTypesArray); + opBuilder.setAttr("ragged_value_types", Operands.toDataTypes(raggedValueTypes)); + opBuilder.setAttr("ragged_split_types", Operands.toDataTypes(raggedSplitTypes)); if (options != null) { for (Options opts : options) { if (opts.deterministic != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/PrivateThreadPoolDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/PrivateThreadPoolDataset.java index d987b21f0bd..bc0a4750675 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/PrivateThreadPoolDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/PrivateThreadPoolDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,16 +47,12 @@ public final class PrivateThreadPoolDataset extends RawOp implements Operand inputDataset, Operand numThreads, List> outputTypes, List outputShapes) { + public static PrivateThreadPoolDataset create(Scope scope, Operand inputDataset, Operand numThreads, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalPrivateThreadPoolDataset", scope.makeOpName("PrivateThreadPoolDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numThreads.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RandomDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RandomDataset.java index 239e8531c03..435d331f68e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RandomDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RandomDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -49,16 +49,12 @@ public final class RandomDataset extends RawOp implements Operand { * @return a new instance of RandomDataset */ @Endpoint(describeByClass = true) - public static RandomDataset create(Scope scope, Operand seed, Operand seed2, List> outputTypes, List outputShapes) { + public static RandomDataset create(Scope scope, Operand seed, Operand seed2, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalRandomDataset", scope.makeOpName("RandomDataset")); opBuilder.addInput(seed.asOutput()); opBuilder.addInput(seed2.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RebatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RebatchDataset.java index d970139989a..219f9bd44b4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RebatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/RebatchDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -72,16 +72,12 @@ private Options() { * @return a new instance of RebatchDataset */ @Endpoint(describeByClass = true) - public static RebatchDataset create(Scope scope, Operand inputDataset, Operand numReplicas, List> outputTypes, List outputShapes, Options... options) { + public static RebatchDataset create(Scope scope, Operand inputDataset, Operand numReplicas, List> outputTypes, List outputShapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalRebatchDataset", scope.makeOpName("RebatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(numReplicas.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SetStatsAggregatorDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SetStatsAggregatorDataset.java index 66e9be0e7e9..b87aa6597ea 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SetStatsAggregatorDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SetStatsAggregatorDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -48,18 +48,14 @@ public final class SetStatsAggregatorDataset extends RawOp implements Operand inputDataset, Operand statsAggregator, Operand tag, Operand counterPrefix, List> outputTypes, List outputShapes) { + public static SetStatsAggregatorDataset create(Scope scope, Operand inputDataset, Operand statsAggregator, Operand tag, Operand counterPrefix, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalSetStatsAggregatorDataset", scope.makeOpName("SetStatsAggregatorDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(statsAggregator.asOutput()); opBuilder.addInput(tag.asOutput()); opBuilder.addInput(counterPrefix.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SleepDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SleepDataset.java index 8377ac20bdf..4fac4cfc637 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SleepDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SleepDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -46,16 +46,12 @@ public final class SleepDataset extends RawOp implements Operand { * @return a new instance of SleepDataset */ @Endpoint(describeByClass = true) - public static SleepDataset create(Scope scope, Operand inputDataset, Operand sleepMicroseconds, List> outputTypes, List outputShapes) { + public static SleepDataset create(Scope scope, Operand inputDataset, Operand sleepMicroseconds, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalSleepDataset", scope.makeOpName("SleepDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(sleepMicroseconds.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SlidingWindowDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SlidingWindowDataset.java index a68f34c069d..85dde0defb6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SlidingWindowDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SlidingWindowDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -52,18 +52,14 @@ public final class SlidingWindowDataset extends RawOp implements Operand * @return a new instance of SlidingWindowDataset */ @Endpoint(describeByClass = true) - public static SlidingWindowDataset create(Scope scope, Operand inputDataset, Operand windowSize, Operand windowShift, Operand windowStride, List> outputTypes, List outputShapes) { + public static SlidingWindowDataset create(Scope scope, Operand inputDataset, Operand windowSize, Operand windowShift, Operand windowStride, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalSlidingWindowDataset", scope.makeOpName("SlidingWindowDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(windowSize.asOutput()); opBuilder.addInput(windowShift.asOutput()); opBuilder.addInput(windowStride.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SqlDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SqlDataset.java index 126818001ab..6964f134918 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SqlDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/SqlDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -48,17 +48,13 @@ public final class SqlDataset extends RawOp implements Operand { * @return a new instance of SqlDataset */ @Endpoint(describeByClass = true) - public static SqlDataset create(Scope scope, Operand driverName, Operand dataSourceName, Operand query, List> outputTypes, List outputShapes) { + public static SqlDataset create(Scope scope, Operand driverName, Operand dataSourceName, Operand query, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalSqlDataset", scope.makeOpName("SqlDataset")); opBuilder.addInput(driverName.asOutput()); opBuilder.addInput(dataSourceName.asOutput()); opBuilder.addInput(query.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ThreadPoolDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ThreadPoolDataset.java index c8c7d309f69..fb2af7e60a7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ThreadPoolDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/ThreadPoolDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -46,16 +46,12 @@ public final class ThreadPoolDataset extends RawOp implements Operand { * @return a new instance of ThreadPoolDataset */ @Endpoint(describeByClass = true) - public static ThreadPoolDataset create(Scope scope, Operand inputDataset, Operand threadPool, List> outputTypes, List outputShapes) { + public static ThreadPoolDataset create(Scope scope, Operand inputDataset, Operand threadPool, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalThreadPoolDataset", scope.makeOpName("ThreadPoolDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder.addInput(threadPool.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UnbatchDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UnbatchDataset.java index 5e0750da3c6..c00d3fe5470 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UnbatchDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UnbatchDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -45,15 +45,11 @@ public final class UnbatchDataset extends RawOp implements Operand { * @return a new instance of UnbatchDataset */ @Endpoint(describeByClass = true) - public static UnbatchDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static UnbatchDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalUnbatchDataset", scope.makeOpName("UnbatchDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java index 486a71b0f9f..36e07fe069c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java @@ -20,12 +20,12 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,15 +47,11 @@ public final class UncompressElement extends RawOp implements Iterable compressed, List> outputTypes, List outputShapes) { + public static UncompressElement create(Scope scope, Operand compressed, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("UncompressElement", scope.makeOpName("UncompressElement")); opBuilder.addInput(compressed.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UniqueDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UniqueDataset.java index dde6d8a9552..6461c5afafa 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UniqueDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UniqueDataset.java @@ -18,12 +18,12 @@ package org.tensorflow.op.data.experimental; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -45,15 +45,11 @@ public final class UniqueDataset extends RawOp implements Operand { * @return a new instance of UniqueDataset */ @Endpoint(describeByClass = true) - public static UniqueDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { + public static UniqueDataset create(Scope scope, Operand inputDataset, List> outputTypes, List outputShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalUniqueDataset", scope.makeOpName("UniqueDataset")); opBuilder.addInput(inputDataset.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] outputTypesArray = new DataType[outputTypes.size()]; - for (int i = 0; i < outputTypesArray.length; ++i) { - outputTypesArray[i] = outputTypes.get(i); - } - opBuilder.setAttr("output_types", outputTypesArray); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); Shape[] outputShapesArray = new Shape[outputShapes.size()]; for (int i = 0; i < outputShapesArray.length; ++i) { outputShapesArray[i] = outputShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/CheckNumerics.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/CheckNumerics.java index 2f24bf16a5c..a24bde53102 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/CheckNumerics.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/CheckNumerics.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Checks a tensor for NaN, -Inf and +Inf values. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/DebugNumericsSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/DebugNumericsSummary.java index 4b37bd511ed..281d59d21c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/DebugNumericsSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/debugging/DebugNumericsSummary.java @@ -17,11 +17,11 @@ package org.tensorflow.op.debugging; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -132,11 +132,11 @@ private Options() { * @return a new instance of DebugNumericsSummary */ @Endpoint(describeByClass = true) - public static DebugNumericsSummary create(Scope scope, Operand input, DataType outputDtype, Options... options) { + public static DebugNumericsSummary create(Scope scope, Operand input, Class outputDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("DebugNumericSummaryV2", scope.makeOpName("DebugNumericsSummary")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("output_dtype", outputDtype); + opBuilder.setAttr("output_dtype", Operands.toDataType(outputDtype)); if (options != null) { for (Options opts : options) { if (opts.tensorDebugMode != null) { @@ -160,7 +160,7 @@ public static DebugNumericsSummary creat */ @Endpoint(describeByClass = true) public static DebugNumericsSummary create(Scope scope, Operand input, Options... options) { - return create(scope, input, TFloat32.DTYPE, options); + return create(scope, input, TFloat32.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Cast.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Cast.java index 0731c5112ae..502d0e564e3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Cast.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Cast.java @@ -17,11 +17,11 @@ package org.tensorflow.op.dtypes; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -65,11 +65,11 @@ private Options() { * @return a new instance of Cast */ @Endpoint(describeByClass = true) - public static Cast create(Scope scope, Operand x, DataType DstT, Options... options) { + public static Cast create(Scope scope, Operand x, Class DstT, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Cast", scope.makeOpName("Cast")); opBuilder.addInput(x.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("DstT", DstT); + opBuilder.setAttr("DstT", Operands.toDataType(DstT)); if (options != null) { for (Options opts : options) { if (opts.Truncate != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Complex.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Complex.java index aeed73b547c..126d26f5945 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Complex.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/dtypes/Complex.java @@ -17,11 +17,11 @@ package org.tensorflow.op.dtypes; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -62,12 +62,12 @@ public final class Complex extends RawOp implements Operand * @return a new instance of Complex */ @Endpoint(describeByClass = true) - public static Complex create(Scope scope, Operand real, Operand imag, DataType Tout) { + public static Complex create(Scope scope, Operand real, Operand imag, Class Tout) { OperationBuilder opBuilder = scope.env().opBuilder("Complex", scope.makeOpName("Complex")); opBuilder.addInput(real.asOutput()); opBuilder.addInput(imag.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Tout", Tout); + opBuilder.setAttr("Tout", Operands.toDataType(Tout)); return new Complex(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustContrast.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustContrast.java index 8e53fc204ce..1c3eb4b86bc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustContrast.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustContrast.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Adjust the contrast of one or more images. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustHue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustHue.java index e76cf4cc056..fa9d0e9a50c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustHue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustHue.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Adjust the hue of one or more images. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustSaturation.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustSaturation.java index f8f62572e43..5c023d89c6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustSaturation.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/AdjustSaturation.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Adjust the saturation of one or more images. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResize.java index a02c0afc7e1..7724191d9ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResize.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Extracts crops from the input image tensor and resizes them. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradBoxes.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradBoxes.java index e86c413c8d9..82dae7df1e0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradBoxes.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradBoxes.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of the crop_and_resize op wrt the input boxes tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradImage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradImage.java index 0d2221348d7..97cd59a6db6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradImage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/CropAndResizeGradImage.java @@ -17,11 +17,11 @@ package org.tensorflow.op.image; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -29,7 +29,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of the crop_and_resize op wrt the input image tensor. @@ -84,14 +83,14 @@ private Options() { * @return a new instance of CropAndResizeGradImage */ @Endpoint(describeByClass = true) - public static CropAndResizeGradImage create(Scope scope, Operand grads, Operand boxes, Operand boxInd, Operand imageSize, DataType T, Options... options) { + public static CropAndResizeGradImage create(Scope scope, Operand grads, Operand boxes, Operand boxInd, Operand imageSize, Class T, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("CropAndResizeGradImage", scope.makeOpName("CropAndResizeGradImage")); opBuilder.addInput(grads.asOutput()); opBuilder.addInput(boxes.asOutput()); opBuilder.addInput(boxInd.asOutput()); opBuilder.addInput(imageSize.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("T", T); + opBuilder.setAttr("T", Operands.toDataType(T)); if (options != null) { for (Options opts : options) { if (opts.method != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DecodePng.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DecodePng.java index 2130da508dd..c2eb543f629 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DecodePng.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DecodePng.java @@ -17,11 +17,11 @@ package org.tensorflow.op.image; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -29,7 +29,6 @@ import org.tensorflow.types.TString; import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Decode a PNG-encoded image to a uint8 or uint16 tensor. @@ -92,11 +91,11 @@ private Options() { * @return a new instance of DecodePng */ @Endpoint(describeByClass = true) - public static DecodePng create(Scope scope, Operand contents, DataType dtype, Options... options) { + public static DecodePng create(Scope scope, Operand contents, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("DecodePng", scope.makeOpName("DecodePng")); opBuilder.addInput(contents.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.channels != null) { @@ -117,7 +116,7 @@ public static DecodePng create(Scope scope, Operand create(Scope scope, Operand contents, Options... options) { - return create(scope, contents, TUint8.DTYPE, options); + return create(scope, contents, TUint8.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DrawBoundingBoxes.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DrawBoundingBoxes.java index 3edef6b3b4a..5204c603b31 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DrawBoundingBoxes.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/DrawBoundingBoxes.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Draw bounding boxes on a batch of images. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/EncodePng.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/EncodePng.java index df2579b94d1..cd0e8b89f7b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/EncodePng.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/EncodePng.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * PNG-encode an image. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractJpegShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractJpegShape.java index 1e56198a4b8..bb12cb5cb1a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractJpegShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractJpegShape.java @@ -17,11 +17,11 @@ package org.tensorflow.op.image; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -29,7 +29,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Extract the shape information of a JPEG-encoded image. @@ -51,11 +50,11 @@ public final class ExtractJpegShape extends RawOp implements * @return a new instance of ExtractJpegShape */ @Endpoint(describeByClass = true) - public static ExtractJpegShape create(Scope scope, Operand contents, DataType outputType) { + public static ExtractJpegShape create(Scope scope, Operand contents, Class outputType) { OperationBuilder opBuilder = scope.env().opBuilder("ExtractJpegShape", scope.makeOpName("ExtractJpegShape")); opBuilder.addInput(contents.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("output_type", outputType); + opBuilder.setAttr("output_type", Operands.toDataType(outputType)); return new ExtractJpegShape(opBuilder.build()); } @@ -68,7 +67,7 @@ public static ExtractJpegShape create(Scope scope, Operan */ @Endpoint(describeByClass = true) public static ExtractJpegShape create(Scope scope, Operand contents) { - return create(scope, contents, TInt32.DTYPE); + return create(scope, contents, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/HsvToRgb.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/HsvToRgb.java index f51027d884b..7934480bad9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/HsvToRgb.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/HsvToRgb.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Convert one or more images from HSV to RGB. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ImageProjectiveTransformV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ImageProjectiveTransformV2.java index fc72745c500..5fd57e44754 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ImageProjectiveTransformV2.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ImageProjectiveTransformV2.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Applies the given transform to each of the images. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/NonMaxSuppression.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/NonMaxSuppression.java index da5681b0bd3..0f2e983f786 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/NonMaxSuppression.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/NonMaxSuppression.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Greedily selects a subset of bounding boxes in descending order of score, diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RandomCrop.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RandomCrop.java index 355b2bf37ed..133c38eaaa3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RandomCrop.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RandomCrop.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Randomly crop `image`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeArea.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeArea.java index 2ad69becce8..4d998f6cf1b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeArea.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeArea.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Resize `images` to `size` using area interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubic.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubic.java index b6adb9242d4..ecc156ed9d8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubic.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubic.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Resize `images` to `size` using bicubic interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubicGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubicGrad.java index 4f59a26a025..e5526a769f3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubicGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBicubicGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of bicubic interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinear.java index fa02d5b9ec7..0c96019b54d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinear.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Resize `images` to `size` using bilinear interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinearGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinearGrad.java index 18dc81aa7ec..38c20992da0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinearGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeBilinearGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of bilinear interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighbor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighbor.java index 72fca1f45fe..3e4763c692c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighbor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighbor.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Resize `images` to `size` using nearest neighbor interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighborGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighborGrad.java index 56f33a63e05..c3d2dd10e2a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighborGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ResizeNearestNeighborGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of nearest neighbor interpolation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RgbToHsv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RgbToHsv.java index 1b7a0c33d3c..2746052a7b0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RgbToHsv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/RgbToHsv.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Converts one or more images from RGB to HSV. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/SampleDistortedBoundingBox.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/SampleDistortedBoundingBox.java index 4fec1955f54..78a37b37dc2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/SampleDistortedBoundingBox.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/SampleDistortedBoundingBox.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Generate a single randomly distorted bounding box for an image. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslate.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslate.java index 45bb5c00961..41d49fd75f2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslate.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslate.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** */ diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslateGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslateGrad.java index 680351d79a4..b6f55d1d159 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslateGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ScaleAndTranslateGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodePaddedRaw.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodePaddedRaw.java index 2a37c8e0bad..ca41464e1b3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodePaddedRaw.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodePaddedRaw.java @@ -17,11 +17,11 @@ package org.tensorflow.op.io; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -29,7 +29,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Reinterpret the bytes of a string as a vector of numbers. @@ -71,12 +70,12 @@ private Options() { * @return a new instance of DecodePaddedRaw */ @Endpoint(describeByClass = true) - public static DecodePaddedRaw create(Scope scope, Operand inputBytes, Operand fixedLength, DataType outType, Options... options) { + public static DecodePaddedRaw create(Scope scope, Operand inputBytes, Operand fixedLength, Class outType, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("DecodePaddedRaw", scope.makeOpName("DecodePaddedRaw")); opBuilder.addInput(inputBytes.asOutput()); opBuilder.addInput(fixedLength.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); if (options != null) { for (Options opts : options) { if (opts.littleEndian != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodeRaw.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodeRaw.java index fac97605459..5ef7da94d51 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodeRaw.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DecodeRaw.java @@ -17,11 +17,11 @@ package org.tensorflow.op.io; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -68,11 +68,11 @@ private Options() { * @return a new instance of DecodeRaw */ @Endpoint(describeByClass = true) - public static DecodeRaw create(Scope scope, Operand bytes, DataType outType, Options... options) { + public static DecodeRaw create(Scope scope, Operand bytes, Class outType, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("DecodeRaw", scope.makeOpName("DecodeRaw")); opBuilder.addInput(bytes.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); if (options != null) { for (Options opts : options) { if (opts.littleEndian != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DeserializeManySparse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DeserializeManySparse.java index b1006d65daa..9d2e496e50e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DeserializeManySparse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/DeserializeManySparse.java @@ -17,11 +17,11 @@ package org.tensorflow.op.io; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -90,11 +90,11 @@ public final class DeserializeManySparse extends RawOp { * @return a new instance of DeserializeManySparse */ @Endpoint(describeByClass = true) - public static DeserializeManySparse create(Scope scope, Operand serializedSparse, DataType dtype) { + public static DeserializeManySparse create(Scope scope, Operand serializedSparse, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("DeserializeManySparse", scope.makeOpName("DeserializeManySparse")); opBuilder.addInput(serializedSparse.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new DeserializeManySparse(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/FifoQueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/FifoQueue.java index 8ea087d91f9..0559fbddbf1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/FifoQueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/FifoQueue.java @@ -18,12 +18,12 @@ package org.tensorflow.op.io; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -97,14 +97,10 @@ private Options() { * @return a new instance of FifoQueue */ @Endpoint(describeByClass = true) - public static FifoQueue create(Scope scope, List> componentTypes, Options... options) { + public static FifoQueue create(Scope scope, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("FIFOQueueV2", scope.makeOpName("FifoQueue")); opBuilder = scope.apply(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; - for (int i = 0; i < componentTypesArray.length; ++i) { - componentTypesArray[i] = componentTypes.get(i); - } - opBuilder.setAttr("component_types", componentTypesArray); + opBuilder.setAttr("component_types", Operands.toDataTypes(componentTypes)); if (options != null) { for (Options opts : options) { if (opts.shapes != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PaddingFifoQueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PaddingFifoQueue.java index 615e5749f28..691e35140ca 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PaddingFifoQueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PaddingFifoQueue.java @@ -18,12 +18,12 @@ package org.tensorflow.op.io; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -105,14 +105,10 @@ private Options() { * @return a new instance of PaddingFifoQueue */ @Endpoint(describeByClass = true) - public static PaddingFifoQueue create(Scope scope, List> componentTypes, Options... options) { + public static PaddingFifoQueue create(Scope scope, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("PaddingFIFOQueueV2", scope.makeOpName("PaddingFifoQueue")); opBuilder = scope.apply(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; - for (int i = 0; i < componentTypesArray.length; ++i) { - componentTypesArray[i] = componentTypes.get(i); - } - opBuilder.setAttr("component_types", componentTypesArray); + opBuilder.setAttr("component_types", Operands.toDataTypes(componentTypes)); if (options != null) { for (Options opts : options) { if (opts.shapes != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseExample.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseExample.java index 512e6d6f1ee..bdd66a0a62a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseExample.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseExample.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -33,6 +32,7 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; /** * Transforms a vector of tf.Example protos (as strings) into typed tensors. @@ -97,7 +97,7 @@ public final class ParseExample extends RawOp { * @return a new instance of ParseExample */ @Endpoint(describeByClass = true) - public static ParseExample create(Scope scope, Operand serialized, Operand names, Operand sparseKeys, Operand denseKeys, Operand raggedKeys, Iterable> denseDefaults, Long numSparse, List> sparseTypes, List> raggedValueTypes, List> raggedSplitTypes, List denseShapes) { + public static ParseExample create(Scope scope, Operand serialized, Operand names, Operand sparseKeys, Operand denseKeys, Operand raggedKeys, Iterable> denseDefaults, Long numSparse, List> sparseTypes, List> raggedValueTypes, List> raggedSplitTypes, List denseShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ParseExampleV2", scope.makeOpName("ParseExample")); opBuilder.addInput(serialized.asOutput()); opBuilder.addInput(names.asOutput()); @@ -107,21 +107,9 @@ public static ParseExample create(Scope scope, Operand serialized, Oper opBuilder.addInputList(Operands.asOutputs(denseDefaults)); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("num_sparse", numSparse); - DataType[] sparseTypesArray = new DataType[sparseTypes.size()]; - for (int i = 0; i < sparseTypesArray.length; ++i) { - sparseTypesArray[i] = sparseTypes.get(i); - } - opBuilder.setAttr("sparse_types", sparseTypesArray); - DataType[] raggedValueTypesArray = new DataType[raggedValueTypes.size()]; - for (int i = 0; i < raggedValueTypesArray.length; ++i) { - raggedValueTypesArray[i] = raggedValueTypes.get(i); - } - opBuilder.setAttr("ragged_value_types", raggedValueTypesArray); - DataType[] raggedSplitTypesArray = new DataType[raggedSplitTypes.size()]; - for (int i = 0; i < raggedSplitTypesArray.length; ++i) { - raggedSplitTypesArray[i] = raggedSplitTypes.get(i); - } - opBuilder.setAttr("ragged_split_types", raggedSplitTypesArray); + opBuilder.setAttr("sparse_types", Operands.toDataTypes(sparseTypes)); + opBuilder.setAttr("ragged_value_types", Operands.toDataTypes(raggedValueTypes)); + opBuilder.setAttr("ragged_split_types", Operands.toDataTypes(raggedSplitTypes)); Shape[] denseShapesArray = new Shape[denseShapes.size()]; for (int i = 0; i < denseShapesArray.length; ++i) { denseShapesArray[i] = denseShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSequenceExample.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSequenceExample.java index cb798f6ada0..b4c9479d7c7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSequenceExample.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSequenceExample.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -34,6 +33,7 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; /** * Transforms a vector of tf.io.SequenceExample protos (as strings) into @@ -153,7 +153,7 @@ private Options() { * @return a new instance of ParseSequenceExample */ @Endpoint(describeByClass = true) - public static ParseSequenceExample create(Scope scope, Operand serialized, Operand debugName, Operand contextSparseKeys, Operand contextDenseKeys, Operand contextRaggedKeys, Operand featureListSparseKeys, Operand featureListDenseKeys, Operand featureListRaggedKeys, Operand featureListDenseMissingAssumedEmpty, Iterable> contextDenseDefaults, List> contextSparseTypes, List> contextRaggedValueTypes, List> contextRaggedSplitTypes, List> featureListDenseTypes, List> featureListSparseTypes, List> featureListRaggedValueTypes, List> featureListRaggedSplitTypes, Options... options) { + public static ParseSequenceExample create(Scope scope, Operand serialized, Operand debugName, Operand contextSparseKeys, Operand contextDenseKeys, Operand contextRaggedKeys, Operand featureListSparseKeys, Operand featureListDenseKeys, Operand featureListRaggedKeys, Operand featureListDenseMissingAssumedEmpty, Iterable> contextDenseDefaults, List> contextSparseTypes, List> contextRaggedValueTypes, List> contextRaggedSplitTypes, List> featureListDenseTypes, List> featureListSparseTypes, List> featureListRaggedValueTypes, List> featureListRaggedSplitTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ParseSequenceExampleV2", scope.makeOpName("ParseSequenceExample")); opBuilder.addInput(serialized.asOutput()); opBuilder.addInput(debugName.asOutput()); @@ -166,41 +166,13 @@ public static ParseSequenceExample create(Scope scope, Operand serializ opBuilder.addInput(featureListDenseMissingAssumedEmpty.asOutput()); opBuilder.addInputList(Operands.asOutputs(contextDenseDefaults)); opBuilder = scope.apply(opBuilder); - DataType[] contextSparseTypesArray = new DataType[contextSparseTypes.size()]; - for (int i = 0; i < contextSparseTypesArray.length; ++i) { - contextSparseTypesArray[i] = contextSparseTypes.get(i); - } - opBuilder.setAttr("context_sparse_types", contextSparseTypesArray); - DataType[] contextRaggedValueTypesArray = new DataType[contextRaggedValueTypes.size()]; - for (int i = 0; i < contextRaggedValueTypesArray.length; ++i) { - contextRaggedValueTypesArray[i] = contextRaggedValueTypes.get(i); - } - opBuilder.setAttr("context_ragged_value_types", contextRaggedValueTypesArray); - DataType[] contextRaggedSplitTypesArray = new DataType[contextRaggedSplitTypes.size()]; - for (int i = 0; i < contextRaggedSplitTypesArray.length; ++i) { - contextRaggedSplitTypesArray[i] = contextRaggedSplitTypes.get(i); - } - opBuilder.setAttr("context_ragged_split_types", contextRaggedSplitTypesArray); - DataType[] featureListDenseTypesArray = new DataType[featureListDenseTypes.size()]; - for (int i = 0; i < featureListDenseTypesArray.length; ++i) { - featureListDenseTypesArray[i] = featureListDenseTypes.get(i); - } - opBuilder.setAttr("feature_list_dense_types", featureListDenseTypesArray); - DataType[] featureListSparseTypesArray = new DataType[featureListSparseTypes.size()]; - for (int i = 0; i < featureListSparseTypesArray.length; ++i) { - featureListSparseTypesArray[i] = featureListSparseTypes.get(i); - } - opBuilder.setAttr("feature_list_sparse_types", featureListSparseTypesArray); - DataType[] featureListRaggedValueTypesArray = new DataType[featureListRaggedValueTypes.size()]; - for (int i = 0; i < featureListRaggedValueTypesArray.length; ++i) { - featureListRaggedValueTypesArray[i] = featureListRaggedValueTypes.get(i); - } - opBuilder.setAttr("feature_list_ragged_value_types", featureListRaggedValueTypesArray); - DataType[] featureListRaggedSplitTypesArray = new DataType[featureListRaggedSplitTypes.size()]; - for (int i = 0; i < featureListRaggedSplitTypesArray.length; ++i) { - featureListRaggedSplitTypesArray[i] = featureListRaggedSplitTypes.get(i); - } - opBuilder.setAttr("feature_list_ragged_split_types", featureListRaggedSplitTypesArray); + opBuilder.setAttr("context_sparse_types", Operands.toDataTypes(contextSparseTypes)); + opBuilder.setAttr("context_ragged_value_types", Operands.toDataTypes(contextRaggedValueTypes)); + opBuilder.setAttr("context_ragged_split_types", Operands.toDataTypes(contextRaggedSplitTypes)); + opBuilder.setAttr("feature_list_dense_types", Operands.toDataTypes(featureListDenseTypes)); + opBuilder.setAttr("feature_list_sparse_types", Operands.toDataTypes(featureListSparseTypes)); + opBuilder.setAttr("feature_list_ragged_value_types", Operands.toDataTypes(featureListRaggedValueTypes)); + opBuilder.setAttr("feature_list_ragged_split_types", Operands.toDataTypes(featureListRaggedSplitTypes)); if (options != null) { for (Options opts : options) { if (opts.NcontextSparse != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleExample.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleExample.java index a18f4d69e32..f860fec4dc4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleExample.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleExample.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -32,6 +31,7 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * Transforms a tf.Example proto (as a string) into typed tensors. @@ -76,7 +76,7 @@ public final class ParseSingleExample extends RawOp { * @return a new instance of ParseSingleExample */ @Endpoint(describeByClass = true) - public static ParseSingleExample create(Scope scope, Operand serialized, Iterable> denseDefaults, Long numSparse, List sparseKeys, List denseKeys, List> sparseTypes, List denseShapes) { + public static ParseSingleExample create(Scope scope, Operand serialized, Iterable> denseDefaults, Long numSparse, List sparseKeys, List denseKeys, List> sparseTypes, List denseShapes) { OperationBuilder opBuilder = scope.env().opBuilder("ParseSingleExample", scope.makeOpName("ParseSingleExample")); opBuilder.addInput(serialized.asOutput()); opBuilder.addInputList(Operands.asOutputs(denseDefaults)); @@ -92,11 +92,7 @@ public static ParseSingleExample create(Scope scope, Operand serialized denseKeysArray[i] = denseKeys.get(i); } opBuilder.setAttr("dense_keys", denseKeysArray); - DataType[] sparseTypesArray = new DataType[sparseTypes.size()]; - for (int i = 0; i < sparseTypesArray.length; ++i) { - sparseTypesArray[i] = sparseTypes.get(i); - } - opBuilder.setAttr("sparse_types", sparseTypesArray); + opBuilder.setAttr("sparse_types", Operands.toDataTypes(sparseTypes)); Shape[] denseShapesArray = new Shape[denseShapes.size()]; for (int i = 0; i < denseShapesArray.length; ++i) { denseShapesArray[i] = denseShapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleSequenceExample.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleSequenceExample.java index 67c6ea97ac1..b54be0324b9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleSequenceExample.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseSingleSequenceExample.java @@ -19,7 +19,6 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -32,6 +31,7 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. @@ -122,7 +122,7 @@ private Options() { * @return a new instance of ParseSingleSequenceExample */ @Endpoint(describeByClass = true) - public static ParseSingleSequenceExample create(Scope scope, Operand serialized, Operand featureListDenseMissingAssumedEmpty, Iterable> contextSparseKeys, Iterable> contextDenseKeys, Iterable> featureListSparseKeys, Iterable> featureListDenseKeys, Iterable> contextDenseDefaults, Operand debugName, List> contextSparseTypes, List> featureListDenseTypes, List> featureListSparseTypes, Options... options) { + public static ParseSingleSequenceExample create(Scope scope, Operand serialized, Operand featureListDenseMissingAssumedEmpty, Iterable> contextSparseKeys, Iterable> contextDenseKeys, Iterable> featureListSparseKeys, Iterable> featureListDenseKeys, Iterable> contextDenseDefaults, Operand debugName, List> contextSparseTypes, List> featureListDenseTypes, List> featureListSparseTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ParseSingleSequenceExample", scope.makeOpName("ParseSingleSequenceExample")); opBuilder.addInput(serialized.asOutput()); opBuilder.addInput(featureListDenseMissingAssumedEmpty.asOutput()); @@ -133,21 +133,9 @@ public static ParseSingleSequenceExample create(Scope scope, Operand se opBuilder.addInputList(Operands.asOutputs(contextDenseDefaults)); opBuilder.addInput(debugName.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] contextSparseTypesArray = new DataType[contextSparseTypes.size()]; - for (int i = 0; i < contextSparseTypesArray.length; ++i) { - contextSparseTypesArray[i] = contextSparseTypes.get(i); - } - opBuilder.setAttr("context_sparse_types", contextSparseTypesArray); - DataType[] featureListDenseTypesArray = new DataType[featureListDenseTypes.size()]; - for (int i = 0; i < featureListDenseTypesArray.length; ++i) { - featureListDenseTypesArray[i] = featureListDenseTypes.get(i); - } - opBuilder.setAttr("feature_list_dense_types", featureListDenseTypesArray); - DataType[] featureListSparseTypesArray = new DataType[featureListSparseTypes.size()]; - for (int i = 0; i < featureListSparseTypesArray.length; ++i) { - featureListSparseTypesArray[i] = featureListSparseTypes.get(i); - } - opBuilder.setAttr("feature_list_sparse_types", featureListSparseTypesArray); + opBuilder.setAttr("context_sparse_types", Operands.toDataTypes(contextSparseTypes)); + opBuilder.setAttr("feature_list_dense_types", Operands.toDataTypes(featureListDenseTypes)); + opBuilder.setAttr("feature_list_sparse_types", Operands.toDataTypes(featureListSparseTypes)); if (options != null) { for (Options opts : options) { if (opts.contextDenseShapes != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseTensor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseTensor.java index 993141554a6..9af89bdcd2d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ParseTensor.java @@ -17,11 +17,11 @@ package org.tensorflow.op.io; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,11 +47,11 @@ public final class ParseTensor extends RawOp implements Operand * @return a new instance of ParseTensor */ @Endpoint(describeByClass = true) - public static ParseTensor create(Scope scope, Operand serialized, DataType outType) { + public static ParseTensor create(Scope scope, Operand serialized, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("ParseTensor", scope.makeOpName("ParseTensor")); opBuilder.addInput(serialized.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new ParseTensor(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PriorityQueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PriorityQueue.java index b49a498fb0e..9decf07b572 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PriorityQueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/PriorityQueue.java @@ -18,12 +18,12 @@ package org.tensorflow.op.io; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -95,14 +95,10 @@ private Options() { * @return a new instance of PriorityQueue */ @Endpoint(describeByClass = true) - public static PriorityQueue create(Scope scope, List> componentTypes, List shapes, Options... options) { + public static PriorityQueue create(Scope scope, List> componentTypes, List shapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("PriorityQueueV2", scope.makeOpName("PriorityQueue")); opBuilder = scope.apply(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; - for (int i = 0; i < componentTypesArray.length; ++i) { - componentTypesArray[i] = componentTypes.get(i); - } - opBuilder.setAttr("component_types", componentTypesArray); + opBuilder.setAttr("component_types", Operands.toDataTypes(componentTypes)); Shape[] shapesArray = new Shape[shapes.size()]; for (int i = 0; i < shapesArray.length; ++i) { shapesArray[i] = shapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeue.java index 273ff1c91a5..b4d4bb71aaf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeue.java @@ -20,11 +20,11 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -75,15 +75,11 @@ private Options() { * @return a new instance of QueueDequeue */ @Endpoint(describeByClass = true) - public static QueueDequeue create(Scope scope, Operand handle, List> componentTypes, Options... options) { + public static QueueDequeue create(Scope scope, Operand handle, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QueueDequeueV2", scope.makeOpName("QueueDequeue")); opBuilder.addInput(handle.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; - for (int i = 0; i < componentTypesArray.length; ++i) { - componentTypesArray[i] = componentTypes.get(i); - } - opBuilder.setAttr("component_types", componentTypesArray); + opBuilder.setAttr("component_types", Operands.toDataTypes(componentTypes)); if (options != null) { for (Options opts : options) { if (opts.timeoutMs != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueMany.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueMany.java index 229c544e139..1ef3426f354 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueMany.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueMany.java @@ -20,11 +20,11 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -84,16 +84,12 @@ private Options() { * @return a new instance of QueueDequeueMany */ @Endpoint(describeByClass = true) - public static QueueDequeueMany create(Scope scope, Operand handle, Operand n, List> componentTypes, Options... options) { + public static QueueDequeueMany create(Scope scope, Operand handle, Operand n, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QueueDequeueManyV2", scope.makeOpName("QueueDequeueMany")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(n.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; - for (int i = 0; i < componentTypesArray.length; ++i) { - componentTypesArray[i] = componentTypes.get(i); - } - opBuilder.setAttr("component_types", componentTypesArray); + opBuilder.setAttr("component_types", Operands.toDataTypes(componentTypes)); if (options != null) { for (Options opts : options) { if (opts.timeoutMs != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueUpTo.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueUpTo.java index 332eefb5096..53516d87ed5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueUpTo.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueDequeueUpTo.java @@ -20,11 +20,11 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -88,16 +88,12 @@ private Options() { * @return a new instance of QueueDequeueUpTo */ @Endpoint(describeByClass = true) - public static QueueDequeueUpTo create(Scope scope, Operand handle, Operand n, List> componentTypes, Options... options) { + public static QueueDequeueUpTo create(Scope scope, Operand handle, Operand n, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QueueDequeueUpToV2", scope.makeOpName("QueueDequeueUpTo")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(n.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; - for (int i = 0; i < componentTypesArray.length; ++i) { - componentTypesArray[i] = componentTypes.get(i); - } - opBuilder.setAttr("component_types", componentTypesArray); + opBuilder.setAttr("component_types", Operands.toDataTypes(componentTypes)); if (options != null) { for (Options opts : options) { if (opts.timeoutMs != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/RandomShuffleQueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/RandomShuffleQueue.java index da168c4f8a8..9ce330cb960 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/RandomShuffleQueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/RandomShuffleQueue.java @@ -18,12 +18,12 @@ package org.tensorflow.op.io; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -127,14 +127,10 @@ private Options() { * @return a new instance of RandomShuffleQueue */ @Endpoint(describeByClass = true) - public static RandomShuffleQueue create(Scope scope, List> componentTypes, Options... options) { + public static RandomShuffleQueue create(Scope scope, List> componentTypes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("RandomShuffleQueueV2", scope.makeOpName("RandomShuffleQueue")); opBuilder = scope.apply(opBuilder); - DataType[] componentTypesArray = new DataType[componentTypes.size()]; - for (int i = 0; i < componentTypesArray.length; ++i) { - componentTypesArray[i] = componentTypes.get(i); - } - opBuilder.setAttr("component_types", componentTypesArray); + opBuilder.setAttr("component_types", Operands.toDataTypes(componentTypes)); if (options != null) { for (Options opts : options) { if (opts.shapes != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeManySparse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeManySparse.java index a513c0b8846..8a712e7278b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeManySparse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeManySparse.java @@ -17,11 +17,11 @@ package org.tensorflow.op.io; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -58,13 +58,13 @@ public final class SerializeManySparse extends RawOp implements * @return a new instance of SerializeManySparse */ @Endpoint(describeByClass = true) - public static SerializeManySparse create(Scope scope, Operand sparseIndices, Operand sparseValues, Operand sparseShape, DataType outType) { + public static SerializeManySparse create(Scope scope, Operand sparseIndices, Operand sparseValues, Operand sparseShape, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("SerializeManySparse", scope.makeOpName("SerializeManySparse")); opBuilder.addInput(sparseIndices.asOutput()); opBuilder.addInput(sparseValues.asOutput()); opBuilder.addInput(sparseShape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new SerializeManySparse(opBuilder.build()); } @@ -79,7 +79,7 @@ public static SerializeManySparse create(S */ @Endpoint(describeByClass = true) public static SerializeManySparse create(Scope scope, Operand sparseIndices, Operand sparseValues, Operand sparseShape) { - return create(scope, sparseIndices, sparseValues, sparseShape, TString.DTYPE); + return create(scope, sparseIndices, sparseValues, sparseShape, TString.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeSparse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeSparse.java index a817d2f6cd8..611feb50188 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeSparse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/SerializeSparse.java @@ -17,11 +17,11 @@ package org.tensorflow.op.io; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -50,13 +50,13 @@ public final class SerializeSparse extends RawOp implements Ope * @return a new instance of SerializeSparse */ @Endpoint(describeByClass = true) - public static SerializeSparse create(Scope scope, Operand sparseIndices, Operand sparseValues, Operand sparseShape, DataType outType) { + public static SerializeSparse create(Scope scope, Operand sparseIndices, Operand sparseValues, Operand sparseShape, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("SerializeSparse", scope.makeOpName("SerializeSparse")); opBuilder.addInput(sparseIndices.asOutput()); opBuilder.addInput(sparseValues.asOutput()); opBuilder.addInput(sparseShape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new SerializeSparse(opBuilder.build()); } @@ -71,7 +71,7 @@ public static SerializeSparse create(Scope */ @Endpoint(describeByClass = true) public static SerializeSparse create(Scope scope, Operand sparseIndices, Operand sparseValues, Operand sparseShape) { - return create(scope, sparseIndices, sparseValues, sparseShape, TString.DTYPE); + return create(scope, sparseIndices, sparseValues, sparseShape, TString.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholesky.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholesky.java index e5399fa96ef..a7ba5c36639 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholesky.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholesky.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholeskyGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholeskyGrad.java index a71729b2b02..515e582f254 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholeskyGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchCholeskyGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixInverse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixInverse.java index d9012f7615a..709a59404b9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixInverse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixInverse.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolve.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolve.java index 0c5cfddf80c..64559a23676 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolve.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolve.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolveLs.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolveLs.java index 7c4b17062b1..71490a6b153 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolveLs.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixSolveLs.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixTriangularSolve.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixTriangularSolve.java index 5e4a64c943d..7bccee674a9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixTriangularSolve.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchMatrixTriangularSolve.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchSelfAdjointEig.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchSelfAdjointEig.java index 5f060b01f47..450a7ee419d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchSelfAdjointEig.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/BatchSelfAdjointEig.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code e()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/CholeskyGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/CholeskyGrad.java index 062178ecd74..5a623ebd876 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/CholeskyGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/CholeskyGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the reverse mode backpropagated gradient of the Cholesky algorithm. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Cross.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Cross.java index 3b02aa71ba7..b62fcbf512c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Cross.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Cross.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the pairwise cross product. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Eig.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Eig.java index e595bc9ec4a..ac976cc929e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Eig.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Eig.java @@ -17,11 +17,11 @@ package org.tensorflow.op.linalg; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -78,11 +78,11 @@ private Options() { * @return a new instance of Eig */ @Endpoint(describeByClass = true) - public static Eig create(Scope scope, Operand input, DataType Tout, Options... options) { + public static Eig create(Scope scope, Operand input, Class Tout, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Eig", scope.makeOpName("Eig")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Tout", Tout); + opBuilder.setAttr("Tout", Operands.toDataType(Tout)); if (options != null) { for (Options opts : options) { if (opts.computeV != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Lu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Lu.java index 412fa76ba75..3edb62f927a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Lu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/Lu.java @@ -17,11 +17,11 @@ package org.tensorflow.op.linalg; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -67,11 +67,11 @@ public final class Lu extends RawOp { * @return a new instance of Lu */ @Endpoint(describeByClass = true) - public static Lu create(Scope scope, Operand input, DataType outputIdxType) { + public static Lu create(Scope scope, Operand input, Class outputIdxType) { OperationBuilder opBuilder = scope.env().opBuilder("Lu", scope.makeOpName("Lu")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("output_idx_type", outputIdxType); + opBuilder.setAttr("output_idx_type", Operands.toDataType(outputIdxType)); return new Lu(opBuilder.build()); } @@ -85,7 +85,7 @@ public static Lu create(Scope scope, */ @Endpoint(describeByClass = true) public static Lu create(Scope scope, Operand input) { - return create(scope, input, TInt32.DTYPE); + return create(scope, input, TInt32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMul.java index 7f29bb51c6d..517578fe83e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMul.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMul.java @@ -17,11 +17,11 @@ package org.tensorflow.op.linalg; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -87,7 +87,7 @@ private Options() { * @return a new instance of QuantizedMatMul */ @Endpoint(describeByClass = true) - public static QuantizedMatMul create(Scope scope, Operand a, Operand b, Operand minA, Operand maxA, Operand minB, Operand maxB, DataType Toutput, DataType Tactivation, Options... options) { + public static QuantizedMatMul create(Scope scope, Operand a, Operand b, Operand minA, Operand maxA, Operand minB, Operand maxB, Class Toutput, Class Tactivation, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMatMul", scope.makeOpName("QuantizedMatMul")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); @@ -96,8 +96,8 @@ public static QuantizedMatMulWithBias create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, DataType Toutput, Options... options) { + public static QuantizedMatMulWithBias create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Class Toutput, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMatMulWithBias", scope.makeOpName("QuantizedMatMulWithBias")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); @@ -107,7 +107,7 @@ public static QuantizedMatMulWithBiasAndRelu create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, DataType Toutput, Options... options) { + public static QuantizedMatMulWithBiasAndRelu create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Class Toutput, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMatMulWithBiasAndRelu", scope.makeOpName("QuantizedMatMulWithBiasAndRelu")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); @@ -108,7 +108,7 @@ public static QuantizedMatMu opBuilder.addInput(minB.asOutput()); opBuilder.addInput(maxB.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Toutput", Toutput); + opBuilder.setAttr("Toutput", Operands.toDataType(Toutput)); if (options != null) { for (Options opts : options) { if (opts.transposeA != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBiasAndReluAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBiasAndReluAndRequantize.java index a7a02d36bc5..a20287446c7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBiasAndReluAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/QuantizedMatMulWithBiasAndReluAndRequantize.java @@ -17,11 +17,11 @@ package org.tensorflow.op.linalg; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -101,7 +101,7 @@ private Options() { * @return a new instance of QuantizedMatMulWithBiasAndReluAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedMatMulWithBiasAndReluAndRequantize create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Operand minFreezedOutput, Operand maxFreezedOutput, DataType Toutput, Options... options) { + public static QuantizedMatMulWithBiasAndReluAndRequantize create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Operand minFreezedOutput, Operand maxFreezedOutput, Class Toutput, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMatMulWithBiasAndReluAndRequantize", scope.makeOpName("QuantizedMatMulWithBiasAndReluAndRequantize")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); @@ -113,7 +113,7 @@ public static extends RawOp { * @return a new instance of CSRSparseMatrixComponents */ @Endpoint(describeByClass = true) - public static CSRSparseMatrixComponents create(Scope scope, Operand csrSparseMatrix, Operand index, DataType type) { + public static CSRSparseMatrixComponents create(Scope scope, Operand csrSparseMatrix, Operand index, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("CSRSparseMatrixComponents", scope.makeOpName("CSRSparseMatrixComponents")); opBuilder.addInput(csrSparseMatrix.asOutput()); opBuilder.addInput(index.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("type", type); + opBuilder.setAttr("type", Operands.toDataType(type)); return new CSRSparseMatrixComponents(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToDense.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToDense.java index 456464df341..2c42786fb53 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToDense.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToDense.java @@ -17,11 +17,11 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -44,11 +44,11 @@ public final class CSRSparseMatrixToDense extends RawOp impleme * @return a new instance of CSRSparseMatrixToDense */ @Endpoint(describeByClass = true) - public static CSRSparseMatrixToDense create(Scope scope, Operand sparseInput, DataType type) { + public static CSRSparseMatrixToDense create(Scope scope, Operand sparseInput, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("CSRSparseMatrixToDense", scope.makeOpName("CSRSparseMatrixToDense")); opBuilder.addInput(sparseInput.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("type", type); + opBuilder.setAttr("type", Operands.toDataType(type)); return new CSRSparseMatrixToDense(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToSparseTensor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToSparseTensor.java index 8ac5ccfdf0b..3b88662432f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToSparseTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/CSRSparseMatrixToSparseTensor.java @@ -17,11 +17,11 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -45,11 +45,11 @@ public final class CSRSparseMatrixToSparseTensor extends RawOp * @return a new instance of CSRSparseMatrixToSparseTensor */ @Endpoint(describeByClass = true) - public static CSRSparseMatrixToSparseTensor create(Scope scope, Operand sparseMatrix, DataType type) { + public static CSRSparseMatrixToSparseTensor create(Scope scope, Operand sparseMatrix, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("CSRSparseMatrixToSparseTensor", scope.makeOpName("CSRSparseMatrixToSparseTensor")); opBuilder.addInput(sparseMatrix.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("type", type); + opBuilder.setAttr("type", Operands.toDataType(type)); return new CSRSparseMatrixToSparseTensor(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmax.java index ae1f276fff8..973a63ade9d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmax.java @@ -17,11 +17,11 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -49,11 +49,11 @@ public final class SparseMatrixSoftmax extends RawOp implements Operand { * @return a new instance of SparseMatrixSoftmax */ @Endpoint(describeByClass = true) - public static SparseMatrixSoftmax create(Scope scope, Operand logits, DataType type) { + public static SparseMatrixSoftmax create(Scope scope, Operand logits, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("SparseMatrixSoftmax", scope.makeOpName("SparseMatrixSoftmax")); opBuilder.addInput(logits.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("type", type); + opBuilder.setAttr("type", Operands.toDataType(type)); return new SparseMatrixSoftmax(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmaxGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmaxGrad.java index 14d3900cf9d..3de1b6edc1b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmaxGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSoftmaxGrad.java @@ -17,11 +17,11 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -44,12 +44,12 @@ public final class SparseMatrixSoftmaxGrad extends RawOp implements Operand SparseMatrixSoftmaxGrad create(Scope scope, Operand softmax, Operand gradSoftmax, DataType type) { + public static SparseMatrixSoftmaxGrad create(Scope scope, Operand softmax, Operand gradSoftmax, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("SparseMatrixSoftmaxGrad", scope.makeOpName("SparseMatrixSoftmaxGrad")); opBuilder.addInput(softmax.asOutput()); opBuilder.addInput(gradSoftmax.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("type", type); + opBuilder.setAttr("type", Operands.toDataType(type)); return new SparseMatrixSoftmaxGrad(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseCholesky.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseCholesky.java index 1fc8703ac5c..91ec5e652a4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseCholesky.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseCholesky.java @@ -17,11 +17,11 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -115,12 +115,12 @@ public final class SparseMatrixSparseCholesky extends RawOp implements Operand SparseMatrixSparseCholesky create(Scope scope, Operand input, Operand permutation, DataType type) { + public static SparseMatrixSparseCholesky create(Scope scope, Operand input, Operand permutation, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("SparseMatrixSparseCholesky", scope.makeOpName("SparseMatrixSparseCholesky")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(permutation.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("type", type); + opBuilder.setAttr("type", Operands.toDataType(type)); return new SparseMatrixSparseCholesky(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseMatMul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseMatMul.java index c8d48b4269c..2e47c1ca6ab 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseMatMul.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixSparseMatMul.java @@ -17,11 +17,11 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -163,12 +163,12 @@ private Options() { * @return a new instance of SparseMatrixSparseMatMul */ @Endpoint(describeByClass = true) - public static SparseMatrixSparseMatMul create(Scope scope, Operand a, Operand b, DataType type, Options... options) { + public static SparseMatrixSparseMatMul create(Scope scope, Operand a, Operand b, Class type, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("SparseMatrixSparseMatMul", scope.makeOpName("SparseMatrixSparseMatMul")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("type", type); + opBuilder.setAttr("type", Operands.toDataType(type)); if (options != null) { for (Options opts : options) { if (opts.transposeA != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixTranspose.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixTranspose.java index e4149a7195b..f1f71b33b4f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixTranspose.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixTranspose.java @@ -17,11 +17,11 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -65,11 +65,11 @@ private Options() { * @return a new instance of SparseMatrixTranspose */ @Endpoint(describeByClass = true) - public static SparseMatrixTranspose create(Scope scope, Operand input, DataType type, Options... options) { + public static SparseMatrixTranspose create(Scope scope, Operand input, Class type, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("SparseMatrixTranspose", scope.makeOpName("SparseMatrixTranspose")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("type", type); + opBuilder.setAttr("type", Operands.toDataType(type)); if (options != null) { for (Options opts : options) { if (opts.conjugate != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixZeros.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixZeros.java index c51571e054b..eea60f2298a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixZeros.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/sparse/SparseMatrixZeros.java @@ -17,11 +17,11 @@ package org.tensorflow.op.linalg.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -43,11 +43,11 @@ public final class SparseMatrixZeros extends RawOp implements Operand { * @return a new instance of SparseMatrixZeros */ @Endpoint(describeByClass = true) - public static SparseMatrixZeros create(Scope scope, Operand denseShape, DataType type) { + public static SparseMatrixZeros create(Scope scope, Operand denseShape, Class type) { OperationBuilder opBuilder = scope.env().opBuilder("SparseMatrixZeros", scope.makeOpName("SparseMatrixZeros")); opBuilder.addInput(denseShape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("type", type); + opBuilder.setAttr("type", Operands.toDataType(type)); return new SparseMatrixZeros(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Abs.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Abs.java index 32f68df69cc..dd31fbdc711 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Abs.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Abs.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the absolute value of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Angle.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Angle.java index ed03e60c54c..9187006ac28 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Angle.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Angle.java @@ -17,11 +17,11 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -63,11 +63,11 @@ public final class Angle extends RawOp implements Operand * @return a new instance of Angle */ @Endpoint(describeByClass = true) - public static Angle create(Scope scope, Operand input, DataType Tout) { + public static Angle create(Scope scope, Operand input, Class Tout) { OperationBuilder opBuilder = scope.env().opBuilder("Angle", scope.makeOpName("Angle")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Tout", Tout); + opBuilder.setAttr("Tout", Operands.toDataType(Tout)); return new Angle(opBuilder.build()); } @@ -80,7 +80,7 @@ public static Angle create(Scope scope, */ @Endpoint(describeByClass = true) public static Angle create(Scope scope, Operand input) { - return create(scope, input, TFloat32.DTYPE); + return create(scope, input, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMax.java index f21885ef002..2672a53fc83 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMax.java @@ -17,11 +17,11 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -63,12 +63,12 @@ public final class ArgMax extends RawOp implements Operand * @return a new instance of ArgMax */ @Endpoint(describeByClass = true) - public static ArgMax create(Scope scope, Operand input, Operand dimension, DataType outputType) { + public static ArgMax create(Scope scope, Operand input, Operand dimension, Class outputType) { OperationBuilder opBuilder = scope.env().opBuilder("ArgMax", scope.makeOpName("ArgMax")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(dimension.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("output_type", outputType); + opBuilder.setAttr("output_type", Operands.toDataType(outputType)); return new ArgMax(opBuilder.build()); } @@ -84,7 +84,7 @@ public static ArgMax */ @Endpoint(describeByClass = true) public static ArgMax create(Scope scope, Operand input, Operand dimension) { - return create(scope, input, dimension, TInt64.DTYPE); + return create(scope, input, dimension, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMin.java index 4df20c59ca9..be484929189 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ArgMin.java @@ -17,11 +17,11 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -63,12 +63,12 @@ public final class ArgMin extends RawOp implements Operand * @return a new instance of ArgMin */ @Endpoint(describeByClass = true) - public static ArgMin create(Scope scope, Operand input, Operand dimension, DataType outputType) { + public static ArgMin create(Scope scope, Operand input, Operand dimension, Class outputType) { OperationBuilder opBuilder = scope.env().opBuilder("ArgMin", scope.makeOpName("ArgMin")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(dimension.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("output_type", outputType); + opBuilder.setAttr("output_type", Operands.toDataType(outputType)); return new ArgMin(opBuilder.build()); } @@ -84,7 +84,7 @@ public static ArgMin */ @Endpoint(describeByClass = true) public static ArgMin create(Scope scope, Operand input, Operand dimension) { - return create(scope, input, dimension, TInt64.DTYPE); + return create(scope, input, dimension, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Atan2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Atan2.java index 1f6912eeaf0..e8ad73ca897 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Atan2.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Atan2.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes arctangent of `y/x` element-wise, respecting signs of the arguments. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java index 1569151d185..ec9c56c199e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0e.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0e.java index 7d28d75e7be..26cfa68f377 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0e.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0e.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java index 38d6c2bf2d2..4eb2b0d38da 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1e.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1e.java index 083ab146ff4..8eed4770299 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1e.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1e.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Betainc.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Betainc.java index f917f35ec86..86caad140d5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Betainc.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Betainc.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the regularized incomplete beta integral \\(I_x(a, b)\\). diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Bincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Bincount.java index 6d666696e56..9e8f03af4cc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Bincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Bincount.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Counts the number of occurrences of each value in an integer array. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ceil.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ceil.java index cc1138a827a..ab36643898b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ceil.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ceil.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns element-wise smallest integer not less than x. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ComplexAbs.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ComplexAbs.java index dfbef1e184e..ce3c3607215 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ComplexAbs.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/ComplexAbs.java @@ -17,11 +17,11 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -52,11 +52,11 @@ public final class ComplexAbs extends RawOp implements Operan * @return a new instance of ComplexAbs */ @Endpoint(describeByClass = true) - public static ComplexAbs create(Scope scope, Operand x, DataType Tout) { + public static ComplexAbs create(Scope scope, Operand x, Class Tout) { OperationBuilder opBuilder = scope.env().opBuilder("ComplexAbs", scope.makeOpName("ComplexAbs")); opBuilder.addInput(x.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Tout", Tout); + opBuilder.setAttr("Tout", Operands.toDataType(Tout)); return new ComplexAbs(opBuilder.build()); } @@ -69,7 +69,7 @@ public static ComplexAbs create(Scope sc */ @Endpoint(describeByClass = true) public static ComplexAbs create(Scope scope, Operand x) { - return create(scope, x, TFloat32.DTYPE); + return create(scope, x, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/CumulativeLogsumexp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/CumulativeLogsumexp.java index d0e285db1eb..4e53bb981c0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/CumulativeLogsumexp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/CumulativeLogsumexp.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the cumulative product of the tensor `x` along `axis`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java index 786a0c93f6f..d6a9d95e4d8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Counts the number of occurrences of each value in an integer array. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Digamma.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Digamma.java index f26827cde84..c64f3b40f6d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Digamma.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Digamma.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes Psi, the derivative of Lgamma (the log of the absolute value of diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erf.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erf.java index 0df38217ffd..eb723c89f2a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erf.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erf.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the Gauss error function of `x` element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erfc.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erfc.java index 1e61da00dad..123ce651d03 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erfc.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Erfc.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the complementary error function of `x` element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Floor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Floor.java index d6664846c30..2fa22d7a60c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Floor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Floor.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns element-wise largest integer not greater than x. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/FloorMod.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/FloorMod.java index 9c8706740fc..18c6c4fd34d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/FloorMod.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/FloorMod.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns element-wise remainder of division. When `x < 0` xor `y < 0` is diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Greater.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Greater.java index 25d7249a489..3db7896647a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Greater.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Greater.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the truth value of (x > y) element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/GreaterEqual.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/GreaterEqual.java index 84905869eae..abeb67c3e66 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/GreaterEqual.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/GreaterEqual.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the truth value of (x >= y) element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igamma.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igamma.java index ab0b7406aa9..7cb4c526e14 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igamma.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igamma.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the lower regularized incomplete Gamma function `P(a, x)`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IgammaGradA.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IgammaGradA.java index 2be04b3b61f..e81622a3391 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IgammaGradA.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IgammaGradA.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of `igamma(a, x)` wrt `a`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igammac.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igammac.java index 24203335180..ab67c74b7b5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igammac.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Igammac.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the upper regularized incomplete Gamma function `Q(a, x)`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Imag.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Imag.java index ce2f523a282..a983a950dff 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Imag.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Imag.java @@ -17,11 +17,11 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -59,11 +59,11 @@ public final class Imag extends RawOp implements Operand { * @return a new instance of Imag */ @Endpoint(describeByClass = true) - public static Imag create(Scope scope, Operand input, DataType Tout) { + public static Imag create(Scope scope, Operand input, Class Tout) { OperationBuilder opBuilder = scope.env().opBuilder("Imag", scope.makeOpName("Imag")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Tout", Tout); + opBuilder.setAttr("Tout", Operands.toDataType(Tout)); return new Imag(opBuilder.build()); } @@ -76,7 +76,7 @@ public static Imag create(Scope scope, O */ @Endpoint(describeByClass = true) public static Imag create(Scope scope, Operand input) { - return create(scope, input, TFloat32.DTYPE); + return create(scope, input, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/InvertPermutation.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/InvertPermutation.java index e24943d7b8f..04373176367 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/InvertPermutation.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/InvertPermutation.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the inverse permutation of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsFinite.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsFinite.java index 4d1abb1e31b..516f2330b39 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsFinite.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsFinite.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns which elements of x are finite. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsInf.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsInf.java index a8638f0642a..ac1cf7946cc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsInf.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsInf.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns which elements of x are Inf. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsNan.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsNan.java index e31e8ddf29e..f58f6a4aa38 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsNan.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/IsNan.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns which elements of x are NaN. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Less.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Less.java index 186a9b2b61e..9564f1b84ff 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Less.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Less.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the truth value of (x < y) element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/LessEqual.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/LessEqual.java index ae724a8c853..10c0b7647cb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/LessEqual.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/LessEqual.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the truth value of (x <= y) element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Lgamma.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Lgamma.java index 176eb94ede4..817993ef5c5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Lgamma.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Lgamma.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the log of the absolute value of `Gamma(x)` element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Maximum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Maximum.java index 3f472f50956..e592e34e500 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Maximum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Maximum.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the max of x and y (i.e. x > y ? x : y) element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Minimum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Minimum.java index 0e64109ed02..1c3ab3bb724 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Minimum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Minimum.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the min of x and y (i.e. x < y ? x : y) element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Mod.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Mod.java index de961916821..e8ead53a30d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Mod.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Mod.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns element-wise remainder of division. This emulates C semantics in that diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ndtri.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ndtri.java index 4c9af64e439..5e5f9d8b326 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ndtri.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Ndtri.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/NextAfter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/NextAfter.java index 96657da0c5e..ec8a0231618 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/NextAfter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/NextAfter.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the next representable value of `x1` in the direction of `x2`, element-wise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Polygamma.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Polygamma.java index 5f236686689..f17bc4958b1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Polygamma.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Polygamma.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the polygamma function \\(\psi^{(n)}(x)\\). diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/PopulationCount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/PopulationCount.java index a9e7e157c5d..c3dfd8c23c9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/PopulationCount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/PopulationCount.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes element-wise population count (a.k.a. popcount, bitsum, bitcount). diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedAdd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedAdd.java index 3b0e3b7851e..ad4a7435e16 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedAdd.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedAdd.java @@ -17,11 +17,11 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -51,7 +51,7 @@ public final class QuantizedAdd extends RawOp { * @return a new instance of QuantizedAdd */ @Endpoint(describeByClass = true) - public static QuantizedAdd create(Scope scope, Operand x, Operand y, Operand minX, Operand maxX, Operand minY, Operand maxY, DataType Toutput) { + public static QuantizedAdd create(Scope scope, Operand x, Operand y, Operand minX, Operand maxX, Operand minY, Operand maxY, Class Toutput) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedAdd", scope.makeOpName("QuantizedAdd")); opBuilder.addInput(x.asOutput()); opBuilder.addInput(y.asOutput()); @@ -60,7 +60,7 @@ public static QuantizedAdd(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedMul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedMul.java index 47ea14523df..cc0a23b4a3c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedMul.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/QuantizedMul.java @@ -17,11 +17,11 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -51,7 +51,7 @@ public final class QuantizedMul extends RawOp { * @return a new instance of QuantizedMul */ @Endpoint(describeByClass = true) - public static QuantizedMul create(Scope scope, Operand x, Operand y, Operand minX, Operand maxX, Operand minY, Operand maxY, DataType Toutput) { + public static QuantizedMul create(Scope scope, Operand x, Operand y, Operand minX, Operand maxX, Operand minY, Operand maxY, Class Toutput) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMul", scope.makeOpName("QuantizedMul")); opBuilder.addInput(x.asOutput()); opBuilder.addInput(y.asOutput()); @@ -60,7 +60,7 @@ public static QuantizedMul(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Real.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Real.java index 44a4330129d..eba116d98f2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Real.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Real.java @@ -17,11 +17,11 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -59,11 +59,11 @@ public final class Real extends RawOp implements Operand { * @return a new instance of Real */ @Endpoint(describeByClass = true) - public static Real create(Scope scope, Operand input, DataType Tout) { + public static Real create(Scope scope, Operand input, Class Tout) { OperationBuilder opBuilder = scope.env().opBuilder("Real", scope.makeOpName("Real")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Tout", Tout); + opBuilder.setAttr("Tout", Operands.toDataType(Tout)); return new Real(opBuilder.build()); } @@ -76,7 +76,7 @@ public static Real create(Scope scope, O */ @Endpoint(describeByClass = true) public static Real create(Scope scope, Operand input) { - return create(scope, input, TFloat32.DTYPE); + return create(scope, input, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/RequantizePerChannel.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/RequantizePerChannel.java index 092787d2300..d1263d29d2e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/RequantizePerChannel.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/RequantizePerChannel.java @@ -17,11 +17,11 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -49,7 +49,7 @@ public final class RequantizePerChannel extends RawOp { * @return a new instance of RequantizePerChannel */ @Endpoint(describeByClass = true) - public static RequantizePerChannel create(Scope scope, Operand input, Operand inputMin, Operand inputMax, Operand requestedOutputMin, Operand requestedOutputMax, DataType outType) { + public static RequantizePerChannel create(Scope scope, Operand input, Operand inputMin, Operand inputMax, Operand requestedOutputMin, Operand requestedOutputMax, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("RequantizePerChannel", scope.makeOpName("RequantizePerChannel")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(inputMin.asOutput()); @@ -57,7 +57,7 @@ public static RequantizePerChannel create( opBuilder.addInput(requestedOutputMin.asOutput()); opBuilder.addInput(requestedOutputMax.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new RequantizePerChannel(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Rint.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Rint.java index 000014596d4..1de8f528899 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Rint.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Rint.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns element-wise integer closest to x. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMax.java index 824c509b390..8c0d4238d3f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMax.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the maximum along segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMin.java index ea4df712929..19ab9059871 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SegmentMin.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the minimum along segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SobolSample.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SobolSample.java index 674448c5325..5424188a9a6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SobolSample.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SobolSample.java @@ -17,11 +17,11 @@ package org.tensorflow.op.math; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -29,7 +29,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Generates points from the Sobol sequence. @@ -54,13 +53,13 @@ public final class SobolSample extends RawOp implements Opera * @return a new instance of SobolSample */ @Endpoint(describeByClass = true) - public static SobolSample create(Scope scope, Operand dim, Operand numResults, Operand skip, DataType dtype) { + public static SobolSample create(Scope scope, Operand dim, Operand numResults, Operand skip, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("SobolSample", scope.makeOpName("SobolSample")); opBuilder.addInput(dim.asOutput()); opBuilder.addInput(numResults.asOutput()); opBuilder.addInput(skip.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new SobolSample(opBuilder.build()); } @@ -77,7 +76,7 @@ public static SobolSample create(Scope scope, Operand create(Scope scope, Operand dim, Operand numResults, Operand skip) { - return create(scope, dim, numResults, skip, TFloat32.DTYPE); + return create(scope, dim, numResults, skip, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Softplus.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Softplus.java index 1f1d84bb394..1014467d255 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Softplus.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Softplus.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softplus: `log(exp(features) + 1)`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SoftplusGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SoftplusGrad.java index f46ebfc0c30..e3582acab64 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SoftplusGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/SoftplusGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softplus gradients for a softplus operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/TruncateMod.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/TruncateMod.java index 9a737a40101..9419e0db0e4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/TruncateMod.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/TruncateMod.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns element-wise remainder of division. This emulates C semantics in that diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMax.java index 8f6b834fceb..efaaf2dea96 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMax.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the maximum along segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMin.java index 2f9eadff306..e73ec14aa1b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/UnsortedSegmentMin.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the minimum along segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Zeta.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Zeta.java index b923c8671f6..73c290262c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Zeta.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/Zeta.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Compute the Hurwitz zeta function \\(\zeta(x, q)\\). diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/erfinv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/erfinv.java index 9d4324aa5c7..cfa2adab329 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/erfinv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/erfinv.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java index 00ea6ba2590..597e82504d0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java index 6b37411908f..0e05c32334d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java index fa35c222824..ae0586d549c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java index f24d8b24126..e4a677d2f1a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java index 89bf309b636..bcdd407508b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java index fbd6c2a7aa4..c5824c47cbd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java index ab10c5bb5c2..03c8c1ade60 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java index 507a4cadbbe..92a7561ee8a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Dawsn.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Dawsn.java index cded11af606..9e50d639d02 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Dawsn.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Dawsn.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Expint.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Expint.java index 2dff32ad8ef..8cff09b15c3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Expint.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Expint.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelCos.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelCos.java index 86a507a30a8..f3d0848f814 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelCos.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelCos.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelSin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelSin.java index 1e14aa97198..33d738bf9fd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelSin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/FresnelSin.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Spence.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Spence.java index 14515106925..39a3c134ebb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Spence.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/Spence.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code y()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool.java index 2586bc9d248..77fb90557ec 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs average pooling on the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3d.java index e1f7341dd14..7b675fc816e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3d.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs 3D average pooling on the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3dGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3dGrad.java index 252be1aafb7..96ad463e8aa 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3dGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPool3dGrad.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients of average pooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPoolGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPoolGrad.java index c8bc81a45b5..585f72414e1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPoolGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/AvgPoolGrad.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients of the average pooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTM.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTM.java index b290b029683..9b1b33daee8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTM.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTM.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the LSTM cell forward propagation for all the time steps. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTMGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTMGrad.java index 45bec4954b8..f7f0c2449d4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTMGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/BlockLSTMGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the LSTM cell backward propagation for the entire time sequence. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2d.java index 312aa16796d..e3d98fe0882 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2d.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes a 2-D convolution given 4-D `input` and `filter` tensors. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropFilter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropFilter.java index 2cfdaa08a6c..bc282d8b349 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropFilter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropFilter.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradients of convolution with respect to the filter. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropInput.java index dafc1789740..3b43fa3301d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropInput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv2dBackpropInput.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradients of convolution with respect to the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3d.java index 563958a8710..b778819c0ce 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3d.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes a 3-D convolution given 5-D `input` and `filter` tensors. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropFilter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropFilter.java index 2ca19483b4f..6ec7dd5b71a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropFilter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropFilter.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradients of 3-D convolution with respect to the filter. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropInput.java index a7069f560d0..0d6c1792495 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropInput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Conv3dBackpropInput.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradients of 3-D convolution with respect to the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcBeamSearchDecoder.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcBeamSearchDecoder.java index 9f21dfce317..6fb3ae265f4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcBeamSearchDecoder.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcBeamSearchDecoder.java @@ -30,7 +30,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs beam search decoding on the logits given in input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcGreedyDecoder.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcGreedyDecoder.java index db0020e6ee4..2e22105a783 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcGreedyDecoder.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcGreedyDecoder.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs greedy decoding on the logits given in inputs. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcLoss.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcLoss.java index 155881c8034..1ea61f47a0b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcLoss.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CtcLoss.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Calculates the CTC Loss (log probability) for each batch entry. Also calculates diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNN.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNN.java index 418df889760..12e94a102f6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNN.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNN.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * A RNN backed by cuDNN. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNBackprop.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNBackprop.java index 9d440d46895..78a08d58328 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNBackprop.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNBackprop.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Backprop step of CudnnRNNV3. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNCanonicalToParams.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNCanonicalToParams.java index 65382bfe14c..2c5c451fac8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNCanonicalToParams.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNCanonicalToParams.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Converts CudnnRNN params from canonical form to usable form. It supports the projection in LSTM. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNParamsToCanonical.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNParamsToCanonical.java index bf070e2b602..adf30df3312 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNParamsToCanonical.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRNNParamsToCanonical.java @@ -29,7 +29,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Retrieves CudnnRNN params in canonical form. It supports the projection in LSTM. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRnnParamsSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRnnParamsSize.java index fc6221112d1..14dcc23a10e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRnnParamsSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/CudnnRnnParamsSize.java @@ -17,18 +17,17 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes size of weights that can be used by a Cudnn RNN model. @@ -146,14 +145,14 @@ private Options() { * @return a new instance of CudnnRnnParamsSize */ @Endpoint(describeByClass = true) - public static CudnnRnnParamsSize create(Scope scope, Operand numLayers, Operand numUnits, Operand inputSize, DataType T, DataType S, Options... options) { + public static CudnnRnnParamsSize create(Scope scope, Operand numLayers, Operand numUnits, Operand inputSize, Class T, Class S, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("CudnnRNNParamsSize", scope.makeOpName("CudnnRnnParamsSize")); opBuilder.addInput(numLayers.asOutput()); opBuilder.addInput(numUnits.asOutput()); opBuilder.addInput(inputSize.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("T", T); - opBuilder.setAttr("S", S); + opBuilder.setAttr("T", Operands.toDataType(T)); + opBuilder.setAttr("S", Operands.toDataType(S)); if (options != null) { for (Options opts : options) { if (opts.rnnMode != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatDimMap.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatDimMap.java index 36e93d0275e..115a5bb12d8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatDimMap.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatDimMap.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the dimension index in the destination data format given the one in diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatVecPermute.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatVecPermute.java index f683175840d..6f8799f0e68 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatVecPermute.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DataFormatVecPermute.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the permuted vector/tensor in the destination data format given the diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNative.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNative.java index 4b90f9e2e61..8420a6524dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNative.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNative.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropFilter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropFilter.java index 9fbc1c64cdc..3825826677b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropFilter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropFilter.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradients of depthwise convolution with respect to the filter. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropInput.java index e8bed246669..35f10aaed5d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropInput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/DepthwiseConv2dNativeBackpropInput.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradients of depthwise convolution with respect to the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2d.java index e5ce7139ecb..d87f33ca07f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2d.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropFilter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropFilter.java index 327646fbb15..c4e1d45a440 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropFilter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropFilter.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of morphological 2-D dilation with respect to the filter. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropInput.java index a546990717b..3f6bb100bc5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropInput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Dilation2dBackpropInput.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the gradient of morphological 2-D dilation with respect to the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Elu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Elu.java index 2de2a4230fc..86012a4c4ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Elu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Elu.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/EluGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/EluGrad.java index 80905bdafc0..4771374123e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/EluGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/EluGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients for the exponential linear (Elu) operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPool.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPool.java index 6eef240adab..a9bddf8774c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPool.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPool.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs fractional average pooling on the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPoolGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPoolGrad.java index 8c84128977f..094ce640a09 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPoolGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalAvgPoolGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradient of the FractionalAvgPool function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPool.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPool.java index c05d2b19f05..86c7e204757 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPool.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPool.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs fractional max pooling on the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPoolGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPoolGrad.java index 939e1cb1d20..e059e233e05 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPoolGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FractionalMaxPoolGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradient of the FractionalMaxPool function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNorm.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNorm.java index 7ca906dd6e8..29a6c2bdad3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNorm.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNorm.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Batch normalization. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNormGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNormGrad.java index c1e01556128..aa8d3b94b99 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNormGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedBatchNormGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Gradient for batch normalization. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedPadConv2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedPadConv2d.java index 74d3caeb6f3..db1a7dac897 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedPadConv2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedPadConv2d.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs a padding as a preprocess during a convolution. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedResizeAndPadConv2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedResizeAndPadConv2d.java index 4694a56caba..1bc106ee063 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedResizeAndPadConv2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/FusedResizeAndPadConv2d.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs a resize and padding as a preprocess during a convolution. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCell.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCell.java index a25a1ccd584..19ee4709eb8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCell.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCell.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the GRU cell forward propagation for 1 time step. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCellGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCellGrad.java index 823658a8df5..16d2efc3dd5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCellGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/GRUBlockCellGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the GRU cell back-propagation for 1 time step. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/InTopK.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/InTopK.java index 9df615a2106..bdf5f761d1a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/InTopK.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/InTopK.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Says whether the targets are in the top `K` predictions. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/L2Loss.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/L2Loss.java index 8c30f10a3df..45856e6c255 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/L2Loss.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/L2Loss.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * L2 Loss. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCell.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCell.java index 666d0645290..ad144fed082 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCell.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCell.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the LSTM cell forward propagation for 1 time step. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCellGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCellGrad.java index 26df5317669..6bcf274af61 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCellGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LSTMBlockCellGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the LSTM cell backward propagation for 1 timestep. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java index ba090ab23d7..2016d3a6abb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes rectified linear: `max(features, features * alpha)`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalization.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalization.java index 858201caee2..1e28dd0ee14 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalization.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalization.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Local Response Normalization. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalizationGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalizationGrad.java index d83b0e75975..51915dc43a7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalizationGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LocalResponseNormalizationGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Gradients for Local Response Normalization. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LogSoftmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LogSoftmax.java index 3fe2e55871f..75addd4b6a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LogSoftmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LogSoftmax.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes log softmax activations. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3d.java index 63ee208be62..bb802045fc4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3d.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs 3D max pooling on the input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGrad.java index 213f4fefb03..12966ef1c04 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients of 3D max pooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGradGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGradGrad.java index cddc0186e5d..1153cc1fb68 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGradGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPool3dGradGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes second-order gradients of the maxpooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGrad.java index 985d6fa1214..58bcd623533 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients of the maxpooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGrad.java index 8138cc51206..84ad249f82d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes second-order gradients of the maxpooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGradWithArgmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGradWithArgmax.java index 59794d0adfb..5b4af53845d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGradWithArgmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGradWithArgmax.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes second-order gradients of the maxpooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradWithArgmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradWithArgmax.java index 75c814d293b..ef9af71cb95 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradWithArgmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradWithArgmax.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients of the maxpooling function. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolWithArgmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolWithArgmax.java index 29c3a7a627c..f269753c21b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolWithArgmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolWithArgmax.java @@ -18,18 +18,17 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs max pooling on the input and outputs both max values and indices. @@ -83,7 +82,7 @@ private Options() { * @return a new instance of MaxPoolWithArgmax */ @Endpoint(describeByClass = true) - public static MaxPoolWithArgmax create(Scope scope, Operand input, List ksize, List strides, DataType Targmax, String padding, Options... options) { + public static MaxPoolWithArgmax create(Scope scope, Operand input, List ksize, List strides, Class Targmax, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("MaxPoolWithArgmax", scope.makeOpName("MaxPoolWithArgmax")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); @@ -97,7 +96,7 @@ public static MaxPoolWithArgmax cre stridesArray[i] = strides.get(i); } opBuilder.setAttr("strides", stridesArray); - opBuilder.setAttr("Targmax", Targmax); + opBuilder.setAttr("Targmax", Operands.toDataType(Targmax)); opBuilder.setAttr("padding", padding); if (options != null) { for (Options opts : options) { @@ -123,7 +122,7 @@ public static MaxPoolWithArgmax cre */ @Endpoint(describeByClass = true) public static MaxPoolWithArgmax create(Scope scope, Operand input, List ksize, List strides, String padding, Options... options) { - return create(scope, input, ksize, strides, TInt64.DTYPE, padding, options); + return create(scope, input, ksize, strides, TInt64.class, padding, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/NthElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/NthElement.java index 2c787751fff..9ea182bb4e5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/NthElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/NthElement.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Finds values of the `n`-th order statistic for the last dimension. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBatchNormWithGlobalNormalization.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBatchNormWithGlobalNormalization.java index 97d2aced9b7..60f242ef636 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBatchNormWithGlobalNormalization.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBatchNormWithGlobalNormalization.java @@ -17,11 +17,11 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -73,7 +73,7 @@ public final class QuantizedBatchNormWithGlobalNormalization ex * @return a new instance of QuantizedBatchNormWithGlobalNormalization */ @Endpoint(describeByClass = true) - public static QuantizedBatchNormWithGlobalNormalization create(Scope scope, Operand t, Operand tMin, Operand tMax, Operand m, Operand mMin, Operand mMax, Operand v, Operand vMin, Operand vMax, Operand beta, Operand betaMin, Operand betaMax, Operand gamma, Operand gammaMin, Operand gammaMax, DataType outType, Float varianceEpsilon, Boolean scaleAfterNormalization) { + public static QuantizedBatchNormWithGlobalNormalization create(Scope scope, Operand t, Operand tMin, Operand tMax, Operand m, Operand mMin, Operand mMax, Operand v, Operand vMin, Operand vMax, Operand beta, Operand betaMin, Operand betaMax, Operand gamma, Operand gammaMin, Operand gammaMax, Class outType, Float varianceEpsilon, Boolean scaleAfterNormalization) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedBatchNormWithGlobalNormalization", scope.makeOpName("QuantizedBatchNormWithGlobalNormalization")); opBuilder.addInput(t.asOutput()); opBuilder.addInput(tMin.asOutput()); @@ -91,7 +91,7 @@ public static QuantizedBatchNormWithGlobalNor opBuilder.addInput(gammaMin.asOutput()); opBuilder.addInput(gammaMax.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); opBuilder.setAttr("variance_epsilon", varianceEpsilon); opBuilder.setAttr("scale_after_normalization", scaleAfterNormalization); return new QuantizedBatchNormWithGlobalNormalization(opBuilder.build()); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBiasAdd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBiasAdd.java index 89171c3bc15..05e2ec67aa5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBiasAdd.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedBiasAdd.java @@ -17,11 +17,11 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -53,7 +53,7 @@ public final class QuantizedBiasAdd extends RawOp { * @return a new instance of QuantizedBiasAdd */ @Endpoint(describeByClass = true) - public static QuantizedBiasAdd create(Scope scope, Operand input, Operand bias, Operand minInput, Operand maxInput, Operand minBias, Operand maxBias, DataType outType) { + public static QuantizedBiasAdd create(Scope scope, Operand input, Operand bias, Operand minInput, Operand maxInput, Operand minBias, Operand maxBias, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedBiasAdd", scope.makeOpName("QuantizedBiasAdd")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(bias.asOutput()); @@ -62,7 +62,7 @@ public static QuantizedBiasA opBuilder.addInput(minBias.asOutput()); opBuilder.addInput(maxBias.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new QuantizedBiasAdd(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRelu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRelu.java index 516b4d762f1..89aee4bd5f5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRelu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRelu.java @@ -18,11 +18,11 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -80,7 +80,7 @@ private Options() { * @return a new instance of QuantizedConv2DAndRelu */ @Endpoint(describeByClass = true) - public static QuantizedConv2DAndRelu create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DAndRelu create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DAndRelu", scope.makeOpName("QuantizedConv2DAndRelu")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -89,7 +89,7 @@ public static QuantizedConv2 opBuilder.addInput(minFilter.asOutput()); opBuilder.addInput(maxFilter.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); long[] stridesArray = new long[strides.size()]; for (int i = 0; i < stridesArray.length; ++i) { stridesArray[i] = strides.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndReluAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndReluAndRequantize.java index 99f10f69ad4..513a9b240e5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndReluAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndReluAndRequantize.java @@ -18,11 +18,11 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -82,7 +82,7 @@ private Options() { * @return a new instance of QuantizedConv2DAndReluAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedConv2DAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DAndReluAndRequantize", scope.makeOpName("QuantizedConv2DAndReluAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -93,7 +93,7 @@ public static QuantizedConv2 opBuilder.addInput(minFreezedOutput.asOutput()); opBuilder.addInput(maxFreezedOutput.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); long[] stridesArray = new long[strides.size()]; for (int i = 0; i < stridesArray.length; ++i) { stridesArray[i] = strides.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRequantize.java index f1531565cee..80a1ddf146a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DAndRequantize.java @@ -18,11 +18,11 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -82,7 +82,7 @@ private Options() { * @return a new instance of QuantizedConv2DAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedConv2DAndRequantize create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DAndRequantize create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DAndRequantize", scope.makeOpName("QuantizedConv2DAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -93,7 +93,7 @@ public static QuantizedConv2 opBuilder.addInput(minFreezedOutput.asOutput()); opBuilder.addInput(maxFreezedOutput.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); long[] stridesArray = new long[strides.size()]; for (int i = 0; i < stridesArray.length; ++i) { stridesArray[i] = strides.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DPerChannel.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DPerChannel.java index adb0975e3ef..f51143d15be 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DPerChannel.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DPerChannel.java @@ -18,11 +18,11 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -73,7 +73,7 @@ private Options() { * @return a new instance of QuantizedConv2DPerChannel */ @Endpoint(describeByClass = true) - public static QuantizedConv2DPerChannel create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DPerChannel create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DPerChannel", scope.makeOpName("QuantizedConv2DPerChannel")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -82,7 +82,7 @@ public static QuantizedConv2 opBuilder.addInput(minFilter.asOutput()); opBuilder.addInput(maxFilter.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); long[] stridesArray = new long[strides.size()]; for (int i = 0; i < stridesArray.length; ++i) { stridesArray[i] = strides.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBias.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBias.java index 8411f5e30ec..c579e306102 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBias.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBias.java @@ -18,11 +18,11 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -81,7 +81,7 @@ private Options() { * @return a new instance of QuantizedConv2DWithBias */ @Endpoint(describeByClass = true) - public static QuantizedConv2DWithBias create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBias create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBias", scope.makeOpName("QuantizedConv2DWithBias")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -91,7 +91,7 @@ public static QuantizedConv2 opBuilder.addInput(minFilter.asOutput()); opBuilder.addInput(maxFilter.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); long[] stridesArray = new long[strides.size()]; for (int i = 0; i < stridesArray.length; ++i) { stridesArray[i] = strides.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndRelu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndRelu.java index d204bd9e6ec..ac9672e9f94 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndRelu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndRelu.java @@ -18,11 +18,11 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -81,7 +81,7 @@ private Options() { * @return a new instance of QuantizedConv2DWithBiasAndRelu */ @Endpoint(describeByClass = true) - public static QuantizedConv2DWithBiasAndRelu create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBiasAndRelu create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBiasAndRelu", scope.makeOpName("QuantizedConv2DWithBiasAndRelu")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -91,7 +91,7 @@ public static QuantizedConv2 opBuilder.addInput(minFilter.asOutput()); opBuilder.addInput(maxFilter.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); long[] stridesArray = new long[strides.size()]; for (int i = 0; i < stridesArray.length; ++i) { stridesArray[i] = strides.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndReluAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndReluAndRequantize.java index dd3f8d7cd05..084f211be2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndReluAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasAndReluAndRequantize.java @@ -18,11 +18,11 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -83,7 +83,7 @@ private Options() { * @return a new instance of QuantizedConv2DWithBiasAndReluAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedConv2DWithBiasAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBiasAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBiasAndReluAndRequantize", scope.makeOpName("QuantizedConv2DWithBiasAndReluAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -95,7 +95,7 @@ public static QuantizedConv2DWithBiasAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBiasAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBiasAndRequantize", scope.makeOpName("QuantizedConv2DWithBiasAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -95,7 +95,7 @@ public static QuantizedConv2DWithBiasSignedSumAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Operand summand, Operand minSummand, Operand maxSummand, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBiasSignedSumAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Operand summand, Operand minSummand, Operand maxSummand, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize", scope.makeOpName("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -101,7 +101,7 @@ public static QuantizedConv2DWithBiasSumAndRelu create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand summand, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBiasSumAndRelu create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand summand, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBiasSumAndRelu", scope.makeOpName("QuantizedConv2DWithBiasSumAndRelu")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -93,7 +93,7 @@ public static QuantizedConv2 opBuilder.addInput(maxFilter.asOutput()); opBuilder.addInput(summand.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); long[] stridesArray = new long[strides.size()]; for (int i = 0; i < stridesArray.length; ++i) { stridesArray[i] = strides.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSumAndReluAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSumAndReluAndRequantize.java index 953cf2cddbe..68052783adb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSumAndReluAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedConv2DWithBiasSumAndReluAndRequantize.java @@ -18,11 +18,11 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -86,7 +86,7 @@ private Options() { * @return a new instance of QuantizedConv2DWithBiasSumAndReluAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedConv2DWithBiasSumAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Operand summand, Operand minSummand, Operand maxSummand, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2DWithBiasSumAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Operand summand, Operand minSummand, Operand maxSummand, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2DWithBiasSumAndReluAndRequantize", scope.makeOpName("QuantizedConv2DWithBiasSumAndReluAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -101,7 +101,7 @@ public static QuantizedConv2d create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedConv2d create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedConv2D", scope.makeOpName("QuantizedConv2d")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -93,7 +93,7 @@ public static QuantizedConv2 opBuilder.addInput(minFilter.asOutput()); opBuilder.addInput(maxFilter.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); long[] stridesArray = new long[strides.size()]; for (int i = 0; i < stridesArray.length; ++i) { stridesArray[i] = strides.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2D.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2D.java index ca98dc872ad..7dae76d87bc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2D.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2D.java @@ -18,11 +18,11 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -73,7 +73,7 @@ private Options() { * @return a new instance of QuantizedDepthwiseConv2D */ @Endpoint(describeByClass = true) - public static QuantizedDepthwiseConv2D create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedDepthwiseConv2D create(Scope scope, Operand input, Operand filter, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedDepthwiseConv2D", scope.makeOpName("QuantizedDepthwiseConv2D")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -82,7 +82,7 @@ public static QuantizedDepth opBuilder.addInput(minFilter.asOutput()); opBuilder.addInput(maxFilter.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); long[] stridesArray = new long[strides.size()]; for (int i = 0; i < stridesArray.length; ++i) { stridesArray[i] = strides.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBias.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBias.java index bd2fe163ed6..fc7f07525af 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBias.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBias.java @@ -18,11 +18,11 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -74,7 +74,7 @@ private Options() { * @return a new instance of QuantizedDepthwiseConv2DWithBias */ @Endpoint(describeByClass = true) - public static QuantizedDepthwiseConv2DWithBias create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedDepthwiseConv2DWithBias create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedDepthwiseConv2DWithBias", scope.makeOpName("QuantizedDepthwiseConv2DWithBias")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -84,7 +84,7 @@ public static QuantizedDepth opBuilder.addInput(minFilter.asOutput()); opBuilder.addInput(maxFilter.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); long[] stridesArray = new long[strides.size()]; for (int i = 0; i < stridesArray.length; ++i) { stridesArray[i] = strides.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndRelu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndRelu.java index 5726f9af259..c16b3c91a0e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndRelu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndRelu.java @@ -18,11 +18,11 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -83,7 +83,7 @@ private Options() { * @return a new instance of QuantizedDepthwiseConv2DWithBiasAndRelu */ @Endpoint(describeByClass = true) - public static QuantizedDepthwiseConv2DWithBiasAndRelu create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, DataType outType, List strides, String padding, Options... options) { + public static QuantizedDepthwiseConv2DWithBiasAndRelu create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedDepthwiseConv2DWithBiasAndRelu", scope.makeOpName("QuantizedDepthwiseConv2DWithBiasAndRelu")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -93,7 +93,7 @@ public static QuantizedDepth opBuilder.addInput(minFilter.asOutput()); opBuilder.addInput(maxFilter.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); long[] stridesArray = new long[strides.size()]; for (int i = 0; i < stridesArray.length; ++i) { stridesArray[i] = strides.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize.java index 073e4f0305b..02dccc28354 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize.java @@ -18,11 +18,11 @@ package org.tensorflow.op.nn; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -85,7 +85,7 @@ private Options() { * @return a new instance of QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize */ @Endpoint(describeByClass = true) - public static QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, DataType outType, List strides, String padding, Options... options) { + public static QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize create(Scope scope, Operand input, Operand filter, Operand bias, Operand minInput, Operand maxInput, Operand minFilter, Operand maxFilter, Operand minFreezedOutput, Operand maxFreezedOutput, Class outType, List strides, String padding, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", scope.makeOpName("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(filter.asOutput()); @@ -97,7 +97,7 @@ public static extends RawOp { * @return a new instance of QuantizedRelu */ @Endpoint(describeByClass = true) - public static QuantizedRelu create(Scope scope, Operand features, Operand minFeatures, Operand maxFeatures, DataType outType) { + public static QuantizedRelu create(Scope scope, Operand features, Operand minFeatures, Operand maxFeatures, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedRelu", scope.makeOpName("QuantizedRelu")); opBuilder.addInput(features.asOutput()); opBuilder.addInput(minFeatures.asOutput()); opBuilder.addInput(maxFeatures.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new QuantizedRelu(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedRelu6.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedRelu6.java index 33276f514b3..2f99d1ab861 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedRelu6.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedRelu6.java @@ -17,11 +17,11 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -48,13 +48,13 @@ public final class QuantizedRelu6 extends RawOp { * @return a new instance of QuantizedRelu6 */ @Endpoint(describeByClass = true) - public static QuantizedRelu6 create(Scope scope, Operand features, Operand minFeatures, Operand maxFeatures, DataType outType) { + public static QuantizedRelu6 create(Scope scope, Operand features, Operand minFeatures, Operand maxFeatures, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedRelu6", scope.makeOpName("QuantizedRelu6")); opBuilder.addInput(features.asOutput()); opBuilder.addInput(minFeatures.asOutput()); opBuilder.addInput(maxFeatures.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new QuantizedRelu6(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedReluX.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedReluX.java index d20e5cd6934..afd943b7595 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedReluX.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/QuantizedReluX.java @@ -17,11 +17,11 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -49,14 +49,14 @@ public final class QuantizedReluX extends RawOp { * @return a new instance of QuantizedReluX */ @Endpoint(describeByClass = true) - public static QuantizedReluX create(Scope scope, Operand features, Operand maxValue, Operand minFeatures, Operand maxFeatures, DataType outType) { + public static QuantizedReluX create(Scope scope, Operand features, Operand maxValue, Operand minFeatures, Operand maxFeatures, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedReluX", scope.makeOpName("QuantizedReluX")); opBuilder.addInput(features.asOutput()); opBuilder.addInput(maxValue.asOutput()); opBuilder.addInput(minFeatures.asOutput()); opBuilder.addInput(maxFeatures.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new QuantizedReluX(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6.java index ca05ffc74c5..efb80c73be6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes rectified linear 6: `min(max(features, 0), 6)`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6Grad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6Grad.java index 6408b81970c..38e042b1635 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6Grad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Relu6Grad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes rectified linear 6 gradients for a Relu6 operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/ReluGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/ReluGrad.java index 3ea222907ef..27999096f0e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/ReluGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/ReluGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes rectified linear gradients for a Relu operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Selu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Selu.java index a90097c263c..0e7b45fcd6a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Selu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Selu.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SeluGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SeluGrad.java index 8ced5023316..e80b0e7114b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SeluGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SeluGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients for the scaled exponential linear (Selu) operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softmax.java index 549f3f9ad92..d7180ef953f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softmax.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softmax activations. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softsign.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softsign.java index acc3fc20ca8..efd21d04476 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softsign.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/Softsign.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softsign: `features / (abs(features) + 1)`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftsignGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftsignGrad.java index 2f9f5a911e6..a8d5669551d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftsignGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftsignGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softsign gradients for a softsign operation. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/TopK.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/TopK.java index f0c37766115..d86357aa90d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/TopK.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/TopK.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Finds values and indices of the `k` largest elements for the last dimension. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java index c1adf5358d6..8032a4c2512 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softmax cross entropy cost and gradients to backpropagate. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java index 54ed2d30c95..6cbd4fddeb1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes softmax cross entropy cost and gradients to backpropagate. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Dequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Dequantize.java index 3f53e0b5bc7..19c738bc721 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Dequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Dequantize.java @@ -17,11 +17,11 @@ package org.tensorflow.op.quantization; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -137,13 +137,13 @@ private Options() { * @return a new instance of Dequantize */ @Endpoint(describeByClass = true) - public static Dequantize create(Scope scope, Operand input, Operand minRange, Operand maxRange, DataType dtype, Options... options) { + public static Dequantize create(Scope scope, Operand input, Operand minRange, Operand maxRange, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Dequantize", scope.makeOpName("Dequantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(minRange.asOutput()); opBuilder.addInput(maxRange.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.mode != null) { @@ -172,7 +172,7 @@ public static Dequantize create(Scope sc */ @Endpoint(describeByClass = true) public static Dequantize create(Scope scope, Operand input, Operand minRange, Operand maxRange, Options... options) { - return create(scope, input, minRange, maxRange, TFloat32.DTYPE, options); + return create(scope, input, minRange, maxRange, TFloat32.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Quantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Quantize.java index a310df7317e..edab6594b2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Quantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/Quantize.java @@ -17,11 +17,11 @@ package org.tensorflow.op.quantization; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -220,13 +220,13 @@ private Options() { * @return a new instance of Quantize */ @Endpoint(describeByClass = true) - public static Quantize create(Scope scope, Operand input, Operand minRange, Operand maxRange, DataType T, Options... options) { + public static Quantize create(Scope scope, Operand input, Operand minRange, Operand maxRange, Class T, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizeV2", scope.makeOpName("Quantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(minRange.asOutput()); opBuilder.addInput(maxRange.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("T", T); + opBuilder.setAttr("T", Operands.toDataType(T)); if (options != null) { for (Options opts : options) { if (opts.mode != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeAndDequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeAndDequantize.java index e587abc87ac..d66ff6c4871 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeAndDequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeAndDequantize.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Quantizes then dequantizes a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeDownAndShrinkRange.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeDownAndShrinkRange.java index 35952c74490..5430dbc69ca 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeDownAndShrinkRange.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizeDownAndShrinkRange.java @@ -17,11 +17,11 @@ package org.tensorflow.op.quantization; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -71,13 +71,13 @@ public final class QuantizeDownAndShrinkRange extends RawOp { * @return a new instance of QuantizeDownAndShrinkRange */ @Endpoint(describeByClass = true) - public static QuantizeDownAndShrinkRange create(Scope scope, Operand input, Operand inputMin, Operand inputMax, DataType outType) { + public static QuantizeDownAndShrinkRange create(Scope scope, Operand input, Operand inputMin, Operand inputMax, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizeDownAndShrinkRange", scope.makeOpName("QuantizeDownAndShrinkRange")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(inputMin.asOutput()); opBuilder.addInput(inputMax.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new QuantizeDownAndShrinkRange(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizedMatMulWithBiasAndDequantize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizedMatMulWithBiasAndDequantize.java index f47477719f5..e925c15546b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizedMatMulWithBiasAndDequantize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/quantization/QuantizedMatMulWithBiasAndDequantize.java @@ -17,11 +17,11 @@ package org.tensorflow.op.quantization; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -90,7 +90,7 @@ private Options() { * @return a new instance of QuantizedMatMulWithBiasAndDequantize */ @Endpoint(describeByClass = true) - public static QuantizedMatMulWithBiasAndDequantize create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Operand minFreezedOutput, Operand maxFreezedOutput, DataType Toutput, Options... options) { + public static QuantizedMatMulWithBiasAndDequantize create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Operand minFreezedOutput, Operand maxFreezedOutput, Class Toutput, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMatMulWithBiasAndDequantize", scope.makeOpName("QuantizedMatMulWithBiasAndDequantize")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); @@ -102,7 +102,7 @@ public static QuantizedMatMulWithBiasAndRequantize create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Operand minFreezedOutput, Operand maxFreezedOutput, DataType Toutput, Options... options) { + public static QuantizedMatMulWithBiasAndRequantize create(Scope scope, Operand a, Operand b, Operand bias, Operand minA, Operand maxA, Operand minB, Operand maxB, Operand minFreezedOutput, Operand maxFreezedOutput, Class Toutput, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("QuantizedMatMulWithBiasAndRequantize", scope.makeOpName("QuantizedMatMulWithBiasAndRequantize")); opBuilder.addInput(a.asOutput()); opBuilder.addInput(b.asOutput()); @@ -101,7 +101,7 @@ public static extends RawOp { * @return a new instance of Requantize */ @Endpoint(describeByClass = true) - public static Requantize create(Scope scope, Operand input, Operand inputMin, Operand inputMax, Operand requestedOutputMin, Operand requestedOutputMax, DataType outType) { + public static Requantize create(Scope scope, Operand input, Operand inputMin, Operand inputMax, Operand requestedOutputMin, Operand requestedOutputMax, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("Requantize", scope.makeOpName("Requantize")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(inputMin.asOutput()); @@ -66,7 +66,7 @@ public static Requantize create(Scope scop opBuilder.addInput(requestedOutputMin.asOutput()); opBuilder.addInput(requestedOutputMax.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new Requantize(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java index 3b23fa1e5a0..12ca81cf3c6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Counts the number of occurrences of each value in an integer array. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java index bc188254edd..857d412129f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs sparse-output bin counting for a ragged tensor input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java index 308bc5fd3d1..3ad9d57582f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java @@ -17,7 +17,6 @@ package org.tensorflow.op.ragged; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; @@ -62,7 +61,7 @@ public final class RaggedCross extends RawOp * @return a new instance of RaggedCross */ @Endpoint(describeByClass = true) - public static RaggedCross create(Scope scope, Iterable> raggedValues, Iterable> raggedRowSplits, Iterable> sparseIndices, Iterable> sparseValues, Iterable> sparseShape, Iterable> denseInputs, String inputOrder, Boolean hashedOutput, Long numBuckets, Long hashKey, DataType outValuesType, DataType outRowSplitsType) { + public static RaggedCross create(Scope scope, Iterable> raggedValues, Iterable> raggedRowSplits, Iterable> sparseIndices, Iterable> sparseValues, Iterable> sparseShape, Iterable> denseInputs, String inputOrder, Boolean hashedOutput, Long numBuckets, Long hashKey, Class outValuesType, Class outRowSplitsType) { OperationBuilder opBuilder = scope.env().opBuilder("RaggedCross", scope.makeOpName("RaggedCross")); opBuilder.addInputList(Operands.asOutputs(raggedValues)); opBuilder.addInputList(Operands.asOutputs(raggedRowSplits)); @@ -75,8 +74,8 @@ public static RaggedCross create(Scop opBuilder.setAttr("hashed_output", hashedOutput); opBuilder.setAttr("num_buckets", numBuckets); opBuilder.setAttr("hash_key", hashKey); - opBuilder.setAttr("out_values_type", outValuesType); - opBuilder.setAttr("out_row_splits_type", outRowSplitsType); + opBuilder.setAttr("out_values_type", Operands.toDataType(outValuesType)); + opBuilder.setAttr("out_row_splits_type", Operands.toDataType(outRowSplitsType)); return new RaggedCross(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedRange.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedRange.java index 8f98601a153..67df6a76a63 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedRange.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedRange.java @@ -17,18 +17,17 @@ package org.tensorflow.op.ragged; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns a `RaggedTensor` containing the specified sequences of numbers. @@ -64,13 +63,13 @@ public final class RaggedRange extends Raw * @return a new instance of RaggedRange */ @Endpoint(describeByClass = true) - public static RaggedRange create(Scope scope, Operand starts, Operand limits, Operand deltas, DataType Tsplits) { + public static RaggedRange create(Scope scope, Operand starts, Operand limits, Operand deltas, Class Tsplits) { OperationBuilder opBuilder = scope.env().opBuilder("RaggedRange", scope.makeOpName("RaggedRange")); opBuilder.addInput(starts.asOutput()); opBuilder.addInput(limits.asOutput()); opBuilder.addInput(deltas.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Tsplits", Tsplits); + opBuilder.setAttr("Tsplits", Operands.toDataType(Tsplits)); return new RaggedRange(opBuilder.build()); } @@ -85,7 +84,7 @@ public static RaggedRange create(Sc */ @Endpoint(describeByClass = true) public static RaggedRange create(Scope scope, Operand starts, Operand limits, Operand deltas) { - return create(scope, starts, limits, deltas, TInt64.DTYPE); + return create(scope, starts, limits, deltas, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedTensorFromVariant.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedTensorFromVariant.java index 7c82a9bf5e7..f58a5ca64f3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedTensorFromVariant.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedTensorFromVariant.java @@ -19,11 +19,11 @@ import java.util.Arrays; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -67,14 +67,14 @@ public final class RaggedTensorFromVariant e * @return a new instance of RaggedTensorFromVariant */ @Endpoint(describeByClass = true) - public static RaggedTensorFromVariant create(Scope scope, Operand encodedRagged, Long inputRaggedRank, Long outputRaggedRank, DataType Tvalues, DataType Tsplits) { + public static RaggedTensorFromVariant create(Scope scope, Operand encodedRagged, Long inputRaggedRank, Long outputRaggedRank, Class Tvalues, Class Tsplits) { OperationBuilder opBuilder = scope.env().opBuilder("RaggedTensorFromVariant", scope.makeOpName("RaggedTensorFromVariant")); opBuilder.addInput(encodedRagged.asOutput()); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("input_ragged_rank", inputRaggedRank); opBuilder.setAttr("output_ragged_rank", outputRaggedRank); - opBuilder.setAttr("Tvalues", Tvalues); - opBuilder.setAttr("Tsplits", Tsplits); + opBuilder.setAttr("Tvalues", Operands.toDataType(Tvalues)); + opBuilder.setAttr("Tsplits", Operands.toDataType(Tsplits)); return new RaggedTensorFromVariant(opBuilder.build()); } @@ -91,8 +91,8 @@ public static RaggedTensorFromVariant * @return a new instance of RaggedTensorFromVariant */ @Endpoint(describeByClass = true) - public static RaggedTensorFromVariant create(Scope scope, Operand encodedRagged, Long inputRaggedRank, Long outputRaggedRank, DataType Tvalues) { - return create(scope, encodedRagged, inputRaggedRank, outputRaggedRank, Tvalues, TInt64.DTYPE); + public static RaggedTensorFromVariant create(Scope scope, Operand encodedRagged, Long inputRaggedRank, Long outputRaggedRank, Class Tvalues) { + return create(scope, encodedRagged, inputRaggedRank, outputRaggedRank, Tvalues, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/Multinomial.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/Multinomial.java index 716a076af30..5dc566feb29 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/Multinomial.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/Multinomial.java @@ -17,11 +17,11 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -29,7 +29,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Draws samples from a multinomial distribution. @@ -80,12 +79,12 @@ private Options() { * @return a new instance of Multinomial */ @Endpoint(describeByClass = true) - public static Multinomial create(Scope scope, Operand logits, Operand numSamples, DataType outputDtype, Options... options) { + public static Multinomial create(Scope scope, Operand logits, Operand numSamples, Class outputDtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("Multinomial", scope.makeOpName("Multinomial")); opBuilder.addInput(logits.asOutput()); opBuilder.addInput(numSamples.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("output_dtype", outputDtype); + opBuilder.setAttr("output_dtype", Operands.toDataType(outputDtype)); if (options != null) { for (Options opts : options) { if (opts.seed != null) { @@ -111,7 +110,7 @@ public static Multinomial create(Scope */ @Endpoint(describeByClass = true) public static Multinomial create(Scope scope, Operand logits, Operand numSamples, Options... options) { - return create(scope, logits, numSamples, TInt64.DTYPE, options); + return create(scope, logits, numSamples, TInt64.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/NonDeterministicInts.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/NonDeterministicInts.java index 56637c87704..2f9aad878b3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/NonDeterministicInts.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/NonDeterministicInts.java @@ -17,11 +17,11 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -47,11 +47,11 @@ public final class NonDeterministicInts extends RawOp implement * @return a new instance of NonDeterministicInts */ @Endpoint(describeByClass = true) - public static NonDeterministicInts create(Scope scope, Operand shape, DataType dtype) { + public static NonDeterministicInts create(Scope scope, Operand shape, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("NonDeterministicInts", scope.makeOpName("NonDeterministicInts")); opBuilder.addInput(shape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new NonDeterministicInts(opBuilder.build()); } @@ -64,7 +64,7 @@ public static NonDeterministicInts create( */ @Endpoint(describeByClass = true) public static NonDeterministicInts create(Scope scope, Operand shape) { - return create(scope, shape, TInt64.DTYPE); + return create(scope, shape, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/ParameterizedTruncatedNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/ParameterizedTruncatedNormal.java index 2e60c56577e..f4300ebaa75 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/ParameterizedTruncatedNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/ParameterizedTruncatedNormal.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random values from a normal distribution. The parameters may each be a diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGamma.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGamma.java index a7ffc285a5d..ddf132eacb8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGamma.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGamma.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random values from the Gamma distribution(s) described by alpha. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGammaGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGammaGrad.java index ec330931444..35f9c06d172 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGammaGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomGammaGrad.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the derivative of a Gamma random sample w.r.t. `alpha`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomPoisson.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomPoisson.java index 02a8494bbc6..8b3c70059e0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomPoisson.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomPoisson.java @@ -17,18 +17,17 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random values from the Poisson distribution(s) described by rate. @@ -91,12 +90,12 @@ private Options() { * @return a new instance of RandomPoisson */ @Endpoint(describeByClass = true) - public static RandomPoisson create(Scope scope, Operand shape, Operand rate, DataType dtype, Options... options) { + public static RandomPoisson create(Scope scope, Operand shape, Operand rate, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("RandomPoissonV2", scope.makeOpName("RandomPoisson")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(rate.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.seed != null) { @@ -123,7 +122,7 @@ public static RandomPo */ @Endpoint(describeByClass = true) public static RandomPoisson create(Scope scope, Operand shape, Operand rate, Options... options) { - return create(scope, shape, rate, TInt64.DTYPE, options); + return create(scope, shape, rate, TInt64.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomStandardNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomStandardNormal.java index 8e323c90bed..c699ec4509c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomStandardNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomStandardNormal.java @@ -17,17 +17,16 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random values from a normal distribution. @@ -79,11 +78,11 @@ private Options() { * @return a new instance of RandomStandardNormal */ @Endpoint(describeByClass = true) - public static RandomStandardNormal create(Scope scope, Operand shape, DataType dtype, Options... options) { + public static RandomStandardNormal create(Scope scope, Operand shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("RandomStandardNormal", scope.makeOpName("RandomStandardNormal")); opBuilder.addInput(shape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.seed != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniform.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniform.java index 6f5b8a97caf..d3bbe991d60 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniform.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniform.java @@ -17,17 +17,16 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random values from a uniform distribution. @@ -80,11 +79,11 @@ private Options() { * @return a new instance of RandomUniform */ @Endpoint(describeByClass = true) - public static RandomUniform create(Scope scope, Operand shape, DataType dtype, Options... options) { + public static RandomUniform create(Scope scope, Operand shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("RandomUniform", scope.makeOpName("RandomUniform")); opBuilder.addInput(shape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.seed != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniformInt.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniformInt.java index 72f8ede53a4..f6ab24df811 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniformInt.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RandomUniformInt.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random integers from a uniform distribution. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulRandomBinomial.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulRandomBinomial.java index 33c2b7be4dc..67264028253 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulRandomBinomial.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulRandomBinomial.java @@ -17,18 +17,17 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output @@ -49,7 +48,7 @@ public final class StatefulRandomBinomial extends RawOp imple * @return a new instance of StatefulRandomBinomial */ @Endpoint(describeByClass = true) - public static StatefulRandomBinomial create(Scope scope, Operand resource, Operand algorithm, Operand shape, Operand counts, Operand probs, DataType dtype) { + public static StatefulRandomBinomial create(Scope scope, Operand resource, Operand algorithm, Operand shape, Operand counts, Operand probs, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatefulRandomBinomial", scope.makeOpName("StatefulRandomBinomial")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(algorithm.asOutput()); @@ -57,7 +56,7 @@ public static Stateful opBuilder.addInput(counts.asOutput()); opBuilder.addInput(probs.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new StatefulRandomBinomial(opBuilder.build()); } @@ -74,7 +73,7 @@ public static Stateful */ @Endpoint(describeByClass = true) public static StatefulRandomBinomial create(Scope scope, Operand resource, Operand algorithm, Operand shape, Operand counts, Operand probs) { - return create(scope, resource, algorithm, shape, counts, probs, TInt64.DTYPE); + return create(scope, resource, algorithm, shape, counts, probs, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulStandardNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulStandardNormal.java index 581509c7f3b..340703444ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulStandardNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulStandardNormal.java @@ -17,11 +17,11 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -51,13 +51,13 @@ public final class StatefulStandardNormal extends RawOp impleme * @return a new instance of StatefulStandardNormal */ @Endpoint(describeByClass = true) - public static StatefulStandardNormal create(Scope scope, Operand resource, Operand algorithm, Operand shape, DataType dtype) { + public static StatefulStandardNormal create(Scope scope, Operand resource, Operand algorithm, Operand shape, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatefulStandardNormalV2", scope.makeOpName("StatefulStandardNormal")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(algorithm.asOutput()); opBuilder.addInput(shape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new StatefulStandardNormal(opBuilder.build()); } @@ -72,7 +72,7 @@ public static StatefulStandardNormal creat */ @Endpoint(describeByClass = true) public static StatefulStandardNormal create(Scope scope, Operand resource, Operand algorithm, Operand shape) { - return create(scope, resource, algorithm, shape, TFloat32.DTYPE); + return create(scope, resource, algorithm, shape, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulTruncatedNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulTruncatedNormal.java index 22b75825390..850473f7662 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulTruncatedNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulTruncatedNormal.java @@ -17,11 +17,11 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -52,13 +52,13 @@ public final class StatefulTruncatedNormal extends RawOp implem * @return a new instance of StatefulTruncatedNormal */ @Endpoint(describeByClass = true) - public static StatefulTruncatedNormal create(Scope scope, Operand resource, Operand algorithm, Operand shape, DataType dtype) { + public static StatefulTruncatedNormal create(Scope scope, Operand resource, Operand algorithm, Operand shape, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatefulTruncatedNormal", scope.makeOpName("StatefulTruncatedNormal")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(algorithm.asOutput()); opBuilder.addInput(shape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new StatefulTruncatedNormal(opBuilder.build()); } @@ -73,7 +73,7 @@ public static StatefulTruncatedNormal crea */ @Endpoint(describeByClass = true) public static StatefulTruncatedNormal create(Scope scope, Operand resource, Operand algorithm, Operand shape) { - return create(scope, resource, algorithm, shape, TFloat32.DTYPE); + return create(scope, resource, algorithm, shape, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniform.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniform.java index 08eed6f0889..b33ced6beec 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniform.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniform.java @@ -17,11 +17,11 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -51,13 +51,13 @@ public final class StatefulUniform extends RawOp implements Ope * @return a new instance of StatefulUniform */ @Endpoint(describeByClass = true) - public static StatefulUniform create(Scope scope, Operand resource, Operand algorithm, Operand shape, DataType dtype) { + public static StatefulUniform create(Scope scope, Operand resource, Operand algorithm, Operand shape, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatefulUniform", scope.makeOpName("StatefulUniform")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(algorithm.asOutput()); opBuilder.addInput(shape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new StatefulUniform(opBuilder.build()); } @@ -72,7 +72,7 @@ public static StatefulUniform create(Scope */ @Endpoint(describeByClass = true) public static StatefulUniform create(Scope scope, Operand resource, Operand algorithm, Operand shape) { - return create(scope, resource, algorithm, shape, TFloat32.DTYPE); + return create(scope, resource, algorithm, shape, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniformFullInt.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniformFullInt.java index 4489103a8bf..6a06923808a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniformFullInt.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatefulUniformFullInt.java @@ -17,11 +17,11 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -49,13 +49,13 @@ public final class StatefulUniformFullInt extends RawOp impleme * @return a new instance of StatefulUniformFullInt */ @Endpoint(describeByClass = true) - public static StatefulUniformFullInt create(Scope scope, Operand resource, Operand algorithm, Operand shape, DataType dtype) { + public static StatefulUniformFullInt create(Scope scope, Operand resource, Operand algorithm, Operand shape, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatefulUniformFullInt", scope.makeOpName("StatefulUniformFullInt")); opBuilder.addInput(resource.asOutput()); opBuilder.addInput(algorithm.asOutput()); opBuilder.addInput(shape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new StatefulUniformFullInt(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessMultinomial.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessMultinomial.java index 60940410499..93f7a9254de 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessMultinomial.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessMultinomial.java @@ -17,11 +17,11 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -29,7 +29,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Draws samples from a multinomial distribution. @@ -51,13 +50,13 @@ public final class StatelessMultinomial extends RawOp impleme * @return a new instance of StatelessMultinomial */ @Endpoint(describeByClass = true) - public static StatelessMultinomial create(Scope scope, Operand logits, Operand numSamples, Operand seed, DataType outputDtype) { + public static StatelessMultinomial create(Scope scope, Operand logits, Operand numSamples, Operand seed, Class outputDtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessMultinomial", scope.makeOpName("StatelessMultinomial")); opBuilder.addInput(logits.asOutput()); opBuilder.addInput(numSamples.asOutput()); opBuilder.addInput(seed.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("output_dtype", outputDtype); + opBuilder.setAttr("output_dtype", Operands.toDataType(outputDtype)); return new StatelessMultinomial(opBuilder.build()); } @@ -73,7 +72,7 @@ public static Stateles */ @Endpoint(describeByClass = true) public static StatelessMultinomial create(Scope scope, Operand logits, Operand numSamples, Operand seed) { - return create(scope, logits, numSamples, seed, TInt64.DTYPE); + return create(scope, logits, numSamples, seed, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java index f179c37e30e..f6d5deadabf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * @param data type for {@code output()} output diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomBinomial.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomBinomial.java index 9c0e3cad0df..6da606468d1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomBinomial.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomBinomial.java @@ -17,18 +17,17 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom random numbers from a binomial distribution. @@ -55,14 +54,14 @@ public final class StatelessRandomBinomial extends RawOp impl * @return a new instance of StatelessRandomBinomial */ @Endpoint(describeByClass = true) - public static StatelessRandomBinomial create(Scope scope, Operand shape, Operand seed, Operand counts, Operand probs, DataType dtype) { + public static StatelessRandomBinomial create(Scope scope, Operand shape, Operand seed, Operand counts, Operand probs, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessRandomBinomial", scope.makeOpName("StatelessRandomBinomial")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(seed.asOutput()); opBuilder.addInput(counts.asOutput()); opBuilder.addInput(probs.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new StatelessRandomBinomial(opBuilder.build()); } @@ -80,7 +79,7 @@ public static StatelessRandomBinomial create(Scope scope, Operand shape, Operand seed, Operand counts, Operand probs) { - return create(scope, shape, seed, counts, probs, TInt64.DTYPE); + return create(scope, shape, seed, counts, probs, TInt64.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomGamma.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomGamma.java index faf065bfb5a..2b4ca37dee9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomGamma.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomGamma.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom random numbers from a gamma distribution. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomNormal.java index 34fc3b313f5..6438b29ec4c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomNormal.java @@ -17,18 +17,17 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom values from a normal distribution. @@ -52,12 +51,12 @@ public final class StatelessRandomNormal extends RawOp implem * @return a new instance of StatelessRandomNormal */ @Endpoint(describeByClass = true) - public static StatelessRandomNormal create(Scope scope, Operand shape, Operand seed, DataType dtype) { + public static StatelessRandomNormal create(Scope scope, Operand shape, Operand seed, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessRandomNormal", scope.makeOpName("StatelessRandomNormal")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(seed.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new StatelessRandomNormal(opBuilder.build()); } @@ -71,7 +70,7 @@ public static Stateles */ @Endpoint(describeByClass = true) public static StatelessRandomNormal create(Scope scope, Operand shape, Operand seed) { - return create(scope, shape, seed, TFloat32.DTYPE); + return create(scope, shape, seed, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomPoisson.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomPoisson.java index 06afdb5eb18..c44af6a5e72 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomPoisson.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomPoisson.java @@ -17,17 +17,16 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom random numbers from a Poisson distribution. @@ -52,13 +51,13 @@ public final class StatelessRandomPoisson extends RawOp imple * @return a new instance of StatelessRandomPoisson */ @Endpoint(describeByClass = true) - public static StatelessRandomPoisson create(Scope scope, Operand shape, Operand seed, Operand lam, DataType dtype) { + public static StatelessRandomPoisson create(Scope scope, Operand shape, Operand seed, Operand lam, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessRandomPoisson", scope.makeOpName("StatelessRandomPoisson")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(seed.asOutput()); opBuilder.addInput(lam.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new StatelessRandomPoisson(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniform.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniform.java index 142c30fbf32..814b89ba69c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniform.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniform.java @@ -17,18 +17,17 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom random values from a uniform distribution. @@ -53,12 +52,12 @@ public final class StatelessRandomUniform extends RawOp imple * @return a new instance of StatelessRandomUniform */ @Endpoint(describeByClass = true) - public static StatelessRandomUniform create(Scope scope, Operand shape, Operand seed, DataType dtype) { + public static StatelessRandomUniform create(Scope scope, Operand shape, Operand seed, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessRandomUniform", scope.makeOpName("StatelessRandomUniform")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(seed.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new StatelessRandomUniform(opBuilder.build()); } @@ -72,7 +71,7 @@ public static Stateles */ @Endpoint(describeByClass = true) public static StatelessRandomUniform create(Scope scope, Operand shape, Operand seed) { - return create(scope, shape, seed, TFloat32.DTYPE); + return create(scope, shape, seed, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformFullInt.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformFullInt.java index 29b7dc61980..d90112620a5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformFullInt.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformFullInt.java @@ -17,17 +17,16 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom random integers from a uniform distribution. @@ -50,12 +49,12 @@ public final class StatelessRandomUniformFullInt extends RawO * @return a new instance of StatelessRandomUniformFullInt */ @Endpoint(describeByClass = true) - public static StatelessRandomUniformFullInt create(Scope scope, Operand shape, Operand seed, DataType dtype) { + public static StatelessRandomUniformFullInt create(Scope scope, Operand shape, Operand seed, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessRandomUniformFullInt", scope.makeOpName("StatelessRandomUniformFullInt")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(seed.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new StatelessRandomUniformFullInt(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformInt.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformInt.java index 287cbba9e79..fb88cc4f69d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformInt.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessRandomUniformInt.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom random integers from a uniform distribution. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessTruncatedNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessTruncatedNormal.java index 41f5c746dfa..e38cd931564 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessTruncatedNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessTruncatedNormal.java @@ -17,18 +17,17 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs deterministic pseudorandom values from a truncated normal distribution. @@ -54,12 +53,12 @@ public final class StatelessTruncatedNormal extends RawOp imp * @return a new instance of StatelessTruncatedNormal */ @Endpoint(describeByClass = true) - public static StatelessTruncatedNormal create(Scope scope, Operand shape, Operand seed, DataType dtype) { + public static StatelessTruncatedNormal create(Scope scope, Operand shape, Operand seed, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("StatelessTruncatedNormal", scope.makeOpName("StatelessTruncatedNormal")); opBuilder.addInput(shape.asOutput()); opBuilder.addInput(seed.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new StatelessTruncatedNormal(opBuilder.build()); } @@ -73,7 +72,7 @@ public static Stateles */ @Endpoint(describeByClass = true) public static StatelessTruncatedNormal create(Scope scope, Operand shape, Operand seed) { - return create(scope, shape, seed, TFloat32.DTYPE); + return create(scope, shape, seed, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/TruncatedNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/TruncatedNormal.java index 1212ddef9e1..ff111c2e8d1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/TruncatedNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/TruncatedNormal.java @@ -17,17 +17,16 @@ package org.tensorflow.op.random; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs random values from a truncated normal distribution. @@ -81,11 +80,11 @@ private Options() { * @return a new instance of TruncatedNormal */ @Endpoint(describeByClass = true) - public static TruncatedNormal create(Scope scope, Operand shape, DataType dtype, Options... options) { + public static TruncatedNormal create(Scope scope, Operand shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TruncatedNormal", scope.makeOpName("TruncatedNormal")); opBuilder.addInput(shape.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.seed != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft.java index 62ace95189d..bd0d2a330ee 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft.java @@ -17,11 +17,11 @@ package org.tensorflow.op.signal; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -63,12 +63,12 @@ public final class Irfft extends RawOp implements Operand * @return a new instance of Irfft */ @Endpoint(describeByClass = true) - public static Irfft create(Scope scope, Operand input, Operand fftLength, DataType Treal) { + public static Irfft create(Scope scope, Operand input, Operand fftLength, Class Treal) { OperationBuilder opBuilder = scope.env().opBuilder("IRFFT", scope.makeOpName("Irfft")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(fftLength.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Treal", Treal); + opBuilder.setAttr("Treal", Operands.toDataType(Treal)); return new Irfft(opBuilder.build()); } @@ -82,7 +82,7 @@ public static Irfft create(Scope scope, */ @Endpoint(describeByClass = true) public static Irfft create(Scope scope, Operand input, Operand fftLength) { - return create(scope, input, fftLength, TFloat32.DTYPE); + return create(scope, input, fftLength, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft2d.java index 6c278ed40a0..861e7aaf62e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft2d.java @@ -17,11 +17,11 @@ package org.tensorflow.op.signal; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -64,12 +64,12 @@ public final class Irfft2d extends RawOp implements Operand Irfft2d create(Scope scope, Operand input, Operand fftLength, DataType Treal) { + public static Irfft2d create(Scope scope, Operand input, Operand fftLength, Class Treal) { OperationBuilder opBuilder = scope.env().opBuilder("IRFFT2D", scope.makeOpName("Irfft2d")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(fftLength.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Treal", Treal); + opBuilder.setAttr("Treal", Operands.toDataType(Treal)); return new Irfft2d(opBuilder.build()); } @@ -83,7 +83,7 @@ public static Irfft2d create(Scope scope */ @Endpoint(describeByClass = true) public static Irfft2d create(Scope scope, Operand input, Operand fftLength) { - return create(scope, input, fftLength, TFloat32.DTYPE); + return create(scope, input, fftLength, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft3d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft3d.java index d57d57566dd..d68a0904eff 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft3d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Irfft3d.java @@ -17,11 +17,11 @@ package org.tensorflow.op.signal; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -64,12 +64,12 @@ public final class Irfft3d extends RawOp implements Operand Irfft3d create(Scope scope, Operand input, Operand fftLength, DataType Treal) { + public static Irfft3d create(Scope scope, Operand input, Operand fftLength, Class Treal) { OperationBuilder opBuilder = scope.env().opBuilder("IRFFT3D", scope.makeOpName("Irfft3d")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(fftLength.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Treal", Treal); + opBuilder.setAttr("Treal", Operands.toDataType(Treal)); return new Irfft3d(opBuilder.build()); } @@ -83,7 +83,7 @@ public static Irfft3d create(Scope scope */ @Endpoint(describeByClass = true) public static Irfft3d create(Scope scope, Operand input, Operand fftLength) { - return create(scope, input, fftLength, TFloat32.DTYPE); + return create(scope, input, fftLength, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft.java index 4571349f1db..3b8e54a361e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft.java @@ -17,11 +17,11 @@ package org.tensorflow.op.signal; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -59,12 +59,12 @@ public final class Rfft extends RawOp implements Operand { * @return a new instance of Rfft */ @Endpoint(describeByClass = true) - public static Rfft create(Scope scope, Operand input, Operand fftLength, DataType Tcomplex) { + public static Rfft create(Scope scope, Operand input, Operand fftLength, Class Tcomplex) { OperationBuilder opBuilder = scope.env().opBuilder("RFFT", scope.makeOpName("Rfft")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(fftLength.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Tcomplex", Tcomplex); + opBuilder.setAttr("Tcomplex", Operands.toDataType(Tcomplex)); return new Rfft(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft2d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft2d.java index e3daf88207b..cf00ae1b33b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft2d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft2d.java @@ -17,11 +17,11 @@ package org.tensorflow.op.signal; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -60,12 +60,12 @@ public final class Rfft2d extends RawOp implements Operand { * @return a new instance of Rfft2d */ @Endpoint(describeByClass = true) - public static Rfft2d create(Scope scope, Operand input, Operand fftLength, DataType Tcomplex) { + public static Rfft2d create(Scope scope, Operand input, Operand fftLength, Class Tcomplex) { OperationBuilder opBuilder = scope.env().opBuilder("RFFT2D", scope.makeOpName("Rfft2d")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(fftLength.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Tcomplex", Tcomplex); + opBuilder.setAttr("Tcomplex", Operands.toDataType(Tcomplex)); return new Rfft2d(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft3d.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft3d.java index f72996a23cf..c01d979d7cb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft3d.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/signal/Rfft3d.java @@ -17,11 +17,11 @@ package org.tensorflow.op.signal; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -60,12 +60,12 @@ public final class Rfft3d extends RawOp implements Operand { * @return a new instance of Rfft3d */ @Endpoint(describeByClass = true) - public static Rfft3d create(Scope scope, Operand input, Operand fftLength, DataType Tcomplex) { + public static Rfft3d create(Scope scope, Operand input, Operand fftLength, Class Tcomplex) { OperationBuilder opBuilder = scope.env().opBuilder("RFFT3D", scope.makeOpName("Rfft3d")); opBuilder.addInput(input.asOutput()); opBuilder.addInput(fftLength.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("Tcomplex", Tcomplex); + opBuilder.setAttr("Tcomplex", Operands.toDataType(Tcomplex)); return new Rfft3d(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java index 58e4e8c2ff1..c06191594a3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs sparse-output bin counting for a tf.tensor input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DeserializeSparse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DeserializeSparse.java index a6873029f69..f229dd2ad98 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DeserializeSparse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DeserializeSparse.java @@ -17,11 +17,11 @@ package org.tensorflow.op.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -89,11 +89,11 @@ public final class DeserializeSparse extends RawOp { * @return a new instance of DeserializeSparse */ @Endpoint(describeByClass = true) - public static DeserializeSparse create(Scope scope, Operand serializedSparse, DataType dtype) { + public static DeserializeSparse create(Scope scope, Operand serializedSparse, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("DeserializeSparse", scope.makeOpName("DeserializeSparse")); opBuilder.addInput(serializedSparse.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new DeserializeSparse(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorTakeGradient.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorTakeGradient.java index bfe138e1589..f7250020c10 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorTakeGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorTakeGradient.java @@ -17,11 +17,11 @@ package org.tensorflow.op.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -57,12 +57,12 @@ public final class SparseAccumulatorTakeGradient extends RawOp * @return a new instance of SparseAccumulatorTakeGradient */ @Endpoint(describeByClass = true) - public static SparseAccumulatorTakeGradient create(Scope scope, Operand handle, Operand numRequired, DataType dtype) { + public static SparseAccumulatorTakeGradient create(Scope scope, Operand handle, Operand numRequired, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("SparseAccumulatorTakeGradient", scope.makeOpName("SparseAccumulatorTakeGradient")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(numRequired.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new SparseAccumulatorTakeGradient(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java index abb959533b6..bea0d8b3c71 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Counts the number of occurrences of each value in an integer array. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseConditionalAccumulator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseConditionalAccumulator.java index c428e2286cf..e012cb631c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseConditionalAccumulator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseConditionalAccumulator.java @@ -17,12 +17,12 @@ package org.tensorflow.op.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -92,10 +92,10 @@ private Options() { * @return a new instance of SparseConditionalAccumulator */ @Endpoint(describeByClass = true) - public static SparseConditionalAccumulator create(Scope scope, DataType dtype, Shape shape, Options... options) { + public static SparseConditionalAccumulator create(Scope scope, Class dtype, Shape shape, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("SparseConditionalAccumulator", scope.makeOpName("SparseConditionalAccumulator")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); opBuilder.setAttr("shape", shape); if (options != null) { for (Options opts : options) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java index 80f3aa2c352..85e6707f0c0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Performs sparse-output bin counting for a sparse tensor input. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseMatMul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseMatMul.java index 9c5b6df4b07..d254cb3224c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseMatMul.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseMatMul.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Multiply matrix "a" by matrix "b". diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMax.java index 863e1112a31..190f06c2dbf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMax.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the max of elements across dimensions of a SparseTensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMaxSparse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMaxSparse.java index 1bf33bd9e68..ae36fcb07e3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMaxSparse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseReduceMaxSparse.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the max of elements across dimensions of a SparseTensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMean.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMean.java index 4269fd92830..a373bd33f03 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMean.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMean.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the mean along sparse segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanGrad.java index cb2ba01028b..8e558afbe9e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients for SparseSegmentMean. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanWithNumSegments.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanWithNumSegments.java index 4ea84634502..d62b956b096 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanWithNumSegments.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentMeanWithNumSegments.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the mean along sparse segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtN.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtN.java index 0ffcbc1a257..50fe5cc8068 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtN.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtN.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the sum along sparse segments of a tensor divided by the sqrt of N. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNGrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNGrad.java index 772b5a71f31..9faa31633c0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNGrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNGrad.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes gradients for SparseSegmentSqrtN. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNWithNumSegments.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNWithNumSegments.java index 5288004ed14..f962af5be0f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNWithNumSegments.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSqrtNWithNumSegments.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the sum along sparse segments of a tensor divided by the sqrt of N. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSum.java index 7fd16c2eb43..add00c19798 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSum.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the sum along sparse segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSumWithNumSegments.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSumWithNumSegments.java index e4775c51cf2..b956d005069 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSumWithNumSegments.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSegmentSumWithNumSegments.java @@ -26,7 +26,6 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the sum along sparse segments of a tensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSoftmax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSoftmax.java index 7e3f66ebc05..252a00e838b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSoftmax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSoftmax.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Applies softmax to a batched N-D `SparseTensor`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSparseMaximum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSparseMaximum.java index 9903a494197..8886a6fda12 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSparseMaximum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseSparseMaximum.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Returns the element-wise max of two SparseTensors. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/TakeManySparseFromTensorsMap.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/TakeManySparseFromTensorsMap.java index ab0123a2c14..a18674e8c4c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/TakeManySparseFromTensorsMap.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/TakeManySparseFromTensorsMap.java @@ -17,11 +17,11 @@ package org.tensorflow.op.sparse; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -124,11 +124,11 @@ private Options() { * @return a new instance of TakeManySparseFromTensorsMap */ @Endpoint(describeByClass = true) - public static TakeManySparseFromTensorsMap create(Scope scope, Operand sparseHandles, DataType dtype, Options... options) { + public static TakeManySparseFromTensorsMap create(Scope scope, Operand sparseHandles, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("TakeManySparseFromTensorsMap", scope.makeOpName("TakeManySparseFromTensorsMap")); opBuilder.addInput(sparseHandles.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); if (options != null) { for (Options opts : options) { if (opts.container != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/StringNGrams.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/StringNGrams.java index 2c967f3ef62..e504e164c41 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/StringNGrams.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/StringNGrams.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Creates ngrams from ragged string data. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/Substr.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/Substr.java index 1848bfc513c..472e30bc98c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/Substr.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/Substr.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Return substrings from `Tensor` of strings. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/ToNumber.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/ToNumber.java index db2bac2cf74..547a4f17b95 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/ToNumber.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/ToNumber.java @@ -17,11 +17,11 @@ package org.tensorflow.op.strings; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -29,7 +29,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Converts each string in the input Tensor to the specified numeric type. @@ -58,11 +57,11 @@ public final class ToNumber extends RawOp implements Operand< * @return a new instance of ToNumber */ @Endpoint(describeByClass = true) - public static ToNumber create(Scope scope, Operand stringTensor, DataType outType) { + public static ToNumber create(Scope scope, Operand stringTensor, Class outType) { OperationBuilder opBuilder = scope.env().opBuilder("StringToNumber", scope.makeOpName("ToNumber")); opBuilder.addInput(stringTensor.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("out_type", outType); + opBuilder.setAttr("out_type", Operands.toDataType(outType)); return new ToNumber(opBuilder.build()); } @@ -75,7 +74,7 @@ public static ToNumber create(Scope scope, Operand create(Scope scope, Operand stringTensor) { - return create(scope, stringTensor, TFloat32.DTYPE); + return create(scope, stringTensor, TFloat32.class); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecode.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecode.java index d5532df9e78..65e5e256780 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecode.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecode.java @@ -17,11 +17,11 @@ package org.tensorflow.op.strings; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -30,7 +30,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Decodes each string in `input` into a sequence of Unicode code points. @@ -116,12 +115,12 @@ private Options() { * @return a new instance of UnicodeDecode */ @Endpoint(describeByClass = true) - public static UnicodeDecode create(Scope scope, Operand input, String inputEncoding, DataType Tsplits, Options... options) { + public static UnicodeDecode create(Scope scope, Operand input, String inputEncoding, Class Tsplits, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("UnicodeDecode", scope.makeOpName("UnicodeDecode")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("input_encoding", inputEncoding); - opBuilder.setAttr("Tsplits", Tsplits); + opBuilder.setAttr("Tsplits", Operands.toDataType(Tsplits)); if (options != null) { for (Options opts : options) { if (opts.errors != null) { @@ -151,7 +150,7 @@ public static UnicodeDecode create(Scope scope, Operand create(Scope scope, Operand input, String inputEncoding, Options... options) { - return create(scope, input, inputEncoding, TInt64.DTYPE, options); + return create(scope, input, inputEncoding, TInt64.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecodeWithOffsets.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecodeWithOffsets.java index 1ab449f34c1..2b10934b37a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecodeWithOffsets.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeDecodeWithOffsets.java @@ -17,11 +17,11 @@ package org.tensorflow.op.strings; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -30,7 +30,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Decodes each string in `input` into a sequence of Unicode code points. @@ -122,12 +121,12 @@ private Options() { * @return a new instance of UnicodeDecodeWithOffsets */ @Endpoint(describeByClass = true) - public static UnicodeDecodeWithOffsets create(Scope scope, Operand input, String inputEncoding, DataType Tsplits, Options... options) { + public static UnicodeDecodeWithOffsets create(Scope scope, Operand input, String inputEncoding, Class Tsplits, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("UnicodeDecodeWithOffsets", scope.makeOpName("UnicodeDecodeWithOffsets")); opBuilder.addInput(input.asOutput()); opBuilder = scope.apply(opBuilder); opBuilder.setAttr("input_encoding", inputEncoding); - opBuilder.setAttr("Tsplits", Tsplits); + opBuilder.setAttr("Tsplits", Operands.toDataType(Tsplits)); if (options != null) { for (Options opts : options) { if (opts.errors != null) { @@ -157,7 +156,7 @@ public static UnicodeDecodeWithOffsets create(Scope scope */ @Endpoint(describeByClass = true) public static UnicodeDecodeWithOffsets create(Scope scope, Operand input, String inputEncoding, Options... options) { - return create(scope, input, inputEncoding, TInt64.DTYPE, options); + return create(scope, input, inputEncoding, TInt64.class, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeEncode.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeEncode.java index aa49be53b13..d24d09a67ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeEncode.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnicodeEncode.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Encode a tensor of ints into unicode strings. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnsortedSegmentJoin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnsortedSegmentJoin.java index 44532692d04..075d21d219e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnsortedSegmentJoin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/strings/UnsortedSegmentJoin.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Joins the elements of `inputs` based on `segment_ids`. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/HistogramSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/HistogramSummary.java index a309f20f391..0acfec7cac1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/HistogramSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/HistogramSummary.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs a `Summary` protocol buffer with a histogram. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java index 796f5e24bfc..7217055142b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs a `Summary` protocol buffer with images. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ScalarSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ScalarSummary.java index 48517f8a00e..37740dc7ef3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ScalarSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ScalarSummary.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Outputs a `Summary` protocol buffer with scalar values. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteHistogramSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteHistogramSummary.java index 0c4d1363c35..7daeb5ccbd0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteHistogramSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteHistogramSummary.java @@ -27,7 +27,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** */ diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java index 7a816e98b6a..70eed3c13a1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java @@ -28,7 +28,6 @@ import org.tensorflow.types.TString; import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** */ diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java index 45dddbda36f..ce5a91db59b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java @@ -27,7 +27,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** */ diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/CrossReplicaSum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/CrossReplicaSum.java index dd45b4298b8..ff2cb8390be 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/CrossReplicaSum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/CrossReplicaSum.java @@ -27,7 +27,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * An Op to sum inputs across replicated TPU instances. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java index 5b799cc2f89..a6256e829fa 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Eases the porting of code that uses tf.nn.embedding_lookup(). diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java index d724e366c91..f302cf7ad5f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * An op that enqueues TPUEmbedding input indices from a SparseTensor. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java index e7f0444cbf4..a482e5f5ebf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java @@ -28,7 +28,6 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Eases the porting of code that uses tf.nn.embedding_lookup_sparse(). diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeue.java index 265a9fa70f8..325db448246 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeue.java @@ -17,12 +17,12 @@ package org.tensorflow.op.tpu; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -45,10 +45,10 @@ public final class InfeedDequeue extends RawOp implements Opera * @return a new instance of InfeedDequeue */ @Endpoint(describeByClass = true) - public static InfeedDequeue create(Scope scope, DataType dtype, Shape shape) { + public static InfeedDequeue create(Scope scope, Class dtype, Shape shape) { OperationBuilder opBuilder = scope.env().opBuilder("InfeedDequeue", scope.makeOpName("InfeedDequeue")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); opBuilder.setAttr("shape", shape); return new InfeedDequeue(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeueTuple.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeueTuple.java index 941886098b6..4816ea0a306 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeueTuple.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedDequeueTuple.java @@ -20,12 +20,12 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -46,14 +46,10 @@ public final class InfeedDequeueTuple extends RawOp implements Iterable> dtypes, List shapes) { + public static InfeedDequeueTuple create(Scope scope, List> dtypes, List shapes) { OperationBuilder opBuilder = scope.env().opBuilder("InfeedDequeueTuple", scope.makeOpName("InfeedDequeueTuple")); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); Shape[] shapesArray = new Shape[shapes.size()]; for (int i = 0; i < shapesArray.length; ++i) { shapesArray[i] = shapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeue.java index 29c516c6a57..e5aa5a35ed7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeue.java @@ -17,12 +17,12 @@ package org.tensorflow.op.tpu; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -69,10 +69,10 @@ private Options() { * @return a new instance of OutfeedDequeue */ @Endpoint(describeByClass = true) - public static OutfeedDequeue create(Scope scope, DataType dtype, Shape shape, Options... options) { + public static OutfeedDequeue create(Scope scope, Class dtype, Shape shape, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OutfeedDequeue", scope.makeOpName("OutfeedDequeue")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); opBuilder.setAttr("shape", shape); if (options != null) { for (Options opts : options) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeueTuple.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeueTuple.java index e86472c77e0..705666e1f02 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeueTuple.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedDequeueTuple.java @@ -20,12 +20,12 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -71,14 +71,10 @@ private Options() { * @return a new instance of OutfeedDequeueTuple */ @Endpoint(describeByClass = true) - public static OutfeedDequeueTuple create(Scope scope, List> dtypes, List shapes, Options... options) { + public static OutfeedDequeueTuple create(Scope scope, List> dtypes, List shapes, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("OutfeedDequeueTuple", scope.makeOpName("OutfeedDequeueTuple")); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); Shape[] shapesArray = new Shape[shapes.size()]; for (int i = 0; i < shapesArray.length; ++i) { shapesArray[i] = shapes.get(i); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorTakeGradient.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorTakeGradient.java index cb6f79796ff..ed8becceee2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorTakeGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorTakeGradient.java @@ -17,11 +17,11 @@ package org.tensorflow.op.train; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -55,12 +55,12 @@ public final class AccumulatorTakeGradient extends RawOp implem * @return a new instance of AccumulatorTakeGradient */ @Endpoint(describeByClass = true) - public static AccumulatorTakeGradient create(Scope scope, Operand handle, Operand numRequired, DataType dtype) { + public static AccumulatorTakeGradient create(Scope scope, Operand handle, Operand numRequired, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("AccumulatorTakeGradient", scope.makeOpName("AccumulatorTakeGradient")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(numRequired.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new AccumulatorTakeGradient(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ConditionalAccumulator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ConditionalAccumulator.java index 5cc5297559f..684d9e3485e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ConditionalAccumulator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ConditionalAccumulator.java @@ -17,12 +17,12 @@ package org.tensorflow.op.train; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -92,10 +92,10 @@ private Options() { * @return a new instance of ConditionalAccumulator */ @Endpoint(describeByClass = true) - public static ConditionalAccumulator create(Scope scope, DataType dtype, Shape shape, Options... options) { + public static ConditionalAccumulator create(Scope scope, Class dtype, Shape shape, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ConditionalAccumulator", scope.makeOpName("ConditionalAccumulator")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); opBuilder.setAttr("shape", shape); if (options != null) { for (Options opts : options) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorTakeGradient.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorTakeGradient.java index 5758aec8517..62e33a44749 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorTakeGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorTakeGradient.java @@ -17,11 +17,11 @@ package org.tensorflow.op.train; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -53,12 +53,12 @@ public final class ResourceAccumulatorTakeGradient extends RawO * @return a new instance of ResourceAccumulatorTakeGradient */ @Endpoint(describeByClass = true) - public static ResourceAccumulatorTakeGradient create(Scope scope, Operand handle, Operand numRequired, DataType dtype) { + public static ResourceAccumulatorTakeGradient create(Scope scope, Operand handle, Operand numRequired, Class dtype) { OperationBuilder opBuilder = scope.env().opBuilder("ResourceAccumulatorTakeGradient", scope.makeOpName("ResourceAccumulatorTakeGradient")); opBuilder.addInput(handle.asOutput()); opBuilder.addInput(numRequired.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); return new ResourceAccumulatorTakeGradient(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceConditionalAccumulator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceConditionalAccumulator.java index 5471a48a5e9..d44b93e6ac1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceConditionalAccumulator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceConditionalAccumulator.java @@ -17,12 +17,12 @@ package org.tensorflow.op.train; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -92,10 +92,10 @@ private Options() { * @return a new instance of ResourceConditionalAccumulator */ @Endpoint(describeByClass = true) - public static ResourceConditionalAccumulator create(Scope scope, DataType dtype, Shape shape, Options... options) { + public static ResourceConditionalAccumulator create(Scope scope, Class dtype, Shape shape, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("ResourceConditionalAccumulator", scope.makeOpName("ResourceConditionalAccumulator")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); opBuilder.setAttr("shape", shape); if (options != null) { for (Options opts : options) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Restore.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Restore.java index a198182719e..bf19384f360 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Restore.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Restore.java @@ -20,11 +20,11 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -65,17 +65,13 @@ public final class Restore extends RawOp implements Iterable> { * @return a new instance of Restore */ @Endpoint(describeByClass = true) - public static Restore create(Scope scope, Operand prefix, Operand tensorNames, Operand shapeAndSlices, List> dtypes) { + public static Restore create(Scope scope, Operand prefix, Operand tensorNames, Operand shapeAndSlices, List> dtypes) { OperationBuilder opBuilder = scope.env().opBuilder("RestoreV2", scope.makeOpName("Restore")); opBuilder.addInput(prefix.asOutput()); opBuilder.addInput(tensorNames.asOutput()); opBuilder.addInput(shapeAndSlices.asOutput()); opBuilder = scope.apply(opBuilder); - DataType[] dtypesArray = new DataType[dtypes.size()]; - for (int i = 0; i < dtypesArray.length; ++i) { - dtypesArray[i] = dtypes.get(i); - } - opBuilder.setAttr("dtypes", dtypesArray); + opBuilder.setAttr("dtypes", Operands.toDataTypes(dtypes)); return new Restore(opBuilder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/RestoreSlice.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/RestoreSlice.java index 55fded1c9ff..579b5b8dbad 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/RestoreSlice.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/RestoreSlice.java @@ -17,11 +17,11 @@ package org.tensorflow.op.train; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -79,13 +79,13 @@ private Options() { * @return a new instance of RestoreSlice */ @Endpoint(describeByClass = true) - public static RestoreSlice create(Scope scope, Operand filePattern, Operand tensorName, Operand shapeAndSlice, DataType dt, Options... options) { + public static RestoreSlice create(Scope scope, Operand filePattern, Operand tensorName, Operand shapeAndSlice, Class dt, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("RestoreSlice", scope.makeOpName("RestoreSlice")); opBuilder.addInput(filePattern.asOutput()); opBuilder.addInput(tensorName.asOutput()); opBuilder.addInput(shapeAndSlice.asOutput()); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dt", dt); + opBuilder.setAttr("dt", Operands.toDataType(dt)); if (options != null) { for (Options opts : options) { if (opts.preferredShard != null) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Recv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Recv.java index 13a7e85b001..05781f61017 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Recv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Recv.java @@ -17,12 +17,12 @@ package org.tensorflow.op.xla; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; @@ -50,10 +50,10 @@ public final class Recv extends RawOp implements Operand { * @return a new instance of Recv */ @Endpoint(describeByClass = true) - public static Recv create(Scope scope, DataType dtype, String tensorName, Shape shape) { + public static Recv create(Scope scope, Class dtype, String tensorName, Shape shape) { OperationBuilder opBuilder = scope.env().opBuilder("XlaRecv", scope.makeOpName("Recv")); opBuilder = scope.apply(opBuilder); - opBuilder.setAttr("dtype", dtype); + opBuilder.setAttr("dtype", Operands.toDataType(dtype)); opBuilder.setAttr("tensor_name", tensorName); opBuilder.setAttr("shape", shape); return new Recv(opBuilder.build()); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java index 18f42d08e82..0ffd6c2205e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java @@ -17,6 +17,7 @@ import org.bytedeco.javacpp.Pointer; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.family.TType; /** @@ -76,7 +77,7 @@ public String toString() { * @param outputIdx index of the output of this operation * @return output tensor datatype */ - abstract DataType dtype(int outputIdx); + abstract DataType dtype(int outputIdx); /** * Returns the tensor of the {@code outputIdx}th output of this operation. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 0bb0d20cae3..71dc0f7cefc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -54,7 +54,7 @@ public class ConcreteFunction implements AutoCloseable { * public class MyModel { * * public static Signature addTwo(Ops tf) { - * Placeholder input = tf.placeholder(TFloat32.DTYPE); + * Placeholder input = tf.placeholder(TFloat32.class); * Add output = tf.math.add(input, tf.constant(2.0f)); * return Signature.builder("addTwo").input("x", input).output("y", output).build(); * } @@ -92,7 +92,7 @@ public static ConcreteFunction create(Function functionBuilder) * *

{@code
    * try (Graph g = new Graph()) {
-   *   Placeholder input = tf.placeholder(TFloat32.DTYPE);
+   *   Placeholder input = tf.placeholder(TFloat32.class);
    *   Add output = tf.math.add(input, tf.constant(2.0f));
    *   Signature signature = Signature.builder().input("x", input).output("y", output).build();
    *
@@ -121,7 +121,7 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
    *
    * 
{@code
    * try (Graph g = new Graph()) {
-   *   Placeholder input = tf.placeholder(TFloat32.DTYPE);
+   *   Placeholder input = tf.placeholder(TFloat32.class);
    *   Add output = tf.math.add(input, tf.constant(2.0f));
    *   Signature signature = Signature.builder().input("x", input).output("y", output).build();
    *
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java
deleted file mode 100644
index 60657837969..00000000000
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java
+++ /dev/null
@@ -1,166 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-package org.tensorflow;
-
-import org.tensorflow.types.TBfloat16;
-import org.tensorflow.types.TBool;
-import org.tensorflow.types.TFloat16;
-import org.tensorflow.types.TFloat32;
-import org.tensorflow.types.TFloat64;
-import org.tensorflow.types.TInt32;
-import org.tensorflow.types.TInt64;
-import org.tensorflow.types.TUint8;
-import org.tensorflow.types.TString;
-import org.tensorflow.types.family.TType;
-
-/** Represents a type of elements in a {@link Tensor} */
-public final class DataType {
-
-  /**
-   * Creates a new datatype
-   *
-   * @param name readable-name for this type
-   * @param value must match the corresponding TF_* value in the TensorFlow C API.
-   * @param byteSize size of an element of this type, in bytes, -1 if unknown
-   * @param  a tensor type
-   * @param tensorMapper method for mapping tensor memory to elements of this type
-   */
-  public static  DataType create(
-      String name, int value, int byteSize, TensorMapper tensorMapper) {
-    return new DataType<>(name, value, byteSize, tensorMapper);
-  }
-
-  /**
-   * Gets the DataType associated with the readable-name for the type
-   * 

The name must match exactly the name used to create the desired DataType

- * - * @param name readable-name for the type - * @return the DataType - * @throws java.lang.IllegalArgumentException if the name is not a valid data type name - * @throws java.lang.NullPointerException if name is null - */ - public static DataType of(String name) { - switch (name) { - case TBfloat16.NAME: - return TBfloat16.DTYPE; - case TFloat16.NAME: - return TFloat16.DTYPE; - case TFloat32.NAME: - return TFloat32.DTYPE; - case TFloat64.NAME: - return TFloat64.DTYPE; - case TUint8.NAME: - return TUint8.DTYPE; - case TInt32.NAME: - return TInt32.DTYPE; - case TInt64.NAME: - return TInt64.DTYPE; - case TBool.NAME: - return TBool.DTYPE; - case TString.NAME: - return TString.DTYPE; - default: - throw new IllegalArgumentException(String.format("%s is an unknown DataType", name)); - } - } - - /** Returns true if this data type represents a floating point type */ - public boolean isFloating() { - switch (this.name()) { - case TBfloat16.NAME: - case TFloat16.NAME: - case TFloat32.NAME: - case TFloat64.NAME: - return true; - default: - return false; - } - } - - /** Returns true if this data type represents an integer type */ - public boolean isInteger() { - switch (this.name()) { - case TInt32.NAME: - case TInt64.NAME: - case TUint8.NAME: - return true; - default: - return false; - } - } - - /** Returns true if this data type represents a numeric type */ - public boolean isNumeric() { - return isFloating() || isInteger(); - } - - /** Returns true if this data type represents a boolean type */ - public boolean isBoolean() { - return this.name().equals(TBool.NAME); - } - - /** Returns true if this data type represents a string type */ - public boolean isString() { - return this.name().equals(TString.NAME); - } - - /** Returns the size of an element of this type, in bytes, or -1 if element size is variable. */ - public int byteSize() { - return byteSize; - } - - /** Returns true if this datatype has elements of variable length */ - public boolean isVariableLength() { - return byteSize == -1; - } - - /** Returns a readable name for this type */ - public String name() { - return name; - } - - @Override - public String toString() { - return name + " (" + nativeCode + ")"; - } - - /** Returns the numeric code for this datatype, as recognized by the native library (C API) */ - int nativeCode() { - return nativeCode; - } - - /** - * Maps a raw tensor to a typed tensor. - * - * @param tensor tensor to map - * @return data structure of elements of this type - */ - T map(RawTensor tensor) { - return tensorMapper.mapDense(tensor); - } - - private final int nativeCode; - private final int byteSize; - private final String name; - private final TensorMapper tensorMapper; - - private DataType(String name, int nativeCode, int byteSize, TensorMapper tensorMapper) { - this.name = name; - this.nativeCode = nativeCode; - this.byteSize = byteSize; - this.tensorMapper = tensorMapper; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java deleted file mode 100644 index 77c0de0c83f..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2019 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ======================================================================= - */ - -package org.tensorflow; - -import java.util.HashMap; -import java.util.Map; -import org.tensorflow.types.TBfloat16; -import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat16; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.TString; -import org.tensorflow.types.TUint8; - -/** - * Utility class for working with {@link DataType} objects. - */ -final class DataTypes { - - /** - * Find a data type from the type code returned by the native layer (C API). - * - *

Only data types registered via {@link #register(DataType)} can be resolved. - * - * @param nativeCode native code - * @return data type for this code - * @throws IllegalArgumentException if the code matches no registered data type - */ - static DataType fromNativeCode(int nativeCode) { - DataType dataType = DATA_TYPE_REGISTRY.get(nativeCode); - if (dataType == null) { - throw new IllegalArgumentException( - "DataType " + nativeCode + " is not recognized in Java (version " + TensorFlow.version() + ")"); - } - return dataType; - } - - private static final Map> DATA_TYPE_REGISTRY = new HashMap<>(); - - static { - register(TBool.DTYPE); - register(TFloat64.DTYPE); - register(TFloat32.DTYPE); - register(TFloat16.DTYPE); - register(TInt32.DTYPE); - register(TInt64.DTYPE); - register(TString.DTYPE); - register(TUint8.DTYPE); - register(TBfloat16.DTYPE); - } - - // TODO (karllessard): Right now this method is private but we might want to expose it - // to allow user to register custom data types? - private static void register(DataType dataType) { - DATA_TYPE_REGISTRY.put(dataType.nativeCode(), dataType); - DATA_TYPE_REGISTRY.put(dataType.nativeCode() + 100, dataType); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 30387e390ed..09e5a47f8fd 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -29,6 +29,7 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; /** * Implementation of an {@link Operation} executed eagerly. @@ -83,12 +84,12 @@ public int inputListLength(final String name) { } @Override - public TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) { + TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) { return outputHandles[outputIndex]; } @Override - public Shape shape(int outputIndex) { + Shape shape(int outputIndex) { // If the tensor of this output has already been resolved, return its shape. // Otherwise, retrieve the tensor shape from the native library. Tensor tensor = outputTensors.get(outputIndex); @@ -104,7 +105,7 @@ public Shape shape(int outputIndex) { } @Override - public DataType dtype(int outputIndex) { + DataType dtype(int outputIndex) { // If the tensor of this output has already been resolved, return its datatype. // Otherwise, retrieve the tensor datatype from the native library. Tensor tensor = outputTensors.get(outputIndex); @@ -112,11 +113,11 @@ public DataType dtype(int outputIndex) { return tensor.dataType(); } TFE_TensorHandle outputNativeHandle = getUnsafeNativeHandle(outputIndex); - return DataTypes.fromNativeCode(dataType(outputNativeHandle)); + return DataType.forNumber(dataType(outputNativeHandle)); } @Override - public Tensor tensor(int outputIndex) { + Tensor tensor(int outputIndex) { Tensor tensor = outputTensors.get(outputIndex); if (tensor == null) { tensor = resolveTensor(outputIndex); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index 5c975929fee..9df8444a11f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -48,6 +48,7 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; /** * An {@link OperationBuilder} for building {@link Operation Operations} that are executed eagerly. @@ -159,16 +160,16 @@ public EagerOperationBuilder setAttr(String name, boolean[] values) { } @Override - public EagerOperationBuilder setAttr(String name, DataType value) { - setAttrType(opHandle, name, value.nativeCode()); + public EagerOperationBuilder setAttr(String name, DataType value) { + setAttrType(opHandle, name, value.getNumber()); return this; } @Override - public EagerOperationBuilder setAttr(String name, DataType[] values) { + public EagerOperationBuilder setAttr(String name, DataType[] values) { int[] c = new int[values.length]; for (int i = 0; i < values.length; ++i) { - c[i] = values[i].nativeCode(); + c[i] = values[i].getNumber(); } setAttrTypeList(opHandle, name, c); return this; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 142d481c04f..d70460ee4ea 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -54,6 +54,7 @@ import org.tensorflow.proto.framework.GraphDef; import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** @@ -767,14 +768,14 @@ private static SaverDef addVariableSaver(Graph graph) { List varNames = new ArrayList<>(); List> varOutputs = new ArrayList<>(); - List> varTypes = new ArrayList<>(); + List> varTypes = new ArrayList<>(); for (Iterator iter = graph.operations(); iter.hasNext();) { Operation op = iter.next(); if (op.type().equals("VariableV2")) { varNames.add(op.name()); varOutputs.add(op.output(0)); - varTypes.add(op.output(0).dataType()); + varTypes.add(op.output(0).type()); } } @@ -783,7 +784,7 @@ private static SaverDef addVariableSaver(Graph graph) { Constant varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp))); Operand varSlices = tf.zerosLike(varNamesTensor); - Placeholder saveFilename = tf.placeholder(TString.DTYPE); + Placeholder saveFilename = tf.placeholder(TString.class); Save saveVariables = tf.train.save( saveFilename, varNamesTensor, diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java index d2fbc4e4995..e1255748c3b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java @@ -31,6 +31,7 @@ import org.tensorflow.internal.c_api.TF_Output; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; /** * Implementation for an {@link Operation} added as a node to a {@link Graph}. @@ -148,10 +149,10 @@ Shape shape(int outputIdx) { } @Override - DataType dtype(int outputIdx) { + DataType dtype(int outputIdx) { Graph.Reference r = graph.ref(); try { - return DataTypes.fromNativeCode(dtype(r.nativeHandle(), getUnsafeNativeHandle(), outputIdx)); + return DataType.forNumber(dtype(r.nativeHandle(), getUnsafeNativeHandle(), outputIdx)); } finally { r.close(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index 5fda65480e9..927d9c52dd1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -52,6 +52,7 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; /** An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. */ public final class GraphOperationBuilder implements OperationBuilder { @@ -224,7 +225,7 @@ public GraphOperationBuilder setAttr(String name, boolean[] value) { public GraphOperationBuilder setAttr(String name, DataType value) { Graph.Reference r = graph.ref(); try { - setAttrType(unsafeNativeHandle, name, value.nativeCode()); + setAttrType(unsafeNativeHandle, name, value.getNumber()); } finally { r.close(); } @@ -235,7 +236,7 @@ public GraphOperationBuilder setAttr(String name, DataType value) { public GraphOperationBuilder setAttr(String name, DataType[] value) { int[] ctypes = new int[value.length]; for (int i = 0; i < value.length; ++i) { - ctypes[i] = value[i].nativeCode(); + ctypes[i] = value[i].getNumber(); } Graph.Reference r = graph.ref(); try { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java index 31a93fb999a..80f62eb5acc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java @@ -30,11 +30,11 @@ * * // The "decodeJpeg" operation can be used as an operand to the "cast" operation * Operand decodeJpeg = tf.image.decodeJpeg(...); - * tf.dtypes.cast(decodeJpeg, TFloat32.DTYPE); + * tf.dtypes.cast(decodeJpeg, TFloat32.class); * * // The output "y" of the "unique" operation can be used as an operand to the "cast" operation * Output y = tf.unique(...).y(); - * tf.dtypes.cast(y, TFloat32.DTYPE); + * tf.dtypes.cast(y, TFloat32.class); * * // The "split" operation can be used as operand list to the "concat" operation * Iterable> split = tf.split(...); @@ -66,10 +66,10 @@ default T asTensor() { } /** - * Returns the data type of this operand + * Returns the tensor type of this operand */ - default DataType dataType() { - return asOutput().dataType(); + default Class type() { + return asOutput().type(); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java index 79f21c33fb7..a487d8b9237 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java @@ -16,6 +16,7 @@ package org.tensorflow; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; /** * A builder for {@link Operation}s. @@ -177,7 +178,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, DataType value); + OperationBuilder setAttr(String name, DataType value); /** * Set the type values of an attribute of the operation being built. @@ -186,7 +187,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, DataType[] value); + OperationBuilder setAttr(String name, DataType[] value); /** * Set the tensor value of an attribute of the operation being built. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java index 5b8337f4d70..9e7dedfdc75 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java @@ -17,8 +17,9 @@ import java.util.Objects; import org.bytedeco.javacpp.Pointer; +import org.tensorflow.internal.types.registry.TensorTypeRegistry; import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.Shaped; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.family.TType; /** @@ -39,23 +40,30 @@ public int index() { /** Returns the DataType of the tensor referred to by this Output. */ @SuppressWarnings("unchecked") - public DataType dataType() { - return (DataType)operation.dtype(index); + public DataType dataType() { + return operation.dtype(index); + } + + /** Returns the type of the tensor referred to by this Output. */ + @SuppressWarnings("unchecked") + @Override + public Class type() { + return (Class)TensorTypeRegistry.find(dataType()).type(); } /** * Returns this Output object with the type {@code Output}. This method is useful when given a * value of type {@code Output}. * - * @param dt any supported tensor data type + * @param type any supported tensor type * @throws IllegalArgumentException if the actual data type of this object does not match the type * {@code U}. */ @SuppressWarnings("unchecked") - public Output expect(DataType dt) { - if (!dt.equals(this.dataType())) { + public Output expect(Class type) { + if (type != type()) { throw new IllegalArgumentException( - "Cannot cast from output of " + this.dataType() + " to output of " + dt); + "Cannot cast from output of " + this.type().getSimpleName() + " to output of " + type.getSimpleName()); } return ((Output) this); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index 5af44ce73df..c332fd7f1d1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -23,9 +23,11 @@ import org.bytedeco.javacpp.PointerScope; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.registry.TensorTypeInfo; +import org.tensorflow.internal.types.registry.TensorTypeRegistry; import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.Shaped; import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.family.TType; /** @@ -42,8 +44,8 @@ public final class RawTensor implements Tensor { @Override - public DataType dataType() { - return dtype; + public DataType dataType() { + return typeInfo.dataType(); } @Override @@ -84,33 +86,52 @@ public ByteDataBuffer data() { */ @Override public String toString() { - return String.format("%s tensor with shape %s", dtype.toString(), shape); + return String.format("%s tensor with shape %s", typeInfo.dataType(), shape); } /** * Allocates a new tensor in native memory of the given type, shape and size. * *

The size of the tensor must be at least large enough to contain all scalars for the - * given type and shape, i.e. size >= dtype.byteSize() * shape.size(). More memory - * can be allocated to store also metadata within the tensor itself, e.g. a lookup table - * in a string tensor. + * given type and shape. More memory can also be allocated to store also metadata within the + * tensor itself, e.g. a lookup table in a string tensor. * - * @param dtype data type + * @param type tensor type class * @param shape shape of the tensor - * @param size size of the tensor + * @param size size in bytes of the tensor, or -1 to compute the size from the shape * @return allocated tensor + * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to + * store the tensor data + * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given + * {@code type} are of variable length (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially + * {@link Shape#hasUnknownDimension() unknown} + * @throws IllegalStateException if tensor failed to be allocated */ - static RawTensor allocate(DataType dtype, Shape shape, long size) { - // Minimum requirements for datatypes of variable length cannot be verified in a relevant way so - // we only validate them for fixed length datatypes - if (!dtype.isVariableLength() && shape.size() * dtype.byteSize() > size) { - throw new IllegalArgumentException("Tensor size is not large enough to contain all scalar values"); + static RawTensor allocate(Class type, Shape shape, long size) { + if (shape.hasUnknownDimension()) { + throw new IllegalArgumentException( + "Cannot allocate a tensor from a totally or partially unknown shape"); } - TF_Tensor nativeHandle = allocate(dtype.nativeCode(), shape.asArray(), size); + TensorTypeInfo typeInfo = TensorTypeRegistry.find(type); + long allocatedSize = size; + if (allocatedSize < 0) { + if (typeInfo.isVariableLength()) { + throw new IllegalArgumentException( + "Explicit size is required for variable-length tensor types"); + } + allocatedSize = shape.size() * typeInfo.byteSize(); + + } else if (!typeInfo.isVariableLength() && shape.size() * typeInfo.byteSize() > allocatedSize) { + // Minimum requirements for datatypes of variable length cannot be verified in a relevant way so + // we only validate them for fixed length datatypes + throw new IllegalArgumentException( + "Tensor size is not large enough to contain all scalar values"); + } + TF_Tensor nativeHandle = allocate(typeInfo.dataType().getNumber(), shape.asArray(), allocatedSize); try (PointerScope scope = new PointerScope()) { scope.attach(nativeHandle); - - RawTensor t = new RawTensor(dtype, shape); + RawTensor t = new RawTensor(typeInfo, shape); t.tensorHandle = nativeHandle; t.tensorScope = scope.extend(); return t; @@ -123,7 +144,8 @@ static RawTensor allocate(DataType dtype, Shape shape, long size) { *

Takes ownership of the handle. */ static RawTensor fromHandle(TF_Tensor handle) { - RawTensor t = new RawTensor(DataTypes.fromNativeCode(dtype(handle)), Shape.of(shape(handle))); + TensorTypeInfo typeInfo = TensorTypeRegistry.find(DataType.forNumber(dtype(handle))); + RawTensor t = new RawTensor(typeInfo, Shape.of(shape(handle))); try (PointerScope scope = new PointerScope()) { scope.attach(handle); t.tensorHandle = handle; @@ -158,13 +180,10 @@ TF_Tensor nativeHandle() { *

In some cases, it is more useful to keep a typed reference to a tensor rather than its raw * nature to prevent mapping its memory on every access (e.g. when calling {@link Operand#asTensor()}). * - * @param type of the tensor (must be compatible with the internal representation of this tensor, - * as indicated by {@link #dataType()}) * @return typed reference to this tensor - * @throws ClassCastException if {@code T} is not compatible type with {@link #dataType()} */ - T asTypedTensor() { - return (T)dtype.map(this); + TType asTypedTensor() { + return typeInfo.mapper().mapDense(this); } private static TF_Tensor requireHandle(TF_Tensor handle) { @@ -197,14 +216,14 @@ private static long[] shape(TF_Tensor handle) { return dims; } - RawTensor(DataType dtype, Shape shape) { - this.dtype = dtype; + RawTensor(TensorTypeInfo typeInfo, Shape shape) { + this.typeInfo = typeInfo; this.shape = shape; } private PointerScope tensorScope; private TF_Tensor tensorHandle; - private final DataType dtype; + private final TensorTypeInfo typeInfo; private final Shape shape; private ByteDataBuffer buffer = null; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index 376dc9039fc..ea32d1fff13 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -18,7 +18,6 @@ import java.util.Map; import java.util.Set; import org.tensorflow.ndarray.Shape; -import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.proto.framework.TensorShapeProto; @@ -113,7 +112,7 @@ private static TensorInfo toTensorInfo(Output operand) { tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i))); } return TensorInfo.newBuilder() - .setDtype(DataType.forNumber(operand.dataType().nativeCode())) + .setDtype(operand.dataType()) .setTensorShape(tensorShapeBuilder) .setName(operand.op().name() + ":" + operand.index()) .build(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index bccdf698608..fc1275229bf 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -15,19 +15,11 @@ package org.tensorflow; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_Dim; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_NumDims; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorByteSize; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorType; - import java.util.function.Consumer; -import org.bytedeco.javacpp.PointerScope; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.Shaped; import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.family.TType; /** @@ -57,37 +49,45 @@ public interface Tensor extends Shaped, AutoCloseable { *

The amount of memory to allocate is derived from the datatype and the shape of the tensor, * and is left uninitialized. * - * @param the tensor element type - * @param dtype datatype of the tensor + * @param the tensor type + * @param type the tensor type class * @param shape shape of the tensor * @return an allocated but uninitialized tensor + * @throws IllegalArgumentException if elements of the given {@code type} are of variable length + * (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially + * {@link Shape#hasUnknownDimension() unknown} * @throws IllegalStateException if tensor failed to be allocated */ - static T of(DataType dtype, Shape shape) { - return of(dtype, shape, shape.size() * dtype.byteSize()); + static T of(Class type, Shape shape) { + return of(type, shape, -1); } /** * Allocates a tensor of a given datatype, shape and size. * - *

This method is identical to {@link #of(DataType, Shape)}, except that the final size of the - * tensor is explicitly set instead of computing it from the datatype and shape, which could be + *

This method is identical to {@link #of(Class, Shape)}, except that the final size of the + * tensor can be explicitly set instead of computing it from the datatype and shape, which could be * larger than the actual space required to store the data but not smaller. * - * @param the tensor element type - * @param dtype datatype of the tensor + * @param the tensor type + * @param type the tensor type class * @param shape shape of the tensor - * @param size size, in bytes, of the tensor + * @param size size in bytes of the tensor or -1 to compute the size from the shape * @return an allocated but uninitialized tensor - * @see #of(DataType, Shape) + * @see #of(Class, Shape) * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to * store the tensor data + * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given + * {@code type} are of variable length (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially + * {@link Shape#hasUnknownDimension() unknown} * @throws IllegalStateException if tensor failed to be allocated */ - static T of(DataType dtype, Shape shape, long size) { - RawTensor tensor = RawTensor.allocate(dtype, shape, size); + static T of(Class type, Shape shape, long size) { + RawTensor tensor = RawTensor.allocate(type, shape, size); try { - return tensor.asTypedTensor(); + return (T)tensor.asTypedTensor(); } catch (Exception e) { tensor.close(); throw e; @@ -103,7 +103,7 @@ static T of(DataType dtype, Shape shape, long size) { * *

{@code
    * FloatNdArray data = ...
-   * try (TFloat32 t = Tensor.of(TFloat32.DTYPE, Shape.of(2, 2), data::copyTo)) {
+   * try (TFloat32 t = Tensor.of(TFloat32.class, Shape.of(2, 2), data::copyTo)) {
    *   ...
    * }
    * }
@@ -111,39 +111,47 @@ static T of(DataType dtype, Shape shape, long size) { *

If {@code dataInitializer} fails and throws an exception, the allocated tensor will be * automatically released before rethrowing the same exception. * - * @param the tensor element type - * @param dtype datatype of the tensor + * @param the tensor type + * @param type the tensor type class * @param shape shape of the tensor * @param dataInitializer method receiving accessor to the allocated tensor data for initialization * @return an allocated and initialized tensor + * @throws IllegalArgumentException if elements of the given {@code type} are of variable length + * (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially + * {@link Shape#hasUnknownDimension() unknown} * @throws IllegalStateException if tensor failed to be allocated */ - static T of(DataType dtype, Shape shape, Consumer dataInitializer) { - return of(dtype, shape, shape.size() * dtype.byteSize(), dataInitializer); + static T of(Class type, Shape shape, Consumer dataInitializer) { + return of(type, shape, -1, dataInitializer); } /** * Allocates a tensor of a given datatype, shape and size. * - *

This method is identical to {@link #of(DataType, Shape, Consumer)}, except that the final - * size for the tensor is explicitly set instead of being computed from the datatype and shape. + *

This method is identical to {@link #of(Class, Shape, Consumer)}, except that the final + * size for the tensor can be explicitly set instead of being computed from the datatype and shape. * *

This could be useful for tensor types that stores data but also metadata in the tensor memory, * such as the lookup table in a tensor of strings. * - * @param the tensor element type - * @param dtype datatype of the tensor + * @param the tensor type + * @param type the tensor type class * @param shape shape of the tensor - * @param size size, in bytes, of the tensor + * @param size size in bytes of the tensor or -1 to compute the size from the shape * @param dataInitializer method receiving accessor to the allocated tensor data for initialization * @return an allocated and initialized tensor - * @see #of(DataType, Shape, long, Consumer) + * @see #of(Class, Shape, long, Consumer) * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to * store the tensor data + * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given + * {@code type} are of variable length (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially + * {@link Shape#hasUnknownDimension() unknown} * @throws IllegalStateException if tensor failed to be allocated */ - static T of(DataType dtype, Shape shape, long size, Consumer dataInitializer) { - T tensor = of(dtype, shape, size); + static T of(Class type, Shape shape, long size, Consumer dataInitializer) { + T tensor = of(type, shape, size); try { dataInitializer.accept(tensor); return tensor; @@ -159,21 +167,24 @@ static T of(DataType dtype, Shape shape, long size, Consume *

Data must have been encoded into {@code data} as per the specification of the TensorFlow C API. * - * @param the tensor element type - * @param dtype the tensor element data type + * @param the tensor type + * @param type the tensor type class * @param shape the tensor shape. * @param rawData a buffer containing the tensor raw data. - * @throws IllegalArgumentException if {@code rawData} is not large enough to contain the tensor data + * @throws IllegalArgumentException if {@code rawData} is not large enough to contain the tensor + * data + * @throws IllegalArgumentException if {@code shape} is totally or partially + * {@link Shape#hasUnknownDimension() unknown} * @throws IllegalStateException if tensor failed to be allocated with the given parameters */ - static T of(DataType dtype, Shape shape, ByteDataBuffer rawData) { - return of(dtype, shape, rawData.size(), t -> rawData.copyTo(t.asRawTensor().data(), rawData.size())); + static T of(Class type, Shape shape, ByteDataBuffer rawData) { + return of(type, shape, rawData.size(), t -> rawData.copyTo(t.asRawTensor().data(), rawData.size())); } /** * Returns the {@link DataType} of elements stored in the tensor. */ - DataType dataType(); + DataType dataType(); /** * Returns the size, in bytes, of the tensor data. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceProvider.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceProvider.java index 0d3d60772be..1a63d551336 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceProvider.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceProvider.java @@ -19,7 +19,6 @@ import java.util.Iterator; import java.util.function.Function; import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArraySequence; /** * Produces sequence of bytes to be stored in a {@link ByteSequenceTensorBuffer}. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Mapper.java index 7b3a400ba16..27688e55779 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Mapper.java @@ -16,9 +16,8 @@ */ package org.tensorflow.internal.types; -import org.tensorflow.TensorMapper; -import org.tensorflow.DataType; import org.tensorflow.RawTensor; +import org.tensorflow.TensorMapper; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.buffer.layout.DataLayouts; @@ -40,8 +39,8 @@ protected TBfloat16 mapDense(RawTensor tensor) { private static final class DenseTBfloat16 extends FloatDenseNdArray implements TBfloat16 { @Override - public DataType dataType() { - return TBfloat16.DTYPE; + public Class type() { + return TBfloat16.class; } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolMapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolMapper.java index 2a9c44330c2..ff4c11a521b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolMapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolMapper.java @@ -16,7 +16,6 @@ */ package org.tensorflow.internal.types; -import org.tensorflow.DataType; import org.tensorflow.RawTensor; import org.tensorflow.TensorMapper; import org.tensorflow.internal.buffer.TensorBuffers; @@ -39,8 +38,8 @@ protected TBool mapDense(RawTensor tensor) { private static final class DenseTBool extends BooleanDenseNdArray implements TBool { @Override - public DataType dataType() { - return TBool.DTYPE; + public Class type() { + return TBool.class; } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Mapper.java index 02ef883f9bf..fec84843f57 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Mapper.java @@ -16,7 +16,6 @@ */ package org.tensorflow.internal.types; -import org.tensorflow.DataType; import org.tensorflow.RawTensor; import org.tensorflow.TensorMapper; import org.tensorflow.internal.buffer.TensorBuffers; @@ -40,8 +39,8 @@ protected TFloat16 mapDense(RawTensor tensor) { private static final class DenseTFloat16 extends FloatDenseNdArray implements TFloat16 { @Override - public DataType dataType() { - return TFloat16.DTYPE; + public Class type() { + return TFloat16.class; } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Mapper.java index 1f446e820a7..62fc0d226ac 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Mapper.java @@ -16,7 +16,6 @@ */ package org.tensorflow.internal.types; -import org.tensorflow.DataType; import org.tensorflow.RawTensor; import org.tensorflow.TensorMapper; import org.tensorflow.internal.buffer.TensorBuffers; @@ -39,8 +38,8 @@ protected TFloat32 mapDense(RawTensor tensor) { private static final class DenseTFloat32 extends FloatDenseNdArray implements TFloat32 { @Override - public DataType dataType() { - return TFloat32.DTYPE; + public Class type() { + return TFloat32.class; } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Mapper.java index 9735604ecb2..375a7429950 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Mapper.java @@ -16,7 +16,6 @@ */ package org.tensorflow.internal.types; -import org.tensorflow.DataType; import org.tensorflow.RawTensor; import org.tensorflow.TensorMapper; import org.tensorflow.internal.buffer.TensorBuffers; @@ -39,8 +38,8 @@ protected TFloat64 mapDense(RawTensor tensor) { private static final class DenseTFloat64 extends DoubleDenseNdArray implements TFloat64 { @Override - public DataType dataType() { - return TFloat64.DTYPE; + public Class type() { + return TFloat64.class; } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Mapper.java index e0cf98c4ad8..fa0852a8b09 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Mapper.java @@ -16,7 +16,6 @@ */ package org.tensorflow.internal.types; -import org.tensorflow.DataType; import org.tensorflow.RawTensor; import org.tensorflow.TensorMapper; import org.tensorflow.internal.buffer.TensorBuffers; @@ -39,8 +38,8 @@ protected TInt32 mapDense(RawTensor tensor) { private static final class DenseTInt32 extends IntDenseNdArray implements TInt32 { @Override - public DataType dataType() { - return TInt32.DTYPE; + public Class type() { + return TInt32.class; } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Mapper.java index d8a629e16ef..c5f2325e25a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Mapper.java @@ -16,7 +16,6 @@ */ package org.tensorflow.internal.types; -import org.tensorflow.DataType; import org.tensorflow.RawTensor; import org.tensorflow.TensorMapper; import org.tensorflow.internal.buffer.TensorBuffers; @@ -39,8 +38,8 @@ protected TInt64 mapDense(RawTensor tensor) { private static final class DenseTInt64 extends LongDenseNdArray implements TInt64 { @Override - public DataType dataType() { - return TInt64.DTYPE; + public Class type() { + return TInt64.class; } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringMapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringMapper.java index 51f224c8734..de7c6016e0e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringMapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringMapper.java @@ -18,11 +18,10 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import org.tensorflow.DataType; import org.tensorflow.RawTensor; import org.tensorflow.TensorMapper; -import org.tensorflow.internal.buffer.ByteSequenceTensorBuffer; import org.tensorflow.internal.buffer.ByteSequenceProvider; +import org.tensorflow.internal.buffer.ByteSequenceTensorBuffer; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; @@ -79,8 +78,8 @@ public NdArray asBytes() { } @Override - public DataType dataType() { - return TString.DTYPE; + public Class type() { + return TString.class; } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Mapper.java index 5ef0fcad692..427debd1ac8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Mapper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Mapper.java @@ -16,7 +16,6 @@ */ package org.tensorflow.internal.types; -import org.tensorflow.DataType; import org.tensorflow.RawTensor; import org.tensorflow.TensorMapper; import org.tensorflow.internal.buffer.TensorBuffers; @@ -39,8 +38,8 @@ protected TUint8 mapDense(RawTensor tensor) { private static final class DenseTUint8 extends ByteDenseNdArray implements TUint8 { @Override - public DataType dataType() { - return TUint8.DTYPE; + public Class type() { + return TUint8.class; } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/registry/TensorTypeInfo.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/registry/TensorTypeInfo.java new file mode 100644 index 00000000000..a4a89a71649 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/registry/TensorTypeInfo.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.internal.types.registry; + +import org.tensorflow.TensorMapper; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.family.TType; + +/** + * Registered information about a tensor type. + * + * @param the tensor type + */ +public final class TensorTypeInfo { + + /** + * Returns the class of this tensor type + */ + public Class type() { + return type; + } + + /** + * Returns the corresponding data type for this tensor type + */ + public DataType dataType() { + return dataType; + } + + /** + * Returns the number of bytes required to store one element of the corresponding data type, -1 if variable. + */ + public int byteSize() { + return byteSize; + } + + /** + * Returns true if elements of the corresponding data type are of variable length (undefined number of bytes) + */ + public boolean isVariableLength() { + return byteSize < 0; + } + + /** + * Returns an object used to map {@link org.tensorflow.RawTensor raw tensors} to a tensor of this type + */ + public TensorMapper mapper() { + return mapper; + } + + TensorTypeInfo(Class type, DataType dataType, int byteSize, TensorMapper mapper) { + this.type = type; + this.dataType = dataType; + this.byteSize = byteSize; + this.mapper = mapper; + } + + private final Class type; + private final DataType dataType; + private final int byteSize; + private final TensorMapper mapper; +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/registry/TensorTypeRegistry.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/registry/TensorTypeRegistry.java new file mode 100644 index 00000000000..a30138e0386 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/registry/TensorTypeRegistry.java @@ -0,0 +1,104 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.internal.types.registry; + +import java.util.HashMap; +import java.util.Map; +import org.tensorflow.TensorMapper; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; +import org.tensorflow.types.annotation.TensorType; +import org.tensorflow.types.family.TType; + +/** + * Repository of all registered tensor types. + */ +public final class TensorTypeRegistry { + + /** + * Find registered information about a tensor type from its equivalent data type + * + * @param dataType data type + * @return type registered information + * @throws IllegalArgumentException if no tensor type for this data type has been registered + */ + public static TensorTypeInfo find(DataType dataType) { + TensorTypeInfo typeInfo = TYPES_BY_CODE.get(dataType.getNumber()); + if (typeInfo == null) { + throw new IllegalArgumentException("No tensor type has been registered for data type " + dataType); + } + return (TensorTypeInfo)typeInfo; + } + + /** + * Find registered information about a tensor type from its class + * + * @param type class implementing {@link TType} + * @return type registered information + * @throws IllegalArgumentException if the provided class has not been registered as a tensor type + */ + public static TensorTypeInfo find(Class type) { + TensorTypeInfo typeInfo = TYPES_BY_CLASS.get(type); + if (typeInfo == null) { + throw new IllegalArgumentException("Class \"" + type.getName() + "\" is not registered as a tensor type"); + } + return (TensorTypeInfo)typeInfo; + } + + private static final Map> TYPES_BY_CODE = new HashMap<>(); + private static final Map, TensorTypeInfo> TYPES_BY_CLASS = new HashMap<>(); + + private static void register(Class type) { + TensorType typeAnnot = type.getDeclaredAnnotation(TensorType.class); + if (typeAnnot == null) { + throw new IllegalArgumentException("Class \"" + type.getName() + "\" must be annotated " + + "with @TensorType to be registered as a tensor type"); + } + TensorMapper mapper; + try { + mapper = (TensorMapper)typeAnnot.mapperClass().newInstance(); + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException("Class \"" + type.getName() + "\" must have a public " + + "parameter-less constructor to be used as a tensor mapper"); + } + TensorTypeInfo typeInfo = new TensorTypeInfo<>(type, typeAnnot.dataType(), typeAnnot.byteSize(), mapper); + TYPES_BY_CLASS.put(type, typeInfo); + TYPES_BY_CODE.put(typeInfo.dataType().getNumber(), typeInfo); + TYPES_BY_CODE.put(typeInfo.dataType().getNumber() + 100, typeInfo); + } + + static { + // TODO (karllessard) scan and registered automatically all annotated tensors types + register(TBool.class); + register(TFloat64.class); + register(TFloat32.class); + register(TFloat16.class); + register(TInt32.class); + register(TInt64.class); + register(TString.class); + register(TUint8.class); + register(TBfloat16.class); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Operands.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Operands.java index ac48da80326..5706ff1f283 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Operands.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Operands.java @@ -16,10 +16,14 @@ package org.tensorflow.op; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import org.tensorflow.Operand; import org.tensorflow.OperationBuilder; import org.tensorflow.Output; +import org.tensorflow.internal.types.registry.TensorTypeRegistry; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.family.TType; /** Utilities for manipulating operand related types and lists. */ public final class Operands { @@ -41,6 +45,31 @@ public static Output[] asOutputs(Iterable> inputs) { return outputList.toArray(new Output[outputList.size()]); } + /** + * Converts a tensor type class to a {@link DataType} attribute. + * + * @param type tensor type class + * @return data type + */ + public static DataType toDataType(Class type) { + return TensorTypeRegistry.find(type).dataType(); + } + + /** + * Converts a list of tensor type classes to an array of {@link DataType} attributes. + * + * @param types tensor type classes + * @return an array of data types + */ + public static DataType[] toDataTypes(Collection> types) { + DataType[] dataTypes = new DataType[types.size()]; + int i = 0; + for (Class type : types) { + dataTypes[i++] = toDataType(type); + } + return dataTypes; + } + // Disabled constructor private Operands() {} } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java index 3be83d8173e..1b6aee0284b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java @@ -16,7 +16,6 @@ package org.tensorflow.op.core; import java.nio.charset.Charset; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.Output; @@ -1013,8 +1012,9 @@ public static Constant tensorOf(Scope scope, Shape shape, ByteDataBuffer /** * Create a constant with data from the given buffer. * + * @param the tensor type * @param scope is a scope used to add the underlying operation. - * @param type the tensor datatype. + * @param type the tensor type class * @param shape the tensor shape. * @param data a buffer containing the tensor data. * @return a constant of type `type` @@ -1022,7 +1022,7 @@ public static Constant tensorOf(Scope scope, Shape shape, ByteDataBuffer * buffer */ @Endpoint - public static Constant tensorOf(Scope scope, DataType type, Shape shape, + public static Constant tensorOf(Scope scope, Class type, Shape shape, ByteDataBuffer data) { try (T value = Tensor.of(type, shape, data)) { return create(scope, value); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Gradients.java index 2827276c32c..82edab51d40 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Gradients.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Gradients.java @@ -22,7 +22,6 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.op.Op; import org.tensorflow.op.Operands; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java index f9ce837fe60..59682777966 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java @@ -16,7 +16,6 @@ package org.tensorflow.op.core; import org.tensorflow.Operand; -import org.tensorflow.Output; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; @@ -46,8 +45,7 @@ private Helpers() {} */ @Endpoint(name = "variable") public static Variable createVariableWithInit(Scope scope, Operand init, Variable.Options... options) { - Output initOutput = init.asOutput(); - Variable newVar = Variable.create(scope,initOutput.shape(), initOutput.dataType(), options); + Variable newVar = Variable.create(scope, init.shape(), init.type(), options); Assign assignOp = Assign.create(scope, newVar, init); Init.add(scope, assignOp); return newVar; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Init.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Init.java index e0f64b6b19a..b7b65a973c9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Init.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Init.java @@ -35,7 +35,7 @@ public final class Init extends RawOp { * try (Session s = new Session(g)) { * s.run(tf.init()); // initialize all variables * - * try (Tensor t = s.runner().fetch(z).run().get(0).expect(TInt32.DTYPE)) { + * try (TInt32 t = (TInt32)s.runner().fetch(z).run().get(0)) { * assertEquals(30, t.data().getInt()); * } * } @@ -62,7 +62,7 @@ public final class Init extends RawOp { * try (SavedModelBundle model = SavedModelBundle.load("/path/to/model", "train")) { * model.session().run(Init.DEFAULT_NAME); * - * try (Tensor t = s.runner().fetch("z").run().get(0).expect(TInt32.DTYPE)) { + * try (TInt32 t = (TInt32)s.runner().fetch("z").run().get(0)) { * assertEquals(30, t.data().getInt()); * } * } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java index 3af0846b441..a57c05f1940 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java @@ -14,7 +14,6 @@ ==============================================================================*/ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.Output; @@ -31,7 +30,7 @@ * An operator creating a constant initialized with ones of the shape given by `dims`. * *

For example, the following expression - *

{@code tf.ones(tf.constant(shape), TFloat32.DTYPE)}
+ *
{@code tf.ones(tf.constant(shape), TFloat32.class)}
* is the equivalent of *
{@code tf.fill(tf.constant(shape), tf.constant(1.0f))}
* @@ -45,14 +44,14 @@ public final class Ones implements Op, Operand { * * @param scope is a scope used to add the underlying operation * @param dims a 1-D operand that represents the shape of the output tensor - * @param type the output tensor datatype. Can not be TString. + * @param type the output tensor type class. Can not be TString. * @return a constant tensor initialized with ones * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with ones. */ @Endpoint - public static Ones create(Scope scope, Operand dims, DataType type) { + public static Ones create(Scope scope, Operand dims, Class type) { Scope onesScope = scope.withSubScope("Ones"); - if (type == TString.DTYPE) { + if (type == TString.class) { throw new IllegalArgumentException("Can't create Ones of String DataType"); } Operand one = Cast.create(onesScope.withName("One"), Constant.scalarOf(onesScope, 1), type); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Shapes.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Shapes.java index 613cb729341..2bf2eecc4cb 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Shapes.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Shapes.java @@ -15,16 +15,12 @@ package org.tensorflow.op.core; import java.util.Arrays; - -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; - -import org.tensorflow.op.math.FloorMod; - import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.FloorMod; import org.tensorflow.op.math.NotEqual; import org.tensorflow.op.math.Sub; import org.tensorflow.types.TBool; @@ -51,8 +47,8 @@ * Operand numPred = tf.shape.size(predShape, tf.constant(0)); * Operand predFlat = tf.shape.flatten(yPred); * - * Shape predShape64 = tf.shape(yPred, TInt64.DTYPE); - * Operand predSqueezed = tf.shape.squeeze(predShape64, TInt64.DTYPE); + * Shape predShape64 = tf.shape(yPred, TInt64.class); + * Operand predSqueezed = tf.shape.squeeze(predShape64, TInt64.class); * }
*/ @Operator(group = "shape") @@ -68,7 +64,7 @@ public abstract class Shapes { */ @Endpoint(name = "flatten") public static Operand flatten(Scope scope, Operand operand) { - return flatten(scope, operand, TInt32.DTYPE); + return flatten(scope, operand, TInt32.class); } /** @@ -78,13 +74,13 @@ public static Operand flatten(Scope scope, Operand opera * @param the shape datatype * @param scope current scope * @param operand the operand to flatten - * @param dType the shape datatype + * @param type the shape datatype * @return the reshaped operand */ @Endpoint(name = "flatten") public static Operand flatten( - Scope scope, Operand operand, DataType dType) { - Operand flatShape = flatten(scope, Shape.create(scope, operand, dType), dType); + Scope scope, Operand operand, Class type) { + Operand flatShape = flatten(scope, Shape.create(scope, operand, type), type); return Reshape.create(scope, operand, flatShape); } @@ -97,7 +93,7 @@ public static Operand flatten( */ @Endpoint(name = "flatten") public static Operand flatten(Scope scope, Shape shape) { - return flatten(scope, shape, TInt32.DTYPE); + return flatten(scope, shape, TInt32.class); } /** @@ -106,16 +102,16 @@ public static Operand flatten(Scope scope, Shape shape) { * @param the shape datatype * @param scope current scope * @param shape the TensorFlow shape - * @param dType the shape datatype + * @param type the shape datatype * @return the flattened shape */ @Endpoint(name = "flatten") public static Operand flatten( - Scope scope, Shape shape, DataType dType) { + Scope scope, Shape shape, Class type) { return ExpandDims.create( scope, - size(scope, shape, dType), - Cast.create(scope, Constant.scalarOf(scope, -1), TInt32.DTYPE)); + size(scope, shape, type), + Cast.create(scope, Constant.scalarOf(scope, -1), TInt32.class)); } /** @@ -127,7 +123,7 @@ public static Operand flatten( */ @Endpoint(name = "size") public static Operand size(Scope scope, Shape shape) { - return size(scope, shape, TInt32.DTYPE); + return size(scope, shape, TInt32.class); } /** @@ -136,20 +132,20 @@ public static Operand size(Scope scope, Shape shape) { * @param the type of the shape * @param scope current scope * @param shape the TensorFlow shape - * @param dType the shape datatype + * @param type the shape datatype * @return the size */ @Endpoint(name = "size") public static Operand size( - Scope scope, Shape shape, DataType dType) { + Scope scope, Shape shape, Class type) { Slice dims = Slice.create( scope, shape, - Cast.create(scope, Constant.arrayOf(scope, 0), dType), + Cast.create(scope, Constant.arrayOf(scope, 0), type), ExpandDims.create( scope, - Cast.create(scope, Constant.scalarOf(scope, -1), dType), + Cast.create(scope, Constant.scalarOf(scope, -1), type), Constant.scalarOf(scope, -1))); return ReduceProd.create(scope, dims, Constant.scalarOf(scope, 0)); } @@ -164,7 +160,7 @@ public static Operand size( */ @Endpoint(name = "size") public static Operand size(Scope scope, Shape shape, Operand dim) { - return size(scope, shape, dim, TInt32.DTYPE); + return size(scope, shape, dim, TInt32.class); } /** @@ -174,20 +170,20 @@ public static Operand size(Scope scope, Shape shape, Operand Operand size( - Scope scope, Shape shape, Operand dim, DataType dType) { + Scope scope, Shape shape, Operand dim, Class type) { return Slice.create( scope, shape, - ExpandDims.create(scope, dim, Cast.create(scope, Constant.scalarOf(scope, -1), dType)), + ExpandDims.create(scope, dim, Cast.create(scope, Constant.scalarOf(scope, -1), type)), ExpandDims.create( scope, - Cast.create(scope, Constant.scalarOf(scope, 1), dType), - Cast.create(scope, Constant.scalarOf(scope, -1), dType))); + Cast.create(scope, Constant.scalarOf(scope, 1), type), + Cast.create(scope, Constant.scalarOf(scope, -1), type))); } /** @@ -201,7 +197,7 @@ public static Operand size( @Endpoint(name = "size") public static Operand size( Scope scope, Operand input, Operand dim) { - return size(scope, input, dim, TInt32.DTYPE); + return size(scope, input, dim, TInt32.class); } /** @@ -211,13 +207,13 @@ public static Operand size( * @param scope current scope * @param input the operand * @param dim the dimension - * @param dType the shape datatype + * @param type the shape datatype * @return the size of the specified dimension */ @Endpoint(name = "size") public static Operand size( - Scope scope, Operand input, Operand dim, DataType dType) { - return size(scope, Shape.create(scope, input, dType), dim, dType); + Scope scope, Operand input, Operand dim, Class type) { + return size(scope, Shape.create(scope, input, type), dim, type); } /** @@ -229,7 +225,7 @@ public static Operand size( */ @Endpoint(name = "numDimensions") public static Operand numDimensions(Scope scope, Shape shape) { - return Size.create(scope, shape, TInt32.DTYPE); + return Size.create(scope, shape, TInt32.class); } /** @@ -238,13 +234,13 @@ public static Operand numDimensions(Scope scope, Shape shape) { * @param the shape datatype * @param scope the curren scope * @param shape the shape - * @param dType the shape datatype + * @param type the shape datatype * @return the number of dimensions */ @Endpoint(name = "numDimensions") public static Operand numDimensions( - Scope scope, Shape shape, DataType dType) { - return Size.create(scope, shape, dType); + Scope scope, Shape shape, Class type) { + return Size.create(scope, shape, type); } /** @@ -259,7 +255,7 @@ public static Operand numDimensions( @Endpoint(name = "reduceDims") public static Operand reduceDims( Scope scope, Operand operand, Operand axis) { - return reduceDims(scope, operand, axis, TInt32.DTYPE); + return reduceDims(scope, operand, axis, TInt32.class); } /** @@ -270,14 +266,14 @@ public static Operand reduceDims( * @param scope current scope * @param operand the operand * @param axis the axis - * @param dType the shape datatype + * @param type the shape datatype * @return the reshaped operand */ @Endpoint(name = "reduceDims") public static Operand reduceDims( - Scope scope, Operand operand, Operand axis, DataType dType) { - Shape newShape = Shape.create(scope, operand, dType); - return Reshape.create(scope, operand, reduceDims(scope, newShape, axis, dType)); + Scope scope, Operand operand, Operand axis, Class type) { + Shape newShape = Shape.create(scope, operand, type); + return Reshape.create(scope, operand, reduceDims(scope, newShape, axis, type)); } /** @@ -290,7 +286,7 @@ public static Operand reduceDims( */ @Endpoint(name = "reduceDims") public static Operand reduceDims(Scope scope, Shape shape, Operand axis) { - return reduceDims(scope, shape, axis, TInt32.DTYPE); + return reduceDims(scope, shape, axis, TInt32.class); } /** @@ -300,13 +296,13 @@ public static Operand reduceDims(Scope scope, Shape shape, Opera * @param scope current scope * @param shape the TensorFlow shape * @param axis the axis - * @param dType the shape datatype + * @param type the shape datatype * @return the reduced shape */ @Endpoint(name = "reduceDims") public static Operand reduceDims( - Scope scope, Shape shape, Operand axis, DataType dType) { - Size rank = Size.create(scope, shape, dType); + Scope scope, Shape shape, Operand axis, Class type) { + Size rank = Size.create(scope, shape, type); axis = FloorMod.create(scope, axis, rank); Sub remainder = Sub.create(scope, rank, axis); @@ -314,7 +310,7 @@ public static Operand reduceDims( Slice.create( scope, shape, - Cast.create(scope, Constant.arrayOf(scope, 0), dType), + Cast.create(scope, Constant.arrayOf(scope, 0), type), ExpandDims.create(scope, axis, Constant.scalarOf(scope, -1))); Operand dims2 = @@ -324,7 +320,7 @@ public static Operand reduceDims( ExpandDims.create(scope, axis, Constant.scalarOf(scope, -1)), ExpandDims.create( scope, - Cast.create(scope, Constant.scalarOf(scope, -1), dType), + Cast.create(scope, Constant.scalarOf(scope, -1), type), Constant.scalarOf(scope, -1))); Operand prod = @@ -343,7 +339,7 @@ public static Operand reduceDims( */ @Endpoint(name = "squeeze") public static Operand squeeze(Scope scope, Shape shape) { - return squeeze(scope, shape, TInt32.DTYPE); + return squeeze(scope, shape, TInt32.class); } /** @@ -352,14 +348,14 @@ public static Operand squeeze(Scope scope, Shape shape) { * @param the shape datatype. * @param scope current scope * @param shape the TensorFlow shape - * @param dType the shape datatype. + * @param type the shape datatype. * @return the squeezed shape */ @Endpoint(name = "squeeze") public static Operand squeeze( - Scope scope, Shape shape, DataType dType) { + Scope scope, Shape shape, Class type) { Operand mask = - NotEqual.create(scope, shape, Cast.create(scope, OnesLike.create(scope, shape), dType)); + NotEqual.create(scope, shape, Cast.create(scope, OnesLike.create(scope, shape), type)); return Gather.create(scope, shape, Where.create(scope, mask), Constant.scalarOf(scope, 0)); } @@ -373,7 +369,7 @@ public static Operand squeeze( */ @Endpoint(name = "head") public static Operand head(Scope scope, Shape shape) { - return head(scope, shape, TInt32.DTYPE); + return head(scope, shape, TInt32.class); } /** @@ -381,14 +377,14 @@ public static Operand head(Scope scope, Shape shape) { * * @param scope current scope * @param shape the TensorFlow shape - * @param dType the shape datatype. + * @param type the shape datatype. * @param the shape datatype. * @return a 1-dimensional Operand containing the Shape's first dimension */ @Endpoint(name = "head") public static Operand head( - Scope scope, Shape shape, DataType dType) { - return take(scope, shape, Cast.create(scope, Constant.scalarOf(scope, 1), dType), dType); + Scope scope, Shape shape, Class type) { + return take(scope, shape, Cast.create(scope, Constant.scalarOf(scope, 1), type), type); } /** @@ -403,7 +399,7 @@ public static Operand head( */ @Endpoint(name = "take") public static Operand take(Scope scope, Shape shape, Operand n) { - return take(scope, shape, n, TInt32.DTYPE); + return take(scope, shape, n, TInt32.class); } /** @@ -413,18 +409,18 @@ public static Operand take(Scope scope, Shape shape, Operand the shape datatype. * @return a 1-dimensional operand with the dimensions matching * the first n dimensions of the * shape */ @Endpoint(name = "take") public static Operand take( - Scope scope, Shape shape, Operand n, DataType dType) { + Scope scope, Shape shape, Operand n, Class type) { return Slice.create( scope, shape, - Cast.create(scope, Constant.arrayOf(scope, 0), dType), + Cast.create(scope, Constant.arrayOf(scope, 0), type), ExpandDims.create(scope, n, Constant.scalarOf(scope, -1))); } @@ -439,7 +435,7 @@ public static Operand take( */ @Endpoint(name = "tail") public static Operand tail(Scope scope, Shape shape) { - return tail(scope, shape, TInt32.DTYPE); + return tail(scope, shape, TInt32.class); } /** @@ -448,15 +444,15 @@ public static Operand tail(Scope scope, Shape shape) { * * @param scope current scope * @param shape the TensorFlow shape - * @param dType the shape datatype. + * @param type the shape datatype. * @param the shape datatype. * @return a 1-dimensional Operand that contains the dimension matching the last dimension of the * Shape */ @Endpoint(name = "tail") public static Operand tail( - Scope scope, Shape shape, DataType dType) { - return takeLast(scope, shape, Cast.create(scope, Constant.scalarOf(scope, 1), dType), dType); + Scope scope, Shape shape, Class type) { + return takeLast(scope, shape, Cast.create(scope, Constant.scalarOf(scope, 1), type), type); } /** @@ -472,7 +468,7 @@ public static Operand tail( @Endpoint(name = "takeLast") public static Operand takeLast( Scope scope, Shape shape, Operand n) { - return takeLast(scope, shape, n, TInt32.DTYPE); + return takeLast(scope, shape, n, TInt32.class); } /** @@ -482,16 +478,16 @@ public static Operand takeLast( * @param scope current scope * @param shape the TensorFlow shape * @param n the number of leading dimensions to get, must be <= than the shape's numDimensions() - * @param dType the shape datatype. + * @param type the shape datatype. * @param the shape datatype. * @return a 1-dimensional operand containing the dimensions matching the last n dimensions of the * shape */ @Endpoint(name = "takeLast") public static Operand takeLast( - Scope scope, Shape shape, Operand n, DataType dType) { + Scope scope, Shape shape, Operand n, Class type) { - Size rank = Size.create(scope, shape, dType); + Size rank = Size.create(scope, shape, type); Sub start = Sub.create(scope, rank, n); return Slice.create( scope, @@ -499,7 +495,7 @@ public static Operand takeLast( ExpandDims.create(scope, start, Constant.scalarOf(scope, -1)), ExpandDims.create( scope, - Cast.create(scope, Constant.scalarOf(scope, -1), dType), + Cast.create(scope, Constant.scalarOf(scope, -1), type), Constant.scalarOf(scope, -1))); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Zeros.java index 4aad417b117..a5b5bb137c2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Zeros.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Zeros.java @@ -14,7 +14,6 @@ ==============================================================================*/ package org.tensorflow.op.core; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.Output; @@ -31,7 +30,7 @@ * An operator creating a constant initialized with zeros of the shape given by `dims`. * *

For example, the following expression - *

{@code tf.zeros(tf.constant(shape), TFloat32.DTYPE)
+ *
{@code tf.zeros(tf.constant(shape), TFloat32.class)
* is the equivalent of *
{@code tf.fill(tf.constant(shape), tf.constant(0.0f))
* @@ -51,10 +50,10 @@ public final class Zeros implements Op, Operand { */ @Endpoint @SuppressWarnings("unchecked") - public static Zeros create(Scope scope, Operand dims, DataType type) { + public static Zeros create(Scope scope, Operand dims, Class type) { Scope zerosScope = scope.withSubScope("Zeros"); Operand zero; - if (type == TString.DTYPE) { + if (type == TString.class) { zero = (Operand)Constant.scalarOf(zerosScope.withName("Zero"), ""); } else { zero = Cast.create(zerosScope.withName("Zero"), Constant.scalarOf(zerosScope, 0), type); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java index e823ed9f6bd..92c413f7e52 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java @@ -72,7 +72,7 @@ public static Operand sigmoidCrossEntropyWithLogits( scope = scope.withSubScope("SigmoidCrossEntropyWithLogits"); Operand zeros = - Cast.create(scope, ZerosLike.create(scope, logits), logits.asOutput().dataType()); + Cast.create(scope, ZerosLike.create(scope, logits), logits.asOutput().type()); Operand cond = GreaterEqual.create(scope, logits, zeros); Operand reluLogits = Select.create(scope, cond, logits, zeros); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java index 67cbe3fb98c..ddeacbea4d4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -1,6 +1,5 @@ package org.tensorflow.op.nn; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; @@ -78,24 +77,22 @@ public static Operand softmaxCrossEntr axis += logits.shape().numDimensions(); } - - boolean convertToFloat32 = - logits.asOutput().dataType() == TFloat16.DTYPE - || logits.asOutput().dataType() == TBfloat16.DTYPE; - if (convertToFloat32) { + if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { Operand result = softmaxCrossEntropyWithLogits(scope, - Cast.create(scope, labels, TFloat32.DTYPE), - Cast.create(scope, logits, TFloat32.DTYPE), + Cast.create(scope, labels, TFloat32.class), + Cast.create(scope, logits, TFloat32.class), axis); - return Cast.create(scope, result, logits.asOutput().dataType()); - } else if(!logits.asOutput().dataType().equals(labels.asOutput().dataType())) { + return Cast.create(scope, result, logits.asOutput().type()); + } + + if (logits.asOutput().type() != labels.asOutput().type()) { return softmaxCrossEntropyWithLogits(scope, - Cast.create(scope, labels, logits.asOutput().dataType()), + Cast.create(scope, labels, logits.asOutput().type()), logits, axis); } - Operand inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.DTYPE); + Operand inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.class); Shape shape = logits.shape(); // Move the dim to the end if dim is not the last dimension. @@ -167,13 +164,13 @@ private static Operand flattenOuterDims(Scope scope, Oper } } - Operand rank = Cast.create(scope, Rank.create(scope, logits), TInt64.DTYPE); + Operand rank = Cast.create(scope, Rank.create(scope, logits), TInt64.class); Operand rankMinusOne = Sub.create(scope, rank, one); Operand lastDimSize = Slice.create( scope, - org.tensorflow.op.core.Shape.create(scope, logits, TInt64.DTYPE), + org.tensorflow.op.core.Shape.create(scope, logits, TInt64.class), rankMinusOne, one); Operand concat = @@ -197,15 +194,15 @@ private static Operand flattenOuterDims(Scope scope, Oper */ private static Operand moveDimToEnd( Scope scope, Operand input, int dimIndex, Operand rank) { - DataType rankDType = rank.asOutput().dataType(); - Operand one = Cast.create(scope, Constant.scalarOf(scope, 1), rankDType); + Class rankType = rank.asOutput().type(); + Operand one = Cast.create(scope, Constant.scalarOf(scope, 1), rankType); List> concatList = Arrays.asList( Range.create( - scope, Cast.create(scope, Constant.scalarOf(scope, dimIndex), rankDType), one, one), + scope, Cast.create(scope, Constant.scalarOf(scope, dimIndex), rankType), one, one), Range.create( scope, - Cast.create(scope, Constant.scalarOf(scope, dimIndex + 1), rankDType), + Cast.create(scope, Constant.scalarOf(scope, dimIndex + 1), rankType), one, one)); return Transpose.create( diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index 3598edbe223..54b32bb5c63 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -74,11 +74,8 @@ public static Operand sparseSoftmaxCrossE scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits"); /** cannot use generics on preciseLogits as it may be recast later */ Operand preciseLogits = logits; - boolean convertToFloat32 = - logits.asOutput().dataType() == TFloat16.DTYPE - || logits.asOutput().dataType() == TBfloat16.DTYPE; - if (convertToFloat32) { - preciseLogits = Cast.create(scope, logits, TFloat32.DTYPE); + if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { + preciseLogits = Cast.create(scope, logits, TFloat32.class); } Shape labelsStaticShape = labels.shape(); org.tensorflow.op.core.Shape labelsShape = @@ -115,8 +112,8 @@ public static Operand sparseSoftmaxCrossE org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create( scope, preciseLogits, labels); Operand loss = smax.loss(); - if (logits.asOutput().dataType() == TFloat16.DTYPE) { - loss = Cast.create(scope, loss, TFloat16.DTYPE); + if (logits.asOutput().type() == TFloat16.class) { + loss = Cast.create(scope, loss, TFloat16.class); } return loss; } @@ -153,8 +150,8 @@ public static Operand sparseSoftmaxCrossE scope, preciseLogits, labels); Operand cost = smax.loss(); cost = Reshape.create(scope, cost, labelsShape); - if (logits.asOutput().dataType() == TFloat16.DTYPE) { - cost = Cast.create(scope, cost, TFloat16.DTYPE); + if (logits.asOutput().type() == TFloat16.class) { + cost = Cast.create(scope, cost, TFloat16.class); } return cost; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java index 94c6f8790b5..ef20b5ec2b6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java @@ -18,7 +18,6 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TBfloat16Mapper; @@ -27,6 +26,8 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TFloating; /** @@ -45,12 +46,8 @@ *

Note that some CPUs support the bfloat16 format natively, which can result in faster * computation compared to {@link TFloat16} when GPUs are not used. */ +@TensorType(dataType = DataType.DT_BFLOAT16, byteSize = 2, mapperClass = TBfloat16Mapper.class) public interface TBfloat16 extends FloatNdArray, TFloating { - /** readable-name for the data type */ - static final String NAME = "BFLOAT16"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 14, 2, new TBfloat16Mapper()); /** * Allocates a new tensor for storing a single float value. @@ -59,7 +56,7 @@ public interface TBfloat16 extends FloatNdArray, TFloating { * @return the new tensor */ static TBfloat16 scalarOf(float value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setFloat(value)); + return Tensor.of(TBfloat16.class, Shape.scalar(), data -> data.setFloat(value)); } /** @@ -72,7 +69,7 @@ static TBfloat16 vectorOf(float... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(TBfloat16.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -84,7 +81,7 @@ static TBfloat16 vectorOf(float... values) { * @return the new tensor */ static TBfloat16 tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + return Tensor.of(TBfloat16.class, src.shape(), src::copyTo); } /** @@ -94,7 +91,7 @@ static TBfloat16 tensorOf(NdArray src) { * @return the new tensor */ static TBfloat16 tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + return Tensor.of(TBfloat16.class, shape); } /** @@ -105,7 +102,7 @@ static TBfloat16 tensorOf(Shape shape) { * @return the new tensor */ static TBfloat16 tensorOf(Shape shape, FloatDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + return Tensor.of(TBfloat16.class, shape, d -> d.write(data)); } /** @@ -117,7 +114,7 @@ static TBfloat16 tensorOf(Shape shape, FloatDataBuffer data) { * @throws TensorFlowException if the tensor cannot be allocated or initialized */ static TBfloat16 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); + return Tensor.of(TBfloat16.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java index bab5e7910b4..0158c12b910 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java @@ -18,12 +18,8 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; -import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.internal.types.TBoolMapper; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.NdArray; @@ -31,7 +27,8 @@ import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.BooleanDataBuffer; import org.tensorflow.ndarray.buffer.layout.DataLayouts; -import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TType; /** @@ -41,12 +38,8 @@ * explicit mapping between Java boolean values and byte buffers using the {@link DataLayouts#BOOL * BOOL} layout, which may impact I/O performances. */ +@TensorType(dataType = DataType.DT_BOOL, byteSize = 1, mapperClass = TBoolMapper.class) public interface TBool extends BooleanNdArray, TType { - /** readable-name for the data type */ - static final String NAME = "BOOL"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 10, 1, new TBoolMapper()); /** * Allocates a new tensor for storing a single boolean value. @@ -55,7 +48,7 @@ public interface TBool extends BooleanNdArray, TType { * @return the new tensor */ static TBool scalarOf(boolean value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setBoolean(value)); + return Tensor.of(TBool.class, Shape.scalar(), data -> data.setBoolean(value)); } /** @@ -68,7 +61,7 @@ static TBool vectorOf(boolean... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(TBool.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -80,7 +73,7 @@ static TBool vectorOf(boolean... values) { * @return the new tensor */ static TBool tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + return Tensor.of(TBool.class, src.shape(), src::copyTo); } /** @@ -90,7 +83,7 @@ static TBool tensorOf(NdArray src) { * @return the new tensor */ static TBool tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + return Tensor.of(TBool.class, shape); } /** @@ -101,7 +94,7 @@ static TBool tensorOf(Shape shape) { * @return the new tensor */ static TBool tensorOf(Shape shape, BooleanDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + return Tensor.of(TBool.class, shape, d -> d.write(data)); } /** @@ -113,6 +106,6 @@ static TBool tensorOf(Shape shape, BooleanDataBuffer data) { * @throws TensorFlowException if the tensor cannot be allocated or initialized */ static TBool tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); + return Tensor.of(TBool.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java index 0decbb66d12..a43a0831f10 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java @@ -18,7 +18,6 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TFloat16Mapper; @@ -27,6 +26,8 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TFloating; /** @@ -42,14 +43,9 @@ * most CPUs do not support this format natively. For CPU computation on 16-bit floats, the {@link * TBfloat16} tensor type might be a better option. */ +@TensorType(dataType = DataType.DT_HALF, byteSize = 2, mapperClass = TFloat16Mapper.class) public interface TFloat16 extends FloatNdArray, TFloating { - /** readable-name for the data type */ - static final String NAME = "FLOAT16"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 19, 2, new TFloat16Mapper()); - /** * Allocates a new tensor for storing a single float value. * @@ -57,7 +53,7 @@ public interface TFloat16 extends FloatNdArray, TFloating { * @return the new tensor */ static TFloat16 scalarOf(float value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setFloat(value)); + return Tensor.of(TFloat16.class, Shape.scalar(), data -> data.setFloat(value)); } /** @@ -70,7 +66,7 @@ static TFloat16 vectorOf(float... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(TFloat16.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -82,7 +78,7 @@ static TFloat16 vectorOf(float... values) { * @return the new tensor */ static TFloat16 tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + return Tensor.of(TFloat16.class, src.shape(), src::copyTo); } /** @@ -92,7 +88,7 @@ static TFloat16 tensorOf(NdArray src) { * @return the new tensor */ static TFloat16 tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + return Tensor.of(TFloat16.class, shape); } /** @@ -103,7 +99,7 @@ static TFloat16 tensorOf(Shape shape) { * @return the new tensor */ static TFloat16 tensorOf(Shape shape, FloatDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + return Tensor.of(TFloat16.class, shape, d -> d.write(data)); } /** @@ -115,6 +111,6 @@ static TFloat16 tensorOf(Shape shape, FloatDataBuffer data) { * @throws TensorFlowException if the tensor cannot be allocated or initialized */ static TFloat16 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); + return Tensor.of(TFloat16.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java index 6300650811e..35208f7de43 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java @@ -18,7 +18,6 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TFloat32Mapper; @@ -27,17 +26,14 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TFloating; /** IEEE-754 single-precision 32-bit float tensor type. */ +@TensorType(dataType = DataType.DT_FLOAT, byteSize = 4, mapperClass = TFloat32Mapper.class) public interface TFloat32 extends FloatNdArray, TFloating { - /** readable-name for the data type */ - static final String NAME = "FLOAT"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 1, 4, new TFloat32Mapper()); - /** * Allocates a new tensor for storing a single float value. * @@ -45,7 +41,7 @@ public interface TFloat32 extends FloatNdArray, TFloating { * @return the new tensor */ static TFloat32 scalarOf(float value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setFloat(value)); + return Tensor.of(TFloat32.class, Shape.scalar(), data -> data.setFloat(value)); } /** @@ -58,7 +54,7 @@ static TFloat32 vectorOf(float... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(TFloat32.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -70,7 +66,7 @@ static TFloat32 vectorOf(float... values) { * @return the new tensor */ static TFloat32 tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + return Tensor.of(TFloat32.class, src.shape(), src::copyTo); } /** @@ -80,7 +76,7 @@ static TFloat32 tensorOf(NdArray src) { * @return the new tensor */ static TFloat32 tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + return Tensor.of(TFloat32.class, shape); } /** @@ -91,7 +87,7 @@ static TFloat32 tensorOf(Shape shape) { * @return the new tensor */ static TFloat32 tensorOf(Shape shape, FloatDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + return Tensor.of(TFloat32.class, shape, d -> d.write(data)); } /** @@ -103,6 +99,6 @@ static TFloat32 tensorOf(Shape shape, FloatDataBuffer data) { * @throws TensorFlowException if the tensor cannot be allocated or initialized */ static TFloat32 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); + return Tensor.of(TFloat32.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java index 923b9992400..957612691e5 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java @@ -18,31 +18,23 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; -import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.internal.types.TFloat64Mapper; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.DoubleDataBuffer; -import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TFloating; /** IEEE-754 double-precision 64-bit float tensor type. */ +@TensorType(dataType = DataType.DT_DOUBLE, byteSize = 8, mapperClass = TFloat64Mapper.class) public interface TFloat64 extends DoubleNdArray, TFloating { - /** readable-name for the data type */ - static final String NAME = "DOUBLE"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 2, 8, new TFloat64Mapper()); - /** * Allocates a new tensor for storing a single double value. * @@ -50,7 +42,7 @@ public interface TFloat64 extends DoubleNdArray, TFloating { * @return the new tensor */ static TFloat64 scalarOf(double value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setDouble(value)); + return Tensor.of(TFloat64.class, Shape.scalar(), data -> data.setDouble(value)); } /** @@ -63,7 +55,7 @@ static TFloat64 vectorOf(double... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(TFloat64.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -75,7 +67,7 @@ static TFloat64 vectorOf(double... values) { * @return the new tensor */ static TFloat64 tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + return Tensor.of(TFloat64.class, src.shape(), src::copyTo); } /** @@ -85,7 +77,7 @@ static TFloat64 tensorOf(NdArray src) { * @return the new tensor */ static TFloat64 tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + return Tensor.of(TFloat64.class, shape); } /** @@ -96,7 +88,7 @@ static TFloat64 tensorOf(Shape shape) { * @return the new tensor */ static TFloat64 tensorOf(Shape shape, DoubleDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + return Tensor.of(TFloat64.class, shape, d -> d.write(data)); } /** @@ -108,6 +100,6 @@ static TFloat64 tensorOf(Shape shape, DoubleDataBuffer data) { * @throws TensorFlowException if the tensor cannot be allocated or initialized */ static TFloat64 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); + return Tensor.of(TFloat64.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java index ccb865e2793..8f6b587795b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java @@ -18,28 +18,20 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; -import org.tensorflow.RawTensor; import org.tensorflow.Tensor; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.internal.types.TInt32Mapper; import org.tensorflow.ndarray.IntNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.IntDataBuffer; -import org.tensorflow.ndarray.impl.dense.IntDenseNdArray; -import org.tensorflow.types.family.TNumber; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; +import org.tensorflow.types.family.TIntegral; /** 32-bit signed integer tensor type. */ -public interface TInt32 extends IntNdArray, TNumber { - - /** readable-name for the data type */ - static final String NAME = "INT32"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 3, 4, new TInt32Mapper()); +@TensorType(dataType = DataType.DT_INT32, byteSize = 4, mapperClass = TInt32Mapper.class) +public interface TInt32 extends IntNdArray, TIntegral { /** * Allocates a new tensor for storing a single int value. @@ -48,7 +40,7 @@ public interface TInt32 extends IntNdArray, TNumber { * @return the new tensor */ static TInt32 scalarOf(int value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setInt(value)); + return Tensor.of(TInt32.class, Shape.scalar(), data -> data.setInt(value)); } /** @@ -62,7 +54,7 @@ static TInt32 vectorOf(int... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(TInt32.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -74,7 +66,7 @@ static TInt32 vectorOf(int... values) { * @return the new tensor */ static TInt32 tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + return Tensor.of(TInt32.class, src.shape(), src::copyTo); } /** @@ -84,7 +76,7 @@ static TInt32 tensorOf(NdArray src) { * @return the new tensor */ static TInt32 tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + return Tensor.of(TInt32.class, shape); } /** @@ -95,7 +87,7 @@ static TInt32 tensorOf(Shape shape) { * @return the new tensor */ static TInt32 tensorOf(Shape shape, IntDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + return Tensor.of(TInt32.class, shape, d -> d.write(data)); } /** @@ -106,7 +98,7 @@ static TInt32 tensorOf(Shape shape, IntDataBuffer data) { * @return the new tensor */ static TInt32 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); + return Tensor.of(TInt32.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java index 02763391ff6..867248c5392 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java @@ -18,29 +18,21 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; -import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.internal.types.TInt64Mapper; import org.tensorflow.ndarray.LongNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.LongDataBuffer; -import org.tensorflow.ndarray.impl.dense.LongDenseNdArray; -import org.tensorflow.types.family.TNumber; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; +import org.tensorflow.types.family.TIntegral; /** 64-bit signed integer tensor type. */ -public interface TInt64 extends LongNdArray, TNumber { - - /** readable-name for the data type */ - static final String NAME = "INT64"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 9, 8, new TInt64Mapper()); +@TensorType(dataType = DataType.DT_INT64, byteSize = 8, mapperClass = TInt64Mapper.class) +public interface TInt64 extends LongNdArray, TIntegral { /** * Allocates a new tensor for storing a single long value. @@ -49,7 +41,7 @@ public interface TInt64 extends LongNdArray, TNumber { * @return the new tensor */ static TInt64 scalarOf(long value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setLong(value)); + return Tensor.of(TInt64.class, Shape.scalar(), data -> data.setLong(value)); } /** @@ -62,7 +54,7 @@ static TInt64 vectorOf(long... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(TInt64.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -74,7 +66,7 @@ static TInt64 vectorOf(long... values) { * @return the new tensor */ static TInt64 tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + return Tensor.of(TInt64.class, src.shape(), src::copyTo); } /** @@ -84,7 +76,7 @@ static TInt64 tensorOf(NdArray src) { * @return the new tensor */ static TInt64 tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + return Tensor.of(TInt64.class, shape); } /** @@ -95,7 +87,7 @@ static TInt64 tensorOf(Shape shape) { * @return the new tensor */ static TInt64 tensorOf(Shape shape, LongDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + return Tensor.of(TInt64.class, shape, d -> d.write(data)); } /** @@ -107,6 +99,6 @@ static TInt64 tensorOf(Shape shape, LongDataBuffer data) { * @throws TensorFlowException if the tensor cannot be allocated or initialized */ static TInt64 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); + return Tensor.of(TInt64.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java index 6d7f7426c1a..b3000cc2f8a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java @@ -20,7 +20,6 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.function.Function; -import org.tensorflow.DataType; import org.tensorflow.Tensor; import org.tensorflow.internal.types.TStringInitializer; import org.tensorflow.internal.types.TStringMapper; @@ -28,6 +27,8 @@ import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TType; /** @@ -39,14 +40,9 @@ * its values initially, so TensorFlow can compute and allocate the right amount of memory. Then the * data in the tensor is initialized once and cannot be modified afterwards. */ +@TensorType(dataType = DataType.DT_STRING, byteSize = -1, mapperClass = TStringMapper.class) public interface TString extends NdArray, TType { - /** readable-name for the data type */ - static final String NAME = "STRING"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 7, -1, new TStringMapper()); - /** * Allocates a new tensor for storing a string scalar. * @@ -110,7 +106,7 @@ static TString tensorOf(NdArray src) { */ static TString tensorOf(Charset charset, NdArray src) { TStringInitializer initializer = new TStringInitializer<>(src, s -> s.getBytes(charset)); - return Tensor.of(TString.DTYPE, src.shape(), initializer.computeRequiredSize(), initializer); + return Tensor.of(TString.class, src.shape(), initializer.computeRequiredSize(), initializer); } /** @@ -171,7 +167,7 @@ static TString tensorOf(Charset charset, Shape shape, DataBuffer data) { */ static TString tensorOfBytes(NdArray src) { TStringInitializer initializer = new TStringInitializer<>(src, Function.identity()); - return Tensor.of(TString.DTYPE, src.shape(), initializer.computeRequiredSize(), initializer); + return Tensor.of(TString.class, src.shape(), initializer.computeRequiredSize(), initializer); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java index a6f5dba8971..eae86414cb4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java @@ -18,29 +18,21 @@ package org.tensorflow.types; import java.util.function.Consumer; -import org.tensorflow.DataType; -import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.internal.types.TUint8Mapper; import org.tensorflow.ndarray.ByteNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.ByteDataBuffer; -import org.tensorflow.ndarray.impl.dense.ByteDenseNdArray; -import org.tensorflow.types.family.TNumber; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.annotation.TensorType; +import org.tensorflow.types.family.TIntegral; /** 8-bit unsigned integer tensor type. */ -public interface TUint8 extends ByteNdArray, TNumber { - - /** readable-name for the data type */ - static final String NAME = "UINT8"; - - /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 4, 1, new TUint8Mapper()); +@TensorType(dataType = DataType.DT_UINT8, byteSize = 1, mapperClass = TUint8Mapper.class) +public interface TUint8 extends ByteNdArray, TIntegral { /** * Allocates a new tensor for storing a single byte value. @@ -49,7 +41,7 @@ public interface TUint8 extends ByteNdArray, TNumber { * @return the new tensor */ static TUint8 scalarOf(byte value) { - return Tensor.of(DTYPE, Shape.scalar(), data -> data.setByte(value)); + return Tensor.of(TUint8.class, Shape.scalar(), data -> data.setByte(value)); } /** @@ -62,7 +54,7 @@ static TUint8 vectorOf(byte... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(DTYPE, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(TUint8.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -74,7 +66,7 @@ static TUint8 vectorOf(byte... values) { * @return the new tensor */ static TUint8 tensorOf(NdArray src) { - return Tensor.of(DTYPE, src.shape(), src::copyTo); + return Tensor.of(TUint8.class, src.shape(), src::copyTo); } /** @@ -84,7 +76,7 @@ static TUint8 tensorOf(NdArray src) { * @return the new tensor */ static TUint8 tensorOf(Shape shape) { - return Tensor.of(DTYPE, shape); + return Tensor.of(TUint8.class, shape); } /** @@ -95,7 +87,7 @@ static TUint8 tensorOf(Shape shape) { * @return the new tensor */ static TUint8 tensorOf(Shape shape, ByteDataBuffer data) { - return Tensor.of(DTYPE, shape, d -> d.write(data)); + return Tensor.of(TUint8.class, shape, d -> d.write(data)); } /** @@ -107,6 +99,6 @@ static TUint8 tensorOf(Shape shape, ByteDataBuffer data) { * @throws TensorFlowException if the tensor cannot be allocated or initialized */ static TUint8 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(DTYPE, shape, dataInit); + return Tensor.of(TUint8.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/annotation/TensorType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/annotation/TensorType.java new file mode 100644 index 00000000000..78ab5d7a8b6 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/annotation/TensorType.java @@ -0,0 +1,53 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.types.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import org.tensorflow.TensorMapper; +import org.tensorflow.proto.framework.DataType; + +/** + * Annotation for all tensor types. + * + *

Any interface extending {@link org.tensorflow.types.family.TType TType} to be registered as a + * tensor type must be annotated with {@code @TensorType} to provide metadata required for allocating + * and mapping tensors of this type.

+ */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface TensorType { + + /** + * The data type of each elements in a tensor of this type + */ + DataType dataType(); + + /** + * The number of bytes required one element of a tensor of type, -1 for variable-length element tensors + */ + int byteSize(); + + /** + * The class of the {@link TensorMapper} to allocate and use for mapping tensors of this type + */ + Class> mapperClass(); +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TIntegral.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TIntegral.java new file mode 100644 index 00000000000..3652ea4613c --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TIntegral.java @@ -0,0 +1,25 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.types.family; + +/** + * Common interface for all integral numeric tensors. + * + *

Operations that only accepts integral values as some of their operands enforce that the tensor + * types for these operands to be bound to this interface. + */ +public interface TIntegral extends TNumber {} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java index 21d275296c8..2fc423b914e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java @@ -18,6 +18,7 @@ package org.tensorflow.types.family; import org.tensorflow.Tensor; +import org.tensorflow.proto.framework.DataType; /** * Common interface for all typed tensors. @@ -60,6 +61,16 @@ */ public interface TType extends Tensor { + /** + * Returns the type of this tensor as a registered subclass of {@code TType} + */ + Class type(); + + @Override + default DataType dataType() { + return asRawTensor().dataType(); + } + @Override default long numBytes() { return asRawTensor().numBytes(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/package-info.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/package-info.java index afbd69fabe5..746ae703694 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/package-info.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/package-info.java @@ -25,17 +25,14 @@ * bound to one of the marker interface found in {@link org.tensorflow.types.family}, according * to the nature of the data. * - *

Each tensor type must provide a static instance of {@link org.tensorflow.DataType} - * carrying type metadata that should be used for allocating a tensor of this type or to pass - * this type as an operation argument. For example, metadata about TensorFlow int32 type is - * found in {@link org.tensorflow.types.TInt32#DTYPE TInt32.DTYPE}. + *

Each tensor type must be annotated with {@link org.tensorflow.types.annotation.TensorType} to + * provide type metadata that should be used for allocating or mapping tensors of this type. * *

Instances of tensor types must also implement the {@link org.tensorflow.ndarray.NdArray NdArray} - * interface so a user can access directly the tensor data in a n-dimensional space by invoking - * {@link org.tensorflow.Tensor#data() Tensor.data()}. + * interface so a user can access directly the tensor data in a n-dimensional space. * *

Note that while it is always possible to allocate a tensor using the - * {@link org.tensorflow.Tensor#of(org.tensorflow.DataType, Shape) Tensor.of(...)} + * {@link org.tensorflow.Tensor#of(Class, Shape) Tensor.of(...)} * method, most tensor types expose factory methods that simplify the creation process, like * {@code scalarOf(...)}, {@code vectorOf(...)}, {@code tensorOf(...)}, etc. */ diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index 72dcfb12430..b2b2c34e223 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -16,7 +16,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; @@ -29,14 +28,14 @@ public class ConcreteFunctionTest { private static Signature plusFive(Ops tf) { - Placeholder input = tf.placeholder(TFloat32.DTYPE); + Placeholder input = tf.placeholder(TFloat32.class); Add output = tf.math.add(input, tf.constant(5.0f)); Init init = tf.init(); // for native resource management tests return Signature.builder().key("plusFive").input("x", input).output("y", output).build(); } private static Signature minusTwo(Ops tf) { - Placeholder input = tf.placeholder(TFloat32.DTYPE); + Placeholder input = tf.placeholder(TFloat32.class); Sub output = tf.math.sub(input, tf.constant(2.0f)); return Signature.builder().key("minusTwo").input("x", input).output("y", output).build(); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java index 6751c513ef3..b39ecec9881 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java @@ -18,9 +18,9 @@ import static org.junit.jupiter.api.Assertions.fail; import org.junit.jupiter.api.Test; -import org.tensorflow.op.Ops; import org.tensorflow.ndarray.Shape; -import org.tensorflow.types.TFloat32; +import org.tensorflow.op.Ops; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TInt32; /** Unit tests for {@link EagerOperationBuilder} class. */ @@ -45,7 +45,7 @@ public void failToBuildOpIfSessionIsClosed() { opBuilder = new EagerOperationBuilder(session, "Empty", "empty"); } try { - opBuilder.setAttr("dtype", TFloat32.DTYPE); + opBuilder.setAttr("dtype", DataType.DT_FLOAT); fail(); } catch (IllegalStateException e) { // expected @@ -90,7 +90,7 @@ public void setAttrs() { // dtype, tensor attributes. try (TInt32 t = TInt32.scalarOf(1)) { opBuilder(session, "Const", "DataTypeAndTensor") - .setAttr("dtype", TInt32.DTYPE) + .setAttr("dtype", t.dataType()) .setAttr("value", t) .build(); } @@ -98,7 +98,7 @@ public void setAttrs() { opBuilder(session, "RandomUniform", "DataTypeAndInt") .addInput(tf.array(1).asOutput()) .setAttr("seed", 10) - .setAttr("dtype", TFloat32.DTYPE) + .setAttr("dtype", DataType.DT_FLOAT) .build(); // list(int), string opBuilder(session, "MaxPool", "IntListAndString") @@ -119,7 +119,7 @@ public void setAttrs() { .build(); // list(shape) opBuilder(session, "FIFOQueue", "queue") - .setAttr("component_types", new DataType[] {TInt32.DTYPE, TInt32.DTYPE}) + .setAttr("component_types", new DataType[] {DataType.DT_INT32, DataType.DT_INT32}) .setAttr("shapes", new Shape[] {Shape.of(2, 2), Shape.of(2, 2, 2)}) .build(); // bool diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java index a14a295fddd..2920fbdf59f 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java @@ -19,10 +19,10 @@ import static org.junit.jupiter.api.Assertions.fail; import org.junit.jupiter.api.Test; -import org.tensorflow.exceptions.TFFailedPreconditionException; import org.tensorflow.exceptions.TFInvalidArgumentException; -import org.tensorflow.op.Ops; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -49,10 +49,10 @@ public void outputDataTypeAndShape() { TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) { EagerOperation op = opBuilder(session, "Const", "OutputAttrs") - .setAttr("dtype", TInt32.DTYPE) + .setAttr("dtype", t.dataType()) .setAttr("value", t) .build(); - assertEquals(TInt32.DTYPE, op.dtype(0)); + assertEquals(DataType.DT_INT32, op.dtype(0)); assertEquals(2, op.shape(0).size(0)); assertEquals(3, op.shape(0).size(1)); } @@ -72,7 +72,7 @@ public void outputTensor() { // Validate that we retrieve the right shape and datatype from the tensor // that has been resolved assertEquals(0, add.shape(0).numDimensions()); - assertEquals(TInt32.DTYPE, add.dtype(0)); + assertEquals(DataType.DT_INT32, add.dtype(0)); } } @@ -123,7 +123,7 @@ public void numOutputs() { opBuilder(session, "UniqueWithCountsV2", "unq") .addInput(tf.constant(new int[]{1, 2, 1}).asOutput()) .addInput(tf.constant(new int[]{0}).asOutput()) - .setAttr("out_idx", TInt32.DTYPE) + .setAttr("out_idx", DataType.DT_INT32) .build(); assertEquals(3, op.numOutputs()); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java index 7573a25ac13..bbb9e23ec90 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java @@ -21,11 +21,11 @@ import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TFInvalidArgumentException; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Constant; -import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; /** Unit tests for {@link org.tensorflow.GraphOperationBuilder}. */ @@ -90,7 +90,7 @@ public void setAttr() { // dtype, tensor attributes. try (TInt32 t = TInt32.scalarOf(1)) { g.opBuilder("Const", "DataTypeAndTensor") - .setAttr("dtype", TInt32.DTYPE) + .setAttr("dtype", t.dataType()) .setAttr("value", t) .build() .output(0); @@ -106,7 +106,7 @@ public void setAttr() { g.opBuilder("RandomUniform", "Int") .addInput(tf.array(1).asOutput()) .setAttr("seed", 10) - .setAttr("dtype", TFloat32.DTYPE) + .setAttr("dtype", DataType.DT_FLOAT) .build(); assertTrue(hasNode(g, "Int")); // list(int) @@ -132,23 +132,23 @@ public void setAttrShape() { try (Graph g = new Graph()) { Output n = g.opBuilder("Placeholder", "unknown") - .setAttr("dtype", TFloat32.DTYPE) + .setAttr("dtype", DataType.DT_FLOAT) .setAttr("shape", Shape.unknown()) .build() .output(0); assertEquals(-1, n.shape().numDimensions()); - assertEquals(TFloat32.DTYPE, n.dataType()); + assertEquals(DataType.DT_FLOAT, n.dataType()); n = g.opBuilder("Placeholder", "batch_of_vectors") - .setAttr("dtype", TFloat32.DTYPE) + .setAttr("dtype", DataType.DT_FLOAT) .setAttr("shape", Shape.of(-1, 784)) .build() .output(0); assertEquals(2, n.shape().numDimensions()); assertEquals(-1, n.shape().size(0)); assertEquals(784, n.shape().size(1)); - assertEquals(TFloat32.DTYPE, n.dataType()); + assertEquals(DataType.DT_FLOAT, n.dataType()); } } @@ -172,7 +172,7 @@ public void addControlInput() { TBool yes = TBool.scalarOf(true); TBool no = TBool.scalarOf(false)) { Ops tf = Ops.create(g); - Output placeholder = tf.placeholder(TBool.DTYPE).asOutput(); + Output placeholder = tf.placeholder(TBool.class).asOutput(); GraphOperation check = g.opBuilder("Assert", "assert") .addInput(placeholder) @@ -200,7 +200,7 @@ private static void testSetAttrShapeList(Shape[] shapes) { int[][] matrix = new int[][] {{0, 0}, {0, 0}}; Output queue = g.opBuilder("FIFOQueue", "queue") - .setAttr("component_types", new DataType[] {TInt32.DTYPE, TInt32.DTYPE}) + .setAttr("component_types", new DataType[] {DataType.DT_INT32, DataType.DT_INT32}) .setAttr("shapes", shapes) .build() .output(0); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index a9eed79041a..d8ffc1a475b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -27,6 +27,7 @@ import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.op.Ops; import org.tensorflow.op.linalg.MatMul; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.GraphDef; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -42,7 +43,7 @@ public void graphDefRoundTrip() { Ops tf = Ops.create(g); tf.withName("Y").linalg.matMul( tf.withName("A").constant(new int[2][2]), - tf.withName("X").placeholder(TInt32.DTYPE), + tf.withName("X").placeholder(TInt32.class), MatMul.transposeA(true).transposeB(false) ); graphDef = g.toGraphDef(); @@ -140,8 +141,8 @@ public void addGradientsToGraph() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output x1 = tf.placeholder(TFloat32.DTYPE).output(); - Output x2 = tf.placeholder(TFloat32.DTYPE).output(); + Output x1 = tf.placeholder(TFloat32.class).output(); + Output x2 = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x1).y(); Output y1 = tf.math.square(y0).y(); Output y2 = tf.math.addN(Arrays.asList(y0, x2)).sum(); @@ -149,13 +150,13 @@ public void addGradientsToGraph() { Output[] grads0 = g.addGradients(y1, toArray(x1)); assertNotNull(grads0); assertEquals(1, grads0.length); - assertEquals(TFloat32.DTYPE, grads0[0].dataType()); + assertEquals(DataType.DT_FLOAT, grads0[0].dataType()); Output[] grads1 = g.addGradients(y2, toArray(x1, x2)); assertNotNull(grads1); assertEquals(2, grads1.length); - assertEquals(TFloat32.DTYPE, grads1[0].dataType()); - assertEquals(TFloat32.DTYPE, grads1[1].dataType()); + assertEquals(DataType.DT_FLOAT, grads1[0].dataType()); + assertEquals(DataType.DT_FLOAT, grads1[1].dataType()); try (TFloat32 c1 = TFloat32.scalarOf(3.0f); TFloat32 c2 = TFloat32.scalarOf(2.0f); @@ -182,14 +183,14 @@ public void addGradientSumsToGraph() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); Output[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null); assertNotNull(grad); assertEquals(1, grad.length); - assertEquals(TFloat32.DTYPE, grad[0].dataType()); + assertEquals(DataType.DT_FLOAT, grad[0].dataType()); try (TFloat32 c = TFloat32.scalarOf(3.0f); TFloat32 output = (TFloat32)s.runner() @@ -208,19 +209,19 @@ public void addGradientsWithInitialValuesToGraph() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); Output[] grad0 = g.addGradients(y1, toArray(y0)); assertNotNull(grad0); assertEquals(1, grad0.length); - assertEquals(TFloat32.DTYPE, grad0[0].dataType()); + assertEquals(DataType.DT_FLOAT, grad0[0].dataType()); Output[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0])); assertNotNull(grad1); assertEquals(1, grad1.length); - assertEquals(TFloat32.DTYPE, grad1[0].dataType()); + assertEquals(DataType.DT_FLOAT, grad1[0].dataType()); try (TFloat32 c = TFloat32.scalarOf(3.0f); TFloat32 output = (TFloat32)s.runner() @@ -238,7 +239,7 @@ public void validateGradientsNames() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null); @@ -267,7 +268,7 @@ public void buildWhileLoopSingleInput() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output input = tf.placeholder(TInt32.DTYPE).output(); + Output input = tf.placeholder(TInt32.class).output(); @SuppressWarnings("unchecked") Output[] loopOutputs = g.whileLoop( @@ -299,8 +300,8 @@ public void buildWhileLoopMultipleInputs() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output input1 = tf.placeholder(TInt32.DTYPE).output(); - Output input2 = tf.placeholder(TInt32.DTYPE).output(); + Output input1 = tf.placeholder(TInt32.class).output(); + Output input2 = tf.placeholder(TInt32.class).output(); Output[] inputs = toArray(input1, input2); @SuppressWarnings("unchecked") diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/RawTensorTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/RawTensorTest.java new file mode 100644 index 00000000000..0d2d8af8b1c --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/RawTensorTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TString; + +public class RawTensorTest { + + @Test + public void rawToTypedTensor() { + RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), -1); + TFloat32 floatTensor = (TFloat32)rawTensor.asTypedTensor(); + assertSame(floatTensor.asRawTensor(), rawTensor); + try { + TInt32 intTensor = (TInt32)rawTensor.asTypedTensor(); + fail(); + } catch (ClassCastException e) { + // ok + } + } + + @Test + public void allocateTensorWithSize() { + try (RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), 16)) { + assertEquals(16, rawTensor.numBytes()); + } + try (RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), 100)) { + assertEquals(100, rawTensor.numBytes()); + } + try (RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), 10)) { + fail(); + } catch (IllegalArgumentException e) { + // ok + } + try (RawTensor rawTensor = RawTensor.allocate(TString.class, Shape.of(2, 2), 100)) { + assertEquals(100, rawTensor.numBytes()); + } + } + + @Test + public void allocateTensorWithoutSize() { + try (RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), -1)) { + assertEquals(16, rawTensor.numBytes()); + // ok + } + try (RawTensor rawTensor = RawTensor.allocate(TString.class, Shape.of(2, 2), -1)) { + fail(); + } catch (IllegalArgumentException e) { + // ok + } + } + + @Test + public void failToAllocateTensorFromUnknownShape() { + try { + RawTensor.allocate(TFloat32.class, Shape.of(3, -1, 3), -1); + fail(); + } catch (IllegalArgumentException e) { + // ok + } + try { + RawTensor.allocate(TString.class, Shape.unknown(), 100); + fail(); + } catch (IllegalArgumentException e) { + // ok + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 417205f1f38..cd8ac7e2ae4 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -305,16 +305,16 @@ public void pythonTfFunction() { } private static Signature buildGraphWithVariables(Ops tf, Shape xShape) { - Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(xShape)); + Placeholder x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape)); Variable y = tf - .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.DTYPE)); + .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class)); ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); Init init = tf.init(); return Signature.builder().input("input", x).output("reducedSum", z).build(); } private static Signature buildIdentityGraph(Ops tf, String signatureKey) { - Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + Placeholder x = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.scalar())); Identity xprime = tf.identity(x); return Signature.builder().key(signatureKey).input("x", x).output("x", xprime).build(); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 0556c1ff17f..b1928bff51c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -165,7 +165,7 @@ public void runInit() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Variable var1 = tf.variable(Shape.scalar(), TInt32.DTYPE); + Variable var1 = tf.variable(Shape.scalar(), TInt32.class); tf.initAdd(tf.assign(var1, tf.constant(10))); Variable var2 = tf.variable(tf.constant(20)); Add add = tf.math.add(var1, var2); @@ -185,7 +185,7 @@ public void runInitByName() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Variable var1 = tf.variable(Shape.scalar(), TInt32.DTYPE); + Variable var1 = tf.variable(Shape.scalar(), TInt32.class); tf.initAdd(tf.assign(var1, tf.constant(10))); Variable var2 = tf.variable(tf.constant(20)); Add add = tf.math.add(var1, var2); @@ -212,8 +212,8 @@ public void save() throws IOException { Path testFolder = Files.createTempDirectory("tf-session-save-test"); try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Variable x = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.DTYPE)); - Variable y = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.DTYPE)); + Variable x = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + Variable y = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); Init init = tf.init(); try (Session s = new Session(g)) { @@ -241,7 +241,7 @@ private static ConfigProto singleThreadConfigProto() { private static void transpose_A_times_X(Ops tf, int[][] a) { tf.withName("Y").linalg.matMul( tf.withName("A").constant(a), - tf.withName("X").placeholder(TInt32.DTYPE), + tf.withName("X").placeholder(TInt32.class), MatMul.transposeA(true).transposeB(false) ); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java index 8931ecbbde1..e1436358a68 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java @@ -14,17 +14,11 @@ ==============================================================================*/ package org.tensorflow; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Init; -import org.tensorflow.op.core.Placeholder; -import org.tensorflow.op.math.Add; -import org.tensorflow.op.math.Sub; -import org.tensorflow.types.TFloat32; public class SignatureTest { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java index 4da9bce9e90..9415a986222 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java @@ -30,9 +30,6 @@ import java.nio.IntBuffer; import java.nio.LongBuffer; import org.junit.jupiter.api.Test; -import org.tensorflow.op.Ops; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.FloatNdArray; @@ -40,7 +37,11 @@ import org.tensorflow.ndarray.LongNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.op.Ops; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; @@ -71,14 +72,14 @@ public void createWithRawData() { // validate creating a tensor using a raw data byte buffers { - try (TBool t = Tensor.of(TBool.DTYPE, bools_shape, DataBuffers.of(bools_))) { + try (TBool t = Tensor.of(TBool.class, bools_shape, DataBuffers.of(bools_))) { boolean[] actual = new boolean[bools_.length]; t.read(DataBuffers.of(actual)); assertArrayEquals(bools, actual); } // note: the buffer is expected to contain raw TF_STRING (as per C API) - try (TString t = Tensor.of(TString.DTYPE, strings_shape, DataBuffers.of(strings_))) { + try (TString t = Tensor.of(TString.class, strings_shape, DataBuffers.of(strings_))) { assertEquals(strings, t.getObject()); } } @@ -95,7 +96,7 @@ public void createWithRawData() { } // validate shape checking - try (TBool t = Tensor.of(TBool.DTYPE, Shape.of(bools_.length * 2), DataBuffers.of(bools_))) { + try (TBool t = Tensor.of(TBool.class, Shape.of(bools_.length * 2), DataBuffers.of(bools_))) { fail("should have failed on incompatible buffer"); } catch (IllegalArgumentException e) { // expected @@ -267,44 +268,51 @@ public void readFromRawData() { @Test public void scalars() { try (TFloat32 t = TFloat32.scalarOf(2.718f)) { - assertEquals(TFloat32.DTYPE, t.dataType()); + assertEquals(TFloat32.class, t.type()); + assertEquals(DataType.DT_FLOAT, t.dataType()); assertEquals(0, t.shape().numDimensions()); assertEquals(2.718f, t.getFloat(), EPSILON_F); } try (TFloat64 t = TFloat64.scalarOf(3.1415)) { - assertEquals(TFloat64.DTYPE, t.dataType()); + assertEquals(TFloat64.class, t.type()); + assertEquals(DataType.DT_DOUBLE, t.dataType()); assertEquals(0, t.shape().numDimensions()); assertEquals(3.1415, t.getDouble(), EPSILON); } try (TInt32 t = TInt32.scalarOf(-33)) { - assertEquals(TInt32.DTYPE, t.dataType()); + assertEquals(TInt32.class, t.type()); + assertEquals(DataType.DT_INT32, t.dataType()); assertEquals(0, t.shape().numDimensions()); assertEquals(-33, t.getInt()); } try (TInt64 t = TInt64.scalarOf(8589934592L)) { - assertEquals(TInt64.DTYPE, t.dataType()); + assertEquals(TInt64.class, t.type()); + assertEquals(DataType.DT_INT64, t.dataType()); assertEquals(0, t.shape().numDimensions()); assertEquals(8589934592L, t.getLong()); } try (TBool t = TBool.scalarOf(true)) { - assertEquals(TBool.DTYPE, t.dataType()); + assertEquals(TBool.class, t.type()); + assertEquals(DataType.DT_BOOL, t.dataType()); assertEquals(0, t.shape().numDimensions()); assertTrue(t.getBoolean()); } try (TString t = TString.scalarOf("sombrero")) { - assertEquals(TString.DTYPE, t.dataType()); + assertEquals(TString.class, t.type()); + assertEquals(DataType.DT_STRING, t.dataType()); assertEquals(0, t.shape().numDimensions()); assertEquals("sombrero", t.getObject()); } final byte[] bytes = {1, 2, 3, 4}; try (TString t = TString.tensorOfBytes(NdArrays.scalarOfObject(bytes))) { - assertEquals(TString.DTYPE, t.dataType()); + assertEquals(TString.class, t.type()); + assertEquals(DataType.DT_STRING, t.dataType()); assertEquals(0, t.shape().numDimensions()); assertArrayEquals(bytes, t.asBytes().getObject()); } @@ -314,7 +322,8 @@ public void scalars() { public void nDimensional() { DoubleNdArray vector = StdArrays.ndCopyOf(new double[]{1.414, 2.718, 3.1415}); try (TFloat64 t = TFloat64.tensorOf(vector)) { - assertEquals(TFloat64.DTYPE, t.dataType()); + assertEquals(TFloat64.class, t.type()); + assertEquals(DataType.DT_DOUBLE, t.dataType()); assertEquals(1, t.shape().numDimensions()); assertEquals(3, t.shape().size(0)); assertEquals(vector, t); @@ -322,7 +331,8 @@ public void nDimensional() { IntNdArray matrix = StdArrays.ndCopyOf(new int[][]{{1, 2, 3}, {4, 5, 6}}); try (TInt32 t = TInt32.tensorOf(matrix)) { - assertEquals(TInt32.DTYPE, t.dataType()); + assertEquals(TInt32.class, t.type()); + assertEquals(DataType.DT_INT32, t.dataType()); assertEquals(2, t.shape().numDimensions()); assertEquals(2, t.shape().size(0)); assertEquals(3, t.shape().size(1)); @@ -333,7 +343,8 @@ public void nDimensional() { {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}}, }); try (TInt64 t = TInt64.tensorOf(threeD)) { - assertEquals(TInt64.DTYPE, t.dataType()); + assertEquals(TInt64.class, t.type()); + assertEquals(DataType.DT_INT64, t.dataType()); assertEquals(3, t.shape().numDimensions()); assertEquals(2, t.shape().size(0)); assertEquals(5, t.shape().size(1)); @@ -347,7 +358,8 @@ public void nDimensional() { {{{false, true, false, true}, {false, true, true, false}}}, }); try (TBool t = TBool.tensorOf(fourD)) { - assertEquals(TBool.DTYPE, t.dataType()); + assertEquals(TBool.class, t.type()); + assertEquals(DataType.DT_BOOL, t.dataType()); assertEquals(4, t.shape().numDimensions()); assertEquals(3, t.shape().size(0)); assertEquals(1, t.shape().size(1)); @@ -366,7 +378,8 @@ public void testNDimensionalStringTensor() { } } try (TString t = TString.tensorOf(matrix)) { - assertEquals(TString.DTYPE, t.dataType()); + assertEquals(TString.class, t.type()); + assertEquals(DataType.DT_STRING, t.dataType()); assertEquals(2, t.shape().numDimensions()); assertEquals(4, t.shape().size(0)); assertEquals(3, t.shape().size(1)); @@ -376,7 +389,8 @@ public void testNDimensionalStringTensor() { NdArray byteMatrix = NdArrays.ofObjects(byte[].class, matrix.shape()); matrix.scalars().forEachIndexed((i, s) -> byteMatrix.setObject(s.getObject().getBytes(UTF_8), i)); try (TString t = TString.tensorOfBytes(byteMatrix)) { - assertEquals(TString.DTYPE, t.dataType()); + assertEquals(TString.class, t.type()); + assertEquals(DataType.DT_STRING, t.dataType()); assertEquals(2, t.shape().numDimensions()); assertEquals(4, t.shape().size(0)); assertEquals(3, t.shape().size(1)); @@ -389,7 +403,8 @@ public void testNDimensionalStringTensor() { public void testUint8TensorFromArray() { byte[] vector = new byte[] {1, 2, 3, 4}; try (TUint8 t = TUint8.vectorOf(vector)) { - assertEquals(TUint8.DTYPE, t.dataType()); + assertEquals(TUint8.class, t.type()); + assertEquals(DataType.DT_UINT8, t.dataType()); assertEquals(1, t.shape().numDimensions()); assertEquals(4, t.shape().size(0)); @@ -403,7 +418,8 @@ public void testUint8TensorFromArray() { public void testCreateFromArrayOfBoxed() { Integer[] vector = new Integer[] {1, 2, 3, 4}; try (TInt32 t = TInt32.tensorOf(Shape.of(4), d -> d.write(DataBuffers.ofObjects(vector)))) { - assertEquals(TInt32.DTYPE, t.dataType()); + assertEquals(TInt32.class, t.type()); + assertEquals(DataType.DT_INT32, t.dataType()); assertEquals(1, t.shape().numDimensions()); assertEquals(4, t.shape().size(0)); @@ -445,14 +461,14 @@ public void tensorWithZeroDimension() { @Test public void allocateTensorWithSize() { - try (TInt32 t = Tensor.of(TInt32.DTYPE, Shape.of(2, 2, 2), 8 * TInt32.DTYPE.byteSize())) { + try (TInt32 t = Tensor.of(TInt32.class, Shape.of(2, 2, 2), 8 * 4)) { // ok } - try (TInt32 t = Tensor.of(TInt32.DTYPE, Shape.of(2, 2, 2), 9 * TInt32.DTYPE.byteSize())) { + try (TInt32 t = Tensor.of(TInt32.class, Shape.of(2, 2, 2), 9 * 4)) { // ok (size requested is larger that minimum space required) } try { - Tensor.of(TInt32.DTYPE, Shape.of(2, 2, 2), 8 * TInt32.DTYPE.byteSize() - 1); + Tensor.of(TInt32.class, Shape.of(2, 2, 2), 8 * 4 - 1); fail(); } catch (IllegalArgumentException e) { // as expected @@ -499,6 +515,7 @@ public void fromHandle() { final FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{{1, 2, 3}, {4, 5, 6}}); try (TFloat32 src = TFloat32.tensorOf(matrix)) { TFloat32 cpy = (TFloat32)RawTensor.fromHandle(src.asRawTensor().nativeHandle()).asTypedTensor(); + assertEquals(src.type(), cpy.type()); assertEquals(src.dataType(), cpy.dataType()); assertEquals(src.shape().numDimensions(), cpy.shape().numDimensions()); assertEquals(src.shape(), cpy.shape()); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java index c97dc83f510..62881dcee8c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java @@ -23,7 +23,6 @@ import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TType; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java index ede179740e5..b1ebd469eb3 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java @@ -22,10 +22,9 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.ndarray.Shape; import org.tensorflow.types.TInt32; public final class GeneratedOperationsTest { @@ -70,7 +69,7 @@ public void testControlDependencies() { try (Graph g = new Graph(); Session sess = new Session(g)) { Ops ops = Ops.create(g); - Operand variable = ops.variable(Shape.scalar(), TInt32.DTYPE); + Operand variable = ops.variable(Shape.scalar(), TInt32.class); Operand initVariable = ops.assign(variable, ops.constant(0)); ArrayList controls = new ArrayList<>(); controls.add(ops.assign(variable, ops.constant(3))); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index 6bab99095e4..80150b64bb6 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -37,7 +37,7 @@ public void createGradients() { Session sess = new Session(g)) { Ops tf = Ops.create(g); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); @@ -64,7 +64,7 @@ public void createGradientsWithSum() { Session sess = new Session(g)) { Ops tf = Ops.create(g); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); @@ -89,7 +89,7 @@ public void createGradientsWithInitialValues() { Session sess = new Session(g)) { Ops tf = Ops.create(g); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); @@ -115,7 +115,7 @@ public void validateGradientsNames() { try (Graph g = new Graph()) { Ops tf = Ops.create(g).withSubScope("sub"); - Output x = tf.placeholder(TFloat32.DTYPE).output(); + Output x = tf.placeholder(TFloat32.class).output(); Output y = tf.math.square(x).y(); Gradients grad0 = Gradients.create(tf.scope(), y, Arrays.asList(x)); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java index 083beca923c..39c04c942af 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java @@ -14,18 +14,14 @@ ==============================================================================*/ package org.tensorflow.op.core; -import java.util.concurrent.atomic.AtomicInteger; - import static org.junit.jupiter.api.Assertions.assertEquals; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; - -import org.junit.jupiter.api.TestTemplate; -import org.tensorflow.Graph; import org.tensorflow.EagerSession; +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.op.Scope; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -40,11 +36,11 @@ public void testFlatten_Operand() { Session session = new Session(g)) { Scope scope = new Scope(g); Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Shape expResult = Shape.create(scope, operand, TInt64.DTYPE); + Shape expResult = Shape.create(scope, operand, TInt64.class); Operand reshaped = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); Operand actual = Shapes.flatten(scope, reshaped); - Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); + Shape tfshape = Shape.create(scope, actual, TInt64.class); AtomicInteger index = new AtomicInteger(); try (TInt64 result1 = (TInt64)session.runner().fetch(tfshape.asOutput()).run().get(0); @@ -63,11 +59,11 @@ public void testFlatten_Shape() { try (EagerSession session = EagerSession.create()) { Scope scope = new Scope(session); Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Shape expShape = Shape.create(scope, operand, TInt64.DTYPE); + Shape expShape = Shape.create(scope, operand, TInt64.class); Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); - Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); - Operand flattened = Shapes.flatten(scope, tfshape, TInt64.DTYPE); + Shape tfshape = Shape.create(scope, actual, TInt64.class); + Operand flattened = Shapes.flatten(scope, tfshape, TInt64.class); AtomicInteger index = new AtomicInteger(); flattened @@ -89,8 +85,8 @@ public void testSize_Shape() { Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); - Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); - Operand size = Shapes.size(scope, tfshape, TInt64.DTYPE); + Shape tfshape = Shape.create(scope, actual, TInt64.class); + Operand size = Shapes.size(scope, tfshape, TInt64.class); AtomicInteger index = new AtomicInteger(); try (TInt64 result1 = (TInt64)session.runner().fetch(size.asOutput()).run().get(0)) { @@ -405,7 +401,7 @@ public void testPrependLong() { Scope scope = new Scope(g); Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); - Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); + Shape tfshape = Shape.create(scope, actual, TInt64.class); Operand prepend = Shapes.prepend(scope, tfshape, 1L); AtomicInteger index = new AtomicInteger(); @@ -462,8 +458,8 @@ public void testPrependShapeTInt64() { Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); Operand actual2 = Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); - Shape tfshape1 = Shape.create(scope, actual1, TInt64.DTYPE); - Shape tfshape2 = Shape.create(scope, actual2, TInt64.DTYPE); + Shape tfshape1 = Shape.create(scope, actual1, TInt64.class); + Shape tfshape2 = Shape.create(scope, actual2, TInt64.class); Operand prepend = Shapes.prepend(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); @@ -487,7 +483,7 @@ public void testAppendLong() { Scope scope = new Scope(g); Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); - Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); + Shape tfshape = Shape.create(scope, actual, TInt64.class); Operand append = Shapes.append(scope, tfshape, 2L); AtomicInteger index = new AtomicInteger(); @@ -568,8 +564,8 @@ public void testAppendShapeTInt64() { Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); Operand actual2 = Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); - Shape tfshape1 = Shape.create(scope, actual1, TInt64.DTYPE); - Shape tfshape2 = Shape.create(scope, actual2, TInt64.DTYPE); + Shape tfshape1 = Shape.create(scope, actual1, TInt64.class); + Shape tfshape2 = Shape.create(scope, actual2, TInt64.class); Operand append = Shapes.append(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java index 204bd4b10f3..4121baf3af1 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -23,7 +23,6 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.op.Scope; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; @@ -41,7 +40,7 @@ public void createIntZeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt32.DTYPE); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt32.class); try (TInt32 result = (TInt32)sess.runner().fetch(op).run().get(0)) { result.scalars().forEach(s -> assertEquals(0, s.getInt())); } @@ -54,7 +53,7 @@ public void createFloatZeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.DTYPE); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.class); try (TFloat32 result = (TFloat32)sess.runner().fetch(op.asOutput()).run().get(0)) { result.scalars().forEach(s -> assertEquals(0.0f, s.getFloat(), 0)); } @@ -67,7 +66,7 @@ public void createDoubleZeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat64.DTYPE); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat64.class); try (TFloat64 result = (TFloat64)sess.runner().fetch(op.asOutput()).run().get(0)) { result.scalars().forEach(s -> assertEquals(0.0f, s.getDouble(), 0)); } @@ -80,7 +79,7 @@ public void createLongZeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt64.DTYPE); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt64.class); try (TInt64 result = (TInt64)sess.runner().fetch(op.asOutput()).run().get(0)) { result.scalars().forEach(s -> assertEquals(0L, s.getLong())); } @@ -93,7 +92,7 @@ public void createBooleanZeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TBool.DTYPE); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TBool.class); try (TBool result = (TBool)sess.runner().fetch(op.asOutput()).run().get(0)) { result.scalars().forEach(s -> assertFalse(s.getBoolean())); } @@ -106,7 +105,7 @@ public void createUint8Zeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TUint8.DTYPE); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TUint8.class); try (TUint8 result = (TUint8)sess.runner().fetch(op.asOutput()).run().get(0)) { result.scalars().forEach(s -> assertEquals(0, s.getByte())); } @@ -119,7 +118,7 @@ public void createStringZeros() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TString.DTYPE); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TString.class); try (TString result = (TString)sess.runner().fetch(op.asOutput()).run().get(0)) { result.scalars().forEach(s -> assertTrue(s.getObject().isEmpty())); } @@ -132,7 +131,7 @@ public void operationsComposingZerosAreCorrectlyNamed() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.DTYPE); + Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.class); List results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java index 9ea7f952f04..a2ab28b6219 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java @@ -17,8 +17,6 @@ package org.tensorflow.types; -import static org.junit.jupiter.api.Assertions.assertEquals; - import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java index ae3d7e8c896..2f2f16f2752 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TBool; @@ -89,9 +88,9 @@ public Operand call(Operand input) { Operand result = tf.nn.elu(input); if (alpha == 1.0) return result; else { - DataType dataType = input.asOutput().dataType(); - Operand y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), dataType)); - Operand cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), dataType)); + Class inputType = input.type(); + Operand y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), inputType)); + Operand cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), inputType)); return tf.select(cond, result, y); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java index a486cbdc601..0b7cf573b8e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TFloating; @@ -63,12 +62,12 @@ public HardSigmoid(Ops tf) { */ @Override public Operand call(Operand input) { - DataType dataType = input.asOutput().dataType(); - Operand point2 = tf.dtypes.cast(tf.constant(0.2), dataType); - Operand point5 = tf.dtypes.cast(tf.constant(0.5), dataType); + Class inputType = input.type(); + Operand point2 = tf.dtypes.cast(tf.constant(0.2), inputType); + Operand point5 = tf.dtypes.cast(tf.constant(0.5), inputType); Operand x = tf.math.add(tf.math.mul(input, point2), point5); return tf.clipByValue( - x, tf.dtypes.cast(tf.constant(0), dataType), tf.dtypes.cast(tf.constant(1), dataType)); + x, tf.dtypes.cast(tf.constant(0), inputType), tf.dtypes.cast(tf.constant(1), inputType)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java index c24cf71077d..aef6ebf2992 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.activations; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.op.math.Greater; @@ -98,8 +97,7 @@ public ReLU(Ops tf, float alpha, float maxValue, float threshold) { /** {@inheritDoc} */ @Override public Operand call(Operand input) { - - DataType dataType = input.asOutput().dataType(); + Class inputType = input.type(); boolean clipMax = !Float.isNaN(maxValue); Operand negativePart = null; @@ -110,7 +108,7 @@ public Operand call(Operand input) { if (threshold != 0) { negativePart = tf.nn.relu( - tf.math.add(tf.math.neg(input), tf.dtypes.cast(tf.constant(threshold), dataType))); + tf.math.add(tf.math.neg(input), tf.dtypes.cast(tf.constant(threshold), inputType))); } else { negativePart = tf.nn.relu(tf.math.neg(input)); } @@ -119,8 +117,8 @@ public Operand call(Operand input) { Operand lInput; if (threshold != 0) { // computes input for input > threshold else 0 - Greater greater = tf.math.greater(input, tf.dtypes.cast(tf.constant(threshold), dataType)); - lInput = tf.math.mul(input, tf.dtypes.cast(greater, dataType)); + Greater greater = tf.math.greater(input, tf.dtypes.cast(tf.constant(threshold), inputType)); + lInput = tf.math.mul(input, tf.dtypes.cast(greater, inputType)); } else if (maxValue == 6) { // if no threshold, then can use nn.relu6 native TF op for performance lInput = tf.nn.relu6(input); @@ -129,15 +127,15 @@ public Operand call(Operand input) { lInput = tf.nn.relu(input); } if (clipMax) { - Operand lmaxValue = tf.dtypes.cast(tf.constant(maxValue), dataType); - Operand zero = tf.dtypes.cast(tf.constant(0), dataType); + Operand lmaxValue = tf.dtypes.cast(tf.constant(maxValue), inputType); + Operand zero = tf.dtypes.cast(tf.constant(0), inputType); lInput = tf.clipByValue(lInput, zero, lmaxValue); } if (alpha != 0.) { lInput = tf.math.sub( - lInput, tf.math.mul(tf.dtypes.cast(tf.constant(alpha), dataType), negativePart)); + lInput, tf.math.mul(tf.dtypes.cast(tf.constant(alpha), inputType), negativePart)); } return lInput; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java index 007bcb01a40..7ac73f616e2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.data.impl.BatchDataset; import org.tensorflow.framework.data.impl.MapDataset; @@ -33,6 +32,7 @@ import java.util.Iterator; import java.util.List; import java.util.function.Function; +import org.tensorflow.types.family.TType; /** * Represents a potentially large list of independent elements (samples), and allows iteration and @@ -41,11 +41,11 @@ public abstract class Dataset implements Iterable>> { protected Ops tf; private Operand variant; - private List> outputTypes; + private List> outputTypes; private List outputShapes; public Dataset( - Ops tf, Operand variant, List> outputTypes, List outputShapes) { + Ops tf, Operand variant, List> outputTypes, List outputShapes) { if (tf == null) { throw new IllegalArgumentException("Ops accessor cannot be null."); } @@ -261,12 +261,12 @@ public DatasetIterator makeOneShotIterator() { * @param tf Ops Accessor * @param tensors A list of {@code Operand} representing components of this dataset (e.g. * features, labels) - * @param outputTypes A list of `DataType` objects representing the data type of each component of + * @param outputTypes A list of tensor type classes representing the data type of each component of * this dataset. * @return A new `Dataset` */ public static Dataset fromTensorSlices( - Ops tf, List> tensors, List> outputTypes) { + Ops tf, List> tensors, List> outputTypes) { return new TensorSliceDataset(tf, tensors, outputTypes); } @@ -288,7 +288,7 @@ public Operand getVariant() { } /** Get a list of output types for each component of this dataset. */ - public List> getOutputTypes() { + public List> getOutputTypes() { return this.outputTypes; } @@ -305,7 +305,7 @@ public Ops getOpsInstance() { public String toString() { return "Dataset{" + "outputTypes=" - + Arrays.toString(getOutputTypes().stream().map(DataType::name).toArray()) + + Arrays.toString(getOutputTypes().stream().map(Class::getSimpleName).toArray()) + ", outputShapes=" + Arrays.toString(getOutputShapes().stream().map(Shape::toString).toArray()) + "}"; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java index 0bad6a41214..a3aa290a8c8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data; -import org.tensorflow.DataType; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.op.Op; @@ -25,6 +24,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import org.tensorflow.types.family.TType; /** * Represents the state of an iteration through a tf.data Datset. DatasetIterator is not a @@ -106,7 +106,7 @@ public class DatasetIterator implements Iterable>> { private Operand iteratorResource; private Op initializer; - protected List> outputTypes; + protected List> outputTypes; protected List outputShapes; /** @@ -115,7 +115,7 @@ public class DatasetIterator implements Iterable>> { * @param iteratorResource An Operand representing the iterator (e.g. constructed from * `tf.data.iterator` or `tf.data.anonymousIterator`) * @param initializer An `Op` that should be run to initialize this iterator - * @param outputTypes A list of `DataType` objects corresponding to the types of each component of + * @param outputTypes A list of classes corresponding to the tensor type of each component of * a dataset element. * @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of * a dataset element. @@ -124,7 +124,7 @@ public DatasetIterator( Ops tf, Operand iteratorResource, Op initializer, - List> outputTypes, + List> outputTypes, List outputShapes) { this.tf = tf; @@ -137,7 +137,7 @@ public DatasetIterator( public DatasetIterator( Ops tf, Operand iteratorResource, - List> outputTypes, + List> outputTypes, List outputShapes) { this.tf = tf; this.iteratorResource = iteratorResource; @@ -229,14 +229,14 @@ public Op makeInitializer(Dataset dataset) { * Creates a new iterator from a "structure" defined by `outputShapes` and `outputTypes`. * * @param tf Ops accessor - * @param outputTypes A list of `DataType` objects repesenting the types of each component of a + * @param outputTypes A list of classes repesenting the tensor type of each component of a * dataset element. * @param outputShapes A list of Shape objects representing the shape of each component of a * dataset element. * @return A new DatasetIterator */ public static DatasetIterator fromStructure( - Ops tf, List> outputTypes, List outputShapes) { + Ops tf, List> outputTypes, List outputShapes) { Operand iteratorResource = tf.scope().env() instanceof Graph ? tf.data.iterator(EMPTY_SHARED_NAME, "", outputTypes, outputShapes) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java index 925252c7298..6617c33eaf7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.ndarray.Shape; @@ -23,6 +22,7 @@ import java.util.ArrayList; import java.util.List; +import org.tensorflow.types.family.TType; /** * An optional represents the result of a dataset getNext operation that may fail, when the end of @@ -36,11 +36,11 @@ public Operand getOptionalVariant() { } private Operand optionalVariant; - private List> outputTypes; + private List> outputTypes; private List outputShapes; public DatasetOptional( - Ops tf, Operand optionalVariant, List> outputTypes, List outputShapes) { + Ops tf, Operand optionalVariant, List> outputTypes, List outputShapes) { this.tf = tf; this.optionalVariant = optionalVariant; this.outputTypes = outputTypes; @@ -75,7 +75,7 @@ public List> getValue() { public static DatasetOptional fromComponents( Ops tf, List> components, - List> outputTypes, + List> outputTypes, List outputShapes) { Operand optionalVariant = tf.data.optionalFromValue(components); return new DatasetOptional(tf, optionalVariant, outputTypes, outputShapes); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/BatchDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/BatchDataset.java index 277b049cf6f..f0561b2e61e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/BatchDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/BatchDataset.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data.impl; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; import org.tensorflow.op.Ops; @@ -25,6 +24,7 @@ import org.tensorflow.types.TInt64; import java.util.List; +import org.tensorflow.types.family.TType; public class BatchDataset extends Dataset { public BatchDataset( @@ -32,7 +32,7 @@ public BatchDataset( Operand variant, Constant batchSize, Constant dropRemainder, - List> outputTypes, + List> outputTypes, List outputShapes) { super( tf, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/SkipDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/SkipDataset.java index 6731bac60b3..63b4208480b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/SkipDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/SkipDataset.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data.impl; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; import org.tensorflow.op.Ops; @@ -24,6 +23,7 @@ import org.tensorflow.types.TInt64; import java.util.List; +import org.tensorflow.types.family.TType; public class SkipDataset extends Dataset { @@ -31,7 +31,7 @@ public SkipDataset( Ops tf, Operand variant, Constant count, - List> outputTypes, + List> outputTypes, List outputShapes) { super( tf, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TFRecordDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TFRecordDataset.java index ed721b13ebf..00297152e90 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TFRecordDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TFRecordDataset.java @@ -34,7 +34,7 @@ public TFRecordDataset( super( tf, tf.data.tfRecordDataset(filenames, compressionType, bufferSize), - Collections.singletonList(TString.DTYPE), + Collections.singletonList(TString.class), Collections.singletonList(Shape.scalar())); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TakeDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TakeDataset.java index 08c57d44a73..39ca9759e74 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TakeDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TakeDataset.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data.impl; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; import org.tensorflow.op.Ops; @@ -24,6 +23,7 @@ import org.tensorflow.types.TInt64; import java.util.List; +import org.tensorflow.types.family.TType; public class TakeDataset extends Dataset { @@ -31,7 +31,7 @@ public TakeDataset( Ops tf, Operand variant, Constant count, - List> outputTypes, + List> outputTypes, List outputShapes) { super( tf, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java index 495014f1753..46639ea2aad 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.data.impl; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; import org.tensorflow.op.Ops; @@ -23,10 +22,11 @@ import java.util.List; import java.util.stream.Collectors; +import org.tensorflow.types.family.TType; public class TensorSliceDataset extends Dataset { - public TensorSliceDataset(Ops tf, List> components, List> outputTypes) { + public TensorSliceDataset(Ops tf, List> components, List> outputTypes) { super(tf, makeVariant(tf, components, outputTypes), outputTypes, outputShapes(components)); } @@ -35,7 +35,7 @@ private static List outputShapes(List> components) { } private static Operand makeVariant( - Ops tf, List> components, List> outputTypes) { + Ops tf, List> components, List> outputTypes) { if (!(components.size() == outputTypes.size())) { throw new IllegalArgumentException( "Lists `tensors` and `dtypes` must have the same number of elements."); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TextLineDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TextLineDataset.java index 4ef47825211..c9a26304778 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TextLineDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TextLineDataset.java @@ -34,7 +34,7 @@ public TextLineDataset( super( tf, tf.data.textLineDataset(filenames, compressionType, bufferSize), - Collections.singletonList(TString.DTYPE), + Collections.singletonList(TString.class), Collections.singletonList(Shape.scalar())); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java index b4544de9bd0..4a2df86d74b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java @@ -14,10 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; /** @@ -29,7 +30,7 @@ * Constant<TFloat32> initializer = * new org.tensorflow.framework.initializers.Constant<>(tf, 3f); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); *

* * @param The Type for the call operation @@ -85,17 +86,17 @@ public Constant(Ops tf, boolean value) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!(dtype.isNumeric() || dtype.isBoolean())) { - throw new IllegalArgumentException("DataType must be numeric or boolean: " + dtype.name()); + public Operand call(Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { + throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName()); } switch (valueType) { case LONG: - return tf.fill(dims, tf.dtypes.cast(tf.constant(longValue), dtype)); + return tf.fill(dims, tf.dtypes.cast(tf.constant(longValue), type)); case DOUBLE: - return tf.fill(dims, tf.dtypes.cast(tf.constant(doubleValue), dtype)); + return tf.fill(dims, tf.dtypes.cast(tf.constant(doubleValue), type)); default: - return tf.fill(dims, tf.dtypes.cast(tf.constant(booleanValue), dtype)); + return tf.fill(dims, tf.dtypes.cast(tf.constant(booleanValue), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java index 3d5d37b91d3..290e4e80b57 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java @@ -16,8 +16,7 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; +import org.tensorflow.types.family.TFloating; /** * The Glorot initializer, also called Xavier initializer. @@ -44,7 +43,7 @@ * new org.tensorflow.framework.initializers.Glorot<>(tf, * Distribution.TRUNCATED_NORMAL, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * *

Glorot Uniform: @@ -55,7 +54,7 @@ * new org.tensorflow.framework.initializers.Glorot<>(tf, * Distribution.UNIFORM, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * *

NOTE: @@ -66,11 +65,10 @@ *

* * @param The TType for the call operation - * @param The TNumber for the call operation * @see VarianceScaling.Distribution * @see Glorot et al., 2010 */ -public class Glorot extends VarianceScaling { +public class Glorot extends VarianceScaling { public static final double SCALE = 1.0; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java index ce99da80bf7..9b1a0887af0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java @@ -15,8 +15,7 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; +import org.tensorflow.types.family.TFloating; /** * He initializer. @@ -39,7 +38,7 @@ * new org.tensorflow.framework.initializers.He<>(tf, * Distribution.TRUNCATED_NORMAL, seed);); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * *

He Uniform: @@ -50,7 +49,7 @@ * new org.tensorflow.framework.initializers.He<>(tf, * Distribution.UNIFORM, seed);); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * *

NOTE: @@ -61,12 +60,11 @@ *

* * @param The TType for the call operation - * @param The TNumber for the call operation * @see He * et al., 2015 */ -public class He extends VarianceScaling { +public class He extends VarianceScaling { public static final double SCALE = 2.0; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java index 34e6cd790f4..f672c9f1e85 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java @@ -14,13 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.utils.ShapeUtils; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TType; +import org.tensorflow.types.family.TFloating; /** * Initializer that generates the identity matrix. @@ -33,12 +32,12 @@ * Identity<TFloat32> initializer = * new org.tensorflow.framework.initializers.Identity<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation */ -public class Identity extends BaseInitializer { +public class Identity extends BaseInitializer { public static final double GAIN_DEFAULT = 1.0; private final double gain; @@ -66,10 +65,7 @@ public Identity(Ops tf, double gain) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!dtype.isFloating()) { - throw new IllegalArgumentException("DataType must be a float type: " + dtype.name()); - } + public Operand call(Operand dims, Class type) { Shape shape = ShapeUtils.toShape(tf.scope(), dims); if (shape.numDimensions() != 2) { throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions()); @@ -79,9 +75,9 @@ public Operand call(Operand dims, DataType dtype) { Shape diagShape = Shape.of(diagSize); Operand op; - Operand zero = tf.dtypes.cast(tf.constant(0), dtype); + Operand zero = tf.dtypes.cast(tf.constant(0), type); Operand diagOnes = - tf.fill(tf.constant(diagShape.asArray()), tf.dtypes.cast(tf.constant(1.0), dtype)); + tf.fill(tf.constant(diagShape.asArray()), tf.dtypes.cast(tf.constant(1.0), type)); if (isSquare) { op = tf.linalg.matrixDiag( @@ -91,10 +87,10 @@ public Operand call(Operand dims, DataType dtype) { tf.constant((int) shape.size(1)), zero); } else { - Operand zeroMatrix = tf.zeros(dims, dtype); + Operand zeroMatrix = tf.zeros(dims, type); op = tf.linalg.matrixSetDiag(zeroMatrix, diagOnes, tf.constant(0)); } - return tf.math.mul(op, tf.dtypes.cast(tf.constant(gain), dtype)); + return tf.math.mul(op, tf.dtypes.cast(tf.constant(gain), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java index 59dce1fc02e..4beb218783b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -30,8 +29,8 @@ public interface Initializer { * Generates the operation used to perform the initialization. * * @param dims the shape dimensions - * @param dtype the data type + * @param type the type of tensor * @return An operand for the initialization. */ - Operand call(Operand dims, DataType dtype); + Operand call(Operand dims, Class type); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java index e2268412fc3..38e68ef688b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java @@ -15,8 +15,7 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; +import org.tensorflow.types.family.TFloating; /** * LeCun normal initializer. @@ -42,7 +41,7 @@ * new org.tensorflow.framework.initializers.LeCunNormal<>(tf, * Distribution.TRUNCATED_NORMAL, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * *

LeCun Uniform: @@ -53,7 +52,7 @@ * new org.tensorflow.framework.initializers.LeCunNormal<>(tf, * Distribution.UNIFORM, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * * * @@ -68,7 +67,6 @@ *

* * @param The TType for the call operation - * @param The TNumber for the call operation * @see Self-Normalizing * Neural Networks, Klambauer et al., 2017 @@ -76,7 +74,7 @@ * al., 1998 * @see VarianceScaling.Distribution */ -public class LeCun extends VarianceScaling { +public class LeCun extends VarianceScaling { /** * Creates a LeCunNormal Initializer diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java index b78f34e3d35..b8eb0c418e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java @@ -14,10 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; /** @@ -29,7 +30,7 @@ * Ones<TFloat32> initializer = * new org.tensorflow.framework.initializers.Ones<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -45,7 +46,7 @@ public class Ones extends BaseInitializer { * Ones<TFloat32> initializer = * new org.tensorflow.framework.initializers.Ones<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param tf the TensorFlow Ops @@ -56,10 +57,10 @@ public Ones(Ops tf) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!(dtype.isNumeric() || dtype.isBoolean())) { - throw new IllegalArgumentException("DataType must be numeric or boolean: " + dtype.name()); + public Operand call(Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { + throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName()); } - return tf.fill(dims, tf.dtypes.cast(tf.constant(1.0), dtype)); + return tf.fill(dims, tf.dtypes.cast(tf.constant(1.0), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java index 48e2c56d5be..a5b466e118e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.framework.utils.ShapeUtils; @@ -22,8 +21,7 @@ import org.tensorflow.op.Ops; import org.tensorflow.op.linalg.Qr; import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; +import org.tensorflow.types.family.TFloating; /** * Initializer that generates an orthogonal matrix. @@ -44,13 +42,12 @@ * Orthogonal<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.Orthogonal<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation - * @param The TNumber for the call operation */ -public class Orthogonal extends BaseInitializer { +public class Orthogonal extends BaseInitializer { public static final double GAIN_DEFAULT = 1.0; @@ -84,10 +81,7 @@ public Orthogonal(Ops tf, double gain, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!dtype.isFloating()) { - throw new IllegalArgumentException("Expected floating point type, got " + dtype.name()); - } + public Operand call(Operand dims, Class type) { Shape dimsShape = ShapeUtils.toShape(tf.scope(), dims); if (dimsShape.numDimensions() < 2) { throw new IllegalArgumentException( @@ -100,22 +94,17 @@ public Operand call(Operand dims, DataType dtype) { long numCols = dimsShape.size(i); Shape flatShape = Shape.of(Math.max(numRows, numCols), Math.min(numRows, numCols)); long[] seeds = {seed, 0}; - @SuppressWarnings("unchecked") - DataType numdType = (DataType) dtype; - @SuppressWarnings("unchecked") Operand op = - (Operand) - tf.random.statelessRandomNormal(tf.constant(flatShape), tf.constant(seeds), numdType); - + tf.random.statelessRandomNormal(tf.constant(flatShape), tf.constant(seeds), type); Qr.Options qrOptions = Qr.fullMatrices(false); Qr qrOp = tf.linalg.qr(op, qrOptions); Output qo = qrOp.q(); Output ro = qrOp.r(); Operand diagOp = - tf.linalg.matrixDiagPart(ro, tf.constant(0), tf.dtypes.cast(tf.constant(0), dtype)); + tf.linalg.matrixDiagPart(ro, tf.constant(0), tf.dtypes.cast(tf.constant(0), type)); Operand qop = tf.math.mul(qo, tf.math.sign(diagOp)); if (numRows < numCols) qop = tf.linalg.transpose(qop, null); - return tf.math.mul(qop, tf.dtypes.cast(tf.constant(this.gain), dtype)); + return tf.math.mul(qop, tf.dtypes.cast(tf.constant(this.gain), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java index f2d8a0d8e6e..38ab194a56b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java @@ -14,12 +14,10 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; +import org.tensorflow.types.family.TFloating; /** * Initializer that generates tensors with a normal distribution. @@ -31,13 +29,13 @@ * RandomNormal<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.RandomNormal<>(tf, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation - * @param The TNumber for the call operation */ -public class RandomNormal extends BaseInitializer { +public class RandomNormal extends BaseInitializer { + public static final double MEAN_DEFAULT = 0.0; public static final double STDDEV_DEFAULT = 1.0; @@ -87,16 +85,10 @@ public RandomNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!dtype.isNumeric()) - throw new IllegalArgumentException("The data type must be numeric. Found : " + dtype.name()); + public Operand call(Operand dims, Class type) { long[] seeds = {seed, 0}; - @SuppressWarnings("unchecked") - DataType numdType = (DataType) dtype; - @SuppressWarnings("unchecked") - Operand distOp = - (Operand) tf.random.statelessRandomNormal(dims, tf.constant(seeds), numdType); - Operand op = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.stddev), dtype)); - return tf.math.add(op, tf.dtypes.cast(tf.constant(mean), dtype)); + Operand distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); + Operand op = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.stddev), type)); + return tf.math.add(op, tf.dtypes.cast(tf.constant(mean), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java index b665729675d..787af15f709 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java @@ -14,13 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.op.random.RandomUniformInt; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Initializer that generates tensors with a uniform distribution. @@ -32,13 +31,12 @@ * RandomUniform<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.RandomUniform<>(tf, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation - * @param The TNumber for the call operation */ -public class RandomUniform extends BaseInitializer { +public class RandomUniform extends BaseInitializer { public static final double MINVAL_DEFAULT = -0.05; public static final double MAXVAL_DEFAULT = 0.05; @@ -77,39 +75,28 @@ public RandomUniform(Ops tf, double minval, double maxval, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!dtype.isNumeric()) - throw new IllegalArgumentException("The data type must be numeric. Found : " + dtype.name()); - @SuppressWarnings("unchecked") - DataType numdType = (DataType) dtype; - Operand distOp; - - if (dtype.isInteger()) { + public Operand call(Operand dims, Class type) { + Operand distOp; + if (TIntegral.class.isAssignableFrom(type)) { RandomUniformInt.Options options = RandomUniformInt.seed(this.seed); distOp = tf.random.randomUniformInt( dims, - tf.dtypes.cast(tf.constant(this.minval), numdType), - tf.dtypes.cast(tf.constant(this.maxval), numdType), + tf.dtypes.cast(tf.constant(this.minval), type), + tf.dtypes.cast(tf.constant(this.maxval), type), options); - @SuppressWarnings("unchecked") - Operand distOpT = (Operand) distOp; - return distOpT; } else { long[] seeds = {seed, 0}; - distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), numdType); - @SuppressWarnings("unchecked") - Operand distOpT = (Operand) distOp; + distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); if (this.minval == 0) { if (this.maxval != 1.0) { - distOpT = tf.math.mul(distOpT, tf.dtypes.cast(tf.constant(this.maxval), dtype)); + distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval), type)); } } else { - distOpT = - tf.math.mul(distOpT, tf.dtypes.cast(tf.constant(this.maxval - this.minval), dtype)); - distOpT = tf.math.add(distOpT, tf.dtypes.cast(tf.constant(this.minval), dtype)); + distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval - this.minval), type)); + distOp = tf.math.add(distOp, tf.dtypes.cast(tf.constant(this.minval), type)); } - return distOpT; } + return distOp; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java index c71cf9a630e..d3cfec26338 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java @@ -14,12 +14,10 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; +import org.tensorflow.types.family.TFloating; /** * Initializer that generates a truncated normal distribution. @@ -31,13 +29,12 @@ * TruncatedNormal<TFloat32, TFloat32> initializer = * new org.tensorflow.framework.initializers.TruncatedNormal<>(tf, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation - * @param The TNumber for the call operation */ -public class TruncatedNormal extends BaseInitializer { +public class TruncatedNormal extends BaseInitializer { public static final double MEAN_DEFAULT = 0.0; public static final double STDDEV_DEFAULT = 0.05; @@ -76,17 +73,11 @@ public TruncatedNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!dtype.isNumeric()) - throw new IllegalArgumentException("The data type must be numeric. Found : " + dtype.name()); + public Operand call(Operand dims, Class type) { long[] seeds = {seed,0}; - @SuppressWarnings("unchecked") - DataType numdType = (DataType) dtype; - Operand distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), numdType); - @SuppressWarnings("unchecked") - Operand distOpT = (Operand) distOp; + Operand distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); return tf.math.add( - tf.math.mul(distOpT, tf.dtypes.cast(tf.constant(stddev), dtype)), - tf.dtypes.cast(tf.constant(mean), dtype)); + tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)), + tf.dtypes.cast(tf.constant(mean), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java index fd33adadd5c..5d951450505 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java @@ -14,20 +14,16 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.utils.ShapeUtils; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; +import org.tensorflow.types.family.TFloating; /** * Initializer capable of adapting its scale to the shape of weights tensors. * - *

- * *

With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from * a truncated/untruncated normal distribution with a mean of zero and a standard deviation (after * truncation, if used) stddev = Math.sqrt(scale / n), where n is: @@ -50,15 +46,14 @@ * new org.tensorflow.framework.initializers.VarianceScaling<>( * tf, scale, Mode.FAN_IN, Distribution.UNIFORM, seed); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation - * @param The TNumber for the call operation * @see VarianceScaling.Mode * @see VarianceScaling.Distribution */ -public class VarianceScaling extends BaseInitializer { +public class VarianceScaling extends BaseInitializer { public static final double SCALE_DEFAULT = 1.0; public static final Mode MODE_DEFAULT = Mode.FAN_IN; @@ -102,10 +97,7 @@ public VarianceScaling(Ops tf, double scale, Mode mode, Distribution distributio /** {@inheritDoc} */ @Override - public Operand call(Operand dims, DataType dtype) { - if (!dtype.isFloating()) { - throw new IllegalArgumentException("Expected floating point type, got " + dtype.name()); - } + public Operand call(Operand dims, Class type) { Shape shape = ShapeUtils.toShape(this.tf.scope(), dims); double lscale = this.scale; double[] fans /* fanIn, fanOut */ = computeFans(shape); @@ -120,32 +112,28 @@ public Operand call(Operand dims, DataType dtype) { lscale /= Math.max(1., (fans[0] + fans[1]) / 2.); break; } - Operand distOp; - Operand mulOp = null; - @SuppressWarnings("unchecked") - DataType numdType = (DataType) dtype; + Operand distOp; + Operand mulOp = null; double stddev; long[] seeds = {seed, 0}; switch (distribution) { case TRUNCATED_NORMAL: - distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), numdType); + distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); stddev = Math.sqrt(lscale) / .87962566103423978; - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), numdType)); + mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); break; case NORMAL: - distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), numdType); + distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); stddev = Math.sqrt(lscale); - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), numdType)); + mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); break; case UNIFORM: - distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), numdType); + distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); stddev = Math.sqrt(3.0 * lscale); - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), numdType)); + mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); break; } - - // Need to cast TNumber to TType - return (Operand) mulOp; + return mulOp; } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java index 09dd512ffaa..4298493ac44 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; @@ -29,7 +28,7 @@ * Zeros<TFloat32> initializer = * new org.tensorflow.framework.initializers.Zeros<>(tf); * Operand<TFloat32> values = - * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.DTYPE); + * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * * @param The TType for the call operation @@ -46,7 +45,7 @@ public Zeros(Ops tf) { } @Override - public Operand call(Operand dims, DataType dtype) { + public Operand call(Operand dims, Class dtype) { return tf.zeros(dims, dtype); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index effdf990f71..c7edfcca24e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -217,8 +217,8 @@ public Operand call( getTF(), "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()), - cast(getTF(), getTF().constant(1), predictions.asOutput().dataType())); + cast(getTF(), getTF().constant(0), predictions.type()), + cast(getTF(), getTF().constant(1), predictions.type())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 6c03fa81b31..77c6ab2bf87 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -256,8 +256,8 @@ public Operand call( getTF(), "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()), - cast(getTF(), getTF().constant(1), predictions.asOutput().dataType())); + cast(getTF(), getTF().constant(0), predictions.type()), + cast(getTF(), getTF().constant(1), predictions.type())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 1e497783841..88b4a7aa056 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -124,16 +124,13 @@ public Hinge(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ? - (Operand)labels : - cast(tf, labels, predictions.asOutput().dataType()); + Operand tLabels = predictions.type() == labels.type() ? + (Operand)labels : cast(tf, labels, predictions.type()); tLabels = LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), - predictions.asOutput().dataType())); - + cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); Operand losses = Losses.hinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 5a9dbf5bf01..81d9e13c8a9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.losses; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; @@ -51,7 +50,7 @@ public class Losses { */ public static Operand meanAbsoluteError( Ops tf, Operand labels, Operand predictions) { - Operand tLabels = cast(tf, labels, predictions.asOutput().dataType()); + Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -73,7 +72,7 @@ public static Operand meanAbsoluteErro */ public static Operand meanSquaredError( Ops tf, Operand labels, Operand predictions) { - Operand tLabels = cast(tf, labels, predictions.asOutput().dataType()); + Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -94,8 +93,8 @@ public static Operand meanSquaredError */ public static Operand meanAbsolutePercentageError( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Class predictionType = predictions.type(); + Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -103,8 +102,8 @@ public static Operand meanAbsolutePerc tf.math.abs( tf.math.div( tf.math.sub(tLabels, predictions), - tf.math.maximum(tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), dataType)))); - return tf.math.mul(cast(tf, tf.constant(100), dataType), tf.math.mean(diff, tf.constant(-1))); + tf.math.maximum(tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); + return tf.math.mul(cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); } /** @@ -121,14 +120,14 @@ public static Operand meanAbsolutePerc */ public static Operand meanSquaredLogarithmicError( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Class predictionType = predictions.type(); + Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); - Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); - Operand one = cast(tf, tf.constant(1), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); + Operand one = cast(tf, tf.constant(1), predictionType); Operand firstLog = tf.math.log(tf.math.add(tf.math.maximum(predictions, epsilonConst), one)); Operand secondLog = tf.math.log(tf.math.add(tf.math.maximum(tLabels, epsilonConst), one)); @@ -152,8 +151,7 @@ public static Operand meanSquaredLogar */ public static Operand binaryCrossentropy( Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { - DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -193,9 +191,9 @@ private static Operand binaryCrossentropyHelper( } */ - DataType dataType = output.asOutput().dataType(); - Operand one = cast(tf, tf.constant(1), dataType); - Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); + Class outputType = output.type(); + Operand one = cast(tf, tf.constant(1), outputType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), outputType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); output = tf.clipByValue(output, epsilonConst, oneMinusEpsilonConst); @@ -231,8 +229,8 @@ public static Operand categoricalCross boolean fromLogits, float labelSmoothing, int axis) { - DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Class predictionType = predictions.type(); + Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -256,8 +254,8 @@ public static Operand categoricalCross } */ - Operand one = cast(tf, tf.constant(1), dataType); - Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); + Operand one = cast(tf, tf.constant(1), predictionType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); predictions = tf.math.div( @@ -284,13 +282,13 @@ public static Operand categoricalCross */ public static Operand categoricalHinge( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Class predictionType = predictions.type(); + Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand one = cast(tf, tf.constant(1), dataType); - Operand zero = cast(tf, tf.constant(0), dataType); + Operand one = cast(tf, tf.constant(1), predictionType); + Operand zero = cast(tf, tf.constant(0), predictionType); Operand pos = tf.reduceSum( @@ -330,8 +328,7 @@ public static Operand categoricalHinge */ public static Operand cosineSimilarity( Ops tf, Operand labels, Operand predictions, int axis) { - DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, predictions.type()); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); @@ -357,13 +354,13 @@ public static Operand cosineSimilarity */ public static Operand hinge( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Class predictionType = predictions.type(); + Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand one = cast(tf, tf.constant(1), dataType); - Operand zero = cast(tf, tf.constant(0), dataType); + Operand one = cast(tf, tf.constant(1), predictionType); + Operand zero = cast(tf, tf.constant(0), predictionType); tLabels = maybeConvertLabels(tf, tLabels); @@ -393,15 +390,15 @@ public static Operand hinge( */ public static Operand huber( Ops tf, Operand labels, Operand predictions, float delta) { - DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Class predictionType = predictions.type(); + Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); Operand error = tf.math.sub(predictions, tLabels); - Operand deltaConst = cast(tf, tf.constant(delta), dataType); - Operand point5 = cast(tf, tf.constant(0.5), dataType); + Operand deltaConst = cast(tf, tf.constant(delta), predictionType); + Operand point5 = cast(tf, tf.constant(0.5), predictionType); Operand absError = tf.math.abs(error); Operand quadratic = tf.math.minimum(absError, deltaConst); Operand linear = tf.math.sub(absError, quadratic); @@ -424,13 +421,13 @@ public static Operand huber( */ public static Operand kullbackLeiblerDivergence( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Class predictionType = predictions.type(); + Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand one = cast(tf, tf.constant(1), dataType); - Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); + Operand one = cast(tf, tf.constant(1), predictionType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); tLabels = tf.clipByValue(tLabels, epsilonConst, one); predictions = tf.clipByValue(predictions, epsilonConst, one); @@ -454,13 +451,13 @@ public static Operand kullbackLeiblerD */ public static Operand logCosh( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Class predictionType = predictions.type(); + Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand minusTwo = cast(tf, tf.constant(-2), dataType); - Operand two = cast(tf, tf.constant(2), dataType); + Operand minusTwo = cast(tf, tf.constant(-2), predictionType); + Operand two = cast(tf, tf.constant(2), predictionType); Operand diff = tf.math.sub(predictions, tLabels); Softplus softplus = tf.math.softplus(tf.math.mul(minusTwo, diff)); @@ -482,12 +479,12 @@ public static Operand logCosh( */ public static Operand poisson( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Class predictionType = predictions.type(); + Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); return tf.math.mean( tf.math.sub( @@ -509,9 +506,9 @@ public static Operand poisson( */ public static Operand sparseCategoricalCrossentropy( Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { - DataType dataType = predictions.asOutput().dataType(); - Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); - Operand one = cast(tf, tf.constant(1), dataType); + Class predictionType = predictions.type(); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); + Operand one = cast(tf, tf.constant(1), predictionType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); /* TODO need ability to walk back inputs @@ -546,7 +543,7 @@ public static Operand sparseCategorica predictions = tf.linalg.transpose(predictions, tf.constant(axisNew)); } - Operand iLabels = cast(tf, labels, TInt64.DTYPE); + Operand iLabels = cast(tf, labels, TInt64.class); // Try to adjust the shape so that rank of labels = rank of logits - 1. Shape labelsShape = labels.shape(); @@ -586,13 +583,13 @@ public static Operand sparseCategorica */ public static Operand squaredHinge( Ops tf, Operand labels, Operand predictions) { - DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Class predictionType = predictions.type(); + Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand one = cast(tf, tf.constant(1), dataType); - Operand zero = cast(tf, tf.constant(0), dataType); + Operand one = cast(tf, tf.constant(1), predictionType); + Operand zero = cast(tf, tf.constant(0), predictionType); tLabels = maybeConvertLabels(tf, tLabels); return tf.math.mean( @@ -614,9 +611,9 @@ public static Operand squaredHinge( */ private static Operand smoothBinaryLabels( Ops tf, Operand labels, float labelSmoothing) { - DataType dataType = labels.asOutput().dataType(); - Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), dataType); - Operand halfSmoothing = cast(tf, tf.constant(0.5F * labelSmoothing), dataType); + Class labelType = labels.type(); + Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), labelType); + Operand halfSmoothing = cast(tf, tf.constant(0.5F * labelSmoothing), labelType); return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), halfSmoothing); } @@ -633,12 +630,12 @@ private static Operand smoothBinaryLabels( */ private static Operand smoothCategoricalLabels( Ops tf, Operand labels, float labelSmoothing) { - DataType dataType = labels.asOutput().dataType(); - Operand smoothing = cast(tf, tf.constant(labelSmoothing), dataType); + Class labelType = labels.type(); + Operand smoothing = cast(tf, tf.constant(labelSmoothing), labelType); Shape labelsShape = labels.shape(); int numDims = labelsShape.numDimensions(); - Operand numClasses = cast(tf, tf.constant(labelsShape.size(numDims - 1)), dataType); - Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), dataType); + Operand numClasses = cast(tf, tf.constant(labelsShape.size(numDims - 1)), labelType); + Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), labelType); return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), tf.math.div(smoothing, numClasses)); } @@ -656,7 +653,7 @@ public static Operand l2Normalize(Ops tf, Operand x, i tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = tf.math.rsqrt( - tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.asOutput().dataType()))); + tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); return tf.math.mul(x, invNorm); } @@ -669,11 +666,11 @@ public static Operand l2Normalize(Ops tf, Operand x, i * @return the labels, possibly converted into -1/1. */ private static Operand maybeConvertLabels(Ops tf, Operand labels) { - DataType dataType = labels.asOutput().dataType(); + Class labelType = labels.type(); - Operand one = cast(tf, tf.constant(1), dataType); - Operand zero = cast(tf, tf.constant(0), dataType); - Operand two = cast(tf, tf.constant(2), dataType); + Operand one = cast(tf, tf.constant(1), labelType); + Operand zero = cast(tf, tf.constant(0), labelType); + Operand two = cast(tf, tf.constant(2), labelType); Operand areZeros = tf.math.equal(labels, zero); Operand areOnes = tf.math.equal(labels, one); Operand isBinary = diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 5586a4da889..ea765e6f8fd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -205,8 +205,8 @@ public Operand call( getTF(), "predictions range check [0-1]", predictions, - cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()), - cast(getTF(), getTF().constant(1), predictions.asOutput().dataType())); + cast(getTF(), getTF().constant(0), predictions.type()), + cast(getTF(), getTF().constant(1), predictions.type())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 182ce592e55..4ad4c1c726c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -125,15 +125,13 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ? - (Operand)labels : - cast(tf, labels, predictions.asOutput().dataType()); + Operand tLabels = predictions.type() == labels.type() ? + (Operand)labels : cast(tf, labels, predictions.type()); tLabels = LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), - predictions.asOutput().dataType())); + cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index fbc92f14d31..10067db91ba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -14,7 +14,6 @@ =======================================================================*/ package org.tensorflow.framework.losses.impl; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Reduction; import org.tensorflow.ndarray.Shape; @@ -249,17 +248,17 @@ public static LossTuple removeSqueezableDimensions( */ public static Operand computeWeightedLoss( Ops tf, Operand loss, Reduction reduction, Operand sampleWeight) { - DataType dataType = loss.asOutput().dataType(); + Class inputType = loss.type(); if (sampleWeight == null) { - sampleWeight = cast(tf, tf.constant(1), dataType); + sampleWeight = cast(tf, tf.constant(1), inputType); } LossTuple result = squeezeOrExpandDimensions(tf, null, loss, sampleWeight); loss = result.getTarget(); sampleWeight = result.getSampleWeights(); - Operand weightedLosses = tf.math.mul(loss, cast(tf, sampleWeight, dataType)); + Operand weightedLosses = tf.math.mul(loss, cast(tf, sampleWeight, inputType)); loss = reduceWeightedLoss(tf, weightedLosses, reduction); - return cast(tf, loss, dataType); + return cast(tf, loss, inputType); } /** @@ -300,7 +299,7 @@ public static Operand safeMean( Ops tf, Operand losses, long numElements) { Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses)); return tf.math.divNoNan( - totalLoss, cast(tf, tf.constant(numElements), losses.asOutput().dataType())); + totalLoss, cast(tf, tf.constant(numElements), losses.type())); } /** @@ -386,7 +385,7 @@ public static Operand valueCheck( Ops tf, String prefix, Operand values, Operand allowedValues) { Operand flatValues = tf.reshape(values, tf.constant(Shape.of(values.shape().size()))); - SetDiff1d diff = tf.setDiff1d(flatValues, allowedValues, TInt32.DTYPE); + SetDiff1d diff = tf.setDiff1d(flatValues, allowedValues, TInt32.class); long diffSize = diff.out().shape().size(); if (diffSize != Shape.UNKNOWN_SIZE) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index 0adf5f58910..822eb490f22 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -141,10 +141,10 @@ protected void createSlots(List> variables) { */ private void createAdaDeltaSlot(Output v) { Operand accumulatorInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), ACCUMULATOR, accumulatorInitializer); Operand updateInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), ACCUMULATOR_UPDATE, updateInitializer); } @@ -157,9 +157,9 @@ protected Op applyDense(Output gradient, Output variable variable, accumSlot, accumUpdateSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), - tf.dtypes.cast(tf.constant(rho), gradient.dataType()), - tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), + tf.dtypes.cast(tf.constant(rho), gradient.type()), + tf.dtypes.cast(tf.constant(epsilon), gradient.type()), gradient); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 1a7f4675662..08f5f18a9cd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -131,7 +131,7 @@ protected void createSlots(List> variables) { */ private void createAdaGradSlot(Output v) { Operand initializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.type())); createSlot(v.asOutput(), ACCUMULATOR, initializer); } @@ -140,7 +140,7 @@ private void createAdaGradSlot(Output v) { protected Op applyDense(Output gradient, Output variable) { Variable slot = getSlot(variable, ACCUMULATOR).get(); return tf.train.applyAdagrad( - variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), gradient); + variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index f76217fda85..df624e41c4e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -187,7 +187,7 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdaGradDASlot(v); } - globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(), TInt64.DTYPE); + globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(), TInt64.class); Assign globalStepInitializer = tf.assign(globalStep, tf.constant(0L)); graph.addInitializer(globalStepInitializer); } @@ -199,10 +199,10 @@ protected void createSlots(List> variables) { * @param the datatype of the variable. */ private void createAdaGradDASlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), ACCUMULATOR, initializer); Operand sqInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.type())); createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer); } @@ -216,9 +216,9 @@ protected Op applyDense(Output gradient, Output variable gradSlot, gradSquaredSlot, gradient, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), - tf.dtypes.cast(tf.constant(l1Strength), gradient.dataType()), - tf.dtypes.cast(tf.constant(l2Strength), gradient.dataType()), + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), + tf.dtypes.cast(tf.constant(l1Strength), gradient.type()), + tf.dtypes.cast(tf.constant(l2Strength), gradient.type()), globalStep); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java index 8f620678781..72598d12543 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java @@ -189,10 +189,10 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdamSlot(v.asOutput()); } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); + betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class); Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); graph.addInitializer(betaOnePowerInit); - betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.DTYPE); + betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.class); Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo)); graph.addInitializer(betaTwoPowerInit); } @@ -215,10 +215,10 @@ protected Optional prepare(String scopeName) { */ private void createAdamSlot(Output v) { Operand firstMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); Operand secondMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); } @@ -231,12 +231,12 @@ protected Op applyDense(Output gradient, Output variable variable, firstMomentSlot, secondMomentSlot, - tf.dtypes.cast(betaOnePower, gradient.dataType()), - tf.dtypes.cast(betaTwoPower, gradient.dataType()), - tf.dtypes.cast(learningRateConst, gradient.dataType()), - tf.dtypes.cast(betaOneConst, gradient.dataType()), - tf.dtypes.cast(betaTwoConst, gradient.dataType()), - tf.dtypes.cast(epsilonConst, gradient.dataType()), + tf.dtypes.cast(betaOnePower, gradient.type()), + tf.dtypes.cast(betaTwoPower, gradient.type()), + tf.dtypes.cast(learningRateConst, gradient.type()), + tf.dtypes.cast(betaOneConst, gradient.type()), + tf.dtypes.cast(betaTwoConst, gradient.type()), + tf.dtypes.cast(epsilonConst, gradient.type()), gradient); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index 335d83cedfa..cd95bb3bd07 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -137,7 +137,7 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdamaxSlot(v.asOutput()); } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); + betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class); Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); ((Graph) tf.scope().env()).addInitializer(betaOnePowerInit); } @@ -150,10 +150,10 @@ protected void createSlots(List> variables) { */ private void createAdamaxSlot(Output v) { Operand firstMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); Operand secondMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); } @@ -167,11 +167,11 @@ protected Op applyDense(Output gradient, Output variable variable, firstMomentSlot, secondMomentSlot, - tf.dtypes.cast(betaOnePower, gradient.dataType()), - tf.dtypes.cast(learningRateConst, gradient.dataType()), - tf.dtypes.cast(betaOneConst, gradient.dataType()), - tf.dtypes.cast(betaTwoConst, gradient.dataType()), - tf.dtypes.cast(epsilonConst, gradient.dataType()), + tf.dtypes.cast(betaOnePower, gradient.type()), + tf.dtypes.cast(learningRateConst, gradient.type()), + tf.dtypes.cast(betaOneConst, gradient.type()), + tf.dtypes.cast(betaTwoConst, gradient.type()), + tf.dtypes.cast(epsilonConst, gradient.type()), gradient); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index 04c34a2535e..66314d2ffe0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -230,10 +230,10 @@ protected void createSlots(List> variables) { */ private void createFtrlSlot(Output v) { Operand initializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue), v.type())); createSlot(v.asOutput(), ACCUMULATOR, initializer); Operand linearInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), LINEAR_ACCUMULATOR, linearInitializer); } @@ -248,12 +248,12 @@ protected Op applyDense(Output gradient, Output variable accumSlot, // accum linearSlot, // linear gradient, // gradient - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), // lr - tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.dataType()), // l1 - tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.dataType()), // l2 + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), // lr + tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.type()), // l1 + tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.type()), // l2 tf.dtypes.cast( - tf.constant(l2ShrinkageRegularizationStrength), gradient.dataType()), // l2Shrinkage - tf.dtypes.cast(tf.constant(learningRatePower), gradient.dataType()), // lrPower + tf.constant(l2ShrinkageRegularizationStrength), gradient.type()), // l2Shrinkage + tf.dtypes.cast(tf.constant(learningRatePower), gradient.type()), // lrPower options); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java index e307855e636..a373b2e5b55 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java @@ -66,7 +66,7 @@ public GradientDescent(Graph graph, String name, float learningRate) { @Override protected Op applyDense(Output gradient, Output variable) { return tf.train.applyGradientDescent( - variable, tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), gradient); + variable, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index 111727d26fa..f6640409d60 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -125,7 +125,7 @@ protected void createSlots(List> variables) { * @param the data type of the variable */ private void createMomentumSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MOMENTUM, initializer); } @@ -136,9 +136,9 @@ protected Op applyDense(Output gradient, Output variable return tf.train.applyMomentum( variable, slot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient, - tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), + tf.dtypes.cast(tf.constant(momentum), gradient.type()), ApplyMomentum.useNesterov(useNesterov)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index 48e5135c952..f9900a8ee78 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -1,6 +1,5 @@ package org.tensorflow.framework.optimizers; -import org.tensorflow.DataType; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; @@ -142,15 +141,15 @@ protected void createSlots(List> variables) { for (Output v : variables) { createNadamSlot(v.asOutput()); } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); + betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class); Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); ((Graph) tf.scope().env()).addInitializer(betaOnePowerInit); - betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.DTYPE); + betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.class); Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo)); ((Graph) tf.scope().env()).addInitializer(betaTwoPowerInit); - momentum = tf.withName("momentum").variable(Shape.scalar(), TFloat32.DTYPE); + momentum = tf.withName("momentum").variable(Shape.scalar(), TFloat32.class); Assign momentumInit = tf.assign(momentum, tf.constant(1.0F)); ((Graph) tf.scope().env()).addInitializer(momentumInit); } @@ -163,14 +162,14 @@ protected void createSlots(List> variables) { */ private void createNadamSlot(Output v) { Operand firstMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); Operand secondMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); Operand momentumInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); } @@ -198,7 +197,7 @@ protected Optional prepare(String scopeName) { point5, tf.math.pow( decayBaseConst, - tf.math.mul(decayConst, tf.dtypes.cast(localStepConst, TFloat32.DTYPE)))))); + tf.math.mul(decayConst, tf.dtypes.cast(localStepConst, TFloat32.class)))))); mT1 = tf.math.mul( @@ -209,7 +208,7 @@ protected Optional prepare(String scopeName) { point5, tf.math.pow( decayBaseConst, - tf.math.mul(decayConst, tf.dtypes.cast(nextStepConst, TFloat32.DTYPE)))))); + tf.math.mul(decayConst, tf.dtypes.cast(nextStepConst, TFloat32.class)))))); Operand mScheduleNew = tf.math.mul(momentum, mT); @@ -222,57 +221,57 @@ protected Optional prepare(String scopeName) { oneMinusMScheduleNew = tf.math.sub(one, mScheduleNew); oneMinusMScheduleNext = tf.math.sub(one, mScheduleNext); vTPrimeDenominator = - tf.math.sub(one, tf.math.pow(betaTwoConst, tf.dtypes.cast(localStepConst, TFloat32.DTYPE))); + tf.math.sub(one, tf.math.pow(betaTwoConst, tf.dtypes.cast(localStepConst, TFloat32.class))); return Optional.empty(); } /** {@inheritDoc} */ @Override protected Op applyDense(Output gradient, Output variable) { - DataType dType = gradient.dataType(); + Class type = gradient.type(); Variable m = getSlot(variable, FIRST_MOMENT).get(); // first Moment Variable v = getSlot(variable, SECOND_MOMENT).get(); // Second Moment // gPrime = grad / coefficients['oneMinusMScheduleNew'] - Operand gPrime = tf.math.div(gradient, tf.dtypes.cast(oneMinusMScheduleNew, dType)); + Operand gPrime = tf.math.div(gradient, tf.dtypes.cast(oneMinusMScheduleNew, type)); // mT = (coefficients['beta_1_t'] * m + coefficients['one_minus_beta_1_t'] * grad) Operand mT = tf.math.add( - tf.math.mul(tf.dtypes.cast(betaOneConst, dType), m), - tf.math.mul(tf.dtypes.cast(oneMinusBeta1, dType), gradient)); + tf.math.mul(tf.dtypes.cast(betaOneConst, type), m), + tf.math.mul(tf.dtypes.cast(oneMinusBeta1, type), gradient)); // mT = state_ops.assign(m, mT, use_locking=self._use_locking) // update m mT = tf.assign(m, mT, Assign.useLocking(true)); // mTPrime = mT / coefficients['oneMinusMScheduleNext'] - Operand mTPrime = tf.math.div(mT, tf.dtypes.cast(oneMinusMScheduleNext, dType)); + Operand mTPrime = tf.math.div(mT, tf.dtypes.cast(oneMinusMScheduleNext, type)); // vT = (coefficients['beta_2_t'] * v + coefficients['one_minus_beta_2_t'] * // math_ops.square(grad)) Operand vT = tf.math.add( - tf.math.mul(tf.dtypes.cast(betaTwoConst, dType), v), - tf.math.mul(tf.dtypes.cast(oneMinusBeta2, dType), tf.math.square(gradient))); + tf.math.mul(tf.dtypes.cast(betaTwoConst, type), v), + tf.math.mul(tf.dtypes.cast(oneMinusBeta2, type), tf.math.square(gradient))); // vT = state_ops.assign(v, vT, use_locking=self._use_locking) // update v vT = tf.assign(v, vT, Assign.useLocking(true)); // vTPrime = vT / coefficients['vTPrimeDenominator'] - Operand vTPrime = tf.math.div(vT, tf.dtypes.cast(vTPrimeDenominator, dType)); + Operand vTPrime = tf.math.div(vT, tf.dtypes.cast(vTPrimeDenominator, type)); // m_t_bar = (coefficients['oneMinusMT'] * gPrime + coefficients['mT1'] * mTPrime) Operand m_t_bar = tf.math.add( - tf.math.mul(tf.dtypes.cast(oneMinusMT, dType), gPrime), - tf.math.mul(tf.dtypes.cast(mT1, dType), mTPrime)); + tf.math.mul(tf.dtypes.cast(oneMinusMT, type), gPrime), + tf.math.mul(tf.dtypes.cast(mT1, type), mTPrime)); // varT = var - coefficients['lr_t'] * m_t_bar / (math_ops.sqrt(vTPrime) + // coefficients['epsilon']) Operand varT = tf.math.sub( variable, tf.math.div( - tf.math.mul(tf.dtypes.cast(learningRateConst, dType), m_t_bar), - tf.math.add(tf.math.sqrt(vTPrime), tf.dtypes.cast(epsilonConst, dType)))); + tf.math.mul(tf.dtypes.cast(learningRateConst, type), m_t_bar), + tf.math.add(tf.math.sqrt(vTPrime), tf.dtypes.cast(epsilonConst, type)))); return tf.assign(variable, varT, Assign.useLocking(true)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index 70f065814f7..fdf56da4a67 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -220,7 +220,7 @@ private Optional> getSlot(String varName, String s protected void createSlot( Output variable, String slotName, Operand initializer) { Variable slot = - tf.withName(createName(variable, slotName)).variable(variable.shape(), variable.dataType()); + tf.withName(createName(variable, slotName)).variable(variable.shape(), variable.type()); Assign slotInit = tf.assign(slot, initializer); graph.addInitializer(slotInit); String varName = variable.op().name(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index 9a48a9b8a7a..b3729dc367f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -175,14 +175,14 @@ protected void createSlots(List> variables) { */ private void createRMSPropSlot(Output v) { Operand rmsInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type())); createSlot(v.asOutput(), RMS, rmsInitializer); Operand momentumInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { Operand mgInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MG, mgInitializer); } } @@ -199,20 +199,20 @@ protected Op applyDense(Output gradient, Output variable mgSlot, rmsSlot, momentumSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), - tf.dtypes.cast(tf.constant(decay), gradient.dataType()), - tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), - tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), + tf.dtypes.cast(tf.constant(decay), gradient.type()), + tf.dtypes.cast(tf.constant(momentum), gradient.type()), + tf.dtypes.cast(tf.constant(epsilon), gradient.type()), gradient); } return tf.train.applyRmsProp( variable, rmsSlot, momentumSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), - tf.dtypes.cast(tf.constant(decay), gradient.dataType()), - tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), - tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), + tf.dtypes.cast(tf.constant(decay), gradient.type()), + tf.dtypes.cast(tf.constant(momentum), gradient.type()), + tf.dtypes.cast(tf.constant(epsilon), gradient.type()), gradient); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java index aec75e6078a..b0fe48967dd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java @@ -15,7 +15,6 @@ */ package org.tensorflow.framework.utils; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TType; @@ -35,8 +34,8 @@ public class CastHelper { */ @SuppressWarnings("unchecked") public static Operand cast( - Ops tf, Operand value, DataType requiredType) { - return (value.asOutput().dataType() == requiredType) + Ops tf, Operand value, Class requiredType) { + return (value.type() == requiredType) ? (Operand) value : tf.dtypes.cast(value, requiredType); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java index 87a43941a16..4ca2c789f28 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java @@ -21,7 +21,7 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.TUint8; -import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TIntegral; import java.util.ArrayList; import java.util.Arrays; @@ -36,7 +36,7 @@ public class ShapeUtils { * @param dims the Operand containing the shape values * @return a new Shape based on an Operand that contains dimensions */ - public static Shape toShape(Scope scope, Operand dims) { + public static Shape toShape(Scope scope, Operand dims) { long[] longDims = getLongArray(scope, dims); return Shape.of(longDims); } @@ -62,12 +62,12 @@ public static int[] getIntArray(Scope scope, Operand dims) { * @return the long array * @throws java.lang.IllegalArgumentException if the dims type is not an integer */ - public static long[] getLongArray(Scope scope, Operand dims) { + public static long[] getLongArray(Scope scope, Operand dims) { if (scope.env().isEager()) { return getLongArray(dims.asTensor()); } try (Session session = new Session((Graph)scope.env()); - Tensor tensor = session.runner().fetch(dims).run().get(0)) { + TIntegral tensor = (TIntegral)session.runner().fetch(dims).run().get(0)) { return getLongArray(tensor); } } @@ -79,7 +79,7 @@ public static long[] getLongArray(Scope scope, Operand di * @return the long array * @throws java.lang.IllegalArgumentException if the dims type is not an integer */ - public static long[] getLongArray(Tensor dims) { + public static long[] getLongArray(T dims) { List result = new ArrayList<>(); if (dims instanceof TInt32) { ((TInt32)dims).scalars().forEach(s -> result.add((long) s.getInt())); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java index e608224a50d..914b94dfada 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java @@ -20,7 +20,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; import static org.junit.jupiter.api.Assertions.assertThrows; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java index a0fd1f60b47..1157c582168 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java @@ -20,7 +20,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; import static org.junit.jupiter.api.Assertions.assertThrows; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java index b1eaab8de22..35f57c47f66 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java @@ -20,7 +20,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; import static org.junit.jupiter.api.Assertions.assertThrows; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java index f54401515ab..a0aa2c4b453 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java @@ -90,8 +90,8 @@ public void testCallFloat16() { Ops tf = session.getTF(); ReLU instance = new ReLU<>(tf); Operand result = - instance.call(tf.dtypes.cast(tf.constant(input), TFloat16.DTYPE)); - session.evaluate(tf.dtypes.cast(tf.constant(expected), TFloat16.DTYPE), result); + instance.call(tf.dtypes.cast(tf.constant(input), TFloat16.class)); + session.evaluate(tf.dtypes.cast(tf.constant(expected), TFloat16.class), result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java index caba5c43ba8..8bad6f1f066 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java @@ -20,7 +20,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; import static org.junit.jupiter.api.Assertions.assertThrows; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java index ffb16cf077a..9dca622c3ec 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java @@ -20,7 +20,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; import static org.junit.jupiter.api.Assertions.assertThrows; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java index a3ff89cc407..05ec3a4f716 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java @@ -18,11 +18,8 @@ import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.ReduceMax; -import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; import static org.junit.jupiter.api.Assertions.assertThrows; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java index 5739bccd3d5..7576789320b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java @@ -20,7 +20,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; import static org.junit.jupiter.api.Assertions.assertThrows; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java index 48800d4dc1b..2d282e5dcf7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java @@ -17,7 +17,6 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; -import org.tensorflow.Tensor; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; @@ -40,16 +39,14 @@ public void testEagerBatchDataset() { Arrays.asList( tf.constant(testMatrix1), tf.constant(testMatrix2)), - Arrays.asList(TInt32.DTYPE, TInt32.DTYPE)) + Arrays.asList(TInt32.class, TInt32.class)) .batch(2); int count = 0; for (List> components : dataset) { try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); - TInt32 batch2 = - (TInt32)components.get(1).asTensor();) { - + TInt32 batch2 = (TInt32)components.get(1).asTensor()) { assertEquals(testMatrix1.slice(range(count, count + 2)), batch1); assertEquals(testMatrix2.slice(range(count, count + 2)), batch2); @@ -66,7 +63,7 @@ public void testDropLastBatch() { Arrays.asList( tf.constant(testMatrix1), tf.constant(testMatrix2)), - Arrays.asList(TInt32.DTYPE, TInt32.DTYPE)) + Arrays.asList(TInt32.class, TInt32.class)) .batch(3, true); int count = 0; @@ -74,9 +71,7 @@ public void testDropLastBatch() { try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); - TInt32 batch2 = - (TInt32)components.get(1).asTensor();) { - + TInt32 batch2 = (TInt32)components.get(1).asTensor()) { assertEquals(testMatrix1.slice(range(count, count + 3)), batch1); assertEquals(testMatrix2.slice(range(count, count + 3)), batch2); @@ -93,7 +88,7 @@ public void testKeepLastBatch() { Arrays.asList( tf.constant(testMatrix1), tf.constant(testMatrix2)), - Arrays.asList(TInt32.DTYPE, TInt32.DTYPE)) + Arrays.asList(TInt32.class, TInt32.class)) .batch(3, false); int count = 0; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java index 448e90a17ea..882a64ba54d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java @@ -16,11 +16,10 @@ package org.tensorflow.framework.data; import org.junit.jupiter.api.Test; -import org.tensorflow.DataType; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; +import org.tensorflow.types.family.TType; import org.tensorflow.exceptions.TFOutOfRangeException; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; @@ -39,7 +38,7 @@ public void testGraphIteration() { List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); - List> dataTypes = Arrays.asList(TInt32.DTYPE, TInt32.DTYPE); + List> dataTypes = Arrays.asList(TInt32.class, TInt32.class); Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes); DatasetIterator iterator = dataset.makeOneShotIterator(); @@ -77,14 +76,13 @@ public void testEagerIteration() { List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); - List> dataTypes = Arrays.asList(TInt32.DTYPE, TInt32.DTYPE); + List> dataTypes = Arrays.asList(TInt32.class, TInt32.class); Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes); int count = 0; for (List> outputs : dataset) { try (TInt32 batch1 = (TInt32)outputs.get(0).asTensor(); - TInt32 batch2 = (TInt32)outputs.get(1).asTensor(); ) { - + TInt32 batch2 = (TInt32)outputs.get(1).asTensor()) { assertEquals(testMatrix1.get(count), batch1); assertEquals(testMatrix2.get(count), batch2); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java index ede0a1aa61d..5f203427563 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java @@ -17,11 +17,10 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.tensorflow.DataType; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; +import org.tensorflow.types.family.TType; import org.tensorflow.exceptions.TFOutOfRangeException; import org.tensorflow.op.Ops; import org.tensorflow.ndarray.IntNdArray; @@ -60,12 +59,13 @@ public void testGraphIteration() { List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); - List> dataTypes = Arrays.asList(TInt32.DTYPE, TInt32.DTYPE); + List> dataTypes = Arrays.asList(TInt32.class, TInt32.class); Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes) .mapAllComponents( - component -> tf.math.mul(component.asOutput().expect(TInt32.DTYPE), tf.constant(2))); + component -> + tf.math.mul(component.asOutput().expect(TInt32.class), tf.constant(2))); DatasetIterator iterator = dataset.makeOneShotIterator(); List> components = iterator.getNext(); @@ -104,18 +104,17 @@ public void testEagerIteration() { List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); - List> dataTypes = Arrays.asList(TInt32.DTYPE, TInt32.DTYPE); + List> dataTypes = Arrays.asList(TInt32.class, TInt32.class); Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes) .mapAllComponents( - op -> tf.math.mul(op.asOutput().expect(TInt32.DTYPE), tf.constant(2))); + op -> tf.math.mul(op.asOutput().expect(TInt32.class), tf.constant(2))); int count = 0; for (List> outputs : dataset) { try (TInt32 XBatch = (TInt32)outputs.get(0).asTensor(); - TInt32 yBatch = (TInt32)outputs.get(1).asTensor(); ) { - + TInt32 yBatch = (TInt32)outputs.get(1).asTensor()) { assertEquals(mapped1.get(count), XBatch); assertEquals(mapped2.get(count), yBatch); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java index 6dc877cc6eb..d0cdb4527a5 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java @@ -17,7 +17,6 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; -import org.tensorflow.Tensor; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; @@ -35,14 +34,13 @@ public void testEagerSkipDataset() { Dataset.fromTensorSlices( tf, Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), - Arrays.asList(TInt32.DTYPE, TInt32.DTYPE)) + Arrays.asList(TInt32.class, TInt32.class)) .skip(2); int count = 2; for (List> components : dataset) { try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); - TInt32 batch2 = - (TInt32)components.get(1).asTensor(); ) { + TInt32 batch2 = (TInt32)components.get(1).asTensor()) { assertEquals(testMatrix1.get(count), batch1); assertEquals(testMatrix2.get(count), batch2); count++; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java index 626fe719936..79a2e79c72e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java @@ -17,7 +17,6 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; -import org.tensorflow.Tensor; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; @@ -36,14 +35,13 @@ public void testEagerTakeDataset() { Dataset.fromTensorSlices( tf, Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), - Arrays.asList(TInt32.DTYPE, TInt32.DTYPE)) + Arrays.asList(TInt32.class, TInt32.class)) .take(4); int count = 0; for (List> components : dataset) { try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); - TInt32 batch2 = (TInt32)components.get(1).asTensor(); ) { - + TInt32 batch2 = (TInt32)components.get(1).asTensor()) { assertEquals(testMatrix1.get(count), batch1); assertEquals(testMatrix2.get(count), batch2); count++; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java index 46e4232d5ae..4e81e0620e6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java @@ -52,7 +52,7 @@ public void testCallUInt() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 0xf); - Operand operand = instance.call(tf.constant(shape), TUint8.DTYPE); + Operand operand = instance.call(tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -68,7 +68,7 @@ public void testCallInt() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 0xf); - Operand operand = instance.call(tf.constant(shape), TInt32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -84,7 +84,7 @@ public void testCallLong() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 0xffL); - Operand operand = instance.call(tf.constant(shape), TInt64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -98,7 +98,7 @@ public void testCallFloat() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 12.F); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -113,7 +113,7 @@ public void testCallDouble() { Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 11.); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -130,7 +130,7 @@ public void testCallString() { Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 22); - instance.call(tf.constant(shape), TString.DTYPE); + instance.call(tf.constant(shape), TString.class); fail("IllegalArgumentException should have been thrown for TString"); } }); @@ -146,7 +146,7 @@ public void testCallBool() { Boolean[] expected = {true, true, true, true}; Constant instance = new Constant<>(tf, true); - Operand operand = instance.call(tf.constant(shape), TBool.DTYPE); + Operand operand = instance.call(tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -159,8 +159,8 @@ public void testReproducible() { Shape shape = Shape.of(2, 2); Constant instance = new Constant<>(tf, 11.); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java index a68bf2a0a98..e9769806928 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java @@ -51,9 +51,9 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,8 +68,8 @@ public void testCallNormalDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,8 +82,8 @@ public void testCallUniformFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -97,8 +97,8 @@ public void testCallUniformDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -109,9 +109,9 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -122,9 +122,9 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -135,10 +135,10 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = + Glorot instance = new Glorot<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java index 468759d347f..8953fa3005e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java @@ -51,8 +51,8 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -66,8 +66,8 @@ public void testCallNormalDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -80,8 +80,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + He instance = new He<>(tf, Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -95,8 +95,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + He instance = new He<>(tf, Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -107,9 +107,9 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -120,9 +120,9 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + He instance = new He<>(tf, Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -133,9 +133,9 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + He instance = new He<>(tf, Distribution.NORMAL, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java index adb6c0c118a..6eee5473937 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java @@ -45,23 +45,6 @@ public void setUp() {} @AfterEach public void tearDown() {} - /** Test of call method, of class Orthogonal. */ - @Test - public void testCallInt() { - for (TestSession.Mode tfMode : tfModes) - assertThrows( - java.lang.IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Shape shape = Shape.of(10, 10); - Identity instance = new Identity<>(tf, 2.); - instance.call(tf.constant(shape), TInt32.DTYPE); - fail("Should have thrown IllegalArgumentException on Integer type"); - } - }); - } - /** Test of call method, of class Constant. */ @Test public void testCallFloat() { @@ -82,7 +65,7 @@ public void testCallFloat() { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); Identity instance = new Identity<>(tf, 2.); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -108,7 +91,7 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); Identity instance = new Identity<>(tf, 2.); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -121,8 +104,8 @@ public void testReproducible() { Shape shape = Shape.of(2, 2); Identity instance = new Identity<>(tf, 2.); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java index 6033f9e12a5..336850a5549 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java @@ -51,8 +51,8 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -66,8 +66,8 @@ public void testCallNormalDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -80,8 +80,8 @@ public void testCallUniformFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -95,8 +95,8 @@ public void testCallUniformDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -107,9 +107,9 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -120,9 +120,9 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -133,9 +133,9 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + LeCun instance = new LeCun<>(tf, Distribution.NORMAL, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java index bbd2ba3d384..053ba5dd7ff 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java @@ -52,7 +52,7 @@ public void testCallUInt() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TUint8.DTYPE); + Operand operand = instance.call(tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -66,7 +66,7 @@ public void testCallInt() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -80,7 +80,7 @@ public void testCallLong() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -94,7 +94,7 @@ public void testCallFloat() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -109,7 +109,7 @@ public void testCallDouble() { Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -126,7 +126,7 @@ public void testCallString() { Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - instance.call(tf.constant(shape), TString.DTYPE); + instance.call(tf.constant(shape), TString.class); fail("IllegalArgumentException should have been thrown for TString"); } }); @@ -141,7 +141,7 @@ public void testCallBool() { Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand = instance.call(tf.constant(shape), TBool.DTYPE); + Operand operand = instance.call(tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -154,8 +154,8 @@ public void testReproducible() { Shape shape = Shape.of(2, 2); Ones instance = new Ones<>(tf); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java index a4fff5fd19c..22b89d9177c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java @@ -47,23 +47,6 @@ public void setUp() {} @AfterEach public void tearDown() {} - /** Test of call method, of class Orthogonal. */ - @Test - public void testCallInt() { - for (TestSession.Mode tfMode : tfModes) - assertThrows( - java.lang.IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - instance.call(tf.constant(shape), TInt32.DTYPE); - fail("Should have thrown IllegalArgumentException on Integer type"); - } - }); - } - /** Test of call method, of class Orthogonal. */ @Test public void testCallFloat() { @@ -173,8 +156,8 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -288,8 +271,8 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -301,9 +284,9 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java index 50aec670503..3b2b3bdb243 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java @@ -52,9 +52,9 @@ public void testCalltestSoftmaxFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = + RandomNormal instance = new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,9 +68,9 @@ public void testCalltestSoftmaxDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = + RandomNormal instance = new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,10 +82,10 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = + RandomNormal instance = new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java index d3f9af74209..23e26083a9b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java @@ -53,9 +53,9 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = + RandomUniform instance = new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TInt32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -68,9 +68,9 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = + RandomUniform instance = new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -84,9 +84,9 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = + RandomUniform instance = new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -98,10 +98,10 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = + RandomUniform instance = new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java index 0a551df2f38..96bf915e199 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java @@ -52,9 +52,9 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = + TruncatedNormal instance = new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -68,9 +68,9 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = + TruncatedNormal instance = new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -82,10 +82,10 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = + TruncatedNormal instance = new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java index 77e0dd7afc7..159affb07e2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java @@ -50,14 +50,14 @@ public void testCallFloat1FanInTruncatedNormal() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -71,14 +71,14 @@ public void testCallDouble1FanInTruncatedNormal() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -91,14 +91,14 @@ public void testCallFloat1FanInNormal() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -112,14 +112,14 @@ public void testCalltestSoftmaxDouble1FanInNormal() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -132,10 +132,10 @@ public void testCalltestSoftmaxFloat1FanInUNIFORM() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -149,10 +149,10 @@ public void testCalltestSoftmaxDouble1FanInUNIFORM() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -164,11 +164,11 @@ public void testReproducible1() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -180,15 +180,15 @@ public void testReproducible2() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -200,15 +200,15 @@ public void testReproducible3() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_OUT, VarianceScaling.Distribution.TRUNCATED_NORMAL, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } @@ -220,11 +220,11 @@ public void testReproducible4() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = + VarianceScaling instance = new VarianceScaling<>( tf, 1.0, VarianceScaling.Mode.FAN_AVG, VarianceScaling.Distribution.UNIFORM, SEED); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java index 975678add19..21bad6ff360 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java @@ -49,7 +49,7 @@ public void testCallUInt() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TUint8.DTYPE); + Operand operand = instance.call(tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } } @@ -63,7 +63,7 @@ public void testCallInt() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } } @@ -77,7 +77,7 @@ public void testCallLong() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TInt64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } } @@ -91,7 +91,7 @@ public void testCallFloat() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat32.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } } @@ -106,7 +106,7 @@ public void testCallDouble() { Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } } @@ -120,7 +120,7 @@ public void testCallString() { Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TString.DTYPE); + Operand operand = instance.call(tf.constant(shape), TString.class); session.evaluateString(operand, String::isEmpty); } } @@ -135,7 +135,7 @@ public void testCallBool() { Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TBool.DTYPE); + Operand operand = instance.call(tf.constant(shape), TBool.class); session.evaluate(expected, operand); } } @@ -148,8 +148,8 @@ public void testReproducible() { Shape shape = Shape.of(2, 2); Zeros instance = new Zeros<>(tf); - Operand operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE); - Operand operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE); + Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); + Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java index 339157c99d1..86a3200ac81 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java @@ -89,8 +89,8 @@ public void testBasic() { float[] var1Init = {3.0F, 4.0F}; float[] fgrads = {grad, grad}; Shape shape = Shape.of(var0Init.length); - Variable var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java index ef9053ff1eb..d5b2657a4fc 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java @@ -64,8 +64,8 @@ public void testBasic() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java index 03717083efc..8182dc5b00d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java @@ -79,8 +79,8 @@ public void testBasic() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java index d2592026f12..49154882a0f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java @@ -16,7 +16,6 @@ import org.junit.jupiter.api.*; import org.tensorflow.Graph; -import org.tensorflow.Tensor; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -80,8 +79,8 @@ public void testBasic() { Ops tf = instance.getTF(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java index ac322f952db..60c17674dfe 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java @@ -16,7 +16,6 @@ import org.junit.jupiter.api.*; import org.tensorflow.Graph; -import org.tensorflow.Tensor; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -101,8 +100,8 @@ public void testBasic() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java index 597f8e52bcd..7698d76f957 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java @@ -76,8 +76,8 @@ public void testFtrlWithL1L2L2Shrinkage() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); @@ -141,8 +141,8 @@ public void testFtrlWithL1() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); @@ -207,8 +207,8 @@ public void testFtrlWithL1L2() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); @@ -273,8 +273,8 @@ public void doTestFtrlwithoutRegularization() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index 4362c54d815..aefcc537979 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -61,8 +61,8 @@ public void testBasic() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java index 3649fbd8287..80a8d9b5fd6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java @@ -77,8 +77,8 @@ public void testBasic() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); @@ -130,8 +130,8 @@ public void testMomentum() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index e064e793a37..849f2fbfec1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -16,7 +16,6 @@ import org.junit.jupiter.api.*; import org.tensorflow.Graph; -import org.tensorflow.Tensor; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -102,8 +101,8 @@ public void testBasic() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java index 202fb21ef68..3b002cd1dbe 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java @@ -87,8 +87,8 @@ public void testDense() { Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java index 1f5f2f16053..7884308c9fb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java @@ -84,8 +84,8 @@ public EagerSession getEagerSession() { /** {@inheritDoc} */ @Override public void evaluate(double expected, Operand input) { - DataType dtype = input.asOutput().dataType(); - if (dtype == TFloat32.DTYPE) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -95,7 +95,7 @@ public void evaluate(double expected, Operand input) { } index.set(0); o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon)); - } else if (dtype == TFloat64.DTYPE) { + } else if (inputType == TFloat64.class) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -105,7 +105,7 @@ public void evaluate(double expected, Operand input) { } index.set(0); o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); - } else if (dtype == TInt32.DTYPE) { + } else if (inputType == TInt32.class) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -115,7 +115,7 @@ public void evaluate(double expected, Operand input) { } index.set(0); o.asTensor().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); - } else if (dtype == TInt64.DTYPE) { + } else if (inputType == TInt64.class) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -125,7 +125,7 @@ public void evaluate(double expected, Operand input) { } index.set(0); o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); - } else if (dtype == TUint8.DTYPE) { + } else if (inputType == TUint8.class) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -146,8 +146,8 @@ public void evaluate(Number[] expected, Output input) { expected.length, size, () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - DataType dtype = input.dataType(); - if (dtype == TFloat32.DTYPE) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -162,7 +162,7 @@ public void evaluate(Number[] expected, Output input) { f -> assertEquals( expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); - } else if (dtype == TFloat64.DTYPE) { + } else if (inputType == TFloat64.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -177,7 +177,7 @@ public void evaluate(Number[] expected, Output input) { f -> assertEquals( expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); - } else if (dtype == TInt32.DTYPE) { + } else if (inputType == TInt32.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -189,7 +189,7 @@ public void evaluate(Number[] expected, Output input) { o.asTensor() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); - } else if (dtype == TInt64.DTYPE) { + } else if (inputType == TInt64.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -201,7 +201,7 @@ public void evaluate(Number[] expected, Output input) { o.asTensor() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); - } else if (dtype == TUint8.DTYPE) { + } else if (inputType == TUint8.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -219,8 +219,8 @@ public void evaluate(Number[] expected, Output input) { /** {@inheritDoc} */ @Override public void evaluate(FloatNdArray expected, Output input) { - DataType dtype = input.dataType(); - if (dtype == TFloat32.DTYPE) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { Output o = (Output) input; AtomicLong index = new AtomicLong(); if (debug) { @@ -233,7 +233,7 @@ public void evaluate(FloatNdArray expected, Output input) { .scalars() .forEach( f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); - } else if (dtype == TFloat64.DTYPE) { + } else if (inputType == TFloat64.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -247,7 +247,7 @@ public void evaluate(FloatNdArray expected, Output input) { .forEach( f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); - } else if (dtype == TInt32.DTYPE) { + } else if (inputType == TInt32.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -259,7 +259,7 @@ public void evaluate(FloatNdArray expected, Output input) { for (IntNdArray f : o.asTensor().scalars()) { assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt()); } - } else if (dtype == TInt64.DTYPE) { + } else if (inputType == TInt64.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -272,7 +272,7 @@ public void evaluate(FloatNdArray expected, Output input) { .scalars() .forEach( f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); - } else if (dtype == TUint8.DTYPE) { + } else if (inputType == TUint8.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { @@ -320,9 +320,9 @@ public void evaluateString(Output input, Predicate predicate) { @Override public void evaluate(Output input, Predicate predicate) { AtomicInteger index = new AtomicInteger(); - DataType dtype = input.asOutput().dataType(); + Class inputType = input.type(); boolean isScalar = input.shape().equals(Shape.scalar()); - if (dtype == TFloat32.DTYPE) { + if (inputType == TFloat32.class) { Output o = (Output) input; if (debug) { if (isScalar) { @@ -346,7 +346,7 @@ public void evaluate(Output input, Predicate predic .scalars() .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); } - } else if (dtype == TFloat64.DTYPE) { + } else if (inputType == TFloat64.class) { Output o = (Output) input; if (debug) { if (isScalar) { @@ -370,7 +370,7 @@ public void evaluate(Output input, Predicate predic .scalars() .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getDouble()))); } - } else if (dtype == TFloat16.DTYPE) { + } else if (inputType == TFloat16.class) { Output o = (Output) input; if (debug) { if (isScalar) { @@ -394,7 +394,7 @@ public void evaluate(Output input, Predicate predic .scalars() .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); } - } else if (dtype == TInt32.DTYPE) { + } else if (inputType == TInt32.class) { Output o = (Output) input; if (debug) { if (isScalar) { @@ -418,7 +418,7 @@ public void evaluate(Output input, Predicate predic .scalars() .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getInt()))); } - } else if (dtype == TInt64.DTYPE) { + } else if (inputType == TInt64.class) { Output o = (Output) input; if (debug) { if (isScalar) { @@ -442,7 +442,7 @@ public void evaluate(Output input, Predicate predic .scalars() .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getLong()))); } - } else if (dtype == TUint8.DTYPE) { + } else if (inputType == TUint8.class) { Output o = (Output) input; if (debug) { if (isScalar) { @@ -467,7 +467,7 @@ public void evaluate(Output input, Predicate predic .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getByte()))); } } else { - fail("Unexpected DataType: " + dtype); + fail("Unexpected Class: " + inputType); } } @@ -522,9 +522,9 @@ public void evaluate(Output expected, Output input) { : String.format( "expected shape (%s) != to input shape (%s)", expected.shape().toString(), input.shape().toString()); - DataType dtype = input.asOutput().dataType(); + Class inputType = input.asOutput().type(); boolean isScalar = input.shape().equals(Shape.scalar()); - if (dtype == TFloat32.DTYPE) { + if (inputType == TFloat32.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); @@ -550,7 +550,7 @@ public void evaluate(Output expected, Output input) { .forEachIndexed( (idx, f) -> assertEquals(x.asTensor().getFloat(idx), f.getFloat(), epsilon)); } - } else if (dtype == TFloat64.DTYPE) { + } else if (inputType == TFloat64.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); @@ -576,7 +576,7 @@ public void evaluate(Output expected, Output input) { .forEachIndexed( (idx, f) -> assertEquals(x.asTensor().getDouble(idx), f.getDouble(), epsilon)); } - } else if (dtype == TInt32.DTYPE) { + } else if (inputType == TInt32.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); @@ -601,7 +601,7 @@ public void evaluate(Output expected, Output input) { .scalars() .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getInt(idx), f.getInt())); } - } else if (dtype == TInt64.DTYPE) { + } else if (inputType == TInt64.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); @@ -626,7 +626,7 @@ public void evaluate(Output expected, Output input) { .scalars() .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getLong(idx), f.getLong())); } - } else if (dtype == TUint8.DTYPE) { + } else if (inputType == TUint8.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); @@ -651,7 +651,7 @@ public void evaluate(Output expected, Output input) { .scalars() .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getByte(idx), f.getByte())); } - } else if (dtype == TString.DTYPE) { + } else if (inputType == TString.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); @@ -676,7 +676,7 @@ public void evaluate(Output expected, Output input) { .scalars() .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getObject(idx), f.getObject())); } - } else if (dtype == TBool.DTYPE) { + } else if (inputType == TBool.class) { Output x = (Output) expected; Output o = (Output) input; AtomicInteger index = new AtomicInteger(); @@ -707,51 +707,51 @@ public void evaluate(Output expected, Output input) { /** {@inheritDoc} */ @Override public void print(PrintWriter writer, Output input) { - DataType dtype = input.asOutput().dataType(); - if (dtype == TFloat32.DTYPE) { + Class inputType = input.asOutput().type(); + if (inputType == TFloat32.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } else if (dtype == TFloat64.DTYPE) { + } else if (inputType == TFloat64.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } else if (dtype == TInt32.DTYPE) { + } else if (inputType == TInt32.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } else if (dtype == TInt64.DTYPE) { + } else if (inputType == TInt64.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } else if (dtype == TUint8.DTYPE) { + } else if (inputType == TUint8.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); - } else if (dtype == TString.DTYPE) { + } else if (inputType == TString.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); - } else if (dtype == TBool.DTYPE) { + } else if (inputType == TBool.class) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); } else { - writer.println("Unexpected DataType: " + dtype); + writer.println("Unexpected Class: " + inputType); } writer.flush(); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 6df79a4432a..33c4e064e69 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -126,8 +126,8 @@ public void run(Op op) { */ @Override public void evaluate(double expected, Operand input) { - DataType dtype = input.asOutput().dataType(); - if (dtype == TFloat32.DTYPE) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TFloat32 result = @@ -142,7 +142,7 @@ public void evaluate(double expected, Operand input) { (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { result.scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); } - } else if (dtype == TFloat64.DTYPE) { + } else if (inputType == TFloat64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TFloat64 result = @@ -157,7 +157,7 @@ public void evaluate(double expected, Operand input) { (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { result.scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); } - } else if (dtype == TInt32.DTYPE) { + } else if (inputType == TInt32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TInt32 result = @@ -172,7 +172,7 @@ public void evaluate(double expected, Operand input) { (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { result.scalars().forEach(f -> assertEquals((int) expected, f.getInt())); } - } else if (dtype == TInt64.DTYPE) { + } else if (inputType == TInt64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TInt64 result = @@ -187,7 +187,7 @@ public void evaluate(double expected, Operand input) { (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { result.scalars().forEach(f -> assertEquals((long) expected, f.getLong())); } - } else if (dtype == TUint8.DTYPE) { + } else if (inputType == TUint8.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TUint8 result = @@ -203,7 +203,7 @@ public void evaluate(double expected, Operand input) { result.scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } } else { - fail("Unexpected DataType: " + dtype); + fail("Unexpected type class: " + inputType); } } @@ -217,8 +217,8 @@ public void evaluate(Number[] expected, Output input) { expected.length, size, () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - DataType dtype = input.asOutput().dataType(); - if (dtype == TFloat32.DTYPE) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TFloat32 result = @@ -238,7 +238,7 @@ public void evaluate(Number[] expected, Output input) { assertEquals( expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); } - } else if (dtype == TFloat64.DTYPE) { + } else if (inputType == TFloat64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TFloat64 result = @@ -258,7 +258,7 @@ public void evaluate(Number[] expected, Output input) { assertEquals( expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); } - } else if (dtype == TInt32.DTYPE) { + } else if (inputType == TInt32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TInt32 result = @@ -275,7 +275,7 @@ public void evaluate(Number[] expected, Output input) { .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); } - } else if (dtype == TInt64.DTYPE) { + } else if (inputType == TInt64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TInt64 result = @@ -292,7 +292,7 @@ public void evaluate(Number[] expected, Output input) { .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); } - } else if (dtype == TUint8.DTYPE) { + } else if (inputType == TUint8.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TUint8 result = @@ -310,7 +310,7 @@ public void evaluate(Number[] expected, Output input) { .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getByte())); } } else { - fail("Unexpected DataType: " + dtype); + fail("Unexpected type class: " + inputType); } } @@ -319,8 +319,8 @@ public void evaluate(Number[] expected, Output input) { */ @Override public void evaluate(FloatNdArray expected, Output input) { - DataType dtype = input.asOutput().dataType(); - if (dtype == TFloat32.DTYPE) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { AtomicLong index = new AtomicLong(); if (debug) { try (TFloat32 result = @@ -340,7 +340,7 @@ public void evaluate(FloatNdArray expected, Output input) { assertEquals( expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); } - } else if (dtype == TFloat64.DTYPE) { + } else if (inputType == TFloat64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TFloat64 result = @@ -360,7 +360,7 @@ public void evaluate(FloatNdArray expected, Output input) { assertEquals( expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); } - } else if (dtype == TInt32.DTYPE) { + } else if (inputType == TInt32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TInt32 result = @@ -378,7 +378,7 @@ public void evaluate(FloatNdArray expected, Output input) { .forEach( f -> assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt())); } - } else if (dtype == TInt64.DTYPE) { + } else if (inputType == TInt64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TInt64 result = @@ -396,7 +396,7 @@ public void evaluate(FloatNdArray expected, Output input) { .forEach( f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); } - } else if (dtype == TUint8.DTYPE) { + } else if (inputType == TUint8.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TUint8 result = @@ -415,7 +415,7 @@ public void evaluate(FloatNdArray expected, Output input) { f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); } } else { - fail("Unexpected DataType: " + dtype); + fail("Unexpected type class: " + inputType); } } @@ -485,15 +485,15 @@ public void evaluate(Output expected, Output input) { "expected shape (%s) != to input shape (%s)", expected.shape().toString(), input.shape().toString()); AtomicInteger index = new AtomicInteger(); - DataType dtype = input.asOutput().dataType(); - if (!dtype.equals(expected.dataType())) { + Class inputType = input.type(); + if (!inputType.equals(expected.type())) { throw new IllegalArgumentException( String.format( "Both data type must be equal, inout = %s, expected = %s", - dtype, expected.dataType())); + inputType, expected.dataType())); } boolean isScalar = input.shape().equals(Shape.scalar()); - if (dtype == TFloat32.DTYPE) { + if (inputType == TFloat32.class) { final Output finalExpected = (Output) expected; if (debug) { try (TFloat32 result = @@ -531,7 +531,7 @@ public void evaluate(Output expected, Output input) { assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); } } - } else if (dtype == TFloat64.DTYPE) { + } else if (inputType == TFloat64.class) { final Output finalExpected = (Output) expected; if (debug) { try (TFloat64 result = @@ -569,7 +569,7 @@ public void evaluate(Output expected, Output input) { assertEquals(expectedResult.getDouble(idx), f.getDouble(), epsilon)); } } - } else if (dtype == TFloat16.DTYPE) { + } else if (inputType == TFloat16.class) { final Output finalExpected = (Output) expected; if (debug) { try (TFloat16 result = @@ -607,7 +607,7 @@ public void evaluate(Output expected, Output input) { assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); } } - } else if (dtype == TInt32.DTYPE) { + } else if (inputType == TInt32.class) { final Output finalExpected = (Output) expected; if (debug) { try (TInt32 result = @@ -642,7 +642,7 @@ public void evaluate(Output expected, Output input) { (idx, f) -> assertEquals(expectedResult.getInt(idx), f.getInt(), epsilon)); } } - } else if (dtype == TInt64.DTYPE) { + } else if (inputType == TInt64.class) { final Output finalExpected = (Output) expected; if (debug) { try (TInt64 result = @@ -680,7 +680,7 @@ public void evaluate(Output expected, Output input) { assertEquals(expectedResult.getLong(idx), f.getLong(), epsilon)); } } - } else if (dtype == TUint8.DTYPE) { + } else if (inputType == TUint8.class) { final Output finalExpected = (Output) expected; if (debug) { try (TUint8 result = @@ -718,7 +718,7 @@ public void evaluate(Output expected, Output input) { assertEquals(expectedResult.getByte(idx), f.getByte(), epsilon)); } } - } else if (dtype == TBool.DTYPE) { + } else if (inputType == TBool.class) { final Output finalExpected = (Output) expected; if (debug) { try (TBool result = @@ -755,7 +755,7 @@ public void evaluate(Output expected, Output input) { (idx, f) -> assertEquals(expectedResult.getBoolean(idx), f.getBoolean())); } } - } else if (dtype == TString.DTYPE) { + } else if (inputType == TString.class) { final Output finalExpected = (Output) expected; if (debug) { try (TString result = @@ -793,7 +793,7 @@ public void evaluate(Output expected, Output input) { } } } else { - fail("Unexpected DataType: " + dtype); + fail("Unexpected type class: " + inputType); } } @@ -841,9 +841,9 @@ public void evaluateString(Output input, Predicate predicate) { @Override public void evaluate(Output input, Predicate predicate) { AtomicInteger index = new AtomicInteger(); - DataType dtype = input.asOutput().dataType(); + Class inputType = input.type(); boolean isScalar = input.shape().equals(Shape.scalar()); - if (dtype == TFloat32.DTYPE) { + if (inputType == TFloat32.class) { if (debug) { try (TFloat32 result = (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { @@ -873,7 +873,7 @@ public void evaluate(Output input, Predicate predic .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getFloat()))); } } - } else if (dtype == TFloat64.DTYPE) { + } else if (inputType == TFloat64.class) { if (debug) { try (TFloat64 result = (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { @@ -903,7 +903,7 @@ public void evaluate(Output input, Predicate predic .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getDouble()))); } } - } else if (dtype == TInt32.DTYPE) { + } else if (inputType == TInt32.class) { if (debug) { try (TInt32 result = (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { @@ -932,7 +932,7 @@ public void evaluate(Output input, Predicate predic .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getInt()))); } } - } else if (dtype == TInt64.DTYPE) { + } else if (inputType == TInt64.class) { if (debug) { try (TInt64 result = (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { @@ -962,7 +962,7 @@ public void evaluate(Output input, Predicate predic .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getLong()))); } } - } else if (dtype == TUint8.DTYPE) { + } else if (inputType == TUint8.class) { if (debug) { try (TUint8 result = (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { @@ -993,7 +993,7 @@ public void evaluate(Output input, Predicate predic } } } else { - fail("Unexpected DataType: " + dtype); + fail("Unexpected type class: " + inputType); } } @@ -1004,8 +1004,8 @@ public void evaluate(Output input, Predicate predic public void print(PrintWriter writer, Output input) { boolean isScalar = input.shape().size() == 1; - DataType dtype = input.dataType(); - if (dtype == TFloat32.DTYPE) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { AtomicInteger index = new AtomicInteger(); try (TFloat32 result = (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { @@ -1018,7 +1018,7 @@ public void print(PrintWriter writer, Output input) { (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } } - } else if (dtype == TFloat64.DTYPE) { + } else if (inputType == TFloat64.class) { AtomicInteger index = new AtomicInteger(); try (TFloat64 result = @@ -1033,7 +1033,7 @@ public void print(PrintWriter writer, Output input) { (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } } - } else if (dtype == TInt32.DTYPE) { + } else if (inputType == TInt32.class) { AtomicInteger index = new AtomicInteger(); try (TInt32 result = @@ -1048,7 +1048,7 @@ public void print(PrintWriter writer, Output input) { (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } } - } else if (dtype == TInt64.DTYPE) { + } else if (inputType == TInt64.class) { AtomicInteger index = new AtomicInteger(); try (TInt64 result = @@ -1063,7 +1063,7 @@ public void print(PrintWriter writer, Output input) { (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } } - } else if (dtype == TUint8.DTYPE) { + } else if (inputType == TUint8.class) { AtomicInteger index = new AtomicInteger(); try (TUint8 result = @@ -1078,7 +1078,7 @@ public void print(PrintWriter writer, Output input) { (idx, f) -> writer.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); } } - } else if (dtype == TBool.DTYPE) { + } else if (inputType == TBool.class) { AtomicInteger index = new AtomicInteger(); try (TBool result = @@ -1093,7 +1093,7 @@ public void print(PrintWriter writer, Output input) { (idx, f) -> writer.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); } } - } else if (dtype == TString.DTYPE) { + } else if (inputType == TString.class) { AtomicInteger index = new AtomicInteger(); try (TString result = @@ -1109,7 +1109,7 @@ public void print(PrintWriter writer, Output input) { } } } else { - writer.println("Unexpected DataType: " + dtype); + writer.println("Unexpected type class: " + inputType); } writer.flush(); } diff --git a/tensorflow-framework/tensorflow-data.md b/tensorflow-framework/tensorflow-data.md index 99cd3321788..df0ec190b9f 100644 --- a/tensorflow-framework/tensorflow-data.md +++ b/tensorflow-framework/tensorflow-data.md @@ -39,8 +39,7 @@ FloatNdArray labels = NdArrays.vectorOf(0, 1, 1, 0); ``` A dataset can be constructed from a list of the constant `Operand`s generated -from this dataset, and a list of `DataType` objects corresponding -to the type of each component: +from this dataset, and a list of classes corresponding to the tensor type of each component: Note: Each of the input components must share the same first "batch" dimension. @@ -48,7 +47,7 @@ Note: Each of the input components must share the same first "batch" dimension. Ops tf = // ... TensorFlow Ops accessor (either graph or eager) Dataset dataset = Dataset.fromTensorSlices( Arrays.asList(tf.constant(features), tf.constant(labels)), - Arrays.asList(TInt32.DTYPE, TInt32.DTYPE) + Arrays.asList(TInt32.class, TInt32.class) ); ``` @@ -79,9 +78,6 @@ The primary use of a dataset is for iteration over its elements. Each row (or batch) element is represented as a list of tensor components, with type `List>`. The tensor components of this element can be accessed using `List.get(int index)`. -It is recommended to use `Tensor.expect(DataType dtype)` to restore types -to the retrieved tensors. - #### Using DatastetIterator The `DatasetIterator` class provides abstractions for creating and using iterators in graph and eager mode. These will be explained here; however @@ -89,7 +85,7 @@ end-users should only interact with `DatasetIterator` objects through the method provided in the `Dataset` class (examples to follow). To construct an iterator for a dataset of a specific structure, use -the static method `DatasetIterator.fromStructure(Ops tf, List> outputTypes, List outputShapes)`. This creates a `DatasetIterator` object +the static method `DatasetIterator.fromStructure(Ops tf, List> outputTypes, List outputShapes)`. This creates a `DatasetIterator` object which can be used with any dataset of a matching structure. Once a `DatasetIterator` is created, it can be initialized on a `Dataset` intsance using `DatasetIterator.makeInitializer(Dataset dataset)`. This will initialize (or re-initialize) the iterator to start at the beginning