@@ -129,7 +129,8 @@ struct ring_buffer {
129
129
};
130
130
131
131
// writes result in res, does not mutate cur
132
- static void llama_token_data_array_sort (const llama_token_data_array & cur, int k, std::vector<llama_token_data> & data) {
132
+ // reduces the size of cur_p to npartial, keeping only the top npartial elements
133
+ static void llama_token_data_array_partial_sort (const llama_token_data_array & cur, int npartial, std::vector<llama_token_data> & res) {
133
134
static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
134
135
return a.logit > b.logit ;
135
136
};
@@ -158,12 +159,12 @@ static void llama_token_data_array_sort(const llama_token_data_array & cur, int
158
159
int ib = nbuckets - 1 ;
159
160
for ( ; ib >= 0 ; --ib) {
160
161
nhave += histo[ib];
161
- if (nhave >= k ) {
162
+ if (nhave >= npartial ) {
162
163
break ;
163
164
}
164
165
}
165
- data .resize (nhave);
166
- auto * ptr = data .data ();
166
+ res .resize (nhave);
167
+ auto * ptr = res .data ();
167
168
bucket_ptrs.reserve (nbuckets - ib);
168
169
for (int j = nbuckets - 1 ; j >= ib; --j) {
169
170
bucket_ptrs.push_back (ptr);
@@ -176,32 +177,39 @@ static void llama_token_data_array_sort(const llama_token_data_array & cur, int
176
177
}
177
178
}
178
179
179
- ptr = data .data ();
180
+ ptr = res .data ();
180
181
int ndone = 0 ;
181
182
for (int j = nbuckets - 1 ; j > ib; --j) {
182
183
std::sort (ptr, ptr + histo[j], comp);
183
184
ptr += histo[j];
184
185
ndone += histo[j];
185
186
}
186
- std::partial_sort (ptr, ptr + k - ndone, ptr + histo[ib], comp);
187
+ std::partial_sort (ptr, ptr + npartial - ndone, ptr + histo[ib], comp);
187
188
}
188
189
189
- // buf is a helper buffer that can optionally be utilized
190
- static void llama_token_data_array_sort_inplace (llama_token_data_array * cur_p, int k ) {
190
+ // reduces the size of cur_p to npartial, keeping only the top npartial elements
191
+ static void llama_token_data_array_partial_sort_inplace (llama_token_data_array * cur_p, int npartial ) {
191
192
static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
192
193
return a.logit > b.logit ;
193
194
};
194
195
195
- if (k <= 128 ) {
196
- std::partial_sort (cur_p->data , cur_p->data + k, cur_p->data + cur_p->size , comp);
196
+ if (npartial <= 128 ) {
197
+ std::partial_sort (cur_p->data , cur_p->data + npartial, cur_p->data + cur_p->size , comp);
198
+
199
+ cur_p->size = npartial;
200
+ cur_p->sorted = true ;
201
+
197
202
return ;
198
203
}
199
204
200
205
std::vector<llama_token_data> tmp;
201
206
202
- llama_token_data_array_sort (*cur_p, k, tmp);
207
+ llama_token_data_array_partial_sort (*cur_p, npartial, tmp);
208
+
209
+ std::copy (tmp.data (), tmp.data () + npartial, cur_p->data );
203
210
204
- std::copy (tmp.data (), tmp.data () + k, cur_p->data );
211
+ cur_p->size = npartial;
212
+ cur_p->sorted = true ;
205
213
}
206
214
207
215
static int llama_sample_dist (llama_token_data_array * cur_p, std::mt19937 & rng) {
@@ -281,8 +289,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_s
281
289
282
290
// Sort the logits in descending order if requested
283
291
if (do_sort && !cur_p->sorted ) {
284
- llama_token_data_array_sort_inplace (cur_p, cur_p->size );
285
- cur_p->sorted = true ;
292
+ llama_token_data_array_partial_sort_inplace (cur_p, cur_p->size );
286
293
}
287
294
288
295
float max_l = cur_p->data [0 ].logit ;
@@ -318,8 +325,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
318
325
319
326
// Sort scores in descending order
320
327
if (!cur_p->sorted ) {
321
- llama_token_data_array_sort_inplace (cur_p, k);
322
- cur_p->sorted = true ;
328
+ llama_token_data_array_partial_sort_inplace (cur_p, k);
323
329
}
324
330
325
331
cur_p->size = k;
@@ -722,12 +728,11 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
722
728
// if not sorted, try adaptive top-k sorting
723
729
if (!cur_p->sorted && cur_p->size > 1024 ) {
724
730
k = std::min<size_t >(256 , cur_p->size );
725
- llama_token_data_array_sort (*cur_p, k, buf_sort);
731
+ llama_token_data_array_partial_sort (*cur_p, k, buf_sort);
726
732
pdata = buf_sort.data ();
727
733
} else if (!cur_p->sorted ) {
728
734
// small candidates -> sort inplace
729
- llama_token_data_array_sort_inplace (cur_p, k);
730
- cur_p->sorted = true ;
735
+ llama_token_data_array_partial_sort_inplace (cur_p, k);
731
736
}
732
737
733
738
// Compute the cumulative probabilities
@@ -747,7 +752,7 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
747
752
// we exceeded the current top-k heuristic -> increase k and continue
748
753
if (!cur_p->sorted && i == k - 1 ) {
749
754
k = cur_p->size ;
750
- llama_token_data_array_sort (*cur_p, k, buf_sort);
755
+ llama_token_data_array_partial_sort (*cur_p, k, buf_sort);
751
756
pdata = buf_sort.data ();
752
757
}
753
758
}
@@ -838,8 +843,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
838
843
if (!min_p_applied) {
839
844
// Sort the logits in descending order
840
845
if (!cur_p->sorted ) {
841
- llama_token_data_array_sort_inplace (cur_p, cur_p->size );
842
- cur_p->sorted = true ;
846
+ llama_token_data_array_partial_sort_inplace (cur_p, cur_p->size );
843
847
}
844
848
845
849
const float min_logit = cur_p->data [0 ].logit + logf (ctx->p ); // min logit for p_i >= p * p_max
0 commit comments