diff --git a/docs/commands.md b/docs/commands.md index bc182a775..5af883776 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -514,7 +514,119 @@ OK The `AI.SCRIPTDEL` is equivalent to the [Redis `DEL` command](https://redis.io/commands/del) and should be used in its stead. This ensures compatibility with all deployment options (i.e., stand-alone vs. cluster, OSS vs. Enterprise). +## AI.SCRIPTEXECUTE + +The **`AI.SCRIPTEXECUTE`** command runs a script stored as a key's value on its specified device. It accepts one or more inputs, where the inputs could be tensors stored in RedisAI, int, float, or strings and stores the script outputs as RedisAI tensors if required. + +The run request is put in a queue and is executed asynchronously by a worker thread. The client that had issued the run request is blocked until the script run is completed. When needed, tensors data is automatically copied to the device prior to execution. + +A `TIMEOUT t` argument can be specified to cause a request to be removed from the queue after it sits there `t` milliseconds, meaning that the client won't be interested in the result being computed after that time (`TIMEDOUT` is returned in that case). + +!!! warning "Intermediate memory overhead" + The execution of models will generate intermediate tensors that are not allocated by the Redis allocator, but by whatever allocator is used in the TORCH backend (which may act on main memory or GPU memory, depending on the device), thus not being limited by `maxmemory` configuration settings of Redis. + +**Redis API** + +``` +AI.SCRIPTEXECUTE +KEYS n [keys...] +[INPUTS m [input ...] | [LIST_INPUTS l [input ...]]*] +[OUTPUTS k [output ...] [TIMEOUT t]]+ +``` + +_Arguments_ + +* **key**: the script's key name +* **function**: the name of the function to run +* **KEYS**: Either a squence of key names that the script will access before, during and after its execution, or a tag which all those keys share. `KEYS` is a mandatory scope in this command. Redis will verify that all potional key accesses are done to the right shard. +* **INPUTS**: Denotes the beginning of the input parameters list, followed by its length and one or more inputs; The inputs can be tensor key name, `string`, `int` or `float`. The order of the input should be aligned with the order of their respected parameter at the function signature. Note that list inputs are treated in the **LIST_INPUTS** scope. +* **LIST_INPUTS** Denotes the beginning of a list, followed by its length and one or more inputs; The inputs can be tensor key name, `string`, `int` or `float`. The order of the input should be aligned with the order of their respected parameter at the function signature. Note that if more than one list is provided, their order should be aligned with the order of their respected paramter at the function signature. + +* **OUTPUTS**: denotes the beginning of the output tensors keys' list, followed by its length and one or more key names. +* **TIMEOUT**: the time (in ms) after which the client is unblocked and a `TIMEDOUT` string is returned + +_Return_ + +A simple 'OK' string, a simple `TIMEDOUT` string, or an error. + +**Examples** + +The following is an example of running the previously-created 'myscript' on two input tensors: + +``` +redis> AI.TENSORSET mytensor1 FLOAT 1 VALUES 40 +OK +redis> AI.TENSORSET mytensor2 FLOAT 1 VALUES 2 +OK +redis> AI.SCRIPTEXECUTE myscript addtwo KEYS 3 mytensor1 mytensor2 result INPUTS 2 mytensor1 mytensor2 OUTPUTS 1 result +OK +redis> AI.TENSORGET result VALUES +1) FLOAT +2) 1) (integer) 1 +3) 1) "42" +``` + +Note: The above command could be executed with a shorter version, given all the keys are tagged with the same tag: + +``` +redis> AI.TENSORSET mytensor1{tag} FLOAT 1 VALUES 40 +OK +redis> AI.TENSORSET mytensor2{tag} FLOAT 1 VALUES 2 +OK +redis> AI.SCRIPTEXECUTE myscript{tag} addtwo KEYS 1 {tag} INPUTS 2 mytensor1{tag} mytensor2{tag} OUTPUTS 1 result{tag} +OK +redis> AI.TENSORGET result{tag} VALUES +1) FLOAT +2) 1) (integer) 1 +3) 1) "42" +``` + +If 'myscript' supports `List[Tensor]` arguments: +```python +def addn(a, args : List[Tensor]): + return a + torch.stack(args).sum() +``` + +``` +redis> AI.TENSORSET mytensor1{tag} FLOAT 1 VALUES 40 +OK +redis> AI.TENSORSET mytensor2{tag} FLOAT 1 VALUES 1 +OK +redis> AI.TENSORSET mytensor3{tag} FLOAT 1 VALUES 1 +OK +redis> AI.SCRIPTEXECUTE myscript{tag} addn keys 1 {tag} INPUTS 1 mytensor1{tag} LIST_INPUTS 2 mytensor2{tag} mytensor3{tag} OUTPUTS 1 result{tag} +OK +redis> AI.TENSORGET result{tag} VALUES +1) FLOAT +2) 1) (integer) 1 +3) 1) "42" +``` + +### Redis Commands support. +In RedisAI TorchScript now supports simple (non-blocking) Redis commnands via the `redis.execute` API. The following (usless) script gets a key name (`x{1}`), and an `int` value (3). First, the script `SET`s the value in the key. Next, the script `GET`s the value back from the key, and sets it in a tensor which is eventually stored under the key 'y{1}'. Note that the inputs are `str` and `int`. The script sets and gets the value and set it into a tensor. + +``` +def redis_int_to_tensor(redis_value: int): + return torch.tensor(redis_value) + +def int_set_get(key:str, value:int): + redis.execute("SET", key, str(value)) + res = redis.execute("GET", key) + return redis_string_int_to_tensor(res) +``` +``` +redis> AI.SCRIPTEXECUTE redis_scripts{1} int_set_get KEYS 1 {1} INPUTS 2 x{1} 3 OUTPUTS 1 y{1} +OK +redis> AI.TENSORGET y{1} VALUES +1) (integer) 3 +``` + +!!! warning "Intermediate memory overhead" + The execution of scripts may generate intermediate tensors that are not allocated by the Redis allocator, but by whatever allocator is used in the backends (which may act on main memory or GPU memory, depending on the device), thus not being limited by `maxmemory` configuration settings of Redis. + ## AI.SCRIPTRUN +_This command is deprecated and will not be available in future versions. consider using AI.MODELEXECUTE command instead._ + The **`AI.SCRIPTRUN`** command runs a script stored as a key's value on its specified device. It accepts one or more input tensors and store output tensors. The run request is put in a queue and is executed asynchronously by a worker thread. The client that had issued the run request is blocked until the script run is completed. When needed, tensors data is automatically copied to the device prior to execution. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a6361138a..901cc7dbd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -15,6 +15,7 @@ file (GLOB BACKEND_COMMON_SRC backends/util.c redis_ai_objects/err.c util/dict.c + util/dictionaries.c redis_ai_objects/tensor.c util/string_utils.c execution/utils.c @@ -22,20 +23,26 @@ file (GLOB BACKEND_COMMON_SRC ADD_LIBRARY(redisai_obj OBJECT util/dict.c + util/dictionaries.c util/queue.c util/string_utils.c redisai.c execution/command_parser.c - execution/deprecated.c + execution/parsing/deprecated.c + execution/parsing/dag_parser.c + execution/parsing/model_commands_parser.c + execution/parsing/script_commands_parser.c + execution/parsing/parse_utils.c execution/run_info.c execution/background_workers.c execution/utils.c config/config.c execution/DAG/dag.c - execution/DAG/dag_parser.c execution/DAG/dag_builder.c execution/DAG/dag_execute.c - execution/modelRun_ctx.c + execution/DAG/dag_op.c + execution/execution_contexts/modelRun_ctx.c + execution/execution_contexts/scriptRun_ctx.c backends/backends.c backends/util.c redis_ai_objects/model.c diff --git a/src/backends/libtorch_c/torch_c.cpp b/src/backends/libtorch_c/torch_c.cpp index 288a2275a..347b046e3 100644 --- a/src/backends/libtorch_c/torch_c.cpp +++ b/src/backends/libtorch_c/torch_c.cpp @@ -192,58 +192,7 @@ struct ModuleContext { int64_t device_id; }; -void torchRunModule(ModuleContext *ctx, const char *fnName, int variadic, long nInputs, - DLManagedTensor **inputs, long nOutputs, DLManagedTensor **outputs) { - // Checks device, if GPU then move input to GPU before running - // TODO: This will need to change at some point, as individual tensors will have their placement - // and script will only make sure that placement is correct - - torch::DeviceType device_type; - switch (ctx->device) { - case kDLCPU: - device_type = torch::kCPU; - break; - case kDLGPU: - device_type = torch::kCUDA; - break; - default: - throw std::runtime_error(std::string("Unsupported device ") + std::to_string(ctx->device)); - } - - torch::Device device(device_type, ctx->device_id); - - torch::jit::Stack stack; - - for (int i = 0; i < nInputs; i++) { - if (i == variadic) { - break; - } - DLTensor *input = &(inputs[i]->dl_tensor); - torch::Tensor tensor = fromDLPack(input); - stack.push_back(tensor.to(device)); - } - - if (variadic != -1) { - std::vector args; - for (int i = variadic; i < nInputs; i++) { - DLTensor *input = &(inputs[i]->dl_tensor); - torch::Tensor tensor = fromDLPack(input); - tensor.to(device); - args.emplace_back(tensor); - } - stack.push_back(args); - } - - if (ctx->module) { - torch::NoGradGuard guard; - torch::jit::script::Method method = ctx->module->get_method(fnName); - method.run(stack); - } else { - torch::NoGradGuard guard; - torch::jit::Function &fn = ctx->cu->get_function(fnName); - fn.run(stack); - } - +static void torchHandlOutputs(torch::jit::Stack& stack, const char* fnName, long nOutputs, DLManagedTensor **outputs) { torch::DeviceType output_device_type = torch::kCPU; torch::Device output_device(output_device_type, -1); @@ -284,30 +233,25 @@ void torchRunModule(ModuleContext *ctx, const char *fnName, int variadic, long n } } -} // namespace - -extern "C" void torchBasicTest() { - torch::Tensor mat = torch::rand({3, 3}); - std::cout << mat << std::endl; -} - -extern "C" DLManagedTensor *torchNewTensor(DLDataType dtype, long ndims, int64_t *shape, - int64_t *strides, char *data) { - // at::DeviceType device_type = getATenDeviceType(kDLCPU); - at::ScalarType stype = toScalarType(dtype); - torch::Device device(getATenDeviceType(kDLCPU), -1); - torch::Tensor tensor = - torch::from_blob(data, at::IntArrayRef(shape, ndims), at::IntArrayRef(strides, ndims), - // torch::device(at::DeviceType::CPU).dtype(stype)); - torch::device(device).dtype(stype)); +void torchRunModule(ModuleContext *ctx, const char *fnName, torch::jit::Stack& stack, long nOutputs, DLManagedTensor **outputs){ - DLManagedTensor *dl_tensor = toManagedDLPack(tensor); + if (ctx->module) { + torch::NoGradGuard guard; + torch::jit::script::Method method = ctx->module->get_method(fnName); + method.run(stack); + } else { + torch::NoGradGuard guard; + torch::jit::Function &fn = ctx->cu->get_function(fnName); + fn.run(stack); + } - return dl_tensor; + torchHandlOutputs(stack, fnName, nOutputs, outputs); } +} // namespace + extern "C" void* torchCompileScript(const char* script, DLDeviceType device, int64_t device_id, - char **error, void* (*alloc)(size_t)) + char **error) { ModuleContext* ctx = new ModuleContext(); ctx->device = device; @@ -329,10 +273,7 @@ extern "C" void* torchCompileScript(const char* script, DLDeviceType device, int } catch(std::exception& e) { - size_t len = strlen(e.what()) +1; - *error = (char*)alloc(len * sizeof(char)); - strcpy(*error, e.what()); - (*error)[len-1] = '\0'; + *error = RedisModule_Strdup(e.what()); delete ctx; return NULL; } @@ -340,7 +281,7 @@ extern "C" void* torchCompileScript(const char* script, DLDeviceType device, int } extern "C" void *torchLoadModel(const char *graph, size_t graphlen, DLDeviceType device, - int64_t device_id, char **error, void *(*alloc)(size_t)) { + int64_t device_id, char **error) { std::string graphstr(graph, graphlen); std::istringstream graph_stream(graphstr, std::ios_base::binary); ModuleContext *ctx = new ModuleContext(); @@ -358,59 +299,234 @@ extern "C" void *torchLoadModel(const char *graph, size_t graphlen, DLDeviceType ctx->module = module; ctx->cu = nullptr; } catch (std::exception &e) { - size_t len = strlen(e.what()) + 1; - *error = (char *)alloc(len * sizeof(char)); - strcpy(*error, e.what()); - (*error)[len - 1] = '\0'; + *error = RedisModule_Strdup(e.what()); delete ctx; return NULL; } return ctx; } -extern "C" void torchRunScript(void *scriptCtx, const char *fnName, int variadic, long nInputs, - DLManagedTensor **inputs, long nOutputs, DLManagedTensor **outputs, - char **error, void *(*alloc)(size_t)) { +static torch::DeviceType getDeviceType(ModuleContext *ctx) { + switch (ctx->device) { + case kDLCPU: + return torch::kCPU; + case kDLGPU: + return torch::kCUDA; + default: + throw std::runtime_error(std::string("Unsupported device ") + std::to_string(ctx->device)); + } +} + +extern "C" bool torchMatchScriptSchema(TorchScriptFunctionArgumentType *schema ,size_t nArguments, TorchFunctionInputCtx* inputsCtx, char **error) { + char* buf; + int schemaListCount = 0; + size_t schemaTensorCount = 0; + size_t schemaIntCount = 0; + size_t schemaFloatCount = 0; + size_t schemaStringCount = 0; + size_t totalInputsCount = inputsCtx->tensorCount + inputsCtx->intCount + inputsCtx->floatCount + inputsCtx->stringCount; + if((totalInputsCount) < nArguments) { + asprintf(&buf, "Wrong number of inputs. Expected %ld but was %ld", nArguments, totalInputsCount); + goto cleanup; + } + for (size_t i = 0; i < nArguments; i++) { + switch (schema[i]) { + case TENSOR: + schemaTensorCount++; + break; + case INT: + schemaIntCount++; + break; + case FLOAT: + schemaFloatCount++; + break; + case STRING: + schemaStringCount++; + break; + case TENSOR_LIST: + schemaListCount++; + if(schemaListCount > inputsCtx->listCount) { + asprintf(&buf, "Wrong number of lists. Expected %d but was %ld", schemaListCount, inputsCtx->listCount); + goto cleanup; + } + schemaTensorCount+=inputsCtx->listSizes[schemaListCount-1]; + break; + case INT_LIST: + schemaListCount++; + if(schemaListCount > inputsCtx->listCount) { + asprintf(&buf, "Wrong number of lists. Expected %d but was %ld", schemaListCount, inputsCtx->listCount); + goto cleanup; + } + schemaIntCount+=inputsCtx->listSizes[schemaListCount-1]; + break; + case FLOAT_LIST: + schemaListCount++; + if(schemaListCount > inputsCtx->listCount) { + asprintf(&buf, "Wrong number of lists. Expected %d but was %ld", schemaListCount, inputsCtx->listCount); + goto cleanup; + } + schemaFloatCount+=inputsCtx->listSizes[schemaListCount-1]; + break; + case STRING_LIST: + schemaListCount++; + if(schemaListCount > inputsCtx->listCount) { + asprintf(&buf, "Wrong number of lists. Expected %d but was %ld", schemaListCount, inputsCtx->listCount); + goto cleanup; + } + schemaStringCount+=inputsCtx->listSizes[schemaListCount-1]; + break; + default: + asprintf(&buf, "Unkown type in script schema validation."); + goto cleanup; + } + } + if(schemaListCount != inputsCtx->listCount) { + asprintf(&buf, "Wrong number of lists. Expected %d but was %ld", schemaListCount, inputsCtx->listCount); + goto cleanup; + } + if(schemaTensorCount != inputsCtx->tensorCount || schemaIntCount != inputsCtx->intCount || schemaFloatCount != inputsCtx->floatCount || schemaStringCount!= inputsCtx->stringCount) { + asprintf(&buf, "Wrong number of parameters"); + goto cleanup; + } + + return true; + + cleanup: + *error = RedisModule_Strdup(buf); + free(buf); + return false; +} + +extern "C" void torchRunScript(void *scriptCtx, const char *fnName, + TorchScriptFunctionArgumentType* schema, size_t nArguments, + TorchFunctionInputCtx* inputsCtx, + DLManagedTensor **outputs,long nOutputs, + char **error) { ModuleContext *ctx = (ModuleContext *)scriptCtx; try { - torchRunModule(ctx, fnName, variadic, nInputs, inputs, nOutputs, outputs); + torch::DeviceType device_type = getDeviceType(ctx); + torch::Device device(device_type, ctx->device_id); + + torch::jit::Stack stack; + + size_t listsIdx = 0; + size_t tensorIdx = 0; + size_t intIdx = 0; + size_t floatIdx = 0; + size_t stringIdx = 0; + for(size_t i= 0; i < nArguments; i++) { + // In case of tensor. + switch (schema[i]) { + case TENSOR: { + DLTensor *input = &(inputsCtx->tensorInputs[tensorIdx++]->dl_tensor); + torch::Tensor tensor = fromDLPack(input); + stack.push_back(tensor.to(device)); + break; + } + case TENSOR_LIST: { + std::vector args; + size_t argumentSize = inputsCtx->listSizes[listsIdx++]; + for (size_t j = 0; j < argumentSize; j++) { + DLTensor *input = &(inputsCtx->tensorInputs[tensorIdx++]->dl_tensor); + torch::Tensor tensor = fromDLPack(input); + tensor.to(device); + args.emplace_back(tensor); + } + stack.push_back(args); + break; + } + case STRING_LIST: { + std::vector args; + size_t argumentSize = inputsCtx->listSizes[listsIdx++]; + for (size_t j = 0; j < argumentSize; j++) { + const char* cstr = RedisModule_StringPtrLen(inputsCtx->stringsInputs[stringIdx++], NULL); + torch::string str = torch::string(cstr); + args.emplace_back(str); + } + stack.push_back(args); + break; + } + case INT_LIST: { + std::vector args; + size_t argumentSize = inputsCtx->listSizes[listsIdx++]; + for (size_t j = 0; j < argumentSize; j++) { + int32_t val = inputsCtx->intInputs[intIdx++]; + args.emplace_back(val); + } + stack.push_back(args); + break; + } + case FLOAT_LIST: { + std::vector args; + size_t argumentSize = inputsCtx->listSizes[listsIdx++]; + for (size_t j = 0; j < argumentSize; j++) { + float val = inputsCtx->floatInputs[floatIdx++]; + args.emplace_back(val); + } + stack.push_back(args); + break; + } + + case INT: { + int32_t val = inputsCtx->intInputs[intIdx++]; + stack.push_back(val); + break; + } + case FLOAT: { + float val = inputsCtx->floatInputs[floatIdx++]; + stack.push_back(val); + break; + } + case STRING: { + const char* cstr = RedisModule_StringPtrLen(inputsCtx->stringsInputs[stringIdx++], NULL); + torch::string str = torch::string(cstr); + stack.push_back(str); + break; + } + default: { + *error = RedisModule_Strdup("Unkown script input type"); + break; + } + } + } + + torchRunModule(ctx, fnName, stack, nOutputs, outputs); } catch (std::exception &e) { - size_t len = strlen(e.what()) + 1; - *error = (char *)alloc(len * sizeof(char)); - strcpy(*error, e.what()); - (*error)[len - 1] = '\0'; + *error = RedisModule_Strdup(e.what()); } } extern "C" void torchRunModel(void *modelCtx, long nInputs, DLManagedTensor **inputs, long nOutputs, - DLManagedTensor **outputs, char **error, void *(*alloc)(size_t)) { + DLManagedTensor **outputs, char **error) { ModuleContext *ctx = (ModuleContext *)modelCtx; try { - torchRunModule(ctx, "forward", -1, nInputs, inputs, nOutputs, outputs); + torch::DeviceType device_type = getDeviceType(ctx); + torch::Device device(device_type, ctx->device_id); + + torch::jit::Stack stack; + for (int i = 0; i < nInputs; i++) { + DLTensor *input = &(inputs[i]->dl_tensor); + torch::Tensor tensor = fromDLPack(input); + stack.push_back(tensor.to(device)); + } + torchRunModule(ctx, "forward", stack, nOutputs, outputs); } catch (std::exception &e) { - size_t len = strlen(e.what()) + 1; - *error = (char *)alloc(len * sizeof(char)); - strcpy(*error, e.what()); - (*error)[len - 1] = '\0'; + *error = RedisModule_Strdup(e.what()); } } -extern "C" void torchSerializeModel(void *modelCtx, char **buffer, size_t *len, char **error, - void *(*alloc)(size_t)) { +extern "C" void torchSerializeModel(void *modelCtx, char **buffer, size_t *len, char **error) { ModuleContext *ctx = (ModuleContext *)modelCtx; std::ostringstream out; try { ctx->module->save(out); auto out_str = out.str(); int size = out_str.size(); - *buffer = (char *)alloc(size); + *buffer = (char *)RedisModule_Alloc(size); memcpy(*buffer, out_str.c_str(), size); *len = size; } catch (std::exception &e) { - size_t len = strlen(e.what()) + 1; - *error = (char *)alloc(len * sizeof(char)); - strcpy(*error, e.what()); - (*error)[len - 1] = '\0'; + *error = RedisModule_Strdup(e.what()); } } @@ -421,7 +537,7 @@ extern "C" void torchDeallocContext(void *ctx) { } } -extern "C" void torchSetInterOpThreads(int num_threads, char **error, void *(*alloc)(size_t)) { +extern "C" void torchSetInterOpThreads(int num_threads, char **error) { int current_num_interop_threads = torch::get_num_interop_threads(); if (current_num_interop_threads != num_threads) { try { @@ -429,15 +545,12 @@ extern "C" void torchSetInterOpThreads(int num_threads, char **error, void *(*al } catch (std::exception) { std::string error_msg = "Cannot set number of inter-op threads after parallel work has started"; - size_t len = error_msg.length() + 1; - *error = (char *)alloc(len * sizeof(char)); - strcpy(*error, error_msg.c_str()); - (*error)[len - 1] = '\0'; + *error = RedisModule_Strdup(error_msg.c_str()); } } } -extern "C" void torchSetIntraOpThreads(int num_threads, char **error, void *(*alloc)(size_t)) { +extern "C" void torchSetIntraOpThreads(int num_threads, char **error) { int current_num_threads = torch::get_num_threads(); if (current_num_threads != num_threads) { try { @@ -445,10 +558,7 @@ extern "C" void torchSetIntraOpThreads(int num_threads, char **error, void *(*al } catch (std::exception) { std::string error_msg = "Cannot set number of intra-op threads after parallel work has started"; - size_t len = error_msg.length() + 1; - *error = (char *)alloc(len * sizeof(char)); - strcpy(*error, error_msg.c_str()); - (*error)[len - 1] = '\0'; + *error = RedisModule_Strdup(error_msg.c_str()); } } } @@ -491,6 +601,37 @@ static int getArgumentTensorCount(const c10::Argument& arg){ } } +static TorchScriptFunctionArgumentType getArgumentType(const c10::Argument& arg){ + switch (arg.type()->kind()) + { + case c10::TypeKind::TensorType: + return TENSOR; + case c10::TypeKind::IntType: + return INT; + case c10::TypeKind::FloatType: + return FLOAT; + case c10::TypeKind::StringType: + return STRING; + case c10::TypeKind::ListType: { + c10::ListTypePtr lt = arg.type()->cast(); + switch(lt->getElementType()->kind()) { + case c10::TypeKind::TensorType: + return TENSOR_LIST; + case c10::TypeKind::IntType: + return INT_LIST; + case c10::TypeKind::FloatType: + return FLOAT_LIST; + case c10::TypeKind::StringType: + return STRING_LIST; + default: + return UNKOWN; + } + } + default: + return UNKOWN; + } +} + extern "C" size_t torchModelNumOutputs(void *modelCtx, char** error) { ModuleContext *ctx = (ModuleContext *)modelCtx; size_t noutputs = 0; @@ -518,3 +659,26 @@ extern "C" const char* torchModelInputNameAtIndex(void* modelCtx, size_t index, } return ret; } + +extern "C" size_t torchScript_FunctionCount(void* scriptCtx) { + ModuleContext *ctx = (ModuleContext *)scriptCtx; + return ctx->cu->get_functions().size(); +} + +extern "C" const char* torchScript_FunctionName(void* scriptCtx, size_t fn_index) { + ModuleContext *ctx = (ModuleContext *)scriptCtx; + std::vector functions = ctx->cu->get_functions(); + return functions[fn_index]->name().c_str(); +} + +extern "C" size_t torchScript_FunctionArgumentCount(void* scriptCtx, size_t fn_index) { + ModuleContext *ctx = (ModuleContext *)scriptCtx; + std::vector functions = ctx->cu->get_functions(); + return functions[fn_index]->getSchema().arguments().size(); +} + +extern "C" TorchScriptFunctionArgumentType torchScript_FunctionArgumentype(void* scriptCtx, size_t fn_index, size_t arg_index) { + ModuleContext *ctx = (ModuleContext *)scriptCtx; + std::vector functions = ctx->cu->get_functions(); + return getArgumentType(ctx->cu->get_functions()[fn_index]->getSchema().arguments()[arg_index]); +} diff --git a/src/backends/libtorch_c/torch_c.h b/src/backends/libtorch_c/torch_c.h index 3f13e7de9..f1a627cf9 100644 --- a/src/backends/libtorch_c/torch_c.h +++ b/src/backends/libtorch_c/torch_c.h @@ -6,39 +6,186 @@ extern "C" { #endif -void torchBasicTest(); - -DLManagedTensor *torchNewTensor(DLDataType dtype, long ndims, int64_t *shape, int64_t *strides, - char *data); - -void *torchCompileScript(const char *script, DLDeviceType device, int64_t device_id, char **error, - void *(*alloc)(size_t)); - +#include "redis_ai_objects/script_struct.h" + +typedef struct TorchFunctionInputCtx { + DLManagedTensor **tensorInputs; + size_t tensorCount; + int32_t *intInputs; + size_t intCount; + float *floatInputs; + size_t floatCount; + RedisModuleString **stringsInputs; + size_t stringCount; + size_t *listSizes; + size_t listCount; +} TorchFunctionInputCtx; + +/** + * @brief Compiles a script string into torch compliation unit stored in a module context. + * + * @param script Script string. + * @param device Device for the script to execute on. + * @param device_id Device id for the script to execute on. + * @param error Error string to be populated in case of an exception. + * @return void* ModuleContext pointer. + */ +void *torchCompileScript(const char *script, DLDeviceType device, int64_t device_id, char **error); + +/** + * @brief Loads a model from model definition string and stores it in a module context. + * + * @param model Model definition string. + * @param modellen Length of the string. + * @param device Device for the model to execute on. + * @param device_id Device id for the model to execute on. + * @param error Error string to be populated in case of an exception. + * @return void* ModuleContext pointer. + */ void *torchLoadModel(const char *model, size_t modellen, DLDeviceType device, int64_t device_id, - char **error, void *(*alloc)(size_t)); - -void torchRunScript(void *scriptCtx, const char *fnName, int variadic, long nInputs, - DLManagedTensor **inputs, long nOutputs, DLManagedTensor **outputs, - char **error, void *(*alloc)(size_t)); - + char **error); + +/** + * @brief Validate SCRIPTEXECUTE or LLAPI script execute inputs according to the funciton schema. + * + * @param schema Fuction argument types (schema). + * @param nArguments Number of arguments in the function. + * @param inputsCtx Function execution context containing the information about given inputs. + * @param error Error string to be populated in case of an exception. + * @return true If the user provided inputs from types and order that matches the schema. + * @return false Otherwise. + */ +bool torchMatchScriptSchema(TorchScriptFunctionArgumentType *schema, size_t nArguments, + TorchFunctionInputCtx *inputsCtx, char **error); + +/** + * @brief Executes a function in a script. + * @note Should be called after torchMatchScriptSchema verication. + * @param scriptCtx Executes a function in a script. + * @param fnName Function name. + * @param schema Fuction argument types (schema). + * @param nArguments Number of arguments in the function. + * @param inputsCtx unction execution context containing the information about given inputs. + * @param outputs Array of output tensor (placeholders). + * @param nOutputs Number of output tensors. + * @param error Error string to be populated in case of an exception. + */ +void torchRunScript(void *scriptCtx, const char *fnName, TorchScriptFunctionArgumentType *schema, + size_t nArguments, TorchFunctionInputCtx *inputsCtx, DLManagedTensor **outputs, + long nOutputs, char **error); + +/** + * @brief Executes a model. + * + * @param modelCtx Model context. + * @param nInputs Number of tensor inputs. + * @param inputs Array of input tensors. + * @param nOutputs Number of output tensors. + * @param outputs Array of output tensor (placeholders). + * @param error Error string to be populated in case of an exception. + */ void torchRunModel(void *modelCtx, long nInputs, DLManagedTensor **inputs, long nOutputs, - DLManagedTensor **outputs, char **error, void *(*alloc)(size_t)); - -void torchSerializeModel(void *modelCtx, char **buffer, size_t *len, char **error, - void *(*alloc)(size_t)); - + DLManagedTensor **outputs, char **error); + +/** + * @brief + * + * @param modelCtx Serilized a model into a string defintion. + * @param buffer Byte array to hold the definition. + * @param len Will store the length of the string. + * @param error Error string to be populated in case of an exception. + */ +void torchSerializeModel(void *modelCtx, char **buffer, size_t *len, char **error); + +/** + * @brief Deallicate the create torch script/model object. + * + * @param ctx Object to free. + */ void torchDeallocContext(void *ctx); -void torchSetInterOpThreads(int num_threads, char **error, void *(*alloc)(size_t)); - -void torchSetIntraOpThreads(int num_threadsm, char **error, void *(*alloc)(size_t)); - +/** + * @brief Sets the number of inter-op threads for Torch backend. + * + * @param num_threads Number of inter-op threads. + * @param error Error string to be populated in case of an exception. + */ +void torchSetInterOpThreads(int num_threads, char **error); + +/** + * @brief Sets the number of intra-op threads for Torch backend. + * + * @param num_threads Number of intra-op threads. + * @param error Error string to be populated in case of an exception. + */ +void torchSetIntraOpThreads(int num_threadsm, char **error); + +/** + * @brief Returns the number of inputs of a model + * + * @param modelCtx Model context. + * @param error Error string to be populated in case of an exception. + * @return size_t Number of inputs. + */ size_t torchModelNumInputs(void *modelCtx, char **error); +/** + * @brief Returns the name of the model input at index. + * + * @param modelCtx Model context. + * @param index Input index. + * @param error Error string to be populated in case of an exception. + * @return const char* Input name. + */ const char *torchModelInputNameAtIndex(void *modelCtx, size_t index, char **error); +/** + * @brief Returns the number of outputs of a model + * + * @param modelCtx Model context. + * @param error Error string to be populated in case of an exception. + * @return size_t Number of outputs. + */ size_t torchModelNumOutputs(void *modelCtx, char **error); +/** + * @brief Return the number of functions in the script. + * + * @param scriptCtx Script context. + * @return size_t number of functions. + */ +size_t torchScript_FunctionCount(void *scriptCtx); + +/** + * @brief Return the name of the function numbered fn_index in the script. + * + * @param scriptCtx Script context. + * @param fn_index Function number. + * @return const char* Function name. + */ +const char *torchScript_FunctionName(void *scriptCtx, size_t fn_index); + +/** + * @brief Return the number of arguments in the fuction numbered fn_index in the script. + * + * @param scriptCtx Script context. + * @param fn_index Function number. + * @return size_t Number of arguments. + */ +size_t torchScript_FunctionArgumentCount(void *scriptCtx, size_t fn_index); + +/** + * @brief Rerturns the type of the argument at arg_index of function numbered fn_index in the + * script. + * + * @param scriptCtx Script context. + * @param fn_index Function number. + * @param arg_index Argument number. + * @return TorchScriptFunctionArgumentType The type of the argument in RedisAI enum format. + */ +TorchScriptFunctionArgumentType torchScript_FunctionArgumentype(void *scriptCtx, size_t fn_index, + size_t arg_index); + #ifdef __cplusplus } #endif diff --git a/src/backends/torch.c b/src/backends/torch.c index a856c56d2..119c1c5ef 100644 --- a/src/backends/torch.c +++ b/src/backends/torch.c @@ -2,6 +2,7 @@ #include "backends/util.h" #include "backends/torch.h" #include "util/arr.h" +#include "util/dictionaries.h" #include "libtorch_c/torch_c.h" #include "redis_ai_objects/tensor.h" @@ -25,6 +26,7 @@ int RAI_InitBackendTorch(int (*get_api_fn)(const char *, void *)) { get_api_fn("RedisModule_ThreadSafeContextUnlock", ((void **)&RedisModule_ThreadSafeContextUnlock)); get_api_fn("RedisModule_FreeThreadSafeContext", ((void **)&RedisModule_FreeThreadSafeContext)); + get_api_fn("RedisModule_StringPtrLen", ((void **)&RedisModule_StringPtrLen)); return REDISMODULE_OK; } @@ -58,7 +60,7 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ char *error_descr = NULL; if (opts.backends_inter_op_parallelism > 0) { - torchSetInterOpThreads(opts.backends_inter_op_parallelism, &error_descr, RedisModule_Alloc); + torchSetInterOpThreads(opts.backends_inter_op_parallelism, &error_descr); } if (error_descr != NULL) { @@ -68,7 +70,7 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ } if (opts.backends_intra_op_parallelism > 0) { - torchSetIntraOpThreads(opts.backends_intra_op_parallelism, &error_descr, RedisModule_Alloc); + torchSetIntraOpThreads(opts.backends_intra_op_parallelism, &error_descr); } if (error_descr) { RAI_SetError(error, RAI_EMODELCREATE, error_descr); @@ -76,8 +78,7 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_ return NULL; } - void *model = - torchLoadModel(modeldef, modellen, dl_device, deviceid, &error_descr, RedisModule_Alloc); + void *model = torchLoadModel(modeldef, modellen, dl_device, deviceid, &error_descr); if (error_descr) { goto cleanup; @@ -225,8 +226,7 @@ int RAI_ModelRunTorch(RAI_ModelRunCtx **mctxs, RAI_Error *error) { } char *error_descr = NULL; - torchRunModel(mctxs[0]->model->model, ninputs, inputs_dl, noutputs, outputs_dl, &error_descr, - RedisModule_Alloc); + torchRunModel(mctxs[0]->model->model, ninputs, inputs_dl, noutputs, outputs_dl, &error_descr); for (size_t i = 0; i < ninputs; ++i) { RAI_TensorFree(inputs[i]); @@ -273,7 +273,7 @@ int RAI_ModelSerializeTorch(RAI_Model *model, char **buffer, size_t *len, RAI_Er *len = model->datalen; } else { char *error_descr = NULL; - torchSerializeModel(model->model, buffer, len, &error_descr, RedisModule_Alloc); + torchSerializeModel(model->model, buffer, len, &error_descr); if (*buffer == NULL) { RAI_SetError(error, RAI_EMODELSERIALIZE, error_descr); @@ -309,8 +309,7 @@ RAI_Script *RAI_ScriptCreateTorch(const char *devicestr, const char *scriptdef, } char *error_descr = NULL; - void *script = - torchCompileScript(scriptdef, dl_device, deviceid, &error_descr, RedisModule_Alloc); + void *script = torchCompileScript(scriptdef, dl_device, deviceid, &error_descr); if (script == NULL) { RAI_SetError(error, RAI_ESCRIPTCREATE, error_descr); @@ -323,6 +322,21 @@ RAI_Script *RAI_ScriptCreateTorch(const char *devicestr, const char *scriptdef, ret->scriptdef = RedisModule_Strdup(scriptdef); ret->devicestr = RedisModule_Strdup(devicestr); ret->refCount = 1; + ret->functionData = AI_dictCreate(&AI_dictType_String_ArrSimple, NULL); + + size_t functionCount = torchScript_FunctionCount(script); + for (size_t i = 0; i < functionCount; i++) { + const char *name = torchScript_FunctionName(script, i); + size_t argCount = torchScript_FunctionArgumentCount(script, i); + TorchScriptFunctionArgumentType *argTypes = + array_new(TorchScriptFunctionArgumentType, argCount); + for (size_t j = 0; j < argCount; j++) { + TorchScriptFunctionArgumentType argType = torchScript_FunctionArgumentype(script, i, j); + argTypes = array_append(argTypes, argType); + } + AI_dictAdd(ret->functionData, (void *)name, (void *)argTypes); + array_free(argTypes); + } return ret; } @@ -330,6 +344,7 @@ RAI_Script *RAI_ScriptCreateTorch(const char *devicestr, const char *scriptdef, void RAI_ScriptFreeTorch(RAI_Script *script, RAI_Error *error) { torchDeallocContext(script->script); + AI_dictRelease(script->functionData); RedisModule_Free(script->scriptdef); RedisModule_Free(script->devicestr); RedisModule_Free(script); @@ -352,8 +367,36 @@ int RAI_ScriptRunTorch(RAI_ScriptRunCtx *sctx, RAI_Error *error) { } char *error_descr = NULL; - torchRunScript(sctx->script->script, sctx->fnname, sctx->variadic, nInputs, inputs, nOutputs, - outputs, &error_descr, RedisModule_Alloc); + + TorchScriptFunctionArgumentType *arguments = + AI_dictFetchValue(sctx->script->functionData, sctx->fnname); + if (!arguments) { + RAI_SetError(error, RAI_ESCRIPTRUN, "attempted to get undefined function"); + RedisModule_Free(error_descr); + return 1; + } + + // Create inputs context on stack. + TorchFunctionInputCtx inputsCtx = {0}; + inputsCtx.tensorInputs = inputs; + inputsCtx.tensorCount = nInputs; + inputsCtx.intInputs = sctx->intInputs; + inputsCtx.intCount = array_len(sctx->intInputs); + inputsCtx.floatInputs = sctx->floatInputs; + inputsCtx.floatCount = array_len(sctx->floatInputs); + inputsCtx.stringsInputs = sctx->stringInputs; + inputsCtx.stringCount = array_len(sctx->stringInputs); + inputsCtx.listSizes = sctx->listSizes; + inputsCtx.listCount = array_len(sctx->listSizes); + + if (!torchMatchScriptSchema(arguments, array_len(arguments), &inputsCtx, &error_descr)) { + RAI_SetError(error, RAI_ESCRIPTRUN, error_descr); + RedisModule_Free(error_descr); + return 1; + } + + torchRunScript(sctx->script->script, sctx->fnname, arguments, array_len(arguments), &inputsCtx, + outputs, nOutputs, &error_descr); if (error_descr) { RAI_SetError(error, RAI_ESCRIPTRUN, error_descr); diff --git a/src/execution/DAG/dag.c b/src/execution/DAG/dag.c index 0ac73c82c..377c40f82 100644 --- a/src/execution/DAG/dag.c +++ b/src/execution/DAG/dag.c @@ -33,7 +33,6 @@ #include #include "redisai.h" -#include "dag_parser.h" #include "rmutil/args.h" #include "rmutil/alloc.h" #include "util/arr.h" @@ -41,8 +40,10 @@ #include "util/queue.h" #include "util/string_utils.h" #include "execution/run_info.h" -#include "execution/modelRun_ctx.h" #include "execution/background_workers.h" +#include "execution/parsing/dag_parser.h" +#include "execution/execution_contexts/modelRun_ctx.h" +#include "execution/execution_contexts/scriptRun_ctx.h" #include "redis_ai_objects/model.h" #include "redis_ai_objects/stats.h" #include "redis_ai_objects/tensor.h" @@ -100,7 +101,7 @@ static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor, if (status == REDISMODULE_ERR) { return REDISMODULE_ERR; } - if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, tensor) != REDISMODULE_OK) { + if (RedisModule_ModuleTypeSetValue(key, RAI_TensorRedisType(), tensor) != REDISMODULE_OK) { RAI_SetError(err, RAI_EDAGRUN, "ERR could not save tensor"); RedisModule_CloseKey(key); return REDISMODULE_ERR; diff --git a/src/execution/DAG/dag_builder.c b/src/execution/DAG/dag_builder.c index 01dabc4c4..fc636e68d 100644 --- a/src/execution/DAG/dag_builder.c +++ b/src/execution/DAG/dag_builder.c @@ -1,8 +1,9 @@ -#include "dag_parser.h" #include "dag_builder.h" +#include "execution/parsing/dag_parser.h" #include "util/string_utils.h" #include "execution/run_info.h" -#include "execution/modelRun_ctx.h" +#include "execution/execution_contexts/modelRun_ctx.h" +#include "execution/execution_contexts/scriptRun_ctx.h" // Store the given arguments from the string in argv array and their amount in argc. int _StringToRMArray(const char *dag, RedisModuleString ***argv, int *argc, RAI_Error *err) { @@ -60,7 +61,7 @@ RAI_DAGRunOp *RAI_DAGCreateModelRunOp(RAI_Model *model) { op->commandType = REDISAI_DAG_CMD_MODELRUN; op->mctx = mctx; op->devicestr = model->devicestr; - op->runkey = RAI_HoldString(NULL, (RedisModuleString *)model->infokey); + op->runkey = RAI_HoldString((RedisModuleString *)model->infokey); return (RAI_DAGRunOp *)op; } @@ -72,7 +73,7 @@ RAI_DAGRunOp *RAI_DAGCreateScriptRunOp(RAI_Script *script, const char *func_name op->commandType = REDISAI_DAG_CMD_SCRIPTRUN; op->sctx = sctx; op->devicestr = script->devicestr; - op->runkey = RAI_HoldString(NULL, (RedisModuleString *)script->infokey); + op->runkey = RAI_HoldString((RedisModuleString *)script->infokey); return (RAI_DAGRunOp *)op; } diff --git a/src/execution/DAG/dag_op.c b/src/execution/DAG/dag_op.c new file mode 100644 index 000000000..68ced701c --- /dev/null +++ b/src/execution/DAG/dag_op.c @@ -0,0 +1,69 @@ +#include "dag_op.h" +#include "util/arr.h" +#include "execution/execution_contexts/modelRun_ctx.h" +#include "execution/execution_contexts/scriptRun_ctx.h" +/** + * Allocate the memory and initialise the RAI_DagOp. + * @param result Output parameter to capture allocated RAI_DagOp. + * @return REDISMODULE_OK on success, or REDISMODULE_ERR if the allocation + * failed. + */ +int RAI_InitDagOp(RAI_DagOp **result) { + RAI_DagOp *dagOp; + dagOp = (RAI_DagOp *)RedisModule_Calloc(1, sizeof(RAI_DagOp)); + + dagOp->commandType = REDISAI_DAG_CMD_NONE; + dagOp->runkey = NULL; + dagOp->inkeys = (RedisModuleString **)array_new(RedisModuleString *, 1); + dagOp->outkeys = (RedisModuleString **)array_new(RedisModuleString *, 1); + dagOp->inkeys_indices = array_new(size_t, 1); + dagOp->outkeys_indices = array_new(size_t, 1); + dagOp->outTensor = NULL; + dagOp->mctx = NULL; + dagOp->sctx = NULL; + dagOp->devicestr = NULL; + dagOp->duration_us = 0; + dagOp->result = -1; + RAI_InitError(&dagOp->err); + dagOp->argv = NULL; + dagOp->argc = 0; + + *result = dagOp; + return REDISMODULE_OK; +} + +void RAI_DagOpSetRunKey(RAI_DagOp *dagOp, RedisModuleString *runkey) { dagOp->runkey = runkey; } + +void RAI_FreeDagOp(RAI_DagOp *dagOp) { + + RAI_FreeError(dagOp->err); + if (dagOp->runkey) + RedisModule_FreeString(NULL, dagOp->runkey); + + if (dagOp->outTensor) + RAI_TensorFree(dagOp->outTensor); + + if (dagOp->mctx) { + RAI_ModelRunCtxFree(dagOp->mctx); + } + if (dagOp->sctx) { + RAI_ScriptRunCtxFree(dagOp->sctx); + } + + if (dagOp->inkeys) { + for (size_t i = 0; i < array_len(dagOp->inkeys); i++) { + RedisModule_FreeString(NULL, dagOp->inkeys[i]); + } + array_free(dagOp->inkeys); + } + array_free(dagOp->inkeys_indices); + + if (dagOp->outkeys) { + for (size_t i = 0; i < array_len(dagOp->outkeys); i++) { + RedisModule_FreeString(NULL, dagOp->outkeys[i]); + } + array_free(dagOp->outkeys); + } + array_free(dagOp->outkeys_indices); + RedisModule_Free(dagOp); +} diff --git a/src/execution/DAG/dag_op.h b/src/execution/DAG/dag_op.h new file mode 100644 index 000000000..5f3523f70 --- /dev/null +++ b/src/execution/DAG/dag_op.h @@ -0,0 +1,56 @@ +#pragma once +#include "redismodule.h" +#include "redis_ai_objects/err.h" +#include "redis_ai_objects/script.h" +#include "redis_ai_objects/model_struct.h" + +typedef enum DAGCommand { + REDISAI_DAG_CMD_NONE = 0, + REDISAI_DAG_CMD_TENSORSET, + REDISAI_DAG_CMD_TENSORGET, + REDISAI_DAG_CMD_MODELRUN, + REDISAI_DAG_CMD_SCRIPTRUN +} DAGCommand; + +typedef struct RAI_DagOp { + DAGCommand commandType; + RedisModuleString *runkey; + RedisModuleString **inkeys; + RedisModuleString **outkeys; + size_t *inkeys_indices; + size_t *outkeys_indices; + RAI_Tensor *outTensor; // The tensor to upload in TENSORSET op. + RAI_ModelRunCtx *mctx; + RAI_ScriptRunCtx *sctx; + uint fmt; // This is relevant for TENSORGET op. + char *devicestr; + int result; // REDISMODULE_OK or REDISMODULE_ERR + long long duration_us; + RAI_Error *err; + RedisModuleString **argv; + int argc; +} RAI_DagOp; + +/** + * Allocate the memory and initialise the RAI_DagOp. + * @param result Output parameter to capture allocated RAI_DagOp. + * @return REDISMODULE_OK on success, or REDISMODULE_ERR if the allocation + * failed. + */ +int RAI_InitDagOp(RAI_DagOp **result); + +/** + * Frees the memory allocated of RAI_DagOp + * @param ctx Context in which Redis modules operate + * @param RAI_DagOp context in which RedisAI command operates. + */ +void RAI_FreeDagOp(RAI_DagOp *dagOp); + +/** + * @brief Sets the key name of current dag op execution subject. The subject is either a model or a + * script. + * + * @param dagOp Current op. + * @param runkey Subject key name. + */ +void RAI_DagOpSetRunKey(RAI_DagOp *dagOp, RedisModuleString *runkey); diff --git a/src/execution/command_parser.c b/src/execution/command_parser.c index cad28148c..c5d98875e 100644 --- a/src/execution/command_parser.c +++ b/src/execution/command_parser.c @@ -1,391 +1,11 @@ -#include "redismodule.h" -#include "run_info.h" #include "command_parser.h" -#include "DAG/dag.h" -#include "DAG/dag_parser.h" -#include "util/string_utils.h" -#include "execution/modelRun_ctx.h" -#include "deprecated.h" -#include "utils.h" - -static int _ModelExecuteCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModuleString **argv, - RAI_Model **model, RAI_Error *error, - RedisModuleString ***inkeys, RedisModuleString ***outkeys, - RedisModuleString **runkey, long long *timeout) { - - if (argc < 8) { - RAI_SetError(error, RAI_EMODELRUN, - "ERR wrong number of arguments for 'AI.MODELEXECUTE' command"); - return REDISMODULE_ERR; - } - size_t arg_pos = 1; - const int status = RAI_GetModelFromKeyspace(ctx, argv[arg_pos], model, REDISMODULE_READ, error); - if (status == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - *runkey = RAI_HoldString(NULL, argv[arg_pos++]); - const char *arg_string = RedisModule_StringPtrLen(argv[arg_pos++], NULL); - - if (strcasecmp(arg_string, "INPUTS") != 0) { - RAI_SetError(error, RAI_EMODELRUN, "ERR INPUTS not specified"); - return REDISMODULE_ERR; - } - - long long ninputs = 0, noutputs = 0; - if (RedisModule_StringToLongLong(argv[arg_pos++], &ninputs) != REDISMODULE_OK) { - RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid argument for input_count"); - return REDISMODULE_ERR; - } - if (ninputs <= 0) { - RAI_SetError(error, RAI_EMODELRUN, "ERR Input count must be a positive integer"); - return REDISMODULE_ERR; - } - if ((*model)->ninputs != ninputs) { - RAI_SetError(error, RAI_EMODELRUN, - "Number of keys given as INPUTS here does not match model definition"); - return REDISMODULE_ERR; - } - // arg_pos = 4 - size_t first_input_pos = arg_pos; - if (first_input_pos + ninputs > argc) { - RAI_SetError( - error, RAI_EMODELRUN, - "ERR number of input keys to AI.MODELEXECUTE command does not match the number of " - "given arguments"); - return REDISMODULE_ERR; - } - for (; arg_pos < first_input_pos + ninputs; arg_pos++) { - *inkeys = array_append(*inkeys, RAI_HoldString(NULL, argv[arg_pos])); - } - - if (argc == arg_pos || - strcasecmp(RedisModule_StringPtrLen(argv[arg_pos++], NULL), "OUTPUTS") != 0) { - RAI_SetError(error, RAI_EMODELRUN, "ERR OUTPUTS not specified"); - return REDISMODULE_ERR; - } - if (argc == arg_pos || - RedisModule_StringToLongLong(argv[arg_pos++], &noutputs) != REDISMODULE_OK) { - RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid argument for output_count"); - } - if (noutputs <= 0) { - RAI_SetError(error, RAI_EMODELRUN, "ERR Output count must be a positive integer"); - return REDISMODULE_ERR; - } - if ((*model)->noutputs != noutputs) { - RAI_SetError(error, RAI_EMODELRUN, - "Number of keys given as OUTPUTS here does not match model definition"); - return REDISMODULE_ERR; - } - // arg_pos = ninputs+6, the argument that we already parsed are: - // AI.MODELEXECUTE INPUTS ... OUTPUTS - size_t first_output_pos = arg_pos; - if (first_output_pos + noutputs > argc) { - RAI_SetError( - error, RAI_EMODELRUN, - "ERR number of output keys to AI.MODELEXECUTE command does not match the number of " - "given arguments"); - return REDISMODULE_ERR; - } - for (; arg_pos < first_output_pos + noutputs; arg_pos++) { - *outkeys = array_append(*outkeys, RAI_HoldString(NULL, argv[arg_pos])); - } - if (arg_pos == argc) { - return REDISMODULE_OK; - } - - // Parse timeout arg if given and store it in timeout. - char *error_str; - arg_string = RedisModule_StringPtrLen(argv[arg_pos++], NULL); - if (!strcasecmp(arg_string, "TIMEOUT")) { - if (arg_pos == argc) { - RAI_SetError(error, RAI_EMODELRUN, "ERR No value provided for TIMEOUT"); - return REDISMODULE_ERR; - } - if (ParseTimeout(argv[arg_pos++], error, timeout) == REDISMODULE_ERR) - return REDISMODULE_ERR; - } else { - error_str = RedisModule_Alloc(strlen("Invalid argument: ") + strlen(arg_string) + 1); - sprintf(error_str, "Invalid argument: %s", arg_string); - RAI_SetError(error, RAI_EMODELRUN, error_str); - RedisModule_Free(error_str); - return REDISMODULE_ERR; - } - - // There are no more valid args to be processed. - if (arg_pos != argc) { - arg_string = RedisModule_StringPtrLen(argv[arg_pos], NULL); - error_str = RedisModule_Alloc(strlen("Invalid argument: ") + strlen(arg_string) + 1); - sprintf(error_str, "Invalid argument: %s", arg_string); - RAI_SetError(error, RAI_EMODELRUN, error_str); - RedisModule_Free(error_str); - return REDISMODULE_ERR; - } - return REDISMODULE_OK; -} - -int ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkeys, - RedisModuleString **outkeys, RAI_ModelRunCtx *mctx, RAI_Error *err) { - - RAI_Model *model = mctx->model; - RAI_Tensor *t; - RedisModuleKey *key; - char *opname = NULL; - size_t ninputs = array_len(inkeys), noutputs = array_len(outkeys); - for (size_t i = 0; i < ninputs; i++) { - const int status = - RAI_GetTensorFromKeyspace(ctx, inkeys[i], &key, &t, REDISMODULE_READ, err); - if (status == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - if (model->inputs) - opname = model->inputs[i]; - RAI_ModelRunCtxAddInput(mctx, opname, t); - } - - for (size_t i = 0; i < noutputs; i++) { - if (model->outputs) { - opname = model->outputs[i]; - } - if (!VerifyKeyInThisShard(ctx, outkeys[i])) { // Relevant for enterprise cluster. - RAI_SetError(err, RAI_EMODELRUN, - "ERR CROSSSLOT Keys in request don't hash to the same slot"); - return REDISMODULE_ERR; - } - RAI_ModelRunCtxAddOutput(mctx, opname); - } - return REDISMODULE_OK; -} - -int ParseModelExecuteCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv, - int argc) { - - int res = REDISMODULE_ERR; - // Build a ModelRunCtx from command. - RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL); - RAI_Model *model; - long long timeout = 0; - if (_ModelExecuteCommand_ParseArgs(ctx, argc, argv, &model, rinfo->err, ¤tOp->inkeys, - ¤tOp->outkeys, ¤tOp->runkey, - &timeout) == REDISMODULE_ERR) { - goto cleanup; - } - - if (timeout > 0 && !rinfo->single_op_dag) { - RAI_SetError(rinfo->err, RAI_EDAGBUILDER, "ERR TIMEOUT not allowed within a DAG command"); - goto cleanup; - } - - RAI_ModelRunCtx *mctx = RAI_ModelRunCtxCreate(model); - currentOp->commandType = REDISAI_DAG_CMD_MODELRUN; - currentOp->mctx = mctx; - currentOp->devicestr = mctx->model->devicestr; - - if (rinfo->single_op_dag) { - rinfo->timeout = timeout; - // Set params in ModelRunCtx, bring inputs from key space. - if (ModelRunCtx_SetParams(ctx, currentOp->inkeys, currentOp->outkeys, mctx, rinfo->err) == - REDISMODULE_ERR) - goto cleanup; - } - res = REDISMODULE_OK; - -cleanup: - RedisModule_FreeThreadSafeContext(ctx); - return res; -} - -static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, - RAI_Script **script, RAI_Error *error, - RedisModuleString ***inkeys, RedisModuleString ***outkeys, - RedisModuleString **runkey, char const **func_name, - long long *timeout, int *variadic) { - - if (argc < 3) { - RAI_SetError(error, RAI_ESCRIPTRUN, - "ERR wrong number of arguments for 'AI.SCRIPTRUN' command"); - return REDISMODULE_ERR; - } - size_t argpos = 1; - const int status = - RAI_GetScriptFromKeyspace(ctx, argv[argpos], script, REDISMODULE_READ, error); - if (status == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - RAI_HoldString(NULL, argv[argpos]); - *runkey = argv[argpos]; - - const char *arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL); - if (!strcasecmp(arg_string, "TIMEOUT") || !strcasecmp(arg_string, "INPUTS") || - !strcasecmp(arg_string, "OUTPUTS")) { - RAI_SetError(error, RAI_ESCRIPTRUN, "ERR function name not specified"); - return REDISMODULE_ERR; - } - *func_name = arg_string; - - bool is_input = false; - bool is_output = false; - bool timeout_set = false; - bool inputs_done = false; - size_t ninputs = 0, noutputs = 0; - int varidic_start_pos = -1; - - while (++argpos < argc) { - arg_string = RedisModule_StringPtrLen(argv[argpos], NULL); - - // Parse timeout arg if given and store it in timeout - if (!strcasecmp(arg_string, "TIMEOUT") && !timeout_set) { - if (ParseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR) - return REDISMODULE_ERR; - timeout_set = true; - continue; - } - - if (!strcasecmp(arg_string, "INPUTS")) { - if (inputs_done) { - RAI_SetError(error, RAI_ESCRIPTRUN, - "ERR Already encountered an INPUTS section in SCRIPTRUN"); - return REDISMODULE_ERR; - } - if (is_input) { - RAI_SetError(error, RAI_ESCRIPTRUN, - "ERR Already encountered an INPUTS keyword in SCRIPTRUN"); - return REDISMODULE_ERR; - } - is_input = true; - is_output = false; - continue; - } - if (!strcasecmp(arg_string, "OUTPUTS")) { - if (is_output) { - RAI_SetError(error, RAI_ESCRIPTRUN, - "ERR Already encountered an OUTPUTS keyword in SCRIPTRUN"); - return REDISMODULE_ERR; - } - is_input = false; - is_output = true; - inputs_done = true; - continue; - } - if (!strcasecmp(arg_string, "$")) { - if (!is_input) { - RAI_SetError( - error, RAI_ESCRIPTRUN, - "ERR Encountered a variable size list of tensors outside of input section"); - return REDISMODULE_ERR; - } - if (varidic_start_pos > -1) { - RAI_SetError(error, RAI_ESCRIPTRUN, - "ERR Already encountered a variable size list of tensors"); - return REDISMODULE_ERR; - } - varidic_start_pos = ninputs; - continue; - } - // Parse argument name - RAI_HoldString(NULL, argv[argpos]); - if (is_input) { - ninputs++; - *inkeys = array_append(*inkeys, argv[argpos]); - } else if (is_output) { - noutputs++; - *outkeys = array_append(*outkeys, argv[argpos]); - } else { - RAI_SetError(error, RAI_ESCRIPTRUN, "ERR Unrecongnized parameter to SCRIPTRUN"); - return REDISMODULE_ERR; - } - } - *variadic = varidic_start_pos; - - return REDISMODULE_OK; -} - -/** - * Extract the params for the ScriptCtxRun object from AI.SCRIPTRUN arguments. - * - * @param ctx Context in which Redis modules operate. - * @param inkeys Script input tensors keys, as an array of strings. - * @param outkeys Script output tensors keys, as an array of strings. - * @param sctx Destination Script context to store the parsed data. - * @return REDISMODULE_OK in case of success, REDISMODULE_ERR otherwise. - */ - -static int _ScriptRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkeys, - RedisModuleString **outkeys, RAI_ScriptRunCtx *sctx, - RAI_Error *err) { - - RAI_Tensor *t; - RedisModuleKey *key; - size_t ninputs = array_len(inkeys), noutputs = array_len(outkeys); - for (size_t i = 0; i < ninputs; i++) { - const int status = - RAI_GetTensorFromKeyspace(ctx, inkeys[i], &key, &t, REDISMODULE_READ, err); - if (status == REDISMODULE_ERR) { - RedisModule_Log(ctx, "warning", "could not load input tensor %s from keyspace", - RedisModule_StringPtrLen(inkeys[i], NULL)); - return REDISMODULE_ERR; - } - RAI_ScriptRunCtxAddInput(sctx, t, err); - } - for (size_t i = 0; i < noutputs; i++) { - RAI_ScriptRunCtxAddOutput(sctx); - } - return REDISMODULE_OK; -} - -int ParseTimeout(RedisModuleString *timeout_arg, RAI_Error *error, long long *timeout) { - - const int retval = RedisModule_StringToLongLong(timeout_arg, timeout); - if (retval != REDISMODULE_OK || *timeout <= 0) { - RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid value for TIMEOUT"); - return REDISMODULE_ERR; - } - return REDISMODULE_OK; -} - -int ParseScriptRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv, - int argc) { - - int res = REDISMODULE_ERR; - // Build a ScriptRunCtx from command. - RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL); - // int lock_status = RedisModule_ThreadSafeContextTryLock(ctx); - RAI_Script *script; - const char *func_name = NULL; - - long long timeout = 0; - int variadic = -1; - if (_ScriptRunCommand_ParseArgs(ctx, argv, argc, &script, rinfo->err, ¤tOp->inkeys, - ¤tOp->outkeys, ¤tOp->runkey, &func_name, &timeout, - &variadic) == REDISMODULE_ERR) { - goto cleanup; - } - if (timeout > 0 && !rinfo->single_op_dag) { - RAI_SetError(rinfo->err, RAI_EDAGBUILDER, "ERR TIMEOUT not allowed within a DAG command"); - goto cleanup; - } - - RAI_ScriptRunCtx *sctx = RAI_ScriptRunCtxCreate(script, func_name); - sctx->variadic = variadic; - currentOp->sctx = sctx; - currentOp->commandType = REDISAI_DAG_CMD_SCRIPTRUN; - currentOp->devicestr = sctx->script->devicestr; - - if (rinfo->single_op_dag) { - rinfo->timeout = timeout; - // Set params in ScriptRunCtx, bring inputs from key space. - if (_ScriptRunCtx_SetParams(ctx, currentOp->inkeys, currentOp->outkeys, sctx, rinfo->err) == - REDISMODULE_ERR) - goto cleanup; - } - res = REDISMODULE_OK; - -cleanup: - // if (lock_status == REDISMODULE_OK) { - // RedisModule_ThreadSafeContextUnlock(ctx); - //} - RedisModule_FreeThreadSafeContext(ctx); - return res; -} +#include "redismodule.h" +#include "execution/run_info.h" +#include "execution/DAG/dag.h" +#include "execution/parsing/dag_parser.h" +#include "execution/parsing/deprecated.h" +#include "execution/parsing/model_commands_parser.h" +#include "execution/parsing/script_commands_parser.h" int RedisAI_ExecuteCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, RunCommand command, bool ro_dag) { @@ -415,6 +35,13 @@ int RedisAI_ExecuteCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar rinfo->dagOps = array_append(rinfo->dagOps, scriptRunOp); status = ParseScriptRunCommand(rinfo, scriptRunOp, argv, argc); break; + case CMD_SCRIPTEXECUTE: + rinfo->single_op_dag = 1; + RAI_DagOp *scriptExecOp; + RAI_InitDagOp(&scriptExecOp); + rinfo->dagOps = array_append(rinfo->dagOps, scriptExecOp); + status = ParseScriptExecuteCommand(rinfo, scriptExecOp, argv, argc); + break; case CMD_DAGRUN: status = ParseDAGRunCommand(rinfo, ctx, argv, argc, ro_dag); break; diff --git a/src/execution/command_parser.h b/src/execution/command_parser.h index 5216f1bc7..5c10fcb8e 100644 --- a/src/execution/command_parser.h +++ b/src/execution/command_parser.h @@ -7,49 +7,10 @@ typedef enum RunCommand { CMD_MODELRUN = 0, CMD_SCRIPTRUN, CMD_DAGRUN, - CMD_MODELEXECUTE + CMD_MODELEXECUTE, + CMD_SCRIPTEXECUTE } RunCommand; -/** - * @brief Parse and validate MODELEXECUTE command: create a modelRunCtx based on the model obtained - * from the key space and save it in the op. The keys of the input and output tensors are stored in - * the op's inkeys and outkeys arrays, the model key is saved in op's runkey, and the given timeout - * is saved as well (if given, otherwise it is zero). - * @return Returns REDISMODULE_OK if the command is valid, REDISMODULE_ERR otherwise. - */ -int ParseModelExecuteCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv, - int argc); - -/** - * Extract the params for the ModelCtxRun object from AI.MODELEXECUTE arguments. - * - * @param ctx Context in which Redis modules operate - * @param inkeys Model input tensors keys, as an array of strings - * @param outkeys Model output tensors keys, as an array of strings - * @param mctx Destination Model context to store the parsed data - * @return REDISMODULE_OK in case of success, REDISMODULE_ERR otherwise - */ - -int ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkeys, - RedisModuleString **outkeys, RAI_ModelRunCtx *mctx, RAI_Error *err); - -/** - * @brief Parse and validate SCRIPTRUN command: create a scriptRunCtx based on the script obtained - * from the key space and the function name given, and save it in the op. The keys of the input and - * output tensors are stored in the op's inkeys and outkeys arrays, the script key is saved in op's - * runkey, and the given timeout is saved as well (if given, otherwise it is zero). - * @return Returns REDISMODULE_OK if the command is valid, REDISMODULE_ERR otherwise. - */ -int ParseScriptRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv, - int argc); - -/** - * @brief Parse and validate TIMEOUT argument. If it is valid, store it in timeout. - * Otherwise set an error. - * @return Returns REDISMODULE_OK if the command is valid, REDISMODULE_ERR otherwise. - */ -int ParseTimeout(RedisModuleString *timeout_arg, RAI_Error *error, long long *timeout); - /** * @brief Parse and execute RedisAI run command. After parsing and validation, the resulted * runInfo (DAG) is queued and the client is blocked until the execution is complete (async diff --git a/src/execution/execution_contexts/modelRun_ctx.c b/src/execution/execution_contexts/modelRun_ctx.c new file mode 100644 index 000000000..f9816d2dd --- /dev/null +++ b/src/execution/execution_contexts/modelRun_ctx.c @@ -0,0 +1,183 @@ + +#include "modelRun_ctx.h" +#include "util/string_utils.h" +#include "execution/utils.h" +#include "execution/DAG/dag.h" +#include "execution/run_info.h" +#include "backends/backends.h" + +static int _Model_RunCtxAddParam(RAI_ModelCtxParam **paramArr, const char *name, + RAI_Tensor *tensor) { + + RAI_ModelCtxParam param = { + .name = name, + .tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL, + }; + *paramArr = array_append(*paramArr, param); + return 1; +} + +RAI_ModelRunCtx *RAI_ModelRunCtxCreate(RAI_Model *model) { +#define PARAM_INITIAL_SIZE 10 + RAI_ModelRunCtx *mctx = RedisModule_Calloc(1, sizeof(*mctx)); + mctx->model = RAI_ModelGetShallowCopy(model); + mctx->inputs = array_new(RAI_ModelCtxParam, PARAM_INITIAL_SIZE); + mctx->outputs = array_new(RAI_ModelCtxParam, PARAM_INITIAL_SIZE); + return mctx; +#undef PARAM_INITIAL_SIZE +} + +int RAI_ModelRunCtxAddInput(RAI_ModelRunCtx *mctx, const char *inputName, RAI_Tensor *inputTensor) { + return _Model_RunCtxAddParam(&mctx->inputs, inputName, inputTensor); +} + +int RAI_ModelRunCtxAddOutput(RAI_ModelRunCtx *mctx, const char *outputName) { + return _Model_RunCtxAddParam(&mctx->outputs, outputName, NULL); +} + +size_t RAI_ModelRunCtxNumInputs(RAI_ModelRunCtx *mctx) { return array_len(mctx->inputs); } + +size_t RAI_ModelRunCtxNumOutputs(RAI_ModelRunCtx *mctx) { return array_len(mctx->outputs); } + +RAI_Tensor *RAI_ModelRunCtxInputTensor(RAI_ModelRunCtx *mctx, size_t index) { + assert(RAI_ModelRunCtxNumInputs(mctx) > index && index >= 0); + return mctx->inputs[index].tensor; +} + +RAI_Tensor *RAI_ModelRunCtxOutputTensor(RAI_ModelRunCtx *mctx, size_t index) { + assert(RAI_ModelRunCtxNumOutputs(mctx) > index && index >= 0); + return mctx->outputs[index].tensor; +} + +void RAI_ModelRunCtxFree(RAI_ModelRunCtx *mctx) { + for (size_t i = 0; i < array_len(mctx->inputs); ++i) { + RAI_TensorFree(mctx->inputs[i].tensor); + } + + for (size_t i = 0; i < array_len(mctx->outputs); ++i) { + if (mctx->outputs[i].tensor) { + RAI_TensorFree(mctx->outputs[i].tensor); + } + } + + array_free(mctx->inputs); + array_free(mctx->outputs); + + RAI_Error err = {0}; + RAI_ModelFree(mctx->model, &err); + + if (err.code != RAI_OK) { + // TODO: take it to client somehow + RAI_ClearError(&err); + } + RedisModule_Free(mctx); +} + +int ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkeys, + RedisModuleString **outkeys, RAI_ModelRunCtx *mctx, RAI_Error *err) { + + RAI_Model *model = mctx->model; + RAI_Tensor *t; + RedisModuleKey *key; + char *opname = NULL; + size_t ninputs = array_len(inkeys), noutputs = array_len(outkeys); + for (size_t i = 0; i < ninputs; i++) { + const int status = + RAI_GetTensorFromKeyspace(ctx, inkeys[i], &key, &t, REDISMODULE_READ, err); + if (status == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + if (model->inputs) + opname = model->inputs[i]; + RAI_ModelRunCtxAddInput(mctx, opname, t); + } + + for (size_t i = 0; i < noutputs; i++) { + if (model->outputs) { + opname = model->outputs[i]; + } + if (!VerifyKeyInThisShard(ctx, outkeys[i])) { // Relevant for enterprise cluster. + RAI_SetError(err, RAI_EMODELRUN, + "ERR CROSSSLOT Keys in request don't hash to the same slot"); + return REDISMODULE_ERR; + } + RAI_ModelRunCtxAddOutput(mctx, opname); + } + return REDISMODULE_OK; +} + +int RAI_ModelRun(RAI_ModelRunCtx **mctxs, long long n, RAI_Error *err) { + int ret; + + if (n == 0) { + RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Nothing to run"); + return REDISMODULE_ERR; + } + + RAI_ModelRunCtx **mctxs_arr = array_newlen(RAI_ModelRunCtx *, n); + for (int i = 0; i < n; i++) { + mctxs_arr[i] = mctxs[i]; + } + + switch (mctxs_arr[0]->model->backend) { + case RAI_BACKEND_TENSORFLOW: + if (!RAI_backends.tf.model_run) { + RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TF"); + return REDISMODULE_ERR; + } + ret = RAI_backends.tf.model_run(mctxs_arr, err); + break; + case RAI_BACKEND_TFLITE: + if (!RAI_backends.tflite.model_run) { + RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TFLITE"); + return REDISMODULE_ERR; + } + ret = RAI_backends.tflite.model_run(mctxs_arr, err); + break; + case RAI_BACKEND_TORCH: + if (!RAI_backends.torch.model_run) { + RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH"); + return REDISMODULE_ERR; + } + ret = RAI_backends.torch.model_run(mctxs_arr, err); + break; + case RAI_BACKEND_ONNXRUNTIME: + if (!RAI_backends.onnx.model_run) { + RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: ONNX"); + return REDISMODULE_ERR; + } + ret = RAI_backends.onnx.model_run(mctxs_arr, err); + break; + default: + RAI_SetError(err, RAI_EUNSUPPORTEDBACKEND, "ERR Unsupported backend"); + return REDISMODULE_ERR; + } + + array_free(mctxs_arr); + + return ret; +} + +int RAI_ModelRunAsync(RAI_ModelRunCtx *mctx, RAI_OnFinishCB ModelAsyncFinish, void *private_data) { + + RedisAI_RunInfo *rinfo = NULL; + RAI_InitRunInfo(&rinfo); + + rinfo->single_op_dag = 1; + rinfo->OnFinish = (RedisAI_OnFinishCB)ModelAsyncFinish; + rinfo->private_data = private_data; + + RAI_DagOp *op; + RAI_InitDagOp(&op); + op->commandType = REDISAI_DAG_CMD_MODELRUN; + op->devicestr = mctx->model->devicestr; + op->mctx = mctx; + + rinfo->dagOps = array_append(rinfo->dagOps, op); + rinfo->dagOpCount = 1; + if (DAG_InsertDAGToQueue(rinfo) != REDISMODULE_OK) { + RAI_FreeRunInfo(rinfo); + return REDISMODULE_ERR; + } + return REDISMODULE_OK; +} diff --git a/src/execution/modelRun_ctx.h b/src/execution/execution_contexts/modelRun_ctx.h similarity index 55% rename from src/execution/modelRun_ctx.h rename to src/execution/execution_contexts/modelRun_ctx.h index 289b1c6a1..86b9754e2 100644 --- a/src/execution/modelRun_ctx.h +++ b/src/execution/execution_contexts/modelRun_ctx.h @@ -76,3 +76,47 @@ RAI_Tensor *RAI_ModelRunCtxInputTensor(RAI_ModelRunCtx *mctx, size_t index); * @return RAI_Tensor */ RAI_Tensor *RAI_ModelRunCtxOutputTensor(RAI_ModelRunCtx *mctx, size_t index); + +/** + * Extract the params for the ModelCtxRun object from AI.MODELEXECUTE arguments. + * + * @param ctx Context in which Redis modules operate + * @param inkeys Model input tensors keys, as an array of strings + * @param outkeys Model output tensors keys, as an array of strings + * @param mctx Destination Model context to store the parsed data + * @return REDISMODULE_OK in case of success, REDISMODULE_ERR otherwise + */ + +// TODO: Remove this once modelrunctx and scriptrunctx have common base struct. +int ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkeys, + RedisModuleString **outkeys, RAI_ModelRunCtx *mctx, RAI_Error *err); + +/** + * Given the input array of mctxs, run the associated backend + * session. If the input array of model context runs is larger than one, then + * each backend's `model_run` is responsible for concatenating tensors, and run + * the model in batches with the size of the input array. On success, the + * tensors corresponding to outputs[0,noutputs-1] are placed in each + * RAI_ModelRunCtx output tensors array. Relies on each backend's `model_run` + * definition. + * + * @param mctxs array on input model contexts + * @param n length of input model contexts array + * @param error error data structure to store error message in the case of + * failures + * @return REDISMODULE_OK if the underlying backend `model_run` runned + * successfully, or REDISMODULE_ERR if failed. + */ +int RAI_ModelRun(RAI_ModelRunCtx **mctxs, long long n, RAI_Error *err); + +/** + * Insert the ModelRunCtx to the run queues so it will run asynchronously. + * + * @param mctx ModelRunCtx to execute + * @param ModelAsyncFinish A callback that will be called when the execution is finished. + * @param private_data This is going to be sent to to the ModelAsyncFinish. + * @return REDISMODULE_OK if the mctx was insert to the queues successfully, REDISMODULE_ERR + * otherwise. + */ + +int RAI_ModelRunAsync(RAI_ModelRunCtx *mctx, RAI_OnFinishCB ModelAsyncFinish, void *private_data); diff --git a/src/execution/execution_contexts/scriptRun_ctx.c b/src/execution/execution_contexts/scriptRun_ctx.c new file mode 100644 index 000000000..6a66798c9 --- /dev/null +++ b/src/execution/execution_contexts/scriptRun_ctx.c @@ -0,0 +1,233 @@ +#include "scriptRun_ctx.h" +#include "redismodule.h" +#include "execution/utils.h" +#include "execution/DAG/dag.h" +#include "execution/run_info.h" +#include "backends/backends.h" +#include "util/string_utils.h" + +RAI_ScriptRunCtx *RAI_ScriptRunCtxCreate(RAI_Script *script, const char *fnname) { +#define PARAM_INITIAL_SIZE 10 + RAI_ScriptRunCtx *sctx = RedisModule_Calloc(1, sizeof(*sctx)); + sctx->script = RAI_ScriptGetShallowCopy(script); + sctx->inputs = array_new(RAI_ScriptCtxParam, PARAM_INITIAL_SIZE); + sctx->outputs = array_new(RAI_ScriptCtxParam, PARAM_INITIAL_SIZE); + sctx->fnname = RedisModule_Strdup(fnname); + sctx->listSizes = array_new(size_t, PARAM_INITIAL_SIZE); + sctx->intInputs = array_new(int32_t, PARAM_INITIAL_SIZE); + sctx->floatInputs = array_new(float, PARAM_INITIAL_SIZE); + sctx->stringInputs = array_new(RedisModuleString *, PARAM_INITIAL_SIZE); + return sctx; +} + +static int _Script_RunCtxAddParam(RAI_ScriptCtxParam **paramArr, RAI_Tensor *tensor) { + RAI_ScriptCtxParam param = { + .tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL, + }; + *paramArr = array_append(*paramArr, param); + return 1; +} + +// Deprecated. +int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor, RAI_Error *error) { + // Even if variadic is set, we still allow to add inputs in the LLAPI + _Script_RunCtxAddParam(&sctx->inputs, inputTensor); + return 1; +} + +// Deprecated. +int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx *sctx, RAI_Tensor **inputTensors, size_t len, + RAI_Error *err) { + int res; + for (size_t i = 0; i < len; i++) { + res = _Script_RunCtxAddParam(&sctx->inputs, inputTensors[i]); + } + sctx->listSizes = array_append(sctx->listSizes, len); + return 1; +} + +int RAI_ScriptRunCtxAddTensorInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor) { + _Script_RunCtxAddParam(&sctx->inputs, inputTensor); + return 1; +} + +int RAI_ScriptRunCtxAddIntInput(RAI_ScriptRunCtx *sctx, int32_t i) { + sctx->intInputs = array_append(sctx->intInputs, i); + return 1; +} + +int RAI_ScriptRunCtxAddFloatInput(RAI_ScriptRunCtx *sctx, float f) { + sctx->floatInputs = array_append(sctx->floatInputs, f); + return 1; +} + +int RAI_ScriptRunCtxAddRStringInput(RAI_ScriptRunCtx *sctx, RedisModuleString *s) { + sctx->stringInputs = array_append(sctx->stringInputs, RAI_HoldString(s)); + return 1; +} + +int RAI_ScriptRunCtxAddStringInput(RAI_ScriptRunCtx *sctx, const char *s, size_t len) { + RedisModuleString *rs = RedisModule_CreateString(NULL, s, len); + return RAI_ScriptRunCtxAddRStringInput(sctx, rs); +} + +int RAI_ScriptRunCtxAddTensorInputList(RAI_ScriptRunCtx *sctx, RAI_Tensor **inputTensors, + size_t count) { + int res = 1; + for (size_t i = 0; i < count; i++) { + res &= RAI_ScriptRunCtxAddTensorInput(sctx, inputTensors[i]); + } + return res; +} + +int RAI_ScriptRunCtxAddIntInputList(RAI_ScriptRunCtx *sctx, int32_t *intInputs, size_t count) { + int res = 1; + for (size_t i = 0; i < count; i++) { + res &= RAI_ScriptRunCtxAddIntInput(sctx, intInputs[i]); + } + return res; +} + +int RAI_ScriptRunCtxAddFloatInputList(RAI_ScriptRunCtx *sctx, float *floatInputs, size_t count) { + int res = 1; + for (size_t i = 0; i < count; i++) { + res &= RAI_ScriptRunCtxAddFloatInput(sctx, floatInputs[i]); + } + return res; +} + +int RAI_ScriptRunCtxAddRStringInputList(RAI_ScriptRunCtx *sctx, RedisModuleString **stringInputs, + size_t count) { + int res = 1; + for (size_t i = 0; i < count; i++) { + res &= RAI_ScriptRunCtxAddRStringInput(sctx, stringInputs[i]); + } + return res; +} + +int RAI_ScriptRunCtxAddStringInputList(RAI_ScriptRunCtx *sctx, const char **stringInputs, + size_t *lens, size_t count) { + int res = 1; + for (size_t i = 0; i < count; i++) { + res &= RAI_ScriptRunCtxAddStringInput(sctx, stringInputs[i], lens[i]); + } + return res; +} + +int RAI_ScriptRunCtxAddListSize(RAI_ScriptRunCtx *sctx, size_t len) { + sctx->listSizes = array_append(sctx->listSizes, len); + return 1; +} + +int RAI_ScriptRunCtxAddOutput(RAI_ScriptRunCtx *sctx) { + return _Script_RunCtxAddParam(&sctx->outputs, NULL); +} + +size_t RAI_ScriptRunCtxNumOutputs(RAI_ScriptRunCtx *sctx) { return array_len(sctx->outputs); } + +RAI_Tensor *RAI_ScriptRunCtxOutputTensor(RAI_ScriptRunCtx *sctx, size_t index) { + assert(RAI_ScriptRunCtxNumOutputs(sctx) > index && index >= 0); + return sctx->outputs[index].tensor; +} + +void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx *sctx) { + + for (size_t i = 0; i < array_len(sctx->inputs); ++i) { + RAI_TensorFree(sctx->inputs[i].tensor); + } + + for (size_t i = 0; i < array_len(sctx->outputs); ++i) { + if (sctx->outputs[i].tensor) { + RAI_TensorFree(sctx->outputs[i].tensor); + } + } + + for (size_t i = 0; i < array_len(sctx->stringInputs); ++i) { + RedisModule_FreeString(NULL, sctx->stringInputs[i]); + } + + array_free(sctx->inputs); + array_free(sctx->outputs); + array_free(sctx->listSizes); + array_free(sctx->stringInputs); + array_free(sctx->intInputs); + array_free(sctx->floatInputs); + + RedisModule_Free(sctx->fnname); + + RAI_Error err = {0}; + RAI_ScriptFree(sctx->script, &err); + + if (err.code != RAI_OK) { + // TODO: take it to client somehow + printf("ERR: %s\n", err.detail); + RAI_ClearError(&err); + } + + RedisModule_Free(sctx); +} + +int RAI_ScriptRun(RAI_ScriptRunCtx *sctx, RAI_Error *err) { + if (!RAI_backends.torch.script_run) { + RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH"); + return REDISMODULE_ERR; + } + + return RAI_backends.torch.script_run(sctx, err); +} + +int RAI_ScriptRunAsync(RAI_ScriptRunCtx *sctx, RAI_OnFinishCB ScriptAsyncFinish, + void *private_data) { + + RedisAI_RunInfo *rinfo = NULL; + RAI_InitRunInfo(&rinfo); + + rinfo->single_op_dag = 1; + rinfo->OnFinish = (RedisAI_OnFinishCB)ScriptAsyncFinish; + rinfo->private_data = private_data; + + RAI_DagOp *op; + RAI_InitDagOp(&op); + + op->commandType = REDISAI_DAG_CMD_SCRIPTRUN; + op->devicestr = sctx->script->devicestr; + op->sctx = sctx; + + rinfo->dagOps = array_append(rinfo->dagOps, op); + rinfo->dagOpCount = 1; + if (DAG_InsertDAGToQueue(rinfo) != REDISMODULE_OK) { + RAI_FreeRunInfo(rinfo); + return REDISMODULE_ERR; + } + return REDISMODULE_OK; +} + +TorchScriptFunctionArgumentType *RAI_ScriptRunCtxGetSignature(RAI_ScriptRunCtx *sctx) { + return AI_dictFetchValue(sctx->script->functionData, sctx->fnname); +} + +size_t RAI_ScriptRunCtxGetInputListLen(RAI_ScriptRunCtx *sctx, size_t index) { + return sctx->listSizes[index]; +} + +int ScriptRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkeys, + RedisModuleString **outkeys, RAI_ScriptRunCtx *sctx, RAI_Error *err) { + + RAI_Tensor *t; + RedisModuleKey *key; + size_t ninputs = array_len(inkeys), noutputs = array_len(outkeys); + for (size_t i = 0; i < ninputs; i++) { + const int status = + RAI_GetTensorFromKeyspace(ctx, inkeys[i], &key, &t, REDISMODULE_READ, err); + if (status == REDISMODULE_ERR) { + RedisModule_Log(ctx, "warning", "could not load input tensor %s from keyspace", + RedisModule_StringPtrLen(inkeys[i], NULL)); + return REDISMODULE_ERR; + } + RAI_ScriptRunCtxAddInput(sctx, t, err); + } + for (size_t i = 0; i < noutputs; i++) { + RAI_ScriptRunCtxAddOutput(sctx); + } + return REDISMODULE_OK; +} diff --git a/src/execution/execution_contexts/scriptRun_ctx.h b/src/execution/execution_contexts/scriptRun_ctx.h new file mode 100644 index 000000000..73e2673a7 --- /dev/null +++ b/src/execution/execution_contexts/scriptRun_ctx.h @@ -0,0 +1,160 @@ +#pragma once +#include "redis_ai_objects/script.h" + +/** + * Allocates the RAI_ScriptRunCtx data structure required for async background + * work within `RedisAI_RunInfo` structure on RedisAI blocking commands + * + * @param script input script + * @param fnname function name to used from the script + * @return RAI_ScriptRunCtx to be used within + */ +RAI_ScriptRunCtx *RAI_ScriptRunCtxCreate(RAI_Script *script, const char *fnname); + +/** + * Allocates a RAI_ScriptCtxParam data structure, and enforces a shallow copy of + * the provided input tensor, adding it to the input tensors array of the + * RAI_ScriptRunCtx. + * + * @param sctx input RAI_ScriptRunCtx to add the input tensor + * @param inputTensor input tensor structure + * @return returns 1 on success, 0 in case of error. + */ +int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor, RAI_Error *error); + +/** + * For each Allocates a RAI_ScriptCtxParam data structure, and enforces a + * shallow copy of the provided input tensor, adding it to the input tensors + * array of the RAI_ScriptRunCtx. + * + * @param sctx input RAI_ScriptRunCtx to add the input tensor + * @param inputTensors input tensors array + * @param len input tensors array len + * @return returns 1 on success, 0 in case of error. + */ +int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx *sctx, RAI_Tensor **inputTensors, size_t len, + RAI_Error *error); + +/** + * @brief Adds a list length to the given script contxt. + * + * @param sctx input RAI_ScriptRunCtx to add the list len. + * @param len input tensors array len + * @return int returns 1 on success, 0 in case of error. + */ +int RAI_ScriptRunCtxAddListSize(RAI_ScriptRunCtx *sctx, size_t len); + +/** + * Allocates a RAI_ScriptCtxParam data structure, and sets the tensor reference + * to NULL ( will be set after SCRIPTRUN ), adding it to the outputs tensors + * array of the RAI_ScriptRunCtx. + * + * @param sctx input RAI_ScriptRunCtx to add the output tensor + * @return returns 1 on success ( always returns success ) + */ +int RAI_ScriptRunCtxAddOutput(RAI_ScriptRunCtx *sctx); + +/** + * Returns the total number of output tensors of the RAI_ScriptRunCtx + * + * @param sctx RAI_ScriptRunCtx + * @return the total number of output tensors of the RAI_ScriptRunCtx + */ +size_t RAI_ScriptRunCtxNumOutputs(RAI_ScriptRunCtx *sctx); + +/** + * Get the RAI_Tensor at the output array index position + * + * @param sctx RAI_ScriptRunCtx + * @param index input array index position + * @return RAI_Tensor + */ +RAI_Tensor *RAI_ScriptRunCtxOutputTensor(RAI_ScriptRunCtx *sctx, size_t index); + +/** + * Frees the RAI_ScriptRunCtx data structure used within for async background + * work + * + * @param sctx + */ +void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx *sctx); + +/** + * Given the input script context, run associated script + * session. On success, the tensors corresponding to outputs[0,noutputs-1] are + * placed in the RAI_ScriptRunCtx output tensors array. Relies on PyTorch's + * `script_run` definition. + * + * @param sctx input script context + * @param error error data structure to store error message in the case of + * failures + * @return REDISMODULE_OK if the underlying backend `script_run` ran + * successfully, or REDISMODULE_ERR if failed. + */ +int RAI_ScriptRun(RAI_ScriptRunCtx *sctx, RAI_Error *err); + +/** + * Insert the ScriptRunCtx to the run queues so it will run asynchronously. + * + * @param sctx ScriptRunCtx to execute + * @param ScriptAsyncFinish A callback that will be called when the execution is finished. + * @param private_data This is going to be sent to to the ScriptAsyncFinish. + * @return REDISMODULE_OK if the sctx was insert to the queues successfully, REDISMODULE_ERR + * otherwise. + */ +int RAI_ScriptRunAsync(RAI_ScriptRunCtx *sctx, RAI_OnFinishCB ScriptAsyncFinish, + void *private_data); + +/** + * @brief Retuens the current Script Run context function signature + * + * @param sctx ScriptRunCtx + * @return TorchScriptFunctionArgumentType* Null in case of no match, arr of argument type according + * to function signature + */ +TorchScriptFunctionArgumentType *RAI_ScriptRunCtxGetSignature(RAI_ScriptRunCtx *sctx); + +/** + * @brief Returns the length of the input list in the given index. + * + * @param sctx ScriptRunCtx. + * @param index Index of the list out of all the lists given as inputs. + * @return size_t length of the input list in the given index. + */ +size_t RAI_ScriptRunCtxGetInputListLen(RAI_ScriptRunCtx *sctx, size_t index); + +/** + * Extract the ternsor parameters for the ScriptCtxRun object from AI.SCRIPTEXECUTE arguments. + * + * @param ctx Context in which Redis modules operate. + * @param inkeys Script input tensors keys, as an array of strings. + * @param outkeys Script output tensors keys, as an array of strings. + * @param sctx Destination Script context to store the parsed data. + * @return REDISMODULE_OK in case of success, REDISMODULE_ERR otherwise. + */ + +int ScriptRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkeys, + RedisModuleString **outkeys, RAI_ScriptRunCtx *sctx, RAI_Error *err); + +int RAI_ScriptRunCtxAddTensorInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor); + +int RAI_ScriptRunCtxAddIntInput(RAI_ScriptRunCtx *sctx, int32_t i); + +int RAI_ScriptRunCtxAddFloatInput(RAI_ScriptRunCtx *sctx, float f); + +int RAI_ScriptRunCtxAddRStringInput(RAI_ScriptRunCtx *sctx, RedisModuleString *s); + +int RAI_ScriptRunCtxAddStringInput(RAI_ScriptRunCtx *sctx, const char *s, size_t len); + +int RAI_ScriptRunCtxAddTensorInputList(RAI_ScriptRunCtx *sctx, RAI_Tensor **inputTensors, + size_t count); + +int RAI_ScriptRunCtxAddIntInputList(RAI_ScriptRunCtx *sctx, int32_t *intInputs, size_t count); + +int RAI_ScriptRunCtxAddFloatInputList(RAI_ScriptRunCtx *sctx, float *floatInputs, size_t count); + +int RAI_ScriptRunCtxAddRStringInputList(RAI_ScriptRunCtx *sctx, RedisModuleString **stringInputs, + size_t count); + +int RAI_ScriptRunCtxAddStringInputList(RAI_ScriptRunCtx *sctx, const char **stringInputs, + size_t *lens, size_t count); diff --git a/src/execution/modelRun_ctx.c b/src/execution/modelRun_ctx.c deleted file mode 100644 index 5024bf130..000000000 --- a/src/execution/modelRun_ctx.c +++ /dev/null @@ -1,70 +0,0 @@ - -#include "modelRun_ctx.h" -#include "util/string_utils.h" - -static int _Model_RunCtxAddParam(RAI_ModelCtxParam **paramArr, const char *name, - RAI_Tensor *tensor) { - - RAI_ModelCtxParam param = { - .name = name, - .tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL, - }; - *paramArr = array_append(*paramArr, param); - return 1; -} - -RAI_ModelRunCtx *RAI_ModelRunCtxCreate(RAI_Model *model) { -#define PARAM_INITIAL_SIZE 10 - RAI_ModelRunCtx *mctx = RedisModule_Calloc(1, sizeof(*mctx)); - mctx->model = RAI_ModelGetShallowCopy(model); - mctx->inputs = array_new(RAI_ModelCtxParam, PARAM_INITIAL_SIZE); - mctx->outputs = array_new(RAI_ModelCtxParam, PARAM_INITIAL_SIZE); - return mctx; -#undef PARAM_INITIAL_SIZE -} - -int RAI_ModelRunCtxAddInput(RAI_ModelRunCtx *mctx, const char *inputName, RAI_Tensor *inputTensor) { - return _Model_RunCtxAddParam(&mctx->inputs, inputName, inputTensor); -} - -int RAI_ModelRunCtxAddOutput(RAI_ModelRunCtx *mctx, const char *outputName) { - return _Model_RunCtxAddParam(&mctx->outputs, outputName, NULL); -} - -size_t RAI_ModelRunCtxNumInputs(RAI_ModelRunCtx *mctx) { return array_len(mctx->inputs); } - -size_t RAI_ModelRunCtxNumOutputs(RAI_ModelRunCtx *mctx) { return array_len(mctx->outputs); } - -RAI_Tensor *RAI_ModelRunCtxInputTensor(RAI_ModelRunCtx *mctx, size_t index) { - assert(RAI_ModelRunCtxNumInputs(mctx) > index && index >= 0); - return mctx->inputs[index].tensor; -} - -RAI_Tensor *RAI_ModelRunCtxOutputTensor(RAI_ModelRunCtx *mctx, size_t index) { - assert(RAI_ModelRunCtxNumOutputs(mctx) > index && index >= 0); - return mctx->outputs[index].tensor; -} - -void RAI_ModelRunCtxFree(RAI_ModelRunCtx *mctx) { - for (size_t i = 0; i < array_len(mctx->inputs); ++i) { - RAI_TensorFree(mctx->inputs[i].tensor); - } - - for (size_t i = 0; i < array_len(mctx->outputs); ++i) { - if (mctx->outputs[i].tensor) { - RAI_TensorFree(mctx->outputs[i].tensor); - } - } - - array_free(mctx->inputs); - array_free(mctx->outputs); - - RAI_Error err = {0}; - RAI_ModelFree(mctx->model, &err); - - if (err.code != RAI_OK) { - // TODO: take it to client somehow - RAI_ClearError(&err); - } - RedisModule_Free(mctx); -} diff --git a/src/execution/DAG/dag_parser.c b/src/execution/parsing/dag_parser.c similarity index 97% rename from src/execution/DAG/dag_parser.c rename to src/execution/parsing/dag_parser.c index 55c3fba74..8c25d48e1 100644 --- a/src/execution/DAG/dag_parser.c +++ b/src/execution/parsing/dag_parser.c @@ -1,14 +1,14 @@ #include +#include "dag_parser.h" #include "redismodule.h" #include "util/dict.h" #include "util/string_utils.h" #include "redis_ai_objects/tensor.h" -#include "execution/modelRun_ctx.h" +#include "execution/execution_contexts/modelRun_ctx.h" #include "execution/command_parser.h" -#include "dag.h" -#include "dag_parser.h" -#include "dag_execute.h" -#include "execution/deprecated.h" +#include "execution/DAG/dag.h" +#include "execution/DAG/dag_execute.h" +#include "execution/parsing/deprecated.h" #include "execution/utils.h" /** @@ -171,7 +171,7 @@ int ParseDAGOps(RedisAI_RunInfo *rinfo, RAI_DagOp **ops) { if (!strcasecmp(arg_string, "AI.TENSORGET")) { currentOp->commandType = REDISAI_DAG_CMD_TENSORGET; currentOp->devicestr = "CPU"; - RAI_HoldString(NULL, currentOp->argv[1]); + RAI_HoldString(currentOp->argv[1]); currentOp->inkeys = array_append(currentOp->inkeys, currentOp->argv[1]); currentOp->fmt = ParseTensorGetArgs(rinfo->err, currentOp->argv, currentOp->argc); if (currentOp->fmt == TENSOR_NONE) @@ -181,7 +181,7 @@ int ParseDAGOps(RedisAI_RunInfo *rinfo, RAI_DagOp **ops) { if (!strcasecmp(arg_string, "AI.TENSORSET")) { currentOp->commandType = REDISAI_DAG_CMD_TENSORSET; currentOp->devicestr = "CPU"; - RAI_HoldString(NULL, currentOp->argv[1]); + RAI_HoldString(currentOp->argv[1]); currentOp->outkeys = array_append(currentOp->outkeys, currentOp->argv[1]); if (RAI_parseTensorSetArgs(currentOp->argv, currentOp->argc, ¤tOp->outTensor, 0, rinfo->err) == -1) diff --git a/src/execution/DAG/dag_parser.h b/src/execution/parsing/dag_parser.h similarity index 100% rename from src/execution/DAG/dag_parser.h rename to src/execution/parsing/dag_parser.h diff --git a/src/execution/deprecated.c b/src/execution/parsing/deprecated.c similarity index 67% rename from src/execution/deprecated.c rename to src/execution/parsing/deprecated.c index d0ad763f7..baded9014 100644 --- a/src/execution/deprecated.c +++ b/src/execution/parsing/deprecated.c @@ -1,12 +1,14 @@ #include "deprecated.h" -#include "modelRun_ctx.h" -#include "command_parser.h" #include "util/string_utils.h" -#include "execution/utils.h" #include "rmutil/args.h" #include "backends/backends.h" -#include "execution/background_workers.h" #include "redis_ai_objects/stats.h" +#include "execution/utils.h" +#include "execution/command_parser.h" +#include "execution/background_workers.h" +#include "execution/execution_contexts/modelRun_ctx.h" +#include "execution/execution_contexts/scriptRun_ctx.h" +#include "execution/parsing/parse_utils.h" static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModuleString **argv, RAI_Model **model, RAI_Error *error, @@ -23,7 +25,7 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModule if (status == REDISMODULE_ERR) { return REDISMODULE_ERR; } - RAI_HoldString(NULL, argv[argpos]); + RAI_HoldString(argv[argpos]); *runkey = argv[argpos]; const char *arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL); @@ -47,7 +49,7 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModule is_input = false; is_output = true; } else { - RAI_HoldString(NULL, argv[argpos]); + RAI_HoldString(argv[argpos]); if (is_input) { ninputs++; *inkeys = array_append(*inkeys, argv[argpos]); @@ -316,7 +318,7 @@ int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { int type = RedisModule_KeyType(key); if (type != REDISMODULE_KEYTYPE_EMPTY && !(type == REDISMODULE_KEYTYPE_MODULE && - RedisModule_ModuleTypeGetType(key) == RedisAI_ModelType)) { + RedisModule_ModuleTypeGetType(key) == RAI_ModelRedisType())) { RedisModule_CloseKey(key); RAI_ModelFree(model, &err); if (err.code != RAI_OK) { @@ -328,7 +330,7 @@ int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); } - RedisModule_ModuleTypeSetValue(key, RedisAI_ModelType, model); + RedisModule_ModuleTypeSetValue(key, RAI_ModelRedisType(), model); model->infokey = RAI_AddStatsEntry(ctx, keystr, RAI_MODEL, backend, devicestr, tag); @@ -340,3 +342,151 @@ int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { return REDISMODULE_OK; } + +static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, + RAI_Error *error, RedisModuleString ***inkeys, + RedisModuleString ***outkeys, long long *timeout, + size_t **listSizes) { + + bool is_input = false; + bool is_output = false; + bool timeout_set = false; + bool inputs_done = false; + size_t ninputs = 0, noutputs = 0; + int varidic_start_pos = -1; + for (int argpos = 3; argpos < argc; argpos++) { + const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL); + + // Parse timeout arg if given and store it in timeout + if (!strcasecmp(arg_string, "TIMEOUT")) { + if (timeout_set) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Already encountered an TIMEOUT section in SCRIPTRUN"); + return REDISMODULE_ERR; + } + if (ParseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR) + return REDISMODULE_ERR; + timeout_set = true; + continue; + } + + if (!strcasecmp(arg_string, "INPUTS")) { + if (inputs_done) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Already encountered an INPUTS section in SCRIPTRUN"); + return REDISMODULE_ERR; + } + if (is_input) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Already encountered an INPUTS keyword in SCRIPTRUN"); + return REDISMODULE_ERR; + } + is_input = true; + is_output = false; + continue; + } + if (!strcasecmp(arg_string, "OUTPUTS")) { + if (is_output) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Already encountered an OUTPUTS keyword in SCRIPTRUN"); + return REDISMODULE_ERR; + } + is_input = false; + is_output = true; + inputs_done = true; + continue; + } + if (!strcasecmp(arg_string, "$")) { + if (!is_input) { + RAI_SetError( + error, RAI_ESCRIPTRUN, + "ERR Encountered a variable size list of tensors outside of input section"); + return REDISMODULE_ERR; + } + if (varidic_start_pos > -1) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Already encountered a variable size list of tensors"); + return REDISMODULE_ERR; + } + varidic_start_pos = ninputs; + continue; + } + // Parse argument name + if (is_input) { + ninputs++; + *inkeys = array_append(*inkeys, RAI_HoldString(argv[argpos])); + } else if (is_output) { + noutputs++; + *outkeys = array_append(*outkeys, RAI_HoldString(argv[argpos])); + } else { + RAI_SetError(error, RAI_ESCRIPTRUN, "ERR Unrecongnized parameter to SCRIPTRUN"); + return REDISMODULE_ERR; + } + } + if (varidic_start_pos != -1) { + *listSizes = array_append(*listSizes, ninputs - varidic_start_pos); + } + + return REDISMODULE_OK; +} + +int ParseScriptRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv, + int argc) { + + if (argc < 3) { + RAI_SetError(rinfo->err, RAI_ESCRIPTRUN, + "ERR wrong number of arguments for 'AI.SCRIPTRUN' command"); + return REDISMODULE_ERR; + } + + int res = REDISMODULE_ERR; + // Build a ScriptRunCtx from command. + RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL); + RAI_ScriptRunCtx *sctx = NULL; + RAI_Script *script = NULL; + RedisModuleString *scriptName = argv[1]; + RAI_GetScriptFromKeyspace(ctx, scriptName, &script, REDISMODULE_READ, rinfo->err); + if (!script) { + goto cleanup; + } + RAI_DagOpSetRunKey(currentOp, RAI_HoldString(argv[1])); + + const char *func_name = ScriptCommand_GetFunctionName(argv[2]); + if (!func_name) { + RAI_SetError(rinfo->err, RAI_ESCRIPTRUN, "ERR function name not specified"); + goto cleanup; + } + + sctx = RAI_ScriptRunCtxCreate(script, func_name); + long long timeout = 0; + if (_ScriptRunCommand_ParseArgs(ctx, argv, argc, rinfo->err, ¤tOp->inkeys, + ¤tOp->outkeys, &timeout, + &sctx->listSizes) == REDISMODULE_ERR) { + goto cleanup; + } + if (timeout > 0 && !rinfo->single_op_dag) { + RAI_SetError(rinfo->err, RAI_EDAGBUILDER, "ERR TIMEOUT not allowed within a DAG command"); + goto cleanup; + } + + if (rinfo->single_op_dag) { + rinfo->timeout = timeout; + // Set params in ScriptRunCtx, bring inputs from key space. + if (ScriptRunCtx_SetParams(ctx, currentOp->inkeys, currentOp->outkeys, sctx, rinfo->err) == + REDISMODULE_ERR) + goto cleanup; + } + currentOp->sctx = sctx; + currentOp->commandType = REDISAI_DAG_CMD_SCRIPTRUN; + currentOp->devicestr = sctx->script->devicestr; + res = REDISMODULE_OK; + RedisModule_FreeThreadSafeContext(ctx); + return res; + +cleanup: + RedisModule_FreeThreadSafeContext(ctx); + if (sctx) { + RAI_ScriptRunCtxFree(sctx); + } + return res; +} diff --git a/src/execution/deprecated.h b/src/execution/parsing/deprecated.h similarity index 80% rename from src/execution/deprecated.h rename to src/execution/parsing/deprecated.h index 251696872..8cdfe639c 100644 --- a/src/execution/deprecated.h +++ b/src/execution/parsing/deprecated.h @@ -1,7 +1,7 @@ #pragma once #include "redismodule.h" -#include "run_info.h" +#include "execution/run_info.h" /** * @brief Parse and validate MODELRUN command: create a modelRunCtx based on the model obtained @@ -13,4 +13,7 @@ int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv, int argc); +int ParseScriptRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv, + int argc); + int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc); diff --git a/src/execution/parsing/model_commands_parser.c b/src/execution/parsing/model_commands_parser.c new file mode 100644 index 000000000..91c6eff45 --- /dev/null +++ b/src/execution/parsing/model_commands_parser.c @@ -0,0 +1,158 @@ +#include "model_commands_parser.h" +#include "redis_ai_objects/model.h" +#include "util/string_utils.h" +#include "execution/parsing/parse_utils.h" +#include "execution/execution_contexts/modelRun_ctx.h" + +static int _ModelExecuteCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModuleString **argv, + RAI_Model **model, RAI_Error *error, + RedisModuleString ***inkeys, RedisModuleString ***outkeys, + RedisModuleString **runkey, long long *timeout) { + + if (argc < 8) { + RAI_SetError(error, RAI_EMODELRUN, + "ERR wrong number of arguments for 'AI.MODELEXECUTE' command"); + return REDISMODULE_ERR; + } + size_t arg_pos = 1; + const int status = RAI_GetModelFromKeyspace(ctx, argv[arg_pos], model, REDISMODULE_READ, error); + if (status == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + *runkey = RAI_HoldString(argv[arg_pos++]); + const char *arg_string = RedisModule_StringPtrLen(argv[arg_pos++], NULL); + + if (strcasecmp(arg_string, "INPUTS") != 0) { + RAI_SetError(error, RAI_EMODELRUN, "ERR INPUTS not specified"); + return REDISMODULE_ERR; + } + + long long ninputs = 0, noutputs = 0; + if (RedisModule_StringToLongLong(argv[arg_pos++], &ninputs) != REDISMODULE_OK) { + RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid argument for input_count"); + return REDISMODULE_ERR; + } + if (ninputs <= 0) { + RAI_SetError(error, RAI_EMODELRUN, "ERR Input count must be a positive integer"); + return REDISMODULE_ERR; + } + if ((*model)->ninputs != ninputs) { + RAI_SetError(error, RAI_EMODELRUN, + "Number of keys given as INPUTS here does not match model definition"); + return REDISMODULE_ERR; + } + // arg_pos = 4 + size_t first_input_pos = arg_pos; + if (first_input_pos + ninputs > argc) { + RAI_SetError( + error, RAI_EMODELRUN, + "ERR number of input keys to AI.MODELEXECUTE command does not match the number of " + "given arguments"); + return REDISMODULE_ERR; + } + for (; arg_pos < first_input_pos + ninputs; arg_pos++) { + *inkeys = array_append(*inkeys, RAI_HoldString(argv[arg_pos])); + } + + if (argc == arg_pos || + strcasecmp(RedisModule_StringPtrLen(argv[arg_pos++], NULL), "OUTPUTS") != 0) { + RAI_SetError(error, RAI_EMODELRUN, "ERR OUTPUTS not specified"); + return REDISMODULE_ERR; + } + if (argc == arg_pos || + RedisModule_StringToLongLong(argv[arg_pos++], &noutputs) != REDISMODULE_OK) { + RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid argument for output_count"); + } + if (noutputs <= 0) { + RAI_SetError(error, RAI_EMODELRUN, "ERR Output count must be a positive integer"); + return REDISMODULE_ERR; + } + if ((*model)->noutputs != noutputs) { + RAI_SetError(error, RAI_EMODELRUN, + "Number of keys given as OUTPUTS here does not match model definition"); + return REDISMODULE_ERR; + } + // arg_pos = ninputs+6, the argument that we already parsed are: + // AI.MODELEXECUTE INPUTS ... OUTPUTS + size_t first_output_pos = arg_pos; + if (first_output_pos + noutputs > argc) { + RAI_SetError( + error, RAI_EMODELRUN, + "ERR number of output keys to AI.MODELEXECUTE command does not match the number of " + "given arguments"); + return REDISMODULE_ERR; + } + for (; arg_pos < first_output_pos + noutputs; arg_pos++) { + *outkeys = array_append(*outkeys, RAI_HoldString(argv[arg_pos])); + } + if (arg_pos == argc) { + return REDISMODULE_OK; + } + + // Parse timeout arg if given and store it in timeout. + char *error_str; + arg_string = RedisModule_StringPtrLen(argv[arg_pos++], NULL); + if (!strcasecmp(arg_string, "TIMEOUT")) { + if (arg_pos == argc) { + RAI_SetError(error, RAI_EMODELRUN, "ERR No value provided for TIMEOUT"); + return REDISMODULE_ERR; + } + if (ParseTimeout(argv[arg_pos++], error, timeout) == REDISMODULE_ERR) + return REDISMODULE_ERR; + } else { + error_str = RedisModule_Alloc(strlen("Invalid argument: ") + strlen(arg_string) + 1); + sprintf(error_str, "Invalid argument: %s", arg_string); + RAI_SetError(error, RAI_EMODELRUN, error_str); + RedisModule_Free(error_str); + return REDISMODULE_ERR; + } + + // There are no more valid args to be processed. + if (arg_pos != argc) { + arg_string = RedisModule_StringPtrLen(argv[arg_pos], NULL); + error_str = RedisModule_Alloc(strlen("Invalid argument: ") + strlen(arg_string) + 1); + sprintf(error_str, "Invalid argument: %s", arg_string); + RAI_SetError(error, RAI_EMODELRUN, error_str); + RedisModule_Free(error_str); + return REDISMODULE_ERR; + } + return REDISMODULE_OK; +} + +int ParseModelExecuteCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv, + int argc) { + + int res = REDISMODULE_ERR; + // Build a ModelRunCtx from command. + RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL); + RAI_Model *model; + long long timeout = 0; + if (_ModelExecuteCommand_ParseArgs(ctx, argc, argv, &model, rinfo->err, ¤tOp->inkeys, + ¤tOp->outkeys, ¤tOp->runkey, + &timeout) == REDISMODULE_ERR) { + goto cleanup; + } + + if (timeout > 0 && !rinfo->single_op_dag) { + RAI_SetError(rinfo->err, RAI_EDAGBUILDER, "ERR TIMEOUT not allowed within a DAG command"); + goto cleanup; + } + + RAI_ModelRunCtx *mctx = RAI_ModelRunCtxCreate(model); + currentOp->commandType = REDISAI_DAG_CMD_MODELRUN; + currentOp->mctx = mctx; + currentOp->devicestr = mctx->model->devicestr; + + if (rinfo->single_op_dag) { + rinfo->timeout = timeout; + // Set params in ModelRunCtx, bring inputs from key space. + if (ModelRunCtx_SetParams(ctx, currentOp->inkeys, currentOp->outkeys, mctx, rinfo->err) == + REDISMODULE_ERR) + goto cleanup; + } + res = REDISMODULE_OK; + +cleanup: + RedisModule_FreeThreadSafeContext(ctx); + return res; +} diff --git a/src/execution/parsing/model_commands_parser.h b/src/execution/parsing/model_commands_parser.h new file mode 100644 index 000000000..c288045b4 --- /dev/null +++ b/src/execution/parsing/model_commands_parser.h @@ -0,0 +1,13 @@ +#pragma once +#include "redismodule.h" +#include "execution/run_info.h" + +/** + * @brief Parse and validate MODELEXECUTE command: create a modelRunCtx based on the model obtained + * from the key space and save it in the op. The keys of the input and output tensors are stored in + * the op's inkeys and outkeys arrays, the model key is saved in op's runkey, and the given timeout + * is saved as well (if given, otherwise it is zero). + * @return Returns REDISMODULE_OK if the command is valid, REDISMODULE_ERR otherwise. + */ +int ParseModelExecuteCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv, + int argc); diff --git a/src/execution/parsing/parse_utils.c b/src/execution/parsing/parse_utils.c new file mode 100644 index 000000000..b1c23f176 --- /dev/null +++ b/src/execution/parsing/parse_utils.c @@ -0,0 +1,17 @@ +#include "parse_utils.h" +#include "string.h" + +int ParseTimeout(RedisModuleString *timeout_arg, RAI_Error *error, long long *timeout) { + + const int retval = RedisModule_StringToLongLong(timeout_arg, timeout); + if (retval != REDISMODULE_OK || *timeout <= 0) { + RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid value for TIMEOUT"); + return REDISMODULE_ERR; + } + return REDISMODULE_OK; +} + +const char *ScriptCommand_GetFunctionName(RedisModuleString *functionName) { + const char *functionName_cstr = RedisModule_StringPtrLen(functionName, NULL); + return functionName_cstr; +} diff --git a/src/execution/parsing/parse_utils.h b/src/execution/parsing/parse_utils.h new file mode 100644 index 000000000..1ddc23e22 --- /dev/null +++ b/src/execution/parsing/parse_utils.h @@ -0,0 +1,18 @@ +#pragma once +#include "redismodule.h" +#include "redis_ai_objects/err.h" + +/** + * @brief Parse and validate TIMEOUT argument. If it is valid, store it in timeout. + * Otherwise set an error. + * @return Returns REDISMODULE_OK if the command is valid, REDISMODULE_ERR otherwise. + */ +int ParseTimeout(RedisModuleString *timeout_arg, RAI_Error *error, long long *timeout); + +/** + * @brief + * + * @param functionName + * @return const char* + */ +const char *ScriptCommand_GetFunctionName(RedisModuleString *functionName); diff --git a/src/execution/parsing/script_commands_parser.c b/src/execution/parsing/script_commands_parser.c new file mode 100644 index 000000000..20d22f2d2 --- /dev/null +++ b/src/execution/parsing/script_commands_parser.c @@ -0,0 +1,391 @@ +#include "script_commands_parser.h" +#include "parse_utils.h" +#include "execution/utils.h" +#include "util/string_utils.h" +#include "execution/execution_contexts/scriptRun_ctx.h" + +static bool _Script_buildInputsBySchema(RAI_ScriptRunCtx *sctx, RedisModuleString **inputs, + RedisModuleString ***inkeys, RAI_Error *error) { + int signatureListCount = 0; + + TorchScriptFunctionArgumentType *signature = RAI_ScriptRunCtxGetSignature(sctx); + if (!signature) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "Wrong function name provider to AI.SCRIPTEXECUTE command"); + return false; + } + size_t nlists = array_len(sctx->listSizes); + size_t nArguments = array_len(signature); + size_t nInputs = array_len(inputs); + size_t inputsIdx = 0; + size_t listIdx = 0; + + if (nInputs < nArguments) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "Wrong number of inputs provided to AI.SCRIPTEXECUTE command"); + return false; + } + for (size_t i = 0; i < nArguments; i++) { + switch (signature[i]) { + case UNKOWN: { + RAI_SetError(error, RAI_ESCRIPTRUN, + "Unsupported argument type in AI.SCRIPTEXECUTE command"); + return false; + } + case TENSOR_LIST: { + // Collect the inputs tensor names from the current list + signatureListCount++; + if (signatureListCount > nlists) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "Wrong number of lists provided in AI.SCRIPTEXECUTE command"); + return false; + } + size_t listLen = RAI_ScriptRunCtxGetInputListLen(sctx, listIdx++); + for (size_t j = 0; j < listLen; j++) { + *inkeys = array_append(*inkeys, RAI_HoldString(inputs[inputsIdx++])); + } + break; + } + case INT_LIST: { + signatureListCount++; + if (signatureListCount > nlists) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "Wrong number of lists provided in AI.SCRIPTEXECUTE command"); + return false; + } + size_t listLen = RAI_ScriptRunCtxGetInputListLen(sctx, listIdx++); + for (size_t j = 0; j < listLen; j++) { + long long l; + RedisModule_StringToLongLong(inputs[inputsIdx++], &l); + RAI_ScriptRunCtxAddIntInput(sctx, (int32_t)l); + } + break; + } + case FLOAT_LIST: { + signatureListCount++; + if (signatureListCount > nlists) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "Wrong number of lists provided in AI.SCRIPTEXECUTE command"); + return false; + } + size_t listLen = RAI_ScriptRunCtxGetInputListLen(sctx, listIdx++); + for (size_t j = 0; j < listLen; j++) { + double d; + RedisModule_StringToDouble(inputs[inputsIdx++], &d); + RAI_ScriptRunCtxAddFloatInput(sctx, (float)d); + } + break; + } + case STRING_LIST: { + signatureListCount++; + if (signatureListCount > nlists) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "Wrong number of lists provided in AI.SCRIPTEXECUTE command"); + return false; + } + size_t listLen = RAI_ScriptRunCtxGetInputListLen(sctx, listIdx++); + for (size_t j = 0; j < listLen; j++) { + RAI_ScriptRunCtxAddRStringInput(sctx, inputs[inputsIdx++]); + } + break; + } + case INT: { + long long l; + RedisModule_StringToLongLong(inputs[inputsIdx++], &l); + RAI_ScriptRunCtxAddIntInput(sctx, (int32_t)l); + break; + } + case FLOAT: { + double d; + RedisModule_StringToDouble(inputs[inputsIdx++], &d); + RAI_ScriptRunCtxAddFloatInput(sctx, (float)d); + break; + } + case STRING: { + // Input is a string. + RAI_ScriptRunCtxAddRStringInput(sctx, inputs[inputsIdx++]); + break; + } + case TENSOR: + default: { + // Input is a tensor, add its name to the inkeys. + *inkeys = array_append(*inkeys, RAI_HoldString(inputs[inputsIdx++])); + break; + } + } + } + if (signatureListCount != nlists) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "Wrong number of lists provided in AI.SCRIPTEXECUTE command"); + return false; + } + + return true; +} + +static int _ScriptExecuteCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, + RAI_Error *error, RedisModuleString ***inkeys, + RedisModuleString ***outkeys, RAI_ScriptRunCtx *sctx, + long long *timeout, bool keysRequired) { + int argpos = 3; + bool inputsDone = false; + bool outputsDone = false; + bool KeysDone = false; + // Local input context to verify correctness. + array_new_on_stack(RedisModuleString *, 10, inputs); + if (keysRequired) { + const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL); + if (!strcasecmp(arg_string, "KEYS")) { + KeysDone = true; + // Read key number. + argpos++; + if (argpos >= argc) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Invalid arguments provided to AI.SCRIPTEXECUTE"); + goto cleanup; + } + long long nkeys; + if (RedisModule_StringToLongLong(argv[argpos], &nkeys) != REDISMODULE_OK) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Invalid argument for key count in AI.SCRIPTEXECUTE"); + goto cleanup; + } + // Check validity of key numbers. + argpos++; + size_t first_input_pos = argpos; + if (first_input_pos + nkeys > argc) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR number of input keys to AI.SCRIPTEXECUTE command does not match " + "the number of given arguments"); + goto cleanup; + } + // Verify given keys in local shard. + for (; argpos < first_input_pos + nkeys; argpos++) { + if (!VerifyKeyInThisShard(ctx, argv[argpos])) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR CROSSSLOT Keys in AI.SCRIPTEXECUTE request don't hash to the " + "same slot"); + goto cleanup; + } + } + } + // argv[3] is not KEYS. + else { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR KEYS scope must be provided first for AI.SCRIPTEXECUTE command"); + goto cleanup; + } + } + + while (argpos < argc) { + const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL); + // See that no addtional KEYS scope is provided. + if (!strcasecmp(arg_string, "KEYS")) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Already Encountered KEYS scope in current command"); + goto cleanup; + } + // Parse timeout arg if given and store it in timeout. + if (!strcasecmp(arg_string, "TIMEOUT")) { + argpos++; + if (argpos >= argc) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR No value provided for TIMEOUT in AI.SCRIPTEXECUTE"); + goto cleanup; + } + if (ParseTimeout(argv[argpos], error, timeout) == REDISMODULE_ERR) + goto cleanup; + // No other arguments expected after timeout. + break; + } + + if (!strcasecmp(arg_string, "INPUTS")) { + // Check for already given inputs. + if (inputsDone) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Already Encountered INPUTS scope in AI.SCRIPTEXECUTE command"); + goto cleanup; + } + inputsDone = true; + // Read input number. + argpos++; + if (argpos >= argc) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Invalid arguments provided to AI.SCRIPTEXECUTE"); + goto cleanup; + } + long long ninputs; + if (RedisModule_StringToLongLong(argv[argpos], &ninputs) != REDISMODULE_OK) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Invalid argument for input count in AI.SCRIPTEXECUTE"); + goto cleanup; + } + // Check validity of input numbers. + argpos++; + size_t first_input_pos = argpos; + if (first_input_pos + ninputs > argc) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR number of input keys to AI.SCRIPTEXECUTE command does not match " + "the number of given arguments"); + goto cleanup; + } + // Add to local input context. + for (; argpos < first_input_pos + ninputs; argpos++) { + inputs = array_append(inputs, RAI_HoldString(argv[argpos])); + } + continue; + } + if (!strcasecmp(arg_string, "OUTPUTS")) { + // Check for already given outputs. + if (outputsDone) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Already Encountered OUTPUTS scope in AI.SCRIPTEXECUTE command"); + goto cleanup; + } + // Update mask. + outputsDone = true; + // Read output number. + argpos++; + if (argpos >= argc) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Invalid arguments provided to AI.SCRIPTEXECUTE"); + goto cleanup; + } + long long noutputs; + if (RedisModule_StringToLongLong(argv[argpos], &noutputs) != REDISMODULE_OK) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Invalid argument for output count in AI.SCRIPTEXECUTE"); + goto cleanup; + } + // Check validity of output numbers. + argpos++; + size_t first_output_pos = argpos; + if (first_output_pos + noutputs > argc) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR number of output keys to AI.SCRIPTEXECUTE command does not match " + "the number of given arguments"); + goto cleanup; + } + for (; argpos < first_output_pos + noutputs; argpos++) { + *outkeys = array_append(*outkeys, RAI_HoldString(argv[argpos])); + } + continue; + } + if (!strcasecmp(arg_string, "LIST_INPUTS")) { + // Read list size. + argpos++; + long long ninputs; + if (RedisModule_StringToLongLong(argv[argpos], &ninputs) != REDISMODULE_OK) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Invalid argument for list input count in AI.SCRIPTEXECUTE"); + goto cleanup; + } + // Check validity of current list size. + argpos++; + if (argpos >= argc) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Invalid arguments provided to AI.SCRIPTEXECUTE"); + goto cleanup; + } + size_t first_input_pos = argpos; + if (first_input_pos + ninputs > argc) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR number of list input keys to AI.SCRIPTEXECUTE command does not " + "match the number of given arguments"); + goto cleanup; + } + for (; argpos < first_input_pos + ninputs; argpos++) { + inputs = array_append(inputs, RAI_HoldString(argv[argpos])); + } + RAI_ScriptRunCtxAddListSize(sctx, ninputs); + continue; + } + + RAI_SetError(error, RAI_ESCRIPTRUN, "ERR Unrecongnized parameter to AI.SCRIPTEXECUTE"); + goto cleanup; + } + if (argpos != argc) { + RAI_SetError(error, RAI_ESCRIPTRUN, "ERR Encountered problem parsing AI.SCRIPTEXECUTE"); + goto cleanup; + } + + if (!_Script_buildInputsBySchema(sctx, inputs, inkeys, error)) { + goto cleanup; + } + for (size_t i = 0; i < array_len(inputs); i++) { + RedisModule_FreeString(ctx, inputs[i]); + } + array_free(inputs); + + return REDISMODULE_OK; +cleanup: + for (size_t i = 0; i < array_len(inputs); i++) { + RedisModule_FreeString(ctx, inputs[i]); + } + array_free(inputs); + return REDISMODULE_ERR; +} + +int ParseScriptExecuteCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, + RedisModuleString **argv, int argc) { + + RAI_Error *error = rinfo->err; + if (argc < 3) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR wrong number of arguments for 'AI.SCRIPTEXECUTE' command"); + return REDISMODULE_ERR; + } + + int res = REDISMODULE_ERR; + // Build a ScriptRunCtx from command. + RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL); + + RAI_Script *script = NULL; + RAI_ScriptRunCtx *sctx = NULL; + RedisModuleString *scriptName = argv[1]; + RAI_GetScriptFromKeyspace(ctx, scriptName, &script, REDISMODULE_READ, error); + if (!script) { + goto cleanup; + } + + RAI_DagOpSetRunKey(currentOp, RAI_HoldString(argv[1])); + + const char *func_name = ScriptCommand_GetFunctionName(argv[2]); + if (!func_name) { + RAI_SetError(rinfo->err, RAI_ESCRIPTRUN, "ERR function name not specified"); + goto cleanup; + } + + sctx = RAI_ScriptRunCtxCreate(script, func_name); + long long timeout = 0; + if (_ScriptExecuteCommand_ParseArgs(ctx, argv, argc, error, ¤tOp->inkeys, + ¤tOp->outkeys, sctx, &timeout, + rinfo->single_op_dag) == REDISMODULE_ERR) { + goto cleanup; + } + if (timeout > 0 && !rinfo->single_op_dag) { + RAI_SetError(error, RAI_EDAGBUILDER, "ERR TIMEOUT not allowed within a DAG command"); + goto cleanup; + } + + if (rinfo->single_op_dag) { + rinfo->timeout = timeout; + // Set params in ScriptRunCtx, bring inputs from key space. + if (ScriptRunCtx_SetParams(ctx, currentOp->inkeys, currentOp->outkeys, sctx, error) == + REDISMODULE_ERR) + goto cleanup; + } + res = REDISMODULE_OK; + RedisModule_FreeThreadSafeContext(ctx); + currentOp->sctx = sctx; + currentOp->commandType = REDISAI_DAG_CMD_SCRIPTRUN; + currentOp->devicestr = sctx->script->devicestr; + return res; + +cleanup: + RedisModule_FreeThreadSafeContext(ctx); + if (sctx) { + RAI_ScriptRunCtxFree(sctx); + } + return res; +} diff --git a/src/execution/parsing/script_commands_parser.h b/src/execution/parsing/script_commands_parser.h new file mode 100644 index 000000000..50ace86a8 --- /dev/null +++ b/src/execution/parsing/script_commands_parser.h @@ -0,0 +1,13 @@ +#pragma once +#include "redismodule.h" +#include "execution/run_info.h" + +/** + * @brief Parse and validate SCRIPTEXECUTE command: create a scriptRunCtx based on the script + * obtained from the key space and the function name given, and save it in the op. The keys of the + * input and output tensors are stored in the op's inkeys and outkeys arrays, the script key is + * saved in op's runkey, and the given timeout is saved as well (if given, otherwise it is zero). + * @return Returns REDISMODULE_OK if the command is valid, REDISMODULE_ERR otherwise. + */ +int ParseScriptExecuteCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, + RedisModuleString **argv, int argc); diff --git a/src/execution/run_info.c b/src/execution/run_info.c index 970a676c5..4bb5672e7 100644 --- a/src/execution/run_info.c +++ b/src/execution/run_info.c @@ -11,12 +11,12 @@ #include "redismodule.h" #include "redis_ai_objects/err.h" #include "redis_ai_objects/model.h" -#include "execution/modelRun_ctx.h" +#include "execution/execution_contexts/modelRun_ctx.h" #include "redis_ai_objects/script.h" #include "redis_ai_objects/tensor.h" #include "redis_ai_objects/model_struct.h" #include "util/arr.h" -#include "util/dict.h" +#include "util/dictionaries.h" #include "util/string_utils.h" static void RAI_TensorDictValFree(void *privdata, void *obj) { @@ -32,36 +32,6 @@ AI_dictType AI_dictTypeTensorVals = { .valDestructor = RAI_TensorDictValFree, }; -/** - * Allocate the memory and initialise the RAI_DagOp. - * @param result Output parameter to capture allocated RAI_DagOp. - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if the allocation - * failed. - */ -int RAI_InitDagOp(RAI_DagOp **result) { - RAI_DagOp *dagOp; - dagOp = (RAI_DagOp *)RedisModule_Calloc(1, sizeof(RAI_DagOp)); - - dagOp->commandType = REDISAI_DAG_CMD_NONE; - dagOp->runkey = NULL; - dagOp->inkeys = (RedisModuleString **)array_new(RedisModuleString *, 1); - dagOp->outkeys = (RedisModuleString **)array_new(RedisModuleString *, 1); - dagOp->inkeys_indices = array_new(size_t, 1); - dagOp->outkeys_indices = array_new(size_t, 1); - dagOp->outTensor = NULL; - dagOp->mctx = NULL; - dagOp->sctx = NULL; - dagOp->devicestr = NULL; - dagOp->duration_us = 0; - dagOp->result = -1; - RAI_InitError(&dagOp->err); - dagOp->argv = NULL; - dagOp->argc = 0; - - *result = dagOp; - return REDISMODULE_OK; -} - /** * Allocate the memory and initialise the RedisAI_RunInfo. * @param result Output parameter to capture allocated RedisAI_RunInfo. @@ -105,40 +75,6 @@ int RAI_ShallowCopyDagRunInfo(RedisAI_RunInfo **result, RedisAI_RunInfo *src) { return REDISMODULE_OK; } -void RAI_FreeDagOp(RAI_DagOp *dagOp) { - - RAI_FreeError(dagOp->err); - if (dagOp->runkey) - RedisModule_FreeString(NULL, dagOp->runkey); - - if (dagOp->outTensor) - RAI_TensorFree(dagOp->outTensor); - - if (dagOp->mctx) { - RAI_ModelRunCtxFree(dagOp->mctx); - } - if (dagOp->sctx) { - RAI_ScriptRunCtxFree(dagOp->sctx); - } - - if (dagOp->inkeys) { - for (size_t i = 0; i < array_len(dagOp->inkeys); i++) { - RedisModule_FreeString(NULL, dagOp->inkeys[i]); - } - array_free(dagOp->inkeys); - } - array_free(dagOp->inkeys_indices); - - if (dagOp->outkeys) { - for (size_t i = 0; i < array_len(dagOp->outkeys); i++) { - RedisModule_FreeString(NULL, dagOp->outkeys[i]); - } - array_free(dagOp->outkeys); - } - array_free(dagOp->outkeys_indices); - RedisModule_Free(dagOp); -} - long long RAI_DagRunInfoFreeShallowCopy(RedisAI_RunInfo *rinfo) { long long ref_count = __atomic_sub_fetch(rinfo->dagRefCount, 1, __ATOMIC_RELAXED); RedisModule_Assert(ref_count >= 0 && "Tried to free the original RunInfo object"); @@ -205,84 +141,6 @@ void RAI_ContextUnlock(RedisAI_RunInfo *rinfo) { pthread_rwlock_unlock(rinfo->dagLock); } -size_t RAI_RunInfoBatchSize(struct RAI_DagOp *op) { - if (op->mctx == NULL) { - return -1; - } - - size_t ninputs = RAI_ModelRunCtxNumInputs(op->mctx); - - int batchsize = 0; - - if (ninputs == 0) { - return batchsize; - } - - for (size_t i = 0; i < ninputs; i++) { - RAI_Tensor *input = RAI_ModelRunCtxInputTensor(op->mctx, i); - - if (i == 0) { - batchsize = RAI_TensorDim(input, 0); - continue; - } - - if (batchsize != RAI_TensorDim(input, 0)) { - batchsize = 0; - break; - } - } - - return batchsize; -} - -int RAI_RunInfoBatchable(struct RAI_DagOp *op1, struct RAI_DagOp *op2) { - - if (op1->mctx == NULL || op2->mctx == NULL) { - return 0; - } - - if (op1->mctx->model != op2->mctx->model) { - return 0; - } - - const int ninputs1 = RAI_ModelRunCtxNumInputs(op1->mctx); - const int ninputs2 = RAI_ModelRunCtxNumInputs(op2->mctx); - - if (ninputs1 != ninputs2) { - return 0; - } - - for (int i = 0; i < ninputs1; i++) { - RAI_Tensor *input1 = RAI_ModelRunCtxInputTensor(op1->mctx, i); - RAI_Tensor *input2 = RAI_ModelRunCtxInputTensor(op2->mctx, i); - - int ndims1 = RAI_TensorNumDims(input1); - int ndims2 = RAI_TensorNumDims(input2); - - if (!RAI_TensorIsDataTypeEqual(input1, input2)) { - return 0; - } - - if (ndims1 != ndims2) { - return 0; - } - - if (ndims1 == 0) { - continue; - } - - for (int j = 1; j < ndims1; j++) { - int dim1 = RAI_TensorDim(input1, j); - int dim2 = RAI_TensorDim(input2, j); - if (dim1 != dim2) { - return 0; - } - } - } - - return 1; -} - RAI_ModelRunCtx *RAI_GetAsModelRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err) { RAI_DagOp *op = rinfo->dagOps[0]; diff --git a/src/execution/run_info.h b/src/execution/run_info.h index 4af7f192b..8d3edf84a 100644 --- a/src/execution/run_info.h +++ b/src/execution/run_info.h @@ -11,60 +11,14 @@ #include "redismodule.h" #include "redis_ai_objects/err.h" -#include "redis_ai_objects/model.h" -#include "redis_ai_objects/script.h" -#include "redis_ai_objects/model_struct.h" +#include "execution/DAG/dag_op.h" #include "util/arr.h" #include "util/dict.h" -typedef enum DAGCommand { - REDISAI_DAG_CMD_NONE = 0, - REDISAI_DAG_CMD_TENSORSET, - REDISAI_DAG_CMD_TENSORGET, - REDISAI_DAG_CMD_MODELRUN, - REDISAI_DAG_CMD_SCRIPTRUN -} DAGCommand; - -enum RedisAI_DAGMode { REDISAI_DAG_READONLY_MODE = 0, REDISAI_DAG_WRITE_MODE }; - -typedef struct RAI_DagOp { - int commandType; - RedisModuleString *runkey; - RedisModuleString **inkeys; - RedisModuleString **outkeys; - size_t *inkeys_indices; - size_t *outkeys_indices; - RAI_Tensor *outTensor; // The tensor to upload in TENSORSET op. - RAI_ModelRunCtx *mctx; - RAI_ScriptRunCtx *sctx; - uint fmt; // This is relevant for TENSORGET op. - char *devicestr; - int result; // REDISMODULE_OK or REDISMODULE_ERR - long long duration_us; - RAI_Error *err; - RedisModuleString **argv; - int argc; -} RAI_DagOp; - #ifdef __cplusplus extern "C" { #endif -/** - * Allocate the memory and initialise the RAI_DagOp. - * @param result Output parameter to capture allocated RAI_DagOp. - * @return REDISMODULE_OK on success, or REDISMODULE_ERR if the allocation - * failed. - */ -int RAI_InitDagOp(RAI_DagOp **result); - -/** - * Frees the memory allocated of RAI_DagOp - * @param ctx Context in which Redis modules operate - * @param RAI_DagOp context in which RedisAI command operates. - */ -void RAI_FreeDagOp(RAI_DagOp *dagOp); - typedef struct RedisAI_RunInfo RedisAI_RunInfo; /** @@ -164,24 +118,6 @@ void RAI_ContextWriteLock(RedisAI_RunInfo *rinfo); */ void RAI_ContextUnlock(RedisAI_RunInfo *rinfo); -/** - * Obtain the batch size for the provided DAG operation, that is, the - * size of the tensor in the zero-th dimension - * @param op DAG operation to operate on - * @return size of the batch for op - */ -size_t RAI_RunInfoBatchSize(struct RAI_DagOp *op); - -/** - * Find out whether two DAG operations are batchable. That means they must be - * two MODELRUN operations with the same model, where respective inputs have - * compatible shapes (all dimensions except the zero-th must match) - * @param op1 first DAG operation - * @param op2 second DAG operation - * @return 1 if batchable, 0 otherwise - */ -int RAI_RunInfoBatchable(struct RAI_DagOp *op1, struct RAI_DagOp *op2); - /** * Retreive the ModelRunCtx of a DAG runInfo that contains a single op of type * MODELRUN. diff --git a/src/execution/utils.c b/src/execution/utils.c index 3f754d36b..a4405dc53 100644 --- a/src/execution/utils.c +++ b/src/execution/utils.c @@ -58,7 +58,5 @@ bool VerifyKeyInThisShard(RedisModuleCtx *ctx, RedisModuleString *key_str) { } } } - RedisModule_Log(ctx, "warning", "%s doesn't exist in keyspace", - RedisModule_StringPtrLen(key_str, NULL)); return true; } diff --git a/src/redis_ai_objects/model.c b/src/redis_ai_objects/model.c index afb2248da..25a0169f9 100644 --- a/src/redis_ai_objects/model.c +++ b/src/redis_ai_objects/model.c @@ -19,9 +19,8 @@ #include "util/arr.h" #include "util/dict.h" #include "util/string_utils.h" -#include "execution/run_info.h" -#include "execution/DAG/dag.h" -#include "execution/utils.h" + +extern RedisModuleType *RedisAI_ModelType; RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModuleString *tag, RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs, @@ -60,7 +59,7 @@ RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModu if (model) { if (tag) { - model->tag = RAI_HoldString(NULL, tag); + model->tag = RAI_HoldString(tag); } else { model->tag = RedisModule_CreateString(NULL, "", 0); } @@ -109,58 +108,6 @@ void RAI_ModelFree(RAI_Model *model, RAI_Error *err) { RedisModule_Free(model); } -int RAI_ModelRun(RAI_ModelRunCtx **mctxs, long long n, RAI_Error *err) { - int ret; - - if (n == 0) { - RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Nothing to run"); - return REDISMODULE_ERR; - } - - RAI_ModelRunCtx **mctxs_arr = array_newlen(RAI_ModelRunCtx *, n); - for (int i = 0; i < n; i++) { - mctxs_arr[i] = mctxs[i]; - } - - switch (mctxs_arr[0]->model->backend) { - case RAI_BACKEND_TENSORFLOW: - if (!RAI_backends.tf.model_run) { - RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TF"); - return REDISMODULE_ERR; - } - ret = RAI_backends.tf.model_run(mctxs_arr, err); - break; - case RAI_BACKEND_TFLITE: - if (!RAI_backends.tflite.model_run) { - RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TFLITE"); - return REDISMODULE_ERR; - } - ret = RAI_backends.tflite.model_run(mctxs_arr, err); - break; - case RAI_BACKEND_TORCH: - if (!RAI_backends.torch.model_run) { - RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH"); - return REDISMODULE_ERR; - } - ret = RAI_backends.torch.model_run(mctxs_arr, err); - break; - case RAI_BACKEND_ONNXRUNTIME: - if (!RAI_backends.onnx.model_run) { - RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: ONNX"); - return REDISMODULE_ERR; - } - ret = RAI_backends.onnx.model_run(mctxs_arr, err); - break; - default: - RAI_SetError(err, RAI_EUNSUPPORTEDBACKEND, "ERR Unsupported backend"); - return REDISMODULE_ERR; - } - - array_free(mctxs_arr); - - return ret; -} - RAI_Model *RAI_ModelGetShallowCopy(RAI_Model *model) { __atomic_fetch_add(&model->refCount, 1, __ATOMIC_RELAXED); return model; @@ -297,32 +244,8 @@ int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RA return REDISMODULE_OK; } -RedisModuleType *RAI_ModelRedisType(void) { return RedisAI_ModelType; } - size_t ModelGetNumInputs(RAI_Model *model) { return model->ninputs; } size_t ModelGetNumOutputs(RAI_Model *model) { return model->noutputs; } -int RAI_ModelRunAsync(RAI_ModelRunCtx *mctx, RAI_OnFinishCB ModelAsyncFinish, void *private_data) { - - RedisAI_RunInfo *rinfo = NULL; - RAI_InitRunInfo(&rinfo); - - rinfo->single_op_dag = 1; - rinfo->OnFinish = (RedisAI_OnFinishCB)ModelAsyncFinish; - rinfo->private_data = private_data; - - RAI_DagOp *op; - RAI_InitDagOp(&op); - op->commandType = REDISAI_DAG_CMD_MODELRUN; - op->devicestr = mctx->model->devicestr; - op->mctx = mctx; - - rinfo->dagOps = array_append(rinfo->dagOps, op); - rinfo->dagOpCount = 1; - if (DAG_InsertDAGToQueue(rinfo) != REDISMODULE_OK) { - RAI_FreeRunInfo(rinfo); - return REDISMODULE_ERR; - } - return REDISMODULE_OK; -} +RedisModuleType *RAI_ModelRedisType(void) { return RedisAI_ModelType; } diff --git a/src/redis_ai_objects/model.h b/src/redis_ai_objects/model.h index b87801198..5c2da3de9 100644 --- a/src/redis_ai_objects/model.h +++ b/src/redis_ai_objects/model.h @@ -17,16 +17,6 @@ #include "util/dict.h" #include "execution/run_info.h" -extern RedisModuleType *RedisAI_ModelType; - -/** - * Helper method to register the RedisModuleType type exported by the module. - * - * @param ctx Context in which Redis modules operate - * @return - */ -int RAI_ModelInit(RedisModuleCtx *ctx); - /** * Helper method to allocated and initialize a RAI_Model. Depending on the * backend it relies on either `model_create_with_nodes` or `model_create` @@ -62,24 +52,6 @@ RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModu */ void RAI_ModelFree(RAI_Model *model, RAI_Error *err); -/** - * Given the input array of mctxs, run the associated backend - * session. If the input array of model context runs is larger than one, then - * each backend's `model_run` is responsible for concatenating tensors, and run - * the model in batches with the size of the input array. On success, the - * tensors corresponding to outputs[0,noutputs-1] are placed in each - * RAI_ModelRunCtx output tensors array. Relies on each backend's `model_run` - * definition. - * - * @param mctxs array on input model contexts - * @param n length of input model contexts array - * @param error error data structure to store error message in the case of - * failures - * @return REDISMODULE_OK if the underlying backend `model_run` runned - * successfully, or REDISMODULE_ERR if failed. - */ -int RAI_ModelRun(RAI_ModelRunCtx **mctxs, long long n, RAI_Error *err); - /** * Every call to this function, will make the RAI_Model 'model' requiring an * additional call to RAI_ModelFree() in order to really free the model. @@ -142,12 +114,6 @@ int RedisAI_ModelRun_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx, Redis */ int ModelExecute_ReportKeysPositions(RedisModuleCtx *ctx, RedisModuleString **argv, int argc); -/** - * @brief Returns the redis module type representing a model. - * @return redis module type representing a model. - */ -RedisModuleType *RAI_ModelRedisType(void); - /** * @brief Returns the number of inputs in the model definition. */ @@ -157,14 +123,9 @@ size_t ModelGetNumInputs(RAI_Model *model); * @brief Returns the number of outputs in the model definition. */ size_t ModelGetNumOutputs(RAI_Model *model); + /** - * Insert the ModelRunCtx to the run queues so it will run asynchronously. - * - * @param mctx ModelRunCtx to execute - * @param ModelAsyncFinish A callback that will be called when the execution is finished. - * @param private_data This is going to be sent to to the ModelAsyncFinish. - * @return REDISMODULE_OK if the mctx was insert to the queues successfully, REDISMODULE_ERR - * otherwise. + * @brief Returns the redis module type representing a model. + * @return redis module type representing a model. */ - -int RAI_ModelRunAsync(RAI_ModelRunCtx *mctx, RAI_OnFinishCB ModelAsyncFinish, void *private_data); +RedisModuleType *RAI_ModelRedisType(void); diff --git a/src/redis_ai_objects/script.c b/src/redis_ai_objects/script.c index aa90994aa..02039b739 100644 --- a/src/redis_ai_objects/script.c +++ b/src/redis_ai_objects/script.c @@ -18,6 +18,8 @@ #include "execution/DAG/dag.h" #include "execution/run_info.h" +extern RedisModuleType *RedisAI_ScriptType; + RAI_Script *RAI_ScriptCreate(const char *devicestr, RedisModuleString *tag, const char *scriptdef, RAI_Error *err) { if (!RAI_backends.torch.script_create) { @@ -28,7 +30,7 @@ RAI_Script *RAI_ScriptCreate(const char *devicestr, RedisModuleString *tag, cons if (script) { if (tag) { - script->tag = RAI_HoldString(NULL, tag); + script->tag = RAI_HoldString(tag); } else { script->tag = RedisModule_CreateString(NULL, "", 0); } @@ -54,98 +56,6 @@ void RAI_ScriptFree(RAI_Script *script, RAI_Error *err) { RAI_backends.torch.script_free(script, err); } -RAI_ScriptRunCtx *RAI_ScriptRunCtxCreate(RAI_Script *script, const char *fnname) { -#define PARAM_INITIAL_SIZE 10 - RAI_ScriptRunCtx *sctx = RedisModule_Calloc(1, sizeof(*sctx)); - sctx->script = RAI_ScriptGetShallowCopy(script); - sctx->inputs = array_new(RAI_ScriptCtxParam, PARAM_INITIAL_SIZE); - sctx->outputs = array_new(RAI_ScriptCtxParam, PARAM_INITIAL_SIZE); - sctx->fnname = RedisModule_Strdup(fnname); - sctx->variadic = -1; - return sctx; -} - -static int _Script_RunCtxAddParam(RAI_ScriptCtxParam **paramArr, RAI_Tensor *tensor) { - RAI_ScriptCtxParam param = { - .tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL, - }; - *paramArr = array_append(*paramArr, param); - return 1; -} - -int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor, RAI_Error *error) { - // Even if variadic is set, we still allow to add inputs in the LLAPI - return _Script_RunCtxAddParam(&sctx->inputs, inputTensor); -} - -int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx *sctx, RAI_Tensor **inputTensors, size_t len, - RAI_Error *err) { - if (sctx->variadic == -1) { - sctx->variadic = array_len(sctx->inputs); - } else { - RAI_SetError(err, RAI_EBACKENDNOTLOADED, - "ERR Already encountered a variable size list of tensors"); - return 0; - } - - int res; - for (size_t i = 0; i < len; i++) { - res = _Script_RunCtxAddParam(&sctx->inputs, inputTensors[i]); - if (res != 1) - return res; - } - return 1; -} - -int RAI_ScriptRunCtxAddOutput(RAI_ScriptRunCtx *sctx) { - return _Script_RunCtxAddParam(&sctx->outputs, NULL); -} - -size_t RAI_ScriptRunCtxNumOutputs(RAI_ScriptRunCtx *sctx) { return array_len(sctx->outputs); } - -RAI_Tensor *RAI_ScriptRunCtxOutputTensor(RAI_ScriptRunCtx *sctx, size_t index) { - assert(RAI_ScriptRunCtxNumOutputs(sctx) > index && index >= 0); - return sctx->outputs[index].tensor; -} - -void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx *sctx) { - - for (size_t i = 0; i < array_len(sctx->inputs); ++i) { - RAI_TensorFree(sctx->inputs[i].tensor); - } - - for (size_t i = 0; i < array_len(sctx->outputs); ++i) { - if (sctx->outputs[i].tensor) { - RAI_TensorFree(sctx->outputs[i].tensor); - } - } - - array_free(sctx->inputs); - array_free(sctx->outputs); - - RedisModule_Free(sctx->fnname); - - RAI_Error err = {0}; - RAI_ScriptFree(sctx->script, &err); - - if (err.code != RAI_OK) { - // TODO: take it to client somehow - printf("ERR: %s\n", err.detail); - RAI_ClearError(&err); - } - - RedisModule_Free(sctx); -} - -int RAI_ScriptRun(RAI_ScriptRunCtx *sctx, RAI_Error *err) { - if (!RAI_backends.torch.script_run) { - RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH"); - return REDISMODULE_ERR; - } - - return RAI_backends.torch.script_run(sctx, err); -} - RAI_Script *RAI_ScriptGetShallowCopy(RAI_Script *script) { __atomic_fetch_add(&script->refCount, 1, __ATOMIC_RELAXED); return script; @@ -157,10 +67,23 @@ RAI_Script *RAI_ScriptGetShallowCopy(RAI_Script *script) { int RAI_GetScriptFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RAI_Script **script, int mode, RAI_Error *err) { RedisModuleKey *key = RedisModule_OpenKey(ctx, keyName, mode); + if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { RedisModule_CloseKey(key); +#ifndef LITE + RedisModule_Log(ctx, "warning", "could not load %s from keyspace, key doesn't exist", + RedisModule_StringPtrLen(keyName, NULL)); RAI_SetError(err, RAI_EKEYEMPTY, "ERR script key is empty"); return REDISMODULE_ERR; +#else + if (VerifyKeyInThisShard(ctx, keyName)) { // Relevant for enterprise cluster. + RAI_SetError(err, RAI_EKEYEMPTY, "ERR script key is empty"); + } else { + RAI_SetError(err, RAI_EKEYEMPTY, + "ERR CROSSSLOT Keys in request don't hash to the same slot"); + } +#endif + return REDISMODULE_ERR; } if (RedisModule_ModuleTypeGetType(key) != RedisAI_ScriptType) { RedisModule_CloseKey(key); @@ -197,30 +120,60 @@ int RedisAI_ScriptRun_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx, return REDISMODULE_OK; } -RedisModuleType *RAI_ScriptRedisType(void) { return RedisAI_ScriptType; } - -int RAI_ScriptRunAsync(RAI_ScriptRunCtx *sctx, RAI_OnFinishCB ScriptAsyncFinish, - void *private_data) { - - RedisAI_RunInfo *rinfo = NULL; - RAI_InitRunInfo(&rinfo); - - rinfo->single_op_dag = 1; - rinfo->OnFinish = (RedisAI_OnFinishCB)ScriptAsyncFinish; - rinfo->private_data = private_data; - - RAI_DagOp *op; - RAI_InitDagOp(&op); - - op->commandType = REDISAI_DAG_CMD_SCRIPTRUN; - op->devicestr = sctx->script->devicestr; - op->sctx = sctx; - - rinfo->dagOps = array_append(rinfo->dagOps, op); - rinfo->dagOpCount = 1; - if (DAG_InsertDAGToQueue(rinfo) != REDISMODULE_OK) { - RAI_FreeRunInfo(rinfo); +int RedisAI_ScriptExecute_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx, + RedisModuleString **argv, int argc) { + // AI.SCRIPTEXECUTE script_name func KEYS n key.... + if (argc < 6) { return REDISMODULE_ERR; } - return REDISMODULE_OK; + RedisModule_KeyAtPos(ctx, 1); + size_t argpos = 3; + long long count; + while (argpos < argc) { + const char *str = RedisModule_StringPtrLen(argv[argpos++], NULL); + + // Inputs, outpus, keys, lists. + if ((!strcasecmp(str, "INPUTS")) || (!strcasecmp(str, "OUTPUTS")) || + (!strcasecmp(str, "LIST_INPUTS")) || (!strcasecmp(str, "KEYS"))) { + bool updateKeyAtPos = false; + // The only scope where the inputs strings are 100% keys are in the KEYS and OUTPUTS + // scopes. + if ((!strcasecmp(str, "OUTPUTS")) || (!strcasecmp(str, "KEYS"))) { + updateKeyAtPos = true; + } + if (argpos >= argc) { + return REDISMODULE_ERR; + } + if (RedisModule_StringToLongLong(argv[argpos++], &count) != REDISMODULE_OK) { + return REDISMODULE_ERR; + } + if (count <= 0) { + return REDISMODULE_ERR; + } + if (argpos + count >= argc) { + return REDISMODULE_ERR; + } + for (long long i = 0; i < count; i++) { + if (updateKeyAtPos) { + RedisModule_KeyAtPos(ctx, argpos); + } + argpos++; + } + continue; + } + // Timeout + if (!strcasecmp(str, "TIMEOUT")) { + argpos++; + break; + } + // Undefinded input. + return REDISMODULE_ERR; + } + if (argpos != argc) { + return REDISMODULE_ERR; + } else { + return REDISMODULE_OK; + } } + +RedisModuleType *RAI_ScriptRedisType(void) { return RedisAI_ScriptType; } diff --git a/src/redis_ai_objects/script.h b/src/redis_ai_objects/script.h index 0864714f9..f070124a7 100644 --- a/src/redis_ai_objects/script.h +++ b/src/redis_ai_objects/script.h @@ -9,21 +9,9 @@ #pragma once #include "err.h" -#include "redismodule.h" -#include "script_struct.h" #include "tensor.h" -#include "config/config.h" -#include "execution/run_info.h" - -extern RedisModuleType *RedisAI_ScriptType; - -/** - * Helper method to register the script type exported by the module. - * - * @param ctx Context in which Redis modules operate - * @return - */ -int RAI_ScriptInit(RedisModuleCtx *ctx); +#include "script_struct.h" +#include "redismodule.h" /** * Helper method to allocated and initialize a RAI_Script. Relies on Pytorch @@ -49,89 +37,6 @@ RAI_Script *RAI_ScriptCreate(const char *devicestr, RedisModuleString *tag, cons */ void RAI_ScriptFree(RAI_Script *script, RAI_Error *err); -/** - * Allocates the RAI_ScriptRunCtx data structure required for async background - * work within `RedisAI_RunInfo` structure on RedisAI blocking commands - * - * @param script input script - * @param fnname function name to used from the script - * @return RAI_ScriptRunCtx to be used within - */ -RAI_ScriptRunCtx *RAI_ScriptRunCtxCreate(RAI_Script *script, const char *fnname); - -/** - * Allocates a RAI_ScriptCtxParam data structure, and enforces a shallow copy of - * the provided input tensor, adding it to the input tensors array of the - * RAI_ScriptRunCtx. - * - * @param sctx input RAI_ScriptRunCtx to add the input tensor - * @param inputTensor input tensor structure - * @return returns 1 on success, 0 in case of error. - */ -int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor, RAI_Error *error); - -/** - * For each Allocates a RAI_ScriptCtxParam data structure, and enforces a - * shallow copy of the provided input tensor, adding it to the input tensors - * array of the RAI_ScriptRunCtx. - * - * @param sctx input RAI_ScriptRunCtx to add the input tensor - * @param inputTensors input tensors array - * @param len input tensors array len - * @return returns 1 on success, 0 in case of error. - */ -int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx *sctx, RAI_Tensor **inputTensors, size_t len, - RAI_Error *error); - -/** - * Allocates a RAI_ScriptCtxParam data structure, and sets the tensor reference - * to NULL ( will be set after SCRIPTRUN ), adding it to the outputs tensors - * array of the RAI_ScriptRunCtx. - * - * @param sctx input RAI_ScriptRunCtx to add the output tensor - * @return returns 1 on success ( always returns success ) - */ -int RAI_ScriptRunCtxAddOutput(RAI_ScriptRunCtx *sctx); - -/** - * Returns the total number of output tensors of the RAI_ScriptRunCtx - * - * @param sctx RAI_ScriptRunCtx - * @return the total number of output tensors of the RAI_ScriptRunCtx - */ -size_t RAI_ScriptRunCtxNumOutputs(RAI_ScriptRunCtx *sctx); - -/** - * Get the RAI_Tensor at the output array index position - * - * @param sctx RAI_ScriptRunCtx - * @param index input array index position - * @return RAI_Tensor - */ -RAI_Tensor *RAI_ScriptRunCtxOutputTensor(RAI_ScriptRunCtx *sctx, size_t index); - -/** - * Frees the RAI_ScriptRunCtx data structure used within for async background - * work - * - * @param sctx - */ -void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx *sctx); - -/** - * Given the input script context, run associated script - * session. On success, the tensors corresponding to outputs[0,noutputs-1] are - * placed in the RAI_ScriptRunCtx output tensors array. Relies on PyTorch's - * `script_run` definition. - * - * @param sctx input script context - * @param error error data structure to store error message in the case of - * failures - * @return REDISMODULE_OK if the underlying backend `script_run` ran - * successfully, or REDISMODULE_ERR if failed. - */ -int RAI_ScriptRun(RAI_ScriptRunCtx *sctx, RAI_Error *err); - /** * Every call to this function, will make the RAI_Script 'script' requiring an * additional call to RAI_ScriptFree() in order to really free the script. @@ -176,31 +81,23 @@ int RAI_GetScriptFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, R int RedisAI_ScriptRun_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx, RedisModuleString **argv, int argc); -#if 0 /** - * Helper method to reply if the ctx is not NULL or fallback and set the error in the RAI_Error structure + * When a module command is called in order to obtain the position of + * keys, since it was flagged as "getkeys-api" during the registration, + * the command implementation checks for this special call using the + * RedisModule_IsKeysPositionRequest() API and uses this function in + * order to report keys. + * No real execution is done on this special call. * @param ctx Context in which Redis modules operate - * @param error the RAI_Error data structure to be populated with the error details in case ctx is NULL - * @param code the error code - * @param errorMessage the error detail + * @param argv Redis command arguments, as an array of strings + * @param argc Redis command number of arguments + * @return */ -void RedisAI_ReplyOrSetError(RedisModuleCtx *ctx, RAI_Error *error, RAI_ErrorCode code, const char* errorMessage ); -#endif +int RedisAI_ScriptExecute_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx, + RedisModuleString **argv, int argc); /** * @brief Returns the redis module type representing a script. * @return redis module type representing a script. */ RedisModuleType *RAI_ScriptRedisType(void); - -/** - * Insert the ScriptRunCtx to the run queues so it will run asynchronously. - * - * @param sctx SodelRunCtx to execute - * @param ScriptAsyncFinish A callback that will be called when the execution is finished. - * @param private_data This is going to be sent to to the ScriptAsyncFinish. - * @return REDISMODULE_OK if the sctx was insert to the queues successfully, REDISMODULE_ERR - * otherwise. - */ -int RAI_ScriptRunAsync(RAI_ScriptRunCtx *sctx, RAI_OnFinishCB ScriptAsyncFinish, - void *private_data); diff --git a/src/redis_ai_objects/script_struct.h b/src/redis_ai_objects/script_struct.h index a9759b2fa..c069ba65a 100644 --- a/src/redis_ai_objects/script_struct.h +++ b/src/redis_ai_objects/script_struct.h @@ -2,6 +2,19 @@ #include "config/config.h" #include "tensor_struct.h" +#include "util/dict.h" + +typedef enum { + UNKOWN, + TENSOR, + INT, + FLOAT, + STRING, + TENSOR_LIST, + INT_LIST, + FLOAT_LIST, + STRING_LIST +} TorchScriptFunctionArgumentType; typedef struct RAI_Script { void *script; @@ -14,6 +27,8 @@ typedef struct RAI_Script { RedisModuleString *tag; long long refCount; void *infokey; + AI_dict *functionData; // A dict to map between + // function name, and its schema. } RAI_Script; typedef struct RAI_ScriptCtxParam { @@ -26,5 +41,8 @@ typedef struct RAI_ScriptRunCtx { char *fnname; RAI_ScriptCtxParam *inputs; RAI_ScriptCtxParam *outputs; - int variadic; + size_t *listSizes; + int32_t *intInputs; + float *floatInputs; + RedisModuleString **stringInputs; } RAI_ScriptRunCtx; diff --git a/src/redis_ai_objects/stats.c b/src/redis_ai_objects/stats.c index 6e1c7c9d2..eaa51de23 100644 --- a/src/redis_ai_objects/stats.c +++ b/src/redis_ai_objects/stats.c @@ -27,11 +27,11 @@ void *RAI_AddStatsEntry(RedisModuleCtx *ctx, RedisModuleString *key, RAI_RunType RAI_Backend backend, const char *devicestr, RedisModuleString *tag) { struct RedisAI_RunStats *rstats = NULL; rstats = RedisModule_Calloc(1, sizeof(struct RedisAI_RunStats)); - rstats->key = RAI_HoldString(NULL, key); + rstats->key = RAI_HoldString(key); rstats->type = runtype; rstats->backend = backend; rstats->devicestr = RedisModule_Strdup(devicestr); - rstats->tag = RAI_HoldString(NULL, tag); + rstats->tag = RAI_HoldString(tag); AI_dictAdd(run_stats, (void *)key, (void *)rstats); diff --git a/src/redis_ai_objects/tensor.c b/src/redis_ai_objects/tensor.c index 7cd449f13..1cc648987 100644 --- a/src/redis_ai_objects/tensor.c +++ b/src/redis_ai_objects/tensor.c @@ -22,6 +22,8 @@ #include "util/string_utils.h" #include "execution/utils.h" +extern RedisModuleType *RedisAI_TensorType; + DLDataType RAI_TensorDataTypeFromString(const char *typestr) { if (strcasecmp(typestr, RAI_DATATYPE_STR_FLOAT) == 0) { return (DLDataType){.code = kDLFloat, .bits = 32, .lanes = 1}; @@ -193,7 +195,7 @@ RAI_Tensor *_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, size_t dtype } char *data = RedisModule_Alloc(nbytes); memcpy(data, blob, nbytes); - RAI_HoldString(NULL, rstr); + RAI_HoldString(rstr); RAI_Tensor *ret = RAI_TensorNew(); ret->tensor = (DLManagedTensor){.dl_tensor = (DLTensor){.device = device, @@ -559,10 +561,10 @@ int RAI_GetTensorFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, R if (RedisModule_KeyType(*key) == REDISMODULE_KEYTYPE_EMPTY) { RedisModule_CloseKey(*key); if (VerifyKeyInThisShard(ctx, keyName)) { // Relevant for enterprise cluster. - RAI_SetError(err, RAI_EKEYEMPTY, "ERR tensor key is empty"); + RAI_SetError(err, RAI_EKEYEMPTY, "ERR tensor key is empty or in a different shard"); } else { RAI_SetError(err, RAI_EKEYEMPTY, - "ERR CROSSSLOT Keys in request don't hash to the same slot"); + "ERR CROSSSLOT Tensor key in request don't hash to the same slot"); } return REDISMODULE_ERR; } diff --git a/src/redis_ai_objects/tensor.h b/src/redis_ai_objects/tensor.h index 00a0b2ba0..f220ba25c 100644 --- a/src/redis_ai_objects/tensor.h +++ b/src/redis_ai_objects/tensor.h @@ -38,8 +38,6 @@ static const char *RAI_DATATYPE_STR_UINT16 = "UINT16"; #define TENSOR_BLOB (1 << 2) #define TENSOR_ILLEGAL_VALUES_BLOB (TENSOR_VALUES | TENSOR_BLOB) -extern RedisModuleType *RedisAI_TensorType; - /** * Helper method to register the tensor type exported by the module. * diff --git a/src/redis_ai_types/model_type.h b/src/redis_ai_types/model_type.h index 14723ca39..72913d004 100644 --- a/src/redis_ai_types/model_type.h +++ b/src/redis_ai_types/model_type.h @@ -2,4 +2,4 @@ #include "redismodule.h" -int ModelType_Register(RedisModuleCtx *ctx); \ No newline at end of file +int ModelType_Register(RedisModuleCtx *ctx); diff --git a/src/redis_ai_types/script_type.h b/src/redis_ai_types/script_type.h index 1f91c0bdd..f1d5825e6 100644 --- a/src/redis_ai_types/script_type.h +++ b/src/redis_ai_types/script_type.h @@ -2,4 +2,4 @@ #include "redismodule.h" -int ScriptType_Register(RedisModuleCtx *ctx); \ No newline at end of file +int ScriptType_Register(RedisModuleCtx *ctx); diff --git a/src/redisai.c b/src/redisai.c index e0f308b50..c2a81818a 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -14,9 +14,10 @@ #include "execution/DAG/dag.h" #include "execution/DAG/dag_builder.h" #include "execution/DAG/dag_execute.h" -#include "execution/deprecated.h" +#include "execution/parsing/deprecated.h" #include "redis_ai_objects/model.h" -#include "execution/modelRun_ctx.h" +#include "execution/execution_contexts/modelRun_ctx.h" +#include "execution/execution_contexts/scriptRun_ctx.h" #include "redis_ai_objects/script.h" #include "redis_ai_objects/stats.h" #include @@ -29,7 +30,7 @@ #include "rmutil/alloc.h" #include "rmutil/args.h" #include "util/arr.h" -#include "util/dict.h" +#include "util/dictionaries.h" #include "util/string_utils.h" #include "util/queue.h" #include "version.h" @@ -91,7 +92,7 @@ int RedisAI_TensorSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv return REDISMODULE_ERR; } - if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, t) != REDISMODULE_OK) { + if (RedisModule_ModuleTypeSetValue(key, RAI_TensorRedisType(), t) != REDISMODULE_OK) { RAI_TensorFree(t); RedisModule_CloseKey(key); return RedisModule_ReplyWithError(ctx, "ERR could not save tensor"); @@ -368,7 +369,7 @@ int RedisAI_ModelStore_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg int type = RedisModule_KeyType(key); if (type != REDISMODULE_KEYTYPE_EMPTY && !(type == REDISMODULE_KEYTYPE_MODULE && - RedisModule_ModuleTypeGetType(key) == RedisAI_ModelType)) { + RedisModule_ModuleTypeGetType(key) == RAI_ModelRedisType())) { RedisModule_CloseKey(key); RAI_ModelFree(model, &err); if (err.code != RAI_OK) { @@ -380,7 +381,7 @@ int RedisAI_ModelStore_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); } - RedisModule_ModuleTypeSetValue(key, RedisAI_ModelType, model); + RedisModule_ModuleTypeSetValue(key, RAI_ModelRedisType(), model); model->infokey = RAI_AddStatsEntry(ctx, keystr, RAI_MODEL, backend, devicestr, tag); @@ -623,6 +624,15 @@ int RedisAI_ScriptRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv return RedisAI_ExecuteCommand(ctx, argv, argc, CMD_SCRIPTRUN, false); } +int RedisAI_ScriptExecute_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + if (RedisModule_IsKeysPositionRequest(ctx)) { + return RedisAI_ScriptExecute_IsKeysPositionRequest_ReportKeys(ctx, argv, argc); + } + + // Convert The script run command into a DAG command that contains a single op. + return RedisAI_ExecuteCommand(ctx, argv, argc, CMD_SCRIPTEXECUTE, false); +} + /** * AI.SCRIPTGET script_key [META] [SOURCE] */ @@ -780,12 +790,12 @@ int RedisAI_ScriptSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv int type = RedisModule_KeyType(key); if (type != REDISMODULE_KEYTYPE_EMPTY && !(type == REDISMODULE_KEYTYPE_MODULE && - RedisModule_ModuleTypeGetType(key) == RedisAI_ScriptType)) { + RedisModule_ModuleTypeGetType(key) == RAI_ScriptRedisType())) { RedisModule_CloseKey(key); return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE); } - RedisModule_ModuleTypeSetValue(key, RedisAI_ScriptType, script); + RedisModule_ModuleTypeSetValue(key, RAI_ScriptRedisType(), script); script->infokey = RAI_AddStatsEntry(ctx, keystr, RAI_SCRIPT, RAI_BACKEND_TORCH, devicestr, tag); @@ -1087,7 +1097,17 @@ static int RedisAI_RegisterApi(RedisModuleCtx *ctx) { REGISTER_API(ScriptFree, ctx); REGISTER_API(ScriptRunCtxCreate, ctx); REGISTER_API(ScriptRunCtxAddInput, ctx); + REGISTER_API(ScriptRunCtxAddTensorInput, ctx); + REGISTER_API(ScriptRunCtxAddIntInput, ctx); + REGISTER_API(ScriptRunCtxAddFloatInput, ctx); + REGISTER_API(ScriptRunCtxAddRStringInput, ctx); + REGISTER_API(ScriptRunCtxAddStringInput, ctx); REGISTER_API(ScriptRunCtxAddInputList, ctx); + REGISTER_API(ScriptRunCtxAddTensorInputList, ctx); + REGISTER_API(ScriptRunCtxAddIntInputList, ctx); + REGISTER_API(ScriptRunCtxAddFloatInputList, ctx); + REGISTER_API(ScriptRunCtxAddRStringInputList, ctx); + REGISTER_API(ScriptRunCtxAddStringInputList, ctx); REGISTER_API(ScriptRunCtxAddOutput, ctx); REGISTER_API(ScriptRunCtxNumOutputs, ctx); REGISTER_API(ScriptRunCtxOutputTensor, ctx); @@ -1332,6 +1352,10 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) "write deny-oom getkeys-api", 4, 4, 1) == REDISMODULE_ERR) return REDISMODULE_ERR; + if (RedisModule_CreateCommand(ctx, "ai.scriptexecute", RedisAI_ScriptExecute_RedisCommand, + "write deny-oom getkeys-api", 5, 5, 1) == REDISMODULE_ERR) + return REDISMODULE_ERR; + if (RedisModule_CreateCommand(ctx, "ai._scriptscan", RedisAI_ScriptScan_RedisCommand, "readonly", 1, 1, 1) == REDISMODULE_ERR) return REDISMODULE_ERR; diff --git a/src/redisai.h b/src/redisai.h index c2bcac547..5a32c33f4 100644 --- a/src/redisai.h +++ b/src/redisai.h @@ -139,12 +139,48 @@ REDISAI_API int MODULE_API_FUNC(RedisAI_GetScriptFromKeyspace)(RedisModuleCtx *c REDISAI_API void MODULE_API_FUNC(RedisAI_ScriptFree)(RAI_Script *script, RAI_Error *err); REDISAI_API RAI_ScriptRunCtx *MODULE_API_FUNC(RedisAI_ScriptRunCtxCreate)(RAI_Script *script, const char *fnname); +// Deprecated, use RedisAI_ScriptRunCtxAddInputTensor instead. REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddInput)(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor, RAI_Error *err); + +// Deprecated, use RedisAI_ScriptRunCtxAddTensorInputList instead. REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddInputList)(RAI_ScriptRunCtx *sctx, RAI_Tensor **inputTensors, size_t len, RAI_Error *err); + +REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddTensorInput)(RAI_ScriptRunCtx *sctx, + RAI_Tensor *inputTensor); + +REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddIntInput)(RAI_ScriptRunCtx *sctx, int32_t i); + +REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddFloatInput)(RAI_ScriptRunCtx *sctx, float f); + +REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddRStringInput)(RAI_ScriptRunCtx *sctx, + RedisModuleString *s); + +REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddStringInput)(RAI_ScriptRunCtx *sctx, + const char *s, size_t len); + +REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddTensorInputList)(RAI_ScriptRunCtx *sctx, + RAI_Tensor **inputTensors, + size_t count); + +REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddIntInputList)(RAI_ScriptRunCtx *sctx, + int32_t *intInputs, + size_t count); + +REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddFloatInputList)(RAI_ScriptRunCtx *sctx, + float *floatInputs, + size_t count); + +REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddRStringInputList)( + RAI_ScriptRunCtx *sctx, RedisModuleString **stringInputs, size_t count); + +REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddStringInputList)(RAI_ScriptRunCtx *sctx, + const char **stringInputs, + size_t *lens, size_t count); + REDISAI_API int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddOutput)(RAI_ScriptRunCtx *sctx); REDISAI_API size_t MODULE_API_FUNC(RedisAI_ScriptRunCtxNumOutputs)(RAI_ScriptRunCtx *sctx); REDISAI_API RAI_Tensor *MODULE_API_FUNC(RedisAI_ScriptRunCtxOutputTensor)(RAI_ScriptRunCtx *sctx, @@ -264,7 +300,17 @@ static int RedisAI_Initialize(RedisModuleCtx *ctx) { REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptFree); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxCreate); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddInput); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddTensorInput); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddIntInput); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddFloatInput); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddRStringInput); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddStringInput); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddInputList); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddTensorInputList); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddIntInputList); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddFloatInputList); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddRStringInputList); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddStringInputList); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddOutput); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxNumOutputs); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxOutputTensor); diff --git a/src/util/arr.h b/src/util/arr.h index 5f8c22d9f..fd4427189 100644 --- a/src/util/arr.h +++ b/src/util/arr.h @@ -199,3 +199,10 @@ static void array_free(array_t arr) { RedisModule_Assert(array_hdr(arr)->len > 0); \ arr[--(array_hdr(arr)->len)]; \ }) + +/* Duplicate the array to the pointer dest. */ +#define array_clone(dest, arr) \ + ({ \ + dest = array_newlen((array_hdr(arr)->elem_sz), array_len(arr)); \ + memcpy(dest, arr, (array_hdr(arr)->elem_sz) * (array_len(arr))); \ + }) diff --git a/src/util/dictionaries.c b/src/util/dictionaries.c new file mode 100644 index 000000000..3b2f684f1 --- /dev/null +++ b/src/util/dictionaries.c @@ -0,0 +1,38 @@ +#include "dictionaries.h" +#include "string_utils.h" +#include "arr.h" + +static array_t dict_arr_clone_fn(void *privdata, const void *arr) { + array_t dest; + array_clone(dest, (array_t)arr); + return dest; +} + +static void dict_arr_free_fn(void *privdata, void *arr) { array_free(arr); } + +AI_dictType AI_dictTypeHeapStrings = { + .hashFunction = RAI_StringsHashFunction, + .keyDup = RAI_StringsKeyDup, + .valDup = NULL, + .keyCompare = RAI_StringsKeyCompare, + .keyDestructor = RAI_StringsKeyDestructor, + .valDestructor = NULL, +}; + +AI_dictType AI_dictTypeHeapRStrings = { + .hashFunction = RAI_RStringsHashFunction, + .keyDup = RAI_RStringsKeyDup, + .valDup = NULL, + .keyCompare = RAI_RStringsKeyCompare, + .keyDestructor = RAI_RStringsKeyDestructor, + .valDestructor = NULL, +}; + +AI_dictType AI_dictType_String_ArrSimple = { + .hashFunction = RAI_StringsHashFunction, + .keyDup = RAI_StringsKeyDup, + .valDup = dict_arr_clone_fn, + .keyCompare = RAI_StringsKeyCompare, + .keyDestructor = RAI_StringsKeyDestructor, + .valDestructor = dict_arr_free_fn, +}; diff --git a/src/util/dictionaries.h b/src/util/dictionaries.h new file mode 100644 index 000000000..fe7233419 --- /dev/null +++ b/src/util/dictionaries.h @@ -0,0 +1,20 @@ +#pragma once +#include "dict.h" + +/** + * @brief Dictionary key type: const char*. value type: void*. + * + */ +extern AI_dictType AI_dictTypeHeapStrings; + +/** + * @brief Dictionary key type: RedisModuleString*. value type: void*. + * + */ +extern AI_dictType AI_dictTypeHeapRStrings; + +/** + * @brief Dictionary key type: const char*. value type: arr. + * + */ +extern AI_dictType AI_dictType_String_ArrSimple; diff --git a/src/util/string_utils.c b/src/util/string_utils.c index 019367721..6b567c3ed 100644 --- a/src/util/string_utils.c +++ b/src/util/string_utils.c @@ -3,7 +3,7 @@ #include #include "util/redisai_memory.h" -RedisModuleString *RAI_HoldString(RedisModuleCtx *ctx, RedisModuleString *str) { +RedisModuleString *RAI_HoldString(RedisModuleString *str) { if (str == NULL) { return NULL; } @@ -32,24 +32,6 @@ void RAI_StringsKeyDestructor(void *privdata, void *key) { RA_FREE(key); } void *RAI_StringsKeyDup(void *privdata, const void *key) { return RA_STRDUP((char *)key); } -AI_dictType AI_dictTypeHeapStringsVals = { - .hashFunction = RAI_StringsHashFunction, - .keyDup = RAI_StringsKeyDup, - .valDup = NULL, - .keyCompare = RAI_StringsKeyCompare, - .keyDestructor = RAI_StringsKeyDestructor, - .valDestructor = RAI_StringsKeyDestructor, -}; - -AI_dictType AI_dictTypeHeapStrings = { - .hashFunction = RAI_StringsHashFunction, - .keyDup = RAI_StringsKeyDup, - .valDup = NULL, - .keyCompare = RAI_StringsKeyCompare, - .keyDestructor = RAI_StringsKeyDestructor, - .valDestructor = NULL, -}; - uint64_t RAI_RStringsHashFunction(const void *key) { size_t len; const char *buffer = RedisModule_StringPtrLen((RedisModuleString *)key, &len); @@ -70,21 +52,3 @@ void RAI_RStringsKeyDestructor(void *privdata, void *key) { void *RAI_RStringsKeyDup(void *privdata, const void *key) { return RedisModule_CreateStringFromString(NULL, (RedisModuleString *)key); } - -AI_dictType AI_dictTypeHeapRStringsVals = { - .hashFunction = RAI_RStringsHashFunction, - .keyDup = RAI_RStringsKeyDup, - .valDup = NULL, - .keyCompare = RAI_RStringsKeyCompare, - .keyDestructor = RAI_RStringsKeyDestructor, - .valDestructor = RAI_RStringsKeyDestructor, -}; - -AI_dictType AI_dictTypeHeapRStrings = { - .hashFunction = RAI_RStringsHashFunction, - .keyDup = RAI_RStringsKeyDup, - .valDup = NULL, - .keyCompare = RAI_RStringsKeyCompare, - .keyDestructor = RAI_RStringsKeyDestructor, - .valDestructor = NULL, -}; diff --git a/src/util/string_utils.h b/src/util/string_utils.h index 53338e9ef..835fc45e6 100644 --- a/src/util/string_utils.h +++ b/src/util/string_utils.h @@ -1,7 +1,7 @@ #include "redismodule.h" #include "dict.h" -RedisModuleString *RAI_HoldString(RedisModuleCtx *ctx, RedisModuleString *str); +RedisModuleString *RAI_HoldString(RedisModuleString *str); uint64_t RAI_StringsHashFunction(const void *key); int RAI_StringsKeyCompare(void *privdata, const void *key1, const void *key2); @@ -12,9 +12,3 @@ uint64_t RAI_RStringsHashFunction(const void *key); int RAI_RStringsKeyCompare(void *privdata, const void *key1, const void *key2); void RAI_RStringsKeyDestructor(void *privdata, void *key); void *RAI_RStringsKeyDup(void *privdata, const void *key); - -extern AI_dictType AI_dictTypeHeapStrings; -extern AI_dictType AI_dictTypeHeapStringsVals; - -extern AI_dictType AI_dictTypeHeapRStrings; -extern AI_dictType AI_dictTypeHeapRStringsVals; diff --git a/tests/flow/includes.py b/tests/flow/includes.py index 28e274258..09a4f8eb8 100755 --- a/tests/flow/includes.py +++ b/tests/flow/includes.py @@ -202,3 +202,11 @@ def check_error_message(env, con, error_msg, *command): exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) env.assertEqual(error_msg, str(exception)) + +def check_error(env, con, *command): + try: + con.execute_command(*command) + env.assertFalse(True) + except Exception as e: + exception = e + env.assertEqual(type(exception), redis.exceptions.ResponseError) diff --git a/tests/flow/pt_api_test.c b/tests/flow/pt_api_test.c deleted file mode 100644 index 48e1809b4..000000000 --- a/tests/flow/pt_api_test.c +++ /dev/null @@ -1,44 +0,0 @@ -#include -#include -#include -#include "dlpack/dlpack.h" - -int main() { - printf("Hello from libTorch C library version %s\n", "1.0"); - torchBasicTest(); - - void *ctx; - char *err = NULL; - - const char script[] = "\n\ - def foo(a): \n\ - return a * 2 \n\ - "; - - ctx = torchCompileScript(script, kDLCPU, &err); - if (err) { - printf("ERR: %s\n", err); - free(err); - err = NULL; - return 1; - } - printf("Compiled: %p\n", ctx); - - DLDataType dtype = (DLDataType){.code = kDLFloat, .bits = 32, .lanes = 1}; - int64_t shape[1] = {1}; - int64_t strides[1] = {1}; - char data[4] = "0000"; - DLManagedTensor *input = torchNewTensor(dtype, 1, shape, strides, data); - DLManagedTensor *output; - torchRunScript(ctx, "foo", 1, &input, 1, &output, &err); - if (err) { - printf("ERR: %s\n", err); - free(err); - return 1; - } - - torchDeallocContext(ctx); - printf("Deallocated\n"); - - return 0; -} diff --git a/tests/flow/test_data/redis_scripts.py b/tests/flow/test_data/redis_scripts.py index c00810d9c..6fa402996 100644 --- a/tests/flow/test_data/redis_scripts.py +++ b/tests/flow/test_data/redis_scripts.py @@ -22,45 +22,53 @@ def redis_hash_to_tensor(redis_value: Any): l = [torch.tensor(int(str(v))).reshape(1,1) for v in values] return torch.cat(l, dim=0) -def test_redis_error(): - redis.execute("SET", "x") +def test_redis_error(key:str): + redis.execute("SET", key) -def test_int_set_get(): - redis.execute("SET", "x", "1") - res = redis.execute("GET", "x",) - redis.execute("DEL", "x") +def test_int_set_get(key:str, value:int): + redis.execute("SET", key, str(value)) + res = redis.execute("GET", key) + redis.execute("DEL", key) return redis_string_int_to_tensor(res) -def test_int_set_incr(): - redis.execute("SET", "x", "1") - res = redis.execute("INCR", "x") - redis.execute("DEL", "x") +def test_int_set_incr(key:str, value:int): + redis.execute("SET", key, str(value)) + res = redis.execute("INCR", key) + redis.execute("DEL", key) return redis_string_int_to_tensor(res) -def test_float_set_get(): - redis.execute("SET", "x", "1.1") - res = redis.execute("GET", "x",) - redis.execute("DEL", "x") +def test_float_set_get(key:str, value:float): + redis.execute("SET", key, str(value)) + res = redis.execute("GET", key) + redis.execute("DEL", key) return redis_string_float_to_tensor(res) -def test_int_list(): - redis.execute("RPUSH", "x", "1") - redis.execute("RPUSH", "x", "2") - res = redis.execute("LRANGE", "x", "0", "2") - redis.execute("DEL", "x") +def test_int_list(key:str, l:List[str]): + for value in l: + redis.execute("RPUSH", key, value) + res = redis.execute("LRANGE", key, "0", str(len(l))) + redis.execute("DEL", key) return redis_int_list_to_tensor(res) -def test_hash(): - redis.execute("HSET", "x", "field1", "1", "field2", "2") - res = redis.execute("HVALS", "x") - redis.execute("DEL", "x") +def test_str_list(key:str, l:List[str]): + for value in l: + redis.execute("RPUSH", key, value) + + +def test_hash(key:str, l:List[str]): + args = [key] + for s in l: + args.append(s) + redis.execute("HSET", args) + res = redis.execute("HVALS", key) + redis.execute("DEL", key) return redis_hash_to_tensor(res) -def test_set_key(): - redis.execute("SET", ["x{1}", "1"]) +def test_set_key(key:str, value:str): + redis.execute("SET", [key, value]) -def test_del_key(): - redis.execute("DEL", ["x"]) +def test_del_key(key:str): + redis.execute("DEL", [key]) diff --git a/tests/flow/test_data/script.txt b/tests/flow/test_data/script.txt index 34e4b9317..f6cd1d6a2 100644 --- a/tests/flow/test_data/script.txt +++ b/tests/flow/test_data/script.txt @@ -3,3 +3,6 @@ def bar(a, b): def bar_variadic(a, args : List[Tensor]): return args[0] + args[1] + +def bar_two_lists(a: List[Tensor], b:List[Tensor]): + return a[0] + b[0] diff --git a/tests/flow/test_torchscript_extensions.py b/tests/flow/test_torchscript_extensions.py index e3171d6f7..46a6a0fbc 100644 --- a/tests/flow/test_torchscript_extensions.py +++ b/tests/flow/test_torchscript_extensions.py @@ -32,28 +32,28 @@ def __init__(self): def test_redis_error(self): try: self.con.execute_command( - 'AI.SCRIPTRUN', 'redis_scripts', 'test_redis_error') + 'AI.SCRIPTEXECUTE', 'redis_scripts', 'test_redis_error', 'KEYS', 1, "x{1}", "INPUTS", 1, "x{1}") self.env.assertTrue(False) except: pass def test_simple_test_set(self): self.con.execute_command( - 'AI.SCRIPTRUN', 'redis_scripts{1}', 'test_set_key') + 'AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_set_key', 'KEYS', 1, "x{1}", "INPUTS", 2, "x{1}", 1) self.env.assertEqual(b"1", self.con.get("x{1}")) def test_int_set_get(self): - self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts{1}', 'test_int_set_get', 'OUTPUTS', 'y{1}') + self.con.execute_command('AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_int_set_get', 'KEYS', 1, "x{1}", "INPUTS", 2, "x{1}", 1, 'OUTPUTS', 1, 'y{1}') y = self.con.execute_command('AI.TENSORGET', 'y{1}', 'meta' ,'VALUES') self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [], b"values", [1]] ) def test_int_set_incr(self): - self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts{1}', 'test_int_set_incr', 'OUTPUTS', 'y{1}') + self.con.execute_command('AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_int_set_incr', 'KEYS', 1, "x{1}", "INPUTS", 2, "x{1}", 1, 'OUTPUTS', 1, 'y{1}') y = self.con.execute_command('AI.TENSORGET', 'y{1}', 'meta' ,'VALUES') self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [], b"values", [2]] ) def test_float_get_set(self): - self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts{1}', 'test_float_set_get', 'OUTPUTS', 'y{1}') + self.con.execute_command('AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_float_set_get', 'KEYS', 1, "x{1}", "INPUTS", 2, "x{1}", 1.1, 'OUTPUTS', 1, 'y{1}') y = self.con.execute_command('AI.TENSORGET', 'y{1}', 'meta' ,'VALUES') self.env.assertEqual(y[0], b"dtype") self.env.assertEqual(y[1], b"FLOAT") @@ -61,13 +61,18 @@ def test_float_get_set(self): self.env.assertEqual(y[3], []) self.env.assertEqual(y[4], b"values") self.env.assertAlmostEqual(float(y[5][0]), 1.1, 0.1) - + def test_int_list(self): - self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts{1}', 'test_int_list', 'OUTPUTS', 'y{1}') + self.con.execute_command('AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_int_list', 'KEYS', 1, "int_list{1}", 'INPUTS', 1, "int_list{1}", "LIST_INPUTS", 2, "1", "2", 'OUTPUTS', 1, 'y{1}') y = self.con.execute_command('AI.TENSORGET', 'y{1}', 'meta' ,'VALUES') self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [2, 1], b"values", [1, 2]] ) + def test_str_list(self): + self.con.execute_command('AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_str_list', 'KEYS', 1, "str_list{1}", 'INPUTS', 1, "str_list{1}", "LIST_INPUTS", 2, "1", "2") + res = self.con.execute_command("LRANGE", "str_list{1}", "0", "2") + self.env.assertEqual(res, [b"1", b"2"] ) + def test_hash(self): - self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts{1}', 'test_hash', 'OUTPUTS', 'y{1}') + self.con.execute_command('AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_hash', 'KEYS', 1, "hash{1}", 'INPUTS', 1, "hash{1}", "LIST_INPUTS", 4, "field1", 1, "field2", 2, 'OUTPUTS', 1, 'y{1}') y = self.con.execute_command('AI.TENSORGET', 'y{1}', 'meta' ,'VALUES') self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [2, 1], b"values", [1, 2]] ) \ No newline at end of file diff --git a/tests/flow/tests_common.py b/tests/flow/tests_common.py index 460fc4fc3..ec37df748 100644 --- a/tests/flow/tests_common.py +++ b/tests/flow/tests_common.py @@ -233,7 +233,7 @@ def test_common_tensorget_error_replies(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("tensor key is empty",exception.__str__()) + env.assertEqual("tensor key is empty or in a different shard",exception.__str__()) # WRONGTYPE Operation against a key holding the wrong kind of value try: diff --git a/tests/flow/tests_dag.py b/tests/flow/tests_dag.py index a00c9d97e..642609635 100644 --- a/tests/flow/tests_dag.py +++ b/tests/flow/tests_dag.py @@ -72,7 +72,7 @@ def test_dag_load_errors(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("tensor key is empty",exception.__str__()) + env.assertEqual("tensor key is empty or in a different shard",exception.__str__()) # WRONGTYPE Operation against a key holding the wrong kind of value try: @@ -367,7 +367,7 @@ def test_dag_scriptrun_errors(env): except Exception as e: exception = e env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("function name not specified", exception.__str__()) + env.assertEqual("Unrecongnized parameter to SCRIPTRUN", exception.__str__()) def test_dag_modelrun_financialNet_errors(env): diff --git a/tests/flow/tests_deprecated_commands.py b/tests/flow/tests_deprecated_commands.py index e94df5f1c..0bf3c77e6 100644 --- a/tests/flow/tests_deprecated_commands.py +++ b/tests/flow/tests_deprecated_commands.py @@ -209,3 +209,189 @@ def test_modelset_modelrun_onnx(env): values = con.execute_command('AI.TENSORGET', 'b{1}', 'VALUES') argmax = max(range(len(values)), key=lambda i: values[i]) env.assertEqual(argmax, 1) + + +def test_pytorch_scriptrun(env): + if not TEST_PT: + env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) + return + + con = env.getConnection() + script = load_file_content('script.txt') + + ret = con.execute_command('AI.SCRIPTSET', 'myscript{1}', DEVICE, 'TAG', 'version1', 'SOURCE', script) + env.assertEqual(ret, b'OK') + + ret = con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + ret = con.execute_command('AI.TENSORSET', 'b{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + + ensureSlaveSynced(con, env) + + for _ in range( 0,100): + + ret = con.execute_command('AI.SCRIPTRUN', 'myscript{1}', 'bar', 'INPUTS', 'a{1}', 'b{1}', 'OUTPUTS', 'c{1}') + env.assertEqual(ret, b'OK') + + + ensureSlaveSynced(con, env) + + info = con.execute_command('AI.INFO', 'myscript{1}') + info_dict_0 = info_to_dict(info) + + env.assertEqual(info_dict_0['key'], 'myscript{1}') + env.assertEqual(info_dict_0['type'], 'SCRIPT') + env.assertEqual(info_dict_0['backend'], 'TORCH') + env.assertEqual(info_dict_0['tag'], 'version1') + env.assertTrue(info_dict_0['duration'] > 0) + env.assertEqual(info_dict_0['samples'], -1) + env.assertEqual(info_dict_0['calls'], 100) + env.assertEqual(info_dict_0['errors'], 0) + + values = con.execute_command('AI.TENSORGET', 'c{1}', 'VALUES') + env.assertEqual(values, [b'4', b'6', b'4', b'6']) + + ensureSlaveSynced(con, env) + + if env.useSlaves: + con2 = env.getSlaveConnection() + values2 = con2.execute_command('AI.TENSORGET', 'c{1}', 'VALUES') + env.assertEqual(values2, values) + + +def test_pytorch_scriptrun_variadic(env): + if not TEST_PT: + env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) + return + + con = env.getConnection() + + script = load_file_content('script.txt') + + ret = con.execute_command('AI.SCRIPTSET', 'myscript{$}', DEVICE, 'TAG', 'version1', 'SOURCE', script) + env.assertEqual(ret, b'OK') + + ret = con.execute_command('AI.TENSORSET', 'a{$}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + ret = con.execute_command('AI.TENSORSET', 'b1{$}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + ret = con.execute_command('AI.TENSORSET', 'b2{$}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + + ensureSlaveSynced(con, env) + + for _ in range( 0,100): + ret = con.execute_command('AI.SCRIPTRUN', 'myscript{$}', 'bar_variadic', 'INPUTS', 'a{$}', '$', 'b1{$}', 'b2{$}', 'OUTPUTS', 'c{$}') + env.assertEqual(ret, b'OK') + + ensureSlaveSynced(con, env) + + info = con.execute_command('AI.INFO', 'myscript{$}') + info_dict_0 = info_to_dict(info) + + env.assertEqual(info_dict_0['key'], 'myscript{$}') + env.assertEqual(info_dict_0['type'], 'SCRIPT') + env.assertEqual(info_dict_0['backend'], 'TORCH') + env.assertEqual(info_dict_0['tag'], 'version1') + env.assertTrue(info_dict_0['duration'] > 0) + env.assertEqual(info_dict_0['samples'], -1) + env.assertEqual(info_dict_0['calls'], 100) + env.assertEqual(info_dict_0['errors'], 0) + + values = con.execute_command('AI.TENSORGET', 'c{$}', 'VALUES') + env.assertEqual(values, [b'4', b'6', b'4', b'6']) + + ensureSlaveSynced(con, env) + + if env.useSlaves: + con2 = env.getSlaveConnection() + values2 = con2.execute_command('AI.TENSORGET', 'c{$}', 'VALUES') + env.assertEqual(values2, values) + + +def test_pytorch_scriptrun_errors(env): + if not TEST_PT: + env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) + return + + con = env.getConnection() + + script = load_file_content('script.txt') + + ret = con.execute_command('AI.SCRIPTSET', 'ket{1}', DEVICE, 'TAG', 'asdf', 'SOURCE', script) + env.assertEqual(ret, b'OK') + + ret = con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + ret = con.execute_command('AI.TENSORSET', 'b{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + + ensureSlaveSynced(con, env) + + con.execute_command('DEL', 'EMPTY{1}') + # ERR no script at key from SCRIPTGET + check_error_message(env, con, "script key is empty", 'AI.SCRIPTGET', 'EMPTY{1}') + + con.execute_command('SET', 'NOT_SCRIPT{1}', 'BAR') + # ERR wrong type from SCRIPTGET + check_error_message(env, con, "WRONGTYPE Operation against a key holding the wrong kind of value", 'AI.SCRIPTGET', 'NOT_SCRIPT{1}') + + con.execute_command('DEL', 'EMPTY{1}') + # ERR no script at key from SCRIPTRUN + check_error_message(env, con, "script key is empty", 'AI.SCRIPTRUN', 'EMPTY{1}', 'bar', 'INPUTS', 'b{1}', 'OUTPUTS', 'c{1}') + + con.execute_command('SET', 'NOT_SCRIPT{1}', 'BAR') + # ERR wrong type from SCRIPTRUN + check_error_message(env, con, "WRONGTYPE Operation against a key holding the wrong kind of value", 'AI.SCRIPTRUN', 'NOT_SCRIPT{1}', 'bar', 'INPUTS', 'b{1}', 'OUTPUTS', 'c{1}') + + con.execute_command('DEL', 'EMPTY{1}') + # ERR Input key is empty + check_error_message(env, con, "tensor key is empty or in a different shard", 'AI.SCRIPTRUN', 'ket{1}', 'bar', 'INPUTS', 'EMPTY{1}', 'b{1}', 'OUTPUTS', 'c{1}') + + con.execute_command('SET', 'NOT_TENSOR{1}', 'BAR') + # ERR Input key not tensor + check_error_message(env, con, "WRONGTYPE Operation against a key holding the wrong kind of value", 'AI.SCRIPTRUN', 'ket{1}', 'bar', 'INPUTS', 'NOT_TENSOR{1}', 'b{1}', 'OUTPUTS', 'c{1}') + + check_error(env, con, 'AI.SCRIPTRUN', 'ket{1}', 'bar', 'INPUTS', 'b{1}', 'OUTPUTS', 'c{1}') + + check_error(env, con, 'AI.SCRIPTRUN', 'ket{1}', 'INPUTS', 'a{1}', 'b{1}', 'OUTPUTS', 'c{1}') + + check_error(env, con, 'AI.SCRIPTRUN', 'ket{1}', 'bar', 'INPUTS', 'b{1}', 'OUTPUTS') + + check_error(env, con, 'AI.SCRIPTRUN', 'ket{1}', 'bar', 'INPUTS', 'OUTPUTS') + +def test_pytorch_scriptrun_variadic_errors(env): + if not TEST_PT: + env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) + return + + con = env.getConnection() + + script = load_file_content('script.txt') + + ret = con.execute_command('AI.SCRIPTSET', 'ket{$}', DEVICE, 'TAG', 'asdf', 'SOURCE', script) + env.assertEqual(ret, b'OK') + + ret = con.execute_command('AI.TENSORSET', 'a{$}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + ret = con.execute_command('AI.TENSORSET', 'b{$}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + + ensureSlaveSynced(con, env) + + con.execute_command('DEL', 'EMPTY{$}') + # ERR Variadic input key is empty + check_error_message(env, con, "tensor key is empty or in a different shard", 'AI.SCRIPTRUN', 'ket{$}', 'bar_variadic', 'INPUTS', 'a{$}', '$', 'EMPTY{$}', 'b{$}', 'OUTPUTS', 'c{$}') + + con.execute_command('SET', 'NOT_TENSOR{$}', 'BAR') + # ERR Variadic input key not tensor + check_error_message(env, con, "WRONGTYPE Operation against a key holding the wrong kind of value", 'AI.SCRIPTRUN', 'ket{$}', 'bar_variadic', 'INPUTS', 'a{$}', '$' , 'NOT_TENSOR{$}', 'b{$}', 'OUTPUTS', 'c{$}') + + check_error(env, con, 'AI.SCRIPTRUN', 'ket{$}', 'bar_variadic', 'INPUTS', 'b{$}', '${$}', 'OUTPUTS', 'c{$}') + + check_error(env, con, 'AI.SCRIPTRUN', 'ket{$}', 'bar_variadic', 'INPUTS', 'b{$}', '$', 'OUTPUTS') + + check_error(env, con, 'AI.SCRIPTRUN', 'ket{$}', 'bar_variadic', 'INPUTS', '$', 'OUTPUTS') + + check_error_message(env, con, "Already encountered a variable size list of tensors", 'AI.SCRIPTRUN', 'ket{$}', 'bar_variadic', 'INPUTS', '$', 'a{$}', '$', 'b{$}' 'OUTPUTS') diff --git a/tests/flow/tests_pytorch.py b/tests/flow/tests_pytorch.py index b9ded011b..33f52ff7d 100644 --- a/tests/flow/tests_pytorch.py +++ b/tests/flow/tests_pytorch.py @@ -86,12 +86,7 @@ def test_pytorch_modelrun(env): env.assertEqual(ret[1], b'TORCH') env.assertEqual(ret[3], b'CPU') - try: - con.execute_command('AI.MODELSTORE', 'm{1}', 'TORCH', DEVICE, 'BLOB', wrong_model_pb) - env.assertTrue(False) - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) + check_error(env, con, 'AI.MODELSTORE', 'm{1}', 'TORCH', DEVICE, 'BLOB', wrong_model_pb) con.execute_command('AI.MODELEXECUTE', 'm{1}', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}') @@ -240,44 +235,21 @@ def test_pytorch_scriptset(env): con = env.getConnection() - try: - con.execute_command('AI.SCRIPTSET', 'ket{1}', DEVICE, 'SOURCE', 'return 1') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) + check_error(env, con, 'AI.SCRIPTSET', 'ket{1}', DEVICE, 'SOURCE', 'return 1') - try: - con.execute_command('AI.SCRIPTSET', 'nope') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) + check_error(env, con, 'AI.SCRIPTSET', 'nope') - try: - con.execute_command('AI.SCRIPTSET', 'nope', 'SOURCE') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) + check_error(env, con, 'AI.SCRIPTSET', 'nope', 'SOURCE') - try: - con.execute_command('AI.SCRIPTSET', 'more', DEVICE) - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) + check_error(env, con, 'AI.SCRIPTSET', 'more', DEVICE) - test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') - script_filename = os.path.join(test_data_path, 'script.txt') - - with open(script_filename, 'rb') as f: - script = f.read() + script = load_file_content('script.txt') ret = con.execute_command('AI.SCRIPTSET', 'ket{1}', DEVICE, 'SOURCE', script) env.assertEqual(ret, b'OK') ensureSlaveSynced(con, env) - with open(script_filename, 'rb') as f: - script = f.read() - ret = con.execute_command('AI.SCRIPTSET', 'ket{1}', DEVICE, 'TAG', 'asdf', 'SOURCE', script) env.assertEqual(ret, b'OK') @@ -291,11 +263,7 @@ def test_pytorch_scriptdel(env): con = env.getConnection() - test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') - script_filename = os.path.join(test_data_path, 'script.txt') - - with open(script_filename, 'rb') as f: - script = f.read() + script = load_file_content('script.txt') ret = con.execute_command('AI.SCRIPTSET', 'ket{1}', DEVICE, 'SOURCE', script) env.assertEqual(ret, b'OK') @@ -313,37 +281,22 @@ def test_pytorch_scriptdel(env): con2 = env.getSlaveConnection() env.assertFalse(con2.execute_command('EXISTS', 'ket{1}')) + con.execute_command('DEL', 'EMPTY') # ERR no script at key from SCRIPTDEL - try: - con.execute_command('DEL', 'EMPTY') - con.execute_command('AI.SCRIPTDEL', 'EMPTY') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("script key is empty", exception.__str__()) + check_error_message(env, con, "script key is empty", 'AI.SCRIPTDEL', 'EMPTY') + con.execute_command('SET', 'NOT_SCRIPT', 'BAR') # ERR wrong type from SCRIPTDEL - try: - con.execute_command('SET', 'NOT_SCRIPT', 'BAR') - con.execute_command('AI.SCRIPTDEL', 'NOT_SCRIPT') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__()) + check_error_message(env, con, "WRONGTYPE Operation against a key holding the wrong kind of value", 'AI.SCRIPTDEL', 'NOT_SCRIPT') - -def test_pytorch_scriptrun(env): +def test_pytorch_scriptexecute(env): if not TEST_PT: env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) return con = env.getConnection() - test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') - script_filename = os.path.join(test_data_path, 'script.txt') - - with open(script_filename, 'rb') as f: - script = f.read() + script = load_file_content('script.txt') ret = con.execute_command('AI.SCRIPTSET', 'myscript{1}', DEVICE, 'TAG', 'version1', 'SOURCE', script) env.assertEqual(ret, b'OK') @@ -357,7 +310,7 @@ def test_pytorch_scriptrun(env): for _ in range( 0,100): - ret = con.execute_command('AI.SCRIPTRUN', 'myscript{1}', 'bar', 'INPUTS', 'a{1}', 'b{1}', 'OUTPUTS', 'c{1}') + ret = con.execute_command('AI.SCRIPTEXECUTE', 'myscript{1}', 'bar', 'KEYS', 1, '{1}', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}') env.assertEqual(ret, b'OK') @@ -386,18 +339,14 @@ def test_pytorch_scriptrun(env): env.assertEqual(values2, values) -def test_pytorch_scriptrun_variadic(env): +def test_pytorch_scriptexecute_list_input(env): if not TEST_PT: env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) return con = env.getConnection() - test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') - script_filename = os.path.join(test_data_path, 'script.txt') - - with open(script_filename, 'rb') as f: - script = f.read() + script = load_file_content('script.txt') ret = con.execute_command('AI.SCRIPTSET', 'myscript{$}', DEVICE, 'TAG', 'version1', 'SOURCE', script) env.assertEqual(ret, b'OK') @@ -412,7 +361,7 @@ def test_pytorch_scriptrun_variadic(env): ensureSlaveSynced(con, env) for _ in range( 0,100): - ret = con.execute_command('AI.SCRIPTRUN', 'myscript{$}', 'bar_variadic', 'INPUTS', 'a{$}', '$', 'b1{$}', 'b2{$}', 'OUTPUTS', 'c{$}') + ret = con.execute_command('AI.SCRIPTEXECUTE', 'myscript{$}', 'bar_variadic', 'KEYS', 1, '{$}', 'INPUTS', 1, 'a{$}', 'LIST_INPUTS', 2, 'b1{$}', 'b2{$}', 'OUTPUTS', 1, 'c{$}') env.assertEqual(ret, b'OK') ensureSlaveSynced(con, env) @@ -440,18 +389,61 @@ def test_pytorch_scriptrun_variadic(env): env.assertEqual(values2, values) -def test_pytorch_scriptrun_errors(env): +def test_pytorch_scriptexecute_multiple_list_input(env): if not TEST_PT: env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) return con = env.getConnection() - test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') - script_filename = os.path.join(test_data_path, 'script.txt') + script = load_file_content('script.txt') + + ret = con.execute_command('AI.SCRIPTSET', 'myscript{$}', DEVICE, 'TAG', 'version1', 'SOURCE', script) + env.assertEqual(ret, b'OK') + + ret = con.execute_command('AI.TENSORSET', 'a{$}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + ret = con.execute_command('AI.TENSORSET', 'b{$}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + + ensureSlaveSynced(con, env) + + for _ in range( 0,100): + ret = con.execute_command('AI.SCRIPTEXECUTE', 'myscript{$}', 'bar_two_lists', 'KEYS', 1, '{$}', 'LIST_INPUTS', 1, 'a{$}', 'LIST_INPUTS', 1, 'b{$}', 'OUTPUTS', 1, 'c{$}') + env.assertEqual(ret, b'OK') + + ensureSlaveSynced(con, env) + + info = con.execute_command('AI.INFO', 'myscript{$}') + info_dict_0 = info_to_dict(info) + + env.assertEqual(info_dict_0['key'], 'myscript{$}') + env.assertEqual(info_dict_0['type'], 'SCRIPT') + env.assertEqual(info_dict_0['backend'], 'TORCH') + env.assertEqual(info_dict_0['tag'], 'version1') + env.assertTrue(info_dict_0['duration'] > 0) + env.assertEqual(info_dict_0['samples'], -1) + env.assertEqual(info_dict_0['calls'], 100) + env.assertEqual(info_dict_0['errors'], 0) + + values = con.execute_command('AI.TENSORGET', 'c{$}', 'VALUES') + env.assertEqual(values, [b'4', b'6', b'4', b'6']) + + ensureSlaveSynced(con, env) + + if env.useSlaves: + con2 = env.getSlaveConnection() + values2 = con2.execute_command('AI.TENSORGET', 'c{$}', 'VALUES') + env.assertEqual(values2, values) + +def test_pytorch_scriptexecute_errors(env): + if not TEST_PT: + env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) + return + + con = env.getConnection() - with open(script_filename, 'rb') as f: - script = f.read() + script = load_file_content('script.txt') ret = con.execute_command('AI.SCRIPTSET', 'ket{1}', DEVICE, 'TAG', 'asdf', 'SOURCE', script) env.assertEqual(ret, b'OK') @@ -463,97 +455,56 @@ def test_pytorch_scriptrun_errors(env): ensureSlaveSynced(con, env) + con.execute_command('DEL', 'EMPTY{1}') # ERR no script at key from SCRIPTGET - try: - con.execute_command('DEL', 'EMPTY{1}') - con.execute_command('AI.SCRIPTGET', 'EMPTY{1}') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("script key is empty", exception.__str__()) + check_error_message(env, con, "script key is empty", 'AI.SCRIPTGET', 'EMPTY{1}') + con.execute_command('SET', 'NOT_SCRIPT{1}', 'BAR') # ERR wrong type from SCRIPTGET - try: - con.execute_command('SET', 'NOT_SCRIPT{1}', 'BAR') - con.execute_command('AI.SCRIPTGET', 'NOT_SCRIPT{1}') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__()) - - # ERR no script at key from SCRIPTRUN - try: - con.execute_command('DEL', 'EMPTY{1}') - con.execute_command('AI.SCRIPTRUN', 'EMPTY{1}', 'bar', 'INPUTS', 'b{1}', 'OUTPUTS', 'c{1}') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("script key is empty", exception.__str__()) - - # ERR wrong type from SCRIPTRUN - try: - con.execute_command('SET', 'NOT_SCRIPT{1}', 'BAR') - con.execute_command('AI.SCRIPTRUN', 'NOT_SCRIPT{1}', 'bar', 'INPUTS', 'b{1}', 'OUTPUTS', 'c{1}') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__()) + check_error_message(env, con, "WRONGTYPE Operation against a key holding the wrong kind of value", 'AI.SCRIPTGET', 'NOT_SCRIPT{1}') + con.execute_command('DEL', 'EMPTY{1}') + # ERR no script at key from SCRIPTEXECUTE + check_error_message(env, con, "script key is empty", 'AI.SCRIPTEXECUTE', 'EMPTY{1}', 'bar', 'KEYS', 1 , '{1}', 'INPUTS', 1, 'b{1}', 'OUTPUTS', 1, 'c{1}') + + con.execute_command('SET', 'NOT_SCRIPT{1}', 'BAR') + # ERR wrong type from SCRIPTEXECUTE + check_error_message(env, con, "WRONGTYPE Operation against a key holding the wrong kind of value", 'AI.SCRIPTEXECUTE', 'NOT_SCRIPT{1}', 'bar', 'KEYS', 1 , '{1}', 'INPUTS', 1, 'b{1}', 'OUTPUTS', 1, 'c{1}') + + con.execute_command('DEL', 'EMPTY{1}') # ERR Input key is empty - try: - con.execute_command('DEL', 'EMPTY{1}') - con.execute_command('AI.SCRIPTRUN', 'ket{1}', 'bar', 'INPUTS', 'EMPTY{1}', 'b{1}', 'OUTPUTS', 'c{1}') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("tensor key is empty", exception.__str__()) + check_error_message(env, con, "tensor key is empty or in a different shard", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1 , '{1}', 'INPUTS', 2, 'EMPTY{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}') + con.execute_command('SET', 'NOT_TENSOR{1}', 'BAR') # ERR Input key not tensor - try: - con.execute_command('SET', 'NOT_TENSOR{1}', 'BAR') - con.execute_command('AI.SCRIPTRUN', 'ket{1}', 'bar', 'INPUTS', 'NOT_TENSOR{1}', 'b{1}', 'OUTPUTS', 'c{1}') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__()) - - try: - con.execute_command('AI.SCRIPTRUN', 'ket{1}', 'bar', 'INPUTS', 'b{1}', 'OUTPUTS', 'c{1}') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - - try: - con.execute_command('AI.SCRIPTRUN', 'ket{1}', 'INPUTS', 'a{1}', 'b{1}', 'OUTPUTS', 'c{1}') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - - try: - con.execute_command('AI.SCRIPTRUN', 'ket{1}', 'bar', 'INPUTS', 'b{1}', 'OUTPUTS') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - - try: - con.execute_command('AI.SCRIPTRUN', 'ket{1}', 'bar', 'INPUTS', 'OUTPUTS') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - - -def test_pytorch_scriptrun_variadic_errors(env): + check_error_message(env, con, "WRONGTYPE Operation against a key holding the wrong kind of value", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1 , '{1}', 'INPUTS', 2, 'NOT_TENSOR{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}') + + check_error(env, con, 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1 , '{1}', 'INPUTS', 1, 'b{1}', 'OUTPUTS', 1, 'c{1}') + + check_error(env, con, 'AI.SCRIPTEXECUTE', 'ket{1}', 'KEYS', 1 , '{1}', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}') + + check_error(env, con, 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1 , '{1}', 'INPUTS', 1, 'b{1}', 'OUTPUTS') + + check_error(env, con, 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1 , '{1}', 'INPUTS', 'OUTPUTS') + + check_error_message(env, con, "KEYS scope must be provided first for AI.SCRIPTEXECUTE command", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'INPUTS', 'OUTPUTS') + + if env.isCluster(): + # cross shard + check_error_message(env, con, "CROSSSLOT Keys in request don't hash to the same slot", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1 , '{2}', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}') + + # key doesn't exist + check_error_message(env, con, "tensor key is empty or in a different shard", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1 , '{1}', 'INPUTS', 2, 'a{1}', 'b{2}', 'OUTPUTS', 1, 'c{1}') + + +def test_pytorch_scriptexecute_variadic_errors(env): if not TEST_PT: env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) return con = env.getConnection() - test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') - script_filename = os.path.join(test_data_path, 'script.txt') - - with open(script_filename, 'rb') as f: - script = f.read() + script = load_file_content('script.txt') ret = con.execute_command('AI.SCRIPTSET', 'ket{$}', DEVICE, 'TAG', 'asdf', 'SOURCE', script) env.assertEqual(ret, b'OK') @@ -565,48 +516,21 @@ def test_pytorch_scriptrun_variadic_errors(env): ensureSlaveSynced(con, env) + con.execute_command('DEL', 'EMPTY{$}') # ERR Variadic input key is empty - try: - con.execute_command('DEL', 'EMPTY{$}') - con.execute_command('AI.SCRIPTRUN', 'ket{$}', 'bar_variadic', 'INPUTS', 'a{$}', '$', 'EMPTY{$}', 'b{$}', 'OUTPUTS', 'c{$}') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("tensor key is empty", exception.__str__()) + check_error_message(env, con, "tensor key is empty or in a different shard", 'AI.SCRIPTEXECUTE', 'ket{$}', 'bar_variadic', 'KEYS', 1 , '{$}', 'INPUTS', 1, 'a{$}', "LIST_INPUTS", 2, 'EMPTY{$}', 'b{$}', 'OUTPUTS', 1, 'c{$}') + con.execute_command('SET', 'NOT_TENSOR{$}', 'BAR') # ERR Variadic input key not tensor - try: - con.execute_command('SET', 'NOT_TENSOR{$}', 'BAR') - con.execute_command('AI.SCRIPTRUN', 'ket{$}', 'bar_variadic', 'INPUTS', 'a{$}', '$' , 'NOT_TENSOR{$}', 'b{$}', 'OUTPUTS', 'c{$}') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__()) - - try: - con.execute_command('AI.SCRIPTRUN', 'ket{$}', 'bar_variadic', 'INPUTS', 'b{$}', '${$}', 'OUTPUTS', 'c{$}') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - - try: - con.execute_command('AI.SCRIPTRUN', 'ket{$}', 'bar_variadic', 'INPUTS', 'b{$}', '$', 'OUTPUTS') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - - try: - con.execute_command('AI.SCRIPTRUN', 'ket{$}', 'bar_variadic', 'INPUTS', '$', 'OUTPUTS') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) - - # "ERR Already encountered a variable size list of tensors" - try: - con.execute_command('AI.SCRIPTRUN', 'ket{$}', 'bar_variadic', 'INPUTS', '$', 'a{$}', '$', 'b{$}' 'OUTPUTS') - except Exception as e: - exception = e - env.assertEqual(type(exception), redis.exceptions.ResponseError) + check_error_message(env, con, "WRONGTYPE Operation against a key holding the wrong kind of value", 'AI.SCRIPTEXECUTE', 'ket{$}', 'bar_variadic', 'KEYS', 1 , '{$}', 'INPUTS', 1, 'a{$}', "LIST_INPUTS", 2, 'NOT_TENSOR{$}', 'b{$}', 'OUTPUTS', 1, 'c{$}') + + check_error(env, con, 'AI.SCRIPTEXECUTE', 'ket{$}', 'bar_variadic', 'KEYS', 1 , '{$}', 'INPUTS', 2, 'b{$}', '${$}', 'OUTPUTS', 1, 'c{$}') + + check_error(env, con, 'AI.SCRIPTEXECUTE', 'ket{$}', 'bar_variadic', 'KEYS', 1 , '{$}', 'INPUTS', 1, 'b{$}', 'LIST_INPUTS', 'OUTPUTS') + + check_error(env, con, 'AI.SCRIPTEXECUTE', 'ket{$}', 'bar_variadic', 'KEYS', 1 , '{$}', 'INPUTS', 'LIST_INPUTS', 'OUTPUTS') + + check_error(env, con, 'AI.SCRIPTEXECUTE', 'ket{$}', 'bar_variadic', 'KEYS', 1 , '{$}', 'LIST_INPUTS', 1, 'a{$}', 'LIST_INPUTS', 1, 'b{$}' 'OUTPUTS') def test_pytorch_scriptinfo(env): @@ -619,11 +543,7 @@ def test_pytorch_scriptinfo(env): con = env.getConnection() - test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') - script_filename = os.path.join(test_data_path, 'script.txt') - - with open(script_filename, 'rb') as f: - script = f.read() + script = load_file_content('script.txt') ret = con.execute_command('AI.SCRIPTSET', 'ket_script{1}', DEVICE, 'SOURCE', script) env.assertEqual(ret, b'OK') @@ -637,7 +557,7 @@ def test_pytorch_scriptinfo(env): previous_duration = 0 for call in range(1, 100): - ret = con.execute_command('AI.SCRIPTRUN', 'ket_script{1}', 'bar', 'INPUTS', 'a{1}', 'b{1}', 'OUTPUTS', 'c{1}') + ret = con.execute_command('AI.SCRIPTEXECUTE', 'ket_script{1}', 'bar', 'KEYS', 1, '{1}', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}') env.assertEqual(ret, b'OK') ensureSlaveSynced(con, env) @@ -665,7 +585,7 @@ def test_pytorch_scriptinfo(env): env.assertEqual(info_dict_0['errors'], 0) -def test_pytorch_scriptrun_disconnect(env): +def test_pytorch_scriptexecute_disconnect(env): if not TEST_PT: env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) return @@ -676,11 +596,7 @@ def test_pytorch_scriptrun_disconnect(env): con = env.getConnection() - test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') - script_filename = os.path.join(test_data_path, 'script.txt') - - with open(script_filename, 'rb') as f: - script = f.read() + script = load_file_content('script.txt') ret = con.execute_command('AI.SCRIPTSET', 'ket_script{1}', DEVICE, 'SOURCE', script) env.assertEqual(ret, b'OK') @@ -692,7 +608,7 @@ def test_pytorch_scriptrun_disconnect(env): ensureSlaveSynced(con, env) - ret = send_and_disconnect(('AI.SCRIPTRUN', 'ket_script{1}', 'bar', 'INPUTS', 'a{1}', 'b{1}', 'OUTPUTS', 'c{1}'), con) + ret = send_and_disconnect(('AI.SCRIPTEXECUTE', 'ket_script{1}', 'bar', 'KEYS', '{1}', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}'), con) env.assertEqual(ret, None) @@ -734,11 +650,7 @@ def test_pytorch_modelscan_scriptscan(env): # ensure cleaned DB # env.flush() - test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') - model_filename = os.path.join(test_data_path, 'pt-minimal.pt') - - with open(model_filename, 'rb') as f: - model_pb = f.read() + model_pb = load_file_content('pt-minimal.pt') ret = con.execute_command('AI.MODELSTORE', 'm1', 'TORCH', DEVICE, 'TAG', 'm:v1', 'BLOB', model_pb) env.assertEqual(ret, b'OK') @@ -746,10 +658,7 @@ def test_pytorch_modelscan_scriptscan(env): ret = con.execute_command('AI.MODELSTORE', 'm2', 'TORCH', DEVICE, 'BLOB', model_pb) env.assertEqual(ret, b'OK') - script_filename = os.path.join(test_data_path, 'script.txt') - - with open(script_filename, 'rb') as f: - script = f.read() + script = load_file_content('script.txt') ret = con.execute_command('AI.SCRIPTSET', 's1', DEVICE, 'TAG', 's:v1', 'SOURCE', script) env.assertEqual(ret, b'OK') diff --git a/tests/flow/tests_tensorflow.py b/tests/flow/tests_tensorflow.py index 0000b4c4d..b2797a424 100644 --- a/tests/flow/tests_tensorflow.py +++ b/tests/flow/tests_tensorflow.py @@ -637,7 +637,7 @@ def functor_financialNet(env, key_max, repetitions): # env.debugPrint("AI.MODELRUN elapsed time(sec) {:6.2f}\tTotal ops {:10.2f}\tAvg. ops/sec {:10.2f}".format(elapsed_time, total_ops, avg_ops_sec), True) -def test_tensorflow_modelrun_scriptrun_resnet(env): +def test_tensorflow_modelexecute_script_execute_resnet(env): if (not TEST_TF or not TEST_PT): return con = env.getConnection() @@ -668,16 +668,16 @@ def test_tensorflow_modelrun_scriptrun_resnet(env): 'BLOB', img.tobytes()) env.assertEqual(ret, b'OK') - ret = con.execute_command('AI.SCRIPTRUN', script_name, - 'pre_process_3ch', 'INPUTS', image_key, 'OUTPUTS', temp_key1 ) + ret = con.execute_command('AI.SCRIPTEXECUTE', script_name, + 'pre_process_3ch', 'KEYS', '1', script_name, 'INPUTS', 1, image_key, 'OUTPUTS', 1, temp_key1 ) env.assertEqual(ret, b'OK') ret = con.execute_command('AI.MODELEXECUTE', model_name, 'INPUTS', 1, temp_key1, 'OUTPUTS', 1, temp_key2 ) env.assertEqual(ret, b'OK') - ret = con.execute_command('AI.SCRIPTRUN', script_name, - 'post_process', 'INPUTS', temp_key2, 'OUTPUTS', output_key ) + ret = con.execute_command('AI.SCRIPTEXECUTE', script_name, + 'post_process', 'KEYS', 1 ,script_name, 'INPUTS', 1, temp_key2, 'OUTPUTS', 1, output_key ) env.assertEqual(ret, b'OK') ensureSlaveSynced(con, env) diff --git a/tests/flow/tests_withGears.py b/tests/flow/tests_withGears.py index 179070653..a9e32b96c 100644 --- a/tests/flow/tests_withGears.py +++ b/tests/flow/tests_withGears.py @@ -180,7 +180,7 @@ async def ScriptRun_AsyncRunError(record): ret = con.execute_command('rg.trigger', 'ScriptRun_AsyncRunError_test3') # This should raise an exception - env.assertTrue(str(ret[0]).startswith("b'attempted to get undefined function bad_func")) + env.assertTrue(str(ret[0]).startswith("b'attempted to get undefined function")) @skip_if_gears_not_loaded @@ -299,7 +299,7 @@ async def DAGRun_addOpsFromString(record): ret = con.execute_command('rg.trigger', 'DAGRun_test4') # This should raise an exception - env.assertTrue(str(ret[0]).startswith("b'attempted to get undefined function no_func")) + env.assertTrue(str(ret[0]).startswith("b'attempted to get undefined function")) ret = con.execute_command('rg.trigger', 'DAGRun_test5') env.assertEqual(ret[0], b'test5_OK') diff --git a/tests/module/LLAPI.c b/tests/module/LLAPI.c index 2e9ba029e..685313536 100644 --- a/tests/module/LLAPI.c +++ b/tests/module/LLAPI.c @@ -216,7 +216,7 @@ int RAI_llapi_scriptRun(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) keyRedisStr = RedisModule_CreateString(ctx, keyNameStr, strlen(keyNameStr)); key = RedisModule_OpenKey(ctx, keyRedisStr, REDISMODULE_READ); RAI_Tensor *input1 = RedisModule_ModuleTypeGetValue(key); - RedisAI_ScriptRunCtxAddInput(sctx, input1, err); + RedisAI_ScriptRunCtxAddTensorInput(sctx, input1); RedisModule_FreeString(ctx, keyRedisStr); RedisModule_CloseKey(key); @@ -224,7 +224,7 @@ int RAI_llapi_scriptRun(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) keyRedisStr = RedisModule_CreateString(ctx, keyNameStr, strlen(keyNameStr)); key = RedisModule_OpenKey(ctx, keyRedisStr, REDISMODULE_READ); RAI_Tensor *input2 = RedisModule_ModuleTypeGetValue(key); - RedisAI_ScriptRunCtxAddInput(sctx, input2, err); + RedisAI_ScriptRunCtxAddTensorInput(sctx, input2); RedisModule_FreeString(ctx, keyRedisStr); RedisModule_CloseKey(key);