Skip to content
This repository was archived by the owner on Sep 17, 2024. It is now read-only.

Prevent instruction prompt and response from leaking into output #6

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <map>
#include <string>
#include <vector>
#include <deque>

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
Expand Down Expand Up @@ -320,7 +321,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
fin.close();

std::vector<uint8_t> tmp;

for (int i = 0; i < n_parts; ++i) {
const int part_id = i;
//const int part_id = n_parts - i - 1;
Expand Down Expand Up @@ -825,7 +826,7 @@ int main(int argc, char ** argv) {
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!llama_model_load(params.model, model, vocab, params.n_ctx)) {
if (!llama_model_load(params.model, model, vocab, params.n_ctx)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
Expand Down Expand Up @@ -946,7 +947,20 @@ int main(int argc, char ** argv) {
printf(ANSI_COLOR_YELLOW);
}


// remove token 1 from prompt_inp, and the last 2 tokens (line breaks) for each of prompt_inp and response_inp
std::vector<gpt_vocab::id> prompt_inp_sequence;
std::vector<gpt_vocab::id> response_inp_sequence;
prompt_inp_sequence = std::vector<gpt_vocab::id>(prompt_inp.begin() + 1, prompt_inp.end() - 2);
response_inp_sequence = std::vector<gpt_vocab::id>(response_inp.begin(), response_inp.end() - 2);


std::vector<std::vector<gpt_vocab::id>> instruct_sequences = { // token sequences to look for in output
prompt_inp_sequence,
response_inp_sequence
};

// Initialize a vector to track the progress of each target sequence
std::vector<int> instruct_indices(instruct_sequences.size(), 0);

while (remaining_tokens > 0) {
// predict
Expand Down Expand Up @@ -1016,10 +1030,51 @@ int main(int argc, char ** argv) {

// display text
if (!input_noecho) {
for (auto id : embd) {
printf("%s", vocab.id_to_token[id].c_str());
std::deque<std::string> output_buffer;
bool sequence_found = false;

size_t i = 0;
while (i < embd.size() && !sequence_found) {
gpt_vocab::id id = embd[i];
bool handled = false;

size_t j = 0;
while (j < instruct_sequences.size() && !handled) {
if (id == instruct_sequences[j][instruct_indices[j]]) {
instruct_indices[j]++;
handled = true;
if (instruct_indices[j] == instruct_sequences[j].size()) { // If we have found the full instruct_sequence stop printing
i += instruct_sequences[j].size() - 1; // Skip the rest of the found target sequence
sequence_found = true;
is_interacting = true;
break;
}
} else if (instruct_indices[j] > 0) {
// Handle partial match cases
i -= instruct_indices[j] - 1; // Move back by (instruct_indices[j] - 1) steps
instruct_indices[j] = 0;
handled = true;
break;
} else {
j++; // Increment 'j' only when no match or partial match is found
}
}

if (!handled) {
output_buffer.push_back(vocab.id_to_token[id]);
}
i++;
fflush(stdout);
}

// Flush the remaining elements in the buffer
while (!output_buffer.empty()) {
std::string output = output_buffer.front();
output_buffer.pop_front();
printf("%s", output.c_str());
}
fflush(stdout);

}

// in interactive mode, and not currently processing queued inputs;
Expand All @@ -1035,7 +1090,11 @@ int main(int argc, char ** argv) {
// embd_inp.erase(embd_inp.begin());
input_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), prompt_inp.begin(), prompt_inp.end());


// clear all indices
for (size_t j = 0; j < instruct_indices.size(); ++j) {
instruct_indices[j] = 0;
}

printf("\n> ");

Expand Down