Skip to content

Commit a8f8dc0

Browse files
committed
Throw exception on custom gradient registration on Windows
1 parent f00a955 commit a8f8dc0

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import com.google.protobuf.InvalidProtocolBufferException;
2727
import java.util.Collections;
2828
import java.util.IdentityHashMap;
29+
import java.util.Locale;
2930
import java.util.Set;
3031
import java.util.stream.Collectors;
3132
import org.bytedeco.javacpp.PointerPointer;
@@ -193,9 +194,8 @@ private static synchronized boolean hasGradient(String opType) {
193194
* <p>Note that this only works with graph gradients, and will eventually be deprecated in favor
194195
* of unified gradient support once it is fully supported by tensorflow core.
195196
*
196-
* <p><i>Warning: Custom gradient registration is currently not supported on Windows and may crash
197-
* the JVM if trying to do so. See <a href=https://github.com/tensorflow/java/issues/486>GitHub
198-
* issue</a> related to this.</i>
197+
* <p><i>Warning: Custom gradient registration is currently not supported on Windows, see <a
198+
* href=https://github.com/tensorflow/java/issues/486>GitHub issue</a> for more info.</i>
199199
*
200200
* @param opType the type of op to register the gradient for. Should usually be an {@code OP_NAME}
201201
* field, i.e. {@link Add#OP_NAME}.
@@ -205,6 +205,10 @@ private static synchronized boolean hasGradient(String opType) {
205205
*/
206206
public static synchronized boolean registerCustomGradient(
207207
String opType, RawCustomGradient gradient) {
208+
if (isWindowsOs()) {
209+
throw new UnsupportedOperationException(
210+
"Custom gradient registration is not supported on Windows systems.");
211+
}
208212
if (hasGradient(opType)) {
209213
return false;
210214
}
@@ -220,9 +224,8 @@ public static synchronized boolean registerCustomGradient(
220224
* generated op classes or custom op classes with the correct annotations. To operate on the
221225
* {@link org.tensorflow.GraphOperation} directly use {@link RawCustomGradient}.
222226
*
223-
* <p><i>Warning: Custom gradient registration is currently not supported on Windows and may crash
224-
* the JVM if trying to do so. See <a href=https://github.com/tensorflow/java/issues/486>GitHub
225-
* issue</a> related to this.</i>
227+
* <p><i>Warning: Custom gradient registration is currently not supported on Windows, see <a
228+
* href=https://github.com/tensorflow/java/issues/486>GitHub issue</a> for more info.</i>
226229
*
227230
* @param inputClass the inputs class of op to register the gradient for.
228231
* @param gradient the gradient function to use
@@ -233,8 +236,11 @@ public static synchronized boolean registerCustomGradient(
233236
*/
234237
public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGradient(
235238
Class<T> inputClass, CustomGradient<T> gradient) {
239+
if (isWindowsOs()) {
240+
throw new UnsupportedOperationException(
241+
"Custom gradient registration is not supported on Windows systems.");
242+
}
236243
OpInputsMetadata metadata = inputClass.getAnnotation(OpInputsMetadata.class);
237-
238244
if (metadata == null) {
239245
throw new IllegalArgumentException(
240246
"Inputs Class "
@@ -261,4 +267,8 @@ public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGrad
261267
gradientFuncs.add(g);
262268
return true;
263269
}
270+
271+
private static boolean isWindowsOs() {
272+
return System.getProperty("os.name", "").toLowerCase(Locale.ENGLISH).startsWith("win");
273+
}
264274
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.Arrays;
2222
import org.junit.jupiter.api.Test;
2323
import org.junit.jupiter.api.condition.DisabledOnOs;
24+
import org.junit.jupiter.api.condition.EnabledOnOs;
2425
import org.junit.jupiter.api.condition.OS;
2526
import org.tensorflow.ndarray.index.Indices;
2627
import org.tensorflow.op.Ops;
@@ -44,8 +45,8 @@ public void testAlreadyExisting() {
4445
}));
4546
}
4647

47-
// FIXME: Since TF 2.10.1, this test is failing on Windows, because the whole JVM crashes when
48-
// calling the JavaCPP generated binding `NameMap.erase`. Disable it until we find a fix.
48+
// FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see
49+
// https://github.com/tensorflow/java/issues/486
4950
@DisabledOnOs(OS.WINDOWS)
5051
@Test
5152
public void testCustomGradient() {
@@ -76,6 +77,26 @@ public void testCustomGradient() {
7677
}
7778
}
7879

80+
@EnabledOnOs(OS.WINDOWS)
81+
@Test
82+
public void testCustomGradientThrowsOnWindows() {
83+
assertThrows(
84+
UnsupportedOperationException.class,
85+
() ->
86+
TensorFlow.registerCustomGradient(
87+
NthElement.OP_NAME,
88+
(tf, op, gradInputs) ->
89+
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));
90+
91+
assertThrows(
92+
UnsupportedOperationException.class,
93+
() ->
94+
TensorFlow.registerCustomGradient(
95+
NthElement.Inputs.class,
96+
(tf, op, gradInputs) ->
97+
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));
98+
}
99+
79100
private static Output<?>[] toArray(Output<?>... outputs) {
80101
return outputs;
81102
}

0 commit comments

Comments
 (0)