Skip to content

Commit f85623e

Browse files
authored
Use standard Java type system for typing tensors
* Leverage the Java type system for typing tensors * Cleanup some obsolete imports and comments * Restrict tensor types on some initializers * Document a few exception cases and other cleanups
1 parent 3f3f384 commit f85623e

File tree

659 files changed

+2594
-3727
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

659 files changed

+2594
-3727
lines changed

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/java_defs.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,6 @@ class Type {
9797
static Type IterableOf(const Type& type) {
9898
return Interface("Iterable").add_parameter(type);
9999
}
100-
static Type DataTypeOf(const Type& type) {
101-
return Class("DataType", "org.tensorflow").add_parameter(type);
102-
}
103100
static Type ForDataType(DataType data_type) {
104101
switch (data_type) {
105102
case DataType::DT_BOOL:

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -103,39 +103,55 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
103103
}
104104
for (const AttributeSpec& attribute : op.attributes()) {
105105
out->push_back(attribute.var().type());
106-
out->push_back(attribute.jni_type());
106+
if (attribute.jni_type().name() == "DataType") {
107+
out->push_back(Type::Class("Operands", "org.tensorflow.op"));
108+
} else {
109+
out->push_back(attribute.jni_type());
110+
}
107111
if (attribute.has_default_value() &&
108112
attribute.type().kind() == Type::GENERIC) {
109113
out->push_back(Type::ForDataType(attribute.default_value()->type()));
110114
}
111115
}
112116
for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
117+
if (optional_attribute.jni_type().name() == "DataType") {
118+
out->push_back(Type::Class("Operands", "org.tensorflow.op"));
119+
} else {
120+
out->push_back(optional_attribute.jni_type());
121+
}
113122
out->push_back(optional_attribute.var().type());
114123
}
115124
}
116125

117126
void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
118127
SourceWriter* writer) {
119128
string var_name = optional ? "opts." + attr.var().name() : attr.var().name();
120-
if (attr.iterable()) {
121-
string array_name = attr.var().name() + "Array";
122-
writer->AppendType(attr.jni_type())
123-
.Append("[] " + array_name + " = new ")
124-
.AppendType(attr.jni_type())
125-
.Append("[" + var_name + ".size()];")
126-
.EndLine()
127-
.BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
128-
.Append(array_name + "[i] = ");
129-
writer->Append(var_name + ".get(i);");
130-
writer->EndLine()
131-
.EndBlock()
132-
.Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
133-
.Append(array_name + ");")
134-
.EndLine();
135-
} else {
129+
if (attr.jni_type().name() == "DataType") {
136130
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
137-
.Append(var_name + ");")
138-
.EndLine();
131+
.Append(attr.iterable() ? "Operands.toDataTypes(" : "Operands.toDataType(")
132+
.Append(attr.var().name() + "));")
133+
.EndLine();
134+
} else {
135+
if (attr.iterable()) {
136+
string array_name = attr.var().name() + "Array";
137+
writer->AppendType(attr.jni_type())
138+
.Append("[] " + array_name + " = new ")
139+
.AppendType(attr.jni_type())
140+
.Append("[" + var_name + ".size()];")
141+
.EndLine()
142+
.BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
143+
.Append(array_name + "[i] = ");
144+
writer->Append(var_name + ".get(i);");
145+
writer->EndLine()
146+
.EndBlock()
147+
.Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
148+
.Append(array_name + ");")
149+
.EndLine();
150+
} else {
151+
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
152+
.Append(var_name + ");")
153+
.EndLine();
154+
}
139155
}
140156
}
141157

@@ -177,7 +193,7 @@ void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
177193
if (attr.type().kind() == Type::GENERIC &&
178194
default_types.find(attr.type().name()) != default_types.end()) {
179195
factory_statement << default_types.at(attr.type().name()).name()
180-
<< ".DTYPE";
196+
<< ".class";
181197
} else {
182198
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
183199
factory_statement << attr.var().name();

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,19 @@ class TypeResolver {
8181
std::pair<Type, Type> MakeTypePair(const Type& type) {
8282
return std::make_pair(type, type);
8383
}
84-
Type NextGeneric() {
84+
Type NextGeneric(const OpDef_AttrDef& attr_def) {
8585
char generic_letter = next_generic_letter_++;
8686
if (next_generic_letter_ > 'Z') {
8787
next_generic_letter_ = 'A';
8888
}
89-
return Type::Generic(string(1, generic_letter))
90-
.add_supertype(Type::Class("TType", "org.tensorflow.types.family"));
89+
return Type::Generic(string(1, generic_letter));
90+
}
91+
Type TypeFamilyOf(const OpDef_AttrDef& attr_def) {
92+
// TODO(karllessard) support more type families
93+
if (IsRealNumbers(attr_def.allowed_values())) {
94+
return Type::Interface("TNumber", "org.tensorflow.types.family");
95+
}
96+
return Type::Interface("TType", "org.tensorflow.types.family");
9197
}
9298
};
9399

@@ -155,11 +161,9 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
155161
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow"));
156162

157163
} else if (attr_type == "type") {
158-
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
159-
if (IsRealNumbers(attr_def.allowed_values())) {
160-
type.add_supertype(Type::Class("TNumber", "org.tensorflow.types.family"));
161-
}
162-
types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));
164+
Type type = *iterable_out ? Type::Wildcard() : NextGeneric(attr_def);
165+
type.add_supertype(TypeFamilyOf(attr_def));
166+
types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow.proto.framework"));
163167

164168
} else {
165169
LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
@@ -305,7 +309,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
305309
bool iterable = false;
306310
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
307311
Type var_type = types.first.kind() == Type::GENERIC
308-
? Type::DataTypeOf(types.first)
312+
? Type::ClassOf(types.first)
309313
: types.first;
310314
if (iterable) {
311315
var_type = Type::ListOf(var_type);

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.cc

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) {
8585
SourceWriter& SourceWriter::AppendType(const Type& type) {
8686
if (type.wildcard()) {
8787
Append("?");
88+
WriteTypeBounds(type.supertypes());
8889
} else {
8990
Append(type.name());
9091
if (!type.parameters().empty()) {
@@ -321,14 +322,27 @@ SourceWriter& SourceWriter::WriteGenerics(
321322
Append(", ");
322323
}
323324
Append(pt->name());
324-
if (!pt->supertypes().empty()) {
325-
Append(" extends ").AppendType(pt->supertypes().front());
326-
}
325+
WriteTypeBounds(pt->supertypes());
327326
first = false;
328327
}
329328
return Append(">");
330329
}
331330

331+
SourceWriter& SourceWriter::WriteTypeBounds(
332+
const std::list<Type>& bounds) {
333+
bool first = true;
334+
for (const Type& bound : bounds) {
335+
if (first) {
336+
Append(" extends ");
337+
first = false;
338+
} else {
339+
Append(" & ");
340+
}
341+
AppendType(bound);
342+
}
343+
return *this;
344+
}
345+
332346
SourceWriter::GenericNamespace* SourceWriter::PushGenericNamespace(
333347
int modifiers) {
334348
GenericNamespace* generic_namespace;

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ class SourceWriter {
213213
SourceWriter& WriteJavadoc(const Javadoc& javadoc);
214214
SourceWriter& WriteAnnotations(const std::list<Annotation>& annotations);
215215
SourceWriter& WriteGenerics(const std::list<const Type*>& generics);
216+
SourceWriter& WriteTypeBounds(const std::list<Type>& bounds);
216217
GenericNamespace* PushGenericNamespace(int modifiers);
217218
void PopGenericNamespace();
218219
};

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataExperimentalOps.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
package org.tensorflow.op;
1919

2020
import java.util.List;
21-
import org.tensorflow.DataType;
2221
import org.tensorflow.Operand;
2322
import org.tensorflow.ndarray.Shape;
2423
import org.tensorflow.op.data.experimental.DataServiceDataset;
2524
import org.tensorflow.types.TInt64;
2625
import org.tensorflow.types.TString;
26+
import org.tensorflow.types.family.TType;
2727

2828
/**
2929
* An API for building {@code data.experimental} operations as {@link Op Op}s
@@ -57,7 +57,7 @@ public final class DataExperimentalOps {
5757
public DataServiceDataset dataServiceDataset(Operand<TInt64> datasetId,
5858
Operand<TString> processingMode, Operand<TString> address, Operand<TString> protocol,
5959
Operand<TString> jobName, Operand<TInt64> maxOutstandingRequests, Operand<?> iterationCounter,
60-
List<DataType<?>> outputTypes, List<Shape> outputShapes,
60+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes,
6161
DataServiceDataset.Options... options) {
6262
return DataServiceDataset.create(scope, datasetId, processingMode, address, protocol, jobName, maxOutstandingRequests, iterationCounter, outputTypes, outputShapes, options);
6363
}

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.tensorflow.op;
1919

2020
import java.util.List;
21-
import org.tensorflow.DataType;
2221
import org.tensorflow.Operand;
2322
import org.tensorflow.ndarray.Shape;
2423
import org.tensorflow.op.data.AnonymousIterator;
@@ -49,6 +48,7 @@
4948
import org.tensorflow.types.TBool;
5049
import org.tensorflow.types.TInt64;
5150
import org.tensorflow.types.TString;
51+
import org.tensorflow.types.family.TType;
5252

5353
/**
5454
* An API for building {@code data} operations as {@link Op Op}s
@@ -75,7 +75,7 @@ public final class DataOps {
7575
* @param outputShapes
7676
* @return a new instance of AnonymousIterator
7777
*/
78-
public AnonymousIterator anonymousIterator(List<DataType<?>> outputTypes,
78+
public AnonymousIterator anonymousIterator(List<Class<? extends TType>> outputTypes,
7979
List<Shape> outputShapes) {
8080
return AnonymousIterator.create(scope, outputTypes, outputShapes);
8181
}
@@ -93,8 +93,8 @@ public AnonymousIterator anonymousIterator(List<DataType<?>> outputTypes,
9393
* @return a new instance of BatchDataset
9494
*/
9595
public BatchDataset batchDataset(Operand<?> inputDataset, Operand<TInt64> batchSize,
96-
Operand<TBool> dropRemainder, List<DataType<?>> outputTypes, List<Shape> outputShapes,
97-
BatchDataset.Options... options) {
96+
Operand<TBool> dropRemainder, List<Class<? extends TType>> outputTypes,
97+
List<Shape> outputShapes, BatchDataset.Options... options) {
9898
return BatchDataset.create(scope, inputDataset, batchSize, dropRemainder, outputTypes, outputShapes, options);
9999
}
100100

@@ -129,7 +129,7 @@ public CSVDataset cSVDataset(Operand<TString> filenames, Operand<TString> compre
129129
* @return a new instance of ConcatenateDataset
130130
*/
131131
public ConcatenateDataset concatenateDataset(Operand<?> inputDataset, Operand<?> anotherDataset,
132-
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
132+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
133133
return ConcatenateDataset.create(scope, inputDataset, anotherDataset, outputTypes, outputShapes);
134134
}
135135

@@ -164,8 +164,8 @@ public DeserializeIterator deserializeIterator(Operand<?> resourceHandle, Operan
164164
* @param outputShapes
165165
* @return a new instance of Iterator
166166
*/
167-
public Iterator iterator(String sharedName, String container, List<DataType<?>> outputTypes,
168-
List<Shape> outputShapes) {
167+
public Iterator iterator(String sharedName, String container,
168+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
169169
return Iterator.create(scope, sharedName, container, outputTypes, outputShapes);
170170
}
171171

@@ -177,8 +177,8 @@ public Iterator iterator(String sharedName, String container, List<DataType<?>>
177177
* @param outputShapes
178178
* @return a new instance of IteratorGetNext
179179
*/
180-
public IteratorGetNext iteratorGetNext(Operand<?> iterator, List<DataType<?>> outputTypes,
181-
List<Shape> outputShapes) {
180+
public IteratorGetNext iteratorGetNext(Operand<?> iterator,
181+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
182182
return IteratorGetNext.create(scope, iterator, outputTypes, outputShapes);
183183
}
184184

@@ -191,7 +191,7 @@ public IteratorGetNext iteratorGetNext(Operand<?> iterator, List<DataType<?>> ou
191191
* @return a new instance of IteratorGetNextAsOptional
192192
*/
193193
public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand<?> iterator,
194-
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
194+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
195195
return IteratorGetNextAsOptional.create(scope, iterator, outputTypes, outputShapes);
196196
}
197197

@@ -208,8 +208,8 @@ public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand<?> iterator,
208208
* @param outputShapes
209209
* @return a new instance of IteratorGetNextSync
210210
*/
211-
public IteratorGetNextSync iteratorGetNextSync(Operand<?> iterator, List<DataType<?>> outputTypes,
212-
List<Shape> outputShapes) {
211+
public IteratorGetNextSync iteratorGetNextSync(Operand<?> iterator,
212+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
213213
return IteratorGetNextSync.create(scope, iterator, outputTypes, outputShapes);
214214
}
215215

@@ -255,8 +255,8 @@ public OptionalFromValue optionalFromValue(Iterable<Operand<?>> components) {
255255
* @param outputShapes
256256
* @return a new instance of OptionalGetValue
257257
*/
258-
public OptionalGetValue optionalGetValue(Operand<?> optional, List<DataType<?>> outputTypes,
259-
List<Shape> outputShapes) {
258+
public OptionalGetValue optionalGetValue(Operand<?> optional,
259+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
260260
return OptionalGetValue.create(scope, optional, outputTypes, outputShapes);
261261
}
262262

@@ -290,7 +290,7 @@ public OptionalNone optionalNone() {
290290
* @return a new instance of RangeDataset
291291
*/
292292
public RangeDataset rangeDataset(Operand<TInt64> start, Operand<TInt64> stop,
293-
Operand<TInt64> step, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
293+
Operand<TInt64> step, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
294294
return RangeDataset.create(scope, start, stop, step, outputTypes, outputShapes);
295295
}
296296

@@ -305,7 +305,7 @@ public RangeDataset rangeDataset(Operand<TInt64> start, Operand<TInt64> stop,
305305
* @return a new instance of RepeatDataset
306306
*/
307307
public RepeatDataset repeatDataset(Operand<?> inputDataset, Operand<TInt64> count,
308-
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
308+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
309309
return RepeatDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
310310
}
311311

@@ -332,7 +332,7 @@ public SerializeIterator serializeIterator(Operand<?> resourceHandle,
332332
* @return a new instance of SkipDataset
333333
*/
334334
public SkipDataset skipDataset(Operand<?> inputDataset, Operand<TInt64> count,
335-
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
335+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
336336
return SkipDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
337337
}
338338

@@ -348,7 +348,7 @@ public SkipDataset skipDataset(Operand<?> inputDataset, Operand<TInt64> count,
348348
* @return a new instance of TakeDataset
349349
*/
350350
public TakeDataset takeDataset(Operand<?> inputDataset, Operand<TInt64> count,
351-
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
351+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
352352
return TakeDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
353353
}
354354

@@ -409,8 +409,8 @@ public TfRecordDataset tfRecordDataset(Operand<TString> filenames,
409409
* @param outputShapes
410410
* @return a new instance of ZipDataset
411411
*/
412-
public ZipDataset zipDataset(Iterable<Operand<?>> inputDatasets, List<DataType<?>> outputTypes,
413-
List<Shape> outputShapes) {
412+
public ZipDataset zipDataset(Iterable<Operand<?>> inputDatasets,
413+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
414414
return ZipDataset.create(scope, inputDatasets, outputTypes, outputShapes);
415415
}
416416

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DtypesOps.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
//
1818
package org.tensorflow.op;
1919

20-
import org.tensorflow.DataType;
2120
import org.tensorflow.Operand;
2221
import org.tensorflow.op.dtypes.AsString;
2322
import org.tensorflow.op.dtypes.Cast;
@@ -73,7 +72,7 @@ public <T extends TType> AsString asString(Operand<T> input, AsString.Options...
7372
* @param options carries optional attributes values
7473
* @return a new instance of Cast
7574
*/
76-
public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, DataType<U> DstT,
75+
public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, Class<U> DstT,
7776
Cast.Options... options) {
7877
return Cast.create(scope, x, DstT, options);
7978
}
@@ -102,7 +101,7 @@ public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, DataType<U>
102101
* @return a new instance of Complex
103102
*/
104103
public <U extends TType, T extends TNumber> Complex<U> complex(Operand<T> real, Operand<T> imag,
105-
DataType<U> Tout) {
104+
Class<U> Tout) {
106105
return Complex.create(scope, real, imag, Tout);
107106
}
108107

0 commit comments

Comments
 (0)