Skip to content

Implements both Tensor and NdArray interfaces from same instance #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
8 changes: 4 additions & 4 deletions ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public static ByteNdArray vectorOf(byte... values) {
if (values == null) {
throw new IllegalArgumentException("Values cannot be null");
}
return wrap(DataBuffers.of(values, false, false), Shape.of(values.length));
return wrap(Shape.of(values.length), DataBuffers.of(values, false, false));
}

/**
Expand All @@ -81,19 +81,19 @@ public static ByteNdArray ofBytes(Shape shape) {
if (shape == null) {
throw new IllegalArgumentException("Shape cannot be null");
}
return wrap(DataBuffers.ofBytes(shape.size()), shape);
return wrap(shape, DataBuffers.ofBytes(shape.size()));
}

/**
* Wraps a buffer in a byte N-dimensional array of a given shape.
*
* @param buffer buffer to wrap
* @param shape shape of the array
* @param buffer buffer to wrap
* @return new byte N-dimensional array
* @throws IllegalArgumentException if shape is null, has unknown dimensions or has size bigger
* in the buffer size
*/
public static ByteNdArray wrap(ByteDataBuffer buffer, Shape shape) {
public static ByteNdArray wrap(Shape shape, ByteDataBuffer buffer) {
return ByteDenseNdArray.create(buffer, shape);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
writer->EndLine();
for (const ArgumentSpec& input : op.inputs()) {
if (input.iterable()) {
writer->Append("opBuilder.addInputList(Operands.asOutputs(" +
writer->Append("opBuilder.addInputList(Operands.asOutputs(scope, " +
input.var().name() + "));");
writer->EndLine();
} else {
writer->Append("opBuilder.addInput(" + input.var().name() +
".asOutput());");
".asOutput(scope));");
writer->EndLine();
}
}
Expand Down Expand Up @@ -348,8 +348,9 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
if (mode == OPERAND) {
bool cast2obj = output.type().wildcard();
Type return_type = Type::Class("Output", "org.tensorflow")
.add_parameter(cast2obj ? Type::Class("TType", "org.tensorflow.types.family") : output.type());
.add_parameter(cast2obj ? Type::Interface("TType", "org.tensorflow.types.family") : output.type());
Method as_output = Method::Create("asOutput", return_type)
.add_argument(Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")))
.add_annotation(Annotation::Create("Override"));
if (cast2obj) {
as_output.add_annotation(
Expand All @@ -366,7 +367,7 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
} else if (mode == LIST_OPERAND) {
Type operand = Type::Interface("Operand", "org.tensorflow");
if (output.type().wildcard()) {
operand.add_parameter(Type::Class("TType", "org.tensorflow.types.family"));
operand.add_parameter(Type::Interface("TType", "org.tensorflow.types.family"));
} else {
operand.add_parameter(output.type());
}
Expand Down Expand Up @@ -430,10 +431,9 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
RenderMode mode = DEFAULT;
if (op.outputs().size() == 1) {
const ArgumentSpec& output = op.outputs().front();
Type operand_type(output.type().wildcard() ? Type::Class("TType", "org.tensorflow.types.family")
Type operand_type(output.type().wildcard() ? Type::Interface("TType", "org.tensorflow.types.family")
: output.type());
Type operand_inf(Type::Interface("Operand", "org.tensorflow")
.add_parameter(operand_type));
Type operand_inf(Type::Interface("Operand", "org.tensorflow").add_parameter(operand_type));
if (output.iterable()) {
mode = LIST_OPERAND;
op_class.add_supertype(Type::IterableOf(operand_inf));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ class TypeResolver {
std::pair<Type, Type> TypesOf(const OpDef_AttrDef& attr_def,
bool* iterable_out);

// Returns the highest type family this attribute is part of
//
// For example, if the attribute is of type 'bool', the base 'TType' family
// returned. But if it represents a number, like a float or an integer,
// then 'TNumber' (which supersedes 'TType') is returned.
Type FamilyOf(const OpDef_AttrDef& attr_def);

// Returns true if the type of this attribute has already been resolved
bool IsAttributeVisited(const string& attr_name) {
return visited_attrs_.find(attr_name) != visited_attrs_.cend();
Expand All @@ -81,13 +88,12 @@ class TypeResolver {
std::pair<Type, Type> MakeTypePair(const Type& type) {
return std::make_pair(type, type);
}
Type NextGeneric() {
Type NextGeneric(const Type& typeFamily) {
char generic_letter = next_generic_letter_++;
if (next_generic_letter_ > 'Z') {
next_generic_letter_ = 'A';
}
return Type::Generic(string(1, generic_letter))
.add_supertype(Type::Class("TType", "org.tensorflow.types.family"));
return Type::Generic(string(1, generic_letter)).add_supertype(typeFamily);
}
};

Expand Down Expand Up @@ -156,10 +162,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
.add_parameter(Type::Wildcard()));

} else if (attr_type == "type") {
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
if (IsRealNumbers(attr_def.allowed_values())) {
type.add_supertype(Type::Class("TNumber", "org.tensorflow.types.family"));
}
Type type = *iterable_out ? Type::Wildcard() : NextGeneric(FamilyOf(attr_def));
types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));

} else {
Expand All @@ -170,6 +173,14 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
return types;
}

Type TypeResolver::FamilyOf(const OpDef_AttrDef& attr_def) {
// TODO (karlllessard): add more type families
if (IsRealNumbers(attr_def.allowed_values())) {
return Type::Interface("TNumber", "org.tensorflow.types.family");
}
return Type::Interface("TType", "org.tensorflow.types.family");
}

string SnakeToCamelCase(const string& str, bool upper = false) {
string result;
bool cap = upper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,15 @@ SourceWriter& SourceWriter::WriteGenerics(
Append(", ");
}
Append(pt->name());
if (!pt->supertypes().empty()) {
Append(" extends ").AppendType(pt->supertypes().front());
bool first_bound = true;
for (const Type& bound : pt->supertypes()) {
if (first_bound) {
Append(" extends ");
first_bound = false;
} else {
Append(" & ");
}
AppendType(bound);
}
first = false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.tensorflow.EagerSession;
import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.BooleanNdArray;
import org.tensorflow.ndarray.ByteNdArray;
import org.tensorflow.ndarray.DoubleNdArray;
Expand Down Expand Up @@ -1052,6 +1051,17 @@ public <T extends TNumber> Bucketize bucketize(Operand<T> input, List<Float> bou
return Bucketize.create(scope, input, boundaries);
}

/**
* Create a constant from a Tensor.
*
* @param scope is a scope used to add the underlying operation.
* @param tensor a Tensor holding the constant value
* @return a constant of the same data type as `tensor`
*/
public <T extends TType> Constant<T> capture(T tensor) {
return Constant.create(scope, tensor);
}

/**
* Clips tensor values to a specified min and max.
* <p>
Expand Down Expand Up @@ -1687,17 +1697,6 @@ public Constant<TInt64> constant(Shape shape) {
return Constant.tensorOf(scope, shape);
}

/**
* Create a constant from a Tensor.
*
* @param scope is a scope used to add the underlying operation.
* @param tensor a Tensor holding the constant value
* @return a constant of the same data type as `tensor`
*/
public <T extends TType> Constant<T> constant(Tensor<T> tensor) {
return Constant.create(scope, tensor);
}

/**
* Creates a constant of {@code String} elements, using the given charset.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ private Options() {
@Endpoint(describeByClass = true)
public static AudioSpectrogram create(Scope scope, Operand<TFloat32> input, Long windowSize, Long stride, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("AudioSpectrogram", scope.makeOpName("AudioSpectrogram"));
opBuilder.addInput(input.asOutput());
opBuilder.addInput(input.asOutput(scope));
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder.setAttr("window_size", windowSize);
opBuilder.setAttr("stride", stride);
Expand Down Expand Up @@ -123,7 +123,7 @@ public Output<TFloat32> spectrogram() {
}

@Override
public Output<TFloat32> asOutput() {
public Output<TFloat32> asOutput(Scope scope) {
return spectrogram;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ private Options() {
@Endpoint(describeByClass = true)
public static DecodeWav create(Scope scope, Operand<TString> contents, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("DecodeWav", scope.makeOpName("DecodeWav"));
opBuilder.addInput(contents.asOutput());
opBuilder.addInput(contents.asOutput(scope));
opBuilder = scope.applyControlDependencies(opBuilder);
if (options != null) {
for (Options opts : options) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ public final class EncodeWav extends RawOp implements Operand<TString> {
@Endpoint(describeByClass = true)
public static EncodeWav create(Scope scope, Operand<TFloat32> audio, Operand<TInt32> sampleRate) {
OperationBuilder opBuilder = scope.env().opBuilder("EncodeWav", scope.makeOpName("EncodeWav"));
opBuilder.addInput(audio.asOutput());
opBuilder.addInput(sampleRate.asOutput());
opBuilder.addInput(audio.asOutput(scope));
opBuilder.addInput(sampleRate.asOutput(scope));
opBuilder = scope.applyControlDependencies(opBuilder);
return new EncodeWav(opBuilder.build());
}
Expand All @@ -68,7 +68,7 @@ public Output<TString> contents() {
}

@Override
public Output<TString> asOutput() {
public Output<TString> asOutput(Scope scope) {
return contents;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ private Options() {
@Endpoint(describeByClass = true)
public static Mfcc create(Scope scope, Operand<TFloat32> spectrogram, Operand<TInt32> sampleRate, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("Mfcc", scope.makeOpName("Mfcc"));
opBuilder.addInput(spectrogram.asOutput());
opBuilder.addInput(sampleRate.asOutput());
opBuilder.addInput(spectrogram.asOutput(scope));
opBuilder.addInput(sampleRate.asOutput(scope));
opBuilder = scope.applyControlDependencies(opBuilder);
if (options != null) {
for (Options opts : options) {
Expand Down Expand Up @@ -161,7 +161,7 @@ public Output<TFloat32> output() {
}

@Override
public Output<TFloat32> asOutput() {
public Output<TFloat32> asOutput(Scope scope) {
return output;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.tensorflow.op.annotation.Endpoint;
import org.tensorflow.op.annotation.Operator;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TType;

/**
* Elementwise computes the bitwise AND of `x` and `y`.
Expand Down Expand Up @@ -67,8 +66,8 @@ public final class BitwiseAnd<T extends TNumber> extends RawOp implements Operan
@Endpoint(describeByClass = true)
public static <T extends TNumber> BitwiseAnd<T> create(Scope scope, Operand<T> x, Operand<T> y) {
OperationBuilder opBuilder = scope.env().opBuilder("BitwiseAnd", scope.makeOpName("BitwiseAnd"));
opBuilder.addInput(x.asOutput());
opBuilder.addInput(y.asOutput());
opBuilder.addInput(x.asOutput(scope));
opBuilder.addInput(y.asOutput(scope));
opBuilder = scope.applyControlDependencies(opBuilder);
return new BitwiseAnd<T>(opBuilder.build());
}
Expand All @@ -80,7 +79,7 @@ public Output<T> z() {
}

@Override
public Output<T> asOutput() {
public Output<T> asOutput(Scope scope) {
return z;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.tensorflow.op.annotation.Endpoint;
import org.tensorflow.op.annotation.Operator;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TType;

/**
* Elementwise computes the bitwise OR of `x` and `y`.
Expand Down Expand Up @@ -67,8 +66,8 @@ public final class BitwiseOr<T extends TNumber> extends RawOp implements Operand
@Endpoint(describeByClass = true)
public static <T extends TNumber> BitwiseOr<T> create(Scope scope, Operand<T> x, Operand<T> y) {
OperationBuilder opBuilder = scope.env().opBuilder("BitwiseOr", scope.makeOpName("BitwiseOr"));
opBuilder.addInput(x.asOutput());
opBuilder.addInput(y.asOutput());
opBuilder.addInput(x.asOutput(scope));
opBuilder.addInput(y.asOutput(scope));
opBuilder = scope.applyControlDependencies(opBuilder);
return new BitwiseOr<T>(opBuilder.build());
}
Expand All @@ -80,7 +79,7 @@ public Output<T> z() {
}

@Override
public Output<T> asOutput() {
public Output<T> asOutput(Scope scope) {
return z;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.tensorflow.op.annotation.Endpoint;
import org.tensorflow.op.annotation.Operator;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TType;

/**
* Elementwise computes the bitwise XOR of `x` and `y`.
Expand Down Expand Up @@ -67,8 +66,8 @@ public final class BitwiseXor<T extends TNumber> extends RawOp implements Operan
@Endpoint(describeByClass = true)
public static <T extends TNumber> BitwiseXor<T> create(Scope scope, Operand<T> x, Operand<T> y) {
OperationBuilder opBuilder = scope.env().opBuilder("BitwiseXor", scope.makeOpName("BitwiseXor"));
opBuilder.addInput(x.asOutput());
opBuilder.addInput(y.asOutput());
opBuilder.addInput(x.asOutput(scope));
opBuilder.addInput(y.asOutput(scope));
opBuilder = scope.applyControlDependencies(opBuilder);
return new BitwiseXor<T>(opBuilder.build());
}
Expand All @@ -80,7 +79,7 @@ public Output<T> z() {
}

@Override
public Output<T> asOutput() {
public Output<T> asOutput(Scope scope) {
return z;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.tensorflow.op.annotation.Endpoint;
import org.tensorflow.op.annotation.Operator;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TType;

/**
* Invert (flip) each bit of supported types; for example, type `uint8` value 01010101 becomes 10101010.
Expand Down Expand Up @@ -87,7 +86,7 @@ public final class Invert<T extends TNumber> extends RawOp implements Operand<T>
@Endpoint(describeByClass = true)
public static <T extends TNumber> Invert<T> create(Scope scope, Operand<T> x) {
OperationBuilder opBuilder = scope.env().opBuilder("Invert", scope.makeOpName("Invert"));
opBuilder.addInput(x.asOutput());
opBuilder.addInput(x.asOutput(scope));
opBuilder = scope.applyControlDependencies(opBuilder);
return new Invert<T>(opBuilder.build());
}
Expand All @@ -99,7 +98,7 @@ public Output<T> y() {
}

@Override
public Output<T> asOutput() {
public Output<T> asOutput(Scope scope) {
return y;
}

Expand Down
Loading