diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 6098ef465e..2ba5910575 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -98,6 +98,8 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx) NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded); NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->backend); + bh_assert(!wasi_nn_ctx->busy); + /* deinit() the backend */ if (wasi_nn_ctx->is_backend_ctx_initialized) { wasi_nn_error res; @@ -105,6 +107,7 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx) wasi_nn_ctx->backend_ctx); } + os_mutex_destroy(&wasi_nn_ctx->lock); wasm_runtime_free(wasi_nn_ctx); } @@ -150,6 +153,11 @@ wasi_nn_initialize_context() } memset(wasi_nn_ctx, 0, sizeof(WASINNContext)); + if (os_mutex_init(&wasi_nn_ctx->lock)) { + NN_ERR_PRINTF("Error when initializing a lock for WASI-NN context"); + wasm_runtime_free(wasi_nn_ctx); + return NULL; + } return wasi_nn_ctx; } @@ -176,6 +184,35 @@ wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance) return wasi_nn_ctx; } +static WASINNContext * +lock_ctx(wasm_module_inst_t instance) +{ + WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); + if (wasi_nn_ctx == NULL) { + return NULL; + } + os_mutex_lock(&wasi_nn_ctx->lock); + if (wasi_nn_ctx->busy) { + os_mutex_unlock(&wasi_nn_ctx->lock); + return NULL; + } + wasi_nn_ctx->busy = true; + os_mutex_unlock(&wasi_nn_ctx->lock); + return wasi_nn_ctx; +} + +static void +unlock_ctx(WASINNContext *wasi_nn_ctx) +{ + if (wasi_nn_ctx == NULL) { + return; + } + os_mutex_lock(&wasi_nn_ctx->lock); + bh_assert(wasi_nn_ctx->busy); + wasi_nn_ctx->busy = false; + os_mutex_unlock(&wasi_nn_ctx->lock); +} + void wasi_nn_destroy() { @@ -401,7 +438,7 @@ detect_and_load_backend(graph_encoding backend_hint, static wasi_nn_error ensure_backend(wasm_module_inst_t instance, graph_encoding encoding, - WASINNContext **wasi_nn_ctx_ptr) + WASINNContext *wasi_nn_ctx) { wasi_nn_error res; @@ -412,7 +449,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding, goto fail; } - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); if (wasi_nn_ctx->is_backend_ctx_initialized) { if (wasi_nn_ctx->backend != loaded_backend) { res = unsupported_operation; @@ -430,7 +466,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding, wasi_nn_ctx->is_backend_ctx_initialized = true; } - *wasi_nn_ctx_ptr = wasi_nn_ctx; return success; fail: return res; @@ -458,17 +493,23 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, if (!instance) return runtime_error; + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + graph_builder_array builder_native = { 0 }; #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 if (success != (res = graph_builder_array_app_native( instance, builder, builder_wasm_size, &builder_native))) - return res; + goto fail; #else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ if (success != (res = graph_builder_array_app_native(instance, builder, &builder_native))) - return res; + goto fail; #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ if (!wasm_runtime_validate_native_addr(instance, g, @@ -478,8 +519,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, goto fail; } - WASINNContext *wasi_nn_ctx; - res = ensure_backend(instance, encoding, &wasi_nn_ctx); + res = ensure_backend(instance, encoding, wasi_nn_ctx); if (res != success) goto fail; @@ -494,6 +534,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, // XXX: Free intermediate structure pointers if (builder_native.buf) wasm_runtime_free(builder_native.buf); + unlock_ctx(wasi_nn_ctx); return res; } @@ -527,18 +568,26 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name); - WASINNContext *wasi_nn_ctx; - res = ensure_backend(instance, autodetect, &wasi_nn_ctx); + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + + res = ensure_backend(instance, autodetect, wasi_nn_ctx); if (res != success) - return res; + goto fail; call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, wasi_nn_ctx->backend_ctx, name, name_len, g); if (res != success) - return res; + goto fail; wasi_nn_ctx->is_model_loaded = true; - return success; + res = success; +fail: + unlock_ctx(wasi_nn_ctx); + return res; } wasi_nn_error @@ -576,19 +625,28 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name, NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config); - WASINNContext *wasi_nn_ctx; - res = ensure_backend(instance, autodetect, &wasi_nn_ctx); + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + + res = ensure_backend(instance, autodetect, wasi_nn_ctx); if (res != success) - return res; + goto fail; + ; call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res, wasi_nn_ctx->backend_ctx, name, name_len, config, config_len, g); if (res != success) - return res; + goto fail; wasi_nn_ctx->is_model_loaded = true; - return success; + res = success; +fail: + unlock_ctx(wasi_nn_ctx); + return res; } wasi_nn_error @@ -602,20 +660,27 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g, return runtime_error; } - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); - wasi_nn_error res; + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + if (success != (res = is_model_initialized(wasi_nn_ctx))) - return res; + goto fail; if (!wasm_runtime_validate_native_addr( instance, ctx, (uint64)sizeof(graph_execution_context))) { NN_ERR_PRINTF("ctx is invalid"); - return invalid_argument; + res = invalid_argument; + goto fail; } call_wasi_nn_func(wasi_nn_ctx->backend, init_execution_context, res, wasi_nn_ctx->backend_ctx, g, ctx); +fail: + unlock_ctx(wasi_nn_ctx); return res; } @@ -630,17 +695,21 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx, return runtime_error; } - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); - wasi_nn_error res; + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + if (success != (res = is_model_initialized(wasi_nn_ctx))) - return res; + goto fail; tensor input_tensor_native = { 0 }; if (success != (res = tensor_app_native(instance, input_tensor, &input_tensor_native))) - return res; + goto fail; call_wasi_nn_func(wasi_nn_ctx->backend, set_input, res, wasi_nn_ctx->backend_ctx, ctx, index, @@ -648,7 +717,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx, // XXX: Free intermediate structure pointers if (input_tensor_native.dimensions) wasm_runtime_free(input_tensor_native.dimensions); - +fail: + unlock_ctx(wasi_nn_ctx); return res; } @@ -662,14 +732,20 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx) return runtime_error; } - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); - wasi_nn_error res; + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + if (success != (res = is_model_initialized(wasi_nn_ctx))) - return res; + goto fail; call_wasi_nn_func(wasi_nn_ctx->backend, compute, res, wasi_nn_ctx->backend_ctx, ctx); +fail: + unlock_ctx(wasi_nn_ctx); return res; } @@ -692,16 +768,21 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx, return runtime_error; } - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); - wasi_nn_error res; + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + if (success != (res = is_model_initialized(wasi_nn_ctx))) - return res; + goto fail; if (!wasm_runtime_validate_native_addr(instance, output_tensor_size, (uint64)sizeof(uint32_t))) { NN_ERR_PRINTF("output_tensor_size is invalid"); - return invalid_argument; + res = invalid_argument; + goto fail; } #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 @@ -714,6 +795,8 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx, wasi_nn_ctx->backend_ctx, ctx, index, output_tensor, output_tensor_size); #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ +fail: + unlock_ctx(wasi_nn_ctx); return res; } diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h index fcca310238..a20ad1718c 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h @@ -9,7 +9,11 @@ #include "wasi_nn_types.h" #include "wasm_export.h" +#include "bh_platform.h" + typedef struct { + korp_mutex lock; + bool busy; bool is_backend_ctx_initialized; bool is_model_loaded; graph_encoding backend;