Skip to content

Boolean mask ops #214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions ndarray/src/main/java/org/tensorflow/ndarray/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,27 @@ public Shape takeLast(int n) {
return Shape.of(newDimensions);
}

/**
* Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code end}.
* @param begin Where to start the sub-shape.
* @param end Where to end the sub-shape, exclusive.
* @return the sub-shape bounded by begin and end.
*/
public Shape subShape(int begin, int end){
if (end > numDimensions()) {
throw new ArrayIndexOutOfBoundsException(
"End index " + end + " out of bounds: shape only has " + numDimensions() + " dimensions.");
}
if (begin < 0) {
throw new ArrayIndexOutOfBoundsException(
"Begin index " + begin + " out of bounds: cannot be less than 0.");
}

long[] newDimensions = new long[end - begin];
System.arraycopy(dimensionSizes, begin, newDimensions, 0, end - begin);
return Shape.of(newDimensions);
}

/**
* Returns a new Shape, with a new first dimension added. In order for this call to succeed,
* {@link Shape#isUnknown()} must be {@code false}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
import org.tensorflow.op.core.BatchToSpace;
import org.tensorflow.op.core.BatchToSpaceNd;
import org.tensorflow.op.core.Bitcast;
import org.tensorflow.op.core.BooleanMask;
import org.tensorflow.op.core.BooleanMaskUpdate;
import org.tensorflow.op.core.BroadcastDynamicShape;
import org.tensorflow.op.core.BroadcastTo;
import org.tensorflow.op.core.Bucketize;
Expand Down Expand Up @@ -347,10 +349,10 @@ public final class Ops {

public final SignalOps signal;

public final QuantizationOps quantization;

public final TrainOps train;

public final QuantizationOps quantization;

private final Scope scope;

private Ops(Scope scope) {
Expand All @@ -372,8 +374,8 @@ private Ops(Scope scope) {
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
quantization = new QuantizationOps(this);
train = new TrainOps(this);
quantization = new QuantizationOps(this);
}

/**
Expand Down Expand Up @@ -989,6 +991,61 @@ public <U extends TType> Bitcast<U> bitcast(Operand<? extends TType> input, Clas
return Bitcast.create(scope, input, type);
}

/**
* Apply boolean mask to tensor. Returns the flat array of each element corresponding to a {@code true} in the mask.
* <p>
* Numpy equivalent is {@code tensor[mask]}.
* <p>
* In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match
* the first K dimensions of {@code tensor}'s shape. We then have:
* {@code booleanMask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]}
* where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major order).
* <p>
* The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default).
* In that case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
* the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
*
* @param scope
* @param tensor The tensor to mask.
* @param mask The mask to apply.
* @param options carries optional attributes values
* @return The masked tensor.
*/
public <T extends TType> Operand<T> booleanMask(Operand<T> tensor, Operand<TBool> mask,
BooleanMask.Options... options) {
return BooleanMask.create(scope, tensor, mask, options);
}

/**
* Updates a tensor at the masked values, and returns the updated tensor. Does not mutate the input tensors. {@code
* updates} will be broadcasted by default
* <p>
* Numpy equivalent is `tensor[mask] = updates`.
* <p>
* In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match the first K dimensions of
* {@code tensor}'s shape. We then have: {@code booleanMask(tensor, mask)[i, j1,...,jd] =
* tensor[i1,...,iK,j1,...,jd]} where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major
* order).
* <p>
* The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default). In that
* case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match the first {@code axis +
* dim(mask)} dimensions of {@code tensor}'s shape.
* <p>
* The shape of {@code updates} should be {@code [n, t_1, t_2, ...]} where {@code n} is the number of true values in
* {@code mask} and {@code t_i} is the {@code i}th dimension of {@code tensor} after {@code axis} and {@code mask}.
* {@code updates} will be broadcasted to this shape by default, which can be disabled using {@code options}.
*
* @param tensor The tensor to mask.
* @param mask The mask to apply.
* @param updates the new values
* @param options carries optional attributes values
* @return The masked tensor.
*/
public <T extends TType> Operand<T> booleanMaskUpdate(Operand<T> tensor, Operand<TBool> mask,
Operand<T> updates, BooleanMaskUpdate.Options... options) {
return BooleanMaskUpdate.create(scope, tensor, mask, updates, options);
}

/**
* Return the shape of s0 op s1 with broadcast.
* <p>
Expand Down Expand Up @@ -1834,13 +1891,14 @@ public Constant<TInt32> constant(Shape shape, IntDataBuffer data) {
}

/**
* Creates a scalar of {@code type}, with the value of {@code number}.
* {@code number} may be truncated if it does not fit in the target type.
* Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not
* fit in the target type.
*
* @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating})
* @param number the value of the tensor
* @return a constant of the passed type
* @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or unknown.
* @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or
* unknown.
*/
public <T extends TNumber> Constant<T> constant(Class<T> type, Number number) {
return Constant.tensorOf(scope, type, number);
Expand Down Expand Up @@ -1892,14 +1950,14 @@ public <T extends TType> Constant<T> constantOf(T tensor) {
}

/**
* Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}.
* {@code number} may be truncated if it does not fit in the target type.
* Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be
* truncated if it does not fit in the target type.
*
* @param toMatch the operand providing the target type
* @param number the value of the tensor
* @return a constant with the same type as {@code toMatch}
* @see Ops#constant(Class, Number)
* @throws IllegalArgumentException if the type is unknown (which should be impossible).
* @see Ops#constant(Class, Number)
*/
public <T extends TNumber> Constant<T> constantOfSameType(Operand<T> toMatch, Number number) {
return Constant.tensorOfSameType(scope, toMatch, number);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,25 @@ public Scope withName(String opName) {
return new Scope(env, nameScope.withName(opName), controlDependencies, deviceSpec);
}

/**
* Returns a new scope where added operations will be prefixed by this scope's op name
* (set by {@link #withName(String)}), or the given default if it is unset. This is intended to be used for
* composite ops.
*
* <p>Ops created with this scope will have {@code name/opName/} as the prefix. The actual
* name will be unique in the returned scope. All other properties are inherited from the current
* scope.
*
* <p>The default child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*}
*
* @param defaultName name of the sub scope if this scope's name hasn't been set.
* @return a new subscope
* @throws IllegalArgumentException if the name is invalid
*/
public Scope withNameAsSubScope(String defaultName){
return new Scope(env, nameScope.withSubScope(nameScope.makeOpName(defaultName)), controlDependencies, deviceSpec);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be interesting to update our other composite ops to make use of this new method, like this one:

/**
* Return a new scope that uses the provided device specification for an op.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow.op.core;

import java.util.Arrays;
import java.util.Collections;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Scope;
import org.tensorflow.op.annotation.Endpoint;
import org.tensorflow.op.annotation.Operator;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TType;

@Operator
public abstract class BooleanMask {

/**
* Apply boolean mask to tensor. Returns the flat array of each element corresponding to a {@code true} in the mask.
* <p>
* Numpy equivalent is {@code tensor[mask]}.
* <p>
* In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match
* the first K dimensions of {@code tensor}'s shape. We then have:
* {@code booleanMask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]}
* where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major order).
* <p>
* The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default).
* In that case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
* the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
*
* @param scope
* @param tensor The tensor to mask.
* @param mask The mask to apply.
* @param options carries optional attributes values
* @return The masked tensor.
*/
@Endpoint(name = "booleanMask")
public static <T extends TType> Operand<T> create(Scope scope, Operand<T> tensor, Operand<TBool> mask,
Options... options) {

scope = scope.withNameAsSubScope("BooleanMask");

int axis = 0;
if (options != null) {
for (Options opts : options) {
if (opts.axis != null) {
axis = opts.axis;
}
}
}

if (axis < 0) {
axis += tensor.rank();
}

Shape maskShape = mask.shape();
Shape tensorShape = tensor.shape();

if (maskShape.numDimensions() == 0) {
throw new IllegalArgumentException("Mask cannot be a scalar.");
}
if (maskShape.hasUnknownDimension()) {
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
}

Operand<TInt32> axisTensor = Constant.scalarOf(scope, axis);
Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions());
if (!requiredMaskShape.isCompatibleWith(maskShape)) {
throw new IllegalArgumentException(
"Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + ".");
}

org.tensorflow.op.core.Shape<TInt32> liveShape = org.tensorflow.op.core.Shape.create(scope, tensor);

Operand<TInt32> leadingSize = ReduceProd.create(scope,
StridedSliceHelper.stridedSlice(scope,
liveShape,
Indices.range(axis, axis + maskShape.numDimensions())
),
Constant.arrayOf(scope, 0)
);

Operand<T> flattened = Reshape.create(scope, tensor, Concat.create(
scope,
Arrays.asList(
StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceTo(axis)),
Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)),
StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions()))
),
Constant.scalarOf(scope, 0)
));

Operand<TBool> flatMask = Reshape.create(scope, mask, Constant.arrayOf(scope, -1));

Operand<TInt64> indices = Squeeze.create(scope, Where.create(scope, flatMask), Squeeze.axis(Collections.singletonList(1L)));
return Gather.create(scope, flattened, indices, axisTensor);
}

/**
* Used to indicate the axis to mask from.
* {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
* the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
* @param axis the axis to mask from. Uses 0 if null.
*/
public static Options axis(Integer axis){
return new Options().axis(axis);
}


/**
* Used to indicate the axis to mask from.
* {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
* the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
* @param axis the axis to mask from.
*/
public static Options axis(int axis){
return new Options().axis(axis);
}

/**
* Optional attributes for {@link org.tensorflow.op.core.BooleanMask}
*/
public static class Options {

/**
* @param axis (Optional) The axis to mask from, or 0 if not set.
*/
public Options axis(Integer axis) {
this.axis = axis;
return this;
}

private Integer axis;

private Options() {
}
}

}
Loading