Skip to content

Commit 6d2a38c

Browse files
committed
sampling : clarify purpose of partial sort helpers
ggml-ci
1 parent de2902d commit 6d2a38c

File tree

1 file changed

+26
-22
lines changed

1 file changed

+26
-22
lines changed

src/llama-sampling.cpp

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ struct ring_buffer {
129129
};
130130

131131
// writes result in res, does not mutate cur
132-
static void llama_token_data_array_sort(const llama_token_data_array & cur, int k, std::vector<llama_token_data> & data) {
132+
// reduces the size of cur_p to npartial, keeping only the top npartial elements
133+
static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector<llama_token_data> & res) {
133134
static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
134135
return a.logit > b.logit;
135136
};
@@ -158,12 +159,12 @@ static void llama_token_data_array_sort(const llama_token_data_array & cur, int
158159
int ib = nbuckets - 1;
159160
for ( ; ib >= 0; --ib) {
160161
nhave += histo[ib];
161-
if (nhave >= k) {
162+
if (nhave >= npartial) {
162163
break;
163164
}
164165
}
165-
data.resize(nhave);
166-
auto * ptr = data.data();
166+
res.resize(nhave);
167+
auto * ptr = res.data();
167168
bucket_ptrs.reserve(nbuckets - ib);
168169
for (int j = nbuckets - 1; j >= ib; --j) {
169170
bucket_ptrs.push_back(ptr);
@@ -176,32 +177,39 @@ static void llama_token_data_array_sort(const llama_token_data_array & cur, int
176177
}
177178
}
178179

179-
ptr = data.data();
180+
ptr = res.data();
180181
int ndone = 0;
181182
for (int j = nbuckets - 1; j > ib; --j) {
182183
std::sort(ptr, ptr + histo[j], comp);
183184
ptr += histo[j];
184185
ndone += histo[j];
185186
}
186-
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
187+
std::partial_sort(ptr, ptr + npartial - ndone, ptr + histo[ib], comp);
187188
}
188189

189-
// buf is a helper buffer that can optionally be utilized
190-
static void llama_token_data_array_sort_inplace(llama_token_data_array * cur_p, int k) {
190+
// reduces the size of cur_p to npartial, keeping only the top npartial elements
191+
static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) {
191192
static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
192193
return a.logit > b.logit;
193194
};
194195

195-
if (k <= 128) {
196-
std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
196+
if (npartial <= 128) {
197+
std::partial_sort(cur_p->data, cur_p->data + npartial, cur_p->data + cur_p->size, comp);
198+
199+
cur_p->size = npartial;
200+
cur_p->sorted = true;
201+
197202
return;
198203
}
199204

200205
std::vector<llama_token_data> tmp;
201206

202-
llama_token_data_array_sort(*cur_p, k, tmp);
207+
llama_token_data_array_partial_sort(*cur_p, npartial, tmp);
208+
209+
std::copy(tmp.data(), tmp.data() + npartial, cur_p->data);
203210

204-
std::copy(tmp.data(), tmp.data() + k, cur_p->data);
211+
cur_p->size = npartial;
212+
cur_p->sorted = true;
205213
}
206214

207215
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
@@ -281,8 +289,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_s
281289

282290
// Sort the logits in descending order if requested
283291
if (do_sort && !cur_p->sorted) {
284-
llama_token_data_array_sort_inplace(cur_p, cur_p->size);
285-
cur_p->sorted = true;
292+
llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
286293
}
287294

288295
float max_l = cur_p->data[0].logit;
@@ -318,8 +325,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
318325

319326
// Sort scores in descending order
320327
if (!cur_p->sorted) {
321-
llama_token_data_array_sort_inplace(cur_p, k);
322-
cur_p->sorted = true;
328+
llama_token_data_array_partial_sort_inplace(cur_p, k);
323329
}
324330

325331
cur_p->size = k;
@@ -722,12 +728,11 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
722728
// if not sorted, try adaptive top-k sorting
723729
if (!cur_p->sorted && cur_p->size > 1024) {
724730
k = std::min<size_t>(256, cur_p->size);
725-
llama_token_data_array_sort(*cur_p, k, buf_sort);
731+
llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
726732
pdata = buf_sort.data();
727733
} else if (!cur_p->sorted) {
728734
// small candidates -> sort inplace
729-
llama_token_data_array_sort_inplace(cur_p, k);
730-
cur_p->sorted = true;
735+
llama_token_data_array_partial_sort_inplace(cur_p, k);
731736
}
732737

733738
// Compute the cumulative probabilities
@@ -747,7 +752,7 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
747752
// we exceeded the current top-k heuristic -> increase k and continue
748753
if (!cur_p->sorted && i == k - 1) {
749754
k = cur_p->size;
750-
llama_token_data_array_sort(*cur_p, k, buf_sort);
755+
llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
751756
pdata = buf_sort.data();
752757
}
753758
}
@@ -838,8 +843,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
838843
if (!min_p_applied) {
839844
// Sort the logits in descending order
840845
if (!cur_p->sorted) {
841-
llama_token_data_array_sort_inplace(cur_p, cur_p->size);
842-
cur_p->sorted = true;
846+
llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
843847
}
844848

845849
const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max

0 commit comments

Comments
 (0)