#include "common.h" #include "llama.h" #include enum batch_method { add, get_one, single, }; gpt_params mistal_instruct_params(std::string_view model) { gpt_params params = gpt_params(); params = gpt_params(); params.n_ctx = 8192; params.n_keep = 1; params.n_predict = 2048; params.n_gpu_layers = 33; params.n_threads = 6; params.model = model; params.n_gpu_layers = 33; params.input_prefix = "[INST]"; params.input_suffix = "[/INST]"; params.n_threads = 6; params.n_batch = 896; params.instruct = true; params.sparams.penalty_repeat = 1.15f; params.sparams.penalty_last_n = 128; params.sparams.penalize_nl = false; params.seed = static_cast(time( 0 )); return params; } bool parse_params( int argc, char** argv, batch_method &method, gpt_params ¶ms ) { if( argc < 3 ) { std::cout << "Need at least two parameters:" << std::endl; std::cout << " Method: 'old'|'new'|'single'" << std::endl; std::cout << " Mode;: 'path/to/mistral-78-instruct-v0.2.Q6_K.gguf" << std::endl; return false; } if( _strcmpi( argv[1], "old" ) == 0 ) { method = batch_method::get_one; } else if( _strcmpi( argv[1], "new" ) == 0 ) { method = batch_method::add; } else if( _strcmpi( argv[1], "single" ) == 0 ) { method = batch_method::single; } else { return false; } params = mistal_instruct_params(argv[2]); if( method == batch_method::single ) { params.n_batch = 1; } return true; } std::string get_prompt_string() { return "[INST] You are a world famous fantasy writer. " "Your books are known for extremely detailed descriptions of settings and characters. " "They are usually about thrilling adventure in the midst of a fantasy world filled with humans, elves, dwarves, and other fantasy races that you invented. " "Together, we are describing scenes from your books. " "I will prompt you about various objects and characters in your books. " "You will describe them in great detail. " "You should respond as if it is a paragraph written in your book." "[/ INST] Understood.< / s>[INST] Describe the three major kingdoms of the fantasy world.[/ INST] "; } std::vector get_prompt_tokens(llama_model *model) { std::vector result; const std::string text = get_prompt_string(); const size_t max_tokens = text.size() + 1; if( max_tokens > INT_MAX ) { return result; } result.resize( max_tokens ); int32_t token_count = llama_tokenize( model, text.data(), static_cast(text.size()), result.data(), static_cast(result.size()), true, true ); if( token_count < 0 ) { result.resize( -token_count ); size_t check_value = llama_tokenize( model, text.data(), static_cast(text.size()), result.data(), static_cast(result.size()), true, true ); if( check_value != -token_count ) { result.clear(); return result; } } else { result.resize( token_count ); // shrink to fit } return result; } int main(int argc, char** argv) { // ------------------------------------------------------------------------ // Parse Params batch_method method; gpt_params params; if( !parse_params( argc, argv, method, params ) ) { return 1; } // ------------------------------------------------------------------------ // Init Llama and relevant data llama_backend_init(); llama_numa_init( params.numa ); llama_model_params model_params = llama_model_params_from_gpt_params( params ); llama_model* model = llama_load_model_from_file( params.model.c_str(), model_params ); if( !model ) { std::cout << "Error: Failed to load model: " << params.model << std::endl; return 2; } llama_context_params ctx_params = llama_context_params_from_gpt_params( params ); llama_context *ctx = llama_new_context_with_model( model, ctx_params ); llama_sampling_context *ctx_sampling = llama_sampling_init( params.sparams ); std::vector tokens = get_prompt_tokens(model); const int32_t prompt_token_count = static_cast(tokens.size()); const int32_t batch_size = params.n_batch; llama_batch batch = llama_batch_init( batch_size, 0, 1 ); // ------------------------------------------------------------------------ // Consume the prompt for( int32_t batch_first_token = 0; batch_first_token < prompt_token_count; batch_first_token += batch_size ) { const int32_t remaining = prompt_token_count - batch_first_token; int32_t count_to_eval = std::min( remaining, batch_size ); int32_t decode_result = 0; if( method == batch_method::get_one ) { llama_batch temp_batch = llama_batch_get_one( &(tokens[batch_first_token]), static_cast(count_to_eval), static_cast(batch_first_token), 0 ); decode_result = llama_decode( ctx, temp_batch ); } else { llama_batch_clear( batch ); for( int32_t batch_index = 0; batch_index < count_to_eval; ++batch_index ) { int32_t token_index = batch_first_token + batch_index; llama_batch_add( batch, tokens[token_index], static_cast(token_index), { 0 }, false ); } batch.logits[0] = true; // sampling later asserts unless first value of the logits array is valid const bool is_last_batch = batch_first_token + count_to_eval == tokens.size(); if( is_last_batch ) { batch.logits[count_to_eval - 1] = true; // Mark for sampling later } decode_result = llama_decode( ctx, batch ); } if( decode_result ) { std::cout << "Failed to eval, return code: " << decode_result; return 2; } for( int i = 0; i < count_to_eval; ++i ) { size_t token_index = batch_first_token + i; llama_sampling_accept( ctx_sampling, ctx, tokens[token_index], false ); std::cout << llama_token_to_piece( ctx, tokens[token_index] ); } } // ------------------------------------------------------------------------ // Pump some output to demonstrate differences const int32_t to_generate = 100; for( int32_t i = 0; i < to_generate; ++i ) { const llama_token token = llama_sampling_sample( ctx_sampling, ctx, nullptr ); llama_sampling_accept( ctx_sampling, ctx, token, true ); if( token == llama_token_eos( model ) ) { std::cout << std::endl; break; } std::cout << llama_token_to_piece( ctx, token ); int32_t token_index = prompt_token_count + i; llama_batch_clear( batch ); llama_batch_add( batch, token, static_cast(token_index), { 0 }, true ); llama_decode( ctx, batch ); } return 0; }