diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9217a4279..cf135fdc3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,6 +5,7 @@ endif() ADD_LIBRARY(redisai_obj OBJECT util/dict.c util/queue.c + util/string_utils.c redisai.c run_info.c background_workers.c diff --git a/src/dag.c b/src/dag.c index f37506c45..a46c7b575 100644 --- a/src/dag.c +++ b/src/dag.c @@ -43,6 +43,7 @@ #include "util/arr_rm_alloc.h" #include "util/dict.h" #include "util/queue.h" +#include "util/string_utils.h" /** * Execution of a TENSORSET DAG step. @@ -57,7 +58,7 @@ void RedisAI_DagRunSession_TensorSet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur const int parse_result = RAI_parseTensorSetArgs(NULL, currentOp->argv, currentOp->argc, &t, 0, currentOp->err); if (parse_result > 0) { - const char *key_string = RedisModule_StringPtrLen(currentOp->outkeys[0], NULL); + RedisModuleString *key_string = currentOp->outkeys[0]; RAI_ContextWriteLock(rinfo); AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, t); RAI_ContextUnlock(rinfo); @@ -76,7 +77,7 @@ void RedisAI_DagRunSession_TensorSet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur * @return */ void RedisAI_DagRunSession_TensorGet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp) { - const char *key_string = RedisModule_StringPtrLen(currentOp->inkeys[0], NULL); + RedisModuleString *key_string = currentOp->inkeys[0]; RAI_Tensor *t = NULL; RAI_ContextReadLock(rinfo); currentOp->result = RAI_getTensorFromLocalContext(NULL, rinfo->dagTensorsContext, key_string, @@ -108,8 +109,7 @@ void RedisAI_DagRunSession_ModelRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *curr for (uint i = 0; i < n_inkeys; i++) { RAI_Tensor *inputTensor; const int get_result = RAI_getTensorFromLocalContext( - NULL, rinfo->dagTensorsContext, RedisModule_StringPtrLen(currentOp->inkeys[i], NULL), - &inputTensor, currentOp->err); + NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err); if (get_result == REDISMODULE_ERR) { // We check for this outside the function // this check cannot be covered by tests @@ -156,7 +156,7 @@ void RedisAI_DagRunSession_ModelRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *curr const size_t noutputs = RAI_ModelRunCtxNumOutputs(currentOp->mctx); for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) { RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(currentOp->mctx, outputNumber); - const char *key_string = RedisModule_StringPtrLen(currentOp->outkeys[outputNumber], NULL); + RedisModuleString *key_string = currentOp->outkeys[outputNumber]; tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL; AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, tensor); } @@ -196,8 +196,7 @@ void RedisAI_BatchedDagRunSession_ModelRun_Step(RedisAI_RunInfo **batched_rinfo, for (uint i = 0; i < n_inkeys; i++) { RAI_Tensor *inputTensor; const int get_result = RAI_getTensorFromLocalContext( - NULL, rinfo->dagTensorsContext, - RedisModule_StringPtrLen(currentOp->inkeys[i], NULL), &inputTensor, currentOp->err); + NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err); if (get_result == REDISMODULE_ERR) { // We check for this outside the function // this check cannot be covered by tests @@ -253,8 +252,7 @@ void RedisAI_BatchedDagRunSession_ModelRun_Step(RedisAI_RunInfo **batched_rinfo, const size_t noutputs = RAI_ModelRunCtxNumOutputs(currentOp->mctx); for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) { RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(currentOp->mctx, outputNumber); - const char *key_string = - RedisModule_StringPtrLen(currentOp->outkeys[outputNumber], NULL); + RedisModuleString *key_string = currentOp->outkeys[outputNumber]; tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL; AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, tensor); } @@ -289,8 +287,7 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur for (uint i = 0; i < n_inkeys; i++) { RAI_Tensor *inputTensor; const int get_result = RAI_getTensorFromLocalContext( - NULL, rinfo->dagTensorsContext, RedisModule_StringPtrLen(currentOp->inkeys[i], NULL), - &inputTensor, currentOp->err); + NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err); if (get_result == REDISMODULE_ERR) { // We check for this outside the function // this check cannot be covered by tests @@ -320,7 +317,7 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur const size_t noutputs = RAI_ScriptRunCtxNumOutputs(currentOp->sctx); for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) { RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(currentOp->sctx, outputNumber); - const char *key_string = RedisModule_StringPtrLen(currentOp->outkeys[outputNumber], NULL); + RedisModuleString *key_string = currentOp->outkeys[outputNumber]; tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL; AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, tensor); } @@ -349,8 +346,7 @@ size_t RAI_DagOpBatchSize(RAI_DagOp *op, AI_dict *opTensorsContext) { for (size_t i = 0; i < ninputs; i++) { RAI_Tensor *input; - RAI_getTensorFromLocalContext( - NULL, opTensorsContext, RedisModule_StringPtrLen(op->inkeys[i], NULL), &input, op->err); + RAI_getTensorFromLocalContext(NULL, opTensorsContext, op->inkeys[i], &input, op->err); // We are expecting input != NULL, because we only reach this function if all inputs // are available in context for the current dagOp. We could be more defensive eventually. @@ -388,14 +384,10 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, AI_dict *op1TensorsContext, RAI_DagOp *op for (int i = 0; i < ninputs1; i++) { RAI_Tensor *input1; - RAI_getTensorFromLocalContext(NULL, op1TensorsContext, - RedisModule_StringPtrLen(op1->inkeys[i], NULL), &input1, - op1->err); + RAI_getTensorFromLocalContext(NULL, op1TensorsContext, op1->inkeys[i], &input1, op1->err); RAI_Tensor *input2; - RAI_getTensorFromLocalContext(NULL, op2TensorsContext, - RedisModule_StringPtrLen(op2->inkeys[i], NULL), &input2, - op2->err); + RAI_getTensorFromLocalContext(NULL, op2TensorsContext, op2->inkeys[i], &input2, op2->err); if (input1 == NULL || input2 == NULL) { return 0; @@ -463,8 +455,7 @@ void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady, *currentOpReady = 1; for (int i = 0; i < n_inkeys; i++) { - if (AI_dictFind(rinfo->dagTensorsContext, - RedisModule_StringPtrLen(currentOp_->inkeys[i], NULL)) == NULL) { + if (AI_dictFind(rinfo->dagTensorsContext, currentOp_->inkeys[i]) == NULL) { RAI_ContextUnlock(rinfo); *currentOpReady = 0; return; @@ -639,8 +630,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc case REDISAI_DAG_CMD_MODELRUN: { rinfo->dagReplyLength++; struct RedisAI_RunStats *rstats = NULL; - const char *runkey = RedisModule_StringPtrLen(currentOp->runkey, NULL); - RAI_GetRunStats(runkey, &rstats); + RAI_GetRunStats(currentOp->runkey, &rstats); if (currentOp->result == REDISMODULE_ERR) { RAI_SafeAddDataPoint(rstats, 0, 1, 1, 0); RedisModule_ReplyWithError(ctx, currentOp->err->detail_oneline); @@ -665,8 +655,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc case REDISAI_DAG_CMD_SCRIPTRUN: { rinfo->dagReplyLength++; struct RedisAI_RunStats *rstats = NULL; - const char *runkey = RedisModule_StringPtrLen(currentOp->runkey, NULL); - RAI_GetRunStats(runkey, &rstats); + RAI_GetRunStats(currentOp->runkey, &rstats); if (currentOp->result == REDISMODULE_ERR) { RAI_SafeAddDataPoint(rstats, 0, 1, 1, 0); RedisModule_ReplyWithError(ctx, currentOp->err->detail_oneline); @@ -698,7 +687,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc AI_dictIterator *persist_iter = AI_dictGetSafeIterator(rinfo->dagTensorsPersistedContext); AI_dictEntry *persist_entry = AI_dictNext(persist_iter); while (persist_entry) { - const char *persist_key_name = AI_dictGetKey(persist_entry); + RedisModuleString *persist_key_name = AI_dictGetKey(persist_entry); AI_dictEntry *tensor_entry = AI_dictFind(rinfo->dagTensorsContext, persist_key_name); @@ -710,13 +699,13 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc continue; } RedisModuleKey *key; - char *demangled_key_name = RedisModule_Strdup(persist_key_name); - demangled_key_name[strlen(persist_key_name) - 4] = 0; - RedisModuleString *tensor_keyname = - RedisModule_CreateString(ctx, demangled_key_name, strlen(demangled_key_name)); - const int status = - RAI_OpenKey_Tensor(ctx, tensor_keyname, &key, REDISMODULE_READ | REDISMODULE_WRITE); - RedisModule_Free(demangled_key_name); + size_t persist_key_len; + const char *persist_key_str = + RedisModule_StringPtrLen(persist_key_name, &persist_key_len); + RedisModuleString *demangled_key_name = + RedisModule_CreateString(NULL, persist_key_str, persist_key_len - 4); + const int status = RAI_OpenKey_Tensor(ctx, demangled_key_name, &key, + REDISMODULE_READ | REDISMODULE_WRITE); if (status == REDISMODULE_ERR) { RedisModule_ReplyWithError(ctx, "ERR could not save tensor"); rinfo->dagReplyLength++; @@ -729,7 +718,8 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc } } RedisModule_CloseKey(key); - RedisAI_ReplicateTensorSet(ctx, tensor_keyname, tensor); + RedisAI_ReplicateTensorSet(ctx, demangled_key_name, tensor); + RedisModule_FreeString(NULL, demangled_key_name); } else { RedisModule_ReplyWithError(ctx, "ERR specified persistent key that was not used in DAG"); @@ -738,13 +728,13 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc RedisModule_Log(ctx, "warning", "on DAGRUN's PERSIST pecified persistent key (%s) that " "was not used on DAG. Logging all local context keys", - persist_key_name); + RedisModule_StringPtrLen(persist_key_name, NULL)); AI_dictIterator *local_iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext); AI_dictEntry *local_entry = AI_dictNext(local_iter); while (local_entry) { - const char *localcontext_key_name = AI_dictGetKey(local_entry); + RedisModuleString *localcontext_key_name = AI_dictGetKey(local_entry); RedisModule_Log(ctx, "warning", "DAG's local context key (%s)", - localcontext_key_name); + RedisModule_StringPtrLen(localcontext_key_name, NULL)); local_entry = AI_dictNext(local_iter); } AI_dictReleaseIterator(local_iter); @@ -791,7 +781,8 @@ int RAI_parseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc int separator_flag = 0; size_t argpos = 2; for (; (argpos <= argc - 1) && (number_loaded_keys < n_keys); argpos++) { - const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL); + size_t arg_len; + const char *arg_string = RedisModule_StringPtrLen(argv[argpos], &arg_len); if (!strcasecmp(arg_string, chaining_operator)) { separator_flag = 1; break; @@ -807,11 +798,14 @@ int RAI_parseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc return -1; } RedisModule_CloseKey(key); - char *dictKey = (char *)RedisModule_Alloc((strlen(arg_string) + 5) * sizeof(char)); - sprintf(dictKey, "%s%04d", arg_string, 1); + char buf[16]; + sprintf(buf, "%04d", 1); + RedisModuleString *dictKey = RedisModule_CreateStringFromString(NULL, argv[argpos]); + RedisModule_StringAppendBuffer(NULL, dictKey, buf, strlen(buf)); + AI_dictAdd(*localContextDict, (void *)dictKey, (void *)RAI_TensorGetShallowCopy(t)); AI_dictAdd(*loadedContextDict, (void *)dictKey, (void *)1); - RedisModule_Free(dictKey); + RedisModule_FreeString(NULL, dictKey); number_loaded_keys++; } } @@ -849,7 +843,7 @@ int RAI_parseDAGPersistArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int a separator_flag = 1; break; } else { - AI_dictAdd(*persistContextDict, (void *)arg_string, (void *)1); + AI_dictAdd(*persistContextDict, (void *)argv[argpos], (void *)1); number_loaded_keys++; } } @@ -1078,7 +1072,7 @@ static int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int RAI_Tensor *t; RedisModuleKey *key; for (size_t i = 0; i < array_len(op->inkeys); i++) { - const char *inkey = RedisModule_StringPtrLen(op->inkeys[i], NULL); + RedisModuleString *inkey = op->inkeys[i]; const int status = RAI_GetTensorFromKeyspace(ctx, op->inkeys[i], &key, &t, REDISMODULE_READ); if (status == REDISMODULE_ERR) { @@ -1088,8 +1082,10 @@ static int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int return REDISMODULE_ERR; } RedisModule_CloseKey(key); - char *dictKey = (char *)RedisModule_Alloc((strlen(inkey) + 5) * sizeof(char)); - sprintf(dictKey, "%s%04d", inkey, 1); + char buf[16]; + sprintf(buf, "%04d", 1); + RedisModuleString *dictKey = RedisModule_CreateStringFromString(NULL, inkey); + RedisModule_StringAppendBuffer(NULL, dictKey, buf, strlen(buf)); AI_dictAdd(rinfo->dagTensorsContext, (void *)dictKey, (void *)RAI_TensorGetShallowCopy(t)); AI_dictAdd(rinfo->dagTensorsLoadedContext, (void *)dictKey, (void *)1); @@ -1097,7 +1093,7 @@ static int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int } for (size_t i = 0; i < array_len(op->outkeys); i++) { - const char *outkey = RedisModule_StringPtrLen(op->outkeys[i], NULL); + RedisModuleString *outkey = op->outkeys[i]; AI_dictAdd(rinfo->dagTensorsPersistedContext, (void *)outkey, (void *)1); } } @@ -1113,7 +1109,7 @@ static int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int // mangle the names by appending a numerical suffix ":0001". After computing, we // demangle the keys in order to persist them. - AI_dict *mangled_tensors = AI_dictCreate(&AI_dictTypeHeapStrings, NULL); + AI_dict *mangled_tensors = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); if (!mangled_tensors) { return REDISMODULE_ERR; } @@ -1122,13 +1118,14 @@ static int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsLoadedContext); AI_dictEntry *entry = AI_dictNext(iter); while (entry) { - char *key = (char *)AI_dictGetKey(entry); - char *demangled_key = RedisModule_Strdup(key); - demangled_key[strlen(key) - 4] = 0; + RedisModuleString *key = (RedisModuleString *)AI_dictGetKey(entry); + size_t key_len; + const char *key_str = RedisModule_StringPtrLen(key, &key_len); + RedisModuleString *demangled_key = RedisModule_CreateString(NULL, key_str, key_len - 4); int *instance = RedisModule_Alloc(sizeof(int)); *instance = 1; AI_dictAdd(mangled_tensors, (void *)demangled_key, (void *)instance); - RedisModule_Free(demangled_key); + RedisModule_FreeString(NULL, demangled_key); entry = AI_dictNext(iter); } AI_dictReleaseIterator(iter); @@ -1140,7 +1137,7 @@ static int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int RedisModuleString **mangled_inkeys = array_new(RedisModuleString *, array_len(currentOp->inkeys)); for (long long j = 0; j < array_len(currentOp->inkeys); j++) { - const char *key = RedisModule_StringPtrLen(currentOp->inkeys[j], NULL); + RedisModuleString *key = currentOp->inkeys[j]; AI_dictEntry *entry = AI_dictFind(mangled_tensors, key); if (!entry) { AI_dictRelease(mangled_tensors); @@ -1148,15 +1145,17 @@ static int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int return REDISMODULE_ERR; } int *instance = AI_dictGetVal(entry); - RedisModuleString *mangled_key = - RedisModule_CreateStringPrintf(ctx, "%s%04d", key, *instance); + char buf[16]; + sprintf(buf, "%04d", *instance); + RedisModuleString *mangled_key = RedisModule_CreateStringFromString(NULL, key); + RedisModule_StringAppendBuffer(NULL, mangled_key, buf, strlen(buf)); mangled_inkeys = array_append(mangled_inkeys, mangled_key); } RedisModuleString **mangled_outkeys = array_new(RedisModuleString *, array_len(currentOp->outkeys)); for (long long j = 0; j < array_len(currentOp->outkeys); j++) { - const char *key = RedisModule_StringPtrLen(currentOp->outkeys[j], NULL); + RedisModuleString *key = currentOp->outkeys[j]; AI_dictEntry *entry = AI_dictFind(mangled_tensors, key); int *instance = NULL; if (entry) { @@ -1167,8 +1166,10 @@ static int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int *instance = 1; AI_dictAdd(mangled_tensors, (void *)key, (void *)instance); } - RedisModuleString *mangled_key = - RedisModule_CreateStringPrintf(ctx, "%s%04d", key, *instance); + char buf[16]; + sprintf(buf, "%04d", *instance); + RedisModuleString *mangled_key = RedisModule_CreateStringFromString(NULL, key); + RedisModule_StringAppendBuffer(NULL, mangled_key, buf, strlen(buf)); mangled_outkeys = array_append(mangled_outkeys, mangled_key); } @@ -1179,26 +1180,27 @@ static int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int currentOp->outkeys = mangled_outkeys; } - AI_dict *mangled_persisted = AI_dictCreate(&AI_dictTypeHeapStrings, NULL); + AI_dict *mangled_persisted = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); { AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsPersistedContext); AI_dictEntry *entry = AI_dictNext(iter); while (entry) { - char *key = (char *)AI_dictGetKey(entry); + RedisModuleString *key = (RedisModuleString *)AI_dictGetKey(entry); AI_dictEntry *mangled_entry = AI_dictFind(mangled_tensors, key); if (!mangled_entry) { AI_dictRelease(mangled_tensors); AI_dictRelease(mangled_persisted); - RedisModule_ReplyWithError(ctx, "ERR PERSIST key cannot be found in DAG"); AI_dictReleaseIterator(iter); RedisModule_ReplyWithError(ctx, "ERR PERSIST key cannot be found in DAG"); return REDISMODULE_ERR; } int *instance = AI_dictGetVal(mangled_entry); - RedisModuleString *mangled_key = - RedisModule_CreateStringPrintf(ctx, "%s%04d", key, *instance); - const char *mangled_key_str = RedisModule_StringPtrLen(mangled_key, NULL); - AI_dictAdd(mangled_persisted, (void *)mangled_key_str, (void *)1); + char buf[16]; + sprintf(buf, "%04d", *instance); + RedisModuleString *mangled_key = RedisModule_CreateStringFromString(NULL, key); + RedisModule_StringAppendBuffer(NULL, mangled_key, buf, strlen(buf)); + + AI_dictAdd(mangled_persisted, (void *)mangled_key, (void *)1); entry = AI_dictNext(iter); } AI_dictReleaseIterator(iter); diff --git a/src/model.c b/src/model.c index 0e46471c4..a13261fe9 100644 --- a/src/model.c +++ b/src/model.c @@ -17,6 +17,7 @@ #include "stats.h" #include "util/arr_rm_alloc.h" #include "util/dict.h" +#include "util/string_utils.h" #include RedisModuleType *RedisAI_ModelType = NULL; @@ -31,7 +32,7 @@ static void *RAI_Model_RdbLoad(struct RedisModuleIO *io, int encver) { RAI_Backend backend = RedisModule_LoadUnsigned(io); const char *devicestr = RedisModule_LoadStringBuffer(io, NULL); - const char *tag = RedisModule_LoadStringBuffer(io, NULL); + RedisModuleString *tag = RedisModule_LoadString(io); const size_t batchsize = RedisModule_LoadUnsigned(io); const size_t minbatchsize = RedisModule_LoadUnsigned(io); @@ -113,7 +114,7 @@ static void *RAI_Model_RdbLoad(struct RedisModuleIO *io, int encver) { RedisModuleString *stats_keystr = RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io)); const char *stats_devicestr = RedisModule_Strdup(devicestr); - const char *stats_tag = RedisModule_Strdup(tag); + RedisModuleString *stats_tag = RAI_HoldString(NULL, tag); model->infokey = RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_MODEL, backend, stats_devicestr, stats_tag); @@ -143,7 +144,7 @@ static void RAI_Model_RdbSave(RedisModuleIO *io, void *value) { RedisModule_SaveUnsigned(io, model->backend); RedisModule_SaveStringBuffer(io, model->devicestr, strlen(model->devicestr) + 1); - RedisModule_SaveStringBuffer(io, model->tag, strlen(model->tag) + 1); + RedisModule_SaveString(io, model->tag); RedisModule_SaveUnsigned(io, model->opts.batchsize); RedisModule_SaveUnsigned(io, model->opts.minbatchsize); RedisModule_SaveUnsigned(io, model->ninputs); @@ -221,7 +222,7 @@ static void RAI_Model_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, voi const char *backendstr = RAI_BackendName(model->backend); - RedisModule_EmitAOF(aof, "AI.MODELSET", "slccclclcvcvcv", key, backendstr, model->devicestr, + RedisModule_EmitAOF(aof, "AI.MODELSET", "sccsclclcvcvcv", key, backendstr, model->devicestr, model->tag, "BATCHSIZE", model->opts.batchsize, "MINBATCHSIZE", model->opts.minbatchsize, "INPUTS", inputs_, model->ninputs, "OUTPUTS", outputs_, model->noutputs, "BLOB", buffers_, n_chunks); @@ -285,7 +286,7 @@ int RAI_ModelInit(RedisModuleCtx *ctx) { return RedisAI_ModelType != NULL; } -RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, const char *tag, +RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModuleString *tag, RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs, const char **outputs, const char *modeldef, size_t modellen, RAI_Error *err) { @@ -321,7 +322,11 @@ RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, const cha } if (model) { - model->tag = RedisModule_Strdup(tag); + if (tag) { + model->tag = RAI_HoldString(NULL, tag); + } else { + model->tag = RedisModule_CreateString(NULL, "", 0); + } } return model; @@ -361,7 +366,7 @@ void RAI_ModelFree(RAI_Model *model, RAI_Error *err) { return; } - RedisModule_Free(model->tag); + RedisModule_FreeString(NULL, model->tag); RAI_RemoveStatsEntry(model->infokey); @@ -588,12 +593,12 @@ int RedisAI_Parse_ModelRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString * is_input = 1; outputs_flag_count = 1; } else { - RedisModule_RetainString(ctx, argv[argpos]); + RedisModuleString *arg = RAI_HoldString(ctx, argv[argpos]); if (is_input == 0) { - *inkeys = array_append(*inkeys, argv[argpos]); + *inkeys = array_append(*inkeys, arg); ninputs++; } else { - *outkeys = array_append(*outkeys, argv[argpos]); + *outkeys = array_append(*outkeys, arg); noutputs++; } } diff --git a/src/model.h b/src/model.h index 719cc5ab5..51cc992c6 100644 --- a/src/model.h +++ b/src/model.h @@ -49,7 +49,7 @@ int RAI_ModelInit(RedisModuleCtx *ctx); * failures * @return RAI_Model model structure on success, or NULL if failed */ -RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, const char *tag, +RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModuleString *tag, RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs, const char **outputs, const char *modeldef, size_t modellen, RAI_Error *err); diff --git a/src/model_struct.h b/src/model_struct.h index 2a5758a38..1cfdb5879 100644 --- a/src/model_struct.h +++ b/src/model_struct.h @@ -22,7 +22,7 @@ typedef struct RAI_Model { void *session; RAI_Backend backend; char *devicestr; - char *tag; + RedisModuleString *tag; RAI_ModelOpts opts; char **inputs; size_t ninputs; diff --git a/src/redisai.c b/src/redisai.c index c2d45f5c1..0205cc954 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -26,6 +26,7 @@ #include "run_info.h" #include "util/arr_rm_alloc.h" #include "util/dict.h" +#include "util/string_utils.h" #include "util/queue.h" #include "version.h" @@ -184,9 +185,9 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, return RedisModule_ReplyWithError(ctx, "ERR Invalid DEVICE"); } - const char *tag = ""; + RedisModuleString *tag = NULL; if (AC_AdvanceIfMatch(&ac, "TAG")) { - AC_GetString(&ac, &tag, NULL, 0); + AC_GetRString(&ac, &tag, 0); } unsigned long long batchsize = 0; @@ -470,7 +471,8 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, RedisModule_ReplyWithCString(ctx, mto->devicestr); RedisModule_ReplyWithCString(ctx, "tag"); - RedisModule_ReplyWithCString(ctx, mto->tag ? mto->tag : ""); + RedisModuleString *empty_tag = RedisModule_CreateString(ctx, "", 0); + RedisModule_ReplyWithString(ctx, mto->tag ? mto->tag : empty_tag); RedisModule_ReplyWithCString(ctx, "batchsize"); RedisModule_ReplyWithLongLong(ctx, (long)mto->opts.batchsize); @@ -539,7 +541,7 @@ int RedisAI_ModelScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv long long nkeys; RedisModuleString **keys; - const char **tags; + RedisModuleString **tags; RAI_ListStatsEntries(RAI_MODEL, &nkeys, &keys, &tags); RedisModule_ReplyWithArray(ctx, nkeys); @@ -547,7 +549,7 @@ int RedisAI_ModelScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv for (long long i = 0; i < nkeys; i++) { RedisModule_ReplyWithArray(ctx, 2); RedisModule_ReplyWithString(ctx, keys[i]); - RedisModule_ReplyWithCString(ctx, tags[i]); + RedisModule_ReplyWithString(ctx, tags[i]); } RedisModule_Free(keys); @@ -633,7 +635,7 @@ int RedisAI_ScriptGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv RedisModule_ReplyWithCString(ctx, "device"); RedisModule_ReplyWithCString(ctx, sto->devicestr); RedisModule_ReplyWithCString(ctx, "tag"); - RedisModule_ReplyWithCString(ctx, sto->tag); + RedisModule_ReplyWithString(ctx, sto->tag); if (source) { RedisModule_ReplyWithCString(ctx, "source"); RedisModule_ReplyWithCString(ctx, sto->scriptdef); @@ -682,9 +684,9 @@ int RedisAI_ScriptSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv const char *devicestr; AC_GetString(&ac, &devicestr, NULL, 0); - const char *tag = ""; + RedisModuleString *tag = NULL; if (AC_AdvanceIfMatch(&ac, "TAG")) { - AC_GetString(&ac, &tag, NULL, 0); + AC_GetRString(&ac, &tag, 0); } if (AC_IsAtEnd(&ac)) { @@ -780,7 +782,7 @@ int RedisAI_ScriptScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg long long nkeys; RedisModuleString **keys; - const char **tags; + RedisModuleString **tags; RAI_ListStatsEntries(RAI_SCRIPT, &nkeys, &keys, &tags); RedisModule_ReplyWithArray(ctx, nkeys); @@ -788,7 +790,7 @@ int RedisAI_ScriptScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg for (long long i = 0; i < nkeys; i++) { RedisModule_ReplyWithArray(ctx, 2); RedisModule_ReplyWithString(ctx, keys[i]); - RedisModule_ReplyWithCString(ctx, tags[i]); + RedisModule_ReplyWithString(ctx, tags[i]); } RedisModule_Free(keys); @@ -803,7 +805,7 @@ int RedisAI_ScriptScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { if (argc != 2 && argc != 3) return RedisModule_WrongArity(ctx); - const char *runkey = RedisModule_StringPtrLen(argv[1], NULL); + RedisModuleString *runkey = argv[1]; struct RedisAI_RunStats *rstats = NULL; if (RAI_GetRunStats(runkey, &rstats) == REDISMODULE_ERR) { return RedisModule_ReplyWithError(ctx, "ERR cannot find run info for key"); @@ -833,7 +835,11 @@ int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int RedisModule_ReplyWithCString(ctx, "device"); RedisModule_ReplyWithCString(ctx, rstats->devicestr); RedisModule_ReplyWithCString(ctx, "tag"); - RedisModule_ReplyWithCString(ctx, rstats->tag); + if (rstats->tag) { + RedisModule_ReplyWithString(ctx, rstats->tag); + } else { + RedisModule_ReplyWithCString(ctx, ""); + } RedisModule_ReplyWithCString(ctx, "duration"); RedisModule_ReplyWithLongLong(ctx, rstats->duration_us); RedisModule_ReplyWithCString(ctx, "samples"); @@ -1209,9 +1215,7 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; } - run_stats = AI_dictCreate(&AI_dictTypeHeapStrings, NULL); + run_stats = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); return REDISMODULE_OK; } - -extern AI_dictType AI_dictTypeHeapStrings; diff --git a/src/run_info.c b/src/run_info.c index bfe20be41..9dbb96ab5 100644 --- a/src/run_info.c +++ b/src/run_info.c @@ -15,33 +15,19 @@ #include "tensor.h" #include "util/arr_rm_alloc.h" #include "util/dict.h" - -static uint64_t RAI_TensorDictKeyHashFunction(const void *key) { - return AI_dictGenHashFunction(key, strlen((char *)key)); -} - -static int RAI_TensorDictKeyStrcmp(void *privdata, const void *key1, const void *key2) { - const char *strKey1 = key1; - const char *strKey2 = key2; - return strcmp(strKey1, strKey2) == 0; -} - -static void RAI_TensorDictKeyFree(void *privdata, void *key) { RedisModule_Free(key); } - -static void *RAI_TensorDictKeyDup(void *privdata, const void *key) { - return RedisModule_Strdup((char *)key); -} +#include "util/string_utils.h" +#include static void RAI_TensorDictValFree(void *privdata, void *obj) { return RAI_TensorFree((RAI_Tensor *)obj); } AI_dictType AI_dictTypeTensorVals = { - .hashFunction = RAI_TensorDictKeyHashFunction, - .keyDup = RAI_TensorDictKeyDup, + .hashFunction = RAI_RStringsHashFunction, + .keyDup = RAI_RStringsKeyDup, .valDup = NULL, - .keyCompare = RAI_TensorDictKeyStrcmp, - .keyDestructor = RAI_TensorDictKeyFree, + .keyCompare = RAI_RStringsKeyCompare, + .keyDestructor = RAI_RStringsKeyDestructor, .valDestructor = RAI_TensorDictValFree, }; @@ -105,11 +91,11 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) { if (!(rinfo->dagTensorsContext)) { return REDISMODULE_ERR; } - rinfo->dagTensorsLoadedContext = AI_dictCreate(&AI_dictTypeHeapStrings, NULL); + rinfo->dagTensorsLoadedContext = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); if (!(rinfo->dagTensorsLoadedContext)) { return REDISMODULE_ERR; } - rinfo->dagTensorsPersistedContext = AI_dictCreate(&AI_dictTypeHeapStrings, NULL); + rinfo->dagTensorsPersistedContext = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL); if (!(rinfo->dagTensorsPersistedContext)) { return REDISMODULE_ERR; } diff --git a/src/script.c b/src/script.c index 75ef0c0a2..ec75bdfdd 100644 --- a/src/script.c +++ b/src/script.c @@ -13,6 +13,7 @@ #include "script_struct.h" #include "stats.h" #include "util/arr_rm_alloc.h" +#include "util/string_utils.h" #include "version.h" #include @@ -28,7 +29,7 @@ static void *RAI_Script_RdbLoad(struct RedisModuleIO *io, int encver) { RAI_Error err = {0}; const char *devicestr = RedisModule_LoadStringBuffer(io, NULL); - const char *tag = RedisModule_LoadStringBuffer(io, NULL); + RedisModuleString *tag = RedisModule_LoadString(io); size_t len; char *scriptdef = RedisModule_LoadStringBuffer(io, &len); @@ -58,12 +59,13 @@ static void *RAI_Script_RdbLoad(struct RedisModuleIO *io, int encver) { RedisModuleString *stats_keystr = RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io)); const char *stats_devicestr = RedisModule_Strdup(devicestr); - const char *stats_tag = RedisModule_Strdup(tag); + + tag = RAI_HoldString(NULL, tag); script->infokey = RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_SCRIPT, RAI_BACKEND_TORCH, - stats_devicestr, stats_tag); + stats_devicestr, tag); - RedisModule_Free(stats_keystr); + RedisModule_FreeString(NULL, stats_keystr); return script; } @@ -74,14 +76,14 @@ static void RAI_Script_RdbSave(RedisModuleIO *io, void *value) { size_t len = strlen(script->scriptdef) + 1; RedisModule_SaveStringBuffer(io, script->devicestr, strlen(script->devicestr) + 1); - RedisModule_SaveStringBuffer(io, script->tag, strlen(script->tag) + 1); + RedisModule_SaveString(io, script->tag); RedisModule_SaveStringBuffer(io, script->scriptdef, len); } static void RAI_Script_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, void *value) { RAI_Script *script = (RAI_Script *)value; - RedisModule_EmitAOF(aof, "AI.SCRIPTSET", "scccc", key, script->devicestr, script->tag, "SOURCE", + RedisModule_EmitAOF(aof, "AI.SCRIPTSET", "scscc", key, script->devicestr, script->tag, "SOURCE", script->scriptdef); } @@ -107,7 +109,7 @@ int RAI_ScriptInit(RedisModuleCtx *ctx) { return RedisAI_ScriptType != NULL; } -RAI_Script *RAI_ScriptCreate(const char *devicestr, const char *tag, const char *scriptdef, +RAI_Script *RAI_ScriptCreate(const char *devicestr, RedisModuleString *tag, const char *scriptdef, RAI_Error *err) { if (!RAI_backends.torch.script_create) { RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH"); @@ -116,7 +118,11 @@ RAI_Script *RAI_ScriptCreate(const char *devicestr, const char *tag, const char RAI_Script *script = RAI_backends.torch.script_create(devicestr, scriptdef, err); if (script) { - script->tag = RedisModule_Strdup(tag); + if (tag) { + script->tag = RAI_HoldString(NULL, tag); + } else { + script->tag = RedisModule_CreateString(NULL, "", 0); + } } return script; @@ -132,7 +138,7 @@ void RAI_ScriptFree(RAI_Script *script, RAI_Error *err) { return; } - RedisModule_Free(script->tag); + RedisModule_FreeString(NULL, script->tag); RAI_RemoveStatsEntry(script->infokey); diff --git a/src/script.h b/src/script.h index 3a00436e7..ec43e42a4 100644 --- a/src/script.h +++ b/src/script.h @@ -36,7 +36,7 @@ int RAI_ScriptInit(RedisModuleCtx *ctx); * failures * @return RAI_Script script structure on success, or NULL if failed */ -RAI_Script *RAI_ScriptCreate(const char *devicestr, const char *tag, const char *scriptdef, +RAI_Script *RAI_ScriptCreate(const char *devicestr, RedisModuleString *tag, const char *scriptdef, RAI_Error *err); /** diff --git a/src/script_struct.h b/src/script_struct.h index 71dfcb9cc..7efd1f5be 100644 --- a/src/script_struct.h +++ b/src/script_struct.h @@ -12,7 +12,7 @@ typedef struct RAI_Script { // We keep it here at the moment, until we have a // CUDA allocator for dlpack char *devicestr; - char *tag; + RedisModuleString *tag; long long refCount; void *infokey; } RAI_Script; diff --git a/src/stats.c b/src/stats.c index 0c2510bd4..8d3a0d1de 100644 --- a/src/stats.c +++ b/src/stats.c @@ -8,6 +8,7 @@ */ #include "stats.h" +#include "util/string_utils.h" #include @@ -24,31 +25,28 @@ long long ustime(void) { mstime_t mstime(void) { return ustime() / 1000; } void *RAI_AddStatsEntry(RedisModuleCtx *ctx, RedisModuleString *key, RAI_RunType runtype, - RAI_Backend backend, const char *devicestr, const char *tag) { - const char *infokey = RedisModule_StringPtrLen(key, NULL); - + RAI_Backend backend, const char *devicestr, RedisModuleString *tag) { struct RedisAI_RunStats *rstats = NULL; rstats = RedisModule_Calloc(1, sizeof(struct RedisAI_RunStats)); - RedisModule_RetainString(ctx, key); - rstats->key = key; + rstats->key = RAI_HoldString(NULL, key); rstats->type = runtype; rstats->backend = backend; rstats->devicestr = RedisModule_Strdup(devicestr); - rstats->tag = RedisModule_Strdup(tag); + rstats->tag = RAI_HoldString(NULL, tag); - AI_dictAdd(run_stats, (void *)infokey, (void *)rstats); + AI_dictAdd(run_stats, (void *)key, (void *)rstats); - return (void *)infokey; + return (void *)key; } void RAI_ListStatsEntries(RAI_RunType type, long long *nkeys, RedisModuleString ***keys, - const char ***tags) { + RedisModuleString ***tags) { AI_dictIterator *stats_iter = AI_dictGetSafeIterator(run_stats); long long stats_size = AI_dictSize(run_stats); *keys = RedisModule_Calloc(stats_size, sizeof(RedisModuleString *)); - *tags = RedisModule_Calloc(stats_size, sizeof(const char *)); + *tags = RedisModule_Calloc(stats_size, sizeof(RedisModuleString *)); *nkeys = 0; @@ -109,13 +107,13 @@ void RAI_FreeRunStats(struct RedisAI_RunStats *rstats) { RedisModule_Free(rstats->devicestr); } if (rstats->tag) { - RedisModule_Free(rstats->tag); + RedisModule_FreeString(NULL, rstats->tag); } RedisModule_Free(rstats); } } -int RAI_GetRunStats(const char *runkey, struct RedisAI_RunStats **rstats) { +int RAI_GetRunStats(RedisModuleString *runkey, struct RedisAI_RunStats **rstats) { int result = 1; if (run_stats == NULL) { return result; diff --git a/src/stats.h b/src/stats.h index e524e9144..97a6af0ca 100644 --- a/src/stats.h +++ b/src/stats.h @@ -21,7 +21,7 @@ struct RedisAI_RunStats { RAI_RunType type; RAI_Backend backend; char *devicestr; - char *tag; + RedisModuleString *tag; long long duration_us; long long samples; long long calls; @@ -47,7 +47,7 @@ mstime_t mstime(void); * @return */ void *RAI_AddStatsEntry(RedisModuleCtx *ctx, RedisModuleString *key, RAI_RunType type, - RAI_Backend backend, const char *devicestr, const char *tag); + RAI_Backend backend, const char *devicestr, RedisModuleString *tag); /** * Removes the statistical entry with the provided unique stats identifier @@ -67,7 +67,7 @@ void RAI_RemoveStatsEntry(void *infokey); * @param tags output variable containing the list of returned tags */ void RAI_ListStatsEntries(RAI_RunType type, long long *nkeys, RedisModuleString ***keys, - const char ***tags); + RedisModuleString ***tags); /** * @@ -97,7 +97,7 @@ void RAI_FreeRunStats(struct RedisAI_RunStats *rstats); * @param rstats * @return 0 on success, or 1 if the the run stats with runkey does not exist */ -int RAI_GetRunStats(const char *runkey, struct RedisAI_RunStats **rstats); +int RAI_GetRunStats(RedisModuleString *runkey, struct RedisAI_RunStats **rstats); void RedisAI_FreeRunStats(RedisModuleCtx *ctx, struct RedisAI_RunStats *rstats); diff --git a/src/tensor.c b/src/tensor.c index d60a67cb7..e8dc0300b 100644 --- a/src/tensor.c +++ b/src/tensor.c @@ -681,7 +681,7 @@ int RAI_GetTensorFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, R * Return REDISMODULE_OK if the tensor value is present at the localContextDict. */ int RAI_getTensorFromLocalContext(RedisModuleCtx *ctx, AI_dict *localContextDict, - const char *localContextKey, RAI_Tensor **tensor, + RedisModuleString *localContextKey, RAI_Tensor **tensor, RAI_Error *error) { int result = REDISMODULE_ERR; AI_dictEntry *tensor_entry = AI_dictFind(localContextDict, localContextKey); diff --git a/src/tensor.h b/src/tensor.h index 54b885205..f6beab986 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -337,7 +337,7 @@ int RAI_GetTensorFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, R * @return REDISMODULE_OK on success, or REDISMODULE_ERR if failed */ int RAI_getTensorFromLocalContext(RedisModuleCtx *ctx, AI_dict *localContextDict, - const char *localContextKey, RAI_Tensor **tensor, + RedisModuleString *localContextKey, RAI_Tensor **tensor, RAI_Error *error); /** diff --git a/src/util/dict.c b/src/util/dict.c index b61372bd6..59843a276 100644 --- a/src/util/dict.c +++ b/src/util/dict.c @@ -49,39 +49,6 @@ #include "siphash.c.inc" -static uint64_t stringsHashFunction(const void *key) { - return AI_dictGenHashFunction(key, strlen((char *)key)); -} - -static int stringsKeyCompare(void *privdata, const void *key1, const void *key2) { - const char *strKey1 = key1; - const char *strKey2 = key2; - - return strcmp(strKey1, strKey2) == 0; -} - -static void stringsKeyDestructor(void *privdata, void *key) { RA_FREE(key); } - -static void *stringsKeyDup(void *privdata, const void *key) { return RA_STRDUP((char *)key); } - -AI_dictType AI_dictTypeHeapStringsVals = { - .hashFunction = stringsHashFunction, - .keyDup = stringsKeyDup, - .valDup = NULL, - .keyCompare = stringsKeyCompare, - .keyDestructor = stringsKeyDestructor, - .valDestructor = stringsKeyDestructor, -}; - -AI_dictType AI_dictTypeHeapStrings = { - .hashFunction = stringsHashFunction, - .keyDup = stringsKeyDup, - .valDup = NULL, - .keyCompare = stringsKeyCompare, - .keyDestructor = stringsKeyDestructor, - .valDestructor = NULL, -}; - /* Using dictEnableResize() / dictDisableResize() we make possible to * enable/disable resizing of the hash table as needed. This is very important * for Redis, as we use copy-on-write and don't want to move too much memory diff --git a/src/util/dict.h b/src/util/dict.h index 1f72bb06c..74f149772 100644 --- a/src/util/dict.h +++ b/src/util/dict.h @@ -189,7 +189,4 @@ unsigned long AI_dictScan(AI_dict *d, unsigned long v, AI_dictScanFunction *fn, uint64_t AI_dictGetHash(AI_dict *d, const void *key); AI_dictEntry **AI_dictFindEntryRefByPtrAndHash(AI_dict *d, const void *oldptr, uint64_t hash); -extern AI_dictType AI_dictTypeHeapStrings; -extern AI_dictType AI_dictTypeHeapStringsVals; - #endif /* __DICT_H */ diff --git a/src/util/string_utils.c b/src/util/string_utils.c new file mode 100644 index 000000000..904c4af94 --- /dev/null +++ b/src/util/string_utils.c @@ -0,0 +1,90 @@ +#include "string_utils.h" +#include "dict.h" +#include +#include "../redisai_memory.h" + +RedisModuleString *RAI_HoldString(RedisModuleCtx *ctx, RedisModuleString *str) { + if (str == NULL) { + return NULL; + } + RedisModuleString *out; + if (RMAPI_FUNC_SUPPORTED(RedisModule_HoldString)) { + out = RedisModule_HoldString(NULL, str); + } else { + RedisModule_RetainString(NULL, str); + out = str; + } + return out; +} + +uint64_t RAI_StringsHashFunction(const void *key) { + return AI_dictGenHashFunction(key, strlen((char *)key)); +} + +int RAI_StringsKeyCompare(void *privdata, const void *key1, const void *key2) { + const char *strKey1 = key1; + const char *strKey2 = key2; + + return strcmp(strKey1, strKey2) == 0; +} + +void RAI_StringsKeyDestructor(void *privdata, void *key) { RA_FREE(key); } + +void *RAI_StringsKeyDup(void *privdata, const void *key) { return RA_STRDUP((char *)key); } + +AI_dictType AI_dictTypeHeapStringsVals = { + .hashFunction = RAI_StringsHashFunction, + .keyDup = RAI_StringsKeyDup, + .valDup = NULL, + .keyCompare = RAI_StringsKeyCompare, + .keyDestructor = RAI_StringsKeyDestructor, + .valDestructor = RAI_StringsKeyDestructor, +}; + +AI_dictType AI_dictTypeHeapStrings = { + .hashFunction = RAI_StringsHashFunction, + .keyDup = RAI_StringsKeyDup, + .valDup = NULL, + .keyCompare = RAI_StringsKeyCompare, + .keyDestructor = RAI_StringsKeyDestructor, + .valDestructor = NULL, +}; + +uint64_t RAI_RStringsHashFunction(const void *key) { + size_t len; + const char *buffer = RedisModule_StringPtrLen((RedisModuleString *)key, &len); + return AI_dictGenHashFunction(buffer, len); +} + +int RAI_RStringsKeyCompare(void *privdata, const void *key1, const void *key2) { + RedisModuleString *strKey1 = (RedisModuleString *)key1; + RedisModuleString *strKey2 = (RedisModuleString *)key2; + + return RedisModule_StringCompare(strKey1, strKey2) == 0; +} + +void RAI_RStringsKeyDestructor(void *privdata, void *key) { + RedisModule_FreeString(NULL, (RedisModuleString *)key); +} + +void *RAI_RStringsKeyDup(void *privdata, const void *key) { + return RedisModule_CreateStringFromString(NULL, (RedisModuleString *)key); +} + +AI_dictType AI_dictTypeHeapRStringsVals = { + .hashFunction = RAI_RStringsHashFunction, + .keyDup = RAI_RStringsKeyDup, + .valDup = NULL, + .keyCompare = RAI_RStringsKeyCompare, + .keyDestructor = RAI_RStringsKeyDestructor, + .valDestructor = RAI_RStringsKeyDestructor, +}; + +AI_dictType AI_dictTypeHeapRStrings = { + .hashFunction = RAI_RStringsHashFunction, + .keyDup = RAI_RStringsKeyDup, + .valDup = NULL, + .keyCompare = RAI_RStringsKeyCompare, + .keyDestructor = RAI_RStringsKeyDestructor, + .valDestructor = NULL, +}; diff --git a/src/util/string_utils.h b/src/util/string_utils.h new file mode 100644 index 000000000..53338e9ef --- /dev/null +++ b/src/util/string_utils.h @@ -0,0 +1,20 @@ +#include "redismodule.h" +#include "dict.h" + +RedisModuleString *RAI_HoldString(RedisModuleCtx *ctx, RedisModuleString *str); + +uint64_t RAI_StringsHashFunction(const void *key); +int RAI_StringsKeyCompare(void *privdata, const void *key1, const void *key2); +void RAI_StringsKeyDestructor(void *privdata, void *key); +void *RAI_StringsKeyDup(void *privdata, const void *key); + +uint64_t RAI_RStringsHashFunction(const void *key); +int RAI_RStringsKeyCompare(void *privdata, const void *key1, const void *key2); +void RAI_RStringsKeyDestructor(void *privdata, void *key); +void *RAI_RStringsKeyDup(void *privdata, const void *key); + +extern AI_dictType AI_dictTypeHeapStrings; +extern AI_dictType AI_dictTypeHeapStringsVals; + +extern AI_dictType AI_dictTypeHeapRStrings; +extern AI_dictType AI_dictTypeHeapRStringsVals; diff --git a/tests/flow/tests_dag.py b/tests/flow/tests_dag.py index 55f741da9..d84be1d1c 100644 --- a/tests/flow/tests_dag.py +++ b/tests/flow/tests_dag.py @@ -1056,7 +1056,8 @@ def test_dagrun_modelrun_multidevice_resnet_ensemble_alias(env): try: ret = con.execute_command( 'AI.DAGRUN', - 'PERSIST', '1', class_key_0, '|>', + 'PERSIST', '1', class_key_0, + '|>', 'AI.TENSORSET', image_key, 'UINT8', img.shape[1], img.shape[0], 3, 'BLOB', img.tobytes(), '|>', 'AI.SCRIPTRUN', script_name_0, 'pre_process_3ch', @@ -1117,17 +1118,18 @@ def test_dagrun_modelrun_multidevice_resnet_ensemble_alias(env): ret = con.execute_command( 'AI.DAGRUN', - 'PERSIST', '1', class_key_0, '|>', + 'PERSIST', '1', class_key_0, + '|>', 'AI.TENSORSET', image_key, 'UINT8', img.shape[1], img.shape[0], 3, 'BLOB', img.tobytes(), '|>', 'AI.SCRIPTRUN', script_name_0, 'pre_process_3ch', 'INPUTS', image_key, 'OUTPUTS', temp_key1, - '|>', + '|>', 'AI.MODELRUN', model_name_0, 'INPUTS', temp_key1, 'OUTPUTS', temp_key2_0, - '|>', + '|>', 'AI.MODELRUN', model_name_1, 'INPUTS', temp_key1, 'OUTPUTS', temp_key2_1,