Skip to content

Commit ea8e972

Browse files
committed
- Replace the REDISAI_MAIN macro in redisai.h with REDISAI_EXTERN.
- Added tests for errors in dag parsing. - Bug fix in parse timeout
1 parent 7db6b5d commit ea8e972

File tree

9 files changed

+106
-28
lines changed

9 files changed

+106
-28
lines changed

src/DAG/dag_parser.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ static int _ParseDAGPersistArgs(RedisModuleCtx *ctx, RedisModuleString **argv, i
127127

128128
static int _parseTimeout(RedisModuleString **argv, int argc, long long *timeout, RAI_Error *err) {
129129

130-
if (argc == 0) {
130+
if (argc < 2) {
131131
RAI_SetError(err, RAI_EDAGBUILDER, "ERR No value provided for TIMEOUT");
132132
return REDISMODULE_ERR;
133133
}

src/command_parser.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ static int _parseTimeout(RedisModuleString *timeout_arg, RAI_Error *error, long
1717
return REDISMODULE_OK;
1818
}
1919

20-
static int _ModelRunCommand_ParseArgs(RedisModuleString **argv, int argc, RedisModuleCtx *ctx,
20+
static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModuleString **argv,
2121
RAI_Model **model, RAI_Error *error,
2222
RedisModuleString ***inkeys, RedisModuleString ***outkeys,
2323
RedisModuleString **runkey, long long *timeout) {
@@ -130,7 +130,7 @@ int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModu
130130
RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL);
131131
RAI_Model *model;
132132
long long timeout = 0;
133-
if (_ModelRunCommand_ParseArgs(argv, argc, ctx, &model, rinfo->err, &currentOp->inkeys,
133+
if (_ModelRunCommand_ParseArgs(ctx, argc, argv, &model, rinfo->err, &currentOp->inkeys,
134134
&currentOp->outkeys, &currentOp->runkey,
135135
&timeout) == REDISMODULE_ERR) {
136136
goto cleanup;

src/redisai.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,7 @@ static int RedisAI_RegisterApi(RedisModuleCtx *ctx) {
990990
REGISTER_API(GetAsModelRunCtx, ctx);
991991

992992
REGISTER_API(ScriptCreate, ctx);
993+
REGISTER_API(GetScriptFromKeyspace, ctx);
993994
REGISTER_API(ScriptFree, ctx);
994995
REGISTER_API(ScriptRunCtxCreate, ctx);
995996
REGISTER_API(ScriptRunCtxAddInput, ctx);

src/redisai.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#define REDISAI_LLAPI_VERSION 1
88
#define MODULE_API_FUNC(x) (*x)
99

10-
#ifndef REDISAI_MAIN
10+
#ifdef REDISAI_EXTERN
1111
#define REDISAI_API extern
1212
#endif
1313

@@ -134,6 +134,11 @@ REDISAI_API RAI_ModelRunCtx *MODULE_API_FUNC(RedisAI_GetAsModelRunCtx)(RAI_OnFin
134134
REDISAI_API RAI_Script *MODULE_API_FUNC(RedisAI_ScriptCreate)(char *devicestr, char *tag,
135135
const char *scriptdef,
136136
RAI_Error *err);
137+
REDISAI_API int MODULE_API_FUNC(RedisAI_GetScriptFromKeyspace)(RedisModuleCtx *ctx,
138+
RedisModuleString *keyName,
139+
RedisModuleKey **key,
140+
RAI_Script **script, int mode,
141+
RAI_Error *err);
137142
REDISAI_API void MODULE_API_FUNC(RedisAI_ScriptFree)(RAI_Script *script, RAI_Error *err);
138143
REDISAI_API RAI_ScriptRunCtx *MODULE_API_FUNC(RedisAI_ScriptRunCtxCreate)(RAI_Script *script,
139144
const char *fnname);
@@ -258,6 +263,7 @@ static int RedisAI_Initialize(RedisModuleCtx *ctx) {
258263
REDISAI_MODULE_INIT_FUNCTION(ctx, GetAsModelRunCtx);
259264

260265
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptCreate);
266+
REDISAI_MODULE_INIT_FUNCTION(ctx, GetScriptFromKeyspace);
261267
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptFree);
262268
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxCreate);
263269
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddInput);

tests/flow/tests_dag.py

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,43 @@ def test_dag_load(env):
2424
ret = con.execute_command(command)
2525
env.assertEqual(ret, [b'OK'])
2626

27+
2728
def test_dag_load_errors(env):
2829
con = env.getConnection()
2930

31+
# ERR wrong number of arguments for LOAD
32+
try:
33+
command = "AI.DAGRUN PERSIST 1 no_tensor{1} LOAD 1"
34+
35+
ret = con.execute_command(command)
36+
except Exception as e:
37+
exception = e
38+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
39+
env.assertEqual("wrong number of arguments for LOAD in 'AI.DAGRUN' command",exception.__str__())
40+
41+
# ERR invalid or negative number of arguments for LOAD
42+
try:
43+
command = "AI.DAGRUN LOAD notnumber{{1}} |> AI.TENSORGET 'no_tensor{1}'"
44+
45+
ret = con.execute_command(command)
46+
except Exception as e:
47+
exception = e
48+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
49+
env.assertEqual("invalid or negative value found in number of keys to LOAD",exception.__str__())
50+
51+
con.execute_command('AI.TENSORSET', 'a{{1}}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
52+
53+
# ERR number of keys to LOAD does not match the number of given arguments.
54+
try:
55+
command = "AI.DAGRUN LOAD 2 a{{1}}"
56+
57+
ret = con.execute_command(command)
58+
except Exception as e:
59+
exception = e
60+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
61+
env.assertEqual("number of keys to LOAD in AI.DAGRUN command does not match the number of "
62+
"given arguments", exception.__str__())
63+
3064
# ERR tensor key is empty
3165
try:
3266
command = "AI.DAGRUN "\
@@ -55,51 +89,91 @@ def test_dag_load_errors(env):
5589
env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value",exception.__str__())
5690

5791

58-
def test_dag_common_errors(env):
92+
def test_dag_persist_errors(env):
5993
con = env.getConnection()
94+
con.execute_command('AI.TENSORSET', 'a{{1}}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
6095

61-
# ERR unsupported command within DAG
96+
# ERR wrong number of arguments for PERSIST
6297
try:
63-
command = "AI.DAGRUN |> "\
64-
"AI.DONTEXIST tensor1{{1}} FLOAT 1 2 VALUES 5 10"
98+
command = "AI.DAGRUN LOAD 1 a{{1}} PERSIST 1"
6599

66100
ret = con.execute_command(command)
67101
except Exception as e:
68102
exception = e
69103
env.assertEqual(type(exception), redis.exceptions.ResponseError)
70-
env.assertEqual("unsupported command within DAG",exception.__str__())
104+
env.assertEqual("wrong number of arguments for PERSIST in 'AI.DAGRUN' command",exception.__str__())
71105

72-
# ERR wrong number of arguments for 'AI.DAGRUN' command
106+
# ERR invalid or negative value found in number of keys to PERSIST
73107
try:
74-
command = "AI.DAGRUN "
108+
command = "AI.DAGRUN PERSIST notnumber{{1}} |> " \
109+
"AI.TENSORSET tensor1 FLOAT 1 2 VALUES 5 10"
75110

76111
ret = con.execute_command(command)
77112
except Exception as e:
78113
exception = e
79114
env.assertEqual(type(exception), redis.exceptions.ResponseError)
80-
env.assertEqual("wrong number of arguments for 'AI.DAGRUN' command",exception.__str__())
115+
env.assertEqual("invalid or negative value found in number of keys to PERSIST",exception.__str__())
81116

82-
# ERR invalid or negative value found in number of keys to PERSIST
117+
# ERR number of keys to PERSIST does not match the number of given arguments.
118+
try:
119+
command = "AI.DAGRUN PERSIST 2 tensor1{1}"
120+
121+
ret = con.execute_command(command)
122+
except Exception as e:
123+
exception = e
124+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
125+
env.assertEqual("number of keys to PERSIST in AI.DAGRUN command does not match the number of "
126+
"given arguments", exception.__str__())
127+
128+
129+
def test_dag_timeout_errors(env):
130+
con = env.getConnection()
131+
132+
# ERR no value provided for timeout
83133
try:
84-
command = "AI.DAGRUN PERSIST notnumber{{1}} |> "\
134+
command = "AI.DAGRUN PERSIST 1 no_tensor{1} TIMEOUT"
135+
136+
ret = con.execute_command(command)
137+
except Exception as e:
138+
exception = e
139+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
140+
env.assertEqual("No value provided for TIMEOUT",exception.__str__())
141+
142+
# ERR invalid timeout value
143+
try:
144+
command = "AI.DAGRUN TIMEOUT notnumber{{1}} |> " \
85145
"AI.TENSORSET tensor1 FLOAT 1 2 VALUES 5 10"
86146

87147
ret = con.execute_command(command)
88148
except Exception as e:
89149
exception = e
90150
env.assertEqual(type(exception), redis.exceptions.ResponseError)
91-
env.assertEqual("invalid or negative value found in number of keys to PERSIST",exception.__str__())
151+
env.assertEqual("Invalid value for TIMEOUT",exception.__str__())
92152

93-
# ERR invalid or negative value found in number of keys to LOAD
153+
154+
def test_dag_common_errors(env):
155+
con = env.getConnection()
156+
157+
# ERR unsupported command within DAG
94158
try:
95-
command = "AI.DAGRUN LOAD notnumber{{1}} |> "\
96-
"AI.TENSORSET tensor1 FLOAT 1 2 VALUES 5 10"
159+
command = "AI.DAGRUN |> "\
160+
"AI.DONTEXIST tensor1{{1}} FLOAT 1 2 VALUES 5 10"
97161

98162
ret = con.execute_command(command)
99163
except Exception as e:
100164
exception = e
101165
env.assertEqual(type(exception), redis.exceptions.ResponseError)
102-
env.assertEqual("invalid or negative value found in number of keys to LOAD",exception.__str__())
166+
env.assertEqual("unsupported command within DAG",exception.__str__())
167+
168+
# ERR wrong number of arguments for 'AI.DAGRUN' command
169+
try:
170+
command = "AI.DAGRUN "
171+
172+
ret = con.execute_command(command)
173+
except Exception as e:
174+
exception = e
175+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
176+
env.assertEqual("wrong number of arguments for 'AI.DAGRUN' command",exception.__str__())
103177

104178
# ERR DAG with no ops
105179
command = "AI.TENSORSET volatile_tensor1 FLOAT 1 2 VALUES 5 10"

tests/module/DAG_utils.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
#include "redisai.h"
1+
#define REDISAI_EXTERN 1
2+
#include "DAG_utils.h"
23
#include <errno.h>
34
#include <string.h>
45
#include <pthread.h>
56
#include <stdlib.h>
67
#include "util/arr.h"
7-
#include "DAG_utils.h"
88

99
pthread_mutex_t global_lock = PTHREAD_MUTEX_INITIALIZER;
1010
pthread_cond_t global_cond = PTHREAD_COND_INITIALIZER;

tests/module/DAG_utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
2-
#include "../../src/redisai.h"
2+
#include "redisai.h"
33
#include <pthread.h>
4+
#include <stdbool.h>
45

56
#define LLAPIMODULE_OK 0
67
#define LLAPIMODULE_ERR 1

tests/module/LLAPI.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11

22
#define REDISMODULE_MAIN
3-
#define REDISAI_MAIN 1
43

5-
#include "redisai.h"
4+
#include "DAG_utils.h"
65
#include <errno.h>
76
#include <string.h>
8-
#include <stdbool.h>
9-
#include "DAG_utils.h"
7+
108

119
typedef enum LLAPI_status {LLAPI_RUN_NONE = 0,
1210
LLAPI_RUN_SUCCESS,

tests/unit/unit_tests_err.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ extern "C" {
66
#include "src/err.h"
77
}
88

9-
#define REDISAI_MAIN
10-
119
class ErrorStructTest : public ::testing::Test {
1210
protected:
1311
static void SetUpTestCase() {

0 commit comments

Comments
 (0)