@@ -16,9 +16,10 @@ int RAI_InitGlobalRunSessionsORT() {
16
16
OnnxRunSessionCtx * * run_sessions_array =
17
17
array_new (OnnxRunSessionCtx * , RAI_working_threads_num );
18
18
for (size_t i = 0 ; i < RAI_working_threads_num ; i ++ ) {
19
- OnnxRunSessionCtx * entry = RedisModule_Calloc ( 1 , sizeof (OnnxRunSessionCtx ));
20
- entry -> runState = RedisModule_Calloc ( 1 , sizeof (entry -> runState ));
19
+ OnnxRunSessionCtx * entry = RedisModule_Alloc ( sizeof (OnnxRunSessionCtx ));
20
+ entry -> runState = RedisModule_Alloc ( sizeof (entry -> runState ));
21
21
* entry -> runState = RUN_SESSION_AVAILABLE ;
22
+ entry -> queuingTime = LLONG_MAX ;
22
23
run_sessions_array = array_append (run_sessions_array , entry );
23
24
}
24
25
onnx_global_run_sessions -> OnnxRunSessions = run_sessions_array ;
@@ -44,9 +45,10 @@ int RAI_AddNewDeviceORT(const char *device_str) {
44
45
// initialized to NULL.
45
46
size_t size = RedisAI_GetNumThreadsPerQueue ();
46
47
for (size_t i = 0 ; i < size ; i ++ ) {
47
- OnnxRunSessionCtx * entry = RedisModule_Calloc ( 1 , sizeof (OnnxRunSessionCtx ));
48
- entry -> runState = RedisModule_Calloc ( 1 , sizeof (entry -> runState ));
48
+ OnnxRunSessionCtx * entry = RedisModule_Alloc ( sizeof (OnnxRunSessionCtx ));
49
+ entry -> runState = RedisModule_Alloc ( sizeof (entry -> runState ));
49
50
* entry -> runState = RUN_SESSION_AVAILABLE ;
51
+ entry -> queuingTime = LLONG_MAX ;
50
52
run_sessions_array = array_append (run_sessions_array , entry );
51
53
}
52
54
onnx_global_run_sessions -> OnnxRunSessions = run_sessions_array ;
@@ -65,7 +67,10 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s
65
67
long long timeout = RedisAI_GetModelExecutionTimeout ();
66
68
for (size_t i = 0 ; i < len ; i ++ ) {
67
69
// Check if a sessions is running for too long, and kill it if is still active.
68
- if (curr_time - run_sessions_ctx [i ]-> queuingTime > timeout ) {
70
+ // If entry doesn't contain active session, its queueing time is LLONG_MAX
71
+ // (thus the following condition will always be evaluated as false)
72
+ if (curr_time - __atomic_load_n (& (run_sessions_ctx [i ]-> queuingTime ), __ATOMIC_RELAXED ) >
73
+ timeout ) {
69
74
if (__sync_bool_compare_and_swap (run_sessions_ctx [i ]-> runState , RUN_SESSION_ACTIVE ,
70
75
RUN_SESSION_INVALID )) {
71
76
// Set termination flag, validate that ONNX API succeeded (returns NULL)
@@ -91,10 +96,11 @@ void RAI_ActivateRunSessionCtxORT(OrtRunOptions *new_run_options, long *run_sess
91
96
}
92
97
OnnxRunSessionCtx * entry = onnx_global_run_sessions -> OnnxRunSessions [* run_session_index ];
93
98
RedisModule_Assert (* entry -> runState == RUN_SESSION_AVAILABLE );
99
+ RedisModule_Assert (entry -> queuingTime == LLONG_MAX );
94
100
95
101
// Update the entry with the current session data.
96
102
entry -> runOptions = new_run_options ;
97
- entry -> queuingTime = mstime ();
103
+ __atomic_store_n ( & ( entry -> queuingTime ), mstime (), __ATOMIC_RELAXED );
98
104
__atomic_store_n (entry -> runState , RUN_SESSION_ACTIVE , __ATOMIC_RELAXED );
99
105
pthread_rwlock_unlock (& (onnx_global_run_sessions -> rwlock ));
100
106
}
@@ -112,6 +118,7 @@ void RAI_ResetRunSessionCtxORT(long run_session_index) {
112
118
} while (state != RUN_SESSION_ACTIVE && state != RUN_SESSION_TERMINATED );
113
119
114
120
ort -> ReleaseRunOptions (entry -> runOptions );
121
+ __atomic_store_n (& (entry -> queuingTime ), LLONG_MAX , __ATOMIC_RELAXED );
115
122
__atomic_store_n (entry -> runState , RUN_SESSION_AVAILABLE , __ATOMIC_RELAXED );
116
123
pthread_rwlock_unlock (& (onnx_global_run_sessions -> rwlock ));
117
124
}
0 commit comments