|
| 1 | +#include "dag_builder.h" |
| 2 | +#include "run_info.h" |
| 3 | +#include "dag_parser.h" |
| 4 | +#include "string_utils.h" |
| 5 | +#include "modelRun_ctx.h" |
| 6 | + |
| 7 | +// Store the given arguments from the string in argv array and their amount in argc. |
| 8 | +int _StringToRMArray(const char *dag, RedisModuleString ***argv, int *argc, RAI_Error *err) { |
| 9 | + |
| 10 | + char dag_string[strlen(dag) + 1]; |
| 11 | + strcpy(dag_string, dag); |
| 12 | + |
| 13 | + char *token = strtok(dag_string, " "); |
| 14 | + if (strcmp(token, "|>") != 0) { |
| 15 | + RAI_SetError(err, RAI_EDAGBUILDER, "DAG op should start with: '|>' "); |
| 16 | + return REDISMODULE_ERR; |
| 17 | + } |
| 18 | + |
| 19 | + while (token != NULL) { |
| 20 | + RedisModuleString *RS_token = RedisModule_CreateString(NULL, token, strlen(token)); |
| 21 | + *argv = array_append(*argv, RS_token); |
| 22 | + (*argc)++; |
| 23 | + token = strtok(NULL, " "); |
| 24 | + } |
| 25 | + return REDISMODULE_OK; |
| 26 | +} |
| 27 | + |
| 28 | +int RAI_DAGLoadTensor(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor *tensor) { |
| 29 | + |
| 30 | + RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info; |
| 31 | + RedisModuleString *key_name = RedisModule_CreateString(NULL, t_name, strlen(t_name)); |
| 32 | + // Add the tensor under its "mangled" key name to the DAG local context dict. |
| 33 | + char buf[16]; |
| 34 | + sprintf(buf, "%04d", 1); |
| 35 | + RedisModule_StringAppendBuffer(NULL, key_name, buf, strlen(buf)); |
| 36 | + AI_dictAdd(rinfo->dagTensorsContext, (void *)key_name, |
| 37 | + (void *)RAI_TensorGetShallowCopy(tensor)); |
| 38 | + RedisModule_FreeString(NULL, key_name); |
| 39 | + |
| 40 | + return REDISMODULE_OK; |
| 41 | +} |
| 42 | + |
| 43 | +RAI_DAGRunCtx *RAI_DAGRunCtxCreate(void) { |
| 44 | + RedisAI_RunInfo *rinfo; |
| 45 | + RAI_InitRunInfo(&rinfo); |
| 46 | + return (RAI_DAGRunCtx *)rinfo; |
| 47 | +} |
| 48 | + |
| 49 | +RAI_DAGRunOp *RAI_DAGCreateModelRunOp(RAI_Model *model) { |
| 50 | + RAI_ModelRunCtx *mctx = RAI_ModelRunCtxCreate(model); |
| 51 | + RAI_DagOp *op; |
| 52 | + RAI_InitDagOp(&op); |
| 53 | + |
| 54 | + op->commandType = REDISAI_DAG_CMD_MODELRUN; |
| 55 | + op->mctx = mctx; |
| 56 | + op->devicestr = model->devicestr; |
| 57 | + op->runkey = RAI_HoldString(NULL, (RedisModuleString *)model->infokey); |
| 58 | + return (RAI_DAGRunOp *)op; |
| 59 | +} |
| 60 | + |
| 61 | +RAI_DAGRunOp *RAI_DAGCreateScriptRunOp(RAI_Script *script, const char *func_name) { |
| 62 | + RAI_ScriptRunCtx *sctx = RAI_ScriptRunCtxCreate(script, func_name); |
| 63 | + RAI_DagOp *op; |
| 64 | + RAI_InitDagOp(&op); |
| 65 | + |
| 66 | + op->commandType = REDISAI_DAG_CMD_SCRIPTRUN; |
| 67 | + op->sctx = sctx; |
| 68 | + op->devicestr = script->devicestr; |
| 69 | + op->runkey = RAI_HoldString(NULL, (RedisModuleString *)script->infokey); |
| 70 | + return (RAI_DAGRunOp *)op; |
| 71 | +} |
| 72 | + |
| 73 | +int RAI_DAGRunOpAddInput(RAI_DAGRunOp *DAGOp, const char *input) { |
| 74 | + RAI_DagOp *op = (RAI_DagOp *)DAGOp; |
| 75 | + RedisModuleString *inkey = RedisModule_CreateString(NULL, input, strlen(input)); |
| 76 | + op->inkeys = array_append(op->inkeys, inkey); |
| 77 | + return REDISMODULE_OK; |
| 78 | +} |
| 79 | + |
| 80 | +int RAI_DAGRunOpAddOutput(RAI_DAGRunOp *DAGOp, const char *output) { |
| 81 | + RAI_DagOp *op = (RAI_DagOp *)DAGOp; |
| 82 | + RedisModuleString *outkey = RedisModule_CreateString(NULL, output, strlen(output)); |
| 83 | + op->outkeys = array_append(op->outkeys, outkey); |
| 84 | + return REDISMODULE_OK; |
| 85 | +} |
| 86 | + |
| 87 | +int RAI_DAGAddRunOp(RAI_DAGRunCtx *run_info, RAI_DAGRunOp *DAGop, RAI_Error *err) { |
| 88 | + |
| 89 | + RAI_DagOp *op = (RAI_DagOp *)DAGop; |
| 90 | + RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info; |
| 91 | + if (op->mctx) { |
| 92 | + RAI_Model *model = op->mctx->model; |
| 93 | + if (ModelGetNumInputs(model) != array_len(op->inkeys)) { |
| 94 | + RAI_SetError(err, RAI_EDAGBUILDER, |
| 95 | + "Number of keys given as INPUTS does not match model definition"); |
| 96 | + return REDISMODULE_ERR; |
| 97 | + } |
| 98 | + if (ModelGetNumOutputs(model) != array_len(op->outkeys)) { |
| 99 | + RAI_SetError(err, RAI_EDAGBUILDER, |
| 100 | + "Number of keys given as OUTPUTS does not match model definition"); |
| 101 | + return REDISMODULE_ERR; |
| 102 | + } |
| 103 | + } |
| 104 | + rinfo->dagOps = array_append(rinfo->dagOps, op); |
| 105 | + |
| 106 | + return REDISMODULE_OK; |
| 107 | +} |
| 108 | + |
| 109 | +int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err) { |
| 110 | + |
| 111 | + RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info; |
| 112 | + RAI_DagOp *op; |
| 113 | + RAI_InitDagOp(&op); |
| 114 | + rinfo->dagOps = array_append(rinfo->dagOps, op); |
| 115 | + op->commandType = REDISAI_DAG_CMD_TENSORGET; |
| 116 | + op->devicestr = "CPU"; |
| 117 | + RedisModuleString *name = RedisModule_CreateString(NULL, t_name, strlen(t_name)); |
| 118 | + op->inkeys = array_append(op->inkeys, name); |
| 119 | + return REDISMODULE_OK; |
| 120 | +} |
| 121 | + |
| 122 | +int RAI_DAGAddTensorSet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor *tensor) { |
| 123 | + |
| 124 | + RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info; |
| 125 | + RAI_DagOp *op; |
| 126 | + RAI_InitDagOp(&op); |
| 127 | + rinfo->dagOps = array_append(rinfo->dagOps, op); |
| 128 | + op->commandType = REDISAI_DAG_CMD_TENSORSET; |
| 129 | + op->devicestr = "CPU"; |
| 130 | + RedisModuleString *name = RedisModule_CreateString(NULL, t_name, strlen(t_name)); |
| 131 | + op->outkeys = array_append(op->outkeys, name); |
| 132 | + op->outTensor = RAI_TensorGetShallowCopy(tensor); |
| 133 | + return REDISMODULE_OK; |
| 134 | +} |
| 135 | + |
| 136 | +int RAI_DAGAddOpsFromString(RAI_DAGRunCtx *run_info, const char *dag, RAI_Error *err) { |
| 137 | + |
| 138 | + int res = REDISMODULE_ERR; |
| 139 | + RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info; |
| 140 | + array_new_on_stack(RAI_DagOp *, 10, new_ops); |
| 141 | + array_new_on_stack(RedisModuleString *, 100, argv); |
| 142 | + int argc = 0; |
| 143 | + if (_StringToRMArray(dag, &argv, &argc, err) != REDISMODULE_OK) { |
| 144 | + goto cleanup; |
| 145 | + } |
| 146 | + |
| 147 | + RAI_DagOp *op; |
| 148 | + for (size_t i = 0; i < argc; i++) { |
| 149 | + const char *arg_string = RedisModule_StringPtrLen(argv[i], NULL); |
| 150 | + if (strcmp(arg_string, "|>") == 0 && i < argc - 1) { |
| 151 | + RAI_InitDagOp(&op); |
| 152 | + new_ops = array_append(new_ops, op); |
| 153 | + op->argv = &argv[i + 1]; |
| 154 | + } else { |
| 155 | + op->argc++; |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + if (ParseDAGOps(rinfo, new_ops) != REDISMODULE_OK) { |
| 160 | + RAI_SetError(err, RAI_GetErrorCode(rinfo->err), RAI_GetError(rinfo->err)); |
| 161 | + goto cleanup; |
| 162 | + } |
| 163 | + rinfo->dagOpCount = array_len(rinfo->dagOps); |
| 164 | + res = REDISMODULE_OK; |
| 165 | + |
| 166 | +cleanup: |
| 167 | + array_free(new_ops); |
| 168 | + for (size_t i = 0; i < argc; i++) { |
| 169 | + RedisModule_FreeString(NULL, argv[i]); |
| 170 | + } |
| 171 | + array_free(argv); |
| 172 | + return res; |
| 173 | +} |
| 174 | + |
| 175 | +size_t RAI_DAGNumOps(RAI_DAGRunCtx *run_info) { |
| 176 | + RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info; |
| 177 | + return array_len(rinfo->dagOps); |
| 178 | +} |
| 179 | + |
| 180 | +void RAI_DAGRunOpFree(RAI_DAGRunOp *dagOp) { |
| 181 | + RAI_DagOp *op = (RAI_DagOp *)dagOp; |
| 182 | + RAI_FreeDagOp(op); |
| 183 | +} |
| 184 | + |
| 185 | +void RAI_DAGFree(RAI_DAGRunCtx *run_info) { |
| 186 | + RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info; |
| 187 | + RAI_FreeRunInfo(rinfo); |
| 188 | +} |
0 commit comments