1
+ use std:: cmp:: min;
1
2
use crate :: scratchpad_abstract:: HasTokenizerAndEot ;
2
- use crate :: call_validation:: ChatMessage ;
3
+ use crate :: call_validation:: { ChatContent , ChatMessage } ;
3
4
use std:: collections:: HashSet ;
5
+ use std:: sync:: { Arc , RwLock } ;
6
+ use tokenizers:: Tokenizer ;
7
+ use crate :: scratchpads:: multimodality:: MultimodalElement ;
4
8
9
+ const MESSAGE_TOKEN_LIMIT : i32 = 12_000 ;
10
+
11
+ fn compress_string ( text : & String , tokenizer : Arc < RwLock < Tokenizer > > ) -> Result < String , String > {
12
+ let tokenizer_lock = tokenizer. read ( ) . unwrap ( ) ;
13
+ let tokens = tokenizer_lock. encode ( & * * text, false ) . map_err ( |e| e. to_string ( ) ) ?;
14
+ let first_tokens = & tokens. get_ids ( ) [ 0 ..( MESSAGE_TOKEN_LIMIT / 2 ) as usize ] ;
15
+ let last_tokens = & tokens. get_ids ( ) [ tokens. len ( ) - ( MESSAGE_TOKEN_LIMIT / 2 ) as usize ..] ;
16
+ let mut text = tokenizer_lock. decode ( first_tokens, false ) . map_err ( |e| e. to_string ( ) ) ?;
17
+ text. push_str ( "\n ...\n " ) ;
18
+ text. push_str ( & tokenizer_lock. decode ( last_tokens, false ) . map_err ( |e| e. to_string ( ) ) ?) ;
19
+ Ok ( text)
20
+ }
21
+
22
+ fn compress_message ( msg : & ChatMessage , tokenizer : Arc < RwLock < Tokenizer > > ) -> Result < ChatMessage , String > {
23
+ let mut message = msg. clone ( ) ;
24
+ match message. content . clone ( ) {
25
+ ChatContent :: SimpleText ( simple_text) => {
26
+ message. content = ChatContent :: SimpleText ( compress_string ( & simple_text, tokenizer. clone ( ) ) ?) ;
27
+ }
28
+ ChatContent :: Multimodal ( elements) => {
29
+ let mut new_elements: Vec < MultimodalElement > = vec ! [ ] ;
30
+ for element in elements {
31
+ if element. is_text ( ) {
32
+ new_elements. push ( MultimodalElement :: new ( "text" . to_string ( ) , compress_string ( & element. m_content , tokenizer. clone ( ) ) ?) ?) ;
33
+ } else {
34
+ new_elements. push ( element. clone ( ) ) ;
35
+ }
36
+ }
37
+ message. content = ChatContent :: Multimodal ( new_elements) ;
38
+ }
39
+ } ;
40
+ Ok ( message)
41
+ }
5
42
6
43
pub fn limit_messages_history (
7
44
t : & HasTokenizerAndEot ,
@@ -16,28 +53,43 @@ pub fn limit_messages_history(
16
53
let mut tokens_used: i32 = 0 ;
17
54
let mut message_token_count: Vec < i32 > = vec ! [ 0 ; messages. len( ) ] ;
18
55
let mut message_take: Vec < bool > = vec ! [ false ; messages. len( ) ] ;
56
+ let mut message_can_be_compressed: Vec < bool > = vec ! [ false ; messages. len( ) ] ;
57
+ let message_roles: Vec < String > = messages. iter ( ) . map ( |x| x. role . clone ( ) ) . collect ( ) ;
58
+
19
59
for ( i, msg) in messages. iter ( ) . enumerate ( ) {
20
60
let tcnt = 3 + msg. content . count_tokens ( t. tokenizer . clone ( ) , & None ) ?;
21
61
message_token_count[ i] = tcnt;
22
62
if i==0 && msg. role == "system" {
23
63
message_take[ i] = true ;
24
64
tokens_used += tcnt;
25
65
} else if i==1 && msg. role == "user" {
26
- // we cannot drop the user message which comes right after the system message according to Antropic API
66
+ // we cannot drop the user message which comes right after the system message according to Anthropic API
27
67
message_take[ i] = true ;
28
- tokens_used += tcnt;
68
+ tokens_used += min ( tcnt, MESSAGE_TOKEN_LIMIT + 3 ) ;
29
69
} else if i >= last_user_msg_starts {
30
70
message_take[ i] = true ;
31
- tokens_used += tcnt;
71
+ tokens_used += min ( tcnt, MESSAGE_TOKEN_LIMIT + 3 ) ;
32
72
}
33
73
}
74
+
75
+ // Need to save uncompressed last messages of assistant, tool_calls and user between assistant. It could be patch tool calls
76
+ for i in ( 0 ..message_roles. len ( ) ) . rev ( ) {
77
+ if message_roles[ i] == "user" {
78
+ message_can_be_compressed[ i] = true ;
79
+ }
80
+ }
81
+
34
82
let mut log_buffer = Vec :: new ( ) ;
35
83
let mut dropped = false ;
36
84
37
85
for i in ( 0 ..messages. len ( ) ) . rev ( ) {
38
86
let tcnt = 3 + message_token_count[ i] ;
39
87
if !message_take[ i] {
40
- if tokens_used + tcnt < tokens_limit {
88
+ if message_can_be_compressed[ i] && tcnt > MESSAGE_TOKEN_LIMIT + 3 && tokens_used + MESSAGE_TOKEN_LIMIT + 3 < tokens_limit {
89
+ message_take[ i] = true ;
90
+ tokens_used += MESSAGE_TOKEN_LIMIT + 3 ;
91
+ log_buffer. push ( format ! ( "take compressed {:?}, tokens_used={} < {}" , crate :: nicer_logs:: first_n_chars( & messages[ i] . content. content_text_only( ) , 30 ) , tokens_used, tokens_limit) ) ;
92
+ } else if tokens_used + tcnt < tokens_limit {
41
93
message_take[ i] = true ;
42
94
tokens_used += tcnt;
43
95
log_buffer. push ( format ! ( "take {:?}, tokens_used={} < {}" , crate :: nicer_logs:: first_n_chars( & messages[ i] . content. content_text_only( ) , 30 ) , tokens_used, tokens_limit) ) ;
@@ -46,6 +98,7 @@ pub fn limit_messages_history(
46
98
dropped = true ;
47
99
break ;
48
100
}
101
+
49
102
} else {
50
103
message_take[ i] = true ;
51
104
log_buffer. push ( format ! ( "not allowed to drop {:?}, tokens_used={} < {}" , crate :: nicer_logs:: first_n_chars( & messages[ i] . content. content_text_only( ) , 30 ) , tokens_used, tokens_limit) ) ;
@@ -77,7 +130,16 @@ pub fn limit_messages_history(
77
130
tracing:: info!( "drop {:?} because of drop tool result rule" , crate :: nicer_logs:: first_n_chars( & messages[ i] . content. content_text_only( ) , 30 ) ) ;
78
131
}
79
132
}
80
-
81
- let messages_out: Vec < ChatMessage > = messages. iter ( ) . enumerate ( ) . filter ( |( i, _) | message_take[ * i] ) . map ( |( _, x) | x. clone ( ) ) . collect ( ) ;
133
+ let mut messages_out: Vec < ChatMessage > = Vec :: new ( ) ;
134
+ for i in 0 ..messages. len ( ) {
135
+ if message_take[ i] {
136
+ if message_can_be_compressed[ i] && message_token_count[ i] > MESSAGE_TOKEN_LIMIT {
137
+ messages_out. push ( compress_message ( & messages[ i] , t. tokenizer . clone ( ) ) ?) ;
138
+ } else {
139
+ messages_out. push ( messages[ i] . clone ( ) ) ;
140
+ }
141
+ }
142
+ }
143
+
82
144
Ok ( messages_out)
83
145
}
0 commit comments