From 3a98d7772ecc46b39326e81e1a34ee1d7768e68f Mon Sep 17 00:00:00 2001 From: filipecosta90 Date: Tue, 16 Jun 2020 18:27:30 +0100 Subject: [PATCH 1/4] [add] added mobilenet model use case to quickly check for leaks on modelruns with large tensors (dagrun and modelrun variations) --- test/includes.py | 32 ++++++- .../mobilenet_v1_100_224_cpu_NxHxWxC.pb | 3 + .../mobilenet_v1_100_224_gpu_NxHxWxC.pb | 3 + ...bilenet_v1_100_224_gpu_NxHxWxC_fp16_trt.pb | 3 + .../mobilenet/mobilenet_v2_1.4_224_frozen.pb | 3 + test/test_data/mobilenet/model_saver.py | 49 +++++++++++ test/tests_sanitizer.py | 87 +++++++++++++++++++ test/tests_tensorflow.py | 15 +--- test/unit/.gitignore | 8 ++ test/unit/CMakeLists.txt | 20 +++++ 10 files changed, 207 insertions(+), 16 deletions(-) create mode 100644 test/test_data/mobilenet/mobilenet_v1_100_224_cpu_NxHxWxC.pb create mode 100644 test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC.pb create mode 100644 test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC_fp16_trt.pb create mode 100644 test/test_data/mobilenet/mobilenet_v2_1.4_224_frozen.pb create mode 100644 test/test_data/mobilenet/model_saver.py create mode 100644 test/tests_sanitizer.py create mode 100644 test/unit/.gitignore create mode 100644 test/unit/CMakeLists.txt diff --git a/test/includes.py b/test/includes.py index 5f72c8636..237f556a6 100755 --- a/test/includes.py +++ b/test/includes.py @@ -16,6 +16,7 @@ except: pass +MAX_ITERATIONS = 2 if os.environ.get("MAX_ITERATIONS") == None else os.environ.get("MAX_ITERATIONS") TEST_TF = os.environ.get("TEST_TF") != "0" and os.environ.get("WITH_TF") != "0" TEST_TFLITE = os.environ.get("TEST_TFLITE") != "0" and os.environ.get("WITH_TFLITE") != "0" TEST_PT = os.environ.get("TEST_PT") != "0" and os.environ.get("WITH_PT") != "0" @@ -24,7 +25,7 @@ DEVICE = os.environ.get('DEVICE', 'CPU').upper().encode('utf-8', 'ignore').decode('utf-8') VALGRIND = os.environ.get("VALGRIND") == "1" print(f"Running tests on {DEVICE}\n") - +print(f"Using a max of {MAX_ITERATIONS} iterations per test\n") # change this to make inference tests longer MAX_TRANSACTIONS=100 @@ -91,12 +92,35 @@ def load_resnet_test_data(): return model_pb, script, labels, img +def load_mobilenet_v1_test_data(): + test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') + labels_filename = os.path.join(test_data_path, 'imagenet_class_index.json') + image_filename = os.path.join(test_data_path, 'panda.jpg') + model_filename = os.path.join(test_data_path, 'mobilenet/mobilenet_v1_100_224_cpu_NxHxWxC.pb') + input_var = 'input' + output_var = 'MobilenetV1/Predictions/Reshape_1' + + with open(model_filename, 'rb') as f: + model_pb = f.read() + + with open(labels_filename, 'r') as f: + labels = json.load(f) + + img_height, img_width = 224, 224 + + img = imread(image_filename) + img = resize(img, (img_height, img_width), mode='constant', anti_aliasing=True) + img = img.astype(np.float32) + + return model_pb, input_var, output_var, labels, img -def load_mobilenet_test_data(): +def load_mobilenet_v2_test_data(): test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') labels_filename = os.path.join(test_data_path, 'imagenet_class_index.json') image_filename = os.path.join(test_data_path, 'panda.jpg') - model_filename = os.path.join(test_data_path, 'mobilenet_v2_1.4_224_frozen.pb') + model_filename = os.path.join(test_data_path, 'mobilenet/mobilenet_v2_1.4_224_frozen.pb') + input_var = 'input' + output_var = 'MobilenetV2/Predictions/Reshape_1' with open(model_filename, 'rb') as f: model_pb = f.read() @@ -110,7 +134,7 @@ def load_mobilenet_test_data(): img = resize(img, (img_height, img_width), mode='constant', anti_aliasing=True) img = img.astype(np.float32) - return model_pb, labels, img + return model_pb, input_var, output_var, labels, img def load_creditcardfraud_data(env,max_tensors=10000): test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') diff --git a/test/test_data/mobilenet/mobilenet_v1_100_224_cpu_NxHxWxC.pb b/test/test_data/mobilenet/mobilenet_v1_100_224_cpu_NxHxWxC.pb new file mode 100644 index 000000000..30ba12b3c --- /dev/null +++ b/test/test_data/mobilenet/mobilenet_v1_100_224_cpu_NxHxWxC.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbb2752038ff1749d2b55988bb5f6e999a799c19413a0691b82d29f7aec0bab3 +size 17198345 diff --git a/test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC.pb b/test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC.pb new file mode 100644 index 000000000..2e8871769 --- /dev/null +++ b/test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1fe206dfd3cff261cf403b5757abec886da445a80056e55310ddac0b2805a3b +size 17198345 diff --git a/test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC_fp16_trt.pb b/test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC_fp16_trt.pb new file mode 100644 index 000000000..197733e00 --- /dev/null +++ b/test/test_data/mobilenet/mobilenet_v1_100_224_gpu_NxHxWxC_fp16_trt.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd925f4b59d8d5035ccb2ecdfbf9b0f47a5ba3acfa81bd5a18536f69021df74a +size 34277746 diff --git a/test/test_data/mobilenet/mobilenet_v2_1.4_224_frozen.pb b/test/test_data/mobilenet/mobilenet_v2_1.4_224_frozen.pb new file mode 100644 index 000000000..41e3481fd --- /dev/null +++ b/test/test_data/mobilenet/mobilenet_v2_1.4_224_frozen.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:111479258f3841c93d0a7a377c976c24e8281077818991931429d2277dd88590 +size 24508794 diff --git a/test/test_data/mobilenet/model_saver.py b/test/test_data/mobilenet/model_saver.py new file mode 100644 index 000000000..4ba9da4b7 --- /dev/null +++ b/test/test_data/mobilenet/model_saver.py @@ -0,0 +1,49 @@ +import tensorflow as tf +import tensorflow_hub as hub +import ml2rt +import argparse +import sys + +url = 'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/quantops/classification/3' +model_name = 'mobilenet_v1_100_224' +module = hub.Module(url) +batch_size = 1 +number_channels = 3 +height, width = hub.get_expected_image_size(module) +input_var = 'input' +output_var = 'MobilenetV1/Predictions/Reshape_1' + +parser = argparse.ArgumentParser() +parser.add_argument('--gpu', action="store_true", default=False) +parser.add_argument('--input-shape', default="NxHxWxC", type=str) +args = parser.parse_args() +device = 'gpu' if args.gpu else 'cpu' + +gpu_available = tf.test.is_gpu_available( + cuda_only=True, min_cuda_compute_capability=None +) + +if gpu_available is False and args.gpu: + print("No CUDA GPUs found. Exiting...") + sys.exit(1) + +var_converter = tf.compat.v1.graph_util.convert_variables_to_constants + +if args.input_shape == "NxHxWxC": + print("Saving N x H x W x C (1, 224, 224, 3) (with channels_last data format)") + images = tf.compat.v1.placeholder(tf.float32, shape=( + batch_size, height, width, number_channels), name=input_var) +elif args.input_shape == "NxHxWxC": + print("Saving N x C x H x W (1, 3, 224, 224)") + images = tf.placeholder(tf.float32, shape=( + batch_size, number_channels, height, width), name=input_var) +else: + print("inputs shape is either NxHxWxC or NxCxHxW. Exiting...") + sys.exit(1) + +logits = module(images) +logits = tf.identity(logits, output_var) +with tf.compat.v1.Session() as sess: + sess.run([tf.compat.v1.global_variables_initializer()]) + ml2rt.save_tensorflow(sess, '{model_name}_{device}_{input_shape}.pb'.format( + model_name=model_name, device=device, input_shape=args.input_shape), output=[output_var]) diff --git a/test/tests_sanitizer.py b/test/tests_sanitizer.py new file mode 100644 index 000000000..c4ed29044 --- /dev/null +++ b/test/tests_sanitizer.py @@ -0,0 +1,87 @@ +import redis +from functools import wraps +import multiprocessing as mp +from includes import * + +''' +python -m RLTest --test tests_sanitizer.py --module path/to/redisai.so +''' + + +def test_sanitizer_dagrun_mobilenet_v1(env): + if (not TEST_TF or not TEST_PT): + return + con = env.getConnection() + mem_allocator = con.execute_command('info', 'memory')['mem_allocator'] + if 'jemalloc' in mem_allocator: + print("exiting sanitizer test given we're not using stdlib allocator") + return + + model_name = 'mobilenet_v1' + model_pb, input_var, output_var, labels, img = load_mobilenet_v1_test_data() + + ret = con.execute_command('AI.MODELSET', model_name, 'TF', DEVICE, + 'INPUTS', input_var, + 'OUTPUTS', output_var, + 'BLOB', model_pb) + env.assertEqual(ret, b'OK') + + for opnumber in range(1, MAX_ITERATIONS): + image_key = 'image' + temp_key1 = 'temp_key1' + temp_key2 = 'temp_key2' + class_key = 'output' + + ret = con.execute_command( + 'AI.DAGRUN', '|>', + 'AI.TENSORSET', image_key, 'FLOAT', 1, 224, 224, 3, 'BLOB', img.tobytes(), '|>', + 'AI.MODELRUN', model_name, + 'INPUTS', image_key, + 'OUTPUTS', class_key, '|>', + 'AI.TENSORGET', class_key, 'blob' + ) + env.assertEqual([b'OK', b'OK'], ret[:2]) + env.assertEqual(1001.0, len(ret[2])/4) + + +def test_sanitizer_modelrun_mobilenet_v1(env): + if (not TEST_TF or not TEST_PT): + return + con = env.getConnection() + mem_allocator = con.execute_command('info', 'memory')['mem_allocator'] + if 'jemalloc' in mem_allocator: + print("exiting sanitizer test given we're not using stdlib allocator") + return + + model_name = 'mobilenet_v1' + model_pb, input_var, output_var, labels, img = load_mobilenet_v1_test_data() + + ret = con.execute_command('AI.MODELSET', model_name, 'TF', DEVICE, + 'INPUTS', input_var, + 'OUTPUTS', output_var, + 'BLOB', model_pb) + env.assertEqual(ret, b'OK') + + for opnumber in range(1, MAX_ITERATIONS): + image_key = 'image' + temp_key1 = 'temp_key1' + temp_key2 = 'temp_key2' + class_key = 'output' + ret = con.execute_command( + 'AI.TENSORSET', image_key, 'FLOAT', 1, 224, 224, 3, 'BLOB', img.tobytes() + ) + env.assertEqual(b'OK', ret) + + ret = con.execute_command( + 'AI.MODELRUN', model_name, + 'INPUTS', image_key, + 'OUTPUTS', class_key + ) + + env.assertEqual(b'OK', ret) + + ret = con.execute_command( + 'AI.TENSORGET', class_key, 'blob' + ) + + env.assertEqual(1001.0, len(ret)/4) diff --git a/test/tests_tensorflow.py b/test/tests_tensorflow.py index f17efa427..7883a8ffa 100644 --- a/test/tests_tensorflow.py +++ b/test/tests_tensorflow.py @@ -24,10 +24,7 @@ def wrapper(env, *args, **kwargs): def test_run_mobilenet(env): con = env.getConnection() - input_var = 'input' - output_var = 'MobilenetV2/Predictions/Reshape_1' - - model_pb, labels, img = load_mobilenet_test_data() + model_pb, input_var, output_var, labels, img = load_mobilenet_v2_test_data() con.execute_command('AI.MODELSET', 'mobilenet', 'TF', DEVICE, 'INPUTS', input_var, 'OUTPUTS', output_var, 'BLOB', model_pb) @@ -94,10 +91,7 @@ def test_run_mobilenet_multiproc(env): con = env.getConnection() - input_var = 'input' - output_var = 'MobilenetV2/Predictions/Reshape_1' - - model_pb, labels, img = load_mobilenet_test_data() + model_pb, input_var, output_var, labels, img = load_mobilenet_v2_test_data() con.execute_command('AI.MODELSET', 'mobilenet', 'TF', DEVICE, 'INPUTS', input_var, 'OUTPUTS', output_var, 'BLOB', model_pb) ensureSlaveSynced(con, env) @@ -627,10 +621,7 @@ def test_tensorflow_modelrun_with_batch_and_minbatch(env): minbatch_size = 2 model_name = 'model' another_model_name = 'another_model' - inputvar = 'input' - outputvar = 'MobilenetV2/Predictions/Reshape_1' - - model_pb, labels, img = load_mobilenet_test_data() + model_pb, input_var, output_var, labels, img = load_mobilenet_v2_test_data() con.execute_command('AI.MODELSET', model_name, 'TF', DEVICE, 'BATCHSIZE', batch_size, 'MINBATCHSIZE', minbatch_size, diff --git a/test/unit/.gitignore b/test/unit/.gitignore new file mode 100644 index 000000000..c8c326b5e --- /dev/null +++ b/test/unit/.gitignore @@ -0,0 +1,8 @@ +# Unit test binaries +*.run + +# Object files +*.o +*.ko +*.obj +*.elf \ No newline at end of file diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt new file mode 100644 index 000000000..ce8a8c05c --- /dev/null +++ b/test/unit/CMakeLists.txt @@ -0,0 +1,20 @@ + + add_subdirectory("${PROJECT_SOURCE_DIR}/deps/googletest" "deps/googletest") + + macro(package_add_test TESTNAME) + # create an exectuable in which the tests will be stored + add_executable(${TESTNAME} ${ARGN}) + # link the Google test infrastructure, mocking library, and a default main fuction to + # the test executable. Remove g_test_main if writing your own main function. + target_link_libraries(${TESTNAME} gtest gmock gtest_main) + # gtest_discover_tests replaces gtest_add_tests, + # see https://cmake.org/cmake/help/v3.10/module/GoogleTest.html for more options to pass to it + gtest_discover_tests(${TESTNAME} + # set a working directory so your project root so that you can find test data via paths relative to the project root + WORKING_DIRECTORY ${PROJECT_DIR} + PROPERTIES VS_DEBUGGER_WORKING_DIRECTORY "${PROJECT_DIR}" + ) + set_target_properties(${TESTNAME} PROPERTIES FOLDER tests) +endmacro() + +package_add_test(test1 test1.cpp) \ No newline at end of file From 2e885a63079b28e1a32a01c04854bcfbd074b0dd Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 16 Jun 2020 23:42:51 +0200 Subject: [PATCH 2/4] Fix run info memleaks --- src/model.c | 19 +++++++++++-------- src/model.h | 3 ++- src/run_info.c | 33 ++++++++++++++++++++++++++------- src/script.c | 19 +++++++++++-------- src/script.h | 3 ++- 5 files changed, 52 insertions(+), 25 deletions(-) diff --git a/src/model.c b/src/model.c index 73810f863..0fa7b6794 100644 --- a/src/model.c +++ b/src/model.c @@ -424,17 +424,20 @@ RAI_Tensor* RAI_ModelRunCtxOutputTensor(RAI_ModelRunCtx* mctx, size_t index) { return mctx->outputs[index].tensor; } -void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx) { - for (size_t i=0; iinputs); ++i) { - RAI_TensorFree(mctx->inputs[i].tensor); - } - array_free(mctx->inputs); +void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx, int freeTensors) { + if (freeTensors) { + for (size_t i=0; iinputs); ++i) { + RAI_TensorFree(mctx->inputs[i].tensor); + } - for (size_t i = 0 ; i < array_len(mctx->outputs) ; ++i) { - if (mctx->outputs[i].tensor) { - RAI_TensorFree(mctx->outputs[i].tensor); + for (size_t i = 0 ; i < array_len(mctx->outputs) ; ++i) { + if (mctx->outputs[i].tensor) { + RAI_TensorFree(mctx->outputs[i].tensor); + } } } + + array_free(mctx->inputs); array_free(mctx->outputs); RAI_Error err = {0}; diff --git a/src/model.h b/src/model.h index cb9c98610..99599dead 100644 --- a/src/model.h +++ b/src/model.h @@ -79,8 +79,9 @@ RAI_ModelRunCtx* RAI_ModelRunCtxCreate(RAI_Model* model); * work * * @param mctx + * @param freeTensors free input and output tensors or leave them allocated */ -void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx); +void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx, int freeTensors); /** * Allocates a RAI_ModelCtxParam data structure, and enforces a shallow copy of diff --git a/src/run_info.c b/src/run_info.c index 42070306b..0825346a6 100644 --- a/src/run_info.c +++ b/src/run_info.c @@ -116,6 +116,13 @@ void RAI_FreeDagOp(RedisModuleCtx *ctx, RAI_DagOp *dagOp) { } array_free(dagOp->outTensors); + if (dagOp->mctx) { + RAI_ModelRunCtxFree(dagOp->mctx, false); + } + if (dagOp->sctx) { + RAI_ScriptRunCtxFree(dagOp->sctx, false); + } + RedisModule_Free(dagOp); } } @@ -125,21 +132,21 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) { return; } if (rinfo->mctx) { - RAI_ModelRunCtxFree(rinfo->mctx); + RAI_ModelRunCtxFree(rinfo->mctx, true); } if (rinfo->sctx) { - RAI_ScriptRunCtxFree(rinfo->sctx); + RAI_ScriptRunCtxFree(rinfo->sctx, true); } RAI_FreeError(rinfo->err); if (rinfo->dagTensorsContext) { AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext); - AI_dictEntry *stats_entry = AI_dictNext(iter); + AI_dictEntry *entry = AI_dictNext(iter); RAI_Tensor *tensor = NULL; - while (stats_entry) { - tensor = AI_dictGetVal(stats_entry); - char *key = (char *)AI_dictGetKey(stats_entry); + while (entry) { + tensor = AI_dictGetVal(entry); + char *key = (char *)AI_dictGetKey(entry); if (tensor&&key!=NULL) { // if the key is persistent then we should not delete it @@ -149,13 +156,25 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) { // it AI_dictEntry *loaded_entry = AI_dictFind(rinfo->dagTensorsLoadedContext, key); + if (persistent_entry == NULL && loaded_entry == NULL) { RAI_TensorFree(tensor); } + + if (persistent_entry) { + AI_dictDelete(rinfo->dagTensorsPersistentContext, key); + } + if (loaded_entry) { + AI_dictDelete(rinfo->dagTensorsLoadedContext, key); + } } - stats_entry = AI_dictNext(iter); + entry = AI_dictNext(iter); } AI_dictReleaseIterator(iter); + + RedisModule_Free(rinfo->dagTensorsContext); + RedisModule_Free(rinfo->dagTensorsLoadedContext); + RedisModule_Free(rinfo->dagTensorsPersistentContext); } if (rinfo->dagOps) { diff --git a/src/script.c b/src/script.c index 9419b712a..eeb82f913 100644 --- a/src/script.c +++ b/src/script.c @@ -202,17 +202,20 @@ RAI_Tensor* RAI_ScriptRunCtxOutputTensor(RAI_ScriptRunCtx* sctx, size_t index) { return sctx->outputs[index].tensor; } -void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx* sctx) { - for (size_t i = 0; i < array_len(sctx->inputs); ++i) { - RAI_TensorFree(sctx->inputs[i].tensor); - } - array_free(sctx->inputs); +void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx* sctx, int freeTensors) { + if (freeTensors) { + for (size_t i = 0; i < array_len(sctx->inputs); ++i) { + RAI_TensorFree(sctx->inputs[i].tensor); + } - for (size_t i = 0; i < array_len(sctx->outputs); ++i) { - if (sctx->outputs[i].tensor) { - RAI_TensorFree(sctx->outputs[i].tensor); + for (size_t i = 0; i < array_len(sctx->outputs); ++i) { + if (sctx->outputs[i].tensor) { + RAI_TensorFree(sctx->outputs[i].tensor); + } } } + + array_free(sctx->inputs); array_free(sctx->outputs); RedisModule_Free(sctx->fnname); diff --git a/src/script.h b/src/script.h index ce15aaa43..79208305d 100644 --- a/src/script.h +++ b/src/script.h @@ -119,8 +119,9 @@ RAI_Tensor* RAI_ScriptRunCtxOutputTensor(RAI_ScriptRunCtx* sctx, size_t index); * work * * @param sctx + * @param freeTensors free input and output tensors or leave them allocated */ -void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx* sctx); +void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx* sctx, int freeTensors); /** * Given the input script context, run associated script From 9e060bbcce55eb8f9737205d9fcf0044fd01f378 Mon Sep 17 00:00:00 2001 From: filipecosta90 Date: Wed, 17 Jun 2020 02:46:45 +0100 Subject: [PATCH 3/4] [fix] fixed reference count on ai.dagrun and ai.dagrunro for tensor structure. Added AI_dictType AI_dictTypeTensorVals with proper valDestructor --- src/dag.c | 11 +++++++++- src/run_info.c | 43 +++++++++++++++++++++++++++++++++++----- test/tests_sanitizer.py | 10 +++++----- test/tests_tensorflow.py | 8 ++++---- 4 files changed, 57 insertions(+), 15 deletions(-) diff --git a/src/dag.c b/src/dag.c index 6c0d39487..50b92b950 100644 --- a/src/dag.c +++ b/src/dag.c @@ -91,6 +91,16 @@ void *RedisAI_DagRunSession(RedisAI_RunInfo *rinfo) { currentOp->result = REDISMODULE_ERR; } } + // since we've increased the reference count prior modelrun we need to decrease it + const size_t ninputs = RAI_ModelRunCtxNumInputs(currentOp->mctx); + for (size_t inputNumber = 0; inputNumber < ninputs; inputNumber++) { + RAI_Tensor *tensor = + RAI_ModelRunCtxInputTensor(currentOp->mctx, inputNumber); + if (tensor) { + RAI_TensorFree(tensor); + } + } + } else { currentOp->result = REDISMODULE_ERR; } @@ -243,7 +253,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, } RedisModule_CloseKey(key); RedisAI_ReplicateTensorSet(ctx, tensor_keyname, tensor); - // TODO: free Tensor } else { RedisModule_ReplyWithError( ctx, "ERR specified persistent key that was not used on DAG"); diff --git a/src/run_info.c b/src/run_info.c index 0825346a6..af4f5bdc6 100644 --- a/src/run_info.c +++ b/src/run_info.c @@ -16,6 +16,40 @@ #include "util/arr_rm_alloc.h" #include "util/dict.h" + +static uint64_t RAI_TensorDictKeyHashFunction(const void *key){ + return AI_dictGenHashFunction(key, strlen((char*)key)); +} + +static int RAI_TensorDictKeyStrcmp(void *privdata, const void *key1, const void *key2){ + const char* strKey1 = key1; + const char* strKey2 = key2; + return strcmp(strKey1, strKey2) == 0; +} + +static void RAI_TensorDictKeyFree(void *privdata, void *key){ + RedisModule_Free(key); +} + +static void* RAI_TensorDictKeyDup(void *privdata, const void *key){ + return RedisModule_Strdup((char*)key); +} + +static void RAI_TensorDictValFree(void *privdata, const void *obj){ + return RAI_TensorFree((RAI_Tensor*)obj); +} + + +AI_dictType AI_dictTypeTensorVals = { + .hashFunction = RAI_TensorDictKeyHashFunction, + .keyDup = RAI_TensorDictKeyDup, + .valDup = NULL, + .keyCompare = RAI_TensorDictKeyStrcmp, + .keyDestructor = RAI_TensorDictKeyFree, + .valDestructor = RAI_TensorDictValFree, +}; + + /** * Allocate the memory and initialise the RAI_DagOp. * @param result Output parameter to capture allocated RAI_DagOp. @@ -76,7 +110,7 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) { return REDISMODULE_ERR; } rinfo->use_local_context = 0; - rinfo->dagTensorsContext = AI_dictCreate(&AI_dictTypeHeapStrings, NULL); + rinfo->dagTensorsContext = AI_dictCreate(&AI_dictTypeTensorVals, NULL); if (!(rinfo->dagTensorsContext)) { return REDISMODULE_ERR; } @@ -148,17 +182,16 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) { tensor = AI_dictGetVal(entry); char *key = (char *)AI_dictGetKey(entry); - if (tensor&&key!=NULL) { + if (tensor && key != NULL) { // if the key is persistent then we should not delete it AI_dictEntry *persistent_entry = AI_dictFind(rinfo->dagTensorsPersistentContext, key); - // if the key was loaded from the keyspace then we should not delete - // it + // if the key was loaded from the keyspace then we should not delete it AI_dictEntry *loaded_entry = AI_dictFind(rinfo->dagTensorsLoadedContext, key); if (persistent_entry == NULL && loaded_entry == NULL) { - RAI_TensorFree(tensor); + AI_dictDelete(rinfo->dagTensorsContext, key); } if (persistent_entry) { diff --git a/test/tests_sanitizer.py b/test/tests_sanitizer.py index c4ed29044..4f0c45ca4 100644 --- a/test/tests_sanitizer.py +++ b/test/tests_sanitizer.py @@ -27,17 +27,17 @@ def test_sanitizer_dagrun_mobilenet_v1(env): env.assertEqual(ret, b'OK') for opnumber in range(1, MAX_ITERATIONS): - image_key = 'image' - temp_key1 = 'temp_key1' - temp_key2 = 'temp_key2' + image_key = 'image{}'.format(opnumber) class_key = 'output' ret = con.execute_command( 'AI.DAGRUN', '|>', - 'AI.TENSORSET', image_key, 'FLOAT', 1, 224, 224, 3, 'BLOB', img.tobytes(), '|>', + 'AI.TENSORSET', image_key, 'FLOAT', 1, 224, 224, 3, 'BLOB', img.tobytes(), + '|>', 'AI.MODELRUN', model_name, 'INPUTS', image_key, - 'OUTPUTS', class_key, '|>', + 'OUTPUTS', class_key, + '|>', 'AI.TENSORGET', class_key, 'blob' ) env.assertEqual([b'OK', b'OK'], ret[:2]) diff --git a/test/tests_tensorflow.py b/test/tests_tensorflow.py index 7883a8ffa..97ea50014 100644 --- a/test/tests_tensorflow.py +++ b/test/tests_tensorflow.py @@ -625,8 +625,8 @@ def test_tensorflow_modelrun_with_batch_and_minbatch(env): con.execute_command('AI.MODELSET', model_name, 'TF', DEVICE, 'BATCHSIZE', batch_size, 'MINBATCHSIZE', minbatch_size, - 'INPUTS', inputvar, - 'OUTPUTS', outputvar, + 'INPUTS', input_var, + 'OUTPUTS', output_var, 'BLOB', model_pb) con.execute_command('AI.TENSORSET', 'input', 'FLOAT', 1, img.shape[1], img.shape[0], img.shape[2], @@ -649,8 +649,8 @@ def run(name=model_name, output_name='output'): con.execute_command('AI.MODELSET', another_model_name, 'TF', DEVICE, 'BATCHSIZE', batch_size, 'MINBATCHSIZE', minbatch_size, - 'INPUTS', inputvar, - 'OUTPUTS', outputvar, + 'INPUTS', input_var, + 'OUTPUTS', output_var, 'BLOB', model_pb) p1b = mp.Process(target=run, args=(another_model_name, 'final1')) From 9e6f4162087700c6bb21462238953d35d87185c5 Mon Sep 17 00:00:00 2001 From: filipecosta90 Date: Wed, 17 Jun 2020 02:53:14 +0100 Subject: [PATCH 4/4] [fix] removed WIP unit test folder (not for this issue) --- test/unit/.gitignore | 8 -------- test/unit/CMakeLists.txt | 20 -------------------- 2 files changed, 28 deletions(-) delete mode 100644 test/unit/.gitignore delete mode 100644 test/unit/CMakeLists.txt diff --git a/test/unit/.gitignore b/test/unit/.gitignore deleted file mode 100644 index c8c326b5e..000000000 --- a/test/unit/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# Unit test binaries -*.run - -# Object files -*.o -*.ko -*.obj -*.elf \ No newline at end of file diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt deleted file mode 100644 index ce8a8c05c..000000000 --- a/test/unit/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ - - add_subdirectory("${PROJECT_SOURCE_DIR}/deps/googletest" "deps/googletest") - - macro(package_add_test TESTNAME) - # create an exectuable in which the tests will be stored - add_executable(${TESTNAME} ${ARGN}) - # link the Google test infrastructure, mocking library, and a default main fuction to - # the test executable. Remove g_test_main if writing your own main function. - target_link_libraries(${TESTNAME} gtest gmock gtest_main) - # gtest_discover_tests replaces gtest_add_tests, - # see https://cmake.org/cmake/help/v3.10/module/GoogleTest.html for more options to pass to it - gtest_discover_tests(${TESTNAME} - # set a working directory so your project root so that you can find test data via paths relative to the project root - WORKING_DIRECTORY ${PROJECT_DIR} - PROPERTIES VS_DEBUGGER_WORKING_DIRECTORY "${PROJECT_DIR}" - ) - set_target_properties(${TESTNAME} PROPERTIES FOLDER tests) -endmacro() - -package_add_test(test1 test1.cpp) \ No newline at end of file