Skip to content

[Type Refactor] Use Java type system instead of custom one for typing tensors #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 30, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -97,9 +97,6 @@ class Type {
static Type IterableOf(const Type& type) {
return Interface("Iterable").add_parameter(type);
}
static Type DataTypeOf(const Type& type) {
return Class("DataType", "org.tensorflow").add_parameter(type);
}
static Type ForDataType(DataType data_type) {
switch (data_type) {
case DataType::DT_BOOL:
Original file line number Diff line number Diff line change
@@ -103,39 +103,55 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
}
for (const AttributeSpec& attribute : op.attributes()) {
out->push_back(attribute.var().type());
out->push_back(attribute.jni_type());
if (attribute.jni_type().name() == "DataType") {
out->push_back(Type::Class("Operands", "org.tensorflow.op"));
} else {
out->push_back(attribute.jni_type());
}
if (attribute.has_default_value() &&
attribute.type().kind() == Type::GENERIC) {
out->push_back(Type::ForDataType(attribute.default_value()->type()));
}
}
for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
if (optional_attribute.jni_type().name() == "DataType") {
out->push_back(Type::Class("Operands", "org.tensorflow.op"));
} else {
out->push_back(optional_attribute.jni_type());
}
out->push_back(optional_attribute.var().type());
}
}

void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
SourceWriter* writer) {
string var_name = optional ? "opts." + attr.var().name() : attr.var().name();
if (attr.iterable()) {
string array_name = attr.var().name() + "Array";
writer->AppendType(attr.jni_type())
.Append("[] " + array_name + " = new ")
.AppendType(attr.jni_type())
.Append("[" + var_name + ".size()];")
.EndLine()
.BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
.Append(array_name + "[i] = ");
writer->Append(var_name + ".get(i);");
writer->EndLine()
.EndBlock()
.Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
.Append(array_name + ");")
.EndLine();
} else {
if (attr.jni_type().name() == "DataType") {
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
.Append(var_name + ");")
.EndLine();
.Append(attr.iterable() ? "Operands.toDataTypes(" : "Operands.toDataType(")
.Append(attr.var().name() + "));")
.EndLine();
} else {
if (attr.iterable()) {
string array_name = attr.var().name() + "Array";
writer->AppendType(attr.jni_type())
.Append("[] " + array_name + " = new ")
.AppendType(attr.jni_type())
.Append("[" + var_name + ".size()];")
.EndLine()
.BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
.Append(array_name + "[i] = ");
writer->Append(var_name + ".get(i);");
writer->EndLine()
.EndBlock()
.Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
.Append(array_name + ");")
.EndLine();
} else {
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
.Append(var_name + ");")
.EndLine();
}
}
}

@@ -177,7 +193,7 @@ void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
if (attr.type().kind() == Type::GENERIC &&
default_types.find(attr.type().name()) != default_types.end()) {
factory_statement << default_types.at(attr.type().name()).name()
<< ".DTYPE";
<< ".class";
} else {
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
factory_statement << attr.var().name();
Original file line number Diff line number Diff line change
@@ -81,13 +81,19 @@ class TypeResolver {
std::pair<Type, Type> MakeTypePair(const Type& type) {
return std::make_pair(type, type);
}
Type NextGeneric() {
Type NextGeneric(const OpDef_AttrDef& attr_def) {
char generic_letter = next_generic_letter_++;
if (next_generic_letter_ > 'Z') {
next_generic_letter_ = 'A';
}
return Type::Generic(string(1, generic_letter))
.add_supertype(Type::Class("TType", "org.tensorflow.types.family"));
return Type::Generic(string(1, generic_letter));
}
Type TypeFamilyOf(const OpDef_AttrDef& attr_def) {
// TODO(karllessard) support more type families
if (IsRealNumbers(attr_def.allowed_values())) {
return Type::Interface("TNumber", "org.tensorflow.types.family");
}
return Type::Interface("TType", "org.tensorflow.types.family");
}
};

@@ -155,11 +161,9 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow"));

} else if (attr_type == "type") {
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
if (IsRealNumbers(attr_def.allowed_values())) {
type.add_supertype(Type::Class("TNumber", "org.tensorflow.types.family"));
}
types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));
Type type = *iterable_out ? Type::Wildcard() : NextGeneric(attr_def);
type.add_supertype(TypeFamilyOf(attr_def));
types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow.proto.framework"));

} else {
LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
@@ -305,7 +309,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
bool iterable = false;
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
Type var_type = types.first.kind() == Type::GENERIC
? Type::DataTypeOf(types.first)
? Type::ClassOf(types.first)
: types.first;
if (iterable) {
var_type = Type::ListOf(var_type);
Original file line number Diff line number Diff line change
@@ -85,6 +85,7 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) {
SourceWriter& SourceWriter::AppendType(const Type& type) {
if (type.wildcard()) {
Append("?");
WriteTypeBounds(type.supertypes());
} else {
Append(type.name());
if (!type.parameters().empty()) {
@@ -321,14 +322,27 @@ SourceWriter& SourceWriter::WriteGenerics(
Append(", ");
}
Append(pt->name());
if (!pt->supertypes().empty()) {
Append(" extends ").AppendType(pt->supertypes().front());
}
WriteTypeBounds(pt->supertypes());
first = false;
}
return Append(">");
}

SourceWriter& SourceWriter::WriteTypeBounds(
const std::list<Type>& bounds) {
bool first = true;
for (const Type& bound : bounds) {
if (first) {
Append(" extends ");
first = false;
} else {
Append(" & ");
}
AppendType(bound);
}
return *this;
}

SourceWriter::GenericNamespace* SourceWriter::PushGenericNamespace(
int modifiers) {
GenericNamespace* generic_namespace;
Original file line number Diff line number Diff line change
@@ -213,6 +213,7 @@ class SourceWriter {
SourceWriter& WriteJavadoc(const Javadoc& javadoc);
SourceWriter& WriteAnnotations(const std::list<Annotation>& annotations);
SourceWriter& WriteGenerics(const std::list<const Type*>& generics);
SourceWriter& WriteTypeBounds(const std::list<Type>& bounds);
GenericNamespace* PushGenericNamespace(int modifiers);
void PopGenericNamespace();
};
Original file line number Diff line number Diff line change
@@ -18,12 +18,12 @@
package org.tensorflow.op;

import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.data.experimental.DataServiceDataset;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;

/**
* An API for building {@code data.experimental} operations as {@link Op Op}s
@@ -57,7 +57,7 @@ public final class DataExperimentalOps {
public DataServiceDataset dataServiceDataset(Operand<TInt64> datasetId,
Operand<TString> processingMode, Operand<TString> address, Operand<TString> protocol,
Operand<TString> jobName, Operand<TInt64> maxOutstandingRequests, Operand<?> iterationCounter,
List<DataType<?>> outputTypes, List<Shape> outputShapes,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes,
DataServiceDataset.Options... options) {
return DataServiceDataset.create(scope, datasetId, processingMode, address, protocol, jobName, maxOutstandingRequests, iterationCounter, outputTypes, outputShapes, options);
}
Original file line number Diff line number Diff line change
@@ -18,7 +18,6 @@
package org.tensorflow.op;

import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.data.AnonymousIterator;
@@ -49,6 +48,7 @@
import org.tensorflow.types.TBool;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;

/**
* An API for building {@code data} operations as {@link Op Op}s
@@ -75,7 +75,7 @@ public final class DataOps {
* @param outputShapes
* @return a new instance of AnonymousIterator
*/
public AnonymousIterator anonymousIterator(List<DataType<?>> outputTypes,
public AnonymousIterator anonymousIterator(List<Class<? extends TType>> outputTypes,
List<Shape> outputShapes) {
return AnonymousIterator.create(scope, outputTypes, outputShapes);
}
@@ -93,8 +93,8 @@ public AnonymousIterator anonymousIterator(List<DataType<?>> outputTypes,
* @return a new instance of BatchDataset
*/
public BatchDataset batchDataset(Operand<?> inputDataset, Operand<TInt64> batchSize,
Operand<TBool> dropRemainder, List<DataType<?>> outputTypes, List<Shape> outputShapes,
BatchDataset.Options... options) {
Operand<TBool> dropRemainder, List<Class<? extends TType>> outputTypes,
List<Shape> outputShapes, BatchDataset.Options... options) {
return BatchDataset.create(scope, inputDataset, batchSize, dropRemainder, outputTypes, outputShapes, options);
}

@@ -129,7 +129,7 @@ public CSVDataset cSVDataset(Operand<TString> filenames, Operand<TString> compre
* @return a new instance of ConcatenateDataset
*/
public ConcatenateDataset concatenateDataset(Operand<?> inputDataset, Operand<?> anotherDataset,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return ConcatenateDataset.create(scope, inputDataset, anotherDataset, outputTypes, outputShapes);
}

@@ -164,8 +164,8 @@ public DeserializeIterator deserializeIterator(Operand<?> resourceHandle, Operan
* @param outputShapes
* @return a new instance of Iterator
*/
public Iterator iterator(String sharedName, String container, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public Iterator iterator(String sharedName, String container,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return Iterator.create(scope, sharedName, container, outputTypes, outputShapes);
}

@@ -177,8 +177,8 @@ public Iterator iterator(String sharedName, String container, List<DataType<?>>
* @param outputShapes
* @return a new instance of IteratorGetNext
*/
public IteratorGetNext iteratorGetNext(Operand<?> iterator, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public IteratorGetNext iteratorGetNext(Operand<?> iterator,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return IteratorGetNext.create(scope, iterator, outputTypes, outputShapes);
}

@@ -191,7 +191,7 @@ public IteratorGetNext iteratorGetNext(Operand<?> iterator, List<DataType<?>> ou
* @return a new instance of IteratorGetNextAsOptional
*/
public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand<?> iterator,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return IteratorGetNextAsOptional.create(scope, iterator, outputTypes, outputShapes);
}

@@ -208,8 +208,8 @@ public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand<?> iterator,
* @param outputShapes
* @return a new instance of IteratorGetNextSync
*/
public IteratorGetNextSync iteratorGetNextSync(Operand<?> iterator, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public IteratorGetNextSync iteratorGetNextSync(Operand<?> iterator,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return IteratorGetNextSync.create(scope, iterator, outputTypes, outputShapes);
}

@@ -255,8 +255,8 @@ public OptionalFromValue optionalFromValue(Iterable<Operand<?>> components) {
* @param outputShapes
* @return a new instance of OptionalGetValue
*/
public OptionalGetValue optionalGetValue(Operand<?> optional, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public OptionalGetValue optionalGetValue(Operand<?> optional,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return OptionalGetValue.create(scope, optional, outputTypes, outputShapes);
}

@@ -290,7 +290,7 @@ public OptionalNone optionalNone() {
* @return a new instance of RangeDataset
*/
public RangeDataset rangeDataset(Operand<TInt64> start, Operand<TInt64> stop,
Operand<TInt64> step, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
Operand<TInt64> step, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return RangeDataset.create(scope, start, stop, step, outputTypes, outputShapes);
}

@@ -305,7 +305,7 @@ public RangeDataset rangeDataset(Operand<TInt64> start, Operand<TInt64> stop,
* @return a new instance of RepeatDataset
*/
public RepeatDataset repeatDataset(Operand<?> inputDataset, Operand<TInt64> count,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return RepeatDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
}

@@ -332,7 +332,7 @@ public SerializeIterator serializeIterator(Operand<?> resourceHandle,
* @return a new instance of SkipDataset
*/
public SkipDataset skipDataset(Operand<?> inputDataset, Operand<TInt64> count,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return SkipDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
}

@@ -348,7 +348,7 @@ public SkipDataset skipDataset(Operand<?> inputDataset, Operand<TInt64> count,
* @return a new instance of TakeDataset
*/
public TakeDataset takeDataset(Operand<?> inputDataset, Operand<TInt64> count,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return TakeDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
}

@@ -409,8 +409,8 @@ public TfRecordDataset tfRecordDataset(Operand<TString> filenames,
* @param outputShapes
* @return a new instance of ZipDataset
*/
public ZipDataset zipDataset(Iterable<Operand<?>> inputDatasets, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public ZipDataset zipDataset(Iterable<Operand<?>> inputDatasets,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return ZipDataset.create(scope, inputDatasets, outputTypes, outputShapes);
}

Original file line number Diff line number Diff line change
@@ -17,7 +17,6 @@
//
package org.tensorflow.op;

import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.op.dtypes.AsString;
import org.tensorflow.op.dtypes.Cast;
@@ -73,7 +72,7 @@ public <T extends TType> AsString asString(Operand<T> input, AsString.Options...
* @param options carries optional attributes values
* @return a new instance of Cast
*/
public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, DataType<U> DstT,
public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, Class<U> DstT,
Cast.Options... options) {
return Cast.create(scope, x, DstT, options);
}
@@ -102,7 +101,7 @@ public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, DataType<U>
* @return a new instance of Complex
*/
public <U extends TType, T extends TNumber> Complex<U> complex(Operand<T> real, Operand<T> imag,
DataType<U> Tout) {
Class<U> Tout) {
return Complex.create(scope, real, imag, Tout);
}

Loading