Skip to content

Binary safe strings #538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ endif()
ADD_LIBRARY(redisai_obj OBJECT
util/dict.c
util/queue.c
util/string_utils.c
redisai.c
run_info.c
background_workers.c
Expand Down
132 changes: 67 additions & 65 deletions src/dag.c

Large diffs are not rendered by default.

25 changes: 15 additions & 10 deletions src/model.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "stats.h"
#include "util/arr_rm_alloc.h"
#include "util/dict.h"
#include "util/string_utils.h"
#include <pthread.h>

RedisModuleType *RedisAI_ModelType = NULL;
Expand All @@ -31,7 +32,7 @@ static void *RAI_Model_RdbLoad(struct RedisModuleIO *io, int encver) {
RAI_Backend backend = RedisModule_LoadUnsigned(io);
const char *devicestr = RedisModule_LoadStringBuffer(io, NULL);

const char *tag = RedisModule_LoadStringBuffer(io, NULL);
RedisModuleString *tag = RedisModule_LoadString(io);

const size_t batchsize = RedisModule_LoadUnsigned(io);
const size_t minbatchsize = RedisModule_LoadUnsigned(io);
Expand Down Expand Up @@ -113,7 +114,7 @@ static void *RAI_Model_RdbLoad(struct RedisModuleIO *io, int encver) {
RedisModuleString *stats_keystr =
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));
const char *stats_devicestr = RedisModule_Strdup(devicestr);
const char *stats_tag = RedisModule_Strdup(tag);
RedisModuleString *stats_tag = RAI_HoldString(NULL, tag);

model->infokey =
RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_MODEL, backend, stats_devicestr, stats_tag);
Expand Down Expand Up @@ -143,7 +144,7 @@ static void RAI_Model_RdbSave(RedisModuleIO *io, void *value) {

RedisModule_SaveUnsigned(io, model->backend);
RedisModule_SaveStringBuffer(io, model->devicestr, strlen(model->devicestr) + 1);
RedisModule_SaveStringBuffer(io, model->tag, strlen(model->tag) + 1);
RedisModule_SaveString(io, model->tag);
RedisModule_SaveUnsigned(io, model->opts.batchsize);
RedisModule_SaveUnsigned(io, model->opts.minbatchsize);
RedisModule_SaveUnsigned(io, model->ninputs);
Expand Down Expand Up @@ -221,7 +222,7 @@ static void RAI_Model_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, voi

const char *backendstr = RAI_BackendName(model->backend);

RedisModule_EmitAOF(aof, "AI.MODELSET", "slccclclcvcvcv", key, backendstr, model->devicestr,
RedisModule_EmitAOF(aof, "AI.MODELSET", "sccsclclcvcvcv", key, backendstr, model->devicestr,
model->tag, "BATCHSIZE", model->opts.batchsize, "MINBATCHSIZE",
model->opts.minbatchsize, "INPUTS", inputs_, model->ninputs, "OUTPUTS",
outputs_, model->noutputs, "BLOB", buffers_, n_chunks);
Expand Down Expand Up @@ -285,7 +286,7 @@ int RAI_ModelInit(RedisModuleCtx *ctx) {
return RedisAI_ModelType != NULL;
}

RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, const char *tag,
RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModuleString *tag,
RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs,
const char **outputs, const char *modeldef, size_t modellen,
RAI_Error *err) {
Expand Down Expand Up @@ -321,7 +322,11 @@ RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, const cha
}

if (model) {
model->tag = RedisModule_Strdup(tag);
if (tag) {
model->tag = RAI_HoldString(NULL, tag);
} else {
model->tag = RedisModule_CreateString(NULL, "", 0);
}
}

return model;
Expand Down Expand Up @@ -361,7 +366,7 @@ void RAI_ModelFree(RAI_Model *model, RAI_Error *err) {
return;
}

RedisModule_Free(model->tag);
RedisModule_FreeString(NULL, model->tag);

RAI_RemoveStatsEntry(model->infokey);

Expand Down Expand Up @@ -588,12 +593,12 @@ int RedisAI_Parse_ModelRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString *
is_input = 1;
outputs_flag_count = 1;
} else {
RedisModule_RetainString(ctx, argv[argpos]);
RedisModuleString *arg = RAI_HoldString(ctx, argv[argpos]);
if (is_input == 0) {
*inkeys = array_append(*inkeys, argv[argpos]);
*inkeys = array_append(*inkeys, arg);
ninputs++;
} else {
*outkeys = array_append(*outkeys, argv[argpos]);
*outkeys = array_append(*outkeys, arg);
noutputs++;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ int RAI_ModelInit(RedisModuleCtx *ctx);
* failures
* @return RAI_Model model structure on success, or NULL if failed
*/
RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, const char *tag,
RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModuleString *tag,
RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs,
const char **outputs, const char *modeldef, size_t modellen,
RAI_Error *err);
Expand Down
2 changes: 1 addition & 1 deletion src/model_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ typedef struct RAI_Model {
void *session;
RAI_Backend backend;
char *devicestr;
char *tag;
RedisModuleString *tag;
RAI_ModelOpts opts;
char **inputs;
size_t ninputs;
Expand Down
34 changes: 19 additions & 15 deletions src/redisai.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "run_info.h"
#include "util/arr_rm_alloc.h"
#include "util/dict.h"
#include "util/string_utils.h"
#include "util/queue.h"
#include "version.h"

Expand Down Expand Up @@ -184,9 +185,9 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
return RedisModule_ReplyWithError(ctx, "ERR Invalid DEVICE");
}

const char *tag = "";
RedisModuleString *tag = NULL;
if (AC_AdvanceIfMatch(&ac, "TAG")) {
AC_GetString(&ac, &tag, NULL, 0);
AC_GetRString(&ac, &tag, 0);
}

unsigned long long batchsize = 0;
Expand Down Expand Up @@ -470,7 +471,8 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
RedisModule_ReplyWithCString(ctx, mto->devicestr);

RedisModule_ReplyWithCString(ctx, "tag");
RedisModule_ReplyWithCString(ctx, mto->tag ? mto->tag : "");
RedisModuleString *empty_tag = RedisModule_CreateString(ctx, "", 0);
RedisModule_ReplyWithString(ctx, mto->tag ? mto->tag : empty_tag);

RedisModule_ReplyWithCString(ctx, "batchsize");
RedisModule_ReplyWithLongLong(ctx, (long)mto->opts.batchsize);
Expand Down Expand Up @@ -539,15 +541,15 @@ int RedisAI_ModelScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv

long long nkeys;
RedisModuleString **keys;
const char **tags;
RedisModuleString **tags;
RAI_ListStatsEntries(RAI_MODEL, &nkeys, &keys, &tags);

RedisModule_ReplyWithArray(ctx, nkeys);

for (long long i = 0; i < nkeys; i++) {
RedisModule_ReplyWithArray(ctx, 2);
RedisModule_ReplyWithString(ctx, keys[i]);
RedisModule_ReplyWithCString(ctx, tags[i]);
RedisModule_ReplyWithString(ctx, tags[i]);
}

RedisModule_Free(keys);
Expand Down Expand Up @@ -633,7 +635,7 @@ int RedisAI_ScriptGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
RedisModule_ReplyWithCString(ctx, "device");
RedisModule_ReplyWithCString(ctx, sto->devicestr);
RedisModule_ReplyWithCString(ctx, "tag");
RedisModule_ReplyWithCString(ctx, sto->tag);
RedisModule_ReplyWithString(ctx, sto->tag);
if (source) {
RedisModule_ReplyWithCString(ctx, "source");
RedisModule_ReplyWithCString(ctx, sto->scriptdef);
Expand Down Expand Up @@ -682,9 +684,9 @@ int RedisAI_ScriptSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
const char *devicestr;
AC_GetString(&ac, &devicestr, NULL, 0);

const char *tag = "";
RedisModuleString *tag = NULL;
if (AC_AdvanceIfMatch(&ac, "TAG")) {
AC_GetString(&ac, &tag, NULL, 0);
AC_GetRString(&ac, &tag, 0);
}

if (AC_IsAtEnd(&ac)) {
Expand Down Expand Up @@ -780,15 +782,15 @@ int RedisAI_ScriptScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg

long long nkeys;
RedisModuleString **keys;
const char **tags;
RedisModuleString **tags;
RAI_ListStatsEntries(RAI_SCRIPT, &nkeys, &keys, &tags);

RedisModule_ReplyWithArray(ctx, nkeys);

for (long long i = 0; i < nkeys; i++) {
RedisModule_ReplyWithArray(ctx, 2);
RedisModule_ReplyWithString(ctx, keys[i]);
RedisModule_ReplyWithCString(ctx, tags[i]);
RedisModule_ReplyWithString(ctx, tags[i]);
}

RedisModule_Free(keys);
Expand All @@ -803,7 +805,7 @@ int RedisAI_ScriptScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg
int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
if (argc != 2 && argc != 3)
return RedisModule_WrongArity(ctx);
const char *runkey = RedisModule_StringPtrLen(argv[1], NULL);
RedisModuleString *runkey = argv[1];
struct RedisAI_RunStats *rstats = NULL;
if (RAI_GetRunStats(runkey, &rstats) == REDISMODULE_ERR) {
return RedisModule_ReplyWithError(ctx, "ERR cannot find run info for key");
Expand Down Expand Up @@ -833,7 +835,11 @@ int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
RedisModule_ReplyWithCString(ctx, "device");
RedisModule_ReplyWithCString(ctx, rstats->devicestr);
RedisModule_ReplyWithCString(ctx, "tag");
RedisModule_ReplyWithCString(ctx, rstats->tag);
if (rstats->tag) {
RedisModule_ReplyWithString(ctx, rstats->tag);
} else {
RedisModule_ReplyWithCString(ctx, "");
}
RedisModule_ReplyWithCString(ctx, "duration");
RedisModule_ReplyWithLongLong(ctx, rstats->duration_us);
RedisModule_ReplyWithCString(ctx, "samples");
Expand Down Expand Up @@ -1209,9 +1215,7 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
return REDISMODULE_ERR;
}

run_stats = AI_dictCreate(&AI_dictTypeHeapStrings, NULL);
run_stats = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);

return REDISMODULE_OK;
}

extern AI_dictType AI_dictTypeHeapStrings;
30 changes: 8 additions & 22 deletions src/run_info.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,19 @@
#include "tensor.h"
#include "util/arr_rm_alloc.h"
#include "util/dict.h"

static uint64_t RAI_TensorDictKeyHashFunction(const void *key) {
return AI_dictGenHashFunction(key, strlen((char *)key));
}

static int RAI_TensorDictKeyStrcmp(void *privdata, const void *key1, const void *key2) {
const char *strKey1 = key1;
const char *strKey2 = key2;
return strcmp(strKey1, strKey2) == 0;
}

static void RAI_TensorDictKeyFree(void *privdata, void *key) { RedisModule_Free(key); }

static void *RAI_TensorDictKeyDup(void *privdata, const void *key) {
return RedisModule_Strdup((char *)key);
}
#include "util/string_utils.h"
#include <pthread.h>

static void RAI_TensorDictValFree(void *privdata, void *obj) {
return RAI_TensorFree((RAI_Tensor *)obj);
}

AI_dictType AI_dictTypeTensorVals = {
.hashFunction = RAI_TensorDictKeyHashFunction,
.keyDup = RAI_TensorDictKeyDup,
.hashFunction = RAI_RStringsHashFunction,
.keyDup = RAI_RStringsKeyDup,
.valDup = NULL,
.keyCompare = RAI_TensorDictKeyStrcmp,
.keyDestructor = RAI_TensorDictKeyFree,
.keyCompare = RAI_RStringsKeyCompare,
.keyDestructor = RAI_RStringsKeyDestructor,
.valDestructor = RAI_TensorDictValFree,
};

Expand Down Expand Up @@ -105,11 +91,11 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) {
if (!(rinfo->dagTensorsContext)) {
return REDISMODULE_ERR;
}
rinfo->dagTensorsLoadedContext = AI_dictCreate(&AI_dictTypeHeapStrings, NULL);
rinfo->dagTensorsLoadedContext = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
if (!(rinfo->dagTensorsLoadedContext)) {
return REDISMODULE_ERR;
}
rinfo->dagTensorsPersistedContext = AI_dictCreate(&AI_dictTypeHeapStrings, NULL);
rinfo->dagTensorsPersistedContext = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
if (!(rinfo->dagTensorsPersistedContext)) {
return REDISMODULE_ERR;
}
Expand Down
24 changes: 15 additions & 9 deletions src/script.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "script_struct.h"
#include "stats.h"
#include "util/arr_rm_alloc.h"
#include "util/string_utils.h"
#include "version.h"
#include <pthread.h>

Expand All @@ -28,7 +29,7 @@ static void *RAI_Script_RdbLoad(struct RedisModuleIO *io, int encver) {
RAI_Error err = {0};

const char *devicestr = RedisModule_LoadStringBuffer(io, NULL);
const char *tag = RedisModule_LoadStringBuffer(io, NULL);
RedisModuleString *tag = RedisModule_LoadString(io);

size_t len;
char *scriptdef = RedisModule_LoadStringBuffer(io, &len);
Expand Down Expand Up @@ -58,12 +59,13 @@ static void *RAI_Script_RdbLoad(struct RedisModuleIO *io, int encver) {
RedisModuleString *stats_keystr =
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));
const char *stats_devicestr = RedisModule_Strdup(devicestr);
const char *stats_tag = RedisModule_Strdup(tag);

tag = RAI_HoldString(NULL, tag);

script->infokey = RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_SCRIPT, RAI_BACKEND_TORCH,
stats_devicestr, stats_tag);
stats_devicestr, tag);

RedisModule_Free(stats_keystr);
RedisModule_FreeString(NULL, stats_keystr);

return script;
}
Expand All @@ -74,14 +76,14 @@ static void RAI_Script_RdbSave(RedisModuleIO *io, void *value) {
size_t len = strlen(script->scriptdef) + 1;

RedisModule_SaveStringBuffer(io, script->devicestr, strlen(script->devicestr) + 1);
RedisModule_SaveStringBuffer(io, script->tag, strlen(script->tag) + 1);
RedisModule_SaveString(io, script->tag);
RedisModule_SaveStringBuffer(io, script->scriptdef, len);
}

static void RAI_Script_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, void *value) {
RAI_Script *script = (RAI_Script *)value;

RedisModule_EmitAOF(aof, "AI.SCRIPTSET", "scccc", key, script->devicestr, script->tag, "SOURCE",
RedisModule_EmitAOF(aof, "AI.SCRIPTSET", "scscc", key, script->devicestr, script->tag, "SOURCE",
script->scriptdef);
}

Expand All @@ -107,7 +109,7 @@ int RAI_ScriptInit(RedisModuleCtx *ctx) {
return RedisAI_ScriptType != NULL;
}

RAI_Script *RAI_ScriptCreate(const char *devicestr, const char *tag, const char *scriptdef,
RAI_Script *RAI_ScriptCreate(const char *devicestr, RedisModuleString *tag, const char *scriptdef,
RAI_Error *err) {
if (!RAI_backends.torch.script_create) {
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH");
Expand All @@ -116,7 +118,11 @@ RAI_Script *RAI_ScriptCreate(const char *devicestr, const char *tag, const char
RAI_Script *script = RAI_backends.torch.script_create(devicestr, scriptdef, err);

if (script) {
script->tag = RedisModule_Strdup(tag);
if (tag) {
script->tag = RAI_HoldString(NULL, tag);
} else {
script->tag = RedisModule_CreateString(NULL, "", 0);
}
}

return script;
Expand All @@ -132,7 +138,7 @@ void RAI_ScriptFree(RAI_Script *script, RAI_Error *err) {
return;
}

RedisModule_Free(script->tag);
RedisModule_FreeString(NULL, script->tag);

RAI_RemoveStatsEntry(script->infokey);

Expand Down
2 changes: 1 addition & 1 deletion src/script.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int RAI_ScriptInit(RedisModuleCtx *ctx);
* failures
* @return RAI_Script script structure on success, or NULL if failed
*/
RAI_Script *RAI_ScriptCreate(const char *devicestr, const char *tag, const char *scriptdef,
RAI_Script *RAI_ScriptCreate(const char *devicestr, RedisModuleString *tag, const char *scriptdef,
RAI_Error *err);

/**
Expand Down
2 changes: 1 addition & 1 deletion src/script_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ typedef struct RAI_Script {
// We keep it here at the moment, until we have a
// CUDA allocator for dlpack
char *devicestr;
char *tag;
RedisModuleString *tag;
long long refCount;
void *infokey;
} RAI_Script;
Expand Down
Loading