Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,33 @@ set(WLLAMA_SRC cpp/wllama.cpp
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cpp)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/cpp/helpers)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/llama.cpp/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/llama.cpp/examples/llava)

#
# multimodal
#
add_library(mtmd OBJECT
llama.cpp/examples/llava/mtmd.cpp
llama.cpp/examples/llava/mtmd.h
llama.cpp/examples/llava/clip.cpp
llama.cpp/examples/llava/clip.h
llama.cpp/examples/llava/clip-impl.h
)

target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})

target_include_directories(mtmd PUBLIC llama.cpp/examples/llava)
target_include_directories(mtmd PRIVATE llama.cpp)
target_include_directories(mtmd PRIVATE llama.cpp/common) # for stb_image.h

target_compile_features(mtmd PRIVATE cxx_std_17)
target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h
target_compile_options(mtmd PRIVATE -Wno-c++11-narrowing) # clip.cpp

add_library(mtmd_static STATIC $<TARGET_OBJECTS:mtmd>)

#
# wllama target definition
#
add_executable(wllama ${WLLAMA_SRC})
target_link_libraries(wllama PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(wllama PRIVATE ggml llama mtmd_static ${CMAKE_THREAD_LIBS_INIT})
222 changes: 209 additions & 13 deletions cpp/actions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
#include <sstream>
#include <stdio.h>
#include <cmath>
#include <unordered_map>
#include <set>

#include "llama.h"
#include "mtmd.h"
#include "utils.hpp"
#include "helpers/wcommon.h"
#include "helpers/wsampling.h"

Expand All @@ -20,13 +24,47 @@

struct app_t
{
llama_model *model;
llama_context *ctx;
llama_model *model = nullptr;
llama_context *ctx = nullptr;
const llama_vocab *vocab;
wcommon_sampler *ctx_sampling = nullptr;
llama_batch batch = llama_batch_init(512, 0, 1);
llama_tokens tokens;
int32_t n_batch;
int32_t seed = LLAMA_DEFAULT_SEED;

// multimodal
mtmd_context *mctx = nullptr;
llama_token next_img_id = -1000; // count backwards: -1001, -1002, -1003, ...
std::unordered_map<llama_token, mtmd_image_tokens_ptr> img_cache;

llama_token find_cached_image_tokens_by_id(const std::string &id)
{
for (auto it = img_cache.begin(); it != img_cache.end(); ++it)
{
if (mtmd_image_tokens_get_id(it->second.get()) == id)
return it->first;
}
return LLAMA_TOKEN_NULL;
}

// remove unused images
void update_img_cache()
{
std::set<llama_token> active_img_ids;
for (auto &t : tokens)
{
if (t < 0)
active_img_ids.insert(t);
}
for (auto it = img_cache.begin(); it != img_cache.end();)
{
if (active_img_ids.find(it->first) == active_img_ids.end())
it = img_cache.erase(it);
else
++it;
}
}
};

inline std::vector<char> convert_string_to_buf(std::string &input)
Expand Down Expand Up @@ -211,6 +249,7 @@ glue_msg_load_res action_load(app_t &app, const char *req_raw)
}

// load model
app.n_batch = cparams.n_batch;
app.model = llama_model_load_from_splits(
model_paths_ptrs.data(), model_paths_ptrs.size(), mparams);
if (app.model == nullptr)
Expand Down Expand Up @@ -260,8 +299,22 @@ glue_msg_load_res action_load(app_t &app, const char *req_raw)
}
kv_dump metadata = dump_metadata(app);

if (req.mmproj_path.not_null() && !req.mmproj_path.value.empty())
{
mtmd_context_params mparams;
mparams.use_gpu = false; // not yet supported
mparams.n_threads = cparams.n_threads;
app.mctx = mtmd_init_from_file(req.mmproj_path.value.c_str(), app.model, mparams);
if (app.mctx == nullptr)
{
free_all(app);
throw app_exception("Error while loading multimodal projector");
}
}

glue_msg_load_res res;
res.success.value = true;
res.has_mtmd.value = app.mctx != nullptr;
res.n_ctx.value = cparams.n_ctx;
res.n_batch.value = llama_n_batch(app.ctx);
res.n_ubatch.value = llama_n_ubatch(app.ctx);
Expand Down Expand Up @@ -420,12 +473,66 @@ glue_msg_lookup_token_res action_lookup_token(app_t &app, const char *req_raw)
}

// tokenize an input string
// if the input contains images, the preprocessed images will be cached until action_kv_remove/clear is called
// images NOT being decoded will be cleared from the cache
glue_msg_tokenize_res action_tokenize(app_t &app, const char *req_raw)
{
PARSE_REQ(glue_msg_tokenize_req);
std::string &text = req.text.value;
bool special = req.special.value;
llama_tokens tokens_list = wcommon_tokenize(app.vocab, text, false, special);
llama_tokens tokens_list;
if (app.mctx)
{
// multimodal
std::vector<mtmd_bitmap> mbitmaps;
for (size_t i = 0; i < req.bitmaps.arr.size(); i++)
{
mtmd_bitmap bitmap;
std::vector<unsigned char> tmp(req.bitmaps.arr[i].begin(), req.bitmaps.arr[i].end());
bitmap.id = std::to_string(fnv_hash(tmp.data(), tmp.size()));
bitmap.data = std::move(tmp);
bitmap.nx = req.bitmaps_x.arr[i];
bitmap.ny = req.bitmaps_y.arr[i];
mbitmaps.push_back(bitmap);
}
mtmd_input_text mtext{
/* text */ text,
/* add_special */ false,
/* parse_special */ special,
};
mtmd_input_chunks chunks;
int32_t res = mtmd_tokenize(app.mctx, chunks, mtext, mbitmaps);
if (res != 0)
{
glue_msg_tokenize_res res;
res.success.value = false;
res.message.value = "mtmd_tokenize failed";
return res;
}
for (auto &chunk : chunks)
{
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT)
{
tokens_list.insert(tokens_list.end(), chunk.tokens_text.begin(), chunk.tokens_text.end());
}
else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE)
{
// if the image is already in the cache, use the cached image id
std::string id = mtmd_image_tokens_get_id(chunk.tokens_image.get());
llama_token cached_id = app.find_cached_image_tokens_by_id(id);
llama_token img_id = cached_id == LLAMA_TOKEN_NULL ? app.next_img_id-- : cached_id;
for (size_t i = 0; i < mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get()); i++)
tokens_list.push_back(img_id);
if (cached_id == LLAMA_TOKEN_NULL)
app.img_cache[img_id] = std::move(chunk.tokens_image); // add to cache
}
}
}
else
{
// text only
tokens_list = wcommon_tokenize(app.vocab, text, false, special);
}

glue_msg_tokenize_res res;
res.success.value = true;
Expand All @@ -441,7 +548,9 @@ glue_msg_detokenize_res action_detokenize(app_t &app, const char *req_raw)
std::stringstream output;
for (auto id : tokens)
{
output << wcommon_token_to_piece(app.ctx, id);
// TODO: if we encounter an image token, we should replace it with the image marker
if (id >= 0)
output << wcommon_token_to_piece(app.ctx, id);
}
std::string parsed_str = output.str();

Expand All @@ -459,9 +568,16 @@ glue_msg_decode_res action_decode(app_t &app, const char *req_raw)
bool skip_logits = req.skip_logits.value;
size_t i = 0;
wcommon_batch_clear(app.batch);
glue_msg_decode_res res;
for (auto id : tokens_list)
{
bool grp_attn_enabled = false; // TODO: maybe remove grp_attn
if (id < 0)
{
res.success.value = false;
res.message.value = "image token not supported, use action_eval_image instead";
res.n_past.value = app.tokens.size();
return res;
}
int32_t n_past = app.tokens.size();
wcommon_batch_add(app.batch, id, n_past, {0}, false);
app.tokens.push_back(id);
Expand All @@ -472,7 +588,6 @@ glue_msg_decode_res action_decode(app_t &app, const char *req_raw)
{
app.batch.logits[app.batch.n_tokens - 1] = true;
}
glue_msg_decode_res res;
if (llama_decode(app.ctx, app.batch) != 0)
{
res.success.value = false;
Expand Down Expand Up @@ -521,6 +636,66 @@ glue_msg_encode_res action_encode(app_t &app, const char *req_raw)
return res;
}

glue_msg_eval_image_res action_eval_image(app_t &app, const char *req_raw)
{
PARSE_REQ(glue_msg_eval_image_req);
glue_msg_eval_image_res res;
llama_pos n_past = app.tokens.size();
res.n_past.value = n_past;
if (app.mctx == nullptr)
{
res.success.value = false;
res.message.value = "this model does not have a multimodal projector";
return res;
}
llama_token img_id = req.cached_image_id.value;
auto it = app.img_cache.find(img_id);
if (it != app.img_cache.end())
{
size_t n_tokens = mtmd_image_tokens_get_n_tokens(it->second.get());
// TODO: we should get rid of this ugly code, by having mtmd_helper_eval to accept single chunk
mtmd_input_chunks chunks;
{
mtmd_input_chunk chunk0{
/* type */ MTMD_INPUT_CHUNK_TYPE_IMAGE,
/* tokens_text */ {},
/* tokens_image */ std::move(it->second),
};
mtmd_input_chunk chunk1{
/* type */ MTMD_INPUT_CHUNK_TYPE_TEXT,
/* tokens_text */ {},
/* tokens_image */ nullptr,
};
chunks.emplace_back(std::move(chunk0));
chunks.emplace_back(std::move(chunk1));
}
int32_t result = mtmd_helper_eval(app.mctx, app.ctx, chunks, n_past, 0, app.n_batch);
// remember to move it back to cache
// TODO: maybe we just need to keep the image id in the cache, not the whole image?
app.img_cache[img_id] = std::move(chunks[0].tokens_image);
if (result != 0)
{
res.success.value = false;
res.message.value = "mtmd_helper_eval() failed with status " + std::to_string(result);
}
else
{
for (size_t i = 0; i < n_tokens; i++)
{
app.tokens.push_back(img_id);
}
res.n_past.value = app.tokens.size();
res.success.value = true;
}
}
else
{
res.success.value = false;
res.message.value = "image not found in cache (maybe already removed by action_kv_remove/clear?)";
}
return res;
}

// decode the current logits and sample the new token
glue_msg_sampling_sample_res action_sampling_sample(app_t &app, const char *req_raw)
{
Expand Down Expand Up @@ -650,15 +825,22 @@ glue_msg_get_kv_remove_res action_kv_remove(app_t &app, const char *req_raw)
const int n_keep = req.n_keep.value;
const int n_discard = req.n_discard.value;

glue_msg_get_kv_remove_res res;

// TODO: make sure n_keep is not in the middle of an image

if (n_discard > 0)
{
// TODO: this code branch is kinda broken, to be fixed later
const int n_past = app.tokens.size();
llama_kv_self_seq_rm(app.ctx, 0, n_keep, n_keep + n_discard);
llama_kv_self_seq_add(app.ctx, 0, n_keep + n_discard, n_past, -n_discard);
app.tokens.erase(
app.tokens.begin() + n_keep,
app.tokens.begin() + n_keep + n_discard);
// const int n_past = app.tokens.size();
// llama_kv_self_seq_rm(app.ctx, 0, n_keep, n_keep + n_discard);
// llama_kv_self_seq_add(app.ctx, 0, n_keep + n_discard, n_past, -n_discard);
// app.tokens.erase(
// app.tokens.begin() + n_keep,
// app.tokens.begin() + n_keep + n_discard);
res.success.value = false;
res.message.value = "n_discard > 0 is not supported yet";
return res;
}
else if (n_discard < 0)
{
Expand All @@ -675,7 +857,6 @@ glue_msg_get_kv_remove_res action_kv_remove(app_t &app, const char *req_raw)
}
}

glue_msg_get_kv_remove_res res;
res.success.value = true;
res.n_past.value = app.tokens.size();
return res;
Expand All @@ -694,6 +875,17 @@ glue_msg_get_kv_clear_res action_kv_clear(app_t &app, const char *req_raw)
return res;
}

// remove dangling images in the cache
glue_msg_img_cache_update_res action_img_cache_update(app_t &app, const char *req_raw)
{
PARSE_REQ(glue_msg_img_cache_update_req);
app.update_img_cache();

glue_msg_img_cache_update_res res;
res.success.value = true;
return res;
}

/*
// save current session
json action_session_save(app_t &app, json &body)
Expand Down Expand Up @@ -756,6 +948,10 @@ glue_msg_status_res action_current_status(app_t &app, const char *req_raw)
return res;
}

//
// multimodal
//

//
// benchmark & perplexity
//
Expand Down
Loading
Loading