@@ -163,7 +163,7 @@ OrtValue* RAI_OrtValueFromTensors(RAI_Tensor** ts, size_t count, RAI_Error *erro
163
163
return NULL ;
164
164
}
165
165
166
- RAI_Tensor * RAI_TensorCreateFromOrtValue (OrtValue * v , size_t batch_offset , size_t batch_size , RAI_Error * error ) {
166
+ RAI_Tensor * RAI_TensorCreateFromOrtValue (OrtValue * v , size_t batch_offset , long long batch_size , RAI_Error * error ) {
167
167
OrtStatus * status = NULL ;
168
168
const OrtApi * ort = OrtGetApiBase ()-> GetApi (1 );
169
169
@@ -214,7 +214,12 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_
214
214
shape [i ] = dims [i ];
215
215
strides [i ] = 1 ;
216
216
}
217
- shape [0 ] = batch_size ;
217
+ if (batch_size != -1 ) {
218
+ shape [0 ] = batch_size ;
219
+ }
220
+ else {
221
+ batch_size = total_batch_size ;
222
+ }
218
223
for (int64_t i = ndims - 2 ; i >= 0 ; -- i )
219
224
{
220
225
strides [i ] *= strides [i + 1 ] * shape [i + 1 ];
@@ -529,14 +534,30 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)
529
534
}
530
535
531
536
for (size_t i = 0 ; i < n_output_nodes ; i ++ ) {
532
- for (size_t b = 0 ; b < nbatches ; b ++ ) {
533
- RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], batch_offsets [b ], batch_sizes [b ], error );
537
+ if (nbatches > 1 ) {
538
+ for (size_t b = 0 ; b < nbatches ; b ++ ) {
539
+ RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], batch_offsets [b ], batch_sizes [b ], error );
540
+ if (error -> code != RAI_OK ) {
541
+ ort -> ReleaseStatus (status );
542
+ return 1 ;
543
+ }
544
+ if (output_tensor ) {
545
+ mctxs [b ]-> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
546
+ RAI_TensorFree (output_tensor );
547
+ }
548
+ else {
549
+ printf ("ERR: non-tensor output from ONNX models, ignoring (currently unsupported)" );
550
+ }
551
+ }
552
+ }
553
+ else {
554
+ RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], 0 , -1 , error );
534
555
if (error -> code != RAI_OK ) {
535
556
ort -> ReleaseStatus (status );
536
557
return 1 ;
537
558
}
538
559
if (output_tensor ) {
539
- mctxs [b ]-> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
560
+ mctxs [0 ]-> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
540
561
RAI_TensorFree (output_tensor );
541
562
}
542
563
else {
0 commit comments