26
26
import com .google .protobuf .InvalidProtocolBufferException ;
27
27
import java .util .Collections ;
28
28
import java .util .IdentityHashMap ;
29
+ import java .util .Locale ;
29
30
import java .util .Set ;
30
31
import java .util .stream .Collectors ;
31
32
import org .bytedeco .javacpp .PointerPointer ;
@@ -193,9 +194,8 @@ private static synchronized boolean hasGradient(String opType) {
193
194
* <p>Note that this only works with graph gradients, and will eventually be deprecated in favor
194
195
* of unified gradient support once it is fully supported by tensorflow core.
195
196
*
196
- * <p><i>Warning: Custom gradient registration is currently not supported on Windows and may crash
197
- * the JVM if trying to do so. See <a href=https://github.com/tensorflow/java/issues/486>GitHub
198
- * issue</a> related to this.</i>
197
+ * <p><i>Warning: Custom gradient registration is currently not supported on Windows, see <a
198
+ * href=https://github.com/tensorflow/java/issues/486>GitHub issue</a> for more info.</i>
199
199
*
200
200
* @param opType the type of op to register the gradient for. Should usually be an {@code OP_NAME}
201
201
* field, i.e. {@link Add#OP_NAME}.
@@ -205,6 +205,10 @@ private static synchronized boolean hasGradient(String opType) {
205
205
*/
206
206
public static synchronized boolean registerCustomGradient (
207
207
String opType , RawCustomGradient gradient ) {
208
+ if (isWindowsOs ()) {
209
+ throw new UnsupportedOperationException (
210
+ "Custom gradient registration is not supported on Windows systems." );
211
+ }
208
212
if (hasGradient (opType )) {
209
213
return false ;
210
214
}
@@ -220,9 +224,8 @@ public static synchronized boolean registerCustomGradient(
220
224
* generated op classes or custom op classes with the correct annotations. To operate on the
221
225
* {@link org.tensorflow.GraphOperation} directly use {@link RawCustomGradient}.
222
226
*
223
- * <p><i>Warning: Custom gradient registration is currently not supported on Windows and may crash
224
- * the JVM if trying to do so. See <a href=https://github.com/tensorflow/java/issues/486>GitHub
225
- * issue</a> related to this.</i>
227
+ * <p><i>Warning: Custom gradient registration is currently not supported on Windows, see <a
228
+ * href=https://github.com/tensorflow/java/issues/486>GitHub issue</a> for more info.</i>
226
229
*
227
230
* @param inputClass the inputs class of op to register the gradient for.
228
231
* @param gradient the gradient function to use
@@ -233,8 +236,11 @@ public static synchronized boolean registerCustomGradient(
233
236
*/
234
237
public static synchronized <T extends RawOpInputs <?>> boolean registerCustomGradient (
235
238
Class <T > inputClass , CustomGradient <T > gradient ) {
239
+ if (isWindowsOs ()) {
240
+ throw new UnsupportedOperationException (
241
+ "Custom gradient registration is not supported on Windows systems." );
242
+ }
236
243
OpInputsMetadata metadata = inputClass .getAnnotation (OpInputsMetadata .class );
237
-
238
244
if (metadata == null ) {
239
245
throw new IllegalArgumentException (
240
246
"Inputs Class "
@@ -261,4 +267,8 @@ public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGrad
261
267
gradientFuncs .add (g );
262
268
return true ;
263
269
}
270
+
271
+ private static boolean isWindowsOs () {
272
+ return System .getProperty ("os.name" , "" ).toLowerCase (Locale .ENGLISH ).startsWith ("win" );
273
+ }
264
274
}
0 commit comments