3636#include < ctime>
3737#include < cinttypes>
3838#include < fstream>
39- #include < functional>
4039#include < random>
4140#include < map>
4241#include < unordered_map>
@@ -2876,7 +2875,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
28762875 ctx->t_sample_us += ggml_time_us () - t_start_sample_us;
28772876}
28782877
2879- struct beam {
2878+ struct llama_beam {
28802879 std::vector<llama_token> tokens;
28812880 float p; // Cumulative beam probability (renormalized relative to all beams)
28822881 // end-of-sentence
@@ -2939,16 +2938,16 @@ struct beam_search {
29392938 int n_past;
29402939 int n_predict;
29412940 int n_threads;
2942- std::vector<beam > beams;
2943- std::vector<beam > next_beams;
2941+ std::vector<llama_beam > beams;
2942+ std::vector<llama_beam > next_beams;
29442943
29452944 // Re-calculated on each loop iteration
29462945 size_t common_prefix_length;
29472946 // true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
29482947 bool common_prefix_evaluated;
29492948
2950- // Temporary memory used by beams_state to pass back via callback.
2951- std::vector<beam_view > beam_views;
2949+ // Temporary memory used by llama_beams_state to pass back via callback.
2950+ std::vector<llama_beam_view > beam_views;
29522951
29532952 beam_search (llama_context * ctx, size_t beam_width, int n_past, int n_predict, int n_threads)
29542953 : ctx(ctx)
@@ -2974,32 +2973,32 @@ struct beam_search {
29742973 // * Gather elements until the vector is full, then call std::make_heap() on it.
29752974 // * If the heap is full and a new element is found that should be included, pop the
29762975 // least element to the back(), replace it with the new, then push it into the heap.
2977- void fill_next_beams_by_top_probabilities (beam& b ) {
2976+ void fill_next_beams_by_top_probabilities (llama_beam& beam ) {
29782977 // Min-heaps use a greater-than comparator.
2979- auto const comp = [](beam const & a, beam const & b) { return a.p > b.p ; };
2978+ auto const comp = [](llama_beam const & a, llama_beam const & b) { return a.p > b.p ; };
29802979 if (common_prefix_evaluated) {
29812980 // llama_eval was already called during this iteration
29822981 // with the common token prefix, so shift it off this beam.
2983- b .shift_tokens (common_prefix_length);
2982+ beam .shift_tokens (common_prefix_length);
29842983 }
2985- if (b .eos ()) {
2984+ if (beam .eos ()) {
29862985 // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
29872986 if (next_beams.size () < beam_width) {
2988- next_beams.push_back (std::move (b ));
2987+ next_beams.push_back (std::move (beam ));
29892988 if (next_beams.size () == beam_width) {
29902989 std::make_heap (next_beams.begin (), next_beams.end (), comp);
29912990 }
2992- } else if (next_beams.front ().p < b .p ) {
2991+ } else if (next_beams.front ().p < beam .p ) {
29932992 std::pop_heap (next_beams.begin (), next_beams.end (), comp);
2994- next_beams.back () = std::move (b );
2993+ next_beams.back () = std::move (beam );
29952994 std::push_heap (next_beams.begin (), next_beams.end (), comp);
29962995 }
29972996 } else {
29982997 // beam is not at end-of-sentence, so branch with next top_k tokens.
2999- if (!b .tokens .empty ()) {
3000- llama_eval (ctx, b .tokens .data (), b .tokens .size (), n_past, n_threads);
2998+ if (!beam .tokens .empty ()) {
2999+ llama_eval (ctx, beam .tokens .data (), beam .tokens .size (), n_past, n_threads);
30013000 if (!common_prefix_evaluated && common_prefix_length) {
3002- b .shift_tokens (common_prefix_length);
3001+ beam .shift_tokens (common_prefix_length);
30033002 n_past += common_prefix_length;
30043003 common_prefix_evaluated = true ;
30053004 }
@@ -3009,7 +3008,7 @@ struct beam_search {
30093008 size_t i=0 ;
30103009 if (next_beams.size () < beam_width) {
30113010 for (; next_beams.size () < beam_width ; ++i) {
3012- beam next_beam = b ;
3011+ llama_beam next_beam = beam ;
30133012 next_beam.tokens .push_back (next_tokens[i].id );
30143013 next_beam.p *= logit_info.probability_from_logit (next_tokens[i].logit );
30153014 next_beams.push_back (std::move (next_beam));
@@ -3018,17 +3017,17 @@ struct beam_search {
30183017 } else {
30193018 for (; next_beams.front ().p == 0 .0f ; ++i) {
30203019 std::pop_heap (next_beams.begin (), next_beams.end (), comp);
3021- next_beams.back () = b ;
3020+ next_beams.back () = beam ;
30223021 next_beams.back ().tokens .push_back (next_tokens[i].id );
30233022 next_beams.back ().p *= logit_info.probability_from_logit (next_tokens[i].logit );
30243023 std::push_heap (next_beams.begin (), next_beams.end (), comp);
30253024 }
30263025 }
30273026 for (; i < beam_width ; ++i) {
3028- float const next_p = b .p * logit_info.probability_from_logit (next_tokens[i].logit );
3027+ float const next_p = beam .p * logit_info.probability_from_logit (next_tokens[i].logit );
30293028 if (next_beams.front ().p < next_p) {
30303029 std::pop_heap (next_beams.begin (), next_beams.end (), comp);
3031- next_beams.back () = b ;
3030+ next_beams.back () = beam ;
30323031 next_beams.back ().tokens .push_back (next_tokens[i].id );
30333032 next_beams.back ().p = next_p;
30343033 std::push_heap (next_beams.begin (), next_beams.end (), comp);
@@ -3055,9 +3054,9 @@ struct beam_search {
30553054
30563055 // Construct beams_state to send back to caller via the callback function.
30573056 // Side effect: set common_prefix_length = find_common_prefix_length();
3058- beams_state get_beams_state (bool const last_call) {
3057+ llama_beams_state get_beams_state (bool const last_call) {
30593058 for (size_t i=0 ; i<beams.size () ; ++i) {
3060- beam_views[i] = beam_view {beams[i].tokens .data (), beams[i].tokens .size (), beams[i].p };
3059+ beam_views[i] = llama_beam_view {beams[i].tokens .data (), beams[i].tokens .size (), beams[i].p };
30613060 }
30623061 common_prefix_length = find_common_prefix_length ();
30633062 return {beam_views.data (), beams.size (), common_prefix_length, last_call};
@@ -3070,10 +3069,10 @@ struct beam_search {
30703069 // (since all other beam probabilities can only decrease)
30713070 void loop (llama_beam_search_callback_fn_t const callback, void * const callback_state) {
30723071 beams.push_back ({{}, 1 .0f }); // Start with one empty beam w/ probability = 1.0.
3073- auto const not_eos = [](beam const & beam) { return !beam.eos (); };
3072+ auto const not_eos = [](llama_beam const & beam) { return !beam.eos (); };
30743073 for (int i=0 ; i<n_predict && std::any_of (beams.begin (),beams.end (),not_eos) &&
30753074 !beams[top_beam_index ()].eos () ; ++i) {
3076- beam_search_control const control = callback (callback_state, get_beams_state (false ));
3075+ llama_beam_search_control const control = callback (callback_state, get_beams_state (false ));
30773076 if (control.collapse_to < beams.size ()) {
30783077 // Caller has manually selected a specific beam. Collapse beams into it.
30793078 collapse_beams (control.collapse_to );
@@ -3082,30 +3081,30 @@ struct beam_search {
30823081 break ;
30833082 }
30843083 common_prefix_evaluated = false ;
3085- for (beam & beam : beams) {
3084+ for (llama_beam & beam : beams) {
30863085 fill_next_beams_by_top_probabilities (beam);
30873086 }
30883087 beams.swap (next_beams);
30893088 renormalize_beam_probabilities (beams);
3090- std::for_each (next_beams.begin (), next_beams.end (), [](beam & beam) { beam.p = 0 .0f ; });
3089+ std::for_each (next_beams.begin (), next_beams.end (), [](llama_beam & beam) { beam.p = 0 .0f ; });
30913090 }
30923091 collapse_beams (top_beam_index ());
30933092 callback (callback_state, get_beams_state (true ));
30943093 }
30953094
30963095 // As beams grow, the cumulative probabilities decrease.
30973096 // Renormalize them to avoid floating point underflow.
3098- static void renormalize_beam_probabilities (std::vector<beam >& beams) {
3099- auto const sum_p = [](float sum, beam & beam) { return sum + beam.p ; };
3097+ static void renormalize_beam_probabilities (std::vector<llama_beam >& beams) {
3098+ auto const sum_p = [](float sum, llama_beam & beam) { return sum + beam.p ; };
31003099 float const inv_sum = 1 .0f / std::accumulate (beams.begin (), beams.end (), 0 .0f , sum_p);
3101- std::for_each (beams.begin (), beams.end (), [=](beam & beam) { beam.p *= inv_sum; });
3100+ std::for_each (beams.begin (), beams.end (), [=](llama_beam & beam) { beam.p *= inv_sum; });
31023101 }
31033102
31043103 // Return index of highest ranking beam by (probability,eos()).
31053104 // In other words choose most probable beam. In case of ties, choose beam at end-of-sentence.
31063105 // Assumes beams is non-empty.
31073106 size_t top_beam_index () {
3108- auto const by_p_and_eos = [](beam const & a, beam const & b) {
3107+ auto const by_p_and_eos = [](llama_beam const & a, llama_beam const & b) {
31093108 return a.p < b.p || (a.p == b.p && a.eos () < b.eos ()); };
31103109 return std::max_element (beams.begin (), beams.end (), by_p_and_eos) - beams.begin ();
31113110 }
0 commit comments