diff --git a/CMakeLists.txt b/CMakeLists.txt index 40af736..50751e2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,6 +18,33 @@ set(WLLAMA_SRC cpp/wllama.cpp include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cpp) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cpp/helpers) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/llama.cpp/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/llama.cpp/examples/llava) +# +# multimodal +# +add_library(mtmd OBJECT + llama.cpp/examples/llava/mtmd.cpp + llama.cpp/examples/llava/mtmd.h + llama.cpp/examples/llava/clip.cpp + llama.cpp/examples/llava/clip.h + llama.cpp/examples/llava/clip-impl.h + ) + +target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT}) + +target_include_directories(mtmd PUBLIC llama.cpp/examples/llava) +target_include_directories(mtmd PRIVATE llama.cpp) +target_include_directories(mtmd PRIVATE llama.cpp/common) # for stb_image.h + +target_compile_features(mtmd PRIVATE cxx_std_17) +target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h +target_compile_options(mtmd PRIVATE -Wno-c++11-narrowing) # clip.cpp + +add_library(mtmd_static STATIC $) + +# +# wllama target definition +# add_executable(wllama ${WLLAMA_SRC}) -target_link_libraries(wllama PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(wllama PRIVATE ggml llama mtmd_static ${CMAKE_THREAD_LIBS_INIT}) diff --git a/cpp/actions.hpp b/cpp/actions.hpp index 662dad5..8e6cd12 100644 --- a/cpp/actions.hpp +++ b/cpp/actions.hpp @@ -6,8 +6,12 @@ #include #include #include +#include +#include #include "llama.h" +#include "mtmd.h" +#include "utils.hpp" #include "helpers/wcommon.h" #include "helpers/wsampling.h" @@ -20,13 +24,47 @@ struct app_t { - llama_model *model; - llama_context *ctx; + llama_model *model = nullptr; + llama_context *ctx = nullptr; const llama_vocab *vocab; wcommon_sampler *ctx_sampling = nullptr; llama_batch batch = llama_batch_init(512, 0, 1); llama_tokens tokens; + int32_t n_batch; int32_t seed = LLAMA_DEFAULT_SEED; + + // multimodal + mtmd_context *mctx = nullptr; + llama_token next_img_id = -1000; // count backwards: -1001, -1002, -1003, ... + std::unordered_map img_cache; + + llama_token find_cached_image_tokens_by_id(const std::string &id) + { + for (auto it = img_cache.begin(); it != img_cache.end(); ++it) + { + if (mtmd_image_tokens_get_id(it->second.get()) == id) + return it->first; + } + return LLAMA_TOKEN_NULL; + } + + // remove unused images + void update_img_cache() + { + std::set active_img_ids; + for (auto &t : tokens) + { + if (t < 0) + active_img_ids.insert(t); + } + for (auto it = img_cache.begin(); it != img_cache.end();) + { + if (active_img_ids.find(it->first) == active_img_ids.end()) + it = img_cache.erase(it); + else + ++it; + } + } }; inline std::vector convert_string_to_buf(std::string &input) @@ -211,6 +249,7 @@ glue_msg_load_res action_load(app_t &app, const char *req_raw) } // load model + app.n_batch = cparams.n_batch; app.model = llama_model_load_from_splits( model_paths_ptrs.data(), model_paths_ptrs.size(), mparams); if (app.model == nullptr) @@ -260,8 +299,22 @@ glue_msg_load_res action_load(app_t &app, const char *req_raw) } kv_dump metadata = dump_metadata(app); + if (req.mmproj_path.not_null() && !req.mmproj_path.value.empty()) + { + mtmd_context_params mparams; + mparams.use_gpu = false; // not yet supported + mparams.n_threads = cparams.n_threads; + app.mctx = mtmd_init_from_file(req.mmproj_path.value.c_str(), app.model, mparams); + if (app.mctx == nullptr) + { + free_all(app); + throw app_exception("Error while loading multimodal projector"); + } + } + glue_msg_load_res res; res.success.value = true; + res.has_mtmd.value = app.mctx != nullptr; res.n_ctx.value = cparams.n_ctx; res.n_batch.value = llama_n_batch(app.ctx); res.n_ubatch.value = llama_n_ubatch(app.ctx); @@ -420,12 +473,66 @@ glue_msg_lookup_token_res action_lookup_token(app_t &app, const char *req_raw) } // tokenize an input string +// if the input contains images, the preprocessed images will be cached until action_kv_remove/clear is called +// images NOT being decoded will be cleared from the cache glue_msg_tokenize_res action_tokenize(app_t &app, const char *req_raw) { PARSE_REQ(glue_msg_tokenize_req); std::string &text = req.text.value; bool special = req.special.value; - llama_tokens tokens_list = wcommon_tokenize(app.vocab, text, false, special); + llama_tokens tokens_list; + if (app.mctx) + { + // multimodal + std::vector mbitmaps; + for (size_t i = 0; i < req.bitmaps.arr.size(); i++) + { + mtmd_bitmap bitmap; + std::vector tmp(req.bitmaps.arr[i].begin(), req.bitmaps.arr[i].end()); + bitmap.id = std::to_string(fnv_hash(tmp.data(), tmp.size())); + bitmap.data = std::move(tmp); + bitmap.nx = req.bitmaps_x.arr[i]; + bitmap.ny = req.bitmaps_y.arr[i]; + mbitmaps.push_back(bitmap); + } + mtmd_input_text mtext{ + /* text */ text, + /* add_special */ false, + /* parse_special */ special, + }; + mtmd_input_chunks chunks; + int32_t res = mtmd_tokenize(app.mctx, chunks, mtext, mbitmaps); + if (res != 0) + { + glue_msg_tokenize_res res; + res.success.value = false; + res.message.value = "mtmd_tokenize failed"; + return res; + } + for (auto &chunk : chunks) + { + if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) + { + tokens_list.insert(tokens_list.end(), chunk.tokens_text.begin(), chunk.tokens_text.end()); + } + else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) + { + // if the image is already in the cache, use the cached image id + std::string id = mtmd_image_tokens_get_id(chunk.tokens_image.get()); + llama_token cached_id = app.find_cached_image_tokens_by_id(id); + llama_token img_id = cached_id == LLAMA_TOKEN_NULL ? app.next_img_id-- : cached_id; + for (size_t i = 0; i < mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get()); i++) + tokens_list.push_back(img_id); + if (cached_id == LLAMA_TOKEN_NULL) + app.img_cache[img_id] = std::move(chunk.tokens_image); // add to cache + } + } + } + else + { + // text only + tokens_list = wcommon_tokenize(app.vocab, text, false, special); + } glue_msg_tokenize_res res; res.success.value = true; @@ -441,7 +548,9 @@ glue_msg_detokenize_res action_detokenize(app_t &app, const char *req_raw) std::stringstream output; for (auto id : tokens) { - output << wcommon_token_to_piece(app.ctx, id); + // TODO: if we encounter an image token, we should replace it with the image marker + if (id >= 0) + output << wcommon_token_to_piece(app.ctx, id); } std::string parsed_str = output.str(); @@ -459,9 +568,16 @@ glue_msg_decode_res action_decode(app_t &app, const char *req_raw) bool skip_logits = req.skip_logits.value; size_t i = 0; wcommon_batch_clear(app.batch); + glue_msg_decode_res res; for (auto id : tokens_list) { - bool grp_attn_enabled = false; // TODO: maybe remove grp_attn + if (id < 0) + { + res.success.value = false; + res.message.value = "image token not supported, use action_eval_image instead"; + res.n_past.value = app.tokens.size(); + return res; + } int32_t n_past = app.tokens.size(); wcommon_batch_add(app.batch, id, n_past, {0}, false); app.tokens.push_back(id); @@ -472,7 +588,6 @@ glue_msg_decode_res action_decode(app_t &app, const char *req_raw) { app.batch.logits[app.batch.n_tokens - 1] = true; } - glue_msg_decode_res res; if (llama_decode(app.ctx, app.batch) != 0) { res.success.value = false; @@ -521,6 +636,66 @@ glue_msg_encode_res action_encode(app_t &app, const char *req_raw) return res; } +glue_msg_eval_image_res action_eval_image(app_t &app, const char *req_raw) +{ + PARSE_REQ(glue_msg_eval_image_req); + glue_msg_eval_image_res res; + llama_pos n_past = app.tokens.size(); + res.n_past.value = n_past; + if (app.mctx == nullptr) + { + res.success.value = false; + res.message.value = "this model does not have a multimodal projector"; + return res; + } + llama_token img_id = req.cached_image_id.value; + auto it = app.img_cache.find(img_id); + if (it != app.img_cache.end()) + { + size_t n_tokens = mtmd_image_tokens_get_n_tokens(it->second.get()); + // TODO: we should get rid of this ugly code, by having mtmd_helper_eval to accept single chunk + mtmd_input_chunks chunks; + { + mtmd_input_chunk chunk0{ + /* type */ MTMD_INPUT_CHUNK_TYPE_IMAGE, + /* tokens_text */ {}, + /* tokens_image */ std::move(it->second), + }; + mtmd_input_chunk chunk1{ + /* type */ MTMD_INPUT_CHUNK_TYPE_TEXT, + /* tokens_text */ {}, + /* tokens_image */ nullptr, + }; + chunks.emplace_back(std::move(chunk0)); + chunks.emplace_back(std::move(chunk1)); + } + int32_t result = mtmd_helper_eval(app.mctx, app.ctx, chunks, n_past, 0, app.n_batch); + // remember to move it back to cache + // TODO: maybe we just need to keep the image id in the cache, not the whole image? + app.img_cache[img_id] = std::move(chunks[0].tokens_image); + if (result != 0) + { + res.success.value = false; + res.message.value = "mtmd_helper_eval() failed with status " + std::to_string(result); + } + else + { + for (size_t i = 0; i < n_tokens; i++) + { + app.tokens.push_back(img_id); + } + res.n_past.value = app.tokens.size(); + res.success.value = true; + } + } + else + { + res.success.value = false; + res.message.value = "image not found in cache (maybe already removed by action_kv_remove/clear?)"; + } + return res; +} + // decode the current logits and sample the new token glue_msg_sampling_sample_res action_sampling_sample(app_t &app, const char *req_raw) { @@ -650,15 +825,22 @@ glue_msg_get_kv_remove_res action_kv_remove(app_t &app, const char *req_raw) const int n_keep = req.n_keep.value; const int n_discard = req.n_discard.value; + glue_msg_get_kv_remove_res res; + + // TODO: make sure n_keep is not in the middle of an image + if (n_discard > 0) { // TODO: this code branch is kinda broken, to be fixed later - const int n_past = app.tokens.size(); - llama_kv_self_seq_rm(app.ctx, 0, n_keep, n_keep + n_discard); - llama_kv_self_seq_add(app.ctx, 0, n_keep + n_discard, n_past, -n_discard); - app.tokens.erase( - app.tokens.begin() + n_keep, - app.tokens.begin() + n_keep + n_discard); + // const int n_past = app.tokens.size(); + // llama_kv_self_seq_rm(app.ctx, 0, n_keep, n_keep + n_discard); + // llama_kv_self_seq_add(app.ctx, 0, n_keep + n_discard, n_past, -n_discard); + // app.tokens.erase( + // app.tokens.begin() + n_keep, + // app.tokens.begin() + n_keep + n_discard); + res.success.value = false; + res.message.value = "n_discard > 0 is not supported yet"; + return res; } else if (n_discard < 0) { @@ -675,7 +857,6 @@ glue_msg_get_kv_remove_res action_kv_remove(app_t &app, const char *req_raw) } } - glue_msg_get_kv_remove_res res; res.success.value = true; res.n_past.value = app.tokens.size(); return res; @@ -694,6 +875,17 @@ glue_msg_get_kv_clear_res action_kv_clear(app_t &app, const char *req_raw) return res; } +// remove dangling images in the cache +glue_msg_img_cache_update_res action_img_cache_update(app_t &app, const char *req_raw) +{ + PARSE_REQ(glue_msg_img_cache_update_req); + app.update_img_cache(); + + glue_msg_img_cache_update_res res; + res.success.value = true; + return res; +} + /* // save current session json action_session_save(app_t &app, json &body) @@ -756,6 +948,10 @@ glue_msg_status_res action_current_status(app_t &app, const char *req_raw) return res; } +// +// multimodal +// + // // benchmark & perplexity // diff --git a/cpp/glue.hpp b/cpp/glue.hpp index 2cdc81b..f1e8a49 100644 --- a/cpp/glue.hpp +++ b/cpp/glue.hpp @@ -491,6 +491,7 @@ struct glue_msg_load_req { GLUE_HANDLER("load_req") GLUE_FIELD(arr_str, model_paths) + GLUE_FIELD(str, mmproj_path) GLUE_FIELD(bool, n_ctx_auto) GLUE_FIELD(bool, use_mmap) GLUE_FIELD(bool, use_mlock) @@ -519,6 +520,7 @@ struct glue_msg_load_res { GLUE_HANDLER("load_res") GLUE_FIELD(bool, success) + GLUE_FIELD(int, has_mtmd) GLUE_FIELD(int, n_ctx) GLUE_FIELD(int, n_batch) GLUE_FIELD(int, n_ubatch) @@ -623,12 +625,16 @@ struct glue_msg_tokenize_req GLUE_HANDLER("tokn_req") GLUE_FIELD(str, text) GLUE_FIELD(bool, special) + GLUE_FIELD(arr_raw, bitmaps) + GLUE_FIELD(arr_int, bitmaps_x) + GLUE_FIELD(arr_int, bitmaps_y) }; struct glue_msg_tokenize_res { GLUE_HANDLER("tokn_res") GLUE_FIELD(bool, success) + GLUE_FIELD(str, message) GLUE_FIELD(arr_int, tokens) }; @@ -682,6 +688,22 @@ struct glue_msg_encode_res ///////// +struct glue_msg_eval_image_req +{ + GLUE_HANDLER("eimg_req") + GLUE_FIELD(int, cached_image_id) +}; + +struct glue_msg_eval_image_res +{ + GLUE_HANDLER("eimg_res") + GLUE_FIELD(bool, success) + GLUE_FIELD(str, message) + GLUE_FIELD(int, n_past) +}; + +///////// + struct glue_msg_sampling_sample_req { GLUE_HANDLER("ssam_req") @@ -755,6 +777,7 @@ struct glue_msg_get_kv_remove_res GLUE_HANDLER("kvcr_res") GLUE_FIELD(int, n_past) GLUE_FIELD(bool, success) + GLUE_FIELD(str, message) }; ///////// @@ -773,6 +796,19 @@ struct glue_msg_get_kv_clear_res ///////// +struct glue_msg_img_cache_update_req +{ + GLUE_HANDLER("icud_req") +}; + +struct glue_msg_img_cache_update_res +{ + GLUE_HANDLER("icud_res") + GLUE_FIELD(bool, success) +}; + +///////// + struct glue_msg_session_save_req { GLUE_HANDLER("sesa_req") diff --git a/cpp/utils.hpp b/cpp/utils.hpp new file mode 100644 index 0000000..1a1bec5 --- /dev/null +++ b/cpp/utils.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +#include "llama.h" +#include "mtmd.h" + +// +// utils +// + +static uint64_t fnv_hash(const uint8_t *data, size_t len) +{ + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) + { + hash ^= data[i]; + hash *= fnv_prime; + } + return hash; +} diff --git a/cpp/wllama.cpp b/cpp/wllama.cpp index 64a1e1d..2db7554 100644 --- a/cpp/wllama.cpp +++ b/cpp/wllama.cpp @@ -112,11 +112,13 @@ extern "C" const char *wllama_action(const char *name, const char *req_raw) WLLAMA_ACTION(detokenize) WLLAMA_ACTION(decode) WLLAMA_ACTION(encode) + WLLAMA_ACTION(eval_image) WLLAMA_ACTION(get_logits) WLLAMA_ACTION(embeddings) WLLAMA_ACTION(chat_format) WLLAMA_ACTION(kv_remove) WLLAMA_ACTION(kv_clear) + WLLAMA_ACTION(img_cache_update) WLLAMA_ACTION(current_status) // WLLAMA_ACTION(session_save) // WLLAMA_ACTION(session_load) diff --git a/examples/multimodal/bliss.png b/examples/multimodal/bliss.png new file mode 100644 index 0000000..22ceab5 Binary files /dev/null and b/examples/multimodal/bliss.png differ diff --git a/examples/multimodal/index.html b/examples/multimodal/index.html new file mode 100644 index 0000000..466fce7 --- /dev/null +++ b/examples/multimodal/index.html @@ -0,0 +1,106 @@ + + + + + + wllama.cpp multimodal demo + + + + + +

Completions

+ +
+
+ Completion:
+
+ + + + \ No newline at end of file diff --git a/llama.cpp b/llama.cpp index b9154ec..2d451c8 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit b9154ecff93ff54dc554411eb844a2a654be49f2 +Subproject commit 2d451c80590b9ac250322769ac13d3b4870dbcf7 diff --git a/src/glue/messages.ts b/src/glue/messages.ts index 77dcff9..ed16cb6 100644 --- a/src/glue/messages.ts +++ b/src/glue/messages.ts @@ -28,6 +28,11 @@ export const GLUE_MESSAGE_PROTOTYPES: { [name: string]: GlueMessageProto } = { "name": "model_paths", "isNullable": false }, + { + "type": "str", + "name": "mmproj_path", + "isNullable": false + }, { "type": "bool", "name": "n_ctx_auto", @@ -150,6 +155,11 @@ export const GLUE_MESSAGE_PROTOTYPES: { [name: string]: GlueMessageProto } = { "name": "success", "isNullable": false }, + { + "type": "int", + "name": "has_mtmd", + "isNullable": false + }, { "type": "int", "name": "n_ctx", @@ -456,6 +466,21 @@ export const GLUE_MESSAGE_PROTOTYPES: { [name: string]: GlueMessageProto } = { "type": "bool", "name": "special", "isNullable": false + }, + { + "type": "arr_raw", + "name": "bitmaps", + "isNullable": false + }, + { + "type": "arr_int", + "name": "bitmaps_x", + "isNullable": false + }, + { + "type": "arr_int", + "name": "bitmaps_y", + "isNullable": false } ] }, @@ -469,6 +494,11 @@ export const GLUE_MESSAGE_PROTOTYPES: { [name: string]: GlueMessageProto } = { "name": "success", "isNullable": false }, + { + "type": "str", + "name": "message", + "isNullable": false + }, { "type": "arr_int", "name": "tokens", @@ -578,6 +608,40 @@ export const GLUE_MESSAGE_PROTOTYPES: { [name: string]: GlueMessageProto } = { } ] }, + "eimg_req": { + "name": "eimg_req", + "structName": "glue_msg_eval_image_req", + "className": "GlueMsgEvalImageReq", + "fields": [ + { + "type": "int", + "name": "cached_image_id", + "isNullable": false + } + ] + }, + "eimg_res": { + "name": "eimg_res", + "structName": "glue_msg_eval_image_res", + "className": "GlueMsgEvalImageRes", + "fields": [ + { + "type": "bool", + "name": "success", + "isNullable": false + }, + { + "type": "str", + "name": "message", + "isNullable": false + }, + { + "type": "int", + "name": "n_past", + "isNullable": false + } + ] + }, "ssam_req": { "name": "ssam_req", "structName": "glue_msg_sampling_sample_req", @@ -729,6 +793,11 @@ export const GLUE_MESSAGE_PROTOTYPES: { [name: string]: GlueMessageProto } = { "type": "bool", "name": "success", "isNullable": false + }, + { + "type": "str", + "name": "message", + "isNullable": false } ] }, @@ -755,6 +824,24 @@ export const GLUE_MESSAGE_PROTOTYPES: { [name: string]: GlueMessageProto } = { } ] }, + "icud_req": { + "name": "icud_req", + "structName": "glue_msg_img_cache_update_req", + "className": "GlueMsgImgCacheUpdateReq", + "fields": [] + }, + "icud_res": { + "name": "icud_res", + "structName": "glue_msg_img_cache_update_res", + "className": "GlueMsgImgCacheUpdateRes", + "fields": [ + { + "type": "bool", + "name": "success", + "isNullable": false + } + ] + }, "sesa_req": { "name": "sesa_req", "structName": "glue_msg_session_save_req", @@ -990,6 +1077,7 @@ export interface GlueMsgError { export interface GlueMsgLoadReq { _name: "load_req"; model_paths: string[]; + mmproj_path: string; n_ctx_auto: boolean; use_mmap: boolean; use_mlock: boolean; @@ -1018,6 +1106,7 @@ export interface GlueMsgLoadReq { export interface GlueMsgLoadRes { _name: "load_res"; success: boolean; + has_mtmd: number; n_ctx: number; n_batch: number; n_ubatch: number; @@ -1112,12 +1201,16 @@ export interface GlueMsgTokenizeReq { _name: "tokn_req"; text: string; special: boolean; + bitmaps: Uint8Array[]; + bitmaps_x: number[]; + bitmaps_y: number[]; } // struct glue_msg_tokenize_res export interface GlueMsgTokenizeRes { _name: "tokn_res"; success: boolean; + message: string; tokens: number[]; } @@ -1163,6 +1256,20 @@ export interface GlueMsgEncodeRes { n_past: number; } +// struct glue_msg_eval_image_req +export interface GlueMsgEvalImageReq { + _name: "eimg_req"; + cached_image_id: number; +} + +// struct glue_msg_eval_image_res +export interface GlueMsgEvalImageRes { + _name: "eimg_res"; + success: boolean; + message: string; + n_past: number; +} + // struct glue_msg_sampling_sample_req export interface GlueMsgSamplingSampleReq { _name: "ssam_req"; @@ -1228,6 +1335,7 @@ export interface GlueMsgGetKvRemoveRes { _name: "kvcr_res"; n_past: number; success: boolean; + message: string; } // struct glue_msg_get_kv_clear_req @@ -1242,6 +1350,17 @@ export interface GlueMsgGetKvClearRes { success: boolean; } +// struct glue_msg_img_cache_update_req +export interface GlueMsgImgCacheUpdateReq { + _name: "icud_req"; +} + +// struct glue_msg_img_cache_update_res +export interface GlueMsgImgCacheUpdateRes { + _name: "icud_res"; + success: boolean; +} + // struct glue_msg_session_save_req export interface GlueMsgSessionSaveReq { _name: "sesa_req"; @@ -1331,4 +1450,4 @@ export interface GlueMsgChatFormatRes { } -export type GlueMsg = GlueMsgError | GlueMsgLoadReq | GlueMsgLoadRes | GlueMsgSetOptionsReq | GlueMsgSetOptionsRes | GlueMsgSamplingInitReq | GlueMsgSamplingInitRes | GlueMsgGetVocabReq | GlueMsgGetVocabRes | GlueMsgLookupTokenReq | GlueMsgLookupTokenRes | GlueMsgTokenizeReq | GlueMsgTokenizeRes | GlueMsgDetokenizeReq | GlueMsgDetokenizeRes | GlueMsgDecodeReq | GlueMsgDecodeRes | GlueMsgEncodeReq | GlueMsgEncodeRes | GlueMsgSamplingSampleReq | GlueMsgSamplingSampleRes | GlueMsgSamplingAcceptReq | GlueMsgSamplingAcceptRes | GlueMsgGetLogitsReq | GlueMsgGetLogitsRes | GlueMsgGetEmbeddingsReq | GlueMsgGetEmbeddingsRes | GlueMsgGetKvRemoveReq | GlueMsgGetKvRemoveRes | GlueMsgGetKvClearReq | GlueMsgGetKvClearRes | GlueMsgSessionSaveReq | GlueMsgSessionSaveRes | GlueMsgSessionLoadReq | GlueMsgSessionLoadRes | GlueMsgStatusReq | GlueMsgStatusRes | GlueMsgTestBenchmarkReq | GlueMsgTestBenchmarkRes | GlueMsgTestPerplexityReq | GlueMsgTestPerplexityRes | GlueMsgChatFormatReq | GlueMsgChatFormatRes; +export type GlueMsg = GlueMsgError | GlueMsgLoadReq | GlueMsgLoadRes | GlueMsgSetOptionsReq | GlueMsgSetOptionsRes | GlueMsgSamplingInitReq | GlueMsgSamplingInitRes | GlueMsgGetVocabReq | GlueMsgGetVocabRes | GlueMsgLookupTokenReq | GlueMsgLookupTokenRes | GlueMsgTokenizeReq | GlueMsgTokenizeRes | GlueMsgDetokenizeReq | GlueMsgDetokenizeRes | GlueMsgDecodeReq | GlueMsgDecodeRes | GlueMsgEncodeReq | GlueMsgEncodeRes | GlueMsgEvalImageReq | GlueMsgEvalImageRes | GlueMsgSamplingSampleReq | GlueMsgSamplingSampleRes | GlueMsgSamplingAcceptReq | GlueMsgSamplingAcceptRes | GlueMsgGetLogitsReq | GlueMsgGetLogitsRes | GlueMsgGetEmbeddingsReq | GlueMsgGetEmbeddingsRes | GlueMsgGetKvRemoveReq | GlueMsgGetKvRemoveRes | GlueMsgGetKvClearReq | GlueMsgGetKvClearRes | GlueMsgImgCacheUpdateReq | GlueMsgImgCacheUpdateRes | GlueMsgSessionSaveReq | GlueMsgSessionSaveRes | GlueMsgSessionLoadReq | GlueMsgSessionLoadRes | GlueMsgStatusReq | GlueMsgStatusRes | GlueMsgTestBenchmarkReq | GlueMsgTestBenchmarkRes | GlueMsgTestPerplexityReq | GlueMsgTestPerplexityRes | GlueMsgChatFormatReq | GlueMsgChatFormatRes; diff --git a/src/multi-thread/wllama.wasm b/src/multi-thread/wllama.wasm index ebe1094..b0304b8 100755 Binary files a/src/multi-thread/wllama.wasm and b/src/multi-thread/wllama.wasm differ diff --git a/src/single-thread/wllama.wasm b/src/single-thread/wllama.wasm index c5071fd..63ccb32 100755 Binary files a/src/single-thread/wllama.wasm and b/src/single-thread/wllama.wasm differ diff --git a/src/utils.ts b/src/utils.ts index 3fb2868..c63da19 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,3 +1,5 @@ +import { WllamaBitmap } from "./wllama"; + export const joinBuffers = (buffers: Uint8Array[]): Uint8Array => { const totalSize = buffers.reduce((acc, buf) => acc + buf.length, 0); const output = new Uint8Array(totalSize); @@ -252,3 +254,44 @@ export const cbToAsyncIter = } })(); }; + +export const getBitmapFromUrl = async (url: string): Promise => { + const response = await fetch(url); + if (!response.ok) { + throw new Error(`Failed to fetch image: ${response.statusText}`); + } + const blob = await response.blob(); + const imageBitmap = await createImageBitmap(blob); + const { width, height } = imageBitmap; + + // Use OffscreenCanvas for better performance and worker compatibility + const canvas = new OffscreenCanvas(width, height); + const ctx = canvas.getContext('2d'); + + if (!ctx) { + throw new Error('Could not get 2D context from canvas'); + } + + // Draw the image onto the canvas + ctx.drawImage(imageBitmap, 0, 0); + + // Get the pixel data (RGBA) + const imageData = ctx.getImageData(0, 0, width, height); + const rgbaData = imageData.data; + + // Convert RGBA to RGB + const rgbData = new Uint8Array(width * height * 3); + for (let i = 0, j = 0; i < rgbaData.length; i += 4, j += 3) { + rgbData[j] = rgbaData[i]; // R + rgbData[j + 1] = rgbaData[i + 1]; // G + rgbData[j + 2] = rgbaData[i + 2]; // B + // Skip alpha channel (rgbaData[i + 3]) + } + imageBitmap.close(); + + return { + width, + height, + data: rgbData, + }; +} diff --git a/src/wllama.ts b/src/wllama.ts index 73b4962..042c220 100644 --- a/src/wllama.ts +++ b/src/wllama.ts @@ -4,6 +4,7 @@ import { bufToText, cbToAsyncIter, checkEnvironmentCompatible, + getBitmapFromUrl, isString, isSupportMultiThread, joinBuffers, @@ -15,11 +16,13 @@ import { GlueMsgChatFormatRes, GlueMsgDecodeRes, GlueMsgDetokenizeRes, + GlueMsgEvalImageRes, GlueMsgGetEmbeddingsRes, GlueMsgGetKvClearRes, GlueMsgGetKvRemoveRes, GlueMsgGetLogitsRes, GlueMsgGetVocabRes, + GlueMsgImgCacheUpdateRes, GlueMsgLoadRes, GlueMsgLookupTokenRes, GlueMsgSamplingAcceptRes, @@ -114,6 +117,8 @@ export interface LoadModelConfig { // optimizations cache_type_k?: 'f32' | 'f16' | 'q8_0' | 'q5_1' | 'q5_0' | 'q4_1' | 'q4_0'; cache_type_v?: 'f32' | 'f16' | 'q8_0' | 'q5_1' | 'q5_0' | 'q4_1' | 'q4_0'; + // multimodal + mmproj?: Blob[] | Model; } export interface SamplingConfig { @@ -173,6 +178,7 @@ export interface ChatCompletionOptions { abortSignal: () => any; } ): any; + images?: WllamaInputImage[]; sampling?: SamplingConfig; /** * List of custom token IDs for stopping the generation. @@ -231,6 +237,18 @@ export interface LoadedContextInfo { add_eos_token: boolean; } +export interface WllamaBitmap { + data: Uint8Array; + width: number; + height: number; +} + +export type WllamaInputImage = string | WllamaBitmap; + +export interface WllamaInputExtra { + images?: WllamaInputImage[]; +} + /** * Logger preset with debug messages suppressed */ @@ -266,6 +284,11 @@ export class WllamaAbortError extends Error { } } +interface InternalBatch { + type: 'text' | 'image'; + tokens: number[]; +}; + export class Wllama { // The CacheManager and ModelManager are singleton, can be accessed by user public cacheManager: CacheManager; @@ -579,6 +602,14 @@ export class Wllama { name: `model-${i}.gguf`, blob, })); + // multimodal + if (config.mmproj) { + const mmprojBlobs = config.mmproj instanceof Model ? await config.mmproj.open() : config.mmproj; + modelFiles.push({ + name: 'mmproj.gguf', + blob: mmprojBlobs[0], + }); + } await this.proxy.moduleInit(modelFiles); // run it const startResult: any = await this.proxy.wllamaStart(); @@ -590,6 +621,8 @@ export class Wllama { // load the model const loadResult: GlueMsgLoadRes = await this.proxy.wllamaAction('load', { _name: 'load_req', + model_paths: modelFiles.map((f) => `models/${f.name}`), + mmproj_path: config.mmproj ? 'models/mmproj.gguf' : '', use_mmap: true, use_mlock: true, n_gpu_layers: 0, // not supported for now @@ -597,7 +630,6 @@ export class Wllama { n_ctx: config.n_ctx || 1024, n_threads: this.useMultiThread ? nbThreads : 1, n_ctx_auto: false, // not supported for now - model_paths: modelFiles.map((f) => `models/${f.name}`), embeddings: config.embeddings, offload_kqv: config.offload_kqv, n_batch: config.n_batch, @@ -750,7 +782,9 @@ export class Wllama { await this.samplingInit(this.samplingConfig); const stopTokens = new Set(options.stopTokens ?? []); // process prompt - let tokens = await this.tokenize(prompt, true); + let tokens = await this.tokenize(prompt, true, { + images: options.images, + }); if (this.addBosToken && tokens[0] !== this.bosToken) { tokens.unshift(this.bosToken); } @@ -890,19 +924,43 @@ export class Wllama { } /** - * Convert a given text to list of tokens + * Convert a given text to list of tokens. + * + * If the input contains images, the preprocessed image will be store temporarily in an image cache, they will be available for the next call to `imageCacheUpdate()` + * * @param text * @param special Should split special tokens? * @returns List of token ID */ - async tokenize(text: string, special: boolean = true): Promise { + async tokenize(text: string, special: boolean = true, extra: WllamaInputExtra = {}): Promise { this.checkModelLoaded(); + const bitmaps: Uint8Array[] = []; + const bitmaps_x: number[] = []; + const bitmaps_y: number[] = []; + + for (const image of extra.images || []) { + let bmp: WllamaBitmap; + if (isString(image)) { + bmp = await getBitmapFromUrl(image as string); + } else if ((image as WllamaBitmap).data) { + bmp = image as WllamaBitmap; + } else { + throw new WllamaError('Invalid bitmap'); + } + bitmaps.push(bmp.data); + bitmaps_x.push(bmp.width); + bitmaps_y.push(bmp.height); + } + const result = await this.proxy.wllamaAction( 'tokenize', { _name: 'tokn_req', text, special: !!special, + bitmaps, + bitmaps_x, + bitmaps_y, } ); return result.tokens; @@ -959,29 +1017,37 @@ export class Wllama { 'kv_cache_full' ); } + let nPast = this.nCachedTokens; const batches = this.breakTokensIntoBatches( tokens, this.loadedContextInfo.n_batch ); - let result: any; for (let i = 0; i < batches.length; i++) { if (options?.abortSignal?.aborted) { throw new WllamaAbortError(); } - const isNotLast = batches.length > 1 && i < batches.length - 1; - result = await this.proxy.wllamaAction('decode', { - _name: 'deco_req', - tokens: batches[i], - skip_logits: options.skipLogits || isNotLast, - }); - if (result.error) { - throw new WllamaError(result.error); - } else if (!result.success) { - throw new WllamaError('Cannot encode, unknown error'); + const batch = batches[i]; + const isLast = i === batches.length - 1; + if (batch.type === 'text') { + const result = await this.proxy.wllamaAction('decode', { + _name: 'deco_req', + tokens: batch.tokens, + skip_logits: options.skipLogits || !isLast, + }); + if (!result.success) { + throw new WllamaError('Cannot decode, unknown error'); + } + nPast = result.n_past; + } else if (batch.type === 'image') { + const result = await this.evalImage(batch.tokens[0]); + nPast = result.nPast; + } else { + throw new Error('Invalid batch type'); } } - this.nCachedTokens = result.n_past; - return { nPast: result.n_past }; + await this.imageCacheUpdate(); // make sure to clean non-decoded images + this.nCachedTokens = nPast; + return { nPast }; } /** @@ -1023,36 +1089,66 @@ export class Wllama { tokens, this.loadedContextInfo.n_batch ); - let result: any; + let nPast = 0; for (let i = 0; i < batches.length; i++) { if (options?.abortSignal?.aborted) { throw new WllamaAbortError(); } - result = await this.proxy.wllamaAction('encode', { + const result = await this.proxy.wllamaAction('encode', { _name: 'enco_req', - tokens: batches[i], + tokens: batches[i].tokens, }); - if (result.error) { - throw new WllamaError(result.error); - } else if (!result.success) { + if (!result.success) { throw new WllamaError('Cannot encode, unknown error'); } + nPast = result.n_past; } - this.nCachedTokens = result.n_past; - return { nPast: result.n_past }; + return { nPast }; } private breakTokensIntoBatches( tokens: number[], maxBatchSize: number - ): number[][] { - const batches: number[][] = []; - for (let i = 0; i < tokens.length; i += maxBatchSize) { - batches.push(tokens.slice(i, i + maxBatchSize)); + ): InternalBatch[] { + const batches: InternalBatch[] = []; + if (tokens.length === 0) return batches; + for (let i = 0; i < tokens.length;) { + const firstTok = tokens[i]; + const isImage = firstTok < 0; + const cur: InternalBatch = { + type: isImage ? 'image' : 'text', + tokens: [], + }; + const until = Math.min(i + maxBatchSize, tokens.length); + for (; i < until; i++) { + if (isImage && tokens[i] !== firstTok) break; + if (!isImage && tokens[i] < 0) break; + cur.tokens.push(tokens[i]); + } + if (cur.tokens.length) batches.push(cur); } return batches; } + /** + * Evaluate (encode-decode) an image by its ID + */ + private async evalImage(cachedImageId: number): Promise<{ nPast: number }> { + this.checkModelLoaded(); + const result = await this.proxy.wllamaAction( + 'eval_image', + { + _name: 'eimg_req', + cached_image_id: cachedImageId, + } + ); + if (!result.success) { + throw new WllamaError(result.message); + } + this.nCachedTokens = result.n_past; + return { nPast: result.n_past }; + } + /** * Sample a new token (remember to samplingInit() at least once before calling this function) * @returns the token ID and its detokenized value (which maybe an unfinished unicode) @@ -1202,6 +1298,24 @@ export class Wllama { this.nCachedTokens = 0; } + /** + * Remove dangling tokens in image cache. + * + * If the image is not being stored in the context (i.e. not being processed by `llama_decode()`), it will be removed from the cache. + */ + async imageCacheUpdate(): Promise { + this.checkModelLoaded(); + const result = await this.proxy.wllamaAction( + 'img_cache_update', + { + _name: 'icud_req', + } + ); + if (!result.success) { + throw new WllamaError('imageCacheUpdate unknown error'); + } + } + /** * Save session to file (virtual file system) * TODO: add ability to download the file