Skip to content

Commit ba3d471

Browse files
committed
Leverage the Java type system for typing tensors
1 parent 3f3f384 commit ba3d471

File tree

646 files changed

+2646
-3586
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

646 files changed

+2646
-3586
lines changed

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/java_defs.h

-3
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,6 @@ class Type {
9797
static Type IterableOf(const Type& type) {
9898
return Interface("Iterable").add_parameter(type);
9999
}
100-
static Type DataTypeOf(const Type& type) {
101-
return Class("DataType", "org.tensorflow").add_parameter(type);
102-
}
103100
static Type ForDataType(DataType data_type) {
104101
switch (data_type) {
105102
case DataType::DT_BOOL:

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc

+32-19
Original file line numberDiff line numberDiff line change
@@ -104,38 +104,51 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
104104
for (const AttributeSpec& attribute : op.attributes()) {
105105
out->push_back(attribute.var().type());
106106
out->push_back(attribute.jni_type());
107+
if (attribute.jni_type().name() == "DataType") {
108+
out->push_back(Type::Class("Operands", "org.tensorflow.op"));
109+
}
107110
if (attribute.has_default_value() &&
108111
attribute.type().kind() == Type::GENERIC) {
109112
out->push_back(Type::ForDataType(attribute.default_value()->type()));
110113
}
111114
}
112115
for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
116+
if (optional_attribute.jni_type().name() == "DataType") {
117+
out->push_back(Type::Class("Operands", "org.tensorflow.op"));
118+
}
113119
out->push_back(optional_attribute.var().type());
114120
}
115121
}
116122

117123
void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
118124
SourceWriter* writer) {
119125
string var_name = optional ? "opts." + attr.var().name() : attr.var().name();
120-
if (attr.iterable()) {
121-
string array_name = attr.var().name() + "Array";
122-
writer->AppendType(attr.jni_type())
123-
.Append("[] " + array_name + " = new ")
124-
.AppendType(attr.jni_type())
125-
.Append("[" + var_name + ".size()];")
126-
.EndLine()
127-
.BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
128-
.Append(array_name + "[i] = ");
129-
writer->Append(var_name + ".get(i);");
130-
writer->EndLine()
131-
.EndBlock()
132-
.Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
133-
.Append(array_name + ");")
134-
.EndLine();
135-
} else {
126+
if (attr.jni_type().name() == "DataType") {
136127
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
137-
.Append(var_name + ");")
138-
.EndLine();
128+
.Append(attr.iterable() ? "Operands.toDataTypes(" : "Operands.toDataType(")
129+
.Append(attr.var().name() + "));")
130+
.EndLine();
131+
} else {
132+
if (attr.iterable()) {
133+
string array_name = attr.var().name() + "Array";
134+
writer->AppendType(attr.jni_type())
135+
.Append("[] " + array_name + " = new ")
136+
.AppendType(attr.jni_type())
137+
.Append("[" + var_name + ".size()];")
138+
.EndLine()
139+
.BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
140+
.Append(array_name + "[i] = ");
141+
writer->Append(var_name + ".get(i);");
142+
writer->EndLine()
143+
.EndBlock()
144+
.Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
145+
.Append(array_name + ");")
146+
.EndLine();
147+
} else {
148+
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
149+
.Append(var_name + ");")
150+
.EndLine();
151+
}
139152
}
140153
}
141154

@@ -177,7 +190,7 @@ void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
177190
if (attr.type().kind() == Type::GENERIC &&
178191
default_types.find(attr.type().name()) != default_types.end()) {
179192
factory_statement << default_types.at(attr.type().name()).name()
180-
<< ".DTYPE";
193+
<< ".class";
181194
} else {
182195
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
183196
factory_statement << attr.var().name();

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc

+13-9
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,19 @@ class TypeResolver {
8181
std::pair<Type, Type> MakeTypePair(const Type& type) {
8282
return std::make_pair(type, type);
8383
}
84-
Type NextGeneric() {
84+
Type NextGeneric(const OpDef_AttrDef& attr_def) {
8585
char generic_letter = next_generic_letter_++;
8686
if (next_generic_letter_ > 'Z') {
8787
next_generic_letter_ = 'A';
8888
}
89-
return Type::Generic(string(1, generic_letter))
90-
.add_supertype(Type::Class("TType", "org.tensorflow.types.family"));
89+
return Type::Generic(string(1, generic_letter));
90+
}
91+
Type TypeFamilyOf(const OpDef_AttrDef& attr_def) {
92+
// TODO(karllessard) support more type families
93+
if (IsRealNumbers(attr_def.allowed_values())) {
94+
return Type::Interface("TNumber", "org.tensorflow.types.family");
95+
}
96+
return Type::Interface("TType", "org.tensorflow.types.family");
9197
}
9298
};
9399

@@ -155,11 +161,9 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
155161
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow"));
156162

157163
} else if (attr_type == "type") {
158-
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
159-
if (IsRealNumbers(attr_def.allowed_values())) {
160-
type.add_supertype(Type::Class("TNumber", "org.tensorflow.types.family"));
161-
}
162-
types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));
164+
Type type = *iterable_out ? Type::Wildcard() : NextGeneric(attr_def);
165+
type.add_supertype(TypeFamilyOf(attr_def));
166+
types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow.proto.framework"));
163167

164168
} else {
165169
LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
@@ -305,7 +309,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
305309
bool iterable = false;
306310
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
307311
Type var_type = types.first.kind() == Type::GENERIC
308-
? Type::DataTypeOf(types.first)
312+
? Type::ClassOf(types.first)
309313
: types.first;
310314
if (iterable) {
311315
var_type = Type::ListOf(var_type);

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.cc

+17-3
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) {
8585
SourceWriter& SourceWriter::AppendType(const Type& type) {
8686
if (type.wildcard()) {
8787
Append("?");
88+
WriteTypeBounds(type.supertypes());
8889
} else {
8990
Append(type.name());
9091
if (!type.parameters().empty()) {
@@ -321,14 +322,27 @@ SourceWriter& SourceWriter::WriteGenerics(
321322
Append(", ");
322323
}
323324
Append(pt->name());
324-
if (!pt->supertypes().empty()) {
325-
Append(" extends ").AppendType(pt->supertypes().front());
326-
}
325+
WriteTypeBounds(pt->supertypes());
327326
first = false;
328327
}
329328
return Append(">");
330329
}
331330

331+
SourceWriter& SourceWriter::WriteTypeBounds(
332+
const std::list<Type>& bounds) {
333+
bool first = true;
334+
for (const Type& bound : bounds) {
335+
if (first) {
336+
Append(" extends ");
337+
first = false;
338+
} else {
339+
Append(" & ");
340+
}
341+
AppendType(bound);
342+
}
343+
return *this;
344+
}
345+
332346
SourceWriter::GenericNamespace* SourceWriter::PushGenericNamespace(
333347
int modifiers) {
334348
GenericNamespace* generic_namespace;

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/source_writer.h

+1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ class SourceWriter {
213213
SourceWriter& WriteJavadoc(const Javadoc& javadoc);
214214
SourceWriter& WriteAnnotations(const std::list<Annotation>& annotations);
215215
SourceWriter& WriteGenerics(const std::list<const Type*>& generics);
216+
SourceWriter& WriteTypeBounds(const std::list<Type>& bounds);
216217
GenericNamespace* PushGenericNamespace(int modifiers);
217218
void PopGenericNamespace();
218219
};

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataExperimentalOps.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
package org.tensorflow.op;
1919

2020
import java.util.List;
21-
import org.tensorflow.DataType;
2221
import org.tensorflow.Operand;
2322
import org.tensorflow.ndarray.Shape;
2423
import org.tensorflow.op.data.experimental.DataServiceDataset;
2524
import org.tensorflow.types.TInt64;
2625
import org.tensorflow.types.TString;
26+
import org.tensorflow.types.family.TType;
2727

2828
/**
2929
* An API for building {@code data.experimental} operations as {@link Op Op}s
@@ -57,7 +57,7 @@ public final class DataExperimentalOps {
5757
public DataServiceDataset dataServiceDataset(Operand<TInt64> datasetId,
5858
Operand<TString> processingMode, Operand<TString> address, Operand<TString> protocol,
5959
Operand<TString> jobName, Operand<TInt64> maxOutstandingRequests, Operand<?> iterationCounter,
60-
List<DataType<?>> outputTypes, List<Shape> outputShapes,
60+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes,
6161
DataServiceDataset.Options... options) {
6262
return DataServiceDataset.create(scope, datasetId, processingMode, address, protocol, jobName, maxOutstandingRequests, iterationCounter, outputTypes, outputShapes, options);
6363
}

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java

+20-20
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.tensorflow.op;
1919

2020
import java.util.List;
21-
import org.tensorflow.DataType;
2221
import org.tensorflow.Operand;
2322
import org.tensorflow.ndarray.Shape;
2423
import org.tensorflow.op.data.AnonymousIterator;
@@ -49,6 +48,7 @@
4948
import org.tensorflow.types.TBool;
5049
import org.tensorflow.types.TInt64;
5150
import org.tensorflow.types.TString;
51+
import org.tensorflow.types.family.TType;
5252

5353
/**
5454
* An API for building {@code data} operations as {@link Op Op}s
@@ -75,7 +75,7 @@ public final class DataOps {
7575
* @param outputShapes
7676
* @return a new instance of AnonymousIterator
7777
*/
78-
public AnonymousIterator anonymousIterator(List<DataType<?>> outputTypes,
78+
public AnonymousIterator anonymousIterator(List<Class<? extends TType>> outputTypes,
7979
List<Shape> outputShapes) {
8080
return AnonymousIterator.create(scope, outputTypes, outputShapes);
8181
}
@@ -93,8 +93,8 @@ public AnonymousIterator anonymousIterator(List<DataType<?>> outputTypes,
9393
* @return a new instance of BatchDataset
9494
*/
9595
public BatchDataset batchDataset(Operand<?> inputDataset, Operand<TInt64> batchSize,
96-
Operand<TBool> dropRemainder, List<DataType<?>> outputTypes, List<Shape> outputShapes,
97-
BatchDataset.Options... options) {
96+
Operand<TBool> dropRemainder, List<Class<? extends TType>> outputTypes,
97+
List<Shape> outputShapes, BatchDataset.Options... options) {
9898
return BatchDataset.create(scope, inputDataset, batchSize, dropRemainder, outputTypes, outputShapes, options);
9999
}
100100

@@ -129,7 +129,7 @@ public CSVDataset cSVDataset(Operand<TString> filenames, Operand<TString> compre
129129
* @return a new instance of ConcatenateDataset
130130
*/
131131
public ConcatenateDataset concatenateDataset(Operand<?> inputDataset, Operand<?> anotherDataset,
132-
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
132+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
133133
return ConcatenateDataset.create(scope, inputDataset, anotherDataset, outputTypes, outputShapes);
134134
}
135135

@@ -164,8 +164,8 @@ public DeserializeIterator deserializeIterator(Operand<?> resourceHandle, Operan
164164
* @param outputShapes
165165
* @return a new instance of Iterator
166166
*/
167-
public Iterator iterator(String sharedName, String container, List<DataType<?>> outputTypes,
168-
List<Shape> outputShapes) {
167+
public Iterator iterator(String sharedName, String container,
168+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
169169
return Iterator.create(scope, sharedName, container, outputTypes, outputShapes);
170170
}
171171

@@ -177,8 +177,8 @@ public Iterator iterator(String sharedName, String container, List<DataType<?>>
177177
* @param outputShapes
178178
* @return a new instance of IteratorGetNext
179179
*/
180-
public IteratorGetNext iteratorGetNext(Operand<?> iterator, List<DataType<?>> outputTypes,
181-
List<Shape> outputShapes) {
180+
public IteratorGetNext iteratorGetNext(Operand<?> iterator,
181+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
182182
return IteratorGetNext.create(scope, iterator, outputTypes, outputShapes);
183183
}
184184

@@ -191,7 +191,7 @@ public IteratorGetNext iteratorGetNext(Operand<?> iterator, List<DataType<?>> ou
191191
* @return a new instance of IteratorGetNextAsOptional
192192
*/
193193
public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand<?> iterator,
194-
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
194+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
195195
return IteratorGetNextAsOptional.create(scope, iterator, outputTypes, outputShapes);
196196
}
197197

@@ -208,8 +208,8 @@ public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand<?> iterator,
208208
* @param outputShapes
209209
* @return a new instance of IteratorGetNextSync
210210
*/
211-
public IteratorGetNextSync iteratorGetNextSync(Operand<?> iterator, List<DataType<?>> outputTypes,
212-
List<Shape> outputShapes) {
211+
public IteratorGetNextSync iteratorGetNextSync(Operand<?> iterator,
212+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
213213
return IteratorGetNextSync.create(scope, iterator, outputTypes, outputShapes);
214214
}
215215

@@ -255,8 +255,8 @@ public OptionalFromValue optionalFromValue(Iterable<Operand<?>> components) {
255255
* @param outputShapes
256256
* @return a new instance of OptionalGetValue
257257
*/
258-
public OptionalGetValue optionalGetValue(Operand<?> optional, List<DataType<?>> outputTypes,
259-
List<Shape> outputShapes) {
258+
public OptionalGetValue optionalGetValue(Operand<?> optional,
259+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
260260
return OptionalGetValue.create(scope, optional, outputTypes, outputShapes);
261261
}
262262

@@ -290,7 +290,7 @@ public OptionalNone optionalNone() {
290290
* @return a new instance of RangeDataset
291291
*/
292292
public RangeDataset rangeDataset(Operand<TInt64> start, Operand<TInt64> stop,
293-
Operand<TInt64> step, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
293+
Operand<TInt64> step, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
294294
return RangeDataset.create(scope, start, stop, step, outputTypes, outputShapes);
295295
}
296296

@@ -305,7 +305,7 @@ public RangeDataset rangeDataset(Operand<TInt64> start, Operand<TInt64> stop,
305305
* @return a new instance of RepeatDataset
306306
*/
307307
public RepeatDataset repeatDataset(Operand<?> inputDataset, Operand<TInt64> count,
308-
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
308+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
309309
return RepeatDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
310310
}
311311

@@ -332,7 +332,7 @@ public SerializeIterator serializeIterator(Operand<?> resourceHandle,
332332
* @return a new instance of SkipDataset
333333
*/
334334
public SkipDataset skipDataset(Operand<?> inputDataset, Operand<TInt64> count,
335-
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
335+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
336336
return SkipDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
337337
}
338338

@@ -348,7 +348,7 @@ public SkipDataset skipDataset(Operand<?> inputDataset, Operand<TInt64> count,
348348
* @return a new instance of TakeDataset
349349
*/
350350
public TakeDataset takeDataset(Operand<?> inputDataset, Operand<TInt64> count,
351-
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
351+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
352352
return TakeDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
353353
}
354354

@@ -409,8 +409,8 @@ public TfRecordDataset tfRecordDataset(Operand<TString> filenames,
409409
* @param outputShapes
410410
* @return a new instance of ZipDataset
411411
*/
412-
public ZipDataset zipDataset(Iterable<Operand<?>> inputDatasets, List<DataType<?>> outputTypes,
413-
List<Shape> outputShapes) {
412+
public ZipDataset zipDataset(Iterable<Operand<?>> inputDatasets,
413+
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
414414
return ZipDataset.create(scope, inputDatasets, outputTypes, outputShapes);
415415
}
416416

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DtypesOps.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
//
1818
package org.tensorflow.op;
1919

20-
import org.tensorflow.DataType;
2120
import org.tensorflow.Operand;
2221
import org.tensorflow.op.dtypes.AsString;
2322
import org.tensorflow.op.dtypes.Cast;
@@ -73,7 +72,7 @@ public <T extends TType> AsString asString(Operand<T> input, AsString.Options...
7372
* @param options carries optional attributes values
7473
* @return a new instance of Cast
7574
*/
76-
public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, DataType<U> DstT,
75+
public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, Class<U> DstT,
7776
Cast.Options... options) {
7877
return Cast.create(scope, x, DstT, options);
7978
}
@@ -102,7 +101,7 @@ public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, DataType<U>
102101
* @return a new instance of Complex
103102
*/
104103
public <U extends TType, T extends TNumber> Complex<U> complex(Operand<T> real, Operand<T> imag,
105-
DataType<U> Tout) {
104+
Class<U> Tout) {
106105
return Complex.create(scope, real, imag, Tout);
107106
}
108107

0 commit comments

Comments
 (0)