Skip to content

Commit cdedb70

Browse files
authored
sampling : optimize dist sampler (#15704)
ggml-ci
1 parent 2c8dac7 commit cdedb70

File tree

1 file changed

+65
-2
lines changed

1 file changed

+65
-2
lines changed

src/llama-sampling.cpp

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,10 +604,73 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
604604
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
605605
auto * ctx = (llama_sampler_dist *) smpl->ctx;
606606

607-
// sorting is not necessary here
608-
llama_sampler_softmax_impl(cur_p, false);
607+
// edge cases
608+
if (cur_p->size == 0) {
609+
cur_p->selected = -1;
610+
return;
611+
}
612+
613+
cur_p->selected = 0;
614+
615+
if (cur_p->size == 1) {
616+
cur_p->data[0].p = 1.0f;
617+
return;
618+
}
619+
620+
// max logit for numerical stability
621+
float max_l = cur_p->data[0].logit;
622+
if (!cur_p->sorted) {
623+
for (size_t i = 1; i < cur_p->size; ++i) {
624+
max_l = std::max(max_l, cur_p->data[i].logit);
625+
}
626+
}
627+
628+
// apply softmax to obtain the probabilities
629+
double sum_cum = 0.0f;
630+
for (size_t i = 0; i < cur_p->size; ++i) {
631+
float p = expf(cur_p->data[i].logit - max_l);
632+
cur_p->data[i].p = p;
633+
sum_cum += p;
634+
}
635+
636+
#if 1
637+
// sample from the obtained probabilities and normalize the probs in a single pass
638+
// this is ~3x faster on Mac with full gpt-oss vocab than the version below
639+
//
640+
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
641+
const double rnd = dist(ctx->rng);
642+
643+
double sum_run = 0.0f;
644+
const double sum_tgt = sum_cum*rnd;
645+
646+
bool found = false;
647+
for (size_t i = 0; i < cur_p->size; ++i) {
648+
if (!found) {
649+
// accumulate probs until we reach the target sum
650+
sum_run += cur_p->data[i].p;
651+
if (sum_run >= sum_tgt) {
652+
cur_p->selected = i;
653+
found = true;
654+
}
655+
}
656+
657+
// normalize probs
658+
cur_p->data[i].p /= sum_cum;
659+
}
660+
661+
// fallback to the last token (don't think this can happen)
662+
assert(found);
663+
if (!found) {
664+
cur_p->selected = cur_p->size - 1;
665+
}
666+
#else
667+
// for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
668+
for (size_t i = 0; i < cur_p->size; ++i) {
669+
cur_p->data[i].p /= sum_cum;
670+
}
609671
610672
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
673+
#endif
611674
}
612675

613676
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {

0 commit comments

Comments
 (0)