Skip to content

Add support for batching #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion docs/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,22 @@ AI.TENSORGET foo BLOB
Set a model.

```sql
AI.MODELSET model_key backend device [INPUTS name1 name2 ... OUTPUTS name1 name2 ...] model_blob
AI.MODELSET model_key backend device [BATCHSIZE n [MINBATCHSIZE m]] [INPUTS name1 name2 ... OUTPUTS name1 name2 ...] model_blob
```

* model_key - Key for storing the model
* backend - The backend corresponding to the model being set. Allowed values: `TF`, `TORCH`, `ONNX`.
* device - Device where the model is loaded and where the computation will run. Allowed values: `CPU`, `GPU`.
* BATCHSIZE n - Batch incoming requests from multiple clients if they hit the same model and if input tensors have the same
shape. Upon MODELRUN, the request queue is visited, input tensors from compatible requests are concatenated
along the 0-th (batch) dimension, up until BATCHSIZE is exceeded. The model is then run for the entire batch,
results are unpacked back among the individual requests and the respective clients are unblocked.
If the batch size of the inputs to the first request in the queue exceeds BATCHSIZE, the request is served
in any case. Default is 0 (no batching).
* MINBATCHSIZE m - Do not execute a MODELRUN until the batch size has reached MINBATCHSIZE. This is primarily used to force
batching during testing, but it can also be used under normal operation. In this case, note that requests
for which MINBATCHSIZE is not reached will hang indefinitely.
Default is 0 (no minimum batch size).
* INPUTS name1 name2 ... - Name of the nodes in the provided graph corresponding to inputs [`TF` backend only]
* OUTPUTS name1 name2 ... - Name of the nodes in the provided graph corresponding to outputs [`TF` backend only]
* model_blob - Binary buffer containing the model protobuf saved from a supported backend
Expand All @@ -111,6 +121,14 @@ AI.MODELSET resnet18 TF CPU INPUTS in1 OUTPUTS linear4 < foo.pb
AI.MODELSET mnist_net ONNX CPU < mnist.onnx
```

```sql
AI.MODELSET mnist_net ONNX CPU BATCHSIZE 10 < mnist.onnx
```

```sql
AI.MODELSET resnet18 TF CPU BATCHSIZE 10 MINBATCHSIZE 6 INPUTS in1 OUTPUTS linear4 < foo.pb
```

## AI.MODELGET

Get a model.
Expand Down
8 changes: 4 additions & 4 deletions src/backends.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ int RAI_LoadBackend_TensorFlow(RedisModuleCtx *ctx, const char *path) {
}
init_backend(RedisModule_GetApi);

backend.model_create_with_nodes = (RAI_Model* (*)(RAI_Backend, const char*,
backend.model_create_with_nodes = (RAI_Model* (*)(RAI_Backend, const char*, RAI_ModelOpts,
size_t, const char**, size_t, const char**,
const char*, size_t, RAI_Error*))
(unsigned long) dlsym(handle, "RAI_ModelCreateTF");
Expand Down Expand Up @@ -140,7 +140,7 @@ int RAI_LoadBackend_TFLite(RedisModuleCtx *ctx, const char *path) {
}
init_backend(RedisModule_GetApi);

backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*,
backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*, RAI_ModelOpts,
const char*, size_t, RAI_Error*))
(unsigned long) dlsym(handle, "RAI_ModelCreateTFLite");
if (backend.model_create == NULL) {
Expand Down Expand Up @@ -205,7 +205,7 @@ int RAI_LoadBackend_Torch(RedisModuleCtx *ctx, const char *path) {
}
init_backend(RedisModule_GetApi);

backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*,
backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*, RAI_ModelOpts,
const char*, size_t, RAI_Error*))
(unsigned long) dlsym(handle, "RAI_ModelCreateTorch");
if (backend.model_create == NULL) {
Expand Down Expand Up @@ -294,7 +294,7 @@ int RAI_LoadBackend_ONNXRuntime(RedisModuleCtx *ctx, const char *path) {
}
init_backend(RedisModule_GetApi);

backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*,
backend.model_create = (RAI_Model* (*)(RAI_Backend, const char*, RAI_ModelOpts,
const char*, size_t, RAI_Error*))
(unsigned long) dlsym(handle, "RAI_ModelCreateORT");
if (backend.model_create == NULL) {
Expand Down
4 changes: 2 additions & 2 deletions src/backends.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
#include "err.h"

typedef struct RAI_LoadedBackend {
RAI_Model* (*model_create_with_nodes)(RAI_Backend, const char*,
RAI_Model* (*model_create_with_nodes)(RAI_Backend, const char*, RAI_ModelOpts,
size_t, const char**, size_t, const char**,
const char*, size_t, RAI_Error*);
RAI_Model* (*model_create)(RAI_Backend, const char*,
RAI_Model* (*model_create)(RAI_Backend, const char*, RAI_ModelOpts,
const char*, size_t, RAI_Error*);
void (*model_free)(RAI_Model*, RAI_Error*);
int (*model_run)(RAI_ModelRunCtx*, RAI_Error*);
Expand Down
Loading