Skip to content

Commit 7ad0d18

Browse files
committed
Add batch size checks
1 parent 1eeb776 commit 7ad0d18

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

src/backends/onnxruntime.c

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,24 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)
535535

536536
for (size_t i = 0; i < n_output_nodes; i++) {
537537
if (nbatches > 1) {
538+
OrtTensorTypeAndShapeInfo* info;
539+
status = ort->GetTensorTypeAndShape(outputs[i], &info);
540+
if (status != NULL) goto error;
541+
542+
size_t ndims;
543+
status = ort->GetDimensionsCount(info, &ndims);
544+
if (status != NULL) goto error;
545+
546+
int64_t dims[ndims];
547+
status = ort->GetDimensions(info, dims, ndims);
548+
if (status != NULL) goto error;
549+
550+
if (dims[0] != nbatches) {
551+
RAI_SetError(error, RAI_EMODELRUN, "ERR Model did not generate the expected batch size");
552+
ort->ReleaseStatus(status);
553+
return 1;
554+
}
555+
538556
for (size_t b=0; b<nbatches; b++) {
539557
RAI_Tensor* output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], batch_offsets[b], batch_sizes[b], error);
540558
if (error->code != RAI_OK) {

src/backends/tensorflow.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,15 @@ int RAI_ModelRunTF(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
537537

538538
for(size_t i=0; i<noutputs; ++i) {
539539
if (nbatches > 1) {
540+
const size_t ndims = TF_NumDims(outputTensorsValues[i]);
541+
const int64_t total_batch_size = TF_Dim(outputTensorsValues[i], 0);
542+
if (nbatches != total_batch_size) {
543+
TF_DeleteTensor(outputTensorsValues[i]);
544+
TF_DeleteStatus(status);
545+
RAI_SetError(error, RAI_EMODELRUN, "ERR Model did not generate the expected batch size");
546+
return 1;
547+
}
548+
540549
for (size_t b=0; b<nbatches; b++) {
541550
mctxs[b]->outputs[i].tensor = RAI_TensorCreateFromTFTensor(outputTensorsValues[i], batch_offsets[b], batch_sizes[b]);
542551
}

src/backends/torch.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ int RAI_ModelRunTorch(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
147147
}
148148
RAI_Tensor* output_tensor = RAI_TensorCreateFromDLTensor(outputs_dl[i]);
149149
if (nbatches > 1) {
150+
if (outputs_dl[i]->dl_tensor.shape[0] != nbatches) {
151+
RAI_SetError(error, RAI_EMODELRUN, "ERR Model did not generate the expected batch size");
152+
return 1;
153+
}
150154
for (size_t b=0; b<nbatches; b++) {
151155
mctxs[b]->outputs[i].tensor = RAI_TensorCreateBySlicingTensor(output_tensor, batch_offsets[b], batch_sizes[b]);
152156
}

0 commit comments

Comments
 (0)