3
3
#include "background_workers.h"
4
4
#include "util/string_utils.h"
5
5
6
- void _DAG_SetTensorsInLocalContext (RedisAI_RunInfo * rinfo ) {
7
- for (size_t i = 0 ; i < rinfo -> dagOpCount ; i ++ ) {
8
- RAI_DagOp * op = rinfo -> dagOps [i ];
9
- if (op -> commandType == REDISAI_DAG_CMD_TENSORSET ) {
10
- // Insert the tensor with its mangled (unique) name.
11
- void * t = (void * )RAI_TensorGetShallowCopy (op -> outTensor );
12
- AI_dictReplace (rinfo -> dagTensorsContext , (void * )op -> outkeys [0 ], t );
13
- }
14
- }
15
- }
16
-
17
- int MangleTensorsNames (RedisAI_RunInfo * rinfo ) {
18
-
19
- int res = REDISMODULE_ERR ;
20
- AI_dict * mangled_tensors = AI_dictCreate (& AI_dictTypeHeapRStrings , NULL );
6
+ int ValidatePersistKeys (RedisAI_RunInfo * rinfo , AI_dict * tensorsNamesToInd ,
7
+ AI_dict * persistTensorsNames ) {
21
8
22
9
{
23
- AI_dictIterator * iter = AI_dictGetSafeIterator (rinfo -> dagTensorsContext );
24
- AI_dictEntry * entry = AI_dictNext ( iter ) ;
25
- while (entry ) {
26
- RedisModuleString * key = (RedisModuleString * )AI_dictGetKey (entry );
27
- size_t key_len ;
28
- const char * key_str = RedisModule_StringPtrLen ( key , & key_len );
29
- RedisModuleString * demangled_key = RedisModule_CreateString ( NULL , key_str , key_len - 4 );
30
- int * instance = RedisModule_Alloc ( sizeof ( int ) );
31
- * instance = 1 ;
32
- AI_dictAdd ( mangled_tensors , ( void * ) demangled_key , ( void * ) instance );
33
- RedisModule_FreeString ( NULL , demangled_key );
34
- entry = AI_dictNext ( iter );
10
+ AI_dictIterator * iter = AI_dictGetSafeIterator (persistTensorsNames );
11
+ AI_dictEntry * persist_entry ;
12
+ while (( persist_entry = AI_dictNext ( iter )) ) {
13
+ RedisModuleString * persist_key = (RedisModuleString * )AI_dictGetKey (persist_entry );
14
+ AI_dictEntry * entry = AI_dictFind ( tensorsNamesToInd , persist_key ) ;
15
+ if (! entry ) {
16
+ RAI_SetError ( rinfo -> err , RAI_EDAGRUN , "ERR PERSIST key cannot be found in DAG" );
17
+ AI_dictReleaseIterator ( iter );
18
+ return REDISMODULE_ERR ;
19
+ }
20
+ size_t index = ( size_t ) AI_dictGetVal ( entry );
21
+ AI_dictReplace ( persistTensorsNames , ( void * ) persist_key , ( void * ) index );
35
22
}
36
23
AI_dictReleaseIterator (iter );
37
24
}
25
+ return REDISMODULE_OK ;
26
+ }
27
+
28
+ int MapTensorsKeysToIndices (RedisAI_RunInfo * rinfo , AI_dict * tensorsNamesToInd ) {
38
29
39
30
for (long long i = 0 ; i < array_len (rinfo -> dagOps ); i ++ ) {
40
31
RAI_DagOp * currentOp = rinfo -> dagOps [i ];
41
32
42
- RedisModuleString * * mangled_inkeys =
43
- array_new (RedisModuleString * , array_len (currentOp -> inkeys ));
44
33
for (long long j = 0 ; j < array_len (currentOp -> inkeys ); j ++ ) {
45
34
RedisModuleString * key = currentOp -> inkeys [j ];
46
- AI_dictEntry * entry = AI_dictFind (mangled_tensors , key );
35
+ AI_dictEntry * entry = AI_dictFind (tensorsNamesToInd , key );
47
36
if (!entry ) {
48
- array_free (mangled_inkeys );
49
37
RAI_SetError (rinfo -> err , RAI_EDAGRUN , "ERR INPUT key cannot be found in DAG" );
50
- goto cleanup ;
38
+ return REDISMODULE_ERR ;
51
39
}
52
- int * instance = AI_dictGetVal (entry );
53
- char buf [16 ];
54
- sprintf (buf , "%04d" , * instance );
55
- RedisModuleString * mangled_key = RedisModule_CreateStringFromString (NULL , key );
56
- RedisModule_StringAppendBuffer (NULL , mangled_key , buf , strlen (buf ));
57
- mangled_inkeys = array_append (mangled_inkeys , mangled_key );
40
+ size_t ind = (size_t )AI_dictGetVal (entry );
41
+ currentOp -> inkeys_indices = array_append (currentOp -> inkeys_indices , ind );
58
42
}
59
43
60
- RedisModuleString * * mangled_outkeys =
61
- array_new (RedisModuleString * , array_len (currentOp -> outkeys ));
62
44
for (long long j = 0 ; j < array_len (currentOp -> outkeys ); j ++ ) {
63
45
RedisModuleString * key = currentOp -> outkeys [j ];
64
- AI_dictEntry * entry = AI_dictFind (mangled_tensors , key );
65
- int * instance = NULL ;
66
- if (entry ) {
67
- instance = AI_dictGetVal (entry );
68
- * instance += 1 ;
69
- } else {
70
- instance = RedisModule_Alloc (sizeof (int ));
71
- * instance = 1 ;
72
- AI_dictAdd (mangled_tensors , (void * )key , (void * )instance );
73
- }
74
- char buf [16 ];
75
- sprintf (buf , "%04d" , * instance );
76
- RedisModuleString * mangled_key = RedisModule_CreateStringFromString (NULL , key );
77
- RedisModule_StringAppendBuffer (NULL , mangled_key , buf , strlen (buf ));
78
- mangled_outkeys = array_append (mangled_outkeys , mangled_key );
79
- }
80
-
81
- if (currentOp -> inkeys ) {
82
- for (size_t j = 0 ; j < array_len (currentOp -> inkeys ); j ++ ) {
83
- RedisModule_FreeString (NULL , currentOp -> inkeys [j ]);
84
- }
85
- array_free (currentOp -> inkeys );
86
- }
87
-
88
- if (currentOp -> outkeys ) {
89
- for (size_t j = 0 ; j < array_len (currentOp -> outkeys ); j ++ ) {
90
- RedisModule_FreeString (NULL , currentOp -> outkeys [j ]);
91
- }
92
- array_free (currentOp -> outkeys );
93
- }
94
-
95
- currentOp -> inkeys = mangled_inkeys ;
96
- currentOp -> outkeys = mangled_outkeys ;
97
- }
46
+ size_t ind = array_len (rinfo -> dagSharedTensors );
98
47
99
- AI_dict * mangled_persisted = AI_dictCreate (& AI_dictTypeHeapRStrings , NULL );
100
- {
101
- AI_dictIterator * iter = AI_dictGetSafeIterator (rinfo -> dagTensorsPersistedContext );
102
- AI_dictEntry * entry = AI_dictNext (iter );
103
- while (entry ) {
104
- RedisModuleString * key = (RedisModuleString * )AI_dictGetKey (entry );
105
- AI_dictEntry * mangled_entry = AI_dictFind (mangled_tensors , key );
106
- if (!mangled_entry ) {
107
- AI_dictRelease (mangled_persisted );
108
- AI_dictReleaseIterator (iter );
109
- RAI_SetError (rinfo -> err , RAI_EDAGRUN , "ERR PERSIST key cannot be found in DAG" );
110
- goto cleanup ;
111
- }
112
- if (AI_dictFind (mangled_persisted , key ) != NULL ) {
113
- AI_dictRelease (mangled_persisted );
114
- AI_dictReleaseIterator (iter );
115
- RAI_SetError (rinfo -> err , RAI_EDAGRUN , "ERR PERSIST keys must be unique" );
116
- goto cleanup ;
48
+ // Add a new empty place holder in the array for an output tensor.
49
+ // If this is a TENSORSET op, the tensor is already realized.
50
+ if (currentOp -> commandType == REDISAI_DAG_CMD_TENSORSET ) {
51
+ RAI_Tensor * t = RAI_TensorGetShallowCopy (currentOp -> outTensor );
52
+ rinfo -> dagSharedTensors = array_append (rinfo -> dagSharedTensors , t );
53
+ } else {
54
+ rinfo -> dagSharedTensors = array_append (rinfo -> dagSharedTensors , NULL );
117
55
}
118
- int * instance = AI_dictGetVal (mangled_entry );
119
- char buf [16 ];
120
- sprintf (buf , "%04d" , * instance );
121
- RedisModuleString * mangled_key = RedisModule_CreateStringFromString (NULL , key );
122
- RedisModule_StringAppendBuffer (NULL , mangled_key , buf , strlen (buf ));
123
- AI_dictAdd (mangled_persisted , (void * )mangled_key , (void * )1 );
124
- RedisModule_FreeString (NULL , mangled_key );
125
- entry = AI_dictNext (iter );
56
+ currentOp -> outkeys_indices = array_append (currentOp -> outkeys_indices , ind );
57
+ AI_dictReplace (tensorsNamesToInd , (void * )key , (void * )ind );
126
58
}
127
- AI_dictReleaseIterator (iter );
128
59
}
129
-
130
- AI_dictRelease (rinfo -> dagTensorsPersistedContext );
131
- rinfo -> dagTensorsPersistedContext = mangled_persisted ;
132
-
133
- for (long long i = 0 ; i < array_len (rinfo -> dagOps ); i ++ ) {
134
- if (rinfo -> dagOps [i ]-> devicestr == NULL ) {
135
- rinfo -> dagOps [i ]-> devicestr = "CPU" ;
136
- }
137
- }
138
- // Tensors from TENSORSET ops are ready to be put in DAG local context under their mangled
139
- // names.
140
- _DAG_SetTensorsInLocalContext (rinfo );
141
- res = REDISMODULE_OK ;
142
-
143
- cleanup : {
144
- AI_dictIterator * iter = AI_dictGetSafeIterator (mangled_tensors );
145
- AI_dictEntry * entry = AI_dictNext (iter );
146
- while (entry ) {
147
- int * val = (int * )AI_dictGetVal (entry );
148
- RedisModule_Free (val );
149
- entry = AI_dictNext (iter );
150
- }
151
- AI_dictReleaseIterator (iter );
152
- }
153
- AI_dictRelease (mangled_tensors );
154
- return res ;
60
+ return REDISMODULE_OK ;
155
61
}
156
62
157
63
// Add Shallow copies of the DAG run info to the devices' queues.
@@ -242,7 +148,7 @@ int RAI_DAGRun(RAI_DAGRunCtx *run_info, RAI_OnFinishCB DAGAsyncFinish, void *pri
242
148
}
243
149
// Make the inkeys and outkeys of the DAG ops unique, to ensure that the operations
244
150
// will be execute in the right order.
245
- if (MangleTensorsNames (rinfo ) != REDISMODULE_OK ) {
151
+ if (MapTensorsKeysToIndices (rinfo , rinfo -> tensorsNamesToIndices ) != REDISMODULE_OK ) {
246
152
RAI_SetError (err , rinfo -> err -> code , rinfo -> err -> detail );
247
153
return REDISMODULE_ERR ;
248
154
}
@@ -269,16 +175,13 @@ size_t RAI_DAGNumOutputs(RAI_OnFinishCtx *finish_ctx) {
269
175
const RAI_Tensor * RAI_DAGOutputTensor (RAI_OnFinishCtx * finish_ctx , size_t index ) {
270
176
size_t tensor_get_op_ind = -1 ;
271
177
RedisAI_RunInfo * rinfo = (RedisAI_RunInfo * )finish_ctx ;
178
+
272
179
for (size_t i = 0 ; i < rinfo -> dagOpCount ; i ++ ) {
273
180
RAI_DagOp * op = rinfo -> dagOps [i ];
274
181
if (op -> commandType == REDISAI_DAG_CMD_TENSORGET ) {
275
182
tensor_get_op_ind ++ ;
276
183
if (tensor_get_op_ind == index ) {
277
- RAI_Tensor * t ;
278
- int res = RAI_getTensorFromLocalContext (rinfo -> dagTensorsContext , op -> inkeys [0 ], & t ,
279
- op -> err );
280
- RedisModule_Assert (res == REDISMODULE_OK );
281
- return t ;
184
+ return Dag_GetTensorFromGlobalCtx (rinfo , op -> inkeys_indices [0 ]);
282
185
}
283
186
}
284
187
}
0 commit comments