Skip to content

Commit ea14743

Browse files
committed
add compression for plain messages
1 parent edb97fe commit ea14743

File tree

1 file changed

+69
-7
lines changed

1 file changed

+69
-7
lines changed

refact-agent/engine/src/scratchpads/chat_utils_limit_history.rs

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,44 @@
1+
use std::cmp::min;
12
use crate::scratchpad_abstract::HasTokenizerAndEot;
2-
use crate::call_validation::ChatMessage;
3+
use crate::call_validation::{ChatContent, ChatMessage};
34
use std::collections::HashSet;
5+
use std::sync::{Arc, RwLock};
6+
use tokenizers::Tokenizer;
7+
use crate::scratchpads::multimodality::MultimodalElement;
48

9+
const MESSAGE_TOKEN_LIMIT: i32 = 12_000;
10+
11+
fn compress_string(text: &String, tokenizer: Arc<RwLock<Tokenizer>>) -> Result<String, String> {
12+
let tokenizer_lock = tokenizer.read().unwrap();
13+
let tokens = tokenizer_lock.encode(&**text, false).map_err(|e| e.to_string())?;
14+
let first_tokens = &tokens.get_ids()[0..(MESSAGE_TOKEN_LIMIT / 2) as usize];
15+
let last_tokens = &tokens.get_ids()[tokens.len() - (MESSAGE_TOKEN_LIMIT / 2) as usize ..];
16+
let mut text = tokenizer_lock.decode(first_tokens, false).map_err(|e| e.to_string())?;
17+
text.push_str("\n...\n");
18+
text.push_str(&tokenizer_lock.decode(last_tokens, false).map_err(|e| e.to_string())?);
19+
Ok(text)
20+
}
21+
22+
fn compress_message(msg: &ChatMessage, tokenizer: Arc<RwLock<Tokenizer>>) -> Result<ChatMessage, String> {
23+
let mut message = msg.clone();
24+
match message.content.clone() {
25+
ChatContent::SimpleText(simple_text) => {
26+
message.content = ChatContent::SimpleText(compress_string(&simple_text, tokenizer.clone())?);
27+
}
28+
ChatContent::Multimodal(elements) => {
29+
let mut new_elements: Vec<MultimodalElement> = vec![];
30+
for element in elements {
31+
if element.is_text() {
32+
new_elements.push(MultimodalElement::new("text".to_string(), compress_string(&element.m_content, tokenizer.clone())?)?);
33+
} else {
34+
new_elements.push(element.clone());
35+
}
36+
}
37+
message.content = ChatContent::Multimodal(new_elements);
38+
}
39+
};
40+
Ok(message)
41+
}
542

643
pub fn limit_messages_history(
744
t: &HasTokenizerAndEot,
@@ -16,28 +53,43 @@ pub fn limit_messages_history(
1653
let mut tokens_used: i32 = 0;
1754
let mut message_token_count: Vec<i32> = vec![0; messages.len()];
1855
let mut message_take: Vec<bool> = vec![false; messages.len()];
56+
let mut message_can_be_compressed: Vec<bool> = vec![false; messages.len()];
57+
let message_roles: Vec<String> = messages.iter().map(|x| x.role.clone()).collect();
58+
1959
for (i, msg) in messages.iter().enumerate() {
2060
let tcnt = 3 + msg.content.count_tokens(t.tokenizer.clone(), &None)?;
2161
message_token_count[i] = tcnt;
2262
if i==0 && msg.role == "system" {
2363
message_take[i] = true;
2464
tokens_used += tcnt;
2565
} else if i==1 && msg.role == "user" {
26-
// we cannot drop the user message which comes right after the system message according to Antropic API
66+
// we cannot drop the user message which comes right after the system message according to Anthropic API
2767
message_take[i] = true;
28-
tokens_used += tcnt;
68+
tokens_used += min(tcnt, MESSAGE_TOKEN_LIMIT + 3);
2969
} else if i >= last_user_msg_starts {
3070
message_take[i] = true;
31-
tokens_used += tcnt;
71+
tokens_used += min(tcnt, MESSAGE_TOKEN_LIMIT + 3);
3272
}
3373
}
74+
75+
// Need to save uncompressed last messages of assistant, tool_calls and user between assistant. It could be patch tool calls
76+
for i in (0..message_roles.len()).rev() {
77+
if message_roles[i] == "user" {
78+
message_can_be_compressed[i] = true;
79+
}
80+
}
81+
3482
let mut log_buffer = Vec::new();
3583
let mut dropped = false;
3684

3785
for i in (0..messages.len()).rev() {
3886
let tcnt = 3 + message_token_count[i];
3987
if !message_take[i] {
40-
if tokens_used + tcnt < tokens_limit {
88+
if message_can_be_compressed[i] && tcnt > MESSAGE_TOKEN_LIMIT + 3 && tokens_used + MESSAGE_TOKEN_LIMIT + 3 < tokens_limit {
89+
message_take[i] = true;
90+
tokens_used += MESSAGE_TOKEN_LIMIT + 3;
91+
log_buffer.push(format!("take compressed {:?}, tokens_used={} < {}", crate::nicer_logs::first_n_chars(&messages[i].content.content_text_only(), 30), tokens_used, tokens_limit));
92+
} else if tokens_used + tcnt < tokens_limit {
4193
message_take[i] = true;
4294
tokens_used += tcnt;
4395
log_buffer.push(format!("take {:?}, tokens_used={} < {}", crate::nicer_logs::first_n_chars(&messages[i].content.content_text_only(), 30), tokens_used, tokens_limit));
@@ -46,6 +98,7 @@ pub fn limit_messages_history(
4698
dropped = true;
4799
break;
48100
}
101+
49102
} else {
50103
message_take[i] = true;
51104
log_buffer.push(format!("not allowed to drop {:?}, tokens_used={} < {}", crate::nicer_logs::first_n_chars(&messages[i].content.content_text_only(), 30), tokens_used, tokens_limit));
@@ -77,7 +130,16 @@ pub fn limit_messages_history(
77130
tracing::info!("drop {:?} because of drop tool result rule", crate::nicer_logs::first_n_chars(&messages[i].content.content_text_only(), 30));
78131
}
79132
}
80-
81-
let messages_out: Vec<ChatMessage> = messages.iter().enumerate().filter(|(i, _)| message_take[*i]).map(|(_, x)| x.clone()).collect();
133+
let mut messages_out: Vec<ChatMessage> = Vec::new();
134+
for i in 0..messages.len() {
135+
if message_take[i] {
136+
if message_can_be_compressed[i] && message_token_count[i] > MESSAGE_TOKEN_LIMIT {
137+
messages_out.push(compress_message(&messages[i], t.tokenizer.clone())?);
138+
} else {
139+
messages_out.push(messages[i].clone());
140+
}
141+
}
142+
}
143+
82144
Ok(messages_out)
83145
}

0 commit comments

Comments
 (0)