Skip to content

Commit b624b9d

Browse files
authored
Add possibility to provide a model in chunks (#338) * Add possibility to provide a model in chunks * Add test on chunked modelset
1 parent 925f2fb commit b624b9d

File tree

9 files changed

+189
-109
lines changed

9 files changed

+189
-109
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Note that Redis config is located at `/usr/local/etc/redis/redis.conf` which can
4747

4848
On the client, set the model
4949
```sh
50-
redis-cli -x AI.MODELSET foo TF CPU INPUTS a b OUTPUTS c < test/test_data/graph.pb
50+
redis-cli -x AI.MODELSET foo TF CPU INPUTS a b OUTPUTS c BLOB < test/test_data/graph.pb
5151
```
5252

5353
Then create the input tensors, run the computation graph and get the output tensor (see `load_model.sh`). Note the signatures:

docs/commands.md

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ AI.TENSORGET foo META VALUES
123123
Set a model.
124124

125125
```sql
126-
AI.MODELSET model_key backend device [TAG tag] [BATCHSIZE n [MINBATCHSIZE m]] [INPUTS name1 name2 ... OUTPUTS name1 name2 ...] model_blob
126+
AI.MODELSET model_key backend device [TAG tag] [BATCHSIZE n [MINBATCHSIZE m]] [INPUTS name1 name2 ... OUTPUTS name1 name2 ...] BLOB model_blob
127127
```
128128

129129
* model_key - Key for storing the model
@@ -142,28 +142,29 @@ AI.MODELSET model_key backend device [TAG tag] [BATCHSIZE n [MINBATCHSIZE m]] [I
142142
Default is 0 (no minimum batch size).
143143
* INPUTS name1 name2 ... - Name of the nodes in the provided graph corresponding to inputs [`TF` backend only]
144144
* OUTPUTS name1 name2 ... - Name of the nodes in the provided graph corresponding to outputs [`TF` backend only]
145-
* model_blob - Binary buffer containing the model protobuf saved from a supported backend
145+
* BLOB model_blob - Binary buffer containing the model protobuf saved from a supported backend. Since Redis supports strings
146+
up to 512MB, blobs for very large models need to be chunked, e.g. `BLOB chunk1 chunk2 ...`.
146147

147148
### MODELSET Example
148149

149150
```sql
150-
AI.MODELSET resnet18 TORCH GPU < foo.pt
151+
AI.MODELSET resnet18 TORCH GPU BLOB < foo.pt
151152
```
152153

153154
```sql
154-
AI.MODELSET resnet18 TF CPU INPUTS in1 OUTPUTS linear4 < foo.pb
155+
AI.MODELSET resnet18 TF CPU INPUTS in1 OUTPUTS linear4 BLOB < foo.pb
155156
```
156157

157158
```sql
158-
AI.MODELSET mnist_net ONNX CPU TAG mnist:lenet:v0.1 < mnist.onnx
159+
AI.MODELSET mnist_net ONNX CPU TAG mnist:lenet:v0.1 BLOB < mnist.onnx
159160
```
160161

161162
```sql
162-
AI.MODELSET mnist_net ONNX CPU BATCHSIZE 10 < mnist.onnx
163+
AI.MODELSET mnist_net ONNX CPU BATCHSIZE 10 BLOB < mnist.onnx
163164
```
164165

165166
```sql
166-
AI.MODELSET resnet18 TF CPU BATCHSIZE 10 MINBATCHSIZE 6 INPUTS in1 OUTPUTS linear4 < foo.pb
167+
AI.MODELSET resnet18 TF CPU BATCHSIZE 10 MINBATCHSIZE 6 INPUTS in1 OUTPUTS linear4 BLOB < foo.pb
167168
```
168169

169170
## AI.MODELGET
@@ -284,13 +285,13 @@ AI._MODELSCAN
284285
Set a script.
285286

286287
```sql
287-
AI.SCRIPTSET script_key device [TAG tag] script_source
288+
AI.SCRIPTSET script_key device [TAG tag] SOURCE script_source
288289
```
289290

290291
* script_key - Key for storing the script
291292
* device - The device where the script will execute
292293
* TAG tag - Optional string tagging the script, such as a version number or other identifier
293-
* script_source - A string containing [TorchScript](https://pytorch.org/docs/stable/jit.html) source code
294+
* SOURCE script_source - A string containing [TorchScript](https://pytorch.org/docs/stable/jit.html) source code
294295

295296
### SCRIPTSET Example
296297

@@ -302,11 +303,11 @@ def addtwo(a, b):
302303
```
303304

304305
```sql
305-
AI.SCRIPTSET addscript GPU < addtwo.txt
306+
AI.SCRIPTSET addscript GPU SOURCE < addtwo.txt
306307
```
307308

308309
```sql
309-
AI.SCRIPTSET addscript GPU TAG myscript:v0.1 < addtwo.txt
310+
AI.SCRIPTSET addscript GPU TAG myscript:v0.1 SOURCE < addtwo.txt
310311
```
311312

312313
## AI.SCRIPTGET

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ docker run -p 6379:6379 -it --rm redisai/redisai
2424
On the client, load a backend (TF, TORCH or ONNX), and set the model
2525
```sh
2626
redis-cli AI.CONFIG LOADBACKEND TF install/backends/redisai_tensorflow/redisai_tensorflow.so
27-
redis-cli -x AI.MODELSET foo TF CPU INPUTS a b OUTPUTS c < test/test_data/graph.pb
27+
redis-cli -x AI.MODELSET foo TF CPU INPUTS a b OUTPUTS c BLOB < test/test_data/graph.pb
2828
```
2929

3030
Then create the input tensors, run the computation graph and get the output tensor (see `load_model.sh`). Note the signatures:

src/redisai.c

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ int RedisAI_TensorGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
8686
}
8787

8888
/**
89-
* AI.MODELSET model_key backend device [TAG tag] [BATCHSIZE n [MINBATCHSIZE m]] [INPUTS name1 name2 ... OUTPUTS name1 name2 ...] model_blob
89+
* AI.MODELSET model_key backend device [TAG tag] [BATCHSIZE n [MINBATCHSIZE m]] [INPUTS name1 name2 ... OUTPUTS name1 name2 ...] BLOB model_blob
9090
*/
9191
int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
9292
RedisModule_AutoMemory(ctx);
@@ -121,7 +121,14 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
121121
const char* devicestr;
122122
AC_GetString(&ac, &devicestr, NULL, 0);
123123

124-
if (strlen(devicestr) > 10) {
124+
if (strlen(devicestr) > 10 ||
125+
strcasecmp(devicestr, "INPUTS") == 0 ||
126+
strcasecmp(devicestr, "OUTPUTS") == 0 ||
127+
strcasecmp(devicestr, "TAG") == 0 ||
128+
strcasecmp(devicestr, "BATCHSIZE") == 0 ||
129+
strcasecmp(devicestr, "MINBATCHSIZE") == 0 ||
130+
strcasecmp(devicestr, "BLOB") == 0
131+
) {
125132
return RedisModule_ReplyWithError(ctx, "ERR Invalid DEVICE");
126133
}
127134

@@ -150,21 +157,21 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
150157
}
151158
}
152159

153-
154160
if (AC_IsAtEnd(&ac)) {
155161
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing model BLOB");
156162
}
157163

158164
ArgsCursor optionsac;
159-
AC_GetSliceToOffset(&ac, &optionsac, argc-2);
165+
const char* blob_matches[] = {"BLOB"};
166+
AC_GetSliceUntilMatches(&ac, &optionsac, 1, blob_matches);
160167

161168
if (optionsac.argc == 0 && backend == RAI_BACKEND_TENSORFLOW) {
162169
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, INPUTS and OUTPUTS not specified");
163170
}
164171

165172
ArgsCursor inac = {0};
166173
ArgsCursor outac = {0};
167-
if (optionsac.argc > 0) {
174+
if (optionsac.argc > 0 && backend == RAI_BACKEND_TENSORFLOW) {
168175
if (!AC_AdvanceIfMatch(&optionsac, "INPUTS")) {
169176
return RedisModule_ReplyWithError(ctx, "ERR INPUTS not specified");
170177
}
@@ -202,9 +209,39 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
202209

203210
RAI_Model *model = NULL;
204211

212+
AC_AdvanceUntilMatches(&ac, 1, blob_matches);
213+
214+
if (AC_IsAtEnd(&ac)) {
215+
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing model BLOB");
216+
}
217+
218+
AC_Advance(&ac);
219+
220+
ArgsCursor blobsac;
221+
AC_GetSliceToEnd(&ac, &blobsac);
222+
205223
size_t modellen;
206-
const char *modeldef;
207-
AC_GetString(&ac, &modeldef, &modellen, 0);
224+
char *modeldef;
225+
226+
if (blobsac.argc == 1) {
227+
AC_GetString(&blobsac, (const char**)&modeldef, &modellen, 0);
228+
}
229+
else {
230+
const char *chunks[blobsac.argc];
231+
size_t chunklens[blobsac.argc];
232+
modellen = 0;
233+
while (!AC_IsAtEnd(&blobsac)) {
234+
AC_GetString(&blobsac, &chunks[blobsac.offset], &chunklens[blobsac.offset], 0);
235+
modellen += chunklens[blobsac.offset-1];
236+
}
237+
238+
modeldef = RedisModule_Calloc(modellen, sizeof(char));
239+
size_t offset = 0;
240+
for (size_t i=0; i<blobsac.argc; i++) {
241+
memcpy(modeldef + offset, chunks[i], chunklens[i]);
242+
offset += chunklens[i];
243+
}
244+
}
208245

209246
RAI_Error err = {0};
210247

@@ -223,6 +260,10 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
223260
model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, outputs, modeldef, modellen, &err);
224261
}
225262

263+
if (blobsac.argc > 1) {
264+
RedisModule_Free(modeldef);
265+
}
266+
226267
if (err.code != RAI_OK) {
227268
#ifdef RAI_PRINT_BACKEND_ERRORS
228269
printf("ERR: %s\n", err.detail);
@@ -502,14 +543,14 @@ int RedisAI_ScriptRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
502543
ArgsCursor outac = {0};
503544

504545
if (!AC_AdvanceIfMatch(&ac, "INPUTS")) {
505-
return RedisModule_ReplyWithError(ctx, "INPUTS not specified");
546+
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, INPUTS not specified");
506547
}
507548

508549
const char* matches[] = {"OUTPUTS"};
509550
AC_GetSliceUntilMatches(&ac, &inac, 1, matches);
510551

511552
if (!AC_AdvanceIfMatch(&ac, "OUTPUTS")) {
512-
return RedisModule_ReplyWithError(ctx, "OUTPUTS not specified");
553+
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, OUTPUTS not specified");
513554
}
514555

515556
AC_GetSliceToEnd(&ac, &outac);
@@ -541,15 +582,15 @@ int RedisAI_ScriptRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
541582
if (!RAI_ScriptRunCtxAddInput(rinfo->sctx, t)) {
542583
RAI_FreeRunInfo(ctx,rinfo);
543584
RedisModule_CloseKey(key);
544-
return RedisModule_ReplyWithError(ctx, "Input key not found");
585+
return RedisModule_ReplyWithError(ctx, "ERR Input key not found");
545586
}
546587
}
547588

548589
for (size_t i=0; i<noutputs; i++) {
549590
if (!RAI_ScriptRunCtxAddOutput(rinfo->sctx)) {
550591
RAI_FreeRunInfo(ctx,rinfo);
551592
RedisModule_CloseKey(key);
552-
return RedisModule_ReplyWithError(ctx, "Output key not found");
593+
return RedisModule_ReplyWithError(ctx, "ERR Output key not found");
553594
}
554595
RedisModule_RetainString(ctx, outputs[i]);
555596
array_append(rinfo->outkeys,outputs[i]);
@@ -651,12 +692,12 @@ int RedisAI_ScriptDel_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
651692
}
652693

653694
/**
654-
* AI.SCRIPTSET script_key device [TAG tag] script_source
695+
* AI.SCRIPTSET script_key device [TAG tag] SOURCE script_source
655696
*/
656697
int RedisAI_ScriptSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
657698
RedisModule_AutoMemory(ctx);
658699

659-
if (argc != 4 && argc != 6) return RedisModule_WrongArity(ctx);
700+
if (argc != 5 && argc != 7) return RedisModule_WrongArity(ctx);
660701

661702
ArgsCursor ac;
662703
ArgsCursor_InitRString(&ac, argv+1, argc-1);
@@ -673,14 +714,21 @@ int RedisAI_ScriptSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
673714
}
674715

675716
if (AC_IsAtEnd(&ac)) {
676-
return RedisModule_ReplyWithError(ctx, "Insufficient arguments, missing script definition");
717+
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing script SOURCE");
677718
}
678719

679-
RAI_Script *script = NULL;
680-
681720
size_t scriptlen;
682-
const char *scriptdef;
683-
AC_GetString(&ac, &scriptdef, &scriptlen, 0);
721+
const char *scriptdef = NULL;
722+
723+
if (AC_AdvanceIfMatch(&ac, "SOURCE")) {
724+
AC_GetString(&ac, &scriptdef, &scriptlen, 0);
725+
}
726+
727+
if (scriptdef == NULL) {
728+
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing script SOURCE");
729+
}
730+
731+
RAI_Script *script = NULL;
684732

685733
RAI_Error err = {0};
686734
script = RAI_ScriptCreate(devicestr, tag, scriptdef, &err);

test/tests_dag.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def test_dag_modelrun_financialNet_errors(env):
144144
model_pb, creditcard_transactions, creditcard_referencedata = load_creditcardfraud_data(
145145
env)
146146
ret = con.execute_command('AI.MODELSET', 'financialNet', 'TF', "CPU",
147-
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', model_pb)
147+
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', 'BLOB', model_pb)
148148
env.assertEqual(ret, b'OK')
149149

150150
tensor_number=1
@@ -390,7 +390,7 @@ def test_dag_modelrun_financialNet_separate_tensorget(env):
390390
model_pb, creditcard_transactions, creditcard_referencedata = load_creditcardfraud_data(
391391
env)
392392
ret = con.execute_command('AI.MODELSET', 'financialNet', 'TF', "CPU",
393-
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', model_pb)
393+
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', 'BLOB', model_pb)
394394
env.assertEqual(ret, b'OK')
395395

396396
tensor_number = 1
@@ -432,7 +432,7 @@ def test_dag_modelrun_financialNet(env):
432432
model_pb, creditcard_transactions, creditcard_referencedata = load_creditcardfraud_data(
433433
env)
434434
ret = con.execute_command('AI.MODELSET', 'financialNet', 'TF', "CPU",
435-
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', model_pb)
435+
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', 'BLOB', model_pb)
436436
env.assertEqual(ret, b'OK')
437437

438438
tensor_number = 1
@@ -471,7 +471,7 @@ def test_dag_modelrun_financialNet_no_writes(env):
471471
model_pb, creditcard_transactions, creditcard_referencedata = load_creditcardfraud_data(
472472
env)
473473
ret = con.execute_command('AI.MODELSET', 'financialNet', 'TF', "CPU",
474-
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', model_pb)
474+
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', 'BLOB', model_pb)
475475
env.assertEqual(ret, b'OK')
476476

477477
tensor_number = 1
@@ -522,7 +522,7 @@ def test_dagro_modelrun_financialNet_no_writes_multiple_modelruns(env):
522522
model_pb, creditcard_transactions, creditcard_referencedata = load_creditcardfraud_data(
523523
env)
524524
ret = con.execute_command('AI.MODELSET', 'financialNet', 'TF', DEVICE,
525-
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', model_pb)
525+
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', 'BLOB', model_pb)
526526
env.assertEqual(ret, b'OK')
527527

528528
tensor_number = 1

0 commit comments

Comments
 (0)