From 7db58cbbca9d4c6cad3861abd940653735bb5087 Mon Sep 17 00:00:00 2001 From: daiki0321 Date: Sat, 31 May 2025 12:42:27 +0200 Subject: [PATCH] Support onnx runtime for wasi-nn. --- build-scripts/config_common.cmake | 7 +- core/iwasm/libraries/wasi-nn/README.md | 3 +- .../wasi-nn/cmake/Findonnxruntime.cmake | 85 ++ .../libraries/wasi-nn/cmake/wasi_nn.cmake | 27 + core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 14 + .../wasi-nn/src/wasi_nn_onnxruntime.cpp | 765 ++++++++++++++++++ core/iwasm/libraries/wasi-nn/test/utils.c | 174 ++-- 7 files changed, 999 insertions(+), 76 deletions(-) create mode 100644 core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake create mode 100644 core/iwasm/libraries/wasi-nn/src/wasi_nn_onnxruntime.cpp diff --git a/build-scripts/config_common.cmake b/build-scripts/config_common.cmake index 1a77d4cac8..1ccb990a87 100644 --- a/build-scripts/config_common.cmake +++ b/build-scripts/config_common.cmake @@ -511,7 +511,8 @@ if (WAMR_BUILD_WASI_NN EQUAL 1) # Variant backends if (NOT WAMR_BUILD_WASI_NN_TFLITE EQUAL 1 AND NOT WAMR_BUILD_WASI_NN_OPENVINO EQUAL 1 AND - NOT WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1) + NOT WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1 AND + NOT WAMR_BUILD_WASI_NN_ONNXRUNTIME EQUAL 1) message (FATAL_ERROR " Need to select a backend for WASI-NN") endif () @@ -527,6 +528,10 @@ if (WAMR_BUILD_WASI_NN EQUAL 1) message (" WASI-NN: backend llamacpp enabled") add_definitions (-DWASM_ENABLE_WASI_NN_LLAMACPP) endif () + if (WAMR_BUILD_WASI_NN_ONNXRUNTIME EQUAL 1) + message (" WASI-NN: backend onnxruntime enabled") + add_definitions (-DWASM_ENABLE_WASI_NN_ONNXRUNTIME) + endif () # Variant devices if (WAMR_BUILD_WASI_NN_ENABLE_GPU EQUAL 1) message (" WASI-NN: GPU enabled") diff --git a/core/iwasm/libraries/wasi-nn/README.md b/core/iwasm/libraries/wasi-nn/README.md index 99a7664676..7e381af787 100644 --- a/core/iwasm/libraries/wasi-nn/README.md +++ b/core/iwasm/libraries/wasi-nn/README.md @@ -26,6 +26,7 @@ $ cmake -DWAMR_BUILD_WASI_NN=1 ... - `WAMR_BUILD_WASI_NN_TFLITE`. This option designates TensorFlow Lite as the backend. - `WAMR_BUILD_WASI_NN_OPENVINO`. This option designates OpenVINO as the backend. - `WAMR_BUILD_WASI_NN_LLAMACPP`. This option designates Llama.cpp as the backend. +- `WAMR_BUILD_WASI_NN_ONNXRUNTIME`. This option designates ONNX Runtime as the backend. ### Wasm @@ -151,7 +152,7 @@ docker run \ Supported: -- Graph encoding: `tensorflowlite`, `openvino` and `ggml` +- Graph encoding: `tensorflowlite`, `openvino`, `ggml` and `onnx` - Execution target: `cpu` for all. `gpu` and `tpu` for `tensorflowlite`. - Tensor type: `fp32`. diff --git a/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake b/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake new file mode 100644 index 0000000000..f2e7649b75 --- /dev/null +++ b/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake @@ -0,0 +1,85 @@ +# Copyright (C) 2019 Intel Corporation. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Find ONNX Runtime library +# +# This module defines the following variables: +# +# :: +# +# onnxruntime_FOUND - True if onnxruntime is found +# onnxruntime_INCLUDE_DIRS - Include directories for onnxruntime +# onnxruntime_LIBRARIES - List of libraries for onnxruntime +# onnxruntime_VERSION - Version of onnxruntime +# +# :: +# +# Example usage: +# +# find_package(onnxruntime) +# if(onnxruntime_FOUND) +# target_link_libraries(app onnxruntime) +# endif() + +# First try to find ONNX Runtime using the CMake config file +find_package(onnxruntime CONFIG QUIET) +if(onnxruntime_FOUND) + message(STATUS "Found ONNX Runtime via CMake config: ${onnxruntime_DIR}") + return() +endif() + +# If not found via CMake config, try to find manually +find_path(onnxruntime_INCLUDE_DIR + NAMES onnxruntime_c_api.h + PATHS + /usr/include + /usr/local/include + /opt/onnxruntime/include + $ENV{ONNXRUNTIME_ROOT}/include + ${CMAKE_CURRENT_LIST_DIR}/../../../../.. + /home/ubuntu/onnxruntime/onnxruntime-linux-x64-1.16.3/include + PATH_SUFFIXES onnxruntime +) + +find_library(onnxruntime_LIBRARY + NAMES onnxruntime + PATHS + /usr/lib + /usr/local/lib + /opt/onnxruntime/lib + $ENV{ONNXRUNTIME_ROOT}/lib + ${CMAKE_CURRENT_LIST_DIR}/../../../../.. + /home/ubuntu/onnxruntime/onnxruntime-linux-x64-1.16.3/lib +) + +# Try to determine version from header file +if(onnxruntime_INCLUDE_DIR) + file(STRINGS "${onnxruntime_INCLUDE_DIR}/onnxruntime_c_api.h" onnxruntime_version_str + REGEX "^#define[\t ]+ORT_API_VERSION[\t ]+[0-9]+") + + if(onnxruntime_version_str) + string(REGEX REPLACE "^#define[\t ]+ORT_API_VERSION[\t ]+([0-9]+)" "\\1" + onnxruntime_VERSION "${onnxruntime_version_str}") + endif() +endif() + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(onnxruntime + REQUIRED_VARS onnxruntime_LIBRARY onnxruntime_INCLUDE_DIR + VERSION_VAR onnxruntime_VERSION +) + +if(onnxruntime_FOUND) + set(onnxruntime_LIBRARIES ${onnxruntime_LIBRARY}) + set(onnxruntime_INCLUDE_DIRS ${onnxruntime_INCLUDE_DIR}) + + if(NOT TARGET onnxruntime) + add_library(onnxruntime UNKNOWN IMPORTED) + set_target_properties(onnxruntime PROPERTIES + IMPORTED_LOCATION "${onnxruntime_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_INCLUDE_DIRS}" + ) + endif() +endif() + +mark_as_advanced(onnxruntime_INCLUDE_DIR onnxruntime_LIBRARY) diff --git a/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake b/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake index b771b1c402..65a2d4ab46 100644 --- a/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake +++ b/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake @@ -109,3 +109,30 @@ if(WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1) install(TARGETS wasi_nn_llamacpp DESTINATION lib) endif() + +# - onnxruntime +if(WAMR_BUILD_WASI_NN_ONNXRUNTIME EQUAL 1) + find_package(onnxruntime REQUIRED) + enable_language(CXX) + + add_library( + wasi_nn_onnxruntime + SHARED + ${WASI_NN_ROOT}/src/wasi_nn_onnxruntime.cpp + ) + + target_include_directories( + wasi_nn_onnxruntime + PUBLIC + ${onnxruntime_INCLUDE_DIRS} + ) + + target_link_libraries( + wasi_nn_onnxruntime + PUBLIC + vmlib + onnxruntime + ) + + install(TARGETS wasi_nn_onnxruntime DESTINATION lib) +endif() diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 4697e931b0..4ba8beaa39 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -24,6 +24,7 @@ #define TFLITE_BACKEND_LIB "libwasi_nn_tflite.so" #define OPENVINO_BACKEND_LIB "libwasi_nn_openvino.so" #define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp.so" +#define ONNXRUNTIME_BACKEND_LIB "libwasi_nn_onnxruntime.so" /* Global variables */ struct backends_api_functions { @@ -212,6 +213,17 @@ choose_a_backend() return openvino; } +#ifndef NDEBUG + NN_WARN_PRINTF("%s", dlerror()); +#endif + + handle = dlopen(ONNXRUNTIME_BACKEND_LIB, RTLD_LAZY); + if (handle) { + NN_INFO_PRINTF("Using onnxruntime backend"); + dlclose(handle); + return onnx; + } + #ifndef NDEBUG NN_WARN_PRINTF("%s", dlerror()); #endif @@ -335,6 +347,8 @@ graph_encoding_to_backend_lib_name(graph_encoding encoding) return TFLITE_BACKEND_LIB; case ggml: return LLAMACPP_BACKEND_LIB; + case onnx: + return ONNXRUNTIME_BACKEND_LIB; default: return NULL; } diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnxruntime.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnxruntime.cpp new file mode 100644 index 0000000000..c1d1400686 --- /dev/null +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnxruntime.cpp @@ -0,0 +1,765 @@ +/* + * Copyright (C) 2019 Intel Corporation. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + */ + +#include +#include +#include +#include +#include +#include + +#include "wasi_nn_private.h" +#include "wasi_nn.h" +#include "utils/logger.h" +#include "onnxruntime_c_api.h" + +/* Maximum number of graphs and execution contexts */ +#define MAX_GRAPHS 10 +#define MAX_CONTEXTS 10 + +/* ONNX Runtime context structure */ +typedef struct { + OrtEnv *env; + OrtSessionOptions *session_options; + OrtAllocator *allocator; + const OrtApi *ort_api; + std::mutex mutex; + bool is_initialized; +} OnnxRuntimeContext; + +/* Graph structure */ +typedef struct { + OrtSession *session; + bool is_initialized; +} OnnxRuntimeGraph; + +/* Execution context structure */ +typedef struct { + OrtMemoryInfo *memory_info; + std::vector input_names; + std::vector output_names; + std::unordered_map inputs; + std::unordered_map outputs; + OnnxRuntimeGraph *graph; + bool is_initialized; +} OnnxRuntimeExecCtx; + +/* Global variables */ +static OnnxRuntimeContext g_ort_ctx; +static OnnxRuntimeGraph g_graphs[MAX_GRAPHS]; +static OnnxRuntimeExecCtx g_exec_ctxs[MAX_CONTEXTS]; + +/* Helper functions */ +static void +check_status_and_log(OrtStatus *status) +{ + if (status != nullptr) { + const char *msg = g_ort_ctx.ort_api->GetErrorMessage(status); + NN_ERR_PRINTF("ONNX Runtime error: %s", msg); + g_ort_ctx.ort_api->ReleaseStatus(status); + } +} + +static wasi_nn_error +convert_ort_error_to_wasi_nn_error(OrtStatus *status) +{ + if (status == nullptr) { + return success; + } + + wasi_nn_error err; + OrtErrorCode code = g_ort_ctx.ort_api->GetErrorCode(status); + const char *msg = g_ort_ctx.ort_api->GetErrorMessage(status); + + NN_ERR_PRINTF("ONNX Runtime error: %s", msg); + + switch (code) { + case ORT_INVALID_ARGUMENT: + err = invalid_argument; + break; + case ORT_RUNTIME_EXCEPTION: + err = runtime_error; + break; + case ORT_NOT_IMPLEMENTED: + err = unsupported_operation; + break; + case ORT_INVALID_PROTOBUF: + err = invalid_encoding; + break; + case ORT_MODEL_LOADED: + err = too_large; + break; + case ORT_INVALID_GRAPH: + err = invalid_encoding; + break; + default: + err = unknown; + break; + } + + g_ort_ctx.ort_api->ReleaseStatus(status); + return err; +} + +static tensor_type +convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type) +{ + switch (ort_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return fp32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return fp16; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return fp64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return u8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return i32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return i64; +#else + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return up8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ip32; +#endif + default: + NN_WARN_PRINTF("Unsupported ONNX tensor type: %d", ort_type); + return fp32; // Default to fp32 + } +} + +static ONNXTensorElementDataType +convert_wasi_nn_type_to_ort_type(tensor_type type) +{ + switch (type) { + case fp32: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + case fp16: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + case fp64: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; + case u8: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + case i32: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + case i64: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; +#else + case up8: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + case ip32: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; +#endif + default: + NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type); + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; // Default to float + } +} + +static size_t +get_tensor_element_size(tensor_type type) +{ + switch (type) { + case fp32: + return 4; + case fp16: + return 2; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + case fp64: + return 8; + case u8: + return 1; + case i32: + return 4; + case i64: + return 8; +#else + case up8: + return 1; + case ip32: + return 4; +#endif + default: + NN_WARN_PRINTF("Unsupported tensor type: %d", type); + return 4; // Default to 4 bytes (float) + } +} + +/* Backend API implementation */ + +extern "C" { + +__attribute__((visibility("default"))) wasi_nn_error +init_backend(void **onnx_ctx) +{ + std::lock_guard lock(g_ort_ctx.mutex); + + if (g_ort_ctx.is_initialized) { + *onnx_ctx = &g_ort_ctx; + return success; + } + + g_ort_ctx.ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + if (!g_ort_ctx.ort_api) { + NN_ERR_PRINTF("Failed to get ONNX Runtime API"); + return unknown; + } + + NN_INFO_PRINTF("Creating ONNX Runtime environment..."); + OrtStatus *status = g_ort_ctx.ort_api->CreateEnv(ORT_LOGGING_LEVEL_VERBOSE, "wasi-nn", &g_ort_ctx.env); + if (status != nullptr) { + const char* error_message = g_ort_ctx.ort_api->GetErrorMessage(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to create ONNX Runtime environment: %s", error_message); + g_ort_ctx.ort_api->ReleaseStatus(status); + return err; + } + NN_INFO_PRINTF("ONNX Runtime environment created successfully"); + + status = g_ort_ctx.ort_api->CreateSessionOptions(&g_ort_ctx.session_options); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env); + NN_ERR_PRINTF("Failed to create ONNX Runtime session options"); + return err; + } + + status = g_ort_ctx.ort_api->SetSessionGraphOptimizationLevel(g_ort_ctx.session_options, ORT_ENABLE_BASIC); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + g_ort_ctx.ort_api->ReleaseSessionOptions(g_ort_ctx.session_options); + g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env); + NN_ERR_PRINTF("Failed to set graph optimization level"); + return err; + } + + status = g_ort_ctx.ort_api->GetAllocatorWithDefaultOptions(&g_ort_ctx.allocator); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + g_ort_ctx.ort_api->ReleaseSessionOptions(g_ort_ctx.session_options); + g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env); + NN_ERR_PRINTF("Failed to get default allocator"); + return err; + } + + for (int i = 0; i < MAX_GRAPHS; i++) { + g_graphs[i].is_initialized = false; + g_graphs[i].session = nullptr; + } + + for (int i = 0; i < MAX_CONTEXTS; i++) { + g_exec_ctxs[i].is_initialized = false; + g_exec_ctxs[i].memory_info = nullptr; + g_exec_ctxs[i].graph = nullptr; + g_exec_ctxs[i].input_names.clear(); + g_exec_ctxs[i].output_names.clear(); + g_exec_ctxs[i].inputs.clear(); + g_exec_ctxs[i].outputs.clear(); + } + + g_ort_ctx.is_initialized = true; + *onnx_ctx = &g_ort_ctx; + + NN_INFO_PRINTF("ONNX Runtime backend initialized"); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +deinit_backend(void *onnx_ctx) +{ + OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ctx->mutex); + + if (!ctx->is_initialized) { + return success; + } + + for (int i = 0; i < MAX_GRAPHS; i++) { + if (g_graphs[i].is_initialized) { + ctx->ort_api->ReleaseSession(g_graphs[i].session); + g_graphs[i].is_initialized = false; + } + } + + for (int i = 0; i < MAX_CONTEXTS; i++) { + if (g_exec_ctxs[i].is_initialized) { + for (auto &input : g_exec_ctxs[i].inputs) { + ctx->ort_api->ReleaseValue(input.second); + } + for (auto &output : g_exec_ctxs[i].outputs) { + ctx->ort_api->ReleaseValue(output.second); + } + ctx->ort_api->ReleaseMemoryInfo(g_exec_ctxs[i].memory_info); + g_exec_ctxs[i].is_initialized = false; + } + } + + ctx->ort_api->ReleaseSessionOptions(ctx->session_options); + ctx->ort_api->ReleaseEnv(ctx->env); + ctx->is_initialized = false; + + NN_INFO_PRINTF("ONNX Runtime backend deinitialized"); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, + execution_target target, graph *g) +{ + if (encoding != onnx) { + NN_ERR_PRINTF("Unsupported encoding: %d", encoding); + return invalid_encoding; + } + + if (target != cpu) { + NN_ERR_PRINTF("Only CPU target is supported"); + return unsupported_operation; + } + + OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ctx->mutex); + + int graph_index = -1; + for (int i = 0; i < MAX_GRAPHS; i++) { + if (!g_graphs[i].is_initialized) { + graph_index = i; + break; + } + } + + if (graph_index == -1) { + NN_ERR_PRINTF("Maximum number of graphs reached"); + return runtime_error; + } + + if (builder->size == 0 || builder->buf == NULL) { + NN_ERR_PRINTF("No model data provided"); + return invalid_argument; + } + + NN_INFO_PRINTF("[ONNX Runtime] Loading model of size %zu bytes...", builder->buf[0].size); + + if (builder->buf[0].size > 16) { + NN_INFO_PRINTF("Model header bytes: %02x %02x %02x %02x %02x %02x %02x %02x", + ((uint8_t*)builder->buf[0].buf)[0], ((uint8_t*)builder->buf[0].buf)[1], + ((uint8_t*)builder->buf[0].buf)[2], ((uint8_t*)builder->buf[0].buf)[3], + ((uint8_t*)builder->buf[0].buf)[4], ((uint8_t*)builder->buf[0].buf)[5], + ((uint8_t*)builder->buf[0].buf)[6], ((uint8_t*)builder->buf[0].buf)[7]); + } + + OrtStatus *status = ctx->ort_api->CreateSessionFromArray( + ctx->env, builder->buf[0].buf, builder->buf[0].size, + ctx->session_options, &g_graphs[graph_index].session); + + if (status != nullptr) { + const char* error_message = ctx->ort_api->GetErrorMessage(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to create ONNX Runtime session: %s", error_message); + ctx->ort_api->ReleaseStatus(status); + return err; + } + + NN_INFO_PRINTF("ONNX Runtime session created successfully"); + + g_graphs[graph_index].is_initialized = true; + *g = graph_index; + + NN_INFO_PRINTF("ONNX model loaded as graph %d", graph_index); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +load_by_name(void *onnx_ctx, const char *name, graph *g) +{ + OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ctx->mutex); + + int graph_index = -1; + for (int i = 0; i < MAX_GRAPHS; i++) { + if (!g_graphs[i].is_initialized) { + graph_index = i; + break; + } + } + + if (graph_index == -1) { + NN_ERR_PRINTF("Maximum number of graphs reached"); + return runtime_error; + } + + OrtStatus *status = ctx->ort_api->CreateSession( + ctx->env, name, ctx->session_options, &g_graphs[graph_index].session); + + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to create ONNX Runtime session from file: %s", name); + return err; + } + + g_graphs[graph_index].is_initialized = true; + *g = graph_index; + + NN_INFO_PRINTF("ONNX model loaded from file %s as graph %d", name, graph_index); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx) +{ + if (g >= MAX_GRAPHS || !g_graphs[g].is_initialized) { + NN_ERR_PRINTF("Invalid graph handle: %d", g); + return invalid_argument; + } + + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ort_ctx->mutex); + + int ctx_index = -1; + for (int i = 0; i < MAX_CONTEXTS; i++) { + if (!g_exec_ctxs[i].is_initialized) { + ctx_index = i; + break; + } + } + + if (ctx_index == -1) { + NN_ERR_PRINTF("Maximum number of execution contexts reached"); + return runtime_error; + } + + OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx_index]; + exec_ctx->graph = &g_graphs[g]; + + OrtStatus *status = ort_ctx->ort_api->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &exec_ctx->memory_info); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to create CPU memory info"); + return err; + } + + size_t num_input_nodes; + status = ort_ctx->ort_api->SessionGetInputCount(exec_ctx->graph->session, &num_input_nodes); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); + NN_ERR_PRINTF("Failed to get input count"); + return err; + } + + for (size_t i = 0; i < num_input_nodes; i++) { + char *input_name; + status = ort_ctx->ort_api->SessionGetInputName(exec_ctx->graph->session, i, ort_ctx->allocator, &input_name); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); + NN_ERR_PRINTF("Failed to get input name"); + return err; + } + exec_ctx->input_names.push_back(input_name); + } + + size_t num_output_nodes; + status = ort_ctx->ort_api->SessionGetOutputCount(exec_ctx->graph->session, &num_output_nodes); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); + for (const char *name : exec_ctx->input_names) { + ort_ctx->allocator->Free(ort_ctx->allocator, (void *)name); + } + NN_ERR_PRINTF("Failed to get output count"); + return err; + } + + for (size_t i = 0; i < num_output_nodes; i++) { + char *output_name; + status = ort_ctx->ort_api->SessionGetOutputName(exec_ctx->graph->session, i, ort_ctx->allocator, &output_name); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); + for (const char *name : exec_ctx->input_names) { + ort_ctx->allocator->Free(ort_ctx->allocator, (void *)name); + } + NN_ERR_PRINTF("Failed to get output name"); + return err; + } + exec_ctx->output_names.push_back(output_name); + } + + exec_ctx->is_initialized = true; + *ctx = ctx_index; + + NN_INFO_PRINTF("Execution context %d initialized for graph %d", ctx_index, g); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index, tensor *input_tensor) +{ + if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) { + NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); + return invalid_argument; + } + + if (index >= g_exec_ctxs[ctx].input_names.size()) { + NN_ERR_PRINTF("Invalid input index: %d (max: %zu)", index, g_exec_ctxs[ctx].input_names.size() - 1); + return invalid_argument; + } + + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ort_ctx->mutex); + OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx]; + + size_t num_dims = input_tensor->dimensions->size; + int64_t *ort_dims = (int64_t *)malloc(num_dims * sizeof(int64_t)); + if (!ort_dims) { + NN_ERR_PRINTF("Failed to allocate memory for tensor dimensions"); + return runtime_error; + } + + for (size_t i = 0; i < num_dims; i++) { + ort_dims[i] = input_tensor->dimensions->buf[i]; + } + + ONNXTensorElementDataType ort_type = convert_wasi_nn_type_to_ort_type(input_tensor->type); + + OrtValue *input_value = nullptr; + size_t total_elements = 1; + for (size_t i = 0; i < num_dims; i++) { + total_elements *= input_tensor->dimensions->buf[i]; + } + + OrtStatus *status = ort_ctx->ort_api->CreateTensorWithDataAsOrtValue( + exec_ctx->memory_info, input_tensor->data, + get_tensor_element_size(input_tensor->type) * total_elements, + ort_dims, num_dims, ort_type, &input_value); + + free(ort_dims); + + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to create input tensor"); + return err; + } + + if (exec_ctx->inputs.count(index) > 0) { + ort_ctx->ort_api->ReleaseValue(exec_ctx->inputs[index]); + } + exec_ctx->inputs[index] = input_value; + + NN_INFO_PRINTF("Input tensor set for context %d, index %d", ctx, index); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +compute(void *onnx_ctx, graph_execution_context ctx) +{ + if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) { + NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); + return invalid_argument; + } + + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ort_ctx->mutex); + OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx]; + + std::vector input_values; + std::vector input_names; + + for (size_t i = 0; i < exec_ctx->input_names.size(); i++) { + if (exec_ctx->inputs.count(i) == 0) { + NN_ERR_PRINTF("Input tensor not set for index %zu", i); + return invalid_argument; + } + input_values.push_back(exec_ctx->inputs[i]); + input_names.push_back(exec_ctx->input_names[i]); + } + + for (auto &output : exec_ctx->outputs) { + ort_ctx->ort_api->ReleaseValue(output.second); + } + exec_ctx->outputs.clear(); + + std::vector output_values(exec_ctx->output_names.size()); + + OrtStatus *status = ort_ctx->ort_api->Run( + exec_ctx->graph->session, nullptr, + input_names.data(), input_values.data(), input_values.size(), + exec_ctx->output_names.data(), exec_ctx->output_names.size(), output_values.data()); + + for (size_t i = 0; i < output_values.size(); i++) { + exec_ctx->outputs[i] = output_values[i]; + } + + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to run inference"); + return err; + } + + NN_INFO_PRINTF("Inference computed for context %d", ctx); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, tensor_data out_buffer, uint32_t *out_buffer_size) +{ + if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) { + NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); + return invalid_argument; + } + + if (index >= g_exec_ctxs[ctx].output_names.size()) { + NN_ERR_PRINTF("Invalid output index: %d (max: %zu)", index, g_exec_ctxs[ctx].output_names.size() - 1); + return invalid_argument; + } + + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ort_ctx->mutex); + OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx]; + + OrtValue *output_value = exec_ctx->outputs[index]; + if (!output_value) { + NN_ERR_PRINTF("Output tensor not available for index %d", index); + return runtime_error; + } + + OrtTensorTypeAndShapeInfo *tensor_info; + OrtStatus *status = ort_ctx->ort_api->GetTensorTypeAndShape(output_value, &tensor_info); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to get tensor type and shape"); + return err; + } + + ONNXTensorElementDataType element_type; + status = ort_ctx->ort_api->GetTensorElementType(tensor_info, &element_type); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to get tensor element type"); + return err; + } + + size_t num_dims; + status = ort_ctx->ort_api->GetDimensionsCount(tensor_info, &num_dims); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to get tensor dimensions count"); + return err; + } + + int64_t *dims = (int64_t *)malloc(num_dims * sizeof(int64_t)); + if (!dims) { + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to allocate memory for tensor dimensions"); + return runtime_error; + } + + status = ort_ctx->ort_api->GetDimensions(tensor_info, dims, num_dims); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + free(dims); + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to get tensor dimensions"); + return err; + } + + size_t tensor_size; + status = ort_ctx->ort_api->GetTensorShapeElementCount(tensor_info, &tensor_size); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + free(dims); + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to get tensor element count"); + return err; + } + + NN_INFO_PRINTF("Output tensor dimensions: "); + for (size_t i = 0; i < num_dims; i++) { + NN_INFO_PRINTF(" dim[%zu] = %lld", i, dims[i]); + } + NN_INFO_PRINTF("Total elements: %zu", tensor_size); + + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + free(dims); + + if (tensor_size == 0) { + NN_ERR_PRINTF("Tensor is empty (zero elements)"); + return runtime_error; + } + + void *tensor_data = nullptr; + status = ort_ctx->ort_api->GetTensorMutableData(output_value, &tensor_data); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to get tensor data"); + return err; + } + + if (tensor_data == nullptr) { + NN_ERR_PRINTF("Tensor data pointer is null"); + return runtime_error; + } + + size_t element_size; + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + element_size = sizeof(float); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + element_size = sizeof(uint16_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + element_size = sizeof(double); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + element_size = sizeof(int32_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + element_size = sizeof(int64_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + element_size = sizeof(uint8_t); + break; + default: + NN_ERR_PRINTF("Unsupported tensor element type: %d", element_type); + return unsupported_operation; + } + + size_t output_size_bytes = tensor_size * element_size; + + NN_INFO_PRINTF("Output tensor size: %zu elements, element size: %zu bytes, total: %zu bytes", + tensor_size, element_size, output_size_bytes); + + if (*out_buffer_size < output_size_bytes) { + NN_ERR_PRINTF("Output buffer too small: %u bytes provided, %zu bytes needed", + *out_buffer_size, output_size_bytes); + *out_buffer_size = output_size_bytes; + return invalid_argument; + } + + if (tensor_data == nullptr) { + NN_ERR_PRINTF("Tensor data is null"); + return runtime_error; + } + + if (out_buffer == nullptr) { + NN_ERR_PRINTF("Output buffer is null"); + return invalid_argument; + } + + memcpy(out_buffer, tensor_data, output_size_bytes); + *out_buffer_size = output_size_bytes; + + NN_INFO_PRINTF("Output tensor retrieved for context %d, index %d, size %zu bytes", + ctx, index, output_size_bytes); + return success; +} + +} /* End of extern "C" */ diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index 9e43ec9854..68dbd9fa9c 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -3,62 +3,90 @@ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception */ -#include "utils.h" -#include "logger.h" -#include "wasi_nn.h" - #include #include +#include +#include + +#include "wasi_nn.h" +#include "utils.h" + +#define NN_ERR_PRINTF(...) \ + do { \ + printf("Error: "); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define NN_WARN_PRINTF(...) \ + do { \ + printf("Warning: "); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define NN_INFO_PRINTF(...) \ + do { \ + printf("Info: "); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + } while (0) wasi_nn_error wasm_load(char *model_name, graph *g, execution_target target) { - FILE *pFile = fopen(model_name, "r"); - if (pFile == NULL) + FILE *pFile = fopen(model_name, "rb"); + if (pFile == NULL) { + NN_ERR_PRINTF("Error opening file %s", model_name); return invalid_argument; + } - uint8_t *buffer; - size_t result; + fseek(pFile, 0, SEEK_END); + long lSize = ftell(pFile); + rewind(pFile); - // allocate memory to contain the whole file: - buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE); - if (buffer == NULL) { + if (lSize > MAX_MODEL_SIZE) { + NN_ERR_PRINTF("Model size too large: %ld", lSize); fclose(pFile); return too_large; } - result = fread(buffer, 1, MAX_MODEL_SIZE, pFile); - if (result <= 0) { + uint8_t *buffer = (uint8_t *)malloc(lSize); + if (buffer == NULL) { + NN_ERR_PRINTF("Memory allocation error"); fclose(pFile); - free(buffer); - return too_large; + return runtime_error; } - graph_builder_array arr; - - arr.size = 1; - arr.buf = (graph_builder *)malloc(sizeof(graph_builder)); - if (arr.buf == NULL) { + size_t result = fread(buffer, 1, lSize, pFile); + if (result != lSize) { + NN_ERR_PRINTF("Error reading file"); fclose(pFile); free(buffer); - return too_large; + return runtime_error; } - arr.buf[0].size = result; - arr.buf[0].buf = buffer; + graph_builder builder; + builder.buf = buffer; + builder.size = result; - wasi_nn_error res = load(&arr, tensorflowlite, target, g); + graph_builder_array arr; + arr.size = 1; + arr.buf = &builder; + + graph_encoding encoding = tensorflowlite; + const char *ext = strrchr(model_name, '.'); + if (ext && strcmp(ext, ".onnx") == 0) { + encoding = onnx; + } else if (ext && strcmp(ext, ".tflite") == 0) { + encoding = tensorflowlite; + } + + wasi_nn_error res = load(&arr, encoding, target, g); fclose(pFile); free(buffer); - free(arr.buf); - return res; -} -wasi_nn_error -wasm_load_by_name(const char *model_name, graph *g) -{ - wasi_nn_error res = load_by_name(model_name, g); return res; } @@ -73,20 +101,14 @@ wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim) { tensor_dimensions dims; dims.size = INPUT_TENSOR_DIMS; - dims.buf = (uint32_t *)malloc(dims.size * sizeof(uint32_t)); - if (dims.buf == NULL) - return too_large; + dims.buf = dim; - tensor tensor; - tensor.dimensions = &dims; - for (int i = 0; i < tensor.dimensions->size; ++i) - tensor.dimensions->buf[i] = dim[i]; - tensor.type = fp32; - tensor.data = (uint8_t *)input_tensor; - wasi_nn_error err = set_input(ctx, 0, &tensor); + tensor input; + input.dimensions = &dims; + input.type = fp32; + input.data = (uint8_t *)input_tensor; - free(dims.buf); - return err; + return set_input(ctx, 0, &input); } wasi_nn_error @@ -107,44 +129,48 @@ run_inference(execution_target target, float *input, uint32_t *input_size, uint32_t *output_size, char *model_name, uint32_t num_output_tensors) { - graph graph; - if (wasm_load(model_name, &graph, target) != success) { - NN_ERR_PRINTF("Error when loading model."); - exit(1); + graph g; + wasi_nn_error err = wasm_load(model_name, &g, target); + if (err != success) { + NN_ERR_PRINTF("Error when loading model: %d.", err); + return NULL; } graph_execution_context ctx; - if (wasm_init_execution_context(graph, &ctx) != success) { - NN_ERR_PRINTF("Error when initialixing execution context."); - exit(1); + err = wasm_init_execution_context(g, &ctx); + if (err != success) { + NN_ERR_PRINTF("Error when initializing execution context: %d.", err); + return NULL; } - if (wasm_set_input(ctx, input, input_size) != success) { - NN_ERR_PRINTF("Error when setting input tensor."); - exit(1); + err = wasm_set_input(ctx, input, input_size); + if (err != success) { + NN_ERR_PRINTF("Error when setting input tensor: %d.", err); + return NULL; } - if (wasm_compute(ctx) != success) { - NN_ERR_PRINTF("Error when running inference."); - exit(1); + err = wasm_compute(ctx); + if (err != success) { + NN_ERR_PRINTF("Error when computing: %d.", err); + return NULL; } - float *out_tensor = (float *)malloc(sizeof(float) * MAX_OUTPUT_TENSOR_SIZE); + float *out_tensor = (float *)malloc(MAX_OUTPUT_TENSOR_SIZE); if (out_tensor == NULL) { - NN_ERR_PRINTF("Error when allocating memory for output tensor."); - exit(1); + NN_ERR_PRINTF("Memory allocation error"); + return NULL; } uint32_t offset = 0; for (int i = 0; i < num_output_tensors; ++i) { - *output_size = MAX_OUTPUT_TENSOR_SIZE - *output_size; - if (wasm_get_output(ctx, i, &out_tensor[offset], output_size) + uint32_t remaining_size = MAX_OUTPUT_TENSOR_SIZE - (offset * sizeof(float)); + if (wasm_get_output(ctx, i, &out_tensor[offset], &remaining_size) != success) { NN_ERR_PRINTF("Error when getting index %d.", i); break; } - offset += *output_size; + offset += remaining_size / sizeof(float); } *output_size = offset; return out_tensor; @@ -153,18 +179,18 @@ run_inference(execution_target target, float *input, uint32_t *input_size, input_info create_input(int *dims) { - input_info input = { .dim = NULL, .input_tensor = NULL, .elements = 1 }; - - input.dim = malloc(INPUT_TENSOR_DIMS * sizeof(uint32_t)); - if (input.dim) - for (int i = 0; i < INPUT_TENSOR_DIMS; ++i) { - input.dim[i] = dims[i]; - input.elements *= dims[i]; - } + input_info info; + uint32_t elements = 1; + for (int i = 0; i < INPUT_TENSOR_DIMS; ++i) { + elements *= dims[i]; + } - input.input_tensor = malloc(input.elements * sizeof(float)); - for (int i = 0; i < input.elements; ++i) - input.input_tensor[i] = i; + info.input_tensor = (float *)malloc(elements * sizeof(float)); + info.dim = (uint32_t *)malloc(INPUT_TENSOR_DIMS * sizeof(uint32_t)); + for (int i = 0; i < INPUT_TENSOR_DIMS; ++i) { + info.dim[i] = dims[i]; + } + info.elements = elements; - return input; + return info; }