diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index d7f0d85156c..65c67c49f01 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -76,7 +76,11 @@ int main(int argc, char** argv) { std::vector buf; buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char std::ofstream fout(FLAGS_output_path.c_str()); - auto callback = [&](const std::string& piece) { + + int32_t num_total_tokens = 0; + + auto callback = [&](const std::string& piece, int32_t tokens_generated) { + num_total_tokens += tokens_generated; for (const char c : piece) { buf.push_back(c); } @@ -85,6 +89,7 @@ int main(int argc, char** argv) { for (int i = 0; i < FLAGS_num_iters; i++) { runner.generate( FLAGS_seq_len, + num_total_tokens, FLAGS_prompt.c_str(), FLAGS_system_prompt.c_str(), callback); diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp index ce7baefa080..d645c7476f5 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp @@ -494,10 +494,56 @@ void ShiftPointerIoMgr::prepare_prefill_io( } } +void ShiftPointerIoMgr::update_kv_to_prefill_io( + int64_t pos, + std::vector>& output_tensors) { + // update v_cache + assert(pos <= 512); + + ET_LOG(Info, "update kv to prefill io, pos: %ld, last prefill pos: %ld", pos, last_pos_); + + int64_t pos_diff = pos - last_pos_; + std::vector>& v_cache_in = + v_cache_in_[prefill_forward_name_]; + for (int i = 0, v_cache_stride = head_dim_ * pos_diff; i < v_cache_in.size(); + i++) { + v_cache_in[i]->set_data( + v_cache_in[i]->mutable_data() + v_cache_stride); + } + + // update k_cache + std::vector>& k_cache_in = + k_cache_in_[prefill_forward_name_]; + + size_t copied_size = pos_diff * sizeof(uint8_t); + + for (int i = 0, k_cache_stride = pos_diff * sizeof(uint8_t); i < k_cache_in_.size(); + i++) { + k_cache_in[i]->set_data( + k_cache_in[i]->mutable_data() + k_cache_stride); + uint8_t* ptr_in = k_cache_in[i]->mutable_data() - pos_diff; + for (int j = 0; j < head_dim_; ++j) { + memcpy( + ptr_in + j * prefill_cache_len_, + ptr_in + j * kv_cache_len_, + copied_size); + } + } + + // Setting attention mask from context_len - prefill_ar_len - i to context_len + IO* ptr = static_cast(data_ptr_.get()); + for (int i = prefill_ar_len_; i < pos; i++) { + for (int j = 0; j < prefill_ar_len_; j++) { + ptr->prefill_attention_mask[j * context_len_ + context_len_ - prefill_ar_len_ - i] = 65535; + } + } +} + void ShiftPointerIoMgr::update_prefill_to_kv_io( int64_t cur_token, int64_t pos, std::vector>& output_tensors) { + last_pos_ = pos; ET_CHECK_MSG(kv_cache_len_ != 0, "k_cache_len_ should not equal to 0"); IO* ptr = static_cast(data_ptr_.get()); @@ -664,33 +710,32 @@ void ShiftPointerIoMgr::update_prefill_io( } void ShiftPointerIoMgr::fill_prefill_toks( - int64_t start_pos, + int64_t num_prev_tokens, + int64_t prompt_pos, std::vector& prompt_tokens) { IO* ptr = static_cast(get_mutable_ptr()); for (int i = 0; i < prefill_ar_len_; i++) { if (!is_bert_) { - ptr->prefill_input_pos[i] = start_pos + i; + ptr->prefill_input_pos[i] = num_prev_tokens + prompt_pos + i; } - if (start_pos + i < prompt_tokens.size()) { + if (prompt_pos + i < prompt_tokens.size()) { // Support CPU 4-bit embedding, which requires int64 input. // However, for QNN embedding, only int32 input is needed. // Therefore, we need to cast to the correct type to write the data. if (use_int64_token_) { - ptr->prefill_input_toks[i] = prompt_tokens[start_pos + i]; + ptr->prefill_input_toks[i] = prompt_tokens[prompt_pos + i]; } else { int32_t* prefill_input_toks_ptr = reinterpret_cast(ptr->prefill_input_toks.data()); prefill_input_toks_ptr[i] = - static_cast(prompt_tokens[start_pos + i]); + static_cast(prompt_tokens[prompt_pos + i]); } } - if (start_pos >= prefill_ar_len_) { - for (int j = 0, - offset = i * context_len_ + - (context_len_ - prefill_ar_len_ - start_pos); - j < prefill_ar_len_; - ++j) { + if (num_prev_tokens + prompt_pos >= prefill_ar_len_) { + int64_t start_offset = i * context_len_ + + (context_len_ - num_prev_tokens - prompt_pos - prefill_ar_len_); + for (int j = 0, offset = start_offset; j < prefill_ar_len_; ++j) { ptr->prefill_attention_mask[offset + j] = 65535; } } @@ -1305,6 +1350,12 @@ void SmartMaskIoMgr::prepare_prefill_io( } } +void SmartMaskIoMgr::update_kv_to_prefill_io( + int64_t pos, + std::vector>& output_tensors) { + //TODO: Fill In + } + void SmartMaskIoMgr::update_prefill_to_kv_io( int64_t cur_token, int64_t pos, @@ -1396,29 +1447,30 @@ void SmartMaskIoMgr::update_prefill_io( } void SmartMaskIoMgr::fill_prefill_toks( - int64_t start_pos, + int64_t num_prev_tokens, + int64_t prompt_pos, std::vector& prompt_tokens) { IO* ptr = static_cast(get_mutable_ptr()); for (int i = 0; i < prefill_ar_len_; i++) { if (!is_bert_) { - ptr->prefill_input_pos[i] = start_pos + i; + ptr->prefill_input_pos[i] = prompt_pos + i; } - if (start_pos + i < prompt_tokens.size()) { + if (prompt_pos + i < prompt_tokens.size()) { // Support CPU 4-bit embedding, which requires int64 input. // However, for QNN embedding, only int32 input is needed. // Therefore, we need to cast to the correct type to write the data. if (use_int64_token_) { - ptr->prefill_input_toks[i] = prompt_tokens[start_pos + i]; + ptr->prefill_input_toks[i] = prompt_tokens[prompt_pos + i]; } else { int32_t* prefill_input_toks_ptr = reinterpret_cast(ptr->prefill_input_toks); prefill_input_toks_ptr[i] = - static_cast(prompt_tokens[start_pos + i]); + static_cast(prompt_tokens[prompt_pos + i]); } } - if (start_pos >= prefill_ar_len_) { - for (int j = 0, offset = i * context_len_ + (start_pos - prefill_ar_len_); + if (prompt_pos >= prefill_ar_len_) { + for (int j = 0, offset = i * context_len_ + (prompt_pos - prefill_ar_len_); j < prefill_ar_len_; ++j) { ptr->prefill_attention_mask[offset + j] = 65535; diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.h b/examples/qualcomm/oss_scripts/llama/runner/io_manager.h index 03808ede3bf..defbd10a70a 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/io_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.h @@ -48,9 +48,13 @@ class IoMgrBase { executorch::runtime::Result>& methods_meta) = 0; virtual void fill_prefill_toks( - int64_t start_pos, + int64_t num_prev_tokens, + int64_t prompt_pos, std::vector& prompt_tokens) = 0; virtual void fill_kv_tok_mask(int64_t pos, int64_t cur_token) = 0; + virtual void update_kv_to_prefill_io( + int64_t pos, + std::vector>& output_tensors) = 0; virtual void update_prefill_to_kv_io( int64_t cur_token, int64_t pos, @@ -118,9 +122,13 @@ class ShiftPointerIoMgr : public IoMgrBase { executorch::runtime::Result>& methods_meta) override; void fill_prefill_toks( - int64_t start_pos, + int64_t num_prev_tokens, + int64_t prompt_pos, std::vector& prompt_tokens) override; void fill_kv_tok_mask(int64_t pos, int64_t cur_token) override; + void update_kv_to_prefill_io( + int64_t pos, + std::vector>& output_tensors) override; void update_prefill_to_kv_io( int64_t cur_token, int64_t pos, @@ -190,6 +198,8 @@ class ShiftPointerIoMgr : public IoMgrBase { std::string kv_forward_name_; const bool use_int64_token_{false}; const bool is_bert_{false}; + + int64_t last_pos_{0}; }; class SmartMaskIoMgr : public IoMgrBase { @@ -226,9 +236,13 @@ class SmartMaskIoMgr : public IoMgrBase { executorch::runtime::Result>& methods_meta) override; void fill_prefill_toks( - int64_t start_pos, + int64_t num_prev_tokens, + int64_t prompt_pos, std::vector& prompt_tokens) override; void fill_kv_tok_mask(int64_t pos, int64_t cur_token) override; + void update_kv_to_prefill_io( + int64_t pos, + std::vector>& output_tensors) override; void update_prefill_to_kv_io( int64_t cur_token, int64_t pos, diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index dafc911a172..28d539681c3 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -276,9 +276,10 @@ void Runner::run_model_step( Error Runner::generate( int32_t seq_len, + int32_t num_prev_tokens, const std::string& prompt, const std::string& system_prompt, - std::function token_callback, + std::function token_callback, std::function stats_callback) { std::unordered_map>> input_tensors, output_tensors; @@ -327,7 +328,7 @@ Error Runner::generate( prompt_.append( "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); if (token_callback) { - token_callback("<|begin_of_text|>"); + token_callback("<|begin_of_text|>", 0); } break; default: @@ -335,6 +336,8 @@ Error Runner::generate( break; } + ET_LOG(Info, "Number of Previous Tokens Prefill + Decode: %d", num_prev_tokens); + seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_; tokenizers::Result> encode_res = tokenizer_->encode(prompt_, n_bos_, 0); @@ -349,7 +352,7 @@ Error Runner::generate( int64_t pos = 0, prev_token, cur_token = prompt_tokens[0]; if (token_callback) { - token_callback(prompt_); + token_callback(prompt_, num_prompt_tokens); } auto prefill_execute = [&](const std::string& method_name) { int num_iters = 1 + ((num_prompt_tokens - 1) / prefill_ar_len_); @@ -361,7 +364,7 @@ Error Runner::generate( num_iters); for (int i = 0; i < num_iters; i++) { - io_mgr_->fill_prefill_toks(pos, prompt_tokens); + io_mgr_->fill_prefill_toks(num_prev_tokens, pos, prompt_tokens); run_model_step(method_name, inputs[method_name]); io_mgr_->update_prefill_io(cur_token, pos, output_tensors[method_name]); pos += prefill_ar_len_; @@ -377,10 +380,12 @@ Error Runner::generate( auto piece_res = tokenizer_->decode(prev_token, cur_token); ET_CHECK(piece_res.ok()); if (token_callback) { - token_callback(piece_res.get().c_str()); + ET_LOG(Info, "Prefill: %s", piece_res.get().c_str()); + token_callback(piece_res.get().c_str(), 1); } - pos = num_prompt_tokens; + pos = num_prev_tokens + num_prompt_tokens; + ET_LOG(Info, "Pos: %ld, Prompt Tokens: %ld", pos, num_prompt_tokens); stats_.first_token_ms = time_in_ms(); stats_.prompt_eval_end_ms = time_in_ms(); }; @@ -394,9 +399,9 @@ Error Runner::generate( // hybrid mode will check these stats_ at prefill(prefill) if (eval_mode_ == EvalMode::kKVCached) { - if (pos == num_prompt_tokens) { + if (pos == num_prev_tokens + num_prompt_tokens) { stats_.first_token_ms = time_in_ms(); - } else if (pos == num_prompt_tokens - 1) { + } else if (pos == num_prev_tokens + num_prompt_tokens - 1) { stats_.prompt_eval_end_ms = time_in_ms(); } } @@ -405,7 +410,7 @@ Error Runner::generate( cur_token = logitsToToken(logits_tensor, pos); stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; - if (pos < num_prompt_tokens - 1) { + if (pos < num_prev_tokens + num_prompt_tokens - 1) { cur_token = prompt_tokens[pos + 1]; } io_mgr_->update_kv_io(cur_token, ++pos, output_tensors[method_name]); @@ -413,7 +418,7 @@ Error Runner::generate( ET_CHECK(piece_res.ok()); if (token_callback && pos >= num_prompt_tokens) { - token_callback(piece_res.get().c_str()); + token_callback(piece_res.get().c_str(), 1); } if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) { @@ -432,6 +437,7 @@ Error Runner::generate( io_mgr_->update_prefill_to_kv_io( cur_token, pos, output_tensors[kv_forward_name_]); kv_execute(kv_forward_name_); + io_mgr_->update_kv_to_prefill_io(pos, output_tensors[prefill_forward_name_]); break; default: ET_CHECK_MSG(false, "Unsupported eval mode"); @@ -448,9 +454,9 @@ Error Runner::generate( if (stats_callback) { stats_callback(stats_); } - io_mgr_->reset_io( - get_methods_meta(prefill_forward_name_), - get_methods_meta(kv_forward_name_)); + // io_mgr_->reset_io( + // get_methods_meta(prefill_forward_name_), + // get_methods_meta(kv_forward_name_)); prompt_.clear(); return Error::Ok; } diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index e693bcd7077..79c92957bf6 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -67,9 +67,10 @@ class Runner { executorch::runtime::Error load(); executorch::runtime::Error generate( int32_t seq_len, + int32_t num_prev_tokens, const std::string& prompt, const std::string& system_prompt, - std::function token_callback = {}, + std::function token_callback = {}, std::function stats_callback = {}); void stop(); std::vector>