@@ -604,10 +604,73 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
604
604
static void llama_sampler_dist_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
605
605
auto * ctx = (llama_sampler_dist *) smpl->ctx ;
606
606
607
- // sorting is not necessary here
608
- llama_sampler_softmax_impl (cur_p, false );
607
+ // edge cases
608
+ if (cur_p->size == 0 ) {
609
+ cur_p->selected = -1 ;
610
+ return ;
611
+ }
612
+
613
+ cur_p->selected = 0 ;
614
+
615
+ if (cur_p->size == 1 ) {
616
+ cur_p->data [0 ].p = 1 .0f ;
617
+ return ;
618
+ }
619
+
620
+ // max logit for numerical stability
621
+ float max_l = cur_p->data [0 ].logit ;
622
+ if (!cur_p->sorted ) {
623
+ for (size_t i = 1 ; i < cur_p->size ; ++i) {
624
+ max_l = std::max (max_l, cur_p->data [i].logit );
625
+ }
626
+ }
627
+
628
+ // apply softmax to obtain the probabilities
629
+ double sum_cum = 0 .0f ;
630
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
631
+ float p = expf (cur_p->data [i].logit - max_l);
632
+ cur_p->data [i].p = p;
633
+ sum_cum += p;
634
+ }
635
+
636
+ #if 1
637
+ // sample from the obtained probabilities and normalize the probs in a single pass
638
+ // this is ~3x faster on Mac with full gpt-oss vocab than the version below
639
+ //
640
+ std::uniform_real_distribution<double > dist (0 .0f , 1 .0f );
641
+ const double rnd = dist (ctx->rng );
642
+
643
+ double sum_run = 0 .0f ;
644
+ const double sum_tgt = sum_cum*rnd;
645
+
646
+ bool found = false ;
647
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
648
+ if (!found) {
649
+ // accumulate probs until we reach the target sum
650
+ sum_run += cur_p->data [i].p ;
651
+ if (sum_run >= sum_tgt) {
652
+ cur_p->selected = i;
653
+ found = true ;
654
+ }
655
+ }
656
+
657
+ // normalize probs
658
+ cur_p->data [i].p /= sum_cum;
659
+ }
660
+
661
+ // fallback to the last token (don't think this can happen)
662
+ assert (found);
663
+ if (!found) {
664
+ cur_p->selected = cur_p->size - 1 ;
665
+ }
666
+ #else
667
+ // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
668
+ for (size_t i = 0; i < cur_p->size; ++i) {
669
+ cur_p->data[i].p /= sum_cum;
670
+ }
609
671
610
672
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
673
+ #endif
611
674
}
612
675
613
676
static struct llama_sampler * llama_sampler_dist_clone (const struct llama_sampler * smpl) {
0 commit comments