Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions src/redisai.c
Original file line number Diff line number Diff line change
Expand Up @@ -655,34 +655,31 @@ int RedisAI_ScriptGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
return REDISMODULE_ERR;
}

int meta = 0;
int source = 0;
bool meta = false; // Indicates whether META argument was given.
bool source = false; // Indicates whether SOURCE argument was given.
for (int i = 2; i < argc; i++) {
const char *optstr = RedisModule_StringPtrLen(argv[i], NULL);
if (!strcasecmp(optstr, "META")) {
meta = 1;
meta = true;
} else if (!strcasecmp(optstr, "SOURCE")) {
source = 1;
source = true;
}
}

if (!meta && !source) {
return RedisModule_ReplyWithError(ctx, "ERR no META or SOURCE specified");
}

// If only SOURCE arg was given, return only the script source.
if (!meta && source) {
RedisModule_ReplyWithCString(ctx, sto->scriptdef);
return REDISMODULE_OK;
}
// We return (META+SOURCE) if both args are given, or if none of them was given.
// The only case where we return only META data, is if META is given while SOURCE was not.
int out_entries = (source || !meta) ? 6 : 4;
RedisModule_ReplyWithArray(ctx, out_entries);

int outentries = source ? 6 : 4;

RedisModule_ReplyWithArray(ctx, outentries);
RedisModule_ReplyWithCString(ctx, "device");
RedisModule_ReplyWithCString(ctx, sto->devicestr);
RedisModule_ReplyWithCString(ctx, "tag");
RedisModule_ReplyWithString(ctx, sto->tag);
if (source) {
if (source || !meta) {
RedisModule_ReplyWithCString(ctx, "source");
RedisModule_ReplyWithCString(ctx, sto->scriptdef);
}
Expand Down
35 changes: 26 additions & 9 deletions tests/flow/tests_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,32 @@ def test_pytorch_scriptset(env):
ensureSlaveSynced(con, env)


def test_pytorch_scriptget(env):
if not TEST_PT:
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
return

con = env.getConnection()
con.execute_command('DEL', 'EMPTY{1}')
# ERR no script at key from SCRIPTGET
check_error_message(env, con, "script key is empty", 'AI.SCRIPTGET', 'EMPTY{1}')

con.execute_command('SET', 'NOT_SCRIPT{1}', 'BAR')
# ERR wrong type from SCRIPTGET
check_error_message(env, con, "WRONGTYPE Operation against a key holding the wrong kind of value", 'AI.SCRIPTGET', 'NOT_SCRIPT{1}')

script = load_file_content('script.txt')
ret = con.execute_command('AI.SCRIPTSET', 'ket{1}', DEVICE, 'TAG', 'asdf', 'SOURCE', script)
env.assertEqual(ret, b'OK')

# return meta + source
_, device, _, tag, _, source = con.execute_command('AI.SCRIPTGET', 'ket{1}')
env.assertEqual([device, tag, source], [b"CPU", b"asdf", script])
# return source only
source = con.execute_command('AI.SCRIPTGET', 'ket{1}', 'SOURCE')
env.assertEqual(source, script)


def test_pytorch_scriptdel(env):
if not TEST_PT:
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
Expand Down Expand Up @@ -453,17 +479,8 @@ def test_pytorch_scriptexecute_errors(env):
env.assertEqual(ret, b'OK')
ret = con.execute_command('AI.TENSORSET', 'b{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
env.assertEqual(ret, b'OK')

ensureSlaveSynced(con, env)

con.execute_command('DEL', 'EMPTY{1}')
# ERR no script at key from SCRIPTGET
check_error_message(env, con, "script key is empty", 'AI.SCRIPTGET', 'EMPTY{1}')

con.execute_command('SET', 'NOT_SCRIPT{1}', 'BAR')
# ERR wrong type from SCRIPTGET
check_error_message(env, con, "WRONGTYPE Operation against a key holding the wrong kind of value", 'AI.SCRIPTGET', 'NOT_SCRIPT{1}')

con.execute_command('DEL', 'EMPTY{1}')
# ERR no script at key from SCRIPTEXECUTE
check_error_message(env, con, "script key is empty", 'AI.SCRIPTEXECUTE', 'EMPTY{1}', 'bar', 'KEYS', 1 , '{1}', 'INPUTS', 1, 'b{1}', 'OUTPUTS', 1, 'c{1}')
Expand Down