Skip to content

[Proposal] Support Multiple Prefill + Decode in a loop #9466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ int main(int argc, char** argv) {
std::vector<char> buf;
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
std::ofstream fout(FLAGS_output_path.c_str());
auto callback = [&](const std::string& piece) {

int32_t num_total_tokens = 0;

auto callback = [&](const std::string& piece, int32_t tokens_generated) {
num_total_tokens += tokens_generated;
for (const char c : piece) {
buf.push_back(c);
}
Expand All @@ -85,6 +89,7 @@ int main(int argc, char** argv) {
for (int i = 0; i < FLAGS_num_iters; i++) {
runner.generate(
FLAGS_seq_len,
num_total_tokens,
FLAGS_prompt.c_str(),
FLAGS_system_prompt.c_str(),
callback);
Expand Down
88 changes: 70 additions & 18 deletions examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,10 +494,56 @@ void ShiftPointerIoMgr::prepare_prefill_io(
}
}

void ShiftPointerIoMgr::update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) {
// update v_cache
assert(pos <= 512);

ET_LOG(Info, "update kv to prefill io, pos: %ld, last prefill pos: %ld", pos, last_pos_);

int64_t pos_diff = pos - last_pos_;
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& v_cache_in =
v_cache_in_[prefill_forward_name_];
for (int i = 0, v_cache_stride = head_dim_ * pos_diff; i < v_cache_in.size();
i++) {
v_cache_in[i]->set_data(
v_cache_in[i]->mutable_data<uint8_t>() + v_cache_stride);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v_cache_out needs to be updated as well, please refer to the resolved comment, thank you.

}

// update k_cache
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& k_cache_in =
k_cache_in_[prefill_forward_name_];

size_t copied_size = pos_diff * sizeof(uint8_t);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copied_size should be pos * sizeof(uint8_t);


for (int i = 0, k_cache_stride = pos_diff * sizeof(uint8_t); i < k_cache_in_.size();
i++) {
k_cache_in[i]->set_data(
k_cache_in[i]->mutable_data<uint8_t>() + k_cache_stride);
uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos_diff;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we need to get the origin for deep copy: uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos;

for (int j = 0; j < head_dim_; ++j) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, probably need to change here a bit: for (int j = 0; j <= head_dim_; ++j) {
I forgot we preserve extra space to prevent shifting beyond boundary.

memcpy(
ptr_in + j * prefill_cache_len_,
ptr_in + j * kv_cache_len_,
copied_size);
}
}

// Setting attention mask from context_len - prefill_ar_len - i to context_len
IO* ptr = static_cast<IO*>(data_ptr_.get());
for (int i = prefill_ar_len_; i < pos; i++) {
for (int j = 0; j < prefill_ar_len_; j++) {
ptr->prefill_attention_mask[j * context_len_ + context_len_ - prefill_ar_len_ - i] = 65535;
}
}
}

void ShiftPointerIoMgr::update_prefill_to_kv_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<Tensor>>& output_tensors) {
last_pos_ = pos;
ET_CHECK_MSG(kv_cache_len_ != 0, "k_cache_len_ should not equal to 0");
IO* ptr = static_cast<IO*>(data_ptr_.get());

Expand Down Expand Up @@ -664,33 +710,32 @@ void ShiftPointerIoMgr::update_prefill_io(
}

void ShiftPointerIoMgr::fill_prefill_toks(
int64_t start_pos,
int64_t num_prev_tokens,
int64_t prompt_pos,
std::vector<uint64_t>& prompt_tokens) {
IO* ptr = static_cast<IO*>(get_mutable_ptr());
for (int i = 0; i < prefill_ar_len_; i++) {
if (!is_bert_) {
ptr->prefill_input_pos[i] = start_pos + i;
ptr->prefill_input_pos[i] = num_prev_tokens + prompt_pos + i;
}

if (start_pos + i < prompt_tokens.size()) {
if (prompt_pos + i < prompt_tokens.size()) {
// Support CPU 4-bit embedding, which requires int64 input.
// However, for QNN embedding, only int32 input is needed.
// Therefore, we need to cast to the correct type to write the data.
if (use_int64_token_) {
ptr->prefill_input_toks[i] = prompt_tokens[start_pos + i];
ptr->prefill_input_toks[i] = prompt_tokens[prompt_pos + i];
} else {
int32_t* prefill_input_toks_ptr =
reinterpret_cast<int32_t*>(ptr->prefill_input_toks.data());
prefill_input_toks_ptr[i] =
static_cast<int32_t>(prompt_tokens[start_pos + i]);
static_cast<int32_t>(prompt_tokens[prompt_pos + i]);
}
}
if (start_pos >= prefill_ar_len_) {
for (int j = 0,
offset = i * context_len_ +
(context_len_ - prefill_ar_len_ - start_pos);
j < prefill_ar_len_;
++j) {
if (num_prev_tokens + prompt_pos >= prefill_ar_len_) {
int64_t start_offset = i * context_len_ +
(context_len_ - num_prev_tokens - prompt_pos - prefill_ar_len_);
for (int j = 0, offset = start_offset; j < prefill_ar_len_; ++j) {
ptr->prefill_attention_mask[offset + j] = 65535;
}
}
Expand Down Expand Up @@ -1305,6 +1350,12 @@ void SmartMaskIoMgr::prepare_prefill_io(
}
}

void SmartMaskIoMgr::update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<Tensor>>& output_tensors) {
//TODO: Fill In
}

void SmartMaskIoMgr::update_prefill_to_kv_io(
int64_t cur_token,
int64_t pos,
Expand Down Expand Up @@ -1396,29 +1447,30 @@ void SmartMaskIoMgr::update_prefill_io(
}

void SmartMaskIoMgr::fill_prefill_toks(
int64_t start_pos,
int64_t num_prev_tokens,
int64_t prompt_pos,
std::vector<uint64_t>& prompt_tokens) {
IO* ptr = static_cast<IO*>(get_mutable_ptr());
for (int i = 0; i < prefill_ar_len_; i++) {
if (!is_bert_) {
ptr->prefill_input_pos[i] = start_pos + i;
ptr->prefill_input_pos[i] = prompt_pos + i;
}

if (start_pos + i < prompt_tokens.size()) {
if (prompt_pos + i < prompt_tokens.size()) {
// Support CPU 4-bit embedding, which requires int64 input.
// However, for QNN embedding, only int32 input is needed.
// Therefore, we need to cast to the correct type to write the data.
if (use_int64_token_) {
ptr->prefill_input_toks[i] = prompt_tokens[start_pos + i];
ptr->prefill_input_toks[i] = prompt_tokens[prompt_pos + i];
} else {
int32_t* prefill_input_toks_ptr =
reinterpret_cast<int32_t*>(ptr->prefill_input_toks);
prefill_input_toks_ptr[i] =
static_cast<int32_t>(prompt_tokens[start_pos + i]);
static_cast<int32_t>(prompt_tokens[prompt_pos + i]);
}
}
if (start_pos >= prefill_ar_len_) {
for (int j = 0, offset = i * context_len_ + (start_pos - prefill_ar_len_);
if (prompt_pos >= prefill_ar_len_) {
for (int j = 0, offset = i * context_len_ + (prompt_pos - prefill_ar_len_);
j < prefill_ar_len_;
++j) {
ptr->prefill_attention_mask[offset + j] = 65535;
Expand Down
20 changes: 17 additions & 3 deletions examples/qualcomm/oss_scripts/llama/runner/io_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ class IoMgrBase {
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
methods_meta) = 0;
virtual void fill_prefill_toks(
int64_t start_pos,
int64_t num_prev_tokens,
int64_t prompt_pos,
std::vector<uint64_t>& prompt_tokens) = 0;
virtual void fill_kv_tok_mask(int64_t pos, int64_t cur_token) = 0;
virtual void update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) = 0;
virtual void update_prefill_to_kv_io(
int64_t cur_token,
int64_t pos,
Expand Down Expand Up @@ -118,9 +122,13 @@ class ShiftPointerIoMgr : public IoMgrBase {
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
methods_meta) override;
void fill_prefill_toks(
int64_t start_pos,
int64_t num_prev_tokens,
int64_t prompt_pos,
std::vector<uint64_t>& prompt_tokens) override;
void fill_kv_tok_mask(int64_t pos, int64_t cur_token) override;
void update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) override;
void update_prefill_to_kv_io(
int64_t cur_token,
int64_t pos,
Expand Down Expand Up @@ -190,6 +198,8 @@ class ShiftPointerIoMgr : public IoMgrBase {
std::string kv_forward_name_;
const bool use_int64_token_{false};
const bool is_bert_{false};

int64_t last_pos_{0};
};

class SmartMaskIoMgr : public IoMgrBase {
Expand Down Expand Up @@ -226,9 +236,13 @@ class SmartMaskIoMgr : public IoMgrBase {
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
methods_meta) override;
void fill_prefill_toks(
int64_t start_pos,
int64_t num_prev_tokens,
int64_t prompt_pos,
std::vector<uint64_t>& prompt_tokens) override;
void fill_kv_tok_mask(int64_t pos, int64_t cur_token) override;
void update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) override;
void update_prefill_to_kv_io(
int64_t cur_token,
int64_t pos,
Expand Down
32 changes: 19 additions & 13 deletions examples/qualcomm/oss_scripts/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,10 @@ void Runner::run_model_step(

Error Runner::generate(
int32_t seq_len,
int32_t num_prev_tokens,
const std::string& prompt,
const std::string& system_prompt,
std::function<void(const std::string&)> token_callback,
std::function<void(const std::string&, int32_t)> token_callback,
std::function<void(const Stats&)> stats_callback) {
std::unordered_map<std::string, std::vector<std::vector<Tensor>>>
input_tensors, output_tensors;
Expand Down Expand Up @@ -327,14 +328,16 @@ Error Runner::generate(
prompt_.append(
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
if (token_callback) {
token_callback("<|begin_of_text|>");
token_callback("<|begin_of_text|>", 0);
}
break;
default:
ET_CHECK_MSG(false, "unsupported llama version");
break;
}

ET_LOG(Info, "Number of Previous Tokens Prefill + Decode: %d", num_prev_tokens);

seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
tokenizers::Result<std::vector<uint64_t>> encode_res =
tokenizer_->encode(prompt_, n_bos_, 0);
Expand All @@ -349,7 +352,7 @@ Error Runner::generate(

int64_t pos = 0, prev_token, cur_token = prompt_tokens[0];
if (token_callback) {
token_callback(prompt_);
token_callback(prompt_, num_prompt_tokens);
}
auto prefill_execute = [&](const std::string& method_name) {
int num_iters = 1 + ((num_prompt_tokens - 1) / prefill_ar_len_);
Expand All @@ -361,7 +364,7 @@ Error Runner::generate(
num_iters);

for (int i = 0; i < num_iters; i++) {
io_mgr_->fill_prefill_toks(pos, prompt_tokens);
io_mgr_->fill_prefill_toks(num_prev_tokens, pos, prompt_tokens);
run_model_step(method_name, inputs[method_name]);
io_mgr_->update_prefill_io(cur_token, pos, output_tensors[method_name]);
pos += prefill_ar_len_;
Expand All @@ -377,10 +380,12 @@ Error Runner::generate(
auto piece_res = tokenizer_->decode(prev_token, cur_token);
ET_CHECK(piece_res.ok());
if (token_callback) {
token_callback(piece_res.get().c_str());
ET_LOG(Info, "Prefill: %s", piece_res.get().c_str());
token_callback(piece_res.get().c_str(), 1);
}

pos = num_prompt_tokens;
pos = num_prev_tokens + num_prompt_tokens;
ET_LOG(Info, "Pos: %ld, Prompt Tokens: %ld", pos, num_prompt_tokens);
stats_.first_token_ms = time_in_ms();
stats_.prompt_eval_end_ms = time_in_ms();
};
Expand All @@ -394,9 +399,9 @@ Error Runner::generate(

// hybrid mode will check these stats_ at prefill(prefill)
if (eval_mode_ == EvalMode::kKVCached) {
if (pos == num_prompt_tokens) {
if (pos == num_prev_tokens + num_prompt_tokens) {
stats_.first_token_ms = time_in_ms();
} else if (pos == num_prompt_tokens - 1) {
} else if (pos == num_prev_tokens + num_prompt_tokens - 1) {
stats_.prompt_eval_end_ms = time_in_ms();
}
}
Expand All @@ -405,15 +410,15 @@ Error Runner::generate(
cur_token = logitsToToken(logits_tensor, pos);
stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;

if (pos < num_prompt_tokens - 1) {
if (pos < num_prev_tokens + num_prompt_tokens - 1) {
cur_token = prompt_tokens[pos + 1];
}
io_mgr_->update_kv_io(cur_token, ++pos, output_tensors[method_name]);
auto piece_res = tokenizer_->decode(prev_token, cur_token);
ET_CHECK(piece_res.ok());

if (token_callback && pos >= num_prompt_tokens) {
token_callback(piece_res.get().c_str());
token_callback(piece_res.get().c_str(), 1);
}

if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) {
Expand All @@ -432,6 +437,7 @@ Error Runner::generate(
io_mgr_->update_prefill_to_kv_io(
cur_token, pos, output_tensors[kv_forward_name_]);
kv_execute(kv_forward_name_);
io_mgr_->update_kv_to_prefill_io(pos, output_tensors[prefill_forward_name_]);
break;
default:
ET_CHECK_MSG(false, "Unsupported eval mode");
Expand All @@ -448,9 +454,9 @@ Error Runner::generate(
if (stats_callback) {
stats_callback(stats_);
}
io_mgr_->reset_io(
get_methods_meta(prefill_forward_name_),
get_methods_meta(kv_forward_name_));
// io_mgr_->reset_io(
// get_methods_meta(prefill_forward_name_),
// get_methods_meta(kv_forward_name_));
prompt_.clear();
return Error::Ok;
}
Expand Down
3 changes: 2 additions & 1 deletion examples/qualcomm/oss_scripts/llama/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ class Runner {
executorch::runtime::Error load();
executorch::runtime::Error generate(
int32_t seq_len,
int32_t num_prev_tokens,
const std::string& prompt,
const std::string& system_prompt,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const std::string&, int32_t)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {});
void stop();
std::vector<executorch::runtime::Result<executorch::runtime::MethodMeta>>
Expand Down
Loading