Skip to content

Commit cf3d12d

Browse files
committed
Add batching for ONNX and ONNX-ML
1 parent fab4d3a commit cf3d12d

File tree

3 files changed

+152
-39
lines changed

3 files changed

+152
-39
lines changed

src/backends/onnxruntime.c

Lines changed: 140 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -77,32 +77,98 @@ DLDataType RAI_GetDLDataTypeFromORT(ONNXTensorElementDataType dtype) {
7777
return (DLDataType){ .bits = 0 };
7878
}
7979

80-
OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) {
81-
// TODO: create outside and pass?
80+
// OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) {
81+
// // TODO: create outside and pass?
82+
// const OrtApi* ort = OrtGetApiBase()->GetApi(1);
83+
// OrtMemoryInfo* memory_info;
84+
// OrtStatus* status;
85+
// status = ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info);
86+
// if (status != NULL) {
87+
// goto error;
88+
// }
89+
//
90+
// OrtValue* out;
91+
// status = OrtCreateTensorWithDataAsOrtValue(
92+
// allocator_info,
93+
// t->tensor.dl_tensor.data,
94+
// RAI_TensorByteSize(t),
95+
// t->tensor.dl_tensor.shape,
96+
// t->tensor.dl_tensor.ndim,
97+
// RAI_GetOrtDataTypeFromDL(t->tensor.dl_tensor.dtype),
98+
// &out);
99+
//
100+
// if (status != NULL) {
101+
// OrtReleaseAllocatorInfo(allocator_info);
102+
// goto error;
103+
// }
104+
//
105+
// OrtReleaseAllocatorInfo(allocator_info);
106+
//
107+
// return out;
108+
//
109+
// error:
110+
// RAI_SetError(error, RAI_EMODELCREATE, OrtGetErrorMessage(status));
111+
// OrtReleaseStatus(status);
112+
// return NULL;
113+
// }
114+
115+
OrtValue* RAI_OrtValueFromTensors(RAI_Tensor** ts, size_t count, RAI_Error *error) {
116+
OrtStatus* status = NULL;
82117
const OrtApi* ort = OrtGetApiBase()->GetApi(1);
83-
OrtMemoryInfo* memory_info;
84-
OrtStatus* status;
85-
status = ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info);
118+
119+
OrtAllocator *allocator;
120+
status = ort->GetAllocatorWithDefaultOptions(&allocator);
86121
if (status != NULL) {
87-
goto error;
122+
return NULL;
123+
}
124+
125+
if (count == 0) {
126+
return NULL;
127+
}
128+
129+
size_t batch_size = 0;
130+
size_t batch_byte_size = 0;
131+
132+
for (size_t i=0; i<count; i++) {
133+
batch_size += ts[i]->tensor.dl_tensor.shape[0];
134+
batch_byte_size += RAI_TensorByteSize(ts[i]);
88135
}
89136

137+
RAI_Tensor* t0 = ts[0];
138+
139+
int ndim = t0->tensor.dl_tensor.ndim;
140+
int64_t batched_shape[ndim];
141+
142+
for (size_t i=0; i<ndim; i++) {
143+
batched_shape[i] = t0->tensor.dl_tensor.shape[i];
144+
}
145+
146+
batched_shape[0] = batch_size;
147+
90148
OrtValue* out;
91-
status = ort->CreateTensorWithDataAsOrtValue(
92-
memory_info,
93-
t->tensor.dl_tensor.data,
94-
RAI_TensorByteSize(t),
95-
t->tensor.dl_tensor.shape,
96-
t->tensor.dl_tensor.ndim,
97-
RAI_GetOrtDataTypeFromDL(t->tensor.dl_tensor.dtype),
149+
status = ort->CreateTensorAsOrtValue(
150+
allocator,
151+
batched_shape,
152+
t0->tensor.dl_tensor.ndim,
153+
RAI_GetOrtDataTypeFromDL(t0->tensor.dl_tensor.dtype),
98154
&out);
99-
100155
if (status != NULL) {
101-
ort->ReleaseMemoryInfo(memory_info);
102156
goto error;
103157
}
158+
159+
char *ort_data;
160+
status = ort->GetTensorMutableData(out, (void **)&ort_data);
161+
if (status != NULL) {
162+
goto error;
163+
}
164+
165+
for (size_t i=0; i<count; i++) {
166+
memcpy(ort_data, RAI_TensorData(ts[i]), RAI_TensorByteSize(ts[i]));
167+
}
104168

105-
ort->ReleaseMemoryInfo(memory_info);
169+
if (status != NULL) {
170+
goto error;
171+
}
106172

107173
return out;
108174

@@ -112,7 +178,7 @@ OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) {
112178
return NULL;
113179
}
114180

115-
RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) {
181+
RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_t batch_size, RAI_Error *error) {
116182
OrtStatus* status = NULL;
117183
const OrtApi* ort = OrtGetApiBase()->GetApi(1);
118184

@@ -154,18 +220,23 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) {
154220
status = ort->GetTensorElementType(info, &ort_dtype);
155221
if (status != NULL) goto error;
156222

223+
int64_t total_batch_size = dims[0];
224+
157225
shape = RedisModule_Calloc(ndims, sizeof(*shape));
158226
strides = RedisModule_Calloc(ndims, sizeof(*strides));
159-
for (int64_t i = 0; i < ndims; ++i)
227+
for (int64_t i=0; i<ndims; ++i)
160228
{
161229
shape[i] = dims[i];
162230
strides[i] = 1;
163231
}
232+
shape[0] = batch_size;
164233
for (int64_t i = ndims - 2; i >= 0; --i)
165234
{
166235
strides[i] *= strides[i + 1] * shape[i + 1];
167236
}
168237

238+
// size_t sample_bytesize = TF_TensorByteSize(tensor) / total_batch_size;
239+
169240
DLDataType dtype = RAI_GetDLDataTypeFromORT(ort_dtype);
170241
#ifdef RAI_COPY_RUN_OUTPUT
171242
char *ort_data;
@@ -180,8 +251,13 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) {
180251
}
181252

182253
size_t len = dtype.bits * elem_count;
183-
char *data = RedisModule_Calloc(len, sizeof(*data));
184-
memcpy(data, ort_data, len);
254+
255+
size_t total_bytesize = len * sizeof(char);
256+
size_t sample_bytesize = total_bytesize / total_batch_size;
257+
size_t batch_bytesize = sample_bytesize * batch_size;
258+
259+
char *data = RedisModule_Calloc(batch_bytesize, sizeof(*data));
260+
memcpy(data, ort_data + batch_offset, batch_bytesize);
185261
#endif
186262

187263
ort->ReleaseTensorTypeAndShapeInfo(info);
@@ -354,6 +430,24 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error)
354430
return 1;
355431
}
356432

433+
const size_t nbatches = array_len(mctx->batches);
434+
if (nbatches == 0) {
435+
RAI_SetError(error, RAI_EMODELRUN, "No batches to run\n");
436+
return 1;
437+
}
438+
439+
size_t batch_sizes[nbatches];
440+
size_t batch_offsets[nbatches];
441+
if (array_len(mctx->batches[0].inputs) > 0) {
442+
for (size_t b=0; b<nbatches; ++b) {
443+
batch_sizes[b] = RAI_TensorDim(mctx->batches[b].inputs[0].tensor, 0);
444+
}
445+
batch_offsets[0] = 0;
446+
for (size_t b=1; b<nbatches; ++b) {
447+
batch_offsets[b] = batch_sizes[b-1];
448+
}
449+
}
450+
357451
OrtStatus *status = NULL;
358452

359453
OrtAllocator *allocator;
@@ -381,8 +475,8 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error)
381475
OrtValue *inputs[n_input_nodes];
382476
OrtValue *outputs[n_output_nodes];
383477

384-
size_t ninputs = array_len(mctx->inputs);
385-
size_t noutputs = array_len(mctx->outputs);
478+
size_t ninputs = array_len(mctx->batches[0].inputs);
479+
size_t noutputs = array_len(mctx->batches[0].outputs);
386480

387481
if (ninputs != n_input_nodes) {
388482
char msg[70];
@@ -407,7 +501,14 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error)
407501

408502
input_names[i] = input_name;
409503

410-
inputs[i] = RAI_OrtValueFromTensor(mctx->inputs[i].tensor, error);
504+
RAI_Tensor* batched_input_tensors[nbatches];
505+
for (size_t b=0; b<nbatches; b++) {
506+
batched_input_tensors[b] = mctx->batches[b].inputs[i].tensor;
507+
}
508+
509+
// TODO: batches
510+
// inputs[i] = RAI_OrtValueFromTensor(mctx->inputs[i].tensor, error);
511+
inputs[i] = RAI_OrtValueFromTensors(batched_input_tensors, nbatches, error);
411512
if (error->code != RAI_OK) {
412513
ort->ReleaseStatus(status);
413514
return 1;
@@ -457,18 +558,23 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error)
457558
}
458559

459560
for (size_t i = 0; i < n_output_nodes; i++) {
460-
RAI_Tensor *output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], error);
461-
if (error->code != RAI_OK) {
462-
ort->ReleaseStatus(status);
463-
return 1;
464-
}
465-
if (output_tensor) {
466-
mctx->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
467-
RAI_TensorFree(output_tensor);
468-
}
469-
else {
470-
printf("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n");
561+
// TODO batched
562+
for (size_t b=0; b<nbatches; b++) {
563+
RAI_Tensor* output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], batch_offsets[b], batch_sizes[b], error);
564+
if (error->code != RAI_OK) {
565+
// TODO: check everything is deallocated here
566+
ort->ReleaseStatus(status);
567+
return 1;
568+
}
569+
if (output_tensor) {
570+
mctx->batches[b].outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
571+
RAI_TensorFree(output_tensor);
572+
}
573+
else {
574+
printf("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n");
575+
}
471576
}
577+
472578
ort->ReleaseValue(outputs[i]);
473579
}
474580

src/backends/tensorflow.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,11 @@ TF_Tensor* RAI_TFTensorFromTensors(RAI_Tensor** ts, size_t count){
172172
}
173173

174174
size_t batch_size = 0;
175+
size_t batch_byte_size = 0;
175176

176177
for (size_t i=0; i<count; i++) {
177178
batch_size += ts[i]->tensor.dl_tensor.shape[0];
179+
batch_byte_size += RAI_TensorByteSize(ts[i]);
178180
}
179181

180182
RAI_Tensor* t0 = ts[0];
@@ -192,7 +194,7 @@ TF_Tensor* RAI_TFTensorFromTensors(RAI_Tensor** ts, size_t count){
192194
RAI_GetTFDataTypeFromDL(t0->tensor.dl_tensor.dtype),
193195
batched_shape,
194196
t0->tensor.dl_tensor.ndim,
195-
RAI_TensorByteSize(t0));
197+
batch_byte_size);
196198

197199
size_t offset = 0;
198200
for (size_t i=0; i<count; i++) {
@@ -422,9 +424,10 @@ void RAI_ModelFreeTF(RAI_Model* model, RAI_Error* error) {
422424

423425
int RAI_ModelRunTF(RAI_ModelRunCtx* mctx, RAI_Error *error) {
424426
TF_Status *status = TF_NewStatus();
425-
const size_t nbatches = array_len(mctx->batches);
426427

428+
const size_t nbatches = array_len(mctx->batches);
427429
if (nbatches == 0) {
430+
RAI_SetError(error, RAI_EMODELRUN, "No batches to run\n");
428431
return 1;
429432
}
430433

src/redisai.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -813,11 +813,15 @@ void *RedisAI_RunSession(struct RedisAI_RunInfo **batch_rinfo) {
813813
for (long long i=0; i<array_len(batch_rinfo); i++) {
814814
struct RedisAI_RunInfo *rinfo = batch_rinfo[i];
815815
if (mctx) {
816-
printf("BATCH %d\n", i);
817-
printf("TENSOR %p\n", i);
818816
size_t noutputs = RAI_ModelRunCtxNumOutputs(mctx);
819817
for (long long o=0; o<noutputs; o++) {
820-
rinfo->mctx->batches[0].outputs[o].tensor = RAI_TensorGetShallowCopy(mctx->batches[i].outputs[o].tensor);
818+
RAI_Tensor* tensor = mctx->batches[i].outputs[o].tensor;
819+
if (tensor) {
820+
rinfo->mctx->batches[0].outputs[o].tensor = RAI_TensorGetShallowCopy(tensor);
821+
}
822+
else {
823+
rinfo->mctx->batches[0].outputs[o].tensor = NULL;
824+
}
821825
}
822826
}
823827
else if (sctx) {

0 commit comments

Comments
 (0)