@@ -77,31 +77,98 @@ DLDataType RAI_GetDLDataTypeFromORT(ONNXTensorElementDataType dtype) {
77
77
return (DLDataType ){ .bits = 0 };
78
78
}
79
79
80
- OrtValue * RAI_OrtValueFromTensor (RAI_Tensor * t , RAI_Error * error ) {
81
- // TODO: create outside and pass?
82
- OrtAllocatorInfo * allocator_info ;
83
- OrtStatus * status ;
84
- status = OrtCreateCpuAllocatorInfo (OrtArenaAllocator , OrtMemTypeDefault , & allocator_info );
85
- if (status != NULL ) {
86
- goto error ;
80
+ // OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) {
81
+ // // TODO: create outside and pass?
82
+ // OrtAllocatorInfo* allocator_info;
83
+ // OrtStatus* status;
84
+ // status = OrtCreateCpuAllocatorInfo(OrtArenaAllocator, OrtMemTypeDefault, &allocator_info);
85
+ // if (status != NULL) {
86
+ // goto error;
87
+ // }
88
+ //
89
+ // OrtValue* out;
90
+ // status = OrtCreateTensorWithDataAsOrtValue(
91
+ // allocator_info,
92
+ // t->tensor.dl_tensor.data,
93
+ // RAI_TensorByteSize(t),
94
+ // t->tensor.dl_tensor.shape,
95
+ // t->tensor.dl_tensor.ndim,
96
+ // RAI_GetOrtDataTypeFromDL(t->tensor.dl_tensor.dtype),
97
+ // &out);
98
+ //
99
+ // if (status != NULL) {
100
+ // OrtReleaseAllocatorInfo(allocator_info);
101
+ // goto error;
102
+ // }
103
+ //
104
+ // OrtReleaseAllocatorInfo(allocator_info);
105
+ //
106
+ // return out;
107
+ //
108
+ // error:
109
+ // RAI_SetError(error, RAI_EMODELCREATE, OrtGetErrorMessage(status));
110
+ // OrtReleaseStatus(status);
111
+ // return NULL;
112
+ // }
113
+
114
+ OrtValue * RAI_OrtValueFromTensors (RAI_Tensor * * ts , size_t count , OrtAllocator * allocator , RAI_Error * error ) {
115
+ if (count == 0 ) {
116
+ return NULL ;
117
+ }
118
+
119
+ size_t batch_size = 0 ;
120
+ size_t batch_byte_size = 0 ;
121
+
122
+ for (size_t i = 0 ; i < count ; i ++ ) {
123
+ batch_size += ts [i ]-> tensor .dl_tensor .shape [0 ];
124
+ batch_byte_size += RAI_TensorByteSize (ts [i ]);
87
125
}
88
126
127
+ RAI_Tensor * t0 = ts [0 ];
128
+
129
+ int ndim = t0 -> tensor .dl_tensor .ndim ;
130
+ int64_t batched_shape [ndim ];
131
+
132
+ for (size_t i = 0 ; i < ndim ; i ++ ) {
133
+ batched_shape [i ] = t0 -> tensor .dl_tensor .shape [i ];
134
+ }
135
+
136
+ batched_shape [0 ] = batch_size ;
137
+
138
+ OrtStatus * status = NULL ;
139
+
89
140
OrtValue * out ;
90
- status = OrtCreateTensorWithDataAsOrtValue (
91
- allocator_info ,
92
- t -> tensor .dl_tensor .data ,
93
- RAI_TensorByteSize (t ),
94
- t -> tensor .dl_tensor .shape ,
95
- t -> tensor .dl_tensor .ndim ,
96
- RAI_GetOrtDataTypeFromDL (t -> tensor .dl_tensor .dtype ),
141
+ // status = OrtCreateTensorWithDataAsOrtValue(
142
+ // allocator_info,
143
+ // t->tensor.dl_tensor.data,
144
+ // RAI_TensorByteSize(t),
145
+ // batched_shape,
146
+ // t->tensor.dl_tensor.ndim,
147
+ // RAI_GetOrtDataTypeFromDL(t->tensor.dl_tensor.dtype),
148
+ // &out);
149
+ status = OrtCreateTensorAsOrtValue (
150
+ allocator ,
151
+ batched_shape ,
152
+ t0 -> tensor .dl_tensor .ndim ,
153
+ RAI_GetOrtDataTypeFromDL (t0 -> tensor .dl_tensor .dtype ),
97
154
& out );
98
-
99
155
if (status != NULL ) {
100
- OrtReleaseAllocatorInfo (allocator_info );
101
156
goto error ;
102
157
}
158
+
159
+ char * ort_data ;
160
+ status = OrtGetTensorMutableData (out , (void * * )& ort_data );
161
+ if (status != NULL ) {
162
+ goto error ;
163
+ }
164
+
165
+ for (size_t i = 0 ; i < count ; i ++ ) {
166
+ memcpy (ort_data , RAI_TensorData (ts [i ]), RAI_TensorByteSize (ts [i ]));
167
+ }
103
168
104
- OrtReleaseAllocatorInfo (allocator_info );
169
+ if (status != NULL ) {
170
+ goto error ;
171
+ }
105
172
106
173
return out ;
107
174
@@ -111,7 +178,7 @@ OrtValue* RAI_OrtValueFromTensor(RAI_Tensor* t, RAI_Error *error) {
111
178
return NULL ;
112
179
}
113
180
114
- RAI_Tensor * RAI_TensorCreateFromOrtValue (OrtValue * v , RAI_Error * error ) {
181
+ RAI_Tensor * RAI_TensorCreateFromOrtValue (OrtValue * v , size_t batch_offset , size_t batch_size , RAI_Error * error ) {
115
182
OrtStatus * status = NULL ;
116
183
117
184
RAI_Tensor * ret = NULL ;
@@ -152,18 +219,23 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) {
152
219
status = OrtGetTensorElementType (info , & ort_dtype );
153
220
if (status != NULL ) goto error ;
154
221
222
+ int64_t total_batch_size = dims [0 ];
223
+
155
224
shape = RedisModule_Calloc (ndims , sizeof (* shape ));
156
225
strides = RedisModule_Calloc (ndims , sizeof (* strides ));
157
- for (int64_t i = 0 ; i < ndims ; ++ i )
226
+ for (int64_t i = 0 ; i < ndims ; ++ i )
158
227
{
159
228
shape [i ] = dims [i ];
160
229
strides [i ] = 1 ;
161
230
}
231
+ shape [0 ] = batch_size ;
162
232
for (int64_t i = ndims - 2 ; i >= 0 ; -- i )
163
233
{
164
234
strides [i ] *= strides [i + 1 ] * shape [i + 1 ];
165
235
}
166
236
237
+ // size_t sample_bytesize = TF_TensorByteSize(tensor) / total_batch_size;
238
+
167
239
DLDataType dtype = RAI_GetDLDataTypeFromORT (ort_dtype );
168
240
#ifdef RAI_COPY_RUN_OUTPUT
169
241
char * ort_data ;
@@ -178,8 +250,13 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, RAI_Error *error) {
178
250
}
179
251
180
252
size_t len = dtype .bits * elem_count ;
181
- char * data = RedisModule_Calloc (len , sizeof (* data ));
182
- memcpy (data , ort_data , len );
253
+
254
+ size_t total_bytesize = len * sizeof (char );
255
+ size_t sample_bytesize = total_bytesize / total_batch_size ;
256
+ size_t batch_bytesize = sample_bytesize * batch_size ;
257
+
258
+ char * data = RedisModule_Calloc (batch_bytesize , sizeof (* data ));
259
+ memcpy (data , ort_data + batch_offset , batch_bytesize );
183
260
#endif
184
261
185
262
OrtReleaseTensorTypeAndShapeInfo (info );
@@ -345,6 +422,24 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) {
345
422
return 1 ;
346
423
}
347
424
425
+ const size_t nbatches = array_len (mctx -> batches );
426
+ if (nbatches == 0 ) {
427
+ RAI_SetError (error , RAI_EMODELRUN , "No batches to run\n" );
428
+ return 1 ;
429
+ }
430
+
431
+ size_t batch_sizes [nbatches ];
432
+ size_t batch_offsets [nbatches ];
433
+ if (array_len (mctx -> batches [0 ].inputs ) > 0 ) {
434
+ for (size_t b = 0 ; b < nbatches ; ++ b ) {
435
+ batch_sizes [b ] = RAI_TensorDim (mctx -> batches [b ].inputs [0 ].tensor , 0 );
436
+ }
437
+ batch_offsets [0 ] = 0 ;
438
+ for (size_t b = 1 ; b < nbatches ; ++ b ) {
439
+ batch_offsets [b ] = batch_sizes [b - 1 ];
440
+ }
441
+ }
442
+
348
443
OrtStatus * status = NULL ;
349
444
350
445
OrtAllocator * allocator ;
@@ -374,8 +469,8 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) {
374
469
OrtValue * inputs [n_input_nodes ];
375
470
OrtValue * outputs [n_output_nodes ];
376
471
377
- size_t ninputs = array_len (mctx -> inputs );
378
- size_t noutputs = array_len (mctx -> outputs );
472
+ size_t ninputs = array_len (mctx -> batches [ 0 ]. inputs );
473
+ size_t noutputs = array_len (mctx -> batches [ 0 ]. outputs );
379
474
380
475
if (ninputs != n_input_nodes ) {
381
476
char msg [70 ];
@@ -403,7 +498,14 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) {
403
498
404
499
input_names [i ] = input_name ;
405
500
406
- inputs [i ] = RAI_OrtValueFromTensor (mctx -> inputs [i ].tensor , error );
501
+ RAI_Tensor * batched_input_tensors [nbatches ];
502
+ for (size_t b = 0 ; b < nbatches ; b ++ ) {
503
+ batched_input_tensors [b ] = mctx -> batches [b ].inputs [i ].tensor ;
504
+ }
505
+
506
+ // TODO: batches
507
+ // inputs[i] = RAI_OrtValueFromTensor(mctx->inputs[i].tensor, error);
508
+ inputs [i ] = RAI_OrtValueFromTensors (batched_input_tensors , nbatches , allocator , error );
407
509
if (error -> code != RAI_OK ) {
408
510
OrtReleaseStatus (status );
409
511
OrtReleaseAllocator (allocator );
@@ -456,20 +558,40 @@ int RAI_ModelRunORT(RAI_ModelRunCtx *mctx, RAI_Error *error) {
456
558
}
457
559
458
560
for (size_t i = 0 ; i < n_output_nodes ; i ++ ) {
459
- RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], error );
460
- if (error -> code != RAI_OK ) {
461
- OrtReleaseStatus (status );
462
- OrtReleaseAllocator (allocator );
463
- return 1 ;
464
- }
465
- if (output_tensor ) {
466
- mctx -> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
467
- RAI_TensorFree (output_tensor );
468
- }
469
- else {
470
- printf ("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n" );
561
+ // TODO batched
562
+ for (size_t b = 0 ; b < nbatches ; b ++ ) {
563
+ RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], batch_offsets [b ], batch_sizes [b ], error );
564
+ if (error -> code != RAI_OK ) {
565
+ // TODO: check everything is deallocated here
566
+ OrtReleaseStatus (status );
567
+ OrtReleaseAllocator (allocator );
568
+ return 1 ;
569
+ }
570
+ if (output_tensor ) {
571
+ mctx -> batches [b ].outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
572
+ RAI_TensorFree (output_tensor );
573
+ }
574
+ else {
575
+ printf ("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n" );
576
+ }
471
577
}
578
+
472
579
OrtReleaseValue (outputs [i ]);
580
+
581
+ // // RAI_Tensor *output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], error);
582
+ // if (error->code != RAI_OK) {
583
+ // OrtReleaseStatus(status);
584
+ // OrtReleaseAllocator(allocator);
585
+ // return 1;
586
+ // }
587
+ // if (output_tensor) {
588
+ // mctx->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
589
+ // RAI_TensorFree(output_tensor);
590
+ // }
591
+ // else {
592
+ // printf("ERR: non-tensor output from ONNX models, ignoring (currently unsupported).\n");
593
+ // }
594
+ // OrtReleaseValue(outputs[i]);
473
595
}
474
596
475
597
for (size_t i = 0 ; i < n_input_nodes ; i ++ ) {
0 commit comments