@@ -139,6 +139,7 @@ struct slot_params {
139
139
140
140
json input_prefix;
141
141
json input_suffix;
142
+ json extra_context;
142
143
};
143
144
144
145
struct server_slot {
@@ -170,6 +171,7 @@ struct server_slot {
170
171
171
172
// when a task is submitted, we first tokenize the prompt and store it here
172
173
std::vector<llama_token> prompt_tokens;
174
+ std::vector<llama_token> extra_tokens;
173
175
174
176
std::string generated_text;
175
177
std::vector<llama_token> cache_tokens;
@@ -906,8 +908,26 @@ struct server_context {
906
908
}
907
909
908
910
// infill
909
- slot.params .input_prefix = json_value (data, " input_prefix" , default_params.input_prefix );
910
- slot.params .input_suffix = json_value (data, " input_suffix" , default_params.input_suffix );
911
+ slot.params .input_prefix = json_value (data, " input_prefix" , default_params.input_prefix );
912
+ slot.params .input_suffix = json_value (data, " input_suffix" , default_params.input_suffix );
913
+ slot.params .extra_context = json_value (data, " extra_context" , default_params.extra_context );
914
+
915
+ SLT_DBG (slot, " extra_context chunks: %d\n " , (int ) slot.params .extra_context .size ());
916
+ for (const auto & chunk : slot.params .extra_context ) {
917
+ // { "text": string, "filename": string }
918
+ if (!chunk.contains (" text" ) || !chunk[" text" ].is_string ()) {
919
+ send_error (task, " extra_context chunk must contain a \" text\" field with a string value" , ERROR_TYPE_INVALID_REQUEST);
920
+ return false ;
921
+ }
922
+
923
+ // filename is optional
924
+ if (chunk.contains (" filename" ) && !chunk[" filename" ].is_string ()) {
925
+ send_error (task, " extra_context chunk's \" filename\" field must be a string" , ERROR_TYPE_INVALID_REQUEST);
926
+ return false ;
927
+ }
928
+
929
+ SLT_DBG (slot, " extra_context chunk in file '%s':\n %s\n " , chunk.value (" filename" , " " ).c_str (), chunk.value (" text" , " " ).c_str ());
930
+ }
911
931
912
932
// get prompt
913
933
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
@@ -1934,13 +1954,66 @@ struct server_context {
1934
1954
} break ;
1935
1955
case SERVER_TASK_CMPL_TYPE_INFILL:
1936
1956
{
1957
+ // use FIM repo-level pattern:
1958
+ // ref: https://arxiv.org/pdf/2409.12186
1959
+ //
1960
+ // [FIM_REP]myproject
1961
+ // [FIM_SEP]filename0
1962
+ // extra chunk 0
1963
+ // [FIM_SEP]filename1
1964
+ // extra chunk 1
1965
+ // ...
1966
+ // [FIM_SEP]filename
1967
+ // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
1968
+ //
1937
1969
auto prefix_tokens = tokenize (slot.params .input_prefix , false , false );
1938
1970
auto suffix_tokens = tokenize (slot.params .input_suffix , false , false );
1939
1971
1940
- // for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
1941
- const int n_suffix_take = std::min<int >(suffix_tokens.size (), n_batch/4 );
1972
+ slot.extra_tokens .clear ();
1973
+ if (llama_token_fim_rep (model) != LLAMA_TOKEN_NULL) {
1974
+ static const auto k_fim_repo = tokenize (" myproject\n " , false , false );
1975
+
1976
+ slot.extra_tokens .push_back (llama_token_fim_rep (model));
1977
+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_repo.begin (), k_fim_repo.end ());
1978
+ }
1979
+
1980
+ for (const auto & chunk : slot.params .extra_context ) {
1981
+ // { "text": string, "filename": string }
1982
+ const std::string text = chunk.value (" text" , " " );
1983
+ const std::string filename = chunk.value (" filename" , " tmp" );
1984
+
1985
+ if (llama_token_fim_sep (model) != LLAMA_TOKEN_NULL) {
1986
+ const auto k_fim_file = tokenize (filename + " \n " , false , false );
1987
+
1988
+ slot.extra_tokens .insert (slot.extra_tokens .end (), llama_token_fim_sep (model));
1989
+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_file.begin (), k_fim_file.end ());
1990
+ } else {
1991
+ // chunk separator in binary form to avoid confusing the AI
1992
+ static const char k_chunk_prefix_str[] = {0x0a , 0x0a , 0x2d , 0x2d , 0x2d , 0x20 , 0x73 , 0x6e , 0x69 , 0x70 , 0x70 , 0x65 , 0x74 , 0x20 , 0x2d , 0x2d , 0x2d , 0x0a , 0x0a , 0x00 };
1993
+ static const auto k_chunk_prefix_tokens = tokenize (k_chunk_prefix_str, false , false );
1994
+
1995
+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_chunk_prefix_tokens.begin (), k_chunk_prefix_tokens.end ());
1996
+ }
1997
+
1998
+ const auto chunk_tokens = tokenize (text, false , false );
1999
+ slot.extra_tokens .insert (slot.extra_tokens .end (), chunk_tokens.begin (), chunk_tokens.end ());
2000
+ }
2001
+
2002
+ if (llama_token_fim_sep (model) != LLAMA_TOKEN_NULL) {
2003
+ // TODO: current filename
2004
+ static const auto k_fim_file = tokenize (" filename\n " , false , false );
2005
+
2006
+ slot.extra_tokens .insert (slot.extra_tokens .end (), llama_token_fim_sep (model));
2007
+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_file.begin (), k_fim_file.end ());
2008
+ }
2009
+
2010
+ // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
2011
+ const int n_suffix_take = std::min<int >(suffix_tokens.size (), (n_batch)/4 );
1942
2012
const int n_prefix_take = std::min<int >(prefix_tokens.size (), (n_batch - 3 ) - n_suffix_take);
1943
2013
2014
+ // fill the rest of the context with extra chunks
2015
+ const int n_extra_take = std::min<int >(std::max<int >(0 , slot.n_ctx - (n_batch) - 2 *slot.n_predict ), slot.extra_tokens .size ());
2016
+
1944
2017
prefix_tokens.erase (prefix_tokens.begin (), prefix_tokens.begin () + prefix_tokens.size () - n_prefix_take);
1945
2018
suffix_tokens.resize (n_suffix_take);
1946
2019
@@ -1954,6 +2027,11 @@ struct server_context {
1954
2027
embd_inp.insert (embd_inp.begin (), llama_token_bos (model));
1955
2028
}
1956
2029
2030
+ SLT_DBG (slot, " extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n " , slot.n_ctx , n_extra_take, (int ) slot.extra_tokens .size ());
2031
+
2032
+ // put the extra context before the FIM prefix
2033
+ embd_inp.insert (embd_inp.begin (), slot.extra_tokens .end () - n_extra_take, slot.extra_tokens .end ());
2034
+
1957
2035
embd_inp.insert (embd_inp.end (), embd_end.begin (), embd_end.end ());
1958
2036
embd_inp.push_back (llama_token_fim_mid (model));
1959
2037
@@ -2058,11 +2136,15 @@ struct server_context {
2058
2136
2059
2137
while (head_c < slot.cache_tokens .size () &&
2060
2138
head_p < prompt_tokens.size ()) {
2061
- if (llama_token_is_control (model, slot.cache_tokens [head_c])) {
2139
+ if (llama_token_is_control (model, slot.cache_tokens [head_c]) &&
2140
+ slot.cache_tokens [head_c] != llama_token_fim_rep (model) &&
2141
+ slot.cache_tokens [head_c] != llama_token_fim_sep (model)) {
2062
2142
break ;
2063
2143
}
2064
2144
2065
- if (llama_token_is_control (model, prompt_tokens[head_p])) {
2145
+ if (llama_token_is_control (model, prompt_tokens[head_p]) &&
2146
+ prompt_tokens[head_p] != llama_token_fim_rep (model) &&
2147
+ prompt_tokens[head_p] != llama_token_fim_sep (model)) {
2066
2148
break ;
2067
2149
}
2068
2150
@@ -2071,11 +2153,15 @@ struct server_context {
2071
2153
while (head_c + n_match < slot.cache_tokens .size () &&
2072
2154
head_p + n_match < prompt_tokens.size () &&
2073
2155
slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2074
- if (llama_token_is_control (model, slot.cache_tokens [head_c + n_match])) {
2156
+ if (llama_token_is_control (model, slot.cache_tokens [head_c + n_match]) &&
2157
+ slot.cache_tokens [head_c + n_match] != llama_token_fim_rep (model) &&
2158
+ slot.cache_tokens [head_c + n_match] != llama_token_fim_sep (model)) {
2075
2159
break ;
2076
2160
}
2077
2161
2078
- if (llama_token_is_control (model, prompt_tokens[head_p + n_match])) {
2162
+ if (llama_token_is_control (model, prompt_tokens[head_p + n_match]) &&
2163
+ prompt_tokens[head_p + n_match] != llama_token_fim_rep (model) &&
2164
+ prompt_tokens[head_p + n_match] != llama_token_fim_sep (model)) {
2079
2165
break ;
2080
2166
}
2081
2167
0 commit comments