|
27 | 27 | #include <signal.h>
|
28 | 28 | #endif
|
29 | 29 |
|
| 30 | +// Used for debugging to print out beam tokens. |
| 31 | +struct ostream_beam_view { |
| 32 | + llama_context* ctx; |
| 33 | + beam_view bv; |
| 34 | +}; |
| 35 | +std::ostream& operator<<(std::ostream& os, ostream_beam_view const& obv) { |
| 36 | + os << "p(" << obv.bv.p << ") eos(" << std::boolalpha << obv.bv.eos() << ") tokens("; |
| 37 | + for (size_t i=0 ; i<obv.bv.n_tokens ; ++i) { |
| 38 | + os << llama_token_to_str(obv.ctx, obv.bv.tokens[i]); |
| 39 | + } |
| 40 | + return os << ')'; |
| 41 | +} |
| 42 | + |
| 43 | +// Put here anything you want back in beam_search_callback(). |
| 44 | +struct beam_search_callback_state { |
| 45 | + llama_context* ctx; |
| 46 | + std::vector<llama_token>* response; |
| 47 | +}; |
30 | 48 |
|
31 | 49 | // Function matching type llama_beam_search_callback_fn_t.
|
32 | 50 | // Custom callback example is called each time the beams lengths increase:
|
|
35 | 53 | // This is also called when the stop condition is met.
|
36 | 54 | // Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
|
37 | 55 | beam_search_control beam_search_callback(void* callback_state, beams_state const beams_state) {
|
| 56 | + auto const state = *static_cast<beam_search_callback_state*>(callback_state); |
38 | 57 | printf(","); // Show progress
|
39 | 58 | if (size_t const n = beams_state.common_prefix_length) {
|
40 |
| - auto* response = static_cast<std::vector<llama_token>*>(callback_state); |
41 |
| - response->resize(response->size() + n); |
| 59 | + state.response->resize(state.response->size() + n); |
42 | 60 | assert(0u < beams_state.n_beams);
|
43 |
| - std::copy(beams_state.beams[0], beams_state.beams[0] + n, response->end() - n); |
| 61 | + llama_token const* tokens = beams_state.beam_views[0].tokens; |
| 62 | + std::copy(tokens, tokens + n, state.response->end() - n); |
44 | 63 | printf("%lu", n);
|
45 | 64 | }
|
46 | 65 | fflush(stdout);
|
47 |
| -#if 0 // DEBUG: print current beams for this iteration |
48 |
| - std::cout << "\n\nCurrent beams:\n"; |
49 |
| - for (size_t j=0 ; j < beams.size() ; ++j) { |
50 |
| - std::cout << "beams["<<j<<"]: " << ostream_beam{ctx,beams[j]} << std::endl; |
51 |
| - } |
| 66 | +#if 1 // DEBUG: print current beams for this iteration |
| 67 | + std::cout << "\n\nCurrent beams:\n"; |
| 68 | + for (size_t i=0 ; i < beams_state.n_beams ; ++i) { |
| 69 | + std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl; |
| 70 | + } |
52 | 71 | #endif
|
53 |
| - return { beams_state.n_beams, false }; // Continue beam search. |
| 72 | + beam_search_control control { |
| 73 | + beams_state.n_beams, // = collapse_to. Any index out of range means do not collapse beams. |
| 74 | + false // = stop. Don't stop beam search. |
| 75 | + }; |
| 76 | + return control; |
54 | 77 | }
|
55 | 78 |
|
56 | 79 | int main(int argc, char ** argv)
|
@@ -140,9 +163,10 @@ int main(int argc, char ** argv)
|
140 | 163 | n_past += tokens_list.size();
|
141 | 164 |
|
142 | 165 | std::vector<llama_token> response;
|
| 166 | + beam_search_callback_state callback_state{ctx, &response}; |
143 | 167 | size_t const beam_width = static_cast<size_t>(params.n_beams);
|
144 | 168 | int const n_predict = 256;
|
145 |
| - llama_beam_search(ctx, beam_search_callback, &response, beam_width, n_past, n_predict, params.n_threads); |
| 169 | + llama_beam_search(ctx, beam_search_callback, &callback_state, beam_width, n_past, n_predict, params.n_threads); |
146 | 170 |
|
147 | 171 | printf("\n\n");
|
148 | 172 | for (llama_token const token_id : response) {
|
|
0 commit comments