Skip to content

Commit 7db6b5d

Browse files
committed
- Setting a global error in DAG runInfo when one of its ops returns with an error. this is the one that is returned in LLAPI.
- Add macro to redisai.h to resolve double includes issues. - Update arr.h, support creating dynamic arrays on stack (in addition). - PR fixes, test module refactor
1 parent 102b25c commit 7db6b5d

17 files changed

+774
-695
lines changed

src/DAG/dag.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,9 @@ void RedisAI_DagRunSessionStep(RedisAI_RunInfo *rinfo, const char *devicestr) {
571571

572572
if (currentOp->result != REDISMODULE_OK) {
573573
__atomic_store_n(rinfo->dagError, 1, __ATOMIC_RELAXED);
574+
RAI_ContextWriteLock(rinfo);
575+
RAI_SetError(rinfo->err, RAI_GetErrorCode(currentOp->err), RAI_GetError(currentOp->err));
576+
RAI_ContextUnlock(rinfo);
574577
}
575578
}
576579

@@ -599,6 +602,8 @@ void RedisAI_BatchedDagRunSessionStep(RedisAI_RunInfo **batched_rinfo, const cha
599602

600603
if (currentOp->result != REDISMODULE_OK) {
601604
__atomic_store_n(rinfo->dagError, 1, __ATOMIC_RELAXED);
605+
RAI_SetError(rinfo->err, RAI_GetErrorCode(currentOp->err),
606+
RAI_GetError(currentOp->err));
602607
}
603608
}
604609
return;

src/DAG/dag_builder.c

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,27 @@
44
#include "string_utils.h"
55
#include "modelRun_ctx.h"
66

7+
// Store the given arguments from the string in argv array and their amount in argc.
8+
int _StringToRMArray(const char *dag, RedisModuleString ***argv, int *argc, RAI_Error *err) {
9+
10+
char dag_string[strlen(dag) + 1];
11+
strcpy(dag_string, dag);
12+
13+
char *token = strtok(dag_string, " ");
14+
if (strcmp(token, "|>") != 0) {
15+
RAI_SetError(err, RAI_EDAGBUILDER, "DAG op should start with: '|>' ");
16+
return REDISMODULE_ERR;
17+
}
18+
19+
while (token != NULL) {
20+
RedisModuleString *RS_token = RedisModule_CreateString(NULL, token, strlen(token));
21+
*argv = array_append(*argv, RS_token);
22+
(*argc)++;
23+
token = strtok(NULL, " ");
24+
}
25+
return REDISMODULE_OK;
26+
}
27+
728
int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor *tensor) {
829

930
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
@@ -116,51 +137,43 @@ int RAI_DAGAddOpsFromString(RAI_DAGRunCtx *run_info, const char *dag, RAI_Error
116137

117138
int res = REDISMODULE_ERR;
118139
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
140+
array_new_on_stack(RAI_DagOp *, 10, new_ops);
141+
array_new_on_stack(RedisModuleString *, 100, argv);
119142
int argc = 0;
120-
char dag_string[strlen(dag) + 1];
121-
strcpy(dag_string, dag);
122-
123-
char *token = strtok(dag_string, " ");
124-
if (strcmp(token, "|>") != 0) {
125-
RAI_SetError(err, RAI_EDAGBUILDER, "DAG op should start with: '|>' ");
126-
return res;
127-
}
128-
RedisModuleString **argv = array_new(RedisModuleString *, 2);
129-
while (token != NULL) {
130-
RedisModuleString *RS_token = RedisModule_CreateString(NULL, token, strlen(token));
131-
argv = array_append(argv, RS_token);
132-
argc++;
133-
token = strtok(NULL, " ");
143+
if (_StringToRMArray(dag, &argv, &argc, err) != REDISMODULE_OK) {
144+
goto cleanup;
134145
}
135146

136-
size_t num_ops_before = array_len(rinfo->dagOps);
137-
size_t new_ops = 0;
138147
RAI_DagOp *op;
139148
for (size_t i = 0; i < argc; i++) {
140149
const char *arg_string = RedisModule_StringPtrLen(argv[i], NULL);
141150
if (strcmp(arg_string, "|>") == 0 && i < argc - 1) {
142151
RAI_InitDagOp(&op);
143-
rinfo->dagOps = array_append(rinfo->dagOps, op);
144-
new_ops++;
152+
new_ops = array_append(new_ops, op);
145153
op->argv = &argv[i + 1];
146154
} else {
147155
op->argc++;
148156
}
149157
}
150158

151-
if (ParseDAGOps(rinfo, num_ops_before, new_ops) != REDISMODULE_OK) {
152-
// Remove all ops that where added before the error and go back to the initial state.
159+
if (ParseDAGOps(rinfo, new_ops) != REDISMODULE_OK) {
160+
// Remove all ops that where created.
153161
RAI_SetError(err, RAI_GetErrorCode(rinfo->err), RAI_GetError(rinfo->err));
154-
for (size_t i = num_ops_before; i < array_len(rinfo->dagOps); i++) {
155-
RAI_FreeDagOp(rinfo->dagOps[i]);
162+
for (size_t i = 0; i < array_len(new_ops); i++) {
163+
RAI_FreeDagOp(new_ops[i]);
156164
}
157-
rinfo->dagOps = array_trimm_len(rinfo->dagOps, num_ops_before);
158165
goto cleanup;
159166
}
167+
168+
// Copy the new op pointers to the DAG run info.
169+
for (size_t i = 0; i < array_len(new_ops); i++) {
170+
rinfo->dagOps = array_append(rinfo->dagOps, new_ops[i]);
171+
}
160172
rinfo->dagOpCount = array_len(rinfo->dagOps);
161173
res = REDISMODULE_OK;
162174

163175
cleanup:
176+
array_free(new_ops);
164177
for (size_t i = 0; i < argc; i++) {
165178
RedisModule_FreeString(NULL, argv[i]);
166179
}

src/DAG/dag_execute.c

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -285,16 +285,11 @@ RAI_Tensor *RAI_DAGOutputTensor(RAI_OnFinishCtx *finish_ctx, size_t index) {
285285
return NULL;
286286
}
287287

288-
int RAI_DAGRunError(RAI_OnFinishCtx *finish_ctx) {
288+
bool RAI_DAGRunError(RAI_OnFinishCtx *finish_ctx) {
289289
return *((RedisAI_RunInfo *)finish_ctx)->dagError;
290290
}
291291

292-
RAI_Error *RAI_DAGCopyOpStatus(RAI_OnFinishCtx *finish_ctx, size_t index) {
292+
RAI_Error *RAI_DAGGetError(RAI_OnFinishCtx *finish_ctx) {
293293
RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)finish_ctx;
294-
RedisModule_Assert(index < rinfo->dagOpCount);
295-
RAI_Error *err;
296-
RAI_InitError(&err);
297-
RAI_SetError(err, RAI_GetErrorCode(rinfo->dagOps[index]->err),
298-
RAI_GetError(rinfo->dagOps[index]->err));
299-
return err;
294+
return rinfo->err;
300295
}

src/DAG/dag_execute.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,14 @@ size_t RAI_DAGNumOutputs(RAI_OnFinishCtx *finish_ctx);
4545
RAI_Tensor *RAI_DAGOutputTensor(RAI_OnFinishCtx *finish_ctx, size_t index);
4646

4747
/**
48-
* @brief Returns 1 if (at least) one of the DAG ops encountered an error.
48+
* @brief Returns true if (at least) one of the DAG ops encountered an error.
4949
*/
50-
int RAI_DAGRunError(RAI_OnFinishCtx *finish_ctx);
50+
bool RAI_DAGRunError(RAI_OnFinishCtx *finish_ctx);
5151

5252
/**
53-
* @brief This can be called in the finish CB, returns the status of a certain in a DAG.
53+
* @brief This can be called in the finish CB, to get DAG error details.
5454
* @param finish_ctx This represents the DAG runInfo at the end of the run.
55-
* @param index Index of a specific op in the DAG.
56-
* @retval returns an object that represents the i'th op status, from which a user can
55+
* @retval returns an object that represents the DAG status, from which a user can
5756
* obtain the error code (error code is "OK" if no error has occurred) and error details.
5857
*/
59-
RAI_Error *RAI_DAGCopyOpStatus(RAI_OnFinishCtx *finish_ctx, size_t index);
58+
RAI_Error *RAI_DAGGetError(RAI_OnFinishCtx *finish_ctx);

src/DAG/dag_parser.c

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,10 @@ static int _parseTimeout(RedisModuleString **argv, int argc, long long *timeout,
139139
return REDISMODULE_OK;
140140
}
141141

142-
static RAI_DagOp *_AddEmptyOp(RedisAI_RunInfo *rinfo) {
142+
static RAI_DagOp *_AddEmptyOp(RAI_DagOp ***ops) {
143143
RAI_DagOp *currentDagOp;
144144
RAI_InitDagOp(&currentDagOp);
145-
rinfo->dagOps = array_append(rinfo->dagOps, currentDagOp);
145+
*ops = array_append(*ops, currentDagOp);
146146
return currentDagOp;
147147
}
148148

@@ -160,10 +160,10 @@ int _CollectOpArgs(RedisModuleString **argv, int argc, int arg_pos, RAI_DagOp *o
160160
return op->argc;
161161
}
162162

163-
int ParseDAGOps(RedisAI_RunInfo *rinfo, size_t first_op_ind, size_t num_ops) {
163+
int ParseDAGOps(RedisAI_RunInfo *rinfo, RAI_DagOp **ops) {
164164

165-
for (long long i = 0; i < num_ops; i++) {
166-
RAI_DagOp *currentOp = rinfo->dagOps[i + first_op_ind];
165+
for (long long i = 0; i < array_len(ops); i++) {
166+
RAI_DagOp *currentOp = ops[i];
167167
// The first op arg is the command name.
168168
const char *arg_string = RedisModule_StringPtrLen(currentOp->argv[0], NULL);
169169

@@ -211,6 +211,7 @@ int ParseDAGOps(RedisAI_RunInfo *rinfo, size_t first_op_ind, size_t num_ops) {
211211
int ParseDAGRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleString **argv,
212212
int argc, bool dag_ro) {
213213

214+
int res = REDISMODULE_ERR;
214215
if (argc < 4) {
215216
if (dag_ro) {
216217
RAI_SetError(rinfo->err, RAI_EDAGBUILDER,
@@ -219,14 +220,15 @@ int ParseDAGRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleS
219220
RAI_SetError(rinfo->err, RAI_EDAGBUILDER,
220221
"ERR wrong number of arguments for 'AI.DAGRUN' command");
221222
}
222-
goto cleanup;
223+
return res;
223224
}
224225

225226
int chainingOpCount = 0;
226227
int arg_pos = 1;
227228
bool load_complete = false;
228229
bool persist_complete = false;
229230
bool timeout_complete = false;
231+
array_new_on_stack(RAI_DagOp *, 10, dag_ops);
230232

231233
// The first arg is "AI.DAGRUN", so we go over from the next arg.
232234
while (arg_pos < argc) {
@@ -273,7 +275,7 @@ int ParseDAGRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleS
273275
}
274276

275277
if (!strcasecmp(arg_string, "|>") && arg_pos < argc - 1) {
276-
RAI_DagOp *currentOp = _AddEmptyOp(rinfo);
278+
RAI_DagOp *currentOp = _AddEmptyOp(&dag_ops);
277279
chainingOpCount++;
278280
int args_num = _CollectOpArgs(argv, argc, ++arg_pos, currentOp);
279281
arg_pos += args_num;
@@ -283,19 +285,30 @@ int ParseDAGRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleS
283285
RAI_SetError(rinfo->err, RAI_EDAGBUILDER, "ERR Invalid DAGRUN command");
284286
goto cleanup;
285287
}
286-
rinfo->dagOpCount = array_len(rinfo->dagOps);
287-
if (rinfo->dagOpCount < 1) {
288+
289+
if (array_len(dag_ops) < 1) {
288290
RAI_SetError(rinfo->err, RAI_EDAGBUILDER, "ERR DAG is empty");
289291
goto cleanup;
290292
}
291-
if (ParseDAGOps(rinfo, 0, rinfo->dagOpCount) != REDISMODULE_OK) {
293+
294+
if (ParseDAGOps(rinfo, dag_ops) != REDISMODULE_OK) {
295+
for (size_t i = 0; i < array_len(dag_ops); i++) {
296+
RAI_FreeDagOp(dag_ops[i]);
297+
}
292298
goto cleanup;
293299
}
300+
// After validating all the ops, insert them to the DAG.
301+
for (size_t i = 0; i < array_len(dag_ops); i++) {
302+
rinfo->dagOps = array_append(rinfo->dagOps, dag_ops[i]);
303+
}
304+
rinfo->dagOpCount = array_len(rinfo->dagOps);
305+
294306
if (MangleTensorsNames(rinfo) != REDISMODULE_OK) {
295307
goto cleanup;
296308
}
297-
return REDISMODULE_OK;
309+
res = REDISMODULE_OK;
298310

299311
cleanup:
300-
return REDISMODULE_ERR;
312+
array_free(dag_ops);
313+
return res;
301314
}

src/DAG/dag_parser.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@ int ParseDAGRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleS
1515
int argc, bool dag_ro);
1616

1717
/**
18-
* @brief Parse the arguments of ops in the DAGRUN command and build (or extend) the DagOp object
19-
* accordingly.
20-
* @param rinfo The DAG run info with its op, where every op has an argv field that points to an
21-
* array of RedisModule strings the represents the op, and an argc field which is the number of
18+
* @brief Parse the arguments of the given ops in the DAGRUN command and build every op accordingly.
19+
* @param rinfo The DAG run info that will be populated with the ops if they are valid.
20+
* with its op,
21+
* @param ops A local array of ops, where every op has an argv field that points to an
22+
* array of RedisModule strings arguments, and an argc field which is the number of
2223
* args.
23-
* @param first_op_ind The index of the first op in the for which we parse its argument and build
24-
* it.
25-
* @param num_ops The number of ops in the DAG the need to be parsed.
2624
* @return Returns REDISMODULE_OK if the command is valid, REDISMODULE_ERR otherwise.
2725
*/
28-
int ParseDAGOps(RedisAI_RunInfo *rinfo, size_t first_op_ind, size_t num_ops);
26+
int ParseDAGOps(RedisAI_RunInfo *rinfo, RAI_DagOp **ops);

src/command_parser.c

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModu
128128
int res = REDISMODULE_ERR;
129129
// Build a ModelRunCtx from command.
130130
RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL);
131-
// int lock_status = RedisModule_ThreadSafeContextTryLock(ctx);
132131
RAI_Model *model;
133132
long long timeout = 0;
134133
if (_ModelRunCommand_ParseArgs(argv, argc, ctx, &model, rinfo->err, &currentOp->inkeys,
@@ -157,9 +156,6 @@ int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModu
157156
res = REDISMODULE_OK;
158157

159158
cleanup:
160-
// if (lock_status == REDISMODULE_OK) {
161-
// RedisModule_ThreadSafeContextUnlock(ctx);
162-
//}
163159
RedisModule_FreeThreadSafeContext(ctx);
164160
return res;
165161
}
@@ -392,8 +388,7 @@ int RedisAI_ExecuteCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
392388
rinfo->OnFinish = DAG_ReplyAndUnblock;
393389
rinfo->client = RedisModule_BlockClient(ctx, RedisAI_DagRun_Reply, NULL, RunInfo_FreeData, 0);
394390
if (DAG_InsertDAGToQueue(rinfo) != REDISMODULE_OK) {
395-
RedisModule_ReplyWithError(ctx, rinfo->err->detail_oneline);
396-
RAI_FreeRunInfo(rinfo);
391+
RedisModule_UnblockClient(rinfo->client, rinfo);
397392
return REDISMODULE_ERR;
398393
}
399394
return REDISMODULE_OK;

src/redisai.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,7 @@ static int RedisAI_RegisterApi(RedisModuleCtx *ctx) {
953953
REGISTER_API(GetError, ctx);
954954
REGISTER_API(GetErrorOneLine, ctx);
955955
REGISTER_API(GetErrorCode, ctx);
956+
REGISTER_API(SetError, ctx);
956957

957958
REGISTER_API(TensorCreate, ctx);
958959
REGISTER_API(TensorCreateByConcatenatingTensors, ctx);
@@ -975,6 +976,7 @@ static int RedisAI_RegisterApi(RedisModuleCtx *ctx) {
975976
REGISTER_API(ModelCreate, ctx);
976977
REGISTER_API(ModelFree, ctx);
977978
REGISTER_API(ModelRunCtxCreate, ctx);
979+
REGISTER_API(GetModelFromKeyspace, ctx);
978980
REGISTER_API(ModelRunCtxAddInput, ctx);
979981
REGISTER_API(ModelRunCtxAddOutput, ctx);
980982
REGISTER_API(ModelRunCtxNumOutputs, ctx);
@@ -1017,7 +1019,7 @@ static int RedisAI_RegisterApi(RedisModuleCtx *ctx) {
10171019
REGISTER_API(DAGNumOutputs, ctx);
10181020
REGISTER_API(DAGOutputTensor, ctx);
10191021
REGISTER_API(DAGRunError, ctx);
1020-
REGISTER_API(DAGCopyOpStatus, ctx);
1022+
REGISTER_API(DAGGetError, ctx);
10211023
REGISTER_API(DAGRunOpFree, ctx);
10221024
REGISTER_API(DAGFree, ctx);
10231025

0 commit comments

Comments
 (0)