Skip to content

Refactor RAI_getModel/ScriptFromKeyspace #573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/DAG/dag_builder.c
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ int RAI_DAGAddRunOp(RAI_DAGRunCtx *run_info, RAI_DAGRunOp *DAGop, RAI_Error *err
return REDISMODULE_OK;
}

int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err) {
int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name) {

RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
RAI_DagOp *op;
Expand Down
2 changes: 1 addition & 1 deletion src/DAG/dag_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ int RAI_DAGAddTensorSet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor
* @param runInfo The DAG to append this op into.
* @param tensor The tensor to set.
*/
int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err);
int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name);

/**
* @brief Add ops to a DAG from string (according to the command syntax). In case of a valid
Expand Down
7 changes: 2 additions & 5 deletions src/command_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModule
return REDISMODULE_ERR;
}
size_t argpos = 1;
RedisModuleKey *modelKey;
const int status =
RAI_GetModelFromKeyspace(ctx, argv[argpos], &modelKey, model, REDISMODULE_READ, error);
const int status = RAI_GetModelFromKeyspace(ctx, argv[argpos], model, REDISMODULE_READ, error);
if (status == REDISMODULE_ERR) {
return REDISMODULE_ERR;
}
Expand Down Expand Up @@ -172,9 +170,8 @@ static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **
return REDISMODULE_ERR;
}
size_t argpos = 1;
RedisModuleKey *scriptKey;
const int status =
RAI_GetScriptFromKeyspace(ctx, argv[argpos], &scriptKey, script, REDISMODULE_READ, error);
RAI_GetScriptFromKeyspace(ctx, argv[argpos], script, REDISMODULE_READ, error);
if (status == REDISMODULE_ERR) {
return REDISMODULE_ERR;
}
Expand Down
18 changes: 9 additions & 9 deletions src/model.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@
/* Return REDISMODULE_ERR if there was an error getting the Model.
* Return REDISMODULE_OK if the model value stored at key was correctly
* returned and available at *model variable. */
int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisModuleKey **key,
RAI_Model **model, int mode, RAI_Error *err) {
*key = RedisModule_OpenKey(ctx, keyName, mode);
if (RedisModule_KeyType(*key) == REDISMODULE_KEYTYPE_EMPTY) {
RedisModule_CloseKey(*key);
int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RAI_Model **model,
int mode, RAI_Error *err) {
RedisModuleKey *key = RedisModule_OpenKey(ctx, keyName, mode);
if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) {
RedisModule_CloseKey(key);
RAI_SetError(err, RAI_EMODELRUN, "ERR model key is empty");
return REDISMODULE_ERR;
}
if (RedisModule_ModuleTypeGetType(*key) != RedisAI_ModelType) {
RedisModule_CloseKey(*key);
if (RedisModule_ModuleTypeGetType(key) != RedisAI_ModelType) {
RedisModule_CloseKey(key);
RAI_SetError(err, RAI_EMODELRUN, REDISMODULE_ERRORMSG_WRONGTYPE);
return REDISMODULE_ERR;
}
*model = RedisModule_ModuleTypeGetValue(*key);
RedisModule_CloseKey(*key);
*model = RedisModule_ModuleTypeGetValue(key);
RedisModule_CloseKey(key);
return REDISMODULE_OK;
}

Expand Down
6 changes: 2 additions & 4 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,15 @@ int RAI_ModelSerialize(RAI_Model *model, char **buffer, size_t *len, RAI_Error *
*
* @param ctx Context in which Redis modules operate
* @param keyName key name
* @param key models's key handle. On success it contains an handle representing
* a Redis key with the requested access mode
* @param model destination model structure
* @param mode key access mode
* @param error contains the error in case of problem with retrival
* @return REDISMODULE_OK if the model value stored at key was correctly
* returned and available at *model variable, or REDISMODULE_ERR if there was
* an error getting the Model
*/
int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisModuleKey **key,
RAI_Model **model, int mode, RAI_Error *err);
int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RAI_Model **model,
int mode, RAI_Error *err);

/**
* When a module command is called in order to obtain the position of
Expand Down
18 changes: 7 additions & 11 deletions src/redisai.c
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,7 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,

RAI_Error err = {0};
RAI_Model *mto;
RedisModuleKey *key;
const int status = RAI_GetModelFromKeyspace(ctx, argv[1], &key, &mto, REDISMODULE_READ, &err);
const int status = RAI_GetModelFromKeyspace(ctx, argv[1], &mto, REDISMODULE_READ, &err);
if (status == REDISMODULE_ERR) {
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&err));
RAI_ClearError(&err);
Expand Down Expand Up @@ -521,17 +520,16 @@ int RedisAI_ModelDel_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
return RedisModule_WrongArity(ctx);

RAI_Model *mto;
RedisModuleKey *key;
RAI_Error err = {0};
const int status = RAI_GetModelFromKeyspace(ctx, argv[1], &key, &mto,
REDISMODULE_READ | REDISMODULE_WRITE, &err);
const int status =
RAI_GetModelFromKeyspace(ctx, argv[1], &mto, REDISMODULE_READ | REDISMODULE_WRITE, &err);
if (status == REDISMODULE_ERR) {
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&err));
RAI_ClearError(&err);
return REDISMODULE_ERR;
}

key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_WRITE);
RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_WRITE);
RedisModule_DeleteKey(key);
RedisModule_CloseKey(key);
RedisModule_ReplicateVerbatim(ctx);
Expand Down Expand Up @@ -605,9 +603,8 @@ int RedisAI_ScriptGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
return RedisModule_WrongArity(ctx);

RAI_Script *sto;
RedisModuleKey *key;
RAI_Error err = {0};
const int status = RAI_GetScriptFromKeyspace(ctx, argv[1], &key, &sto, REDISMODULE_READ, &err);
const int status = RAI_GetScriptFromKeyspace(ctx, argv[1], &sto, REDISMODULE_READ, &err);
if (status == REDISMODULE_ERR) {
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&err));
RAI_ClearError(&err);
Expand Down Expand Up @@ -656,15 +653,14 @@ int RedisAI_ScriptDel_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
return RedisModule_WrongArity(ctx);

RAI_Script *sto;
RedisModuleKey *key;
RAI_Error err = {0};
const int status = RAI_GetScriptFromKeyspace(ctx, argv[1], &key, &sto, REDISMODULE_WRITE, &err);
const int status = RAI_GetScriptFromKeyspace(ctx, argv[1], &sto, REDISMODULE_WRITE, &err);
if (status == REDISMODULE_ERR) {
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&err));
RAI_ClearError(&err);
return REDISMODULE_ERR;
}
key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_WRITE);
RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_WRITE);
RedisModule_DeleteKey(key);
RedisModule_CloseKey(key);

Expand Down
4 changes: 1 addition & 3 deletions src/redisai.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ REDISAI_API void MODULE_API_FUNC(RedisAI_ModelFree)(RAI_Model *model, RAI_Error
REDISAI_API RAI_ModelRunCtx *MODULE_API_FUNC(RedisAI_ModelRunCtxCreate)(RAI_Model *model);
REDISAI_API int MODULE_API_FUNC(RedisAI_GetModelFromKeyspace)(RedisModuleCtx *ctx,
RedisModuleString *keyName,
RedisModuleKey **key,
RAI_Model **model, int mode,
RAI_Error *err);
REDISAI_API int MODULE_API_FUNC(RedisAI_ModelRunCtxAddInput)(RAI_ModelRunCtx *mctx,
Expand Down Expand Up @@ -136,7 +135,6 @@ REDISAI_API RAI_Script *MODULE_API_FUNC(RedisAI_ScriptCreate)(char *devicestr, c
RAI_Error *err);
REDISAI_API int MODULE_API_FUNC(RedisAI_GetScriptFromKeyspace)(RedisModuleCtx *ctx,
RedisModuleString *keyName,
RedisModuleKey **key,
RAI_Script **script, int mode,
RAI_Error *err);
REDISAI_API void MODULE_API_FUNC(RedisAI_ScriptFree)(RAI_Script *script, RAI_Error *err);
Expand Down Expand Up @@ -175,7 +173,7 @@ REDISAI_API int MODULE_API_FUNC(RedisAI_DAGLoadTensor)(RAI_DAGRunCtx *run_info,
REDISAI_API int MODULE_API_FUNC(RedisAI_DAGAddTensorSet)(RAI_DAGRunCtx *run_info,
const char *t_name, RAI_Tensor *tensor);
REDISAI_API int MODULE_API_FUNC(RedisAI_DAGAddTensorGet)(RAI_DAGRunCtx *run_info,
const char *t_name, RAI_Error *err);
const char *t_name);
REDISAI_API int MODULE_API_FUNC(RedisAI_DAGAddOpsFromString)(RAI_DAGRunCtx *run_info,
const char *dag, RAI_Error *err);
REDISAI_API size_t MODULE_API_FUNC(RedisAI_DAGNumOps)(RAI_DAGRunCtx *run_info);
Expand Down
18 changes: 9 additions & 9 deletions src/script.c
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,21 @@ RAI_Script *RAI_ScriptGetShallowCopy(RAI_Script *script) {
/* Return REDISMODULE_ERR if there was an error getting the Script.
* Return REDISMODULE_OK if the model value stored at key was correctly
* returned and available at *model variable. */
int RAI_GetScriptFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisModuleKey **key,
RAI_Script **script, int mode, RAI_Error *err) {
*key = RedisModule_OpenKey(ctx, keyName, mode);
if (RedisModule_KeyType(*key) == REDISMODULE_KEYTYPE_EMPTY) {
RedisModule_CloseKey(*key);
int RAI_GetScriptFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RAI_Script **script,
int mode, RAI_Error *err) {
RedisModuleKey *key = RedisModule_OpenKey(ctx, keyName, mode);
if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) {
RedisModule_CloseKey(key);
RAI_SetError(err, RAI_ESCRIPTRUN, "ERR script key is empty");
return REDISMODULE_ERR;
}
if (RedisModule_ModuleTypeGetType(*key) != RedisAI_ScriptType) {
RedisModule_CloseKey(*key);
if (RedisModule_ModuleTypeGetType(key) != RedisAI_ScriptType) {
RedisModule_CloseKey(key);
RAI_SetError(err, RAI_ESCRIPTRUN, REDISMODULE_ERRORMSG_WRONGTYPE);
return REDISMODULE_ERR;
}
*script = RedisModule_ModuleTypeGetValue(*key);
RedisModule_CloseKey(*key);
*script = RedisModule_ModuleTypeGetValue(key);
RedisModule_CloseKey(key);
return REDISMODULE_OK;
}

Expand Down
6 changes: 2 additions & 4 deletions src/script.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,14 @@ RAI_Script *RAI_ScriptGetShallowCopy(RAI_Script *script);
*
* @param ctx Context in which Redis modules operate
* @param keyName key name
* @param key script's key handle. On success it contains an handle representing
* a Redis key with the requested access mode
* @param script destination script structure
* @param mode key access mode
* @return REDISMODULE_OK if the script value stored at key was correctly
* returned and available at *script variable, or REDISMODULE_ERR if there was
* an error getting the Script
*/
int RAI_GetScriptFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisModuleKey **key,
RAI_Script **script, int mode, RAI_Error *err);
int RAI_GetScriptFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RAI_Script **script,
int mode, RAI_Error *err);

/**
* When a module command is called in order to obtain the position of
Expand Down
12 changes: 6 additions & 6 deletions tests/module/DAG_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ int testKeysMismatchError(RedisModuleCtx *ctx) {
RAI_Tensor *t = (RAI_Tensor *)_getFromKeySpace(ctx, "a{1}");
RedisAI_DAGLoadTensor(run_info, "input", t);

RedisAI_DAGAddTensorGet(run_info, "non existing tensor", err);
RedisAI_DAGAddTensorGet(run_info, "non existing tensor");
int status = RedisAI_DAGRun(run_info, _DAGFinishFuncError, NULL, err);
if(!_assertError(err, status, "ERR INPUT key cannot be found in DAG")) {
goto cleanup;
Expand Down Expand Up @@ -183,7 +183,7 @@ int testBuildDAGFromString(RedisModuleCtx *ctx) {
goto cleanup;
}
RedisModule_Assert(RedisAI_DAGNumOps(run_info) == 3);
RedisAI_DAGAddTensorGet(run_info, "input1", results.error);
RedisAI_DAGAddTensorGet(run_info, "input1");
RedisModule_Assert(RedisAI_DAGNumOps(run_info) == 4);

pthread_mutex_lock(&global_lock);
Expand Down Expand Up @@ -227,7 +227,7 @@ int testSimpleDAGRun(RedisModuleCtx *ctx) {
goto cleanup;
}

RedisAI_DAGAddTensorGet(run_info, "output", results.error);
RedisAI_DAGAddTensorGet(run_info, "output");
pthread_mutex_lock(&global_lock);
if (RedisAI_DAGRun(run_info, _DAGFinishFunc, &results, results.error) != REDISMODULE_OK) {
pthread_mutex_unlock(&global_lock);
Expand Down Expand Up @@ -280,7 +280,7 @@ int testSimpleDAGRun2(RedisModuleCtx *ctx) {
goto cleanup;
}

RedisAI_DAGAddTensorGet(run_info, "output", results.error);
RedisAI_DAGAddTensorGet(run_info, "output");
pthread_mutex_lock(&global_lock);
if (RedisAI_DAGRun(run_info, _DAGFinishFunc, &results, results.error) != REDISMODULE_OK) {
pthread_mutex_unlock(&global_lock);
Expand Down Expand Up @@ -330,7 +330,7 @@ int testSimpleDAGRun2Error(RedisModuleCtx *ctx) {
goto cleanup;
}

RedisAI_DAGAddTensorGet(run_info, "output", results.error);
RedisAI_DAGAddTensorGet(run_info, "output");
pthread_mutex_lock(&global_lock);
if (RedisAI_DAGRun(run_info, _DAGFinishFunc, &results, results.error) != REDISMODULE_OK) {
pthread_mutex_unlock(&global_lock);
Expand Down Expand Up @@ -392,7 +392,7 @@ int testDAGResnet(RedisModuleCtx *ctx) {
RedisAI_DAGRunOpAddOutput(script_op, "output:{{1}}");
RedisAI_DAGAddRunOp(run_info, script_op, results.error);

RedisAI_DAGAddTensorGet(run_info, "output:{{1}}", results.error);
RedisAI_DAGAddTensorGet(run_info, "output:{{1}}");

pthread_mutex_lock(&global_lock);
if (RedisAI_DAGRun(run_info, _DAGFinishFunc, &results, results.error) != REDISMODULE_OK) {
Expand Down