Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions src/backends/onnx_timeout.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ int RAI_InitGlobalRunSessionsORT() {
OnnxRunSessionCtx **run_sessions_array =
array_new(OnnxRunSessionCtx *, RAI_working_threads_num);
for (size_t i = 0; i < RAI_working_threads_num; i++) {
OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx));
entry->runState = RedisModule_Calloc(1, sizeof(entry->runState));
OnnxRunSessionCtx *entry = RedisModule_Alloc(sizeof(OnnxRunSessionCtx));
entry->runState = RedisModule_Alloc(sizeof(entry->runState));
*entry->runState = RUN_SESSION_AVAILABLE;
entry->queuingTime = LLONG_MAX;
run_sessions_array = array_append(run_sessions_array, entry);
}
onnx_global_run_sessions->OnnxRunSessions = run_sessions_array;
Expand All @@ -44,9 +45,10 @@ int RAI_AddNewDeviceORT(const char *device_str) {
// initialized to NULL.
size_t size = RedisAI_GetNumThreadsPerQueue();
for (size_t i = 0; i < size; i++) {
OnnxRunSessionCtx *entry = RedisModule_Calloc(1, sizeof(OnnxRunSessionCtx));
entry->runState = RedisModule_Calloc(1, sizeof(entry->runState));
OnnxRunSessionCtx *entry = RedisModule_Alloc(sizeof(OnnxRunSessionCtx));
entry->runState = RedisModule_Alloc(sizeof(entry->runState));
*entry->runState = RUN_SESSION_AVAILABLE;
entry->queuingTime = LLONG_MAX;
run_sessions_array = array_append(run_sessions_array, entry);
}
onnx_global_run_sessions->OnnxRunSessions = run_sessions_array;
Expand All @@ -65,7 +67,10 @@ void RAI_EnforceTimeoutORT(RedisModuleCtx *ctx, RedisModuleEvent eid, uint64_t s
long long timeout = RedisAI_GetModelExecutionTimeout();
for (size_t i = 0; i < len; i++) {
// Check if a sessions is running for too long, and kill it if is still active.
if (curr_time - run_sessions_ctx[i]->queuingTime > timeout) {
// If entry doesn't contain active session, its queueing time is LLONG_MAX
// (thus the following condition will always be evaluated as false)
if (curr_time - __atomic_load_n(&(run_sessions_ctx[i]->queuingTime), __ATOMIC_RELAXED) >
timeout) {
if (__sync_bool_compare_and_swap(run_sessions_ctx[i]->runState, RUN_SESSION_ACTIVE,
RUN_SESSION_INVALID)) {
// Set termination flag, validate that ONNX API succeeded (returns NULL)
Expand All @@ -91,10 +96,11 @@ void RAI_ActivateRunSessionCtxORT(OrtRunOptions *new_run_options, long *run_sess
}
OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[*run_session_index];
RedisModule_Assert(*entry->runState == RUN_SESSION_AVAILABLE);
RedisModule_Assert(entry->queuingTime == LLONG_MAX);

// Update the entry with the current session data.
entry->runOptions = new_run_options;
entry->queuingTime = mstime();
__atomic_store_n(&(entry->queuingTime), mstime(), __ATOMIC_RELAXED);
__atomic_store_n(entry->runState, RUN_SESSION_ACTIVE, __ATOMIC_RELAXED);
pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock));
}
Expand All @@ -104,14 +110,16 @@ void RAI_ResetRunSessionCtxORT(long run_session_index) {
pthread_rwlock_rdlock(&(onnx_global_run_sessions->rwlock));
OnnxRunSessionCtx *entry = onnx_global_run_sessions->OnnxRunSessions[run_session_index];

// Busy wait until we get a valid state, as we might access this entry from
// the main thread callback and call ONNX API to terminate the run session.
RunSessionState state;
do {
state = __atomic_load_n(entry->runState, __ATOMIC_RELAXED);
} while (state != RUN_SESSION_ACTIVE && state != RUN_SESSION_TERMINATED);

// In most cases, state will be ACTIVE at this point, and we want to turn in to
// AVAILABLE atomically, so we won't call the kill switch at the same time.
if (!__sync_bool_compare_and_swap(entry->runState, RUN_SESSION_ACTIVE, RUN_SESSION_AVAILABLE)) {
// If state was not ACTIVE, it is INVALID/TERMINATED, due to a timeout that
// has occurred. We do busy wait until the state is set to TERMINATE.
while (!__sync_bool_compare_and_swap(entry->runState, RUN_SESSION_TERMINATED,
RUN_SESSION_AVAILABLE))
;
}
__atomic_store_n(&(entry->queuingTime), LLONG_MAX, __ATOMIC_RELAXED);
ort->ReleaseRunOptions(entry->runOptions);
__atomic_store_n(entry->runState, RUN_SESSION_AVAILABLE, __ATOMIC_RELAXED);
pthread_rwlock_unlock(&(onnx_global_run_sessions->rwlock));
}