diff --git a/src/model.c b/src/model.c index 3a469ee78..401a27ead 100644 --- a/src/model.c +++ b/src/model.c @@ -645,3 +645,7 @@ int RedisAI_Parse_ModelRun_RedisCommand(RedisModuleCtx *ctx, } return argpos; } + +RedisModuleType *RAI_ModelRedisType(void) { + return RedisAI_ModelType; +} diff --git a/src/model.h b/src/model.h index 90b019dc4..cb9c98610 100644 --- a/src/model.h +++ b/src/model.h @@ -225,4 +225,11 @@ int RedisAI_Parse_ModelRun_RedisCommand( RAI_ModelRunCtx** mctx, RedisModuleString*** outkeys, RAI_Model** mto, int useLocalContext, AI_dict** localContextDict, int use_chaining_operator, const char* chaining_operator, RAI_Error* error); + +/** + * @brief Returns the redis module type representing a model. + * @return redis module type representing a model. + */ +RedisModuleType *RAI_ModelRedisType(void); + #endif /* SRC_MODEL_H_ */ diff --git a/src/redisai.c b/src/redisai.c index c9796f1ad..0c1f83c9d 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -966,6 +966,7 @@ static int RedisAI_RegisterApi(RedisModuleCtx* ctx) { REGISTER_API(TensorDim, ctx); REGISTER_API(TensorByteSize, ctx); REGISTER_API(TensorData, ctx); + REGISTER_API(TensorRedisType, ctx); REGISTER_API(ModelCreate, ctx); REGISTER_API(ModelFree, ctx); @@ -978,17 +979,20 @@ static int RedisAI_RegisterApi(RedisModuleCtx* ctx) { REGISTER_API(ModelRun, ctx); REGISTER_API(ModelSerialize, ctx); REGISTER_API(ModelGetShallowCopy, ctx); + REGISTER_API(ModelRedisType, ctx); REGISTER_API(ScriptCreate, ctx); REGISTER_API(ScriptFree, ctx); REGISTER_API(ScriptRunCtxCreate, ctx); REGISTER_API(ScriptRunCtxAddInput, ctx); + REGISTER_API(ScriptRunCtxAddInputList, ctx); REGISTER_API(ScriptRunCtxAddOutput, ctx); REGISTER_API(ScriptRunCtxNumOutputs, ctx); REGISTER_API(ScriptRunCtxOutputTensor, ctx); REGISTER_API(ScriptRunCtxFree, ctx); REGISTER_API(ScriptRun, ctx); REGISTER_API(ScriptGetShallowCopy, ctx); + REGISTER_API(ScriptRedisType, ctx); return REDISMODULE_OK; } diff --git a/src/redisai.h b/src/redisai.h index cdf63bb7f..b93b112f8 100644 --- a/src/redisai.h +++ b/src/redisai.h @@ -78,6 +78,7 @@ int MODULE_API_FUNC(RedisAI_TensorNumDims)(RAI_Tensor* t); long long MODULE_API_FUNC(RedisAI_TensorDim)(RAI_Tensor* t, int dim); size_t MODULE_API_FUNC(RedisAI_TensorByteSize)(RAI_Tensor* t); char* MODULE_API_FUNC(RedisAI_TensorData)(RAI_Tensor* t); +RedisModuleType* MODULE_API_FUNC(RedisAI_TensorRedisType)(void); RAI_Model* MODULE_API_FUNC(RedisAI_ModelCreate)(int backend, char* devicestr, char* tag, RAI_ModelOpts opts, size_t ninputs, const char **inputs, @@ -93,17 +94,20 @@ void MODULE_API_FUNC(RedisAI_ModelRunCtxFree)(RAI_ModelRunCtx* mctx); int MODULE_API_FUNC(RedisAI_ModelRun)(RAI_ModelRunCtx** mctx, long long n, RAI_Error* err); RAI_Model* MODULE_API_FUNC(RedisAI_ModelGetShallowCopy)(RAI_Model* model); int MODULE_API_FUNC(RedisAI_ModelSerialize)(RAI_Model *model, char **buffer, size_t *len, RAI_Error *err); +RedisModuleType* MODULE_API_FUNC(RedisAI_ModelRedisType)(void); RAI_Script* MODULE_API_FUNC(RedisAI_ScriptCreate)(char* devicestr, char* tag, const char* scriptdef, RAI_Error* err); void MODULE_API_FUNC(RedisAI_ScriptFree)(RAI_Script* script, RAI_Error* err); RAI_ScriptRunCtx* MODULE_API_FUNC(RedisAI_ScriptRunCtxCreate)(RAI_Script* script, const char *fnname); -int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddInput)(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor); +int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddInput)(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor, RAI_Error* err); +int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddInputList)(RAI_ScriptRunCtx* sctx, RAI_Tensor** inputTensors, size_t len, RAI_Error* err); int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddOutput)(RAI_ScriptRunCtx* sctx); size_t MODULE_API_FUNC(RedisAI_ScriptRunCtxNumOutputs)(RAI_ScriptRunCtx* sctx); RAI_Tensor* MODULE_API_FUNC(RedisAI_ScriptRunCtxOutputTensor)(RAI_ScriptRunCtx* sctx, size_t index); void MODULE_API_FUNC(RedisAI_ScriptRunCtxFree)(RAI_ScriptRunCtx* sctx); int MODULE_API_FUNC(RedisAI_ScriptRun)(RAI_ScriptRunCtx* sctx, RAI_Error* err); RAI_Script* MODULE_API_FUNC(RedisAI_ScriptGetShallowCopy)(RAI_Script* script); +RedisModuleType* MODULE_API_FUNC(RedisAI_ScriptRedisType)(void); int MODULE_API_FUNC(RedisAI_GetLLAPIVersion)(); @@ -145,6 +149,7 @@ static int RedisAI_Initialize(RedisModuleCtx* ctx){ REDISAI_MODULE_INIT_FUNCTION(ctx, TensorDim); REDISAI_MODULE_INIT_FUNCTION(ctx, TensorByteSize); REDISAI_MODULE_INIT_FUNCTION(ctx, TensorData); + REDISAI_MODULE_INIT_FUNCTION(ctx, TensorRedisType); REDISAI_MODULE_INIT_FUNCTION(ctx, ModelCreate); REDISAI_MODULE_INIT_FUNCTION(ctx, ModelFree); @@ -157,17 +162,20 @@ static int RedisAI_Initialize(RedisModuleCtx* ctx){ REDISAI_MODULE_INIT_FUNCTION(ctx, ModelRun); REDISAI_MODULE_INIT_FUNCTION(ctx, ModelGetShallowCopy); REDISAI_MODULE_INIT_FUNCTION(ctx, ModelSerialize); + REDISAI_MODULE_INIT_FUNCTION(ctx, ModelRedisType); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptCreate); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptFree); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxCreate); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddInput); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddInputList); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddOutput); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxNumOutputs); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxOutputTensor); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxFree); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRun); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptGetShallowCopy); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRedisType); if(RedisAI_GetLLAPIVersion() < REDISAI_LLAPI_VERSION){ return REDISMODULE_ERR; diff --git a/src/script.c b/src/script.c index 0c0fa4d7b..873d2ae7f 100644 --- a/src/script.c +++ b/src/script.c @@ -156,21 +156,40 @@ RAI_ScriptRunCtx* RAI_ScriptRunCtxCreate(RAI_Script* script, } static int Script_RunCtxAddParam(RAI_ScriptRunCtx* sctx, - RAI_ScriptCtxParam* paramArr, + RAI_ScriptCtxParam** paramArr, RAI_Tensor* tensor) { RAI_ScriptCtxParam param = { .tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL, }; - paramArr = array_append(paramArr, param); + *paramArr = array_append(*paramArr, param); return 1; } -int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor) { - return Script_RunCtxAddParam(sctx, sctx->inputs, inputTensor); +int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor, RAI_Error* err) { + if(sctx->variadic != -1) { + RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Already encountered a variable size list of tensors"); + return 0; + } + return Script_RunCtxAddParam(sctx, &sctx->inputs, inputTensor); +} + +int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx* sctx, RAI_Tensor** inputTensors, size_t len, RAI_Error* err) { + // If this is the first time a list is added, set the variadic, else return an error. + if(sctx->variadic == -1) { + sctx->variadic = array_len(sctx->inputs); + } + else { + RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Already encountered a variable size list of tensors"); + return 0; + } + for(size_t i=0; i < len; i++){ + Script_RunCtxAddParam(sctx, &sctx->inputs, inputTensors[i]); + } + return 1; } int RAI_ScriptRunCtxAddOutput(RAI_ScriptRunCtx* sctx) { - return Script_RunCtxAddParam(sctx, sctx->outputs, NULL); + return Script_RunCtxAddParam(sctx, &sctx->outputs, NULL); } size_t RAI_ScriptRunCtxNumOutputs(RAI_ScriptRunCtx* sctx) { @@ -271,7 +290,8 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx, int is_input = 0; int outputs_flag_count = 0; size_t argpos = 4; - + // Keep variadic local variable as the calls for RAI_ScriptRunCtxAddInput check if (*sctx)->variadic already assigned. + size_t variadic = (*sctx)->variadic; for (; argpos <= argc - 1; argpos++) { const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL); if(!arg_string){ @@ -288,7 +308,11 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx, outputs_flag_count = 1; } else { if (!strcasecmp(arg_string, "$")) { - (*sctx)->variadic = argpos - 4; + if(variadic > -1) { + RedisAI_ReplyOrSetError(ctx,error,RAI_ESCRIPTRUN, "ERR Already encountered a variable size list of tensors"); + return -1; + } + variadic = argpos - 4; continue; } RedisModule_RetainString(ctx, argv[argpos]); @@ -310,10 +334,7 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx, return -1; } } - if (!RAI_ScriptRunCtxAddInput(*sctx, inputTensor)) { - RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Input key not found"); - return -1; - } + if (!RAI_ScriptRunCtxAddInput(*sctx, inputTensor, error)) return -1; } else { if (!RAI_ScriptRunCtxAddOutput(*sctx)) { RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Output key not found"); @@ -323,6 +344,8 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx, } } } + // In case variadic position found, set it in the context. + (*sctx)->variadic = variadic; return argpos; } @@ -335,4 +358,8 @@ void RedisAI_ReplyOrSetError(RedisModuleCtx *ctx, RAI_Error *error, RAI_ErrorCod } else { RedisModule_ReplyWithError(ctx, errorMessage); } -} \ No newline at end of file +} + +RedisModuleType *RAI_ScriptRedisType(void) { + return RedisAI_ScriptType; +} diff --git a/src/script.h b/src/script.h index 87b230457..ce15aaa43 100644 --- a/src/script.h +++ b/src/script.h @@ -67,9 +67,25 @@ RAI_ScriptRunCtx* RAI_ScriptRunCtxCreate(RAI_Script* script, * * @param sctx input RAI_ScriptRunCtx to add the input tensor * @param inputTensor input tensor structure - * @return returns 1 on success ( always returns success ) + * @param err error data structure to store error message in the case of + * failures + * @return returns 1 on success, 0 in case of error. */ -int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor); +int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor, RAI_Error* err); + +/** + * For each Allocates a RAI_ScriptCtxParam data structure, and enforces a shallow copy of + * the provided input tensor, adding it to the input tensors array of the + * RAI_ScriptRunCtx. + * + * @param sctx input RAI_ScriptRunCtx to add the input tensor + * @param inputTensors input tensors array + * @param len input tensors array len + * @param err error data structure to store error message in the case of + * failures + * @return returns 1 on success, 0 in case of error. + */ +int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx* sctx, RAI_Tensor** inputTensors, size_t len, RAI_Error* err); /** * Allocates a RAI_ScriptCtxParam data structure, and sets the tensor reference @@ -192,4 +208,10 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx, */ void RedisAI_ReplyOrSetError(RedisModuleCtx *ctx, RAI_Error *error, RAI_ErrorCode code, const char* errorMessage ); +/** + * @brief Returns the redis module type representing a script. + * @return redis module type representing a script. + */ +RedisModuleType *RAI_ScriptRedisType(void); + #endif /* SRC_SCRIPT_H_ */ diff --git a/src/tensor.c b/src/tensor.c index 16ac618ee..341e4acd8 100644 --- a/src/tensor.c +++ b/src/tensor.c @@ -1072,4 +1072,8 @@ int RAI_parseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar // return command arity as the number of processed args return argc; -} \ No newline at end of file +} + +RedisModuleType *RAI_TensorRedisType(void) { + return RedisAI_TensorType; +} diff --git a/src/tensor.h b/src/tensor.h index 2fe89e124..3269b18e9 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -378,4 +378,10 @@ int RAI_parseTensorSetArgs(RedisModuleCtx* ctx, RedisModuleString** argv, int RAI_parseTensorGetArgs(RedisModuleCtx* ctx, RedisModuleString** argv, int argc, RAI_Tensor* t); +/** + * @brief Returns the redis module type representing a tensor. + * @return redis module type representing a tensor. + */ +RedisModuleType *RAI_TensorRedisType(void); + #endif /* SRC_TENSOR_H_ */ diff --git a/test/tests_pytorch.py b/test/tests_pytorch.py index ff33b2715..6e603ccf8 100644 --- a/test/tests_pytorch.py +++ b/test/tests_pytorch.py @@ -641,6 +641,13 @@ def test_pytorch_scriptrun_errors(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) + + # "ERR Already encountered a variable size list of tensors" + try: + con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', '$', 'a', '$', 'b' 'OUTPUTS') + except Exception as e: + exception = e + env.assertEqual(type(exception), redis.exceptions.ResponseError) def test_pytorch_scriptinfo(env):