Skip to content

Commit 439373d

Browse files
authored
Boolean mask ops (#214)
1 parent ef6ece6 commit 439373d

File tree

7 files changed

+666
-9
lines changed

7 files changed

+666
-9
lines changed

ndarray/src/main/java/org/tensorflow/ndarray/Shape.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,27 @@ public Shape takeLast(int n) {
275275
return Shape.of(newDimensions);
276276
}
277277

278+
/**
279+
* Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code end}.
280+
* @param begin Where to start the sub-shape.
281+
* @param end Where to end the sub-shape, exclusive.
282+
* @return the sub-shape bounded by begin and end.
283+
*/
284+
public Shape subShape(int begin, int end){
285+
if (end > numDimensions()) {
286+
throw new ArrayIndexOutOfBoundsException(
287+
"End index " + end + " out of bounds: shape only has " + numDimensions() + " dimensions.");
288+
}
289+
if (begin < 0) {
290+
throw new ArrayIndexOutOfBoundsException(
291+
"Begin index " + begin + " out of bounds: cannot be less than 0.");
292+
}
293+
294+
long[] newDimensions = new long[end - begin];
295+
System.arraycopy(dimensionSizes, begin, newDimensions, 0, end - begin);
296+
return Shape.of(newDimensions);
297+
}
298+
278299
/**
279300
* Returns a new Shape, with a new first dimension added. In order for this call to succeed,
280301
* {@link Shape#isUnknown()} must be {@code false}.

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
import org.tensorflow.op.core.BatchToSpace;
6060
import org.tensorflow.op.core.BatchToSpaceNd;
6161
import org.tensorflow.op.core.Bitcast;
62+
import org.tensorflow.op.core.BooleanMask;
63+
import org.tensorflow.op.core.BooleanMaskUpdate;
6264
import org.tensorflow.op.core.BroadcastDynamicShape;
6365
import org.tensorflow.op.core.BroadcastTo;
6466
import org.tensorflow.op.core.Bucketize;
@@ -362,10 +364,10 @@ public final class Ops {
362364

363365
public final SignalOps signal;
364366

365-
public final QuantizationOps quantization;
366-
367367
public final TrainOps train;
368368

369+
public final QuantizationOps quantization;
370+
369371
private final Scope scope;
370372

371373
private Ops(Scope scope) {
@@ -388,8 +390,8 @@ private Ops(Scope scope) {
388390
math = new MathOps(this);
389391
audio = new AudioOps(this);
390392
signal = new SignalOps(this);
391-
quantization = new QuantizationOps(this);
392393
train = new TrainOps(this);
394+
quantization = new QuantizationOps(this);
393395
}
394396

395397
/**
@@ -1005,6 +1007,61 @@ public <U extends TType> Bitcast<U> bitcast(Operand<? extends TType> input, Clas
10051007
return Bitcast.create(scope, input, type);
10061008
}
10071009

1010+
/**
1011+
* Apply boolean mask to tensor. Returns the flat array of each element corresponding to a {@code true} in the mask.
1012+
* <p>
1013+
* Numpy equivalent is {@code tensor[mask]}.
1014+
* <p>
1015+
* In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match
1016+
* the first K dimensions of {@code tensor}'s shape. We then have:
1017+
* {@code booleanMask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]}
1018+
* where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major order).
1019+
* <p>
1020+
* The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default).
1021+
* In that case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
1022+
* the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
1023+
*
1024+
* @param scope
1025+
* @param tensor The tensor to mask.
1026+
* @param mask The mask to apply.
1027+
* @param options carries optional attributes values
1028+
* @return The masked tensor.
1029+
*/
1030+
public <T extends TType> Operand<T> booleanMask(Operand<T> tensor, Operand<TBool> mask,
1031+
BooleanMask.Options... options) {
1032+
return BooleanMask.create(scope, tensor, mask, options);
1033+
}
1034+
1035+
/**
1036+
* Updates a tensor at the masked values, and returns the updated tensor. Does not mutate the input tensors. {@code
1037+
* updates} will be broadcasted by default
1038+
* <p>
1039+
* Numpy equivalent is `tensor[mask] = updates`.
1040+
* <p>
1041+
* In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match the first K dimensions of
1042+
* {@code tensor}'s shape. We then have: {@code booleanMask(tensor, mask)[i, j1,...,jd] =
1043+
* tensor[i1,...,iK,j1,...,jd]} where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major
1044+
* order).
1045+
* <p>
1046+
* The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default). In that
1047+
* case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match the first {@code axis +
1048+
* dim(mask)} dimensions of {@code tensor}'s shape.
1049+
* <p>
1050+
* The shape of {@code updates} should be {@code [n, t_1, t_2, ...]} where {@code n} is the number of true values in
1051+
* {@code mask} and {@code t_i} is the {@code i}th dimension of {@code tensor} after {@code axis} and {@code mask}.
1052+
* {@code updates} will be broadcasted to this shape by default, which can be disabled using {@code options}.
1053+
*
1054+
* @param tensor The tensor to mask.
1055+
* @param mask The mask to apply.
1056+
* @param updates the new values
1057+
* @param options carries optional attributes values
1058+
* @return The masked tensor.
1059+
*/
1060+
public <T extends TType> Operand<T> booleanMaskUpdate(Operand<T> tensor, Operand<TBool> mask,
1061+
Operand<T> updates, BooleanMaskUpdate.Options... options) {
1062+
return BooleanMaskUpdate.create(scope, tensor, mask, updates, options);
1063+
}
1064+
10081065
/**
10091066
* Return the shape of s0 op s1 with broadcast.
10101067
* <p>
@@ -1850,13 +1907,14 @@ public Constant<TInt32> constant(Shape shape, IntDataBuffer data) {
18501907
}
18511908

18521909
/**
1853-
* Creates a scalar of {@code type}, with the value of {@code number}.
1854-
* {@code number} may be truncated if it does not fit in the target type.
1910+
* Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not
1911+
* fit in the target type.
18551912
*
18561913
* @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating})
18571914
* @param number the value of the tensor
18581915
* @return a constant of the passed type
1859-
* @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or unknown.
1916+
* @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or
1917+
* unknown.
18601918
*/
18611919
public <T extends TNumber> Constant<T> constant(Class<T> type, Number number) {
18621920
return Constant.tensorOf(scope, type, number);
@@ -1908,14 +1966,14 @@ public <T extends TType> Constant<T> constantOf(T tensor) {
19081966
}
19091967

19101968
/**
1911-
* Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}.
1912-
* {@code number} may be truncated if it does not fit in the target type.
1969+
* Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be
1970+
* truncated if it does not fit in the target type.
19131971
*
19141972
* @param toMatch the operand providing the target type
19151973
* @param number the value of the tensor
19161974
* @return a constant with the same type as {@code toMatch}
1917-
* @see Ops#constant(Class, Number)
19181975
* @throws IllegalArgumentException if the type is unknown (which should be impossible).
1976+
* @see Ops#constant(Class, Number)
19191977
*/
19201978
public <T extends TNumber> Constant<T> constantOfSameType(Operand<T> toMatch, Number number) {
19211979
return Constant.tensorOfSameType(scope, toMatch, number);

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,25 @@ public Scope withName(String opName) {
125125
return new Scope(env, nameScope.withName(opName), controlDependencies, deviceSpec);
126126
}
127127

128+
/**
129+
* Returns a new scope where added operations will be prefixed by this scope's op name
130+
* (set by {@link #withName(String)}), or the given default if it is unset. This is intended to be used for
131+
* composite ops.
132+
*
133+
* <p>Ops created with this scope will have {@code name/opName/} as the prefix. The actual
134+
* name will be unique in the returned scope. All other properties are inherited from the current
135+
* scope.
136+
*
137+
* <p>The default child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*}
138+
*
139+
* @param defaultName name of the sub scope if this scope's name hasn't been set.
140+
* @return a new subscope
141+
* @throws IllegalArgumentException if the name is invalid
142+
*/
143+
public Scope withNameAsSubScope(String defaultName){
144+
return new Scope(env, nameScope.withSubScope(nameScope.makeOpName(defaultName)), controlDependencies, deviceSpec);
145+
}
146+
128147
/**
129148
* Return a new scope that uses the provided device specification for an op.
130149
*
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/*
2+
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow.op.core;
18+
19+
import java.util.Arrays;
20+
import java.util.Collections;
21+
import org.tensorflow.Operand;
22+
import org.tensorflow.ndarray.Shape;
23+
import org.tensorflow.ndarray.index.Indices;
24+
import org.tensorflow.op.Scope;
25+
import org.tensorflow.op.annotation.Endpoint;
26+
import org.tensorflow.op.annotation.Operator;
27+
import org.tensorflow.types.TBool;
28+
import org.tensorflow.types.TInt32;
29+
import org.tensorflow.types.TInt64;
30+
import org.tensorflow.types.family.TType;
31+
32+
@Operator
33+
public abstract class BooleanMask {
34+
35+
/**
36+
* Apply boolean mask to tensor. Returns the flat array of each element corresponding to a {@code true} in the mask.
37+
* <p>
38+
* Numpy equivalent is {@code tensor[mask]}.
39+
* <p>
40+
* In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match
41+
* the first K dimensions of {@code tensor}'s shape. We then have:
42+
* {@code booleanMask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]}
43+
* where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major order).
44+
* <p>
45+
* The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default).
46+
* In that case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
47+
* the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
48+
*
49+
* @param scope
50+
* @param tensor The tensor to mask.
51+
* @param mask The mask to apply.
52+
* @param options carries optional attributes values
53+
* @return The masked tensor.
54+
*/
55+
@Endpoint(name = "booleanMask")
56+
public static <T extends TType> Operand<T> create(Scope scope, Operand<T> tensor, Operand<TBool> mask,
57+
Options... options) {
58+
59+
scope = scope.withNameAsSubScope("BooleanMask");
60+
61+
int axis = 0;
62+
if (options != null) {
63+
for (Options opts : options) {
64+
if (opts.axis != null) {
65+
axis = opts.axis;
66+
}
67+
}
68+
}
69+
70+
if (axis < 0) {
71+
axis += tensor.rank();
72+
}
73+
74+
Shape maskShape = mask.shape();
75+
Shape tensorShape = tensor.shape();
76+
77+
if (maskShape.numDimensions() == 0) {
78+
throw new IllegalArgumentException("Mask cannot be a scalar.");
79+
}
80+
if (maskShape.hasUnknownDimension()) {
81+
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
82+
}
83+
84+
Operand<TInt32> axisTensor = Constant.scalarOf(scope, axis);
85+
Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions());
86+
if (!requiredMaskShape.isCompatibleWith(maskShape)) {
87+
throw new IllegalArgumentException(
88+
"Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + ".");
89+
}
90+
91+
org.tensorflow.op.core.Shape<TInt32> liveShape = org.tensorflow.op.core.Shape.create(scope, tensor);
92+
93+
Operand<TInt32> leadingSize = ReduceProd.create(scope,
94+
StridedSliceHelper.stridedSlice(scope,
95+
liveShape,
96+
Indices.range(axis, axis + maskShape.numDimensions())
97+
),
98+
Constant.arrayOf(scope, 0)
99+
);
100+
101+
Operand<T> flattened = Reshape.create(scope, tensor, Concat.create(
102+
scope,
103+
Arrays.asList(
104+
StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceTo(axis)),
105+
Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)),
106+
StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions()))
107+
),
108+
Constant.scalarOf(scope, 0)
109+
));
110+
111+
Operand<TBool> flatMask = Reshape.create(scope, mask, Constant.arrayOf(scope, -1));
112+
113+
Operand<TInt64> indices = Squeeze.create(scope, Where.create(scope, flatMask), Squeeze.axis(Collections.singletonList(1L)));
114+
return Gather.create(scope, flattened, indices, axisTensor);
115+
}
116+
117+
/**
118+
* Used to indicate the axis to mask from.
119+
* {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
120+
* the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
121+
* @param axis the axis to mask from. Uses 0 if null.
122+
*/
123+
public static Options axis(Integer axis){
124+
return new Options().axis(axis);
125+
}
126+
127+
128+
/**
129+
* Used to indicate the axis to mask from.
130+
* {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
131+
* the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
132+
* @param axis the axis to mask from.
133+
*/
134+
public static Options axis(int axis){
135+
return new Options().axis(axis);
136+
}
137+
138+
/**
139+
* Optional attributes for {@link org.tensorflow.op.core.BooleanMask}
140+
*/
141+
public static class Options {
142+
143+
/**
144+
* @param axis (Optional) The axis to mask from, or 0 if not set.
145+
*/
146+
public Options axis(Integer axis) {
147+
this.axis = axis;
148+
return this;
149+
}
150+
151+
private Integer axis;
152+
153+
private Options() {
154+
}
155+
}
156+
157+
}

0 commit comments

Comments
 (0)