Skip to content

Commit 91d65a8

Browse files
committed
Improve beam_state by adding+using struct beam_view.
1 parent 4e702e6 commit 91d65a8

File tree

3 files changed

+53
-39
lines changed

3 files changed

+53
-39
lines changed

examples/beam_search/beam_search.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,24 @@
2727
#include <signal.h>
2828
#endif
2929

30+
// Used for debugging to print out beam tokens.
31+
struct ostream_beam_view {
32+
llama_context* ctx;
33+
beam_view bv;
34+
};
35+
std::ostream& operator<<(std::ostream& os, ostream_beam_view const& obv) {
36+
os << "p(" << obv.bv.p << ") eos(" << std::boolalpha << obv.bv.eos() << ") tokens(";
37+
for (size_t i=0 ; i<obv.bv.n_tokens ; ++i) {
38+
os << llama_token_to_str(obv.ctx, obv.bv.tokens[i]);
39+
}
40+
return os << ')';
41+
}
42+
43+
// Put here anything you want back in beam_search_callback().
44+
struct beam_search_callback_state {
45+
llama_context* ctx;
46+
std::vector<llama_token>* response;
47+
};
3048

3149
// Function matching type llama_beam_search_callback_fn_t.
3250
// Custom callback example is called each time the beams lengths increase:
@@ -35,22 +53,27 @@
3553
// This is also called when the stop condition is met.
3654
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
3755
beam_search_control beam_search_callback(void* callback_state, beams_state const beams_state) {
56+
auto const state = *static_cast<beam_search_callback_state*>(callback_state);
3857
printf(","); // Show progress
3958
if (size_t const n = beams_state.common_prefix_length) {
40-
auto* response = static_cast<std::vector<llama_token>*>(callback_state);
41-
response->resize(response->size() + n);
59+
state.response->resize(state.response->size() + n);
4260
assert(0u < beams_state.n_beams);
43-
std::copy(beams_state.beams[0], beams_state.beams[0] + n, response->end() - n);
61+
llama_token const* tokens = beams_state.beam_views[0].tokens;
62+
std::copy(tokens, tokens + n, state.response->end() - n);
4463
printf("%lu", n);
4564
}
4665
fflush(stdout);
47-
#if 0 // DEBUG: print current beams for this iteration
48-
std::cout << "\n\nCurrent beams:\n";
49-
for (size_t j=0 ; j < beams.size() ; ++j) {
50-
std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl;
51-
}
66+
#if 1 // DEBUG: print current beams for this iteration
67+
std::cout << "\n\nCurrent beams:\n";
68+
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
69+
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
70+
}
5271
#endif
53-
return { beams_state.n_beams, false }; // Continue beam search.
72+
beam_search_control control {
73+
beams_state.n_beams, // = collapse_to. Any index out of range means do not collapse beams.
74+
false // = stop. Don't stop beam search.
75+
};
76+
return control;
5477
}
5578

5679
int main(int argc, char ** argv)
@@ -140,9 +163,10 @@ int main(int argc, char ** argv)
140163
n_past += tokens_list.size();
141164

142165
std::vector<llama_token> response;
166+
beam_search_callback_state callback_state{ctx, &response};
143167
size_t const beam_width = static_cast<size_t>(params.n_beams);
144168
int const n_predict = 256;
145-
llama_beam_search(ctx, beam_search_callback, &response, beam_width, n_past, n_predict, params.n_threads);
169+
llama_beam_search(ctx, beam_search_callback, &callback_state, beam_width, n_past, n_predict, params.n_threads);
146170

147171
printf("\n\n");
148172
for (llama_token const token_id : response) {

llama.cpp

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@
5353
#include <sstream>
5454
#include <numeric>
5555

56-
#include <iostream>
57-
5856
#if defined(_MSC_VER)
5957
#pragma warning(disable: 4244 4267) // possible loss of data
6058
#endif
@@ -2880,7 +2878,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
28802878

28812879
struct beam {
28822880
std::vector<llama_token> tokens;
2883-
float p; // Cumulative beam probability (renormalized with each token)
2881+
float p; // Cumulative beam probability (renormalized relative to all beams)
28842882
// end-of-sentence
28852883
bool eos() const { return !tokens.empty() && tokens.back() == llama_token_eos(); }
28862884
// Shift off first n tokens and discard them.
@@ -2890,19 +2888,6 @@ struct beam {
28902888
}
28912889
};
28922890

2893-
// Used for debugging to print out beam tokens.
2894-
struct ostream_beam {
2895-
llama_context* ctx;
2896-
beam& b;
2897-
};
2898-
std::ostream& operator<<(std::ostream& os, ostream_beam const& osb) {
2899-
os << "p(" << osb.b.p << ") eos(" << std::boolalpha << osb.b.eos() << ") tokens(";
2900-
for (llama_token const token_id : osb.b.tokens) {
2901-
os << llama_token_to_str(osb.ctx, token_id);
2902-
}
2903-
return os << ')';
2904-
}
2905-
29062891
// A struct for calculating logit-related info.
29072892
struct logit_info {
29082893
float const* const logits;
@@ -2962,18 +2947,16 @@ struct beam_search {
29622947
// true iff llama_eval() has been called with non-empty common prefix in current loop iteration.
29632948
bool common_prefix_evaluated;
29642949

2965-
// Memory used by beam_state
2966-
std::vector<size_t> beam_lengths;
2967-
std::vector<llama_token const*> beam_ptrs;
2950+
// Temporary memory used by beams_state to pass back via callback.
2951+
std::vector<beam_view> beam_views;
29682952

29692953
beam_search(llama_context * ctx, size_t beam_width, int n_past, int n_predict, int n_threads)
29702954
: ctx(ctx)
29712955
, beam_width(beam_width)
29722956
, n_past(n_past)
29732957
, n_predict(n_predict)
29742958
, n_threads(n_threads)
2975-
, beam_lengths(beam_width)
2976-
, beam_ptrs(beam_width) {
2959+
, beam_views(beam_width) {
29772960
beams.reserve(beam_width);
29782961
next_beams.reserve(beam_width);
29792962
}
@@ -3074,11 +3057,10 @@ struct beam_search {
30743057
// Side effect: set common_prefix_length = find_common_prefix_length();
30753058
beams_state get_beams_state(bool const last_call) {
30763059
for (size_t i=0 ; i<beams.size() ; ++i) {
3077-
beam_lengths[i] = beams[i].tokens.size();
3078-
beam_ptrs[i] = beams[i].tokens.data();
3060+
beam_views[i] = beam_view{beams[i].tokens.data(), beams[i].tokens.size(), beams[i].p};
30793061
}
30803062
common_prefix_length = find_common_prefix_length();
3081-
return {beams.size(), beam_lengths.data(), beam_ptrs.data(), common_prefix_length, last_call};
3063+
return {beam_views.data(), beams.size(), common_prefix_length, last_call};
30823064
}
30833065

30843066
// Loop:

llama.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -443,16 +443,24 @@ extern "C" {
443443
/// @details Accepts the sampled token into the grammar
444444
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
445445

446+
// Lightweight view of a beam
447+
struct beam_view {
448+
llama_token const* tokens;
449+
size_t n_tokens;
450+
float p; // Cumulative beam probability (renormalized relative to all beams)
451+
// end-of-sentence
452+
bool eos() const { return n_tokens && tokens[n_tokens-1u] == llama_token_eos(); }
453+
};
454+
446455
// Passed to beam_search_callback function.
447456
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
448457
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
449458
// These pointers are valid only during the synchronous callback, so should not be saved.
450459
struct beams_state {
451-
size_t n_beams; // Number of elements in beam_lengths[] and beams[].
452-
size_t const* beam_lengths; // Length of each beam.
453-
llama_token const* const* beams; // Current tokens in each beam.
454-
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
455-
bool last_call; // True iff this is the last callback invocation.
460+
beam_view* beam_views; // View of each beam.
461+
size_t n_beams; // Number of elements in beam_views[].
462+
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
463+
bool last_call; // True iff this is the last callback invocation.
456464
};
457465
// Must be returned by beam_search_callback function.
458466
struct beam_search_control {

0 commit comments

Comments
 (0)