From 900a4eb80d45ab360ab9b79b70c0d554bc268189 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 4 Jan 2021 21:49:01 -0800 Subject: [PATCH 1/6] Basic booleanMask op Signed-off-by: Ryan Nett --- .../java/org/tensorflow/ndarray/Shape.java | 21 +++ .../org/tensorflow/op/core/BooleanMask.java | 123 ++++++++++++++++++ 2 files changed, 144 insertions(+) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java b/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java index a7e2dd0df82..85a905408c7 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java @@ -275,6 +275,27 @@ public Shape takeLast(int n) { return Shape.of(newDimensions); } + /** + * Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code end}. + * @param begin Where to start the sub-shape. + * @param end Where to end the sub-shape, exclusive. + * @return the sub-shape bounded by begin and end. + */ + public Shape subShape(int begin, int end){ + if (end > numDimensions()) { + throw new ArrayIndexOutOfBoundsException( + "End index " + end + " out of bounds: shape only has " + numDimensions() + " dimensions."); + } + if (begin < 0) { + throw new ArrayIndexOutOfBoundsException( + "Begin index " + begin + " out of bounds: cannot be less than 0."); + } + + long[] newDimensions = new long[end - begin]; + System.arraycopy(dimensionSizes, begin, newDimensions, 0, end - begin); + return Shape.of(newDimensions); + } + /** * Returns a new Shape, with a new first dimension added. In order for this call to succeed, * {@link Shape#isUnknown()} must be {@code false}. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java new file mode 100644 index 00000000000..a77019f69da --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java @@ -0,0 +1,123 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Collections; +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +@Operator +public abstract class BooleanMask { + + @Endpoint(name = "booleanMask") + public static Operand create(Scope scope, Operand x, Operand mask, + Options... options) { + + //TODO naming to match python + + int axis = 0; + if (options != null) { + for (Options opts : options) { + if (opts.axis != null) { + axis = opts.axis; + } + } + } + + if (axis < 0) { + axis += x.rank(); + } + + Shape maskShape = mask.shape(); + Shape tensorShape = x.shape(); + + if (maskShape.numDimensions() == 0) { + throw new IllegalArgumentException("Mask cannot be scalar."); + } + if (maskShape.hasUnknownDimension()) { + throw new IllegalArgumentException("Mask cannot have unknown number of dimensions"); + } + + Operand axisTensor = Constant.scalarOf(scope, axis); + Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions()); + if (!requiredMaskShape.isCompatibleWith(maskShape)) { + throw new IllegalArgumentException( + "Mask shape " + maskShape + " is not compatible with required mask shape: " + requiredMaskShape + "."); + } + + org.tensorflow.op.core.Shape liveShape = org.tensorflow.op.core.Shape.create(scope, x); + + Operand leadingSize = ReduceProd.create(scope, + StridedSliceHelper.stridedSlice(scope, + liveShape, + Indices.range(axis, axis + maskShape.numDimensions()) + ), + Constant.arrayOf(scope, 0) + ); + + Operand tensor = Reshape.create(scope, x, Concat.create( + scope, + Arrays.asList( + StridedSliceHelper.stridedSlice(scope, liveShape, Indices.to(axis)), + Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)), + StridedSliceHelper.stridedSlice(scope, liveShape, Indices.from(axis + maskShape.numDimensions())) + ), + Constant.scalarOf(scope, 0) + )); + Operand flatMask = Reshape.create(scope, mask, Constant.arrayOf(scope, -1)); + + Operand indices = Squeeze.create(scope, Where.create(scope, flatMask), Squeeze.axis(Collections.singletonList(1L))); + return Gather.create(scope, tensor, indices, axisTensor); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.BooleanMask} + */ + public static class Options { + + /** + * @param axis (Optional) The axis to mask from, or 0 if not set. + */ + public Options axis(Integer axis) { + this.axis = axis; + return this; + } + + /** + * @param axis (Optional) The axis to mask from, or 0 if not set. + */ + public Options axis(int axis) { + this.axis = axis; + return this; + } + + private Integer axis; + + private Options() { + } + } + +} From 76210462ad2d1e4ec458fb555b09e22519b504a8 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 6 Jan 2021 19:30:26 -0800 Subject: [PATCH 2/6] small fixes Signed-off-by: Ryan Nett --- .../src/gen/annotations/org/tensorflow/op/Ops.java | 9 +++++++++ .../main/java/org/tensorflow/op/core/BooleanMask.java | 5 +++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 529b0d99c39..b0740ac9a0f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -59,6 +59,7 @@ import org.tensorflow.op.core.BatchToSpace; import org.tensorflow.op.core.BatchToSpaceNd; import org.tensorflow.op.core.Bitcast; +import org.tensorflow.op.core.BooleanMask; import org.tensorflow.op.core.BroadcastDynamicShape; import org.tensorflow.op.core.BroadcastTo; import org.tensorflow.op.core.Bucketize; @@ -989,6 +990,14 @@ public Bitcast bitcast(Operand input, Clas return Bitcast.create(scope, input, type); } + /** + * empty + */ + public Operand booleanMask(Operand x, Operand mask, + BooleanMask.Options... options) { + return BooleanMask.create(scope, x, mask, options); + } + /** * Return the shape of s0 op s1 with broadcast. *

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java index a77019f69da..f9b8773f864 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java @@ -55,7 +55,7 @@ public static Operand create(Scope scope, Operand x, Ope Shape tensorShape = x.shape(); if (maskShape.numDimensions() == 0) { - throw new IllegalArgumentException("Mask cannot be scalar."); + throw new IllegalArgumentException("Mask cannot be a scalar."); } if (maskShape.hasUnknownDimension()) { throw new IllegalArgumentException("Mask cannot have unknown number of dimensions"); @@ -65,7 +65,7 @@ public static Operand create(Scope scope, Operand x, Ope Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions()); if (!requiredMaskShape.isCompatibleWith(maskShape)) { throw new IllegalArgumentException( - "Mask shape " + maskShape + " is not compatible with required mask shape: " + requiredMaskShape + "."); + "Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + "."); } org.tensorflow.op.core.Shape liveShape = org.tensorflow.op.core.Shape.create(scope, x); @@ -87,6 +87,7 @@ public static Operand create(Scope scope, Operand x, Ope ), Constant.scalarOf(scope, 0) )); + Operand flatMask = Reshape.create(scope, mask, Constant.arrayOf(scope, -1)); Operand indices = Squeeze.create(scope, Where.create(scope, flatMask), Squeeze.axis(Collections.singletonList(1L))); From a9e2294003bb765c1325115aa7a19bb0a099b323 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 8 Jan 2021 18:18:23 -0800 Subject: [PATCH 3/6] Javadoc and test Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 23 ++++++- .../main/java/org/tensorflow/op/Scope.java | 19 ++++++ .../org/tensorflow/op/core/BooleanMask.java | 56 ++++++++++++++-- .../tensorflow/op/core/BooleanMaskTest.java | 67 +++++++++++++++++++ 4 files changed, 155 insertions(+), 10 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index b0740ac9a0f..73a2b81e64e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -991,11 +991,28 @@ public Bitcast bitcast(Operand input, Clas } /** - * empty + * Apply boolean mask to tensor. + *

+ * Numpy equivalent is `tensor[mask]`. + *

+ * In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match + * the first K dimensions of {@code tensor}'s shape. We then have: + * {@code booleanMask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]} + * where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major order). + *

+ * The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default). + * In that case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match + * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. + * + * @param scope + * @param tensor The tensor to mask. + * @param mask The mask to apply. + * @param options carries optional attributes values + * @return The masked tensor. */ - public Operand booleanMask(Operand x, Operand mask, + public Operand booleanMask(Operand tensor, Operand mask, BooleanMask.Options... options) { - return BooleanMask.create(scope, x, mask, options); + return BooleanMask.create(scope, tensor, mask, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java index 73fa340a487..85e283d9260 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java @@ -125,6 +125,25 @@ public Scope withName(String opName) { return new Scope(env, nameScope.withName(opName), controlDependencies, deviceSpec); } + /** + * Returns a new scope where added operations will be prefixed by this scope's op name + * (set by {@link #withName(String)}), or the given default if it is unset. This is intended to be used for + * composite ops. + * + *

Ops created with this scope will have {@code name/opName/} as the prefix. The actual + * name will be unique in the returned scope. All other properties are inherited from the current + * scope. + * + *

The default child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*} + * + * @param defaultName name of the sub scope if this scope's name hasn't been set. + * @return a new subscope + * @throws IllegalArgumentException if the name is invalid + */ + public Scope withNameAsSubScope(String defaultName){ + return new Scope(env, nameScope.withSubScope(nameScope.makeOpName(defaultName)), controlDependencies, deviceSpec); + } + /** * Return a new scope that uses the provided device specification for an op. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java index f9b8773f864..af3dd91e2e0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java @@ -21,6 +21,7 @@ import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; @@ -32,11 +33,31 @@ @Operator public abstract class BooleanMask { + /** + * Apply boolean mask to tensor. Returns the flat array of each element corresponding to a {@code true} in the mask. + *

+ * Numpy equivalent is `tensor[mask]`. + *

+ * In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match + * the first K dimensions of {@code tensor}'s shape. We then have: + * {@code booleanMask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]} + * where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major order). + *

+ * The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default). + * In that case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match + * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. + * + * @param scope + * @param tensor The tensor to mask. + * @param mask The mask to apply. + * @param options carries optional attributes values + * @return The masked tensor. + */ @Endpoint(name = "booleanMask") - public static Operand create(Scope scope, Operand x, Operand mask, + public static Operand create(Scope scope, Operand tensor, Operand mask, Options... options) { - //TODO naming to match python + scope = scope.withNameAsSubScope("BooleanMask"); int axis = 0; if (options != null) { @@ -48,11 +69,11 @@ public static Operand create(Scope scope, Operand x, Ope } if (axis < 0) { - axis += x.rank(); + axis += tensor.rank(); } Shape maskShape = mask.shape(); - Shape tensorShape = x.shape(); + Shape tensorShape = tensor.shape(); if (maskShape.numDimensions() == 0) { throw new IllegalArgumentException("Mask cannot be a scalar."); @@ -68,7 +89,7 @@ public static Operand create(Scope scope, Operand x, Ope "Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + "."); } - org.tensorflow.op.core.Shape liveShape = org.tensorflow.op.core.Shape.create(scope, x); + org.tensorflow.op.core.Shape liveShape = org.tensorflow.op.core.Shape.create(scope, tensor); Operand leadingSize = ReduceProd.create(scope, StridedSliceHelper.stridedSlice(scope, @@ -78,7 +99,7 @@ public static Operand create(Scope scope, Operand x, Ope Constant.arrayOf(scope, 0) ); - Operand tensor = Reshape.create(scope, x, Concat.create( + Operand flattened = Reshape.create(scope, tensor, Concat.create( scope, Arrays.asList( StridedSliceHelper.stridedSlice(scope, liveShape, Indices.to(axis)), @@ -91,7 +112,28 @@ public static Operand create(Scope scope, Operand x, Ope Operand flatMask = Reshape.create(scope, mask, Constant.arrayOf(scope, -1)); Operand indices = Squeeze.create(scope, Where.create(scope, flatMask), Squeeze.axis(Collections.singletonList(1L))); - return Gather.create(scope, tensor, indices, axisTensor); + return Gather.create(scope, flattened, indices, axisTensor); + } + + /** + * Used to indicate the axis to mask from. + * {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match + * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. + * @param axis the axis to mask from. Uses 0 if null. + */ + public static Options axis(Integer axis){ + return new Options().axis(axis); + } + + + /** + * Used to indicate the axis to mask from. + * {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match + * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. + * @param axis the axis to mask from. + */ + public static Options axis(int axis){ + return new Options().axis(axis); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java new file mode 100644 index 00000000000..ff50fb65ee6 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java @@ -0,0 +1,67 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.op.core; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.Test; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class BooleanMaskTest { + @Test + public void testBooleanMask(){ + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Operand input = Constant.arrayOf(scope, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + Operand input2 = ExpandDims.create(scope, input, Constant.scalarOf(scope, 0)); + + Operand mask = Constant.arrayOf(scope, true, true, false, false, true, true, true, false, false, false); + + Operand output1 = BooleanMask.create(scope, input, mask); + Operand output2 = BooleanMask.create(scope, input2, mask, BooleanMask.axis(1)); + + try (TFloat32 result = (TFloat32) sess.runner().fetch(output1).run().get(0)) { + // expected shape from Python tensorflow + assertEquals(Shape.of(5), result.shape()); + assertEquals(result.getFloat(0), 0); + assertEquals(result.getFloat(1), 1); + assertEquals(result.getFloat(2), 4); + assertEquals(result.getFloat(3), 5); + assertEquals(result.getFloat(4), 6); + } + + try (TFloat32 result = (TFloat32) sess.runner().fetch(output2).run().get(0)) { + // expected shape from Python tensorflow + assertEquals(Shape.of(5), result.shape()); + assertEquals(result.getFloat(0), 0); + assertEquals(result.getFloat(1), 1); + assertEquals(result.getFloat(2), 4); + assertEquals(result.getFloat(3), 5); + assertEquals(result.getFloat(4), 6); + } + } + } +} From afa900ef30390e1dd25c442ccaa15a5ee696b964 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 8 Jan 2021 19:29:17 -0800 Subject: [PATCH 4/6] Start of BooleanMaskUpdate Signed-off-by: Ryan Nett --- .../org/tensorflow/op/core/BooleanMask.java | 4 +- .../tensorflow/op/core/BooleanMaskUpdate.java | 153 ++++++++++++++++++ .../tensorflow/op/core/BooleanMaskTest.java | 20 +-- .../op/core/BooleanMaskUpdateTest.java | 94 +++++++++++ 4 files changed, 259 insertions(+), 12 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java index af3dd91e2e0..080e07a851c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java @@ -102,9 +102,9 @@ public static Operand create(Scope scope, Operand tensor Operand flattened = Reshape.create(scope, tensor, Concat.create( scope, Arrays.asList( - StridedSliceHelper.stridedSlice(scope, liveShape, Indices.to(axis)), + StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceTo(axis)), Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)), - StridedSliceHelper.stridedSlice(scope, liveShape, Indices.from(axis + maskShape.numDimensions())) + StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions())) ), Constant.scalarOf(scope, 0) )); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java new file mode 100644 index 00000000000..8acff79fe62 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java @@ -0,0 +1,153 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Collections; +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +@Operator +public abstract class BooleanMaskUpdate { + + /** + * TODO + * + * @param scope + * @param tensor The tensor to mask. + * @param mask The mask to apply. + * @param value the new values + * @param options carries optional attributes values + * @return The masked tensor. + */ + @Endpoint(name = "booleanMaskUpdate") + public static Operand create(Scope scope, Operand tensor, Operand mask, Operand value, + Options... options) { + + scope = scope.withNameAsSubScope("BooleanMaskUpdate"); + + int axis = 0; + if (options != null) { + for (Options opts : options) { + if (opts.axis != null) { + axis = opts.axis; + } + } + } + + if (axis < 0) { + axis += tensor.rank(); + } + + Shape maskShape = mask.shape(); + Shape tensorShape = tensor.shape(); + + if (maskShape.numDimensions() == 0) { + throw new IllegalArgumentException("Mask cannot be a scalar."); + } + if (maskShape.hasUnknownDimension()) { + throw new IllegalArgumentException("Mask cannot have unknown number of dimensions"); + } + + Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions()); + if (!requiredMaskShape.isCompatibleWith(maskShape)) { + throw new IllegalArgumentException( + "Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + "."); + } + + org.tensorflow.op.core.Shape liveShape = org.tensorflow.op.core.Shape.create(scope, tensor); + + Operand leadingSize = ReduceProd.create(scope, + StridedSliceHelper.stridedSlice(scope, + liveShape, + Indices.sliceTo(axis + maskShape.numDimensions()) + ), + Constant.arrayOf(scope, 0) + ); + + Operand reshaped = Reshape.create(scope, tensor, Concat.create( + scope, + Arrays.asList( + Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)), + StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions())) + ), + Constant.scalarOf(scope, 0) + )); + + Operand indices = Where.create(scope, mask); + //TODO I'd like to broadcast value to the required shape. Need to figure out the shape first + Operand newValue = TensorScatterNdUpdate.create(scope, reshaped, indices, value); + return Reshape.create(scope, newValue, liveShape); + } + + /** + * Used to indicate the axis to mask from. + * {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match + * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. + * @param axis the axis to mask from. Uses 0 if null. + */ + public static Options axis(Integer axis){ + return new Options().axis(axis); + } + + + /** + * Used to indicate the axis to mask from. + * {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match + * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. + * @param axis the axis to mask from. + */ + public static Options axis(int axis){ + return new Options().axis(axis); + } + + /** + * Optional attributes for {@link BooleanMaskUpdate} + */ + public static class Options { + + /** + * @param axis (Optional) The axis to mask from, or 0 if not set. + */ + public Options axis(Integer axis) { + this.axis = axis; + return this; + } + + /** + * @param axis (Optional) The axis to mask from, or 0 if not set. + */ + public Options axis(int axis) { + this.axis = axis; + return this; + } + + private Integer axis; + + private Options() { + } + } + +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java index ff50fb65ee6..a4d9293ccf8 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java @@ -46,21 +46,21 @@ public void testBooleanMask(){ try (TFloat32 result = (TFloat32) sess.runner().fetch(output1).run().get(0)) { // expected shape from Python tensorflow assertEquals(Shape.of(5), result.shape()); - assertEquals(result.getFloat(0), 0); - assertEquals(result.getFloat(1), 1); - assertEquals(result.getFloat(2), 4); - assertEquals(result.getFloat(3), 5); - assertEquals(result.getFloat(4), 6); + assertEquals(0, result.getFloat(0)); + assertEquals(1, result.getFloat(1)); + assertEquals(4, result.getFloat(2)); + assertEquals(5, result.getFloat(3)); + assertEquals(6, result.getFloat(4)); } try (TFloat32 result = (TFloat32) sess.runner().fetch(output2).run().get(0)) { // expected shape from Python tensorflow assertEquals(Shape.of(5), result.shape()); - assertEquals(result.getFloat(0), 0); - assertEquals(result.getFloat(1), 1); - assertEquals(result.getFloat(2), 4); - assertEquals(result.getFloat(3), 5); - assertEquals(result.getFloat(4), 6); + assertEquals(0, result.getFloat(0)); + assertEquals(1, result.getFloat(1)); + assertEquals(4, result.getFloat(2)); + assertEquals(5, result.getFloat(3)); + assertEquals(6, result.getFloat(4)); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java new file mode 100644 index 00000000000..6bdd0edf293 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java @@ -0,0 +1,94 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.op.core; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.Test; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.op.Scope; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class BooleanMaskUpdateTest { + + @Test + public void testBooleanMaskUpdateSlice() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Operand input = Constant.tensorOf(scope, new int[][]{ {0, 0, 0}, {1, 1, 1}, {2, 2, 2}}); + + Operand mask = Constant.arrayOf(scope, true, false, false); + + Operand value = Constant.tensorOf(scope, new int[][]{{-1, -1, -1}}); + + Operand output = BooleanMaskUpdate.create(scope, input, mask, value); + + try (TFloat32 result = (TFloat32) sess.runner().fetch(output).run().get(0)) { + // expected shape from Python tensorflow + assertEquals(Shape.of(3, 3), result.shape()); + assertEquals(-1, result.getFloat(0, 0)); + assertEquals(-1, result.getFloat(0, 1)); + assertEquals(-1, result.getFloat(0, 2)); + assertEquals(1, result.getFloat(1, 0)); + assertEquals(1, result.getFloat(1, 1)); + assertEquals(1, result.getFloat(1, 2)); + assertEquals(2, result.getFloat(2, 0)); + assertEquals(2, result.getFloat(2, 1)); + assertEquals(2, result.getFloat(2, 2)); + } + } + } + + @Test + public void testBooleanMaskUpdateAxis() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Operand input = Constant.tensorOf(scope, new int[][][]{{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}}); + + Operand mask = Constant.arrayOf(scope, true, true, false, false, true, true, true, false, false, false); + + Operand value = Constant.arrayOf(scope, -1, -1, -1, -1, -1); + + Operand output = BooleanMaskUpdate.create(scope, input, mask, value, BooleanMaskUpdate.axis(2)); + + try (TFloat32 result = (TFloat32) sess.runner().fetch(output).run().get(0)) { + // expected shape from Python tensorflow + assertEquals(Shape.of(1, 1, 10), result.shape()); + assertEquals(-1, result.getFloat(0, 0, 0)); + assertEquals(-1, result.getFloat(0, 0, 1)); + assertEquals(2, result.getFloat(0, 0, 2)); + assertEquals(3, result.getFloat(0, 0, 3)); + assertEquals(-1, result.getFloat(0, 0, 4)); + assertEquals(-1, result.getFloat(0, 0, 5)); + assertEquals(-1, result.getFloat(0, 0, 6)); + assertEquals(7, result.getFloat(0, 0, 7)); + assertEquals(8, result.getFloat(0, 0, 8)); + assertEquals(9, result.getFloat(0, 0, 9)); + } + } + } +} From 59d2482e7d4df9c021fef983b2f0f1ce7b64f507 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 8 Jan 2021 20:18:19 -0800 Subject: [PATCH 5/6] optional broadcasting + test Signed-off-by: Ryan Nett --- .../org/tensorflow/op/core/BooleanMask.java | 8 -- .../tensorflow/op/core/BooleanMaskUpdate.java | 102 ++++++++++++++---- .../op/core/BooleanMaskUpdateTest.java | 66 +++++++----- 3 files changed, 124 insertions(+), 52 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java index 080e07a851c..39473ae45a4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java @@ -149,14 +149,6 @@ public Options axis(Integer axis) { return this; } - /** - * @param axis (Optional) The axis to mask from, or 0 if not set. - */ - public Options axis(int axis) { - this.axis = axis; - return this; - } - private Integer axis; private Options() { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java index 8acff79fe62..db64ac6f39e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java @@ -32,28 +32,74 @@ @Operator public abstract class BooleanMaskUpdate { + /* + Python: + def boolean_mask_update(tensor, mask, update, axis=0, name="boolean_mask_update"): + with tf.name_scope(name): + tensor = tf.convert_to_tensor(tensor, name="tensor") + mask = tf.convert_to_tensor(mask, name="mask") + update = tf.convert_to_tensor(update, name="value") + + shape_mask = mask.get_shape() + ndims_mask = shape_mask.ndims + shape_tensor = tensor.get_shape() + if ndims_mask == 0: + raise ValueError("mask cannot be scalar.") + if ndims_mask is None: + raise ValueError( + "Number of mask dimensions must be specified, even if some dimensions" + " are None. E.g. shape=[None] is ok, but shape=None is not.") + axis = 0 if axis is None else axis + axis_value = tf.constant(axis) + if axis_value is not None: + axis = axis_value + shape_tensor[axis:axis + ndims_mask].assert_is_compatible_with(shape_mask) + + leading_size = tf.reduce_prod(tf.shape(tensor)[:axis + ndims_mask], [0]) + innerShape = tf.shape(tensor)[axis + ndims_mask:] + + tensor = tf.reshape( + tensor, + tf.concat([ + [leading_size], + innerShape + ], 0)) + + indices = tf.where(mask) + + updateShape = tf.concat([tf.shape(indices)[:-1], innerShape], 0) + + update = tf.broadcast_to(update, updateShape) + result = tf.tensor_scatter_nd_update(tensor, indices, update) + return tf.reshape(result, shape_tensor) + */ + /** * TODO * - * @param scope * @param tensor The tensor to mask. * @param mask The mask to apply. - * @param value the new values + * @param updates the new values * @param options carries optional attributes values * @return The masked tensor. */ @Endpoint(name = "booleanMaskUpdate") - public static Operand create(Scope scope, Operand tensor, Operand mask, Operand value, + public static Operand create(Scope scope, Operand tensor, Operand mask, + Operand updates, Options... options) { scope = scope.withNameAsSubScope("BooleanMaskUpdate"); int axis = 0; + boolean broadcast = true; if (options != null) { for (Options opts : options) { if (opts.axis != null) { axis = opts.axis; } + if (opts.broadcast != null) { + broadcast = opts.broadcast; + } } } @@ -77,7 +123,7 @@ public static Operand create(Scope scope, Operand tensor "Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + "."); } - org.tensorflow.op.core.Shape liveShape = org.tensorflow.op.core.Shape.create(scope, tensor); + Operand liveShape = org.tensorflow.op.core.Shape.create(scope, tensor); Operand leadingSize = ReduceProd.create(scope, StridedSliceHelper.stridedSlice(scope, @@ -87,40 +133,55 @@ public static Operand create(Scope scope, Operand tensor Constant.arrayOf(scope, 0) ); + Operand innerShape = StridedSliceHelper + .stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions())); + Operand reshaped = Reshape.create(scope, tensor, Concat.create( scope, Arrays.asList( Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)), - StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions())) + innerShape ), Constant.scalarOf(scope, 0) )); Operand indices = Where.create(scope, mask); - //TODO I'd like to broadcast value to the required shape. Need to figure out the shape first - Operand newValue = TensorScatterNdUpdate.create(scope, reshaped, indices, value); + + if(broadcast) { + Operand indicesShape = org.tensorflow.op.core.Shape.create(scope, indices); + Operand batchShape = StridedSliceHelper.stridedSlice(scope, indicesShape, Indices.sliceTo(-1)); + + Operand updateShape = Concat.create( + scope, + Arrays.asList( + batchShape, + innerShape + ), + Constant.scalarOf(scope, 0) + ); + + updates = BroadcastTo.create(scope, updates, updateShape); + } + + Operand newValue = TensorScatterNdUpdate.create(scope, reshaped, indices, updates); return Reshape.create(scope, newValue, liveShape); } /** - * Used to indicate the axis to mask from. - * {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match + * Used to indicate the axis to mask from. {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. + * * @param axis the axis to mask from. Uses 0 if null. */ - public static Options axis(Integer axis){ + public static Options axis(Integer axis) { return new Options().axis(axis); } - /** - * Used to indicate the axis to mask from. - * {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match - * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. - * @param axis the axis to mask from. + * Whether to try broadcasting update. True by default. */ - public static Options axis(int axis){ - return new Options().axis(axis); + public static Options broadcast(Boolean broadcast){ + return new Options().broadcast(broadcast); } /** @@ -137,14 +198,15 @@ public Options axis(Integer axis) { } /** - * @param axis (Optional) The axis to mask from, or 0 if not set. + * @param broadcast (Optional) Whether to try broadcasting update. True by default. */ - public Options axis(int axis) { - this.axis = axis; + public Options broadcast(Boolean broadcast) { + this.broadcast = broadcast; return this; } private Integer axis; + private Boolean broadcast; private Options() { } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java index 6bdd0edf293..63187f98047 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java @@ -18,10 +18,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import java.util.List; import org.junit.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; +import org.tensorflow.Session.Run; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.index.Indices; import org.tensorflow.op.Scope; @@ -37,7 +40,7 @@ public void testBooleanMaskUpdateSlice() { Session sess = new Session(g)) { Scope scope = new Scope(g); - Operand input = Constant.tensorOf(scope, new int[][]{ {0, 0, 0}, {1, 1, 1}, {2, 2, 2}}); + Operand input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}}); Operand mask = Constant.arrayOf(scope, true, false, false); @@ -45,18 +48,25 @@ public void testBooleanMaskUpdateSlice() { Operand output = BooleanMaskUpdate.create(scope, input, mask, value); - try (TFloat32 result = (TFloat32) sess.runner().fetch(output).run().get(0)) { - // expected shape from Python tensorflow + Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); + + List results = sess.runner().fetch(output).fetch(bcastOutput).run(); + try (TInt32 result = (TInt32) results.get(0); + TInt32 bcastResult = (TInt32) results.get(1)) { + assertEquals(Shape.of(3, 3), result.shape()); - assertEquals(-1, result.getFloat(0, 0)); - assertEquals(-1, result.getFloat(0, 1)); - assertEquals(-1, result.getFloat(0, 2)); - assertEquals(1, result.getFloat(1, 0)); - assertEquals(1, result.getFloat(1, 1)); - assertEquals(1, result.getFloat(1, 2)); - assertEquals(2, result.getFloat(2, 0)); - assertEquals(2, result.getFloat(2, 1)); - assertEquals(2, result.getFloat(2, 2)); + + assertEquals(-1, result.getInt(0, 0)); + assertEquals(-1, result.getInt(0, 1)); + assertEquals(-1, result.getInt(0, 2)); + assertEquals(1, result.getInt(1, 0)); + assertEquals(1, result.getInt(1, 1)); + assertEquals(1, result.getInt(1, 2)); + assertEquals(2, result.getInt(2, 0)); + assertEquals(2, result.getInt(2, 1)); + assertEquals(2, result.getInt(2, 2)); + + assertEquals(result, bcastResult); } } } @@ -75,19 +85,27 @@ public void testBooleanMaskUpdateAxis() { Operand output = BooleanMaskUpdate.create(scope, input, mask, value, BooleanMaskUpdate.axis(2)); - try (TFloat32 result = (TFloat32) sess.runner().fetch(output).run().get(0)) { - // expected shape from Python tensorflow + Operand bcastOutput = BooleanMaskUpdate + .create(scope, input, mask, Constant.scalarOf(scope, -1), BooleanMaskUpdate.axis(2)); + + List results = sess.runner().fetch(output).fetch(bcastOutput).run(); + try (TInt32 result = (TInt32) results.get(0); + TInt32 bcastResult = (TInt32) results.get(1)) { + assertEquals(Shape.of(1, 1, 10), result.shape()); - assertEquals(-1, result.getFloat(0, 0, 0)); - assertEquals(-1, result.getFloat(0, 0, 1)); - assertEquals(2, result.getFloat(0, 0, 2)); - assertEquals(3, result.getFloat(0, 0, 3)); - assertEquals(-1, result.getFloat(0, 0, 4)); - assertEquals(-1, result.getFloat(0, 0, 5)); - assertEquals(-1, result.getFloat(0, 0, 6)); - assertEquals(7, result.getFloat(0, 0, 7)); - assertEquals(8, result.getFloat(0, 0, 8)); - assertEquals(9, result.getFloat(0, 0, 9)); + + assertEquals(-1, result.getInt(0, 0, 0)); + assertEquals(-1, result.getInt(0, 0, 1)); + assertEquals(2, result.getInt(0, 0, 2)); + assertEquals(3, result.getInt(0, 0, 3)); + assertEquals(-1, result.getInt(0, 0, 4)); + assertEquals(-1, result.getInt(0, 0, 5)); + assertEquals(-1, result.getInt(0, 0, 6)); + assertEquals(7, result.getInt(0, 0, 7)); + assertEquals(8, result.getInt(0, 0, 8)); + assertEquals(9, result.getInt(0, 0, 9)); + + assertEquals(result, bcastResult); } } } From 55979818c2277a0b86de151c405618b701569c84 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 10 Feb 2021 22:16:43 -0800 Subject: [PATCH 6/6] Javadoc and rebase Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 54 +++++++++++---- .../org/tensorflow/op/core/BooleanMask.java | 3 +- .../tensorflow/op/core/BooleanMaskUpdate.java | 66 ++++++------------- .../op/core/BooleanMaskUpdateTest.java | 40 ++++++++++- 4 files changed, 101 insertions(+), 62 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 73a2b81e64e..b0fb67b5ce1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -60,6 +60,7 @@ import org.tensorflow.op.core.BatchToSpaceNd; import org.tensorflow.op.core.Bitcast; import org.tensorflow.op.core.BooleanMask; +import org.tensorflow.op.core.BooleanMaskUpdate; import org.tensorflow.op.core.BroadcastDynamicShape; import org.tensorflow.op.core.BroadcastTo; import org.tensorflow.op.core.Bucketize; @@ -348,10 +349,10 @@ public final class Ops { public final SignalOps signal; - public final QuantizationOps quantization; - public final TrainOps train; + public final QuantizationOps quantization; + private final Scope scope; private Ops(Scope scope) { @@ -373,8 +374,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - quantization = new QuantizationOps(this); train = new TrainOps(this); + quantization = new QuantizationOps(this); } /** @@ -991,9 +992,9 @@ public Bitcast bitcast(Operand input, Clas } /** - * Apply boolean mask to tensor. + * Apply boolean mask to tensor. Returns the flat array of each element corresponding to a {@code true} in the mask. *

- * Numpy equivalent is `tensor[mask]`. + * Numpy equivalent is {@code tensor[mask]}. *

* In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match * the first K dimensions of {@code tensor}'s shape. We then have: @@ -1015,6 +1016,36 @@ public Operand booleanMask(Operand tensor, Operand + * Numpy equivalent is `tensor[mask] = updates`. + *

+ * In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match the first K dimensions of + * {@code tensor}'s shape. We then have: {@code booleanMask(tensor, mask)[i, j1,...,jd] = + * tensor[i1,...,iK,j1,...,jd]} where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major + * order). + *

+ * The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default). In that + * case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match the first {@code axis + + * dim(mask)} dimensions of {@code tensor}'s shape. + *

+ * The shape of {@code updates} should be {@code [n, t_1, t_2, ...]} where {@code n} is the number of true values in + * {@code mask} and {@code t_i} is the {@code i}th dimension of {@code tensor} after {@code axis} and {@code mask}. + * {@code updates} will be broadcasted to this shape by default, which can be disabled using {@code options}. + * + * @param tensor The tensor to mask. + * @param mask The mask to apply. + * @param updates the new values + * @param options carries optional attributes values + * @return The masked tensor. + */ + public Operand booleanMaskUpdate(Operand tensor, Operand mask, + Operand updates, BooleanMaskUpdate.Options... options) { + return BooleanMaskUpdate.create(scope, tensor, mask, updates, options); + } + /** * Return the shape of s0 op s1 with broadcast. *

@@ -1860,13 +1891,14 @@ public Constant constant(Shape shape, IntDataBuffer data) { } /** - * Creates a scalar of {@code type}, with the value of {@code number}. - * {@code number} may be truncated if it does not fit in the target type. + * Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not + * fit in the target type. * * @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating}) * @param number the value of the tensor * @return a constant of the passed type - * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or unknown. + * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or + * unknown. */ public Constant constant(Class type, Number number) { return Constant.tensorOf(scope, type, number); @@ -1918,14 +1950,14 @@ public Constant constantOf(T tensor) { } /** - * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. - * {@code number} may be truncated if it does not fit in the target type. + * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be + * truncated if it does not fit in the target type. * * @param toMatch the operand providing the target type * @param number the value of the tensor * @return a constant with the same type as {@code toMatch} - * @see Ops#constant(Class, Number) * @throws IllegalArgumentException if the type is unknown (which should be impossible). + * @see Ops#constant(Class, Number) */ public Constant constantOfSameType(Operand toMatch, Number number) { return Constant.tensorOfSameType(scope, toMatch, number); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java index 39473ae45a4..85a41ef485f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java @@ -21,7 +21,6 @@ import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.index.Indices; -import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; @@ -36,7 +35,7 @@ public abstract class BooleanMask { /** * Apply boolean mask to tensor. Returns the flat array of each element corresponding to a {@code true} in the mask. *

- * Numpy equivalent is `tensor[mask]`. + * Numpy equivalent is {@code tensor[mask]}. *

* In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match * the first K dimensions of {@code tensor}'s shape. We then have: diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java index db64ac6f39e..a40ae7ab017 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java @@ -17,7 +17,6 @@ package org.tensorflow.op.core; import java.util.Arrays; -import java.util.Collections; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.index.Indices; @@ -32,50 +31,24 @@ @Operator public abstract class BooleanMaskUpdate { - /* - Python: - def boolean_mask_update(tensor, mask, update, axis=0, name="boolean_mask_update"): - with tf.name_scope(name): - tensor = tf.convert_to_tensor(tensor, name="tensor") - mask = tf.convert_to_tensor(mask, name="mask") - update = tf.convert_to_tensor(update, name="value") - - shape_mask = mask.get_shape() - ndims_mask = shape_mask.ndims - shape_tensor = tensor.get_shape() - if ndims_mask == 0: - raise ValueError("mask cannot be scalar.") - if ndims_mask is None: - raise ValueError( - "Number of mask dimensions must be specified, even if some dimensions" - " are None. E.g. shape=[None] is ok, but shape=None is not.") - axis = 0 if axis is None else axis - axis_value = tf.constant(axis) - if axis_value is not None: - axis = axis_value - shape_tensor[axis:axis + ndims_mask].assert_is_compatible_with(shape_mask) - - leading_size = tf.reduce_prod(tf.shape(tensor)[:axis + ndims_mask], [0]) - innerShape = tf.shape(tensor)[axis + ndims_mask:] - - tensor = tf.reshape( - tensor, - tf.concat([ - [leading_size], - innerShape - ], 0)) - - indices = tf.where(mask) - - updateShape = tf.concat([tf.shape(indices)[:-1], innerShape], 0) - - update = tf.broadcast_to(update, updateShape) - result = tf.tensor_scatter_nd_update(tensor, indices, update) - return tf.reshape(result, shape_tensor) - */ - /** - * TODO + * Updates a tensor at the masked values, and returns the updated tensor. Does not mutate the input tensors. {@code + * updates} will be broadcasted by default + *

+ * Numpy equivalent is `tensor[mask] = updates`. + *

+ * In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match the first K dimensions of + * {@code tensor}'s shape. We then have: {@code booleanMask(tensor, mask)[i, j1,...,jd] = + * tensor[i1,...,iK,j1,...,jd]} where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major + * order). + *

+ * The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default). In that + * case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match the first {@code axis + + * dim(mask)} dimensions of {@code tensor}'s shape. + *

+ * The shape of {@code updates} should be {@code [n, t_1, t_2, ...]} where {@code n} is the number of true values in + * {@code mask} and {@code t_i} is the {@code i}th dimension of {@code tensor} after {@code axis} and {@code mask}. + * {@code updates} will be broadcasted to this shape by default, which can be disabled using {@code options}. * * @param tensor The tensor to mask. * @param mask The mask to apply. @@ -147,8 +120,9 @@ public static Operand create(Scope scope, Operand tensor Operand indices = Where.create(scope, mask); - if(broadcast) { + if (broadcast) { Operand indicesShape = org.tensorflow.op.core.Shape.create(scope, indices); + // this is the number of true values Operand batchShape = StridedSliceHelper.stridedSlice(scope, indicesShape, Indices.sliceTo(-1)); Operand updateShape = Concat.create( @@ -180,7 +154,7 @@ public static Options axis(Integer axis) { /** * Whether to try broadcasting update. True by default. */ - public static Options broadcast(Boolean broadcast){ + public static Options broadcast(Boolean broadcast) { return new Options().broadcast(broadcast); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java index 63187f98047..ab852bbffb2 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java @@ -23,13 +23,10 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Session.Run; import org.tensorflow.Tensor; import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.index.Indices; import org.tensorflow.op.Scope; import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; public class BooleanMaskUpdateTest { @@ -71,6 +68,43 @@ public void testBooleanMaskUpdateSlice() { } } + @Test + public void testBooleanMaskUpdateSliceWithBroadcast() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Operand input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}}); + + Operand mask = Constant.arrayOf(scope, true, false, false); + + Operand value = Constant.vectorOf(scope, new int[]{-1, -1, -1}); + + Operand output = BooleanMaskUpdate.create(scope, input, mask, value); + + Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); + + List results = sess.runner().fetch(output).fetch(bcastOutput).run(); + try (TInt32 result = (TInt32) results.get(0); + TInt32 bcastResult = (TInt32) results.get(1)) { + + assertEquals(Shape.of(3, 3), result.shape()); + + assertEquals(-1, result.getInt(0, 0)); + assertEquals(-1, result.getInt(0, 1)); + assertEquals(-1, result.getInt(0, 2)); + assertEquals(1, result.getInt(1, 0)); + assertEquals(1, result.getInt(1, 1)); + assertEquals(1, result.getInt(1, 2)); + assertEquals(2, result.getInt(2, 0)); + assertEquals(2, result.getInt(2, 1)); + assertEquals(2, result.getInt(2, 2)); + + assertEquals(result, bcastResult); + } + } + } + @Test public void testBooleanMaskUpdateAxis() { try (Graph g = new Graph();