From d0434de773d22587620e6cb299f8ce3a92b60de0 Mon Sep 17 00:00:00 2001 From: Samuel Audet Date: Fri, 26 Mar 2021 11:38:16 +0900 Subject: [PATCH] Fix memory leak occuring with new layout of string tensors Also fix and actually use JavaCPP deallocator for TF_Session as well --- .../src/main/java/org/tensorflow/Session.java | 13 ++------- .../buffer/ByteSequenceTensorBuffer.java | 2 -- .../internal/c_api/AbstractTF_Session.java | 14 +++++---- .../internal/c_api/AbstractTF_Tensor.java | 29 ++++++++++++++++++- 4 files changed, 40 insertions(+), 18 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index b67f4a611e6..8a554f98247 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -604,23 +604,16 @@ private static TF_Session allocate(TF_Graph graphHandle, String target, ConfigPr status.throwExceptionIfNotOK(); } - TF_Session session = TF_NewSession(graphHandle, opts, status); + TF_Session session = TF_Session.newSession(graphHandle, opts, status); status.throwExceptionIfNotOK(); - return session; + return session.retainReference(); } } private static void delete(TF_Session handle) { requireHandle(handle); - - try (PointerScope scope = new PointerScope()) { - TF_Status status = TF_Status.newStatus(); - TF_CloseSession(handle, status); - // Result of close is ignored, delete anyway. - TF_DeleteSession(handle, status); - status.throwExceptionIfNotOK(); - } + handle.releaseReference(); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java index 3bcbd7f8022..acaeaedbc11 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java @@ -19,7 +19,6 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_TString_Assign; import static org.tensorflow.internal.c_api.global.tensorflow.TF_TString_Copy; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_TString_Init; import static org.tensorflow.internal.c_api.global.tensorflow.TF_TString_GetDataPointer; import static org.tensorflow.internal.c_api.global.tensorflow.TF_TString_GetSize; @@ -127,7 +126,6 @@ private class InitDataWriter { void writeNext(byte[] bytes) { try (PointerScope scope = new PointerScope()) { TF_TString tstring = data.getPointer(index++); - TF_TString_Init(tstring); TF_TString_Copy(tstring, new BytePointer(bytes), bytes.length); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Session.java index 126acc1afbf..13bbd0bc3e5 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Session.java @@ -25,6 +25,7 @@ import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; +import org.bytedeco.javacpp.PointerScope; import org.bytedeco.javacpp.annotation.Properties; @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) @@ -33,11 +34,14 @@ protected static class DeleteDeallocator extends TF_Session implements Pointer.D DeleteDeallocator(TF_Session s) { super(s); } @Override public void deallocate() { if (!isNull()) { - TF_Status status = TF_Status.newStatus(); - TF_CloseSession(this, status); - // Result of close is ignored, delete anyway. - TF_DeleteSession(this, status); - setNull(); + try (PointerScope scope = new PointerScope()) { + TF_Status status = TF_Status.newStatus(); + TF_CloseSession(this, status); + // Result of close is ignored, delete anyway. + TF_DeleteSession(this, status); + status.throwExceptionIfNotOK(); + setNull(); + } } } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java index fba056c6dcb..b4b498f95ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java @@ -20,6 +20,12 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_AllocateTensor; import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteTensor; import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewTensor; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_STRING; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_TString_Dealloc; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_TString_Init; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorData; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorElementCount; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorType; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.annotation.Properties; @@ -28,7 +34,20 @@ public abstract class AbstractTF_Tensor extends Pointer { protected static class DeleteDeallocator extends TF_Tensor implements Pointer.Deallocator { DeleteDeallocator(TF_Tensor s) { super(s); } - @Override public void deallocate() { if (!isNull()) TF_DeleteTensor(this); setNull(); } + @Override public void deallocate() { + if (!isNull()) { + if (TF_TensorType(this) == TF_STRING) { + // we need to deallocate the strings themselves before deallocating the tensor memory + long n = TF_TensorElementCount(this); + TF_TString data = new TF_TString(TF_TensorData(this)); + for (int i = 0; i < n; i++) { + TF_TString_Dealloc(data.position(i)); + } + } + TF_DeleteTensor(this); + } + setNull(); + } } /** TensorFlow crashes if we don't pass it a deallocator, so... */ @@ -61,6 +80,14 @@ public static TF_Tensor newTensor(int dtype, long[] dims, Pointer data) { public static TF_Tensor allocateTensor(int dtype, long[] dims, long length) { TF_Tensor t = TF_AllocateTensor(dtype, dims, dims.length, length); if (t != null) { + if (TF_TensorType(t) == TF_STRING) { + // we need to initialize the strings themselves after allocating the tensor memory + long n = TF_TensorElementCount(t); + TF_TString data = new TF_TString(TF_TensorData(t)); + for (int i = 0; i < n; i++) { + TF_TString_Init(data.position(i)); + } + } t.deallocator(new DeleteDeallocator(t)); } return t;