@@ -91,7 +91,7 @@ static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *curre
91
91
92
92
static void Dag_StoreOutputsFromModelRunCtx (RedisAI_RunInfo * rinfo , RAI_DagOp * currentOp ) {
93
93
94
- RAI_ContextReadLock (rinfo );
94
+ RAI_ContextWriteLock (rinfo );
95
95
const size_t noutputs = RAI_ModelRunCtxNumOutputs (currentOp -> mctx );
96
96
for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
97
97
RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (currentOp -> mctx , outputNumber );
@@ -177,6 +177,9 @@ void RedisAI_BatchedDagRunSession_ModelRun_Step(RedisAI_RunInfo **batched_rinfo,
177
177
if (rinfo -> single_op_dag == 0 )
178
178
Dag_StoreOutputsFromModelRunCtx (rinfo , currentOp );
179
179
}
180
+ // Clear the result in case of an error.
181
+ if (result == REDISMODULE_ERR )
182
+ RAI_ClearError (& err );
180
183
}
181
184
182
185
/**
@@ -346,16 +349,20 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
346
349
return 1 ;
347
350
}
348
351
349
- int RedisAI_DagDeviceComplete (RedisAI_RunInfo * rinfo ) {
352
+ bool RedisAI_DagDeviceComplete (RedisAI_RunInfo * rinfo ) {
350
353
return rinfo -> dagDeviceCompleteOpCount == rinfo -> dagDeviceOpCount ;
351
354
}
352
355
353
- int RedisAI_DagComplete (RedisAI_RunInfo * rinfo ) {
356
+ bool RedisAI_DagComplete (RedisAI_RunInfo * rinfo ) {
354
357
int completeOpCount = __atomic_load_n (rinfo -> dagCompleteOpCount , __ATOMIC_RELAXED );
355
358
356
359
return completeOpCount == rinfo -> dagOpCount ;
357
360
}
358
361
362
+ bool RedisAI_DagError (RedisAI_RunInfo * rinfo ) {
363
+ return __atomic_load_n (rinfo -> dagError , __ATOMIC_RELAXED ) != 0 ;
364
+ }
365
+
359
366
RAI_DagOp * RedisAI_DagCurrentOp (RedisAI_RunInfo * rinfo ) {
360
367
if (rinfo -> dagDeviceCompleteOpCount == rinfo -> dagDeviceOpCount ) {
361
368
return NULL ;
@@ -364,21 +371,21 @@ RAI_DagOp *RedisAI_DagCurrentOp(RedisAI_RunInfo *rinfo) {
364
371
return rinfo -> dagDeviceOps [rinfo -> dagDeviceCompleteOpCount ];
365
372
}
366
373
367
- void RedisAI_DagCurrentOpInfo (RedisAI_RunInfo * rinfo , int * currentOpReady ,
368
- int * currentOpBatchable ) {
374
+ void RedisAI_DagCurrentOpInfo (RedisAI_RunInfo * rinfo , bool * currentOpReady ,
375
+ bool * currentOpBatchable ) {
369
376
RAI_DagOp * currentOp_ = RedisAI_DagCurrentOp (rinfo );
370
377
371
- * currentOpReady = 0 ;
372
- * currentOpBatchable = 0 ;
378
+ * currentOpReady = false ;
379
+ * currentOpBatchable = false ;
373
380
374
381
if (currentOp_ == NULL ) {
375
382
return ;
376
383
}
377
384
378
385
if (currentOp_ -> mctx && currentOp_ -> mctx -> model -> opts .batchsize > 0 ) {
379
- * currentOpBatchable = 1 ;
386
+ * currentOpBatchable = true ;
380
387
}
381
- * currentOpReady = 1 ;
388
+ * currentOpReady = true ;
382
389
// If this is a single op dag, the op is definitely ready.
383
390
if (rinfo -> single_op_dag == 1 )
384
391
return ;
@@ -389,7 +396,7 @@ void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady,
389
396
for (int i = 0 ; i < n_inkeys ; i ++ ) {
390
397
if (AI_dictFind (rinfo -> dagTensorsContext , currentOp_ -> inkeys [i ]) == NULL ) {
391
398
RAI_ContextUnlock (rinfo );
392
- * currentOpReady = 0 ;
399
+ * currentOpReady = false ;
393
400
return ;
394
401
}
395
402
}
@@ -577,7 +584,6 @@ static void _ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
577
584
const size_t noutputs = RAI_ModelRunCtxNumOutputs (op -> mctx );
578
585
for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
579
586
RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (op -> mctx , outputNumber );
580
- tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
581
587
if (tensor )
582
588
_StoreTensorInKeySpace (ctx , tensor , op -> outkeys [outputNumber ], false);
583
589
}
@@ -587,7 +593,6 @@ static void _ScriptSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
587
593
const size_t noutputs = RAI_ScriptRunCtxNumOutputs (op -> sctx );
588
594
for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
589
595
RAI_Tensor * tensor = RAI_ScriptRunCtxOutputTensor (op -> sctx , outputNumber );
590
- tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
591
596
if (tensor )
592
597
_StoreTensorInKeySpace (ctx , tensor , op -> outkeys [outputNumber ], false);
593
598
}
@@ -600,7 +605,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
600
605
601
606
if (RAI_GetErrorCode (rinfo -> err ) == RAI_EDAGRUN ) {
602
607
RedisModule_ReplyWithError (ctx , RAI_GetErrorOneLine (rinfo -> err ));
603
- RAI_FreeRunInfo (rinfo );
604
608
return REDISMODULE_ERR ;
605
609
}
606
610
int dag_error = 0 ;
@@ -610,7 +614,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
610
614
611
615
if (* rinfo -> timedOut ) {
612
616
RedisModule_ReplyWithSimpleString (ctx , "TIMEDOUT" );
613
- RAI_FreeRunInfo (rinfo );
614
617
return REDISMODULE_OK ;
615
618
}
616
619
@@ -701,7 +704,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
701
704
if (rinfo -> single_op_dag == 0 ) {
702
705
RedisModule_ReplySetArrayLength (ctx , rinfo -> dagReplyLength );
703
706
}
704
- RAI_FreeRunInfo (rinfo );
705
707
return REDISMODULE_ERR ;
706
708
}
707
709
@@ -718,7 +720,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
718
720
}
719
721
}
720
722
721
- RAI_FreeRunInfo (rinfo );
722
723
return REDISMODULE_OK ;
723
724
}
724
725
@@ -746,11 +747,7 @@ int RedisAI_DagRun_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx, RedisMo
746
747
return REDISMODULE_OK ;
747
748
}
748
749
749
- void RunInfo_FreeData (RedisModuleCtx * ctx , void * rinfo ) {}
750
-
751
- void RedisAI_Disconnected (RedisModuleCtx * ctx , RedisModuleBlockedClient * bc ) {
752
- RedisModule_Log (ctx , "warning" , "Blocked client %p disconnected!" , (void * )bc );
753
- }
750
+ void RunInfo_FreeData (RedisModuleCtx * ctx , void * rinfo ) { RAI_FreeRunInfo (rinfo ); }
754
751
755
752
// Add Shallow copies of the DAG run info to the devices' queues.
756
753
// Return REDISMODULE_OK in case of success, REDISMODULE_ERR if (at least) one insert op had
0 commit comments