Skip to content

Commit 740bbba

Browse files
lantigaalonre24
andauthored
Binary safe strings (#538)
* Mark locations that need change for binary strings * Sweeping change to binary safe strings * Format code * Fixes * Remove repeated reply * Formatting * Fix stray NULL * Rely on HoldString where possible * Update readies * Address PR comments * Free memory bug fix. Co-authored-by: alonre24 <[email protected]>
1 parent f707a56 commit 740bbba

19 files changed

+261
-183
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ endif()
55
ADD_LIBRARY(redisai_obj OBJECT
66
util/dict.c
77
util/queue.c
8+
util/string_utils.c
89
redisai.c
910
run_info.c
1011
background_workers.c

src/dag.c

Lines changed: 67 additions & 65 deletions
Large diffs are not rendered by default.

src/model.c

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "stats.h"
1818
#include "util/arr_rm_alloc.h"
1919
#include "util/dict.h"
20+
#include "util/string_utils.h"
2021
#include <pthread.h>
2122

2223
RedisModuleType *RedisAI_ModelType = NULL;
@@ -31,7 +32,7 @@ static void *RAI_Model_RdbLoad(struct RedisModuleIO *io, int encver) {
3132
RAI_Backend backend = RedisModule_LoadUnsigned(io);
3233
const char *devicestr = RedisModule_LoadStringBuffer(io, NULL);
3334

34-
const char *tag = RedisModule_LoadStringBuffer(io, NULL);
35+
RedisModuleString *tag = RedisModule_LoadString(io);
3536

3637
const size_t batchsize = RedisModule_LoadUnsigned(io);
3738
const size_t minbatchsize = RedisModule_LoadUnsigned(io);
@@ -113,7 +114,7 @@ static void *RAI_Model_RdbLoad(struct RedisModuleIO *io, int encver) {
113114
RedisModuleString *stats_keystr =
114115
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));
115116
const char *stats_devicestr = RedisModule_Strdup(devicestr);
116-
const char *stats_tag = RedisModule_Strdup(tag);
117+
RedisModuleString *stats_tag = RAI_HoldString(NULL, tag);
117118

118119
model->infokey =
119120
RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_MODEL, backend, stats_devicestr, stats_tag);
@@ -143,7 +144,7 @@ static void RAI_Model_RdbSave(RedisModuleIO *io, void *value) {
143144

144145
RedisModule_SaveUnsigned(io, model->backend);
145146
RedisModule_SaveStringBuffer(io, model->devicestr, strlen(model->devicestr) + 1);
146-
RedisModule_SaveStringBuffer(io, model->tag, strlen(model->tag) + 1);
147+
RedisModule_SaveString(io, model->tag);
147148
RedisModule_SaveUnsigned(io, model->opts.batchsize);
148149
RedisModule_SaveUnsigned(io, model->opts.minbatchsize);
149150
RedisModule_SaveUnsigned(io, model->ninputs);
@@ -221,7 +222,7 @@ static void RAI_Model_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, voi
221222

222223
const char *backendstr = RAI_BackendName(model->backend);
223224

224-
RedisModule_EmitAOF(aof, "AI.MODELSET", "slccclclcvcvcv", key, backendstr, model->devicestr,
225+
RedisModule_EmitAOF(aof, "AI.MODELSET", "sccsclclcvcvcv", key, backendstr, model->devicestr,
225226
model->tag, "BATCHSIZE", model->opts.batchsize, "MINBATCHSIZE",
226227
model->opts.minbatchsize, "INPUTS", inputs_, model->ninputs, "OUTPUTS",
227228
outputs_, model->noutputs, "BLOB", buffers_, n_chunks);
@@ -285,7 +286,7 @@ int RAI_ModelInit(RedisModuleCtx *ctx) {
285286
return RedisAI_ModelType != NULL;
286287
}
287288

288-
RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, const char *tag,
289+
RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModuleString *tag,
289290
RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs,
290291
const char **outputs, const char *modeldef, size_t modellen,
291292
RAI_Error *err) {
@@ -321,7 +322,11 @@ RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, const cha
321322
}
322323

323324
if (model) {
324-
model->tag = RedisModule_Strdup(tag);
325+
if (tag) {
326+
model->tag = RAI_HoldString(NULL, tag);
327+
} else {
328+
model->tag = RedisModule_CreateString(NULL, "", 0);
329+
}
325330
}
326331

327332
return model;
@@ -361,7 +366,7 @@ void RAI_ModelFree(RAI_Model *model, RAI_Error *err) {
361366
return;
362367
}
363368

364-
RedisModule_Free(model->tag);
369+
RedisModule_FreeString(NULL, model->tag);
365370

366371
RAI_RemoveStatsEntry(model->infokey);
367372

@@ -588,12 +593,12 @@ int RedisAI_Parse_ModelRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString *
588593
is_input = 1;
589594
outputs_flag_count = 1;
590595
} else {
591-
RedisModule_RetainString(ctx, argv[argpos]);
596+
RedisModuleString *arg = RAI_HoldString(ctx, argv[argpos]);
592597
if (is_input == 0) {
593-
*inkeys = array_append(*inkeys, argv[argpos]);
598+
*inkeys = array_append(*inkeys, arg);
594599
ninputs++;
595600
} else {
596-
*outkeys = array_append(*outkeys, argv[argpos]);
601+
*outkeys = array_append(*outkeys, arg);
597602
noutputs++;
598603
}
599604
}

src/model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ int RAI_ModelInit(RedisModuleCtx *ctx);
4949
* failures
5050
* @return RAI_Model model structure on success, or NULL if failed
5151
*/
52-
RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, const char *tag,
52+
RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModuleString *tag,
5353
RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs,
5454
const char **outputs, const char *modeldef, size_t modellen,
5555
RAI_Error *err);

src/model_struct.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ typedef struct RAI_Model {
2222
void *session;
2323
RAI_Backend backend;
2424
char *devicestr;
25-
char *tag;
25+
RedisModuleString *tag;
2626
RAI_ModelOpts opts;
2727
char **inputs;
2828
size_t ninputs;

src/redisai.c

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "run_info.h"
2727
#include "util/arr_rm_alloc.h"
2828
#include "util/dict.h"
29+
#include "util/string_utils.h"
2930
#include "util/queue.h"
3031
#include "version.h"
3132

@@ -184,9 +185,9 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
184185
return RedisModule_ReplyWithError(ctx, "ERR Invalid DEVICE");
185186
}
186187

187-
const char *tag = "";
188+
RedisModuleString *tag = NULL;
188189
if (AC_AdvanceIfMatch(&ac, "TAG")) {
189-
AC_GetString(&ac, &tag, NULL, 0);
190+
AC_GetRString(&ac, &tag, 0);
190191
}
191192

192193
unsigned long long batchsize = 0;
@@ -470,7 +471,8 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
470471
RedisModule_ReplyWithCString(ctx, mto->devicestr);
471472

472473
RedisModule_ReplyWithCString(ctx, "tag");
473-
RedisModule_ReplyWithCString(ctx, mto->tag ? mto->tag : "");
474+
RedisModuleString *empty_tag = RedisModule_CreateString(ctx, "", 0);
475+
RedisModule_ReplyWithString(ctx, mto->tag ? mto->tag : empty_tag);
474476

475477
RedisModule_ReplyWithCString(ctx, "batchsize");
476478
RedisModule_ReplyWithLongLong(ctx, (long)mto->opts.batchsize);
@@ -539,15 +541,15 @@ int RedisAI_ModelScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
539541

540542
long long nkeys;
541543
RedisModuleString **keys;
542-
const char **tags;
544+
RedisModuleString **tags;
543545
RAI_ListStatsEntries(RAI_MODEL, &nkeys, &keys, &tags);
544546

545547
RedisModule_ReplyWithArray(ctx, nkeys);
546548

547549
for (long long i = 0; i < nkeys; i++) {
548550
RedisModule_ReplyWithArray(ctx, 2);
549551
RedisModule_ReplyWithString(ctx, keys[i]);
550-
RedisModule_ReplyWithCString(ctx, tags[i]);
552+
RedisModule_ReplyWithString(ctx, tags[i]);
551553
}
552554

553555
RedisModule_Free(keys);
@@ -633,7 +635,7 @@ int RedisAI_ScriptGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
633635
RedisModule_ReplyWithCString(ctx, "device");
634636
RedisModule_ReplyWithCString(ctx, sto->devicestr);
635637
RedisModule_ReplyWithCString(ctx, "tag");
636-
RedisModule_ReplyWithCString(ctx, sto->tag);
638+
RedisModule_ReplyWithString(ctx, sto->tag);
637639
if (source) {
638640
RedisModule_ReplyWithCString(ctx, "source");
639641
RedisModule_ReplyWithCString(ctx, sto->scriptdef);
@@ -682,9 +684,9 @@ int RedisAI_ScriptSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
682684
const char *devicestr;
683685
AC_GetString(&ac, &devicestr, NULL, 0);
684686

685-
const char *tag = "";
687+
RedisModuleString *tag = NULL;
686688
if (AC_AdvanceIfMatch(&ac, "TAG")) {
687-
AC_GetString(&ac, &tag, NULL, 0);
689+
AC_GetRString(&ac, &tag, 0);
688690
}
689691

690692
if (AC_IsAtEnd(&ac)) {
@@ -780,15 +782,15 @@ int RedisAI_ScriptScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg
780782

781783
long long nkeys;
782784
RedisModuleString **keys;
783-
const char **tags;
785+
RedisModuleString **tags;
784786
RAI_ListStatsEntries(RAI_SCRIPT, &nkeys, &keys, &tags);
785787

786788
RedisModule_ReplyWithArray(ctx, nkeys);
787789

788790
for (long long i = 0; i < nkeys; i++) {
789791
RedisModule_ReplyWithArray(ctx, 2);
790792
RedisModule_ReplyWithString(ctx, keys[i]);
791-
RedisModule_ReplyWithCString(ctx, tags[i]);
793+
RedisModule_ReplyWithString(ctx, tags[i]);
792794
}
793795

794796
RedisModule_Free(keys);
@@ -803,7 +805,7 @@ int RedisAI_ScriptScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg
803805
int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
804806
if (argc != 2 && argc != 3)
805807
return RedisModule_WrongArity(ctx);
806-
const char *runkey = RedisModule_StringPtrLen(argv[1], NULL);
808+
RedisModuleString *runkey = argv[1];
807809
struct RedisAI_RunStats *rstats = NULL;
808810
if (RAI_GetRunStats(runkey, &rstats) == REDISMODULE_ERR) {
809811
return RedisModule_ReplyWithError(ctx, "ERR cannot find run info for key");
@@ -833,7 +835,11 @@ int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
833835
RedisModule_ReplyWithCString(ctx, "device");
834836
RedisModule_ReplyWithCString(ctx, rstats->devicestr);
835837
RedisModule_ReplyWithCString(ctx, "tag");
836-
RedisModule_ReplyWithCString(ctx, rstats->tag);
838+
if (rstats->tag) {
839+
RedisModule_ReplyWithString(ctx, rstats->tag);
840+
} else {
841+
RedisModule_ReplyWithCString(ctx, "");
842+
}
837843
RedisModule_ReplyWithCString(ctx, "duration");
838844
RedisModule_ReplyWithLongLong(ctx, rstats->duration_us);
839845
RedisModule_ReplyWithCString(ctx, "samples");
@@ -1209,9 +1215,7 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
12091215
return REDISMODULE_ERR;
12101216
}
12111217

1212-
run_stats = AI_dictCreate(&AI_dictTypeHeapStrings, NULL);
1218+
run_stats = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
12131219

12141220
return REDISMODULE_OK;
12151221
}
1216-
1217-
extern AI_dictType AI_dictTypeHeapStrings;

src/run_info.c

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,19 @@
1515
#include "tensor.h"
1616
#include "util/arr_rm_alloc.h"
1717
#include "util/dict.h"
18-
19-
static uint64_t RAI_TensorDictKeyHashFunction(const void *key) {
20-
return AI_dictGenHashFunction(key, strlen((char *)key));
21-
}
22-
23-
static int RAI_TensorDictKeyStrcmp(void *privdata, const void *key1, const void *key2) {
24-
const char *strKey1 = key1;
25-
const char *strKey2 = key2;
26-
return strcmp(strKey1, strKey2) == 0;
27-
}
28-
29-
static void RAI_TensorDictKeyFree(void *privdata, void *key) { RedisModule_Free(key); }
30-
31-
static void *RAI_TensorDictKeyDup(void *privdata, const void *key) {
32-
return RedisModule_Strdup((char *)key);
33-
}
18+
#include "util/string_utils.h"
19+
#include <pthread.h>
3420

3521
static void RAI_TensorDictValFree(void *privdata, void *obj) {
3622
return RAI_TensorFree((RAI_Tensor *)obj);
3723
}
3824

3925
AI_dictType AI_dictTypeTensorVals = {
40-
.hashFunction = RAI_TensorDictKeyHashFunction,
41-
.keyDup = RAI_TensorDictKeyDup,
26+
.hashFunction = RAI_RStringsHashFunction,
27+
.keyDup = RAI_RStringsKeyDup,
4228
.valDup = NULL,
43-
.keyCompare = RAI_TensorDictKeyStrcmp,
44-
.keyDestructor = RAI_TensorDictKeyFree,
29+
.keyCompare = RAI_RStringsKeyCompare,
30+
.keyDestructor = RAI_RStringsKeyDestructor,
4531
.valDestructor = RAI_TensorDictValFree,
4632
};
4733

@@ -105,11 +91,11 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) {
10591
if (!(rinfo->dagTensorsContext)) {
10692
return REDISMODULE_ERR;
10793
}
108-
rinfo->dagTensorsLoadedContext = AI_dictCreate(&AI_dictTypeHeapStrings, NULL);
94+
rinfo->dagTensorsLoadedContext = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
10995
if (!(rinfo->dagTensorsLoadedContext)) {
11096
return REDISMODULE_ERR;
11197
}
112-
rinfo->dagTensorsPersistedContext = AI_dictCreate(&AI_dictTypeHeapStrings, NULL);
98+
rinfo->dagTensorsPersistedContext = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
11399
if (!(rinfo->dagTensorsPersistedContext)) {
114100
return REDISMODULE_ERR;
115101
}

src/script.c

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "script_struct.h"
1414
#include "stats.h"
1515
#include "util/arr_rm_alloc.h"
16+
#include "util/string_utils.h"
1617
#include "version.h"
1718
#include <pthread.h>
1819

@@ -28,7 +29,7 @@ static void *RAI_Script_RdbLoad(struct RedisModuleIO *io, int encver) {
2829
RAI_Error err = {0};
2930

3031
const char *devicestr = RedisModule_LoadStringBuffer(io, NULL);
31-
const char *tag = RedisModule_LoadStringBuffer(io, NULL);
32+
RedisModuleString *tag = RedisModule_LoadString(io);
3233

3334
size_t len;
3435
char *scriptdef = RedisModule_LoadStringBuffer(io, &len);
@@ -58,12 +59,13 @@ static void *RAI_Script_RdbLoad(struct RedisModuleIO *io, int encver) {
5859
RedisModuleString *stats_keystr =
5960
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));
6061
const char *stats_devicestr = RedisModule_Strdup(devicestr);
61-
const char *stats_tag = RedisModule_Strdup(tag);
62+
63+
tag = RAI_HoldString(NULL, tag);
6264

6365
script->infokey = RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_SCRIPT, RAI_BACKEND_TORCH,
64-
stats_devicestr, stats_tag);
66+
stats_devicestr, tag);
6567

66-
RedisModule_Free(stats_keystr);
68+
RedisModule_FreeString(NULL, stats_keystr);
6769

6870
return script;
6971
}
@@ -74,14 +76,14 @@ static void RAI_Script_RdbSave(RedisModuleIO *io, void *value) {
7476
size_t len = strlen(script->scriptdef) + 1;
7577

7678
RedisModule_SaveStringBuffer(io, script->devicestr, strlen(script->devicestr) + 1);
77-
RedisModule_SaveStringBuffer(io, script->tag, strlen(script->tag) + 1);
79+
RedisModule_SaveString(io, script->tag);
7880
RedisModule_SaveStringBuffer(io, script->scriptdef, len);
7981
}
8082

8183
static void RAI_Script_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, void *value) {
8284
RAI_Script *script = (RAI_Script *)value;
8385

84-
RedisModule_EmitAOF(aof, "AI.SCRIPTSET", "scccc", key, script->devicestr, script->tag, "SOURCE",
86+
RedisModule_EmitAOF(aof, "AI.SCRIPTSET", "scscc", key, script->devicestr, script->tag, "SOURCE",
8587
script->scriptdef);
8688
}
8789

@@ -107,7 +109,7 @@ int RAI_ScriptInit(RedisModuleCtx *ctx) {
107109
return RedisAI_ScriptType != NULL;
108110
}
109111

110-
RAI_Script *RAI_ScriptCreate(const char *devicestr, const char *tag, const char *scriptdef,
112+
RAI_Script *RAI_ScriptCreate(const char *devicestr, RedisModuleString *tag, const char *scriptdef,
111113
RAI_Error *err) {
112114
if (!RAI_backends.torch.script_create) {
113115
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH");
@@ -116,7 +118,11 @@ RAI_Script *RAI_ScriptCreate(const char *devicestr, const char *tag, const char
116118
RAI_Script *script = RAI_backends.torch.script_create(devicestr, scriptdef, err);
117119

118120
if (script) {
119-
script->tag = RedisModule_Strdup(tag);
121+
if (tag) {
122+
script->tag = RAI_HoldString(NULL, tag);
123+
} else {
124+
script->tag = RedisModule_CreateString(NULL, "", 0);
125+
}
120126
}
121127

122128
return script;
@@ -132,7 +138,7 @@ void RAI_ScriptFree(RAI_Script *script, RAI_Error *err) {
132138
return;
133139
}
134140

135-
RedisModule_Free(script->tag);
141+
RedisModule_FreeString(NULL, script->tag);
136142

137143
RAI_RemoveStatsEntry(script->infokey);
138144

src/script.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ int RAI_ScriptInit(RedisModuleCtx *ctx);
3636
* failures
3737
* @return RAI_Script script structure on success, or NULL if failed
3838
*/
39-
RAI_Script *RAI_ScriptCreate(const char *devicestr, const char *tag, const char *scriptdef,
39+
RAI_Script *RAI_ScriptCreate(const char *devicestr, RedisModuleString *tag, const char *scriptdef,
4040
RAI_Error *err);
4141

4242
/**

src/script_struct.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ typedef struct RAI_Script {
1212
// We keep it here at the moment, until we have a
1313
// CUDA allocator for dlpack
1414
char *devicestr;
15-
char *tag;
15+
RedisModuleString *tag;
1616
long long refCount;
1717
void *infokey;
1818
} RAI_Script;

0 commit comments

Comments
 (0)