Skip to content

Commit aeef254

Browse files
alonre24DvirDukhan
authored andcommitted
Merge pull request #582 from RedisAI/Turn_DAG_local_context_into_array
Turn dag local context dict into array
1 parent 27ca674 commit aeef254

File tree

14 files changed

+271
-370
lines changed

14 files changed

+271
-370
lines changed

src/DAG/dag.c

Lines changed: 84 additions & 145 deletions
Large diffs are not rendered by default.

src/DAG/dag.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,24 @@ void RedisAI_DagOpBatchInfo(RedisAI_RunInfo *rinfo, RAI_DagOp *op, size_t *batch
8888
void RedisAI_DagOpBatchingMatch(RedisAI_RunInfo *rinfo1, RAI_DagOp *op1, RedisAI_RunInfo *rinfo2,
8989
RAI_DagOp *op2, int *batched, size_t *inbatchsize);
9090

91+
/**
92+
* @brief Get a tensor from the dag local context in a given index
93+
* (this access to a shared array, require read lock)
94+
* @param rinfo The DAG runInfo.
95+
* @param index The index of the tensor in the Dag shared array to return
96+
* @return The tensor of the given index (NULL is returned if this tensor hasn't been realized yet)
97+
*/
98+
RAI_Tensor *Dag_GetTensorFromGlobalCtx(RedisAI_RunInfo *rinfo, size_t index);
99+
100+
/**
101+
* @brief Shallow copy and set a tensor in the dag local context in a given index.
102+
* (this access to a shared array, require write lock)
103+
* @param rinfo The DAG runInfo.
104+
* @param index The index to put in the given tensor in the Dag shared array.
105+
* @param t The tensor to shallow copy and store in the given index.
106+
*/
107+
void Dag_SetTensorInGlobalCtx(RedisAI_RunInfo *rinfo, size_t index, RAI_Tensor *t);
108+
91109
/**
92110
* Run the first unrealized DAG operation in rinfo for the given device.
93111
* @param rinfo context in which RedisAI blocking commands operate.

src/DAG/dag_builder.c

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,18 @@ int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor *t
2929

3030
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
3131
RedisModuleString *key_name = RedisModule_CreateString(NULL, t_name, strlen(t_name));
32-
// Add the tensor under its "mangled" key name to the DAG local context dict.
33-
char buf[16];
34-
sprintf(buf, "%04d", 1);
35-
RedisModule_StringAppendBuffer(NULL, key_name, buf, strlen(buf));
36-
AI_dictAdd(rinfo->dagTensorsContext, (void *)key_name,
37-
(void *)RAI_TensorGetShallowCopy(tensor));
32+
33+
// Cannot load more than one tensor under the same name
34+
if (AI_dictFind(rinfo->tensorsNamesToIndices, key_name) != NULL) {
35+
RedisModule_FreeString(NULL, key_name);
36+
return REDISMODULE_ERR;
37+
}
38+
39+
// Add the tensor to the DAG shared tensors and map its name to the relevant index.
40+
size_t index = array_len(rinfo->dagSharedTensors);
41+
AI_dictAdd(rinfo->tensorsNamesToIndices, (void *)key_name, (void *)index);
42+
RAI_TensorGetShallowCopy(tensor);
43+
rinfo->dagSharedTensors = array_append(rinfo->dagSharedTensors, (void *)tensor);
3844
RedisModule_FreeString(NULL, key_name);
3945

4046
return REDISMODULE_OK;

src/DAG/dag_execute.c

Lines changed: 36 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -3,155 +3,61 @@
33
#include "background_workers.h"
44
#include "util/string_utils.h"
55

6-
void _DAG_SetTensorsInLocalContext(RedisAI_RunInfo *rinfo) {
7-
for (size_t i = 0; i < rinfo->dagOpCount; i++) {
8-
RAI_DagOp *op = rinfo->dagOps[i];
9-
if (op->commandType == REDISAI_DAG_CMD_TENSORSET) {
10-
// Insert the tensor with its mangled (unique) name.
11-
void *t = (void *)RAI_TensorGetShallowCopy(op->outTensor);
12-
AI_dictReplace(rinfo->dagTensorsContext, (void *)op->outkeys[0], t);
13-
}
14-
}
15-
}
16-
17-
int MangleTensorsNames(RedisAI_RunInfo *rinfo) {
18-
19-
int res = REDISMODULE_ERR;
20-
AI_dict *mangled_tensors = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
6+
int ValidatePersistKeys(RedisAI_RunInfo *rinfo, AI_dict *tensorsNamesToInd,
7+
AI_dict *persistTensorsNames) {
218

229
{
23-
AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext);
24-
AI_dictEntry *entry = AI_dictNext(iter);
25-
while (entry) {
26-
RedisModuleString *key = (RedisModuleString *)AI_dictGetKey(entry);
27-
size_t key_len;
28-
const char *key_str = RedisModule_StringPtrLen(key, &key_len);
29-
RedisModuleString *demangled_key = RedisModule_CreateString(NULL, key_str, key_len - 4);
30-
int *instance = RedisModule_Alloc(sizeof(int));
31-
*instance = 1;
32-
AI_dictAdd(mangled_tensors, (void *)demangled_key, (void *)instance);
33-
RedisModule_FreeString(NULL, demangled_key);
34-
entry = AI_dictNext(iter);
10+
AI_dictIterator *iter = AI_dictGetSafeIterator(persistTensorsNames);
11+
AI_dictEntry *persist_entry;
12+
while ((persist_entry = AI_dictNext(iter))) {
13+
RedisModuleString *persist_key = (RedisModuleString *)AI_dictGetKey(persist_entry);
14+
AI_dictEntry *entry = AI_dictFind(tensorsNamesToInd, persist_key);
15+
if (!entry) {
16+
RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR PERSIST key cannot be found in DAG");
17+
AI_dictReleaseIterator(iter);
18+
return REDISMODULE_ERR;
19+
}
20+
size_t index = (size_t)AI_dictGetVal(entry);
21+
AI_dictReplace(persistTensorsNames, (void *)persist_key, (void *)index);
3522
}
3623
AI_dictReleaseIterator(iter);
3724
}
25+
return REDISMODULE_OK;
26+
}
27+
28+
int MapTensorsKeysToIndices(RedisAI_RunInfo *rinfo, AI_dict *tensorsNamesToInd) {
3829

3930
for (long long i = 0; i < array_len(rinfo->dagOps); i++) {
4031
RAI_DagOp *currentOp = rinfo->dagOps[i];
4132

42-
RedisModuleString **mangled_inkeys =
43-
array_new(RedisModuleString *, array_len(currentOp->inkeys));
4433
for (long long j = 0; j < array_len(currentOp->inkeys); j++) {
4534
RedisModuleString *key = currentOp->inkeys[j];
46-
AI_dictEntry *entry = AI_dictFind(mangled_tensors, key);
35+
AI_dictEntry *entry = AI_dictFind(tensorsNamesToInd, key);
4736
if (!entry) {
48-
array_free(mangled_inkeys);
4937
RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR INPUT key cannot be found in DAG");
50-
goto cleanup;
38+
return REDISMODULE_ERR;
5139
}
52-
int *instance = AI_dictGetVal(entry);
53-
char buf[16];
54-
sprintf(buf, "%04d", *instance);
55-
RedisModuleString *mangled_key = RedisModule_CreateStringFromString(NULL, key);
56-
RedisModule_StringAppendBuffer(NULL, mangled_key, buf, strlen(buf));
57-
mangled_inkeys = array_append(mangled_inkeys, mangled_key);
40+
size_t ind = (size_t)AI_dictGetVal(entry);
41+
currentOp->inkeys_indices = array_append(currentOp->inkeys_indices, ind);
5842
}
5943

60-
RedisModuleString **mangled_outkeys =
61-
array_new(RedisModuleString *, array_len(currentOp->outkeys));
6244
for (long long j = 0; j < array_len(currentOp->outkeys); j++) {
6345
RedisModuleString *key = currentOp->outkeys[j];
64-
AI_dictEntry *entry = AI_dictFind(mangled_tensors, key);
65-
int *instance = NULL;
66-
if (entry) {
67-
instance = AI_dictGetVal(entry);
68-
*instance += 1;
69-
} else {
70-
instance = RedisModule_Alloc(sizeof(int));
71-
*instance = 1;
72-
AI_dictAdd(mangled_tensors, (void *)key, (void *)instance);
73-
}
74-
char buf[16];
75-
sprintf(buf, "%04d", *instance);
76-
RedisModuleString *mangled_key = RedisModule_CreateStringFromString(NULL, key);
77-
RedisModule_StringAppendBuffer(NULL, mangled_key, buf, strlen(buf));
78-
mangled_outkeys = array_append(mangled_outkeys, mangled_key);
79-
}
80-
81-
if (currentOp->inkeys) {
82-
for (size_t j = 0; j < array_len(currentOp->inkeys); j++) {
83-
RedisModule_FreeString(NULL, currentOp->inkeys[j]);
84-
}
85-
array_free(currentOp->inkeys);
86-
}
87-
88-
if (currentOp->outkeys) {
89-
for (size_t j = 0; j < array_len(currentOp->outkeys); j++) {
90-
RedisModule_FreeString(NULL, currentOp->outkeys[j]);
91-
}
92-
array_free(currentOp->outkeys);
93-
}
94-
95-
currentOp->inkeys = mangled_inkeys;
96-
currentOp->outkeys = mangled_outkeys;
97-
}
46+
size_t ind = array_len(rinfo->dagSharedTensors);
9847

99-
AI_dict *mangled_persisted = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
100-
{
101-
AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsPersistedContext);
102-
AI_dictEntry *entry = AI_dictNext(iter);
103-
while (entry) {
104-
RedisModuleString *key = (RedisModuleString *)AI_dictGetKey(entry);
105-
AI_dictEntry *mangled_entry = AI_dictFind(mangled_tensors, key);
106-
if (!mangled_entry) {
107-
AI_dictRelease(mangled_persisted);
108-
AI_dictReleaseIterator(iter);
109-
RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR PERSIST key cannot be found in DAG");
110-
goto cleanup;
111-
}
112-
if (AI_dictFind(mangled_persisted, key) != NULL) {
113-
AI_dictRelease(mangled_persisted);
114-
AI_dictReleaseIterator(iter);
115-
RAI_SetError(rinfo->err, RAI_EDAGRUN, "ERR PERSIST keys must be unique");
116-
goto cleanup;
48+
// Add a new empty place holder in the array for an output tensor.
49+
// If this is a TENSORSET op, the tensor is already realized.
50+
if (currentOp->commandType == REDISAI_DAG_CMD_TENSORSET) {
51+
RAI_Tensor *t = RAI_TensorGetShallowCopy(currentOp->outTensor);
52+
rinfo->dagSharedTensors = array_append(rinfo->dagSharedTensors, t);
53+
} else {
54+
rinfo->dagSharedTensors = array_append(rinfo->dagSharedTensors, NULL);
11755
}
118-
int *instance = AI_dictGetVal(mangled_entry);
119-
char buf[16];
120-
sprintf(buf, "%04d", *instance);
121-
RedisModuleString *mangled_key = RedisModule_CreateStringFromString(NULL, key);
122-
RedisModule_StringAppendBuffer(NULL, mangled_key, buf, strlen(buf));
123-
AI_dictAdd(mangled_persisted, (void *)mangled_key, (void *)1);
124-
RedisModule_FreeString(NULL, mangled_key);
125-
entry = AI_dictNext(iter);
56+
currentOp->outkeys_indices = array_append(currentOp->outkeys_indices, ind);
57+
AI_dictReplace(tensorsNamesToInd, (void *)key, (void *)ind);
12658
}
127-
AI_dictReleaseIterator(iter);
12859
}
129-
130-
AI_dictRelease(rinfo->dagTensorsPersistedContext);
131-
rinfo->dagTensorsPersistedContext = mangled_persisted;
132-
133-
for (long long i = 0; i < array_len(rinfo->dagOps); i++) {
134-
if (rinfo->dagOps[i]->devicestr == NULL) {
135-
rinfo->dagOps[i]->devicestr = "CPU";
136-
}
137-
}
138-
// Tensors from TENSORSET ops are ready to be put in DAG local context under their mangled
139-
// names.
140-
_DAG_SetTensorsInLocalContext(rinfo);
141-
res = REDISMODULE_OK;
142-
143-
cleanup : {
144-
AI_dictIterator *iter = AI_dictGetSafeIterator(mangled_tensors);
145-
AI_dictEntry *entry = AI_dictNext(iter);
146-
while (entry) {
147-
int *val = (int *)AI_dictGetVal(entry);
148-
RedisModule_Free(val);
149-
entry = AI_dictNext(iter);
150-
}
151-
AI_dictReleaseIterator(iter);
152-
}
153-
AI_dictRelease(mangled_tensors);
154-
return res;
60+
return REDISMODULE_OK;
15561
}
15662

15763
// Add Shallow copies of the DAG run info to the devices' queues.
@@ -242,7 +148,7 @@ int RAI_DAGRun(RAI_DAGRunCtx *run_info, RAI_OnFinishCB DAGAsyncFinish, void *pri
242148
}
243149
// Make the inkeys and outkeys of the DAG ops unique, to ensure that the operations
244150
// will be execute in the right order.
245-
if (MangleTensorsNames(rinfo) != REDISMODULE_OK) {
151+
if (MapTensorsKeysToIndices(rinfo, rinfo->tensorsNamesToIndices) != REDISMODULE_OK) {
246152
RAI_SetError(err, rinfo->err->code, rinfo->err->detail);
247153
return REDISMODULE_ERR;
248154
}
@@ -269,16 +175,13 @@ size_t RAI_DAGNumOutputs(RAI_OnFinishCtx *finish_ctx) {
269175
const RAI_Tensor *RAI_DAGOutputTensor(RAI_OnFinishCtx *finish_ctx, size_t index) {
270176
size_t tensor_get_op_ind = -1;
271177
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)finish_ctx;
178+
272179
for (size_t i = 0; i < rinfo->dagOpCount; i++) {
273180
RAI_DagOp *op = rinfo->dagOps[i];
274181
if (op->commandType == REDISAI_DAG_CMD_TENSORGET) {
275182
tensor_get_op_ind++;
276183
if (tensor_get_op_ind == index) {
277-
RAI_Tensor *t;
278-
int res = RAI_getTensorFromLocalContext(rinfo->dagTensorsContext, op->inkeys[0], &t,
279-
op->err);
280-
RedisModule_Assert(res == REDISMODULE_OK);
281-
return t;
184+
return Dag_GetTensorFromGlobalCtx(rinfo, op->inkeys_indices[0]);
282185
}
283186
}
284187
}

src/DAG/dag_execute.h

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,39 @@
44
#include "run_info.h"
55

66
/**
7-
* @brief We are given a DAG runInfo of a sequence of operations, each with its own
7+
@brief We are given a DAG runInfo of a sequence of operations, each with its own
88
input and output keys. The names of the keys will be used to look whether the
99
inputs to a DAG operation have all been realized by previous operations (or if
1010
they are available as part of LOADed keys from keyspace).
1111
This strategy is fine if keys are not aliased, that is, if a command's output
1212
overwrites the key of a previous command. This would trick DAG operations into
1313
thinking that their input is ready when it's not.
14-
To overcome this, we make key names unique, so that names are not aliased. We
15-
mangle the names by appending a numerical suffix ":0001". After computing, we
16-
demangle the keys in order to persist them.*/
17-
int MangleTensorsNames(RedisAI_RunInfo *rinfo);
14+
To overcome this, we map the input and output tensors for every operation to indices,
15+
in the following way. For every input of an operation having the key "x", we map the index
16+
for which "x" was last mapped to when, it was an output of a previous operation.
17+
For every output of an operation "y", we map the next available index in the array.
18+
Every entry in the DAG array contains NULL (except for tensors that where loaded
19+
before the DAG run starts).
20+
@param rinfo The DAG runInfo.
21+
@param tensorsNamesToInd A dict mapping every key name of a tensor that appeared
22+
in DAG operation, to the maximal index of the DAG shared array for which they were mapped to.
23+
@returns REDISMODULE_ERR if there exists an operation for which one of the input
24+
tensors didn't appear as an output of a previous operation, REDISMODULE_OK otherwise
25+
*/
26+
int MapTensorsKeysToIndices(RedisAI_RunInfo *rinfo, AI_dict *tensorsNamesToInd);
27+
28+
/**
29+
* @brief Validates that tensors key names to persist appeared in the DAG operations.
30+
* @param rinfo The DAG runInfo.
31+
* @param tensorsNamesToInd A dict mapping every key name of a tensor that appeared
32+
* in DAG operation, to the maximal index of the DAG shared array for which they were mapped to.
33+
* @param persistTensorsNames A hash table the contains the names of the tensors
34+
* to persist when the DAG run is finished.
35+
* @return REDISMODULE_ERR if there exists a tensor key to persist that didn't
36+
* appear in DAG operation, REDISMODULE_OK otherwise
37+
*/
38+
int ValidatePersistKeys(RedisAI_RunInfo *rinfo, AI_dict *tensorsNamesToInd,
39+
AI_dict *persistTensorsNames);
1840

1941
/**
2042
* @brief Run asynchronously a DAG. This will validate that the sequence of DAG ops

0 commit comments

Comments
 (0)