Skip to content

Commit cfe2222

Browse files
committed
torch redis pass with new API
1 parent 4c29254 commit cfe2222

File tree

3 files changed

+33
-29
lines changed

3 files changed

+33
-29
lines changed

src/execution/command_parser.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,8 @@ static int _ScriptExecuteCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleStrin
409409
RedisModuleString ***outkeys, long long *timeout,
410410
size_t **listSizes) {
411411
bool timeout_set = false;
412-
int argpos;
413-
for (argpos = 3; argpos < argc; argpos++) {
412+
int argpos = 3;
413+
while (argpos < argc ) {
414414
const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL);
415415

416416
// Parse timeout arg if given and store it in timeout
@@ -439,6 +439,7 @@ static int _ScriptExecuteCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleStrin
439439
"ERR Invalid argument for input count in AI.SCRIPTEXECUTE");
440440
return REDISMODULE_ERR;
441441
}
442+
argpos++;
442443
size_t first_input_pos = argpos;
443444
if (first_input_pos + ninputs > argc) {
444445
RAI_SetError(error, RAI_ESCRIPTRUN,
@@ -459,6 +460,7 @@ static int _ScriptExecuteCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleStrin
459460
"ERR Invalid argument for key count in AI.SCRIPTEXECUTE");
460461
return REDISMODULE_ERR;
461462
}
463+
argpos++;
462464
size_t first_input_pos = argpos;
463465
if (first_input_pos + ninputs > argc) {
464466
RAI_SetError(error, RAI_ESCRIPTRUN,
@@ -472,7 +474,7 @@ static int _ScriptExecuteCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleStrin
472474
continue;
473475
}
474476
if (!strcasecmp(arg_string, "OUTPUTS")) {
475-
if (array_len(outkeys) != 0) {
477+
if (array_len(*outkeys) != 0) {
476478
RAI_SetError(error, RAI_ESCRIPTRUN,
477479
"ERR Already encountered an OUTPUTS keyword in AI.SCRIPTEXECUTE");
478480
return REDISMODULE_ERR;
@@ -484,6 +486,7 @@ static int _ScriptExecuteCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleStrin
484486
"ERR Invalid argument for output count in AI.SCRIPTEXECUTE");
485487
return REDISMODULE_ERR;
486488
}
489+
argpos++;
487490
size_t first_output_pos = argpos;
488491
if (first_output_pos + noutputs > argc) {
489492
RAI_SetError(error, RAI_ESCRIPTRUN,
@@ -504,6 +507,7 @@ static int _ScriptExecuteCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleStrin
504507
"ERR Invalid argument for list input count in AI.SCRIPTEXECUTE");
505508
return REDISMODULE_ERR;
506509
}
510+
argpos++;
507511
size_t first_input_pos = argpos;
508512
if (first_input_pos + ninputs > argc) {
509513
RAI_SetError(error, RAI_ESCRIPTRUN,

tests/flow/test_data/redis_scripts.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,38 +23,38 @@ def redis_hash_to_tensor(redis_value: Any):
2323
return torch.cat(l, dim=0)
2424

2525
def test_redis_error():
26-
redis.execute("SET", "x")
26+
redis.execute("SET", "x{1}")
2727

2828
def test_int_set_get():
29-
redis.execute("SET", "x", "1")
30-
res = redis.execute("GET", "x",)
31-
redis.execute("DEL", "x")
29+
redis.execute("SET", "x{1}", "1")
30+
res = redis.execute("GET", "x{1}",)
31+
redis.execute("DEL", "x{1}")
3232
return redis_string_int_to_tensor(res)
3333

3434
def test_int_set_incr():
35-
redis.execute("SET", "x", "1")
36-
res = redis.execute("INCR", "x")
37-
redis.execute("DEL", "x")
35+
redis.execute("SET", "x{1}", "1")
36+
res = redis.execute("INCR", "x{1}")
37+
redis.execute("DEL", "x{1}")
3838
return redis_string_int_to_tensor(res)
3939

4040
def test_float_set_get():
41-
redis.execute("SET", "x", "1.1")
42-
res = redis.execute("GET", "x",)
43-
redis.execute("DEL", "x")
41+
redis.execute("SET", "x{1}", "1.1")
42+
res = redis.execute("GET", "x{1}",)
43+
redis.execute("DEL", "x{1}")
4444
return redis_string_float_to_tensor(res)
4545

4646
def test_int_list():
47-
redis.execute("RPUSH", "x", "1")
48-
redis.execute("RPUSH", "x", "2")
49-
res = redis.execute("LRANGE", "x", "0", "2")
50-
redis.execute("DEL", "x")
47+
redis.execute("RPUSH", "x{1}", "1")
48+
redis.execute("RPUSH", "x{1}", "2")
49+
res = redis.execute("LRANGE", "x{1}", "0", "2")
50+
redis.execute("DEL", "x{1}")
5151
return redis_int_list_to_tensor(res)
5252

5353

5454
def test_hash():
55-
redis.execute("HSET", "x", "field1", "1", "field2", "2")
56-
res = redis.execute("HVALS", "x")
57-
redis.execute("DEL", "x")
55+
redis.execute("HSET", "x{1}", "field1", "1", "field2", "2")
56+
res = redis.execute("HVALS", "x{1}")
57+
redis.execute("DEL", "x{1}")
5858
return redis_hash_to_tensor(res)
5959

6060

@@ -63,4 +63,4 @@ def test_set_key():
6363

6464

6565
def test_del_key():
66-
redis.execute("DEL", ["x"])
66+
redis.execute("DEL", ["x{1}"])

tests/flow/test_torchscript_extensions.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,42 +32,42 @@ def __init__(self):
3232
def test_redis_error(self):
3333
try:
3434
self.con.execute_command(
35-
'AI.SCRIPTRUN', 'redis_scripts', 'test_redis_error')
35+
'AI.SCRIPTEXECUTE', 'redis_scripts', 'test_redis_error', 'KEYS', 1, "x{1}")
3636
self.env.assertTrue(False)
3737
except:
3838
pass
3939

4040
def test_simple_test_set(self):
4141
self.con.execute_command(
42-
'AI.SCRIPTRUN', 'redis_scripts{1}', 'test_set_key')
42+
'AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_set_key', 'KEYS', 1, "x{1}")
4343
self.env.assertEqual(b"1", self.con.get("x{1}"))
4444

4545
def test_int_set_get(self):
46-
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts{1}', 'test_int_set_get', 'OUTPUTS', 'y{1}')
46+
self.con.execute_command('AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_int_set_get', 'KEYS', 1, "x{1}", 'OUTPUTS', 1, 'y{1}')
4747
y = self.con.execute_command('AI.TENSORGET', 'y{1}', 'meta' ,'VALUES')
4848
self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [], b"values", [1]] )
4949

5050
def test_int_set_incr(self):
51-
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts{1}', 'test_int_set_incr', 'OUTPUTS', 'y{1}')
51+
self.con.execute_command('AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_int_set_incr', 'KEYS', 1, "x{1}", 'OUTPUTS', 1, 'y{1}')
5252
y = self.con.execute_command('AI.TENSORGET', 'y{1}', 'meta' ,'VALUES')
5353
self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [], b"values", [2]] )
5454

5555
def test_float_get_set(self):
56-
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts{1}', 'test_float_set_get', 'OUTPUTS', 'y{1}')
56+
self.con.execute_command('AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_float_set_get', 'KEYS', 1, "x{1}", 'OUTPUTS', 1, 'y{1}')
5757
y = self.con.execute_command('AI.TENSORGET', 'y{1}', 'meta' ,'VALUES')
5858
self.env.assertEqual(y[0], b"dtype")
5959
self.env.assertEqual(y[1], b"FLOAT")
6060
self.env.assertEqual(y[2], b"shape")
6161
self.env.assertEqual(y[3], [])
6262
self.env.assertEqual(y[4], b"values")
6363
self.env.assertAlmostEqual(float(y[5][0]), 1.1, 0.1)
64-
64+
6565
def test_int_list(self):
66-
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts{1}', 'test_int_list', 'OUTPUTS', 'y{1}')
66+
self.con.execute_command('AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_int_list', 'KEYS', 1, "x{1}", 'OUTPUTS', 1, 'y{1}')
6767
y = self.con.execute_command('AI.TENSORGET', 'y{1}', 'meta' ,'VALUES')
6868
self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [2, 1], b"values", [1, 2]] )
6969

7070
def test_hash(self):
71-
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts{1}', 'test_hash', 'OUTPUTS', 'y{1}')
71+
self.con.execute_command('AI.SCRIPTEXECUTE', 'redis_scripts{1}', 'test_hash', 'KEYS', 1, "x{1}", 'OUTPUTS', 1, 'y{1}')
7272
y = self.con.execute_command('AI.TENSORGET', 'y{1}', 'meta' ,'VALUES')
7373
self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [2, 1], b"values", [1, 2]] )

0 commit comments

Comments
 (0)