Skip to content

Commit ee1da7a

Browse files
committed
Add some protections for edge cases to avoid race conditions in the onnx kill switch
1 parent 7a5d18d commit ee1da7a

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

src/backends/onnx_timeout.c

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ int RAI_InitGlobalRunSessionsORT() {
1616
OnnxRunSessionCtx **run_sessions_array =
1717
array_new(OnnxRunSessionCtx *, RAI_working_threads_num);
1818
for (size_t i = 0; i < RAI_working_threads_num; i++) {
19-
OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx));
20-
entry->runState = RedisModule_Calloc(1, sizeof(entry->runState));
19+
OnnxRunSessionCtx *entry = RedisModule_Alloc(sizeof(OnnxRunSessionCtx));
20+
entry->runState = RedisModule_Alloc(sizeof(entry->runState));
2121
*entry->runState = RUN_SESSION_AVAILABLE;
22+
entry->queuingTime = LLONG_MAX;
2223
run_sessions_array = array_append(run_sessions_array, entry);
2324
}
2425
onnx_global_run_sessions->OnnxRunSessions = run_sessions_array;
@@ -44,9 +45,10 @@ int RAI_AddNewDeviceORT(const char *device_str) {
4445
// initialized to NULL.
4546
size_t size = RedisAI_GetNumThreadsPerQueue();
4647
for (size_t i = 0; i < size; i++) {
47-
OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx));
48-
entry->runState = RedisModule_Calloc(1, sizeof(entry->runState));
48+
OnnxRunSessionCtx *entry = RedisModule_Alloc(sizeof(OnnxRunSessionCtx));
49+
entry->runState = RedisModule_Alloc(sizeof(entry->runState));
4950
*entry->runState = RUN_SESSION_AVAILABLE;
51+
entry->queuingTime = LLONG_MAX;
5052
run_sessions_array = array_append(run_sessions_array, entry);
5153
}
5254
onnx_global_run_sessions->OnnxRunSessions = run_sessions_array;
@@ -65,7 +67,10 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s
6567
long long timeout = RedisAI_GetModelExecutionTimeout();
6668
for (size_t i = 0; i < len; i++) {
6769
// Check if a sessions is running for too long, and kill it if is still active.
68-
if (curr_time - run_sessions_ctx[i]->queuingTime > timeout) {
70+
// If entry doesn't contain active session, its queueing time is LLONG_MAX
71+
// (thus the following condition will always be evaluated as false)
72+
if (curr_time - __atomic_load_n(&(run_sessions_ctx[i]->queuingTime), __ATOMIC_RELAXED) >
73+
timeout) {
6974
if (__sync_bool_compare_and_swap(run_sessions_ctx[i]->runState, RUN_SESSION_ACTIVE,
7075
RUN_SESSION_INVALID)) {
7176
// Set termination flag, validate that ONNX API succeeded (returns NULL)
@@ -91,10 +96,11 @@ void RAI_ActivateRunSessionCtxORT(OrtRunOptions *new_run_options, long *run_sess
9196
}
9297
OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[*run_session_index];
9398
RedisModule_Assert(*entry->runState == RUN_SESSION_AVAILABLE);
99+
RedisModule_Assert(entry->queuingTime == LLONG_MAX);
94100

95101
// Update the entry with the current session data.
96102
entry->runOptions = new_run_options;
97-
entry->queuingTime = mstime();
103+
__atomic_store_n(&(entry->queuingTime), mstime(), __ATOMIC_RELAXED);
98104
__atomic_store_n(entry->runState, RUN_SESSION_ACTIVE, __ATOMIC_RELAXED);
99105
pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock));
100106
}
@@ -104,14 +110,16 @@ void RAI_ResetRunSessionCtxORT(long run_session_index) {
104110
pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock));
105111
OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[run_session_index];
106112

107-
// Busy wait until we get a valid state, as we might access this entry from
108-
// the main thread callback and call ONNX API to terminate the run session.
109-
RunSessionState state;
110-
do {
111-
state = __atomic_load_n(entry->runState, __ATOMIC_RELAXED);
112-
} while (state != RUN_SESSION_ACTIVE && state != RUN_SESSION_TERMINATED);
113-
113+
// In most cases, state will be ACTIVE at this point, and we want to turn in to
114+
// AVAILABLE atomically, so we won't call the kill switch at the same time.
115+
if (!__sync_bool_compare_and_swap(entry->runState, RUN_SESSION_ACTIVE, RUN_SESSION_AVAILABLE)) {
116+
// If state was not ACTIVE, it is INVALID/TERMINATED, due to a timeout that
117+
// has occurred. We do busy wait until the state is set to TERMINATE.
118+
while (!__sync_bool_compare_and_swap(entry->runState, RUN_SESSION_TERMINATED,
119+
RUN_SESSION_AVAILABLE))
120+
;
121+
}
122+
__atomic_store_n(&(entry->queuingTime), LLONG_MAX, __ATOMIC_RELAXED);
114123
ort->ReleaseRunOptions(entry->runOptions);
115-
__atomic_store_n(entry->runState, RUN_SESSION_AVAILABLE, __ATOMIC_RELAXED);
116124
pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock));
117125
}

0 commit comments

Comments
 (0)