Skip to content

Commit 1eeb776

Browse files
committed
Avoid splitting outputs in batches when nbatches == 1
1 parent fcdf0d2 commit 1eeb776

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

src/backends/onnxruntime.c

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ OrtValue* RAI_OrtValueFromTensors(RAI_Tensor** ts, size_t count, RAI_Error *erro
163163
return NULL;
164164
}
165165

166-
RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_t batch_size, RAI_Error *error) {
166+
RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, long long batch_size, RAI_Error *error) {
167167
OrtStatus* status = NULL;
168168
const OrtApi* ort = OrtGetApiBase()->GetApi(1);
169169

@@ -214,7 +214,12 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_
214214
shape[i] = dims[i];
215215
strides[i] = 1;
216216
}
217-
shape[0] = batch_size;
217+
if (batch_size != -1) {
218+
shape[0] = batch_size;
219+
}
220+
else {
221+
batch_size = total_batch_size;
222+
}
218223
for (int64_t i = ndims - 2; i >= 0; --i)
219224
{
220225
strides[i] *= strides[i + 1] * shape[i + 1];
@@ -529,14 +534,30 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)
529534
}
530535

531536
for (size_t i = 0; i < n_output_nodes; i++) {
532-
for (size_t b=0; b<nbatches; b++) {
533-
RAI_Tensor* output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], batch_offsets[b], batch_sizes[b], error);
537+
if (nbatches > 1) {
538+
for (size_t b=0; b<nbatches; b++) {
539+
RAI_Tensor* output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], batch_offsets[b], batch_sizes[b], error);
540+
if (error->code != RAI_OK) {
541+
ort->ReleaseStatus(status);
542+
return 1;
543+
}
544+
if (output_tensor) {
545+
mctxs[b]->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
546+
RAI_TensorFree(output_tensor);
547+
}
548+
else {
549+
printf("ERR: non-tensor output from ONNX models, ignoring (currently unsupported)");
550+
}
551+
}
552+
}
553+
else {
554+
RAI_Tensor* output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], 0, -1, error);
534555
if (error->code != RAI_OK) {
535556
ort->ReleaseStatus(status);
536557
return 1;
537558
}
538559
if (output_tensor) {
539-
mctxs[b]->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
560+
mctxs[0]->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
540561
RAI_TensorFree(output_tensor);
541562
}
542563
else {

src/backends/tensorflow.c

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ DLDataType RAI_GetDLDataTypeFromTF(TF_DataType dtype) {
7979
return (DLDataType){ .bits = 0 };
8080
}
8181

82-
RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor, size_t batch_offset, size_t batch_size) {
82+
RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor, size_t batch_offset, long long batch_size) {
8383
RAI_Tensor* ret = RedisModule_Calloc(1, sizeof(*ret));
8484

8585
DLContext ctx = (DLContext){
@@ -97,7 +97,12 @@ RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor, size_t batch_offset,
9797
shape[i] = TF_Dim(tensor, i);
9898
strides[i] = 1;
9999
}
100-
shape[0] = batch_size;
100+
if (batch_size != -1) {
101+
shape[0] = batch_size;
102+
}
103+
else {
104+
batch_size = total_batch_size;
105+
}
101106
for (int64_t i = ndims-2 ; i >= 0 ; --i) {
102107
strides[i] *= strides[i+1] * shape[i+1];
103108
}
@@ -531,8 +536,13 @@ int RAI_ModelRunTF(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
531536
}
532537

533538
for(size_t i=0; i<noutputs; ++i) {
534-
for (size_t b=0; b<nbatches; b++) {
535-
mctxs[b]->outputs[i].tensor = RAI_TensorCreateFromTFTensor(outputTensorsValues[i], batch_offsets[b], batch_sizes[b]);
539+
if (nbatches > 1) {
540+
for (size_t b=0; b<nbatches; b++) {
541+
mctxs[b]->outputs[i].tensor = RAI_TensorCreateFromTFTensor(outputTensorsValues[i], batch_offsets[b], batch_sizes[b]);
542+
}
543+
}
544+
else {
545+
mctxs[0]->outputs[i].tensor = RAI_TensorCreateFromTFTensor(outputTensorsValues[i], 0, -1);
536546
}
537547
TF_DeleteTensor(outputTensorsValues[i]);
538548
}

0 commit comments

Comments
 (0)