16
16
#include "util/arr_rm_alloc.h"
17
17
#include "util/dict.h"
18
18
19
+
20
+ static uint64_t RAI_TensorDictKeyHashFunction (const void * key ){
21
+ return AI_dictGenHashFunction (key , strlen ((char * )key ));
22
+ }
23
+
24
+ static int RAI_TensorDictKeyStrcmp (void * privdata , const void * key1 , const void * key2 ){
25
+ const char * strKey1 = key1 ;
26
+ const char * strKey2 = key2 ;
27
+ return strcmp (strKey1 , strKey2 ) == 0 ;
28
+ }
29
+
30
+ static void RAI_TensorDictKeyFree (void * privdata , void * key ){
31
+ RedisModule_Free (key );
32
+ }
33
+
34
+ static void * RAI_TensorDictKeyDup (void * privdata , const void * key ){
35
+ return RedisModule_Strdup ((char * )key );
36
+ }
37
+
38
+ static void RAI_TensorDictValFree (void * privdata , const void * obj ){
39
+ return RAI_TensorFree ((RAI_Tensor * )obj );
40
+ }
41
+
42
+
43
+ AI_dictType AI_dictTypeTensorVals = {
44
+ .hashFunction = RAI_TensorDictKeyHashFunction ,
45
+ .keyDup = RAI_TensorDictKeyDup ,
46
+ .valDup = NULL ,
47
+ .keyCompare = RAI_TensorDictKeyStrcmp ,
48
+ .keyDestructor = RAI_TensorDictKeyFree ,
49
+ .valDestructor = RAI_TensorDictValFree ,
50
+ };
51
+
52
+
19
53
/**
20
54
* Allocate the memory and initialise the RAI_DagOp.
21
55
* @param result Output parameter to capture allocated RAI_DagOp.
@@ -76,7 +110,7 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) {
76
110
return REDISMODULE_ERR ;
77
111
}
78
112
rinfo -> use_local_context = 0 ;
79
- rinfo -> dagTensorsContext = AI_dictCreate (& AI_dictTypeHeapStrings , NULL );
113
+ rinfo -> dagTensorsContext = AI_dictCreate (& AI_dictTypeTensorVals , NULL );
80
114
if (!(rinfo -> dagTensorsContext )) {
81
115
return REDISMODULE_ERR ;
82
116
}
@@ -116,6 +150,13 @@ void RAI_FreeDagOp(RedisModuleCtx *ctx, RAI_DagOp *dagOp) {
116
150
}
117
151
array_free (dagOp -> outTensors );
118
152
153
+ if (dagOp -> mctx ) {
154
+ RAI_ModelRunCtxFree (dagOp -> mctx , false);
155
+ }
156
+ if (dagOp -> sctx ) {
157
+ RAI_ScriptRunCtxFree (dagOp -> sctx , false);
158
+ }
159
+
119
160
RedisModule_Free (dagOp );
120
161
}
121
162
}
@@ -125,37 +166,48 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) {
125
166
return ;
126
167
}
127
168
if (rinfo -> mctx ) {
128
- RAI_ModelRunCtxFree (rinfo -> mctx );
169
+ RAI_ModelRunCtxFree (rinfo -> mctx , true );
129
170
}
130
171
if (rinfo -> sctx ) {
131
- RAI_ScriptRunCtxFree (rinfo -> sctx );
172
+ RAI_ScriptRunCtxFree (rinfo -> sctx , true );
132
173
}
133
174
RAI_FreeError (rinfo -> err );
134
175
135
176
if (rinfo -> dagTensorsContext ) {
136
177
AI_dictIterator * iter = AI_dictGetSafeIterator (rinfo -> dagTensorsContext );
137
- AI_dictEntry * stats_entry = AI_dictNext (iter );
178
+ AI_dictEntry * entry = AI_dictNext (iter );
138
179
RAI_Tensor * tensor = NULL ;
139
180
140
- while (stats_entry ) {
141
- tensor = AI_dictGetVal (stats_entry );
142
- char * key = (char * )AI_dictGetKey (stats_entry );
181
+ while (entry ) {
182
+ tensor = AI_dictGetVal (entry );
183
+ char * key = (char * )AI_dictGetKey (entry );
143
184
144
- if (tensor && key != NULL ) {
185
+ if (tensor && key != NULL ) {
145
186
// if the key is persistent then we should not delete it
146
187
AI_dictEntry * persistent_entry =
147
188
AI_dictFind (rinfo -> dagTensorsPersistentContext , key );
148
- // if the key was loaded from the keyspace then we should not delete
149
- // it
189
+ // if the key was loaded from the keyspace then we should not delete it
150
190
AI_dictEntry * loaded_entry =
151
191
AI_dictFind (rinfo -> dagTensorsLoadedContext , key );
192
+
152
193
if (persistent_entry == NULL && loaded_entry == NULL ) {
153
- RAI_TensorFree (tensor );
194
+ AI_dictDelete (rinfo -> dagTensorsContext , key );
195
+ }
196
+
197
+ if (persistent_entry ) {
198
+ AI_dictDelete (rinfo -> dagTensorsPersistentContext , key );
199
+ }
200
+ if (loaded_entry ) {
201
+ AI_dictDelete (rinfo -> dagTensorsLoadedContext , key );
154
202
}
155
203
}
156
- stats_entry = AI_dictNext (iter );
204
+ entry = AI_dictNext (iter );
157
205
}
158
206
AI_dictReleaseIterator (iter );
207
+
208
+ RedisModule_Free (rinfo -> dagTensorsContext );
209
+ RedisModule_Free (rinfo -> dagTensorsLoadedContext );
210
+ RedisModule_Free (rinfo -> dagTensorsPersistentContext );
159
211
}
160
212
161
213
if (rinfo -> dagOps ) {
0 commit comments