@@ -77,32 +77,98 @@ DLDataType RAI_GetDLDataTypeFromORT(ONNXTensorElementDataType dtype) {
77
77
return (DLDataType ){ .bits = 0 };
78
78
}
79
79
80
- OrtValue * RAI_OrtValueFromTensor (RAI_Tensor * t , RAI_Error * error ) {
81
- // TODO: create outside and pass?
80
+ // OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) {
81
+ // // TODO: create outside and pass?
82
+ // const OrtApi* ort = OrtGetApiBase()->GetApi(1);
83
+ // OrtMemoryInfo* memory_info;
84
+ // OrtStatus* status;
85
+ // status = ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info);
86
+ // if (status != NULL) {
87
+ // goto error;
88
+ // }
89
+ //
90
+ // OrtValue* out;
91
+ // status = OrtCreateTensorWithDataAsOrtValue(
92
+ // allocator_info,
93
+ // t->tensor.dl_tensor.data,
94
+ // RAI_TensorByteSize(t),
95
+ // t->tensor.dl_tensor.shape,
96
+ // t->tensor.dl_tensor.ndim,
97
+ // RAI_GetOrtDataTypeFromDL(t->tensor.dl_tensor.dtype),
98
+ // &out);
99
+ //
100
+ // if (status != NULL) {
101
+ // OrtReleaseAllocatorInfo(allocator_info);
102
+ // goto error;
103
+ // }
104
+ //
105
+ // OrtReleaseAllocatorInfo(allocator_info);
106
+ //
107
+ // return out;
108
+ //
109
+ // error:
110
+ // RAI_SetError(error, RAI_EMODELCREATE, OrtGetErrorMessage(status));
111
+ // OrtReleaseStatus(status);
112
+ // return NULL;
113
+ // }
114
+
115
+ OrtValue * RAI_OrtValueFromTensors (RAI_Tensor * * ts , size_t count , RAI_Error * error ) {
116
+ OrtStatus * status = NULL ;
82
117
const OrtApi * ort = OrtGetApiBase ()-> GetApi (1 );
83
- OrtMemoryInfo * memory_info ;
84
- OrtStatus * status ;
85
- status = ort -> CreateCpuMemoryInfo ( OrtArenaAllocator , OrtMemTypeDefault , & memory_info );
118
+
119
+ OrtAllocator * allocator ;
120
+ status = ort -> GetAllocatorWithDefaultOptions ( & allocator );
86
121
if (status != NULL ) {
87
- goto error ;
122
+ return NULL ;
123
+ }
124
+
125
+ if (count == 0 ) {
126
+ return NULL ;
127
+ }
128
+
129
+ size_t batch_size = 0 ;
130
+ size_t batch_byte_size = 0 ;
131
+
132
+ for (size_t i = 0 ; i < count ; i ++ ) {
133
+ batch_size += ts [i ]-> tensor .dl_tensor .shape [0 ];
134
+ batch_byte_size += RAI_TensorByteSize (ts [i ]);
88
135
}
89
136
137
+ RAI_Tensor * t0 = ts [0 ];
138
+
139
+ int ndim = t0 -> tensor .dl_tensor .ndim ;
140
+ int64_t batched_shape [ndim ];
141
+
142
+ for (size_t i = 0 ; i < ndim ; i ++ ) {
143
+ batched_shape [i ] = t0 -> tensor .dl_tensor .shape [i ];
144
+ }
145
+
146
+ batched_shape [0 ] = batch_size ;
147
+
90
148
OrtValue * out ;
91
- status = ort -> CreateTensorWithDataAsOrtValue (
92
- memory_info ,
93
- t -> tensor .dl_tensor .data ,
94
- RAI_TensorByteSize (t ),
95
- t -> tensor .dl_tensor .shape ,
96
- t -> tensor .dl_tensor .ndim ,
97
- RAI_GetOrtDataTypeFromDL (t -> tensor .dl_tensor .dtype ),
149
+ status = ort -> CreateTensorAsOrtValue (
150
+ allocator ,
151
+ batched_shape ,
152
+ t0 -> tensor .dl_tensor .ndim ,
153
+ RAI_GetOrtDataTypeFromDL (t0 -> tensor .dl_tensor .dtype ),
98
154
& out );
99
-
100
155
if (status != NULL ) {
101
- ort -> ReleaseMemoryInfo (memory_info );
102
156
goto error ;
103
157
}
158
+
159
+ char * ort_data ;
160
+ status = ort -> GetTensorMutableData (out , (void * * )& ort_data );
161
+ if (status != NULL ) {
162
+ goto error ;
163
+ }
164
+
165
+ for (size_t i = 0 ; i < count ; i ++ ) {
166
+ memcpy (ort_data , RAI_TensorData (ts [i ]), RAI_TensorByteSize (ts [i ]));
167
+ }
104
168
105
- ort -> ReleaseMemoryInfo (memory_info );
169
+ if (status != NULL ) {
170
+ goto error ;
171
+ }
106
172
107
173
return out ;
108
174
@@ -112,7 +178,7 @@ OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) {
112
178
return NULL ;
113
179
}
114
180
115
- RAI_Tensor * RAI_TensorCreateFromOrtValue (OrtValue * v , RAI_Error * error ) {
181
+ RAI_Tensor * RAI_TensorCreateFromOrtValue (OrtValue * v , size_t batch_offset , size_t batch_size , RAI_Error * error ) {
116
182
OrtStatus * status = NULL ;
117
183
const OrtApi * ort = OrtGetApiBase ()-> GetApi (1 );
118
184
@@ -154,18 +220,23 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) {
154
220
status = ort -> GetTensorElementType (info , & ort_dtype );
155
221
if (status != NULL ) goto error ;
156
222
223
+ int64_t total_batch_size = dims [0 ];
224
+
157
225
shape = RedisModule_Calloc (ndims , sizeof (* shape ));
158
226
strides = RedisModule_Calloc (ndims , sizeof (* strides ));
159
- for (int64_t i = 0 ; i < ndims ; ++ i )
227
+ for (int64_t i = 0 ; i < ndims ; ++ i )
160
228
{
161
229
shape [i ] = dims [i ];
162
230
strides [i ] = 1 ;
163
231
}
232
+ shape [0 ] = batch_size ;
164
233
for (int64_t i = ndims - 2 ; i >= 0 ; -- i )
165
234
{
166
235
strides [i ] *= strides [i + 1 ] * shape [i + 1 ];
167
236
}
168
237
238
+ // size_t sample_bytesize = TF_TensorByteSize(tensor) / total_batch_size;
239
+
169
240
DLDataType dtype = RAI_GetDLDataTypeFromORT (ort_dtype );
170
241
#ifdef RAI_COPY_RUN_OUTPUT
171
242
char * ort_data ;
@@ -180,8 +251,13 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) {
180
251
}
181
252
182
253
size_t len = dtype .bits * elem_count ;
183
- char * data = RedisModule_Calloc (len , sizeof (* data ));
184
- memcpy (data , ort_data , len );
254
+
255
+ size_t total_bytesize = len * sizeof (char );
256
+ size_t sample_bytesize = total_bytesize / total_batch_size ;
257
+ size_t batch_bytesize = sample_bytesize * batch_size ;
258
+
259
+ char * data = RedisModule_Calloc (batch_bytesize , sizeof (* data ));
260
+ memcpy (data , ort_data + batch_offset , batch_bytesize );
185
261
#endif
186
262
187
263
ort -> ReleaseTensorTypeAndShapeInfo (info );
@@ -354,6 +430,24 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error)
354
430
return 1 ;
355
431
}
356
432
433
+ const size_t nbatches = array_len (mctx -> batches );
434
+ if (nbatches == 0 ) {
435
+ RAI_SetError (error , RAI_EMODELRUN , "No batches to run\n" );
436
+ return 1 ;
437
+ }
438
+
439
+ size_t batch_sizes [nbatches ];
440
+ size_t batch_offsets [nbatches ];
441
+ if (array_len (mctx -> batches [0 ].inputs ) > 0 ) {
442
+ for (size_t b = 0 ; b < nbatches ; ++ b ) {
443
+ batch_sizes [b ] = RAI_TensorDim (mctx -> batches [b ].inputs [0 ].tensor , 0 );
444
+ }
445
+ batch_offsets [0 ] = 0 ;
446
+ for (size_t b = 1 ; b < nbatches ; ++ b ) {
447
+ batch_offsets [b ] = batch_sizes [b - 1 ];
448
+ }
449
+ }
450
+
357
451
OrtStatus * status = NULL ;
358
452
359
453
OrtAllocator * allocator ;
@@ -381,8 +475,8 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error)
381
475
OrtValue * inputs [n_input_nodes ];
382
476
OrtValue * outputs [n_output_nodes ];
383
477
384
- size_t ninputs = array_len (mctx -> inputs );
385
- size_t noutputs = array_len (mctx -> outputs );
478
+ size_t ninputs = array_len (mctx -> batches [ 0 ]. inputs );
479
+ size_t noutputs = array_len (mctx -> batches [ 0 ]. outputs );
386
480
387
481
if (ninputs != n_input_nodes ) {
388
482
char msg [70 ];
@@ -407,7 +501,14 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error)
407
501
408
502
input_names [i ] = input_name ;
409
503
410
- inputs [i ] = RAI_OrtValueFromTensor (mctx -> inputs [i ].tensor , error );
504
+ RAI_Tensor * batched_input_tensors [nbatches ];
505
+ for (size_t b = 0 ; b < nbatches ; b ++ ) {
506
+ batched_input_tensors [b ] = mctx -> batches [b ].inputs [i ].tensor ;
507
+ }
508
+
509
+ // TODO: batches
510
+ // inputs[i] = RAI_OrtValueFromTensor(mctx->inputs[i].tensor, error);
511
+ inputs [i ] = RAI_OrtValueFromTensors (batched_input_tensors , nbatches , error );
411
512
if (error -> code != RAI_OK ) {
412
513
ort -> ReleaseStatus (status );
413
514
return 1 ;
@@ -457,18 +558,23 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error)
457
558
}
458
559
459
560
for (size_t i = 0 ; i < n_output_nodes ; i ++ ) {
460
- RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], error );
461
- if (error -> code != RAI_OK ) {
462
- ort -> ReleaseStatus (status );
463
- return 1 ;
464
- }
465
- if (output_tensor ) {
466
- mctx -> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
467
- RAI_TensorFree (output_tensor );
468
- }
469
- else {
470
- printf ("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n" );
561
+ // TODO batched
562
+ for (size_t b = 0 ; b < nbatches ; b ++ ) {
563
+ RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], batch_offsets [b ], batch_sizes [b ], error );
564
+ if (error -> code != RAI_OK ) {
565
+ // TODO: check everything is deallocated here
566
+ ort -> ReleaseStatus (status );
567
+ return 1 ;
568
+ }
569
+ if (output_tensor ) {
570
+ mctx -> batches [b ].outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
571
+ RAI_TensorFree (output_tensor );
572
+ }
573
+ else {
574
+ printf ("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n" );
575
+ }
471
576
}
577
+
472
578
ort -> ReleaseValue (outputs [i ]);
473
579
}
474
580
0 commit comments