21
21
#include "util/arr_rm_alloc.h"
22
22
#include "util/dict.h"
23
23
#include "util/queue.h"
24
- #include <ctype.h>
25
- #include <errno.h>
26
24
#include <pthread.h>
27
25
#include <stdio.h>
28
- #include <stdlib.h>
29
26
#include <string.h>
30
27
#include <unistd.h>
28
+ #include <errno.h>
29
+ #include <stdlib.h>
30
+ #include <ctype.h>
31
31
32
32
int freeRunQueueInfo (RunQueueInfo * info ) {
33
33
int result = REDISMODULE_OK ;
@@ -132,7 +132,6 @@ void *RedisAI_Run_ThreadMain(void *arg) {
132
132
// There might be more than one thread operating on the same
133
133
// queue, according to the THREADS_PER_QUEUE config variable.
134
134
long long run_queue_len = queueLength (run_queue_info -> run_queue );
135
-
136
135
while (run_queue_len > 0 ) {
137
136
// We first peek the front of the queue
138
137
queueItem * item = queueFront (run_queue_info -> run_queue );
@@ -176,15 +175,17 @@ void *RedisAI_Run_ThreadMain(void *arg) {
176
175
if (timedOut == 1 ) {
177
176
queueEvict (run_queue_info -> run_queue , item );
178
177
178
+ RedisAI_RunInfo * orig = rinfo -> orig_copy ;
179
179
long long dagRefCount = RAI_DagRunInfoFreeShallowCopy (rinfo );
180
- if (dagRefCount == 0 && rinfo -> client ) {
181
- RedisModule_UnblockClient (rinfo -> client , rinfo );
180
+ if (dagRefCount == 0 ) {
181
+ RedisAI_OnFinishCtx finish_ctx = (RedisAI_RunInfo * )orig ;
182
+ orig -> OnFinish (finish_ctx , orig -> private_data );
182
183
}
183
184
184
185
queueItem * evicted_item = item ;
185
186
item = item -> next ;
186
187
RedisModule_Free (evicted_item );
187
-
188
+ // Continue with the next item in queue (if exists)
188
189
continue ;
189
190
}
190
191
}
@@ -244,9 +245,9 @@ void *RedisAI_Run_ThreadMain(void *arg) {
244
245
int currentOpReady , currentOpBatchable ;
245
246
RedisAI_DagCurrentOpInfo (rinfo , & currentOpReady , & currentOpBatchable );
246
247
247
- // If any of the inputs of the current op is not in the context, it
248
- // means that some parent ops did not execute. In this case we don't
249
- // schedule to run, but we will place the entry back on the queue
248
+ // If any of the inputs of the current op is not in the context, it means
249
+ // that some parent ops did not execute. In this case we don't schedule
250
+ // to run, but we will place the entry back on the queue
250
251
if (currentOpReady == 0 ) {
251
252
do_run = 0 ;
252
253
do_retry = 1 ;
@@ -256,18 +257,16 @@ void *RedisAI_Run_ThreadMain(void *arg) {
256
257
// If we made it this far, we will run the currentOp
257
258
do_run = 1 ;
258
259
259
- // If the current op is not batchable (that is, if it's not a modelrun
260
- // or if it's a modelrun but batchsize was set to 0), we stop looking
261
- // further
260
+ // If the current op is not batchable (that is, if it's not a modelrun or
261
+ // if it's a modelrun but batchsize was set to 0), we stop looking further
262
262
if (currentOpBatchable == 0 ) {
263
263
break ;
264
264
}
265
265
266
266
// If we are here, then we scheduled to run and we currently have an
267
267
// operation that can be batched.
268
268
269
- // Since the current op can be batched, then we collect info on
270
- // batching, namely
269
+ // Since the current op can be batched, then we collect info on batching, namely
271
270
// - batchsize
272
271
// - minbatchsize
273
272
// - minbatchtimeout
@@ -276,8 +275,8 @@ void *RedisAI_Run_ThreadMain(void *arg) {
276
275
RedisAI_DagOpBatchInfo (rinfo , currentOp , & batchsize , & minbatchsize ,
277
276
& minbatchtimeout , & inbatchsize );
278
277
279
- // Get the size of the batch so far, that is, the size of the first
280
- // input tensor in the 0-th dimension
278
+ // Get the size of the batch so far, that is, the size of the first input
279
+ // tensor in the 0-th dimension
281
280
size_t current_batchsize = inbatchsize ;
282
281
283
282
// If the size is zero or if it already exceeds the desired batch size
@@ -396,8 +395,8 @@ void *RedisAI_Run_ThreadMain(void *arg) {
396
395
RedisAI_DagRunSessionStep (batch_rinfo [0 ], run_queue_info -> devicestr );
397
396
}
398
397
399
- // Lock the queue again: we're done operating on evicted items only, we
400
- // need to update the queue with the new information after run
398
+ // Lock the queue again: we're done operating on evicted items only, we need
399
+ // to update the queue with the new information after run
401
400
pthread_mutex_lock (& run_queue_info -> run_queue_mutex );
402
401
403
402
// Run is over, now iterate over the run info structs in the batch
@@ -415,9 +414,11 @@ void *RedisAI_Run_ThreadMain(void *arg) {
415
414
// If there was an error and the reference count for the dag
416
415
// has gone to zero and the client is still around, we unblock
417
416
if (dagError ) {
417
+ RedisAI_RunInfo * orig = rinfo -> orig_copy ;
418
418
long long dagRefCount = RAI_DagRunInfoFreeShallowCopy (rinfo );
419
- if (dagRefCount == 0 && rinfo -> client ) {
420
- RedisModule_UnblockClient (rinfo -> client , rinfo );
419
+ if (dagRefCount == 0 ) {
420
+ RedisAI_OnFinishCtx finish_ctx = (RedisAI_RunInfo * )orig ;
421
+ orig -> OnFinish (finish_ctx , orig -> private_data );
421
422
}
422
423
} else {
423
424
rinfo -> dagDeviceCompleteOpCount += 1 ;
@@ -426,29 +427,30 @@ void *RedisAI_Run_ThreadMain(void *arg) {
426
427
}
427
428
}
428
429
429
- // We initialize variables where we'll store the fact hat, after the
430
- // current run, all ops for the device or all ops in the dag could be
431
- // complete. This way we can avoid placing the op back on the queue if
432
- // there's nothing left to do.
430
+ // We initialize variables where we'll store the fact hat, after the current
431
+ // run, all ops for the device or all ops in the dag could be complete. This
432
+ // way we can avoid placing the op back on the queue if there's nothing left
433
+ // to do.
433
434
int device_complete_after_run = RedisAI_DagDeviceComplete (batch_rinfo [0 ]);
434
435
int dag_complete_after_run = RedisAI_DagComplete (batch_rinfo [0 ]);
435
436
436
437
long long dagRefCount = -1 ;
437
-
438
+ RedisAI_RunInfo * orig ;
438
439
if (device_complete == 1 || device_complete_after_run == 1 ) {
439
440
RedisAI_RunInfo * evicted_rinfo = (RedisAI_RunInfo * )(evicted_items [0 ]-> value );
440
- // We decrease and get the reference count for the DAG
441
+ orig = evicted_rinfo -> orig_copy ;
442
+ // We decrease and get the reference count for the DAG.
441
443
dagRefCount = RAI_DagRunInfoFreeShallowCopy (evicted_rinfo );
442
444
}
443
445
444
446
// If the DAG was complete, then it's time to unblock the client
445
447
if (do_unblock == 1 || dag_complete_after_run == 1 ) {
446
- RedisAI_RunInfo * evicted_rinfo = (RedisAI_RunInfo * )(evicted_items [0 ]-> value );
447
448
448
- // If the reference count for the DAG is zero and the client is still
449
- // around, then we actually unblock the client
450
- if (dagRefCount == 0 && evicted_rinfo -> client ) {
451
- RedisModule_UnblockClient (evicted_rinfo -> client , evicted_rinfo );
449
+ // If the reference count for the DAG is zero and the client is still around,
450
+ // then we actually unblock the client
451
+ if (dagRefCount == 0 ) {
452
+ RedisAI_OnFinishCtx finish_ctx = (RedisAI_RunInfo * )orig ;
453
+ orig -> OnFinish (finish_ctx , orig -> private_data );
452
454
}
453
455
}
454
456
@@ -463,31 +465,28 @@ void *RedisAI_Run_ThreadMain(void *arg) {
463
465
queueItem * next_item = queuePop (run_queue_info -> run_queue );
464
466
RedisAI_RunInfo * next_rinfo = (RedisAI_RunInfo * )next_item -> value ;
465
467
// Push the DAG to the front of the queue, and then the item we just
466
- // popped in front of it, so that it becomes the first item in the
467
- // queue. The rationale is, since the DAG needs to wait for other
468
- // workers, we are giving way to the next item and we'll get back to
469
- // the DAG when that is done
468
+ // popped in front of it, so that it becomes the first item in the queue.
469
+ // The rationale is, since the DAG needs to wait for other workers, we are
470
+ // giving way to the next item and we'll get back to the DAG when that is done
470
471
queuePushFront (run_queue_info -> run_queue , evicted_rinfo );
471
472
queuePushFront (run_queue_info -> run_queue , next_rinfo );
472
473
}
473
474
// If there's nothing else in the queue
474
475
else {
475
476
// We push the DAG back at the front
476
477
queuePushFront (run_queue_info -> run_queue , evicted_rinfo );
477
- // Since there's nothing else on the queue we just break out and give
478
- // other workers a chance to produce the inputs needed for this DAG
479
- // step
478
+ // Since there's nothing else on the queue we just break out and give other
479
+ // workers a chance to produce the inputs needed for this DAG step
480
480
break ;
481
481
}
482
482
}
483
483
484
- // If the op was ran successfully and without any error, then put the
485
- // entry back on the queue unless all ops for the device have been
486
- // executed
484
+ // If the op was ran successfully and without any error, then put the entry back
485
+ // on the queue unless all ops for the device have been executed
487
486
if (do_run == 1 && run_error == 0 ) {
488
487
// Here we iterate backwards to keep the first evicted on top
489
- // A side effect of this is that we are potentially changing priority in
490
- // the queue We could solve this using a priority queue, TODO for later
488
+ // A side effect of this is that we are potentially changing priority in the queue
489
+ // We could solve this using a priority queue, TODO for later
491
490
for (long long i = array_len (evicted_items ) - 1 ; i >= 0 ; i -- ) {
492
491
// Get the current evicted run info
493
492
RedisAI_RunInfo * evicted_rinfo = (RedisAI_RunInfo * )(evicted_items [i ]-> value );
@@ -500,24 +499,22 @@ void *RedisAI_Run_ThreadMain(void *arg) {
500
499
}
501
500
}
502
501
503
- // TODO now we can figure out of the device is complete or the dag is
504
- // complete if (dag_complete_op_count == evicted_rinfo[0]->dagOpCount) ->
505
- // ublock, free if (dag_device_complete_op_count ==
506
- // evicted_rinfo[0]->dagDeviceOpCount) -> device complete
502
+ // TODO now we can figure out of the device is complete or the dag is complete
503
+ // if (dag_complete_op_count == evicted_rinfo[0]->dagOpCount) -> ublock, free
504
+ // if (dag_device_complete_op_count == evicted_rinfo[0]->dagDeviceOpCount) -> device
505
+ // complete
507
506
508
- // If there's nothing else to do for the DAG in the current worker or if
509
- // an error occurred in any worker, we just move on
507
+ // If there's nothing else to do for the DAG in the current worker or if an error
508
+ // occurred in any worker, we just move on
510
509
if (device_complete == 1 || device_complete_after_run == 1 || do_unblock == 1 ||
511
510
run_error == 1 ) {
512
511
for (long long i = 0 ; i < array_len (evicted_items ); i ++ ) {
513
512
RedisModule_Free (evicted_items [i ]);
514
513
}
515
514
}
516
-
517
515
run_queue_len = queueLength (run_queue_info -> run_queue );
518
516
}
519
517
}
520
-
521
518
array_free (evicted_items );
522
519
array_free (batch_rinfo );
523
520
}
0 commit comments