Skip to content

Commit 9e060bb

Browse files
committed
[fix] fixed reference count on ai.dagrun and ai.dagrunro for tensor structure. Added AI_dictType AI_dictTypeTensorVals with proper valDestructor
1 parent c6c3dd0 commit 9e060bb

File tree

4 files changed

+57
-15
lines changed

4 files changed

+57
-15
lines changed

src/dag.c

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@ void *RedisAI_DagRunSession(RedisAI_RunInfo *rinfo) {
9191
currentOp->result = REDISMODULE_ERR;
9292
}
9393
}
94+
// since we've increased the reference count prior modelrun we need to decrease it
95+
const size_t ninputs = RAI_ModelRunCtxNumInputs(currentOp->mctx);
96+
for (size_t inputNumber = 0; inputNumber < ninputs; inputNumber++) {
97+
RAI_Tensor *tensor =
98+
RAI_ModelRunCtxInputTensor(currentOp->mctx, inputNumber);
99+
if (tensor) {
100+
RAI_TensorFree(tensor);
101+
}
102+
}
103+
94104
} else {
95105
currentOp->result = REDISMODULE_ERR;
96106
}
@@ -243,7 +253,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv,
243253
}
244254
RedisModule_CloseKey(key);
245255
RedisAI_ReplicateTensorSet(ctx, tensor_keyname, tensor);
246-
// TODO: free Tensor
247256
} else {
248257
RedisModule_ReplyWithError(
249258
ctx, "ERR specified persistent key that was not used on DAG");

src/run_info.c

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,40 @@
1616
#include "util/arr_rm_alloc.h"
1717
#include "util/dict.h"
1818

19+
20+
static uint64_t RAI_TensorDictKeyHashFunction(const void *key){
21+
return AI_dictGenHashFunction(key, strlen((char*)key));
22+
}
23+
24+
static int RAI_TensorDictKeyStrcmp(void *privdata, const void *key1, const void *key2){
25+
const char* strKey1 = key1;
26+
const char* strKey2 = key2;
27+
return strcmp(strKey1, strKey2) == 0;
28+
}
29+
30+
static void RAI_TensorDictKeyFree(void *privdata, void *key){
31+
RedisModule_Free(key);
32+
}
33+
34+
static void* RAI_TensorDictKeyDup(void *privdata, const void *key){
35+
return RedisModule_Strdup((char*)key);
36+
}
37+
38+
static void RAI_TensorDictValFree(void *privdata, const void *obj){
39+
return RAI_TensorFree((RAI_Tensor*)obj);
40+
}
41+
42+
43+
AI_dictType AI_dictTypeTensorVals = {
44+
.hashFunction = RAI_TensorDictKeyHashFunction,
45+
.keyDup = RAI_TensorDictKeyDup,
46+
.valDup = NULL,
47+
.keyCompare = RAI_TensorDictKeyStrcmp,
48+
.keyDestructor = RAI_TensorDictKeyFree,
49+
.valDestructor = RAI_TensorDictValFree,
50+
};
51+
52+
1953
/**
2054
* Allocate the memory and initialise the RAI_DagOp.
2155
* @param result Output parameter to capture allocated RAI_DagOp.
@@ -76,7 +110,7 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) {
76110
return REDISMODULE_ERR;
77111
}
78112
rinfo->use_local_context = 0;
79-
rinfo->dagTensorsContext = AI_dictCreate(&AI_dictTypeHeapStrings, NULL);
113+
rinfo->dagTensorsContext = AI_dictCreate(&AI_dictTypeTensorVals, NULL);
80114
if (!(rinfo->dagTensorsContext)) {
81115
return REDISMODULE_ERR;
82116
}
@@ -148,17 +182,16 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) {
148182
tensor = AI_dictGetVal(entry);
149183
char *key = (char *)AI_dictGetKey(entry);
150184

151-
if (tensor&&key!=NULL) {
185+
if (tensor && key != NULL) {
152186
// if the key is persistent then we should not delete it
153187
AI_dictEntry *persistent_entry =
154188
AI_dictFind(rinfo->dagTensorsPersistentContext, key);
155-
// if the key was loaded from the keyspace then we should not delete
156-
// it
189+
// if the key was loaded from the keyspace then we should not delete it
157190
AI_dictEntry *loaded_entry =
158191
AI_dictFind(rinfo->dagTensorsLoadedContext, key);
159192

160193
if (persistent_entry == NULL && loaded_entry == NULL) {
161-
RAI_TensorFree(tensor);
194+
AI_dictDelete(rinfo->dagTensorsContext, key);
162195
}
163196

164197
if (persistent_entry) {

test/tests_sanitizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@ def test_sanitizer_dagrun_mobilenet_v1(env):
2727
env.assertEqual(ret, b'OK')
2828

2929
for opnumber in range(1, MAX_ITERATIONS):
30-
image_key = 'image'
31-
temp_key1 = 'temp_key1'
32-
temp_key2 = 'temp_key2'
30+
image_key = 'image{}'.format(opnumber)
3331
class_key = 'output'
3432

3533
ret = con.execute_command(
3634
'AI.DAGRUN', '|>',
37-
'AI.TENSORSET', image_key, 'FLOAT', 1, 224, 224, 3, 'BLOB', img.tobytes(), '|>',
35+
'AI.TENSORSET', image_key, 'FLOAT', 1, 224, 224, 3, 'BLOB', img.tobytes(),
36+
'|>',
3837
'AI.MODELRUN', model_name,
3938
'INPUTS', image_key,
40-
'OUTPUTS', class_key, '|>',
39+
'OUTPUTS', class_key,
40+
'|>',
4141
'AI.TENSORGET', class_key, 'blob'
4242
)
4343
env.assertEqual([b'OK', b'OK'], ret[:2])

test/tests_tensorflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -625,8 +625,8 @@ def test_tensorflow_modelrun_with_batch_and_minbatch(env):
625625

626626
con.execute_command('AI.MODELSET', model_name, 'TF', DEVICE,
627627
'BATCHSIZE', batch_size, 'MINBATCHSIZE', minbatch_size,
628-
'INPUTS', inputvar,
629-
'OUTPUTS', outputvar,
628+
'INPUTS', input_var,
629+
'OUTPUTS', output_var,
630630
'BLOB', model_pb)
631631
con.execute_command('AI.TENSORSET', 'input',
632632
'FLOAT', 1, img.shape[1], img.shape[0], img.shape[2],
@@ -649,8 +649,8 @@ def run(name=model_name, output_name='output'):
649649

650650
con.execute_command('AI.MODELSET', another_model_name, 'TF', DEVICE,
651651
'BATCHSIZE', batch_size, 'MINBATCHSIZE', minbatch_size,
652-
'INPUTS', inputvar,
653-
'OUTPUTS', outputvar,
652+
'INPUTS', input_var,
653+
'OUTPUTS', output_var,
654654
'BLOB', model_pb)
655655

656656
p1b = mp.Process(target=run, args=(another_model_name, 'final1'))

0 commit comments

Comments
 (0)