Skip to content

decode to prefill #9662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 103 additions & 2 deletions examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ ShiftPointerIoMgr::ShiftPointerIoMgr(
prefill_forward_name_(prefill_forward_name),
kv_forward_name_(kv_forward_name),
use_int64_token_(use_int64_token),
is_bert_(prefill_cache_len_ == 0) {
is_bert_(prefill_cache_len_ == 0),
current_stage_(Stage::kPrefill),
current_pos_(0) {

ET_LOG(Info, "ShiftPointerIoMgr is created");
ET_LOG(Info, " kv_ar_len: %d, kv_cache_len: %d, prefill_ar_len: %d, prefill_cache_len: %d, vocab_size: %d, num_layers: %d, head_dim: %d, num_heads: %d, eval_mode: %d, prefill_forward_name: %s, kv_forward_name: %s", kv_ar_len_, kv_cache_len_, prefill_ar_len_, prefill_cache_len_, vocab_size, num_layers, head_dim, num_heads, eval_mode, prefill_forward_name.c_str(), kv_forward_name.c_str());

if (!prefill_forward_name_.empty()) {
input_tensors_[prefill_forward_name_] =
std::vector<std::vector<executorch::aten::TensorImpl*>>(modules.size());
Expand Down Expand Up @@ -498,6 +504,9 @@ void ShiftPointerIoMgr::update_prefill_to_kv_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<Tensor>>& output_tensors) {

ET_CHECK_MSG(current_stage_ == Stage::kPrefill,
"Transition only allowed from prefill stage");
ET_CHECK_MSG(kv_cache_len_ != 0, "k_cache_len_ should not equal to 0");
IO* ptr = static_cast<IO*>(data_ptr_.get());

Expand Down Expand Up @@ -562,12 +571,88 @@ void ShiftPointerIoMgr::update_prefill_to_kv_io(
}
k_cache_in[i]->set_data(ptr_in + pos);
}

current_stage_ = Stage::kDecode;
current_pos_ = pos; // Track global position
}

void ShiftPointerIoMgr::update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) {
ET_LOG(Info, "update_kv_to_prefill_io, current pos is %ld", pos);
IO* ptr = static_cast<IO*>(data_ptr_.get());

ET_CHECK_MSG(current_stage_ == Stage::kDecode,
"Transition only allowed from decode stage");

// synchronize kv cache pointers with decode stage positions
// v cache sync
auto& decode_v_cache_in = v_cache_in_[kv_forward_name_];
auto& prefill_v_cache_in = v_cache_in_[prefill_forward_name_];
for (int i = 0; i < prefill_v_cache_in.size(); i++) {
uint8_t* decode_v_ptr = decode_v_cache_in[i]->mutable_data<uint8_t>();
// reuse decode's position
prefill_v_cache_in[i]->set_data(decode_v_ptr);
}

// k cache sync
auto& decode_k_cache_in = k_cache_in_[kv_forward_name_];
auto& prefill_k_cache_in = k_cache_in_[prefill_forward_name_];
for (int i = 0; i < prefill_k_cache_in.size(); i++) {
uint8_t* decode_k_ptr = decode_k_cache_in[i]->mutable_data<uint8_t>();
// reuse decode's position
prefill_k_cache_in[i]->set_data(decode_k_ptr);
}

// update prefill input positions to continue from current pos
for (int i = 0; i < prefill_ar_len_; i++) {
// continue sequence
ptr->prefill_input_pos[i] = pos + i;
ET_LOG(Debug, "prefill_input_pos[%d] = %d", i, ptr->prefill_input_pos[i]);
}

// update attention mask for continued sequence
std::fill(ptr->prefill_attention_mask.begin(),
ptr->prefill_attention_mask.end(), 0);

// build mask accounting for previous tokens (pos offset)
for (int i = 0; i < prefill_ar_len_; ++i) {
const int global_step = pos + i;
for (int j = 0;
j <= global_step && j < context_len_;
++j) {
const int offset = i * context_len_ + (context_len_ - global_step - 1);
if (offset + j < ptr->prefill_attention_mask.size()) {
ptr->prefill_attention_mask[offset + j] = 65535;
}
}
}

if (!output_tensors.empty()) {
for (int shard = 0; shard < output_tensors.size(); shard++) {
for (int index = 0; index < output_tensors[shard].size(); index++) {
ET_CHECK_MSG(
modules_[shard]->set_output(
prefill_forward_name_,
output_tensors[shard][index],
index) == Error::Ok,
"Failed to set output tensor during transition");
}
}
}

// update system state
current_stage_ = Stage::kPrefill;
current_pos_ = pos; // Track global position

ET_LOG(Info, "Transitioned to prefill at position %ld. Cache preserved.", pos);
}

void ShiftPointerIoMgr::update_kv_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<Tensor>>& output_tensors) {
ET_CHECK_MSG(current_stage_ == Stage::kDecode, "update_kv_io only allowed from decode stage");
IO* ptr = static_cast<IO*>(data_ptr_.get());
// update input_tok
ptr->kv_input_toks =
Expand Down Expand Up @@ -611,12 +696,16 @@ void ShiftPointerIoMgr::update_kv_io(
}
k_cache_in[i]->set_data(ptr_in + 1);
}
current_stage_ = Stage::kDecode;
current_pos_ = pos;
}

void ShiftPointerIoMgr::update_prefill_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<Tensor>>& output_tensors) {
ET_CHECK_MSG(current_stage_ == Stage::kPrefill,
"Transition only allowed from decode stage");
(void)cur_token;
(void)output_tensors;

Expand Down Expand Up @@ -661,12 +750,16 @@ void ShiftPointerIoMgr::update_prefill_io(
k_cache_in[i]->set_data(ptr_in + prefill_ar_len_);
}
}
current_stage_ = Stage::kPrefill;
current_pos_ = pos;
}

void ShiftPointerIoMgr::fill_prefill_toks(
int64_t start_pos,
std::vector<uint64_t>& prompt_tokens) {
IO* ptr = static_cast<IO*>(get_mutable_ptr());
ET_CHECK_MSG(current_stage_ == Stage::kPrefill,
"fill_prefill_toks only allowed during prefill stage");
IO* ptr = static_cast<IO*>(get_mutable_ptr());
for (int i = 0; i < prefill_ar_len_; i++) {
if (!is_bert_) {
ptr->prefill_input_pos[i] = start_pos + i;
Expand Down Expand Up @@ -698,6 +791,8 @@ void ShiftPointerIoMgr::fill_prefill_toks(
}

void ShiftPointerIoMgr::fill_kv_tok_mask(int64_t pos, int64_t cur_token) {
ET_CHECK_MSG(current_stage_ == Stage::kDecode,
"fill_kv_tok_mask only allowed during decode stage");
IO* ptr = static_cast<IO*>(get_mutable_ptr());
ptr->kv_input_toks =
use_int64_token_ ? cur_token : static_cast<int32_t>(cur_token);
Expand Down Expand Up @@ -1360,6 +1455,12 @@ void SmartMaskIoMgr::update_prefill_to_kv_io(
}
}

void SmartMaskIoMgr::update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) {
ET_LOG(Info, "not implemented");
}

void SmartMaskIoMgr::update_prefill_io(
int64_t cur_token,
int64_t pos,
Expand Down
16 changes: 16 additions & 0 deletions examples/qualcomm/oss_scripts/llama/runner/io_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@

namespace example {

enum class Stage {
kPrefill,
kDecode,
};

enum EvalMode {
kKVCached = 0,
kHybrid,
Expand Down Expand Up @@ -63,6 +68,9 @@ class IoMgrBase {
int64_t cur_token,
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) = 0;
virtual void update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) = 0;
void* get_mutable_ptr();
std::vector<executorch::aten::Tensor> get_input_tensors(
int shard_index,
Expand Down Expand Up @@ -136,6 +144,9 @@ class ShiftPointerIoMgr : public IoMgrBase {
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors)
override;
void update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) override;
struct IO {
int64_t kv_input_toks;
int32_t kv_input_pos;
Expand Down Expand Up @@ -190,6 +201,8 @@ class ShiftPointerIoMgr : public IoMgrBase {
std::string kv_forward_name_;
const bool use_int64_token_{false};
const bool is_bert_{false};
Stage current_stage_ = Stage::kPrefill; // Track current stage
int64_t current_pos_ = 0; // Track current position
};

class SmartMaskIoMgr : public IoMgrBase {
Expand Down Expand Up @@ -244,6 +257,9 @@ class SmartMaskIoMgr : public IoMgrBase {
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors)
override;
void update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) override;

std::unordered_map<std::string, size_t> get_io_elements();
std::unordered_map<std::string, size_t> get_io_bytes();
Expand Down
8 changes: 8 additions & 0 deletions examples/qualcomm/oss_scripts/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,17 @@ int32_t Runner::logitsToToken(const Tensor& logits_tensor, int64_t pos) {
void Runner::run_model_step(
const std::string& method_name,
std::vector<std::vector<EValue>>& inputs) {
ET_LOG(Info, "Running model step %s: inputs len: %zu", method_name.c_str(), inputs.size());
for(size_t i = 0; i < inputs.size(); i++) {
ET_LOG(Info, " input[%zd] size: %zu", i, inputs[i].size());
}
for (size_t i = 0, num_modules = modules_.size(); i < num_modules; ++i) {
Result<std::vector<EValue>> outputs_res =
modules_[i]->execute(method_name, inputs[i]);
ET_CHECK_MSG(
outputs_res.error() == Error::Ok, "shard %zu inference failed", i);
}
ET_LOG(Info, "Finish running model step %s", method_name.c_str());
}

Error Runner::generate(
Expand Down Expand Up @@ -352,6 +357,7 @@ Error Runner::generate(
token_callback(prompt_);
}
auto prefill_execute = [&](const std::string& method_name) {
ET_LOG(Info, "Executing prefill step %s", method_name.c_str());
int num_iters = 1 + ((num_prompt_tokens - 1) / prefill_ar_len_);
ET_LOG(
Info,
Expand Down Expand Up @@ -386,6 +392,7 @@ Error Runner::generate(
};

auto kv_execute = [&](const std::string& method_name) {
ET_LOG(Info, "Executing kv step %s", method_name.c_str());
io_mgr_->fill_kv_tok_mask(pos, cur_token);
while (pos < seq_len - 1) {
// inference
Expand Down Expand Up @@ -432,6 +439,7 @@ Error Runner::generate(
io_mgr_->update_prefill_to_kv_io(
cur_token, pos, output_tensors[kv_forward_name_]);
kv_execute(kv_forward_name_);
io_mgr_->update_kv_to_prefill_io(pos, output_tensors[prefill_forward_name_]);
break;
default:
ET_CHECK_MSG(false, "Unsupported eval mode");
Expand Down
Loading