Skip to content

Commit e4e5d98

Browse files
authored
Merge pull request #852 from RedisAI/fix_tf_invalid_output_delete
fixed invalid delete of outputs after execution error in TF
2 parents cdab22e + f3e815e commit e4e5d98

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

src/backends/tensorflow.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,10 @@ int RAI_ModelRunTF(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error)
530530
outputTensorsValues, noutputs, NULL /* target_opers */, 0 /* ntargets */,
531531
NULL /* run_Metadata */, status);
532532

533+
bool delete_output = true;
533534
if (TF_GetCode(status) != TF_OK) {
534535
RAI_SetError(error, RAI_EMODELRUN, TF_Message(status));
536+
delete_output = false;
535537
goto cleanup;
536538
}
537539

@@ -575,8 +577,10 @@ int RAI_ModelRunTF(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error)
575577
}
576578
TF_DeleteTensor(inputTensorsValues[i]);
577579
}
578-
for (size_t i = 0; i < noutputs; i++) {
579-
TF_DeleteTensor(outputTensorsValues[i]);
580+
if (delete_output) {
581+
for (size_t i = 0; i < noutputs; i++) {
582+
TF_DeleteTensor(outputTensorsValues[i]);
583+
}
580584
}
581585
return res;
582586
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:98d6bb6cdfe22525a894c67b3992dc3e88b06730661817ee79c62c20b9dd09dd
3+
size 174203

tests/flow/tests_tensorflow.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,3 +759,14 @@ def run():
759759
env.assertEqual(out_values, [b'this is', b'the first batch'])
760760
out_values = con.execute_command('AI.TENSORGET', 'second_batch{1}', 'VALUES')
761761
env.assertEqual(out_values, [b'that is', b'the second batch'])
762+
763+
@skip_if_no_TF
764+
def test_bad_execution_model(env):
765+
con = get_connection(env, '{1}')
766+
767+
model_pb = load_file_content('frozen_bad_model.pb')
768+
ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'TF', DEVICE, 'INPUTS', 1, 'x', 'OUTPUTS', 1, 'Identity', 'BLOB', model_pb)
769+
env.assertEqual(ret, b'OK')
770+
con.execute_command('AI.TENSORSET', 'my_str_tensor{1}', 'STRING', 4, 'BLOB', "how do I extract keys from a dict into a list?\x00debug public static void main(string[] args) {...}\x00should I use def main()\x00type hinting for list?\x00")
771+
env.assertEqual(ret, b'OK')
772+
check_error(env, con, 'AI.MODELEXECUTE', 'm{1}', 'INPUTS', 1, 'my_str_tensor{1}', 'OUTPUTS', 1, 'foo{1}')

0 commit comments

Comments
 (0)