Skip to content

Commit 873ef14

Browse files
committed
Add batching for ONNX and ONNX-ML
1 parent e5d2c5d commit 873ef14

File tree

3 files changed

+170
-41
lines changed

3 files changed

+170
-41
lines changed

src/backends/onnxruntime.c

Lines changed: 158 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -77,31 +77,98 @@ DLDataType RAI_GetDLDataTypeFromORT(ONNXTensorElementDataType dtype) {
7777
return (DLDataType){ .bits = 0 };
7878
}
7979

80-
OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) {
81-
// TODO: create outside and pass?
82-
OrtAllocatorInfo* allocator_info;
83-
OrtStatus* status;
84-
status = OrtCreateCpuAllocatorInfo(OrtArenaAllocator, OrtMemTypeDefault, &allocator_info);
85-
if (status != NULL) {
86-
goto error;
80+
// OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) {
81+
// // TODO: create outside and pass?
82+
// OrtAllocatorInfo* allocator_info;
83+
// OrtStatus* status;
84+
// status = OrtCreateCpuAllocatorInfo(OrtArenaAllocator, OrtMemTypeDefault, &allocator_info);
85+
// if (status != NULL) {
86+
// goto error;
87+
// }
88+
//
89+
// OrtValue* out;
90+
// status = OrtCreateTensorWithDataAsOrtValue(
91+
// allocator_info,
92+
// t->tensor.dl_tensor.data,
93+
// RAI_TensorByteSize(t),
94+
// t->tensor.dl_tensor.shape,
95+
// t->tensor.dl_tensor.ndim,
96+
// RAI_GetOrtDataTypeFromDL(t->tensor.dl_tensor.dtype),
97+
// &out);
98+
//
99+
// if (status != NULL) {
100+
// OrtReleaseAllocatorInfo(allocator_info);
101+
// goto error;
102+
// }
103+
//
104+
// OrtReleaseAllocatorInfo(allocator_info);
105+
//
106+
// return out;
107+
//
108+
// error:
109+
// RAI_SetError(error, RAI_EMODELCREATE, OrtGetErrorMessage(status));
110+
// OrtReleaseStatus(status);
111+
// return NULL;
112+
// }
113+
114+
OrtValue* RAI_OrtValueFromTensors(RAI_Tensor** ts, size_t count, OrtAllocator* allocator, RAI_Error *error) {
115+
if (count == 0) {
116+
return NULL;
117+
}
118+
119+
size_t batch_size = 0;
120+
size_t batch_byte_size = 0;
121+
122+
for (size_t i=0; i<count; i++) {
123+
batch_size += ts[i]->tensor.dl_tensor.shape[0];
124+
batch_byte_size += RAI_TensorByteSize(ts[i]);
87125
}
88126

127+
RAI_Tensor* t0 = ts[0];
128+
129+
int ndim = t0->tensor.dl_tensor.ndim;
130+
int64_t batched_shape[ndim];
131+
132+
for (size_t i=0; i<ndim; i++) {
133+
batched_shape[i] = t0->tensor.dl_tensor.shape[i];
134+
}
135+
136+
batched_shape[0] = batch_size;
137+
138+
OrtStatus* status = NULL;
139+
89140
OrtValue* out;
90-
status = OrtCreateTensorWithDataAsOrtValue(
91-
allocator_info,
92-
t->tensor.dl_tensor.data,
93-
RAI_TensorByteSize(t),
94-
t->tensor.dl_tensor.shape,
95-
t->tensor.dl_tensor.ndim,
96-
RAI_GetOrtDataTypeFromDL(t->tensor.dl_tensor.dtype),
141+
// status = OrtCreateTensorWithDataAsOrtValue(
142+
// allocator_info,
143+
// t->tensor.dl_tensor.data,
144+
// RAI_TensorByteSize(t),
145+
// batched_shape,
146+
// t->tensor.dl_tensor.ndim,
147+
// RAI_GetOrtDataTypeFromDL(t->tensor.dl_tensor.dtype),
148+
// &out);
149+
status = OrtCreateTensorAsOrtValue(
150+
allocator,
151+
batched_shape,
152+
t0->tensor.dl_tensor.ndim,
153+
RAI_GetOrtDataTypeFromDL(t0->tensor.dl_tensor.dtype),
97154
&out);
98-
99155
if (status != NULL) {
100-
OrtReleaseAllocatorInfo(allocator_info);
101156
goto error;
102157
}
158+
159+
char *ort_data;
160+
status = OrtGetTensorMutableData(out, (void **)&ort_data);
161+
if (status != NULL) {
162+
goto error;
163+
}
164+
165+
for (size_t i=0; i<count; i++) {
166+
memcpy(ort_data, RAI_TensorData(ts[i]), RAI_TensorByteSize(ts[i]));
167+
}
103168

104-
OrtReleaseAllocatorInfo(allocator_info);
169+
if (status != NULL) {
170+
goto error;
171+
}
105172

106173
return out;
107174

@@ -111,7 +178,7 @@ OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) {
111178
return NULL;
112179
}
113180

114-
RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) {
181+
RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_t batch_size, RAI_Error *error) {
115182
OrtStatus* status = NULL;
116183

117184
RAI_Tensor* ret = NULL;
@@ -152,18 +219,23 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) {
152219
status = OrtGetTensorElementType(info, &ort_dtype);
153220
if (status != NULL) goto error;
154221

222+
int64_t total_batch_size = dims[0];
223+
155224
shape = RedisModule_Calloc(ndims, sizeof(*shape));
156225
strides = RedisModule_Calloc(ndims, sizeof(*strides));
157-
for (int64_t i = 0; i < ndims; ++i)
226+
for (int64_t i=0; i<ndims; ++i)
158227
{
159228
shape[i] = dims[i];
160229
strides[i] = 1;
161230
}
231+
shape[0] = batch_size;
162232
for (int64_t i = ndims - 2; i >= 0; --i)
163233
{
164234
strides[i] *= strides[i + 1] * shape[i + 1];
165235
}
166236

237+
// size_t sample_bytesize = TF_TensorByteSize(tensor) / total_batch_size;
238+
167239
DLDataType dtype = RAI_GetDLDataTypeFromORT(ort_dtype);
168240
#ifdef RAI_COPY_RUN_OUTPUT
169241
char *ort_data;
@@ -178,8 +250,13 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) {
178250
}
179251

180252
size_t len = dtype.bits * elem_count;
181-
char *data = RedisModule_Calloc(len, sizeof(*data));
182-
memcpy(data, ort_data, len);
253+
254+
size_t total_bytesize = len * sizeof(char);
255+
size_t sample_bytesize = total_bytesize / total_batch_size;
256+
size_t batch_bytesize = sample_bytesize * batch_size;
257+
258+
char *data = RedisModule_Calloc(batch_bytesize, sizeof(*data));
259+
memcpy(data, ort_data + batch_offset, batch_bytesize);
183260
#endif
184261

185262
OrtReleaseTensorTypeAndShapeInfo(info);
@@ -345,6 +422,24 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) {
345422
return 1;
346423
}
347424

425+
const size_t nbatches = array_len(mctx->batches);
426+
if (nbatches == 0) {
427+
RAI_SetError(error, RAI_EMODELRUN, "No batches to run\n");
428+
return 1;
429+
}
430+
431+
size_t batch_sizes[nbatches];
432+
size_t batch_offsets[nbatches];
433+
if (array_len(mctx->batches[0].inputs) > 0) {
434+
for (size_t b=0; b<nbatches; ++b) {
435+
batch_sizes[b] = RAI_TensorDim(mctx->batches[b].inputs[0].tensor, 0);
436+
}
437+
batch_offsets[0] = 0;
438+
for (size_t b=1; b<nbatches; ++b) {
439+
batch_offsets[b] = batch_sizes[b-1];
440+
}
441+
}
442+
348443
OrtStatus *status = NULL;
349444

350445
OrtAllocator *allocator;
@@ -374,8 +469,8 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) {
374469
OrtValue *inputs[n_input_nodes];
375470
OrtValue *outputs[n_output_nodes];
376471

377-
size_t ninputs = array_len(mctx->inputs);
378-
size_t noutputs = array_len(mctx->outputs);
472+
size_t ninputs = array_len(mctx->batches[0].inputs);
473+
size_t noutputs = array_len(mctx->batches[0].outputs);
379474

380475
if (ninputs != n_input_nodes) {
381476
char msg[70];
@@ -403,7 +498,14 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) {
403498

404499
input_names[i] = input_name;
405500

406-
inputs[i] = RAI_OrtValueFromTensor(mctx->inputs[i].tensor, error);
501+
RAI_Tensor* batched_input_tensors[nbatches];
502+
for (size_t b=0; b<nbatches; b++) {
503+
batched_input_tensors[b] = mctx->batches[b].inputs[i].tensor;
504+
}
505+
506+
// TODO: batches
507+
// inputs[i] = RAI_OrtValueFromTensor(mctx->inputs[i].tensor, error);
508+
inputs[i] = RAI_OrtValueFromTensors(batched_input_tensors, nbatches, allocator, error);
407509
if (error->code != RAI_OK) {
408510
OrtReleaseStatus(status);
409511
OrtReleaseAllocator(allocator);
@@ -456,20 +558,40 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) {
456558
}
457559

458560
for (size_t i = 0; i < n_output_nodes; i++) {
459-
RAI_Tensor *output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], error);
460-
if (error->code != RAI_OK) {
461-
OrtReleaseStatus(status);
462-
OrtReleaseAllocator(allocator);
463-
return 1;
464-
}
465-
if (output_tensor) {
466-
mctx->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
467-
RAI_TensorFree(output_tensor);
468-
}
469-
else {
470-
printf("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n");
561+
// TODO batched
562+
for (size_t b=0; b<nbatches; b++) {
563+
RAI_Tensor* output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], batch_offsets[b], batch_sizes[b], error);
564+
if (error->code != RAI_OK) {
565+
// TODO: check everything is deallocated here
566+
OrtReleaseStatus(status);
567+
OrtReleaseAllocator(allocator);
568+
return 1;
569+
}
570+
if (output_tensor) {
571+
mctx->batches[b].outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
572+
RAI_TensorFree(output_tensor);
573+
}
574+
else {
575+
printf("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n");
576+
}
471577
}
578+
472579
OrtReleaseValue(outputs[i]);
580+
581+
// // RAI_Tensor *output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], error);
582+
// if (error->code != RAI_OK) {
583+
// OrtReleaseStatus(status);
584+
// OrtReleaseAllocator(allocator);
585+
// return 1;
586+
// }
587+
// if (output_tensor) {
588+
// mctx->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
589+
// RAI_TensorFree(output_tensor);
590+
// }
591+
// else {
592+
// printf("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n");
593+
// }
594+
// OrtReleaseValue(outputs[i]);
473595
}
474596

475597
for (size_t i = 0; i < n_input_nodes; i++) {

src/backends/tensorflow.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,11 @@ TF_Tensor* RAI_TFTensorFromTensors(RAI_Tensor** ts, size_t count){
172172
}
173173

174174
size_t batch_size = 0;
175+
size_t batch_byte_size = 0;
175176

176177
for (size_t i=0; i<count; i++) {
177178
batch_size += ts[i]->tensor.dl_tensor.shape[0];
179+
batch_byte_size += RAI_TensorByteSize(ts[i]);
178180
}
179181

180182
RAI_Tensor* t0 = ts[0];
@@ -192,7 +194,7 @@ TF_Tensor* RAI_TFTensorFromTensors(RAI_Tensor** ts, size_t count){
192194
RAI_GetTFDataTypeFromDL(t0->tensor.dl_tensor.dtype),
193195
batched_shape,
194196
t0->tensor.dl_tensor.ndim,
195-
RAI_TensorByteSize(t0));
197+
batch_byte_size);
196198

197199
size_t offset = 0;
198200
for (size_t i=0; i<count; i++) {
@@ -422,9 +424,10 @@ void RAI_ModelFreeTF(RAI_Model* model, RAI_Error* error) {
422424

423425
int RAI_ModelRunTF(RAI_ModelRunCtx* mctx, RAI_Error *error) {
424426
TF_Status *status = TF_NewStatus();
425-
const size_t nbatches = array_len(mctx->batches);
426427

428+
const size_t nbatches = array_len(mctx->batches);
427429
if (nbatches == 0) {
430+
RAI_SetError(error, RAI_EMODELRUN, "No batches to run\n");
428431
return 1;
429432
}
430433

src/redisai.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -813,11 +813,15 @@ void *RedisAI_RunSession(struct RedisAI_RunInfo **batch_rinfo) {
813813
for (long long i=0; i<array_len(batch_rinfo); i++) {
814814
struct RedisAI_RunInfo *rinfo = batch_rinfo[i];
815815
if (mctx) {
816-
printf("BATCH %d\n", i);
817-
printf("TENSOR %p\n", i);
818816
size_t noutputs = RAI_ModelRunCtxNumOutputs(mctx);
819817
for (long long o=0; o<noutputs; o++) {
820-
rinfo->mctx->batches[0].outputs[o].tensor = RAI_TensorGetShallowCopy(mctx->batches[i].outputs[o].tensor);
818+
RAI_Tensor* tensor = mctx->batches[i].outputs[o].tensor;
819+
if (tensor) {
820+
rinfo->mctx->batches[0].outputs[o].tensor = RAI_TensorGetShallowCopy(tensor);
821+
}
822+
else {
823+
rinfo->mctx->batches[0].outputs[o].tensor = NULL;
824+
}
821825
}
822826
}
823827
else if (sctx) {

0 commit comments

Comments
 (0)