Skip to content

Commit d924e61

Browse files
sampling: separate rng per sampling context
1 parent b1a1891 commit d924e61

File tree

8 files changed

+33
-9
lines changed

8 files changed

+33
-9
lines changed

common/sampling.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "sampling.h"
2+
#include <random>
23

34
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
45
struct llama_sampling_context * result = new llama_sampling_context();
@@ -33,6 +34,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
3334

3435
result->prev.resize(params.n_prev);
3536

37+
llama_sampling_set_rng_seed(result, LLAMA_DEFAULT_SEED);
38+
3639
return result;
3740
}
3841

@@ -62,6 +65,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
6265
ctx->cur.clear();
6366
}
6467

68+
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
69+
if (seed == LLAMA_DEFAULT_SEED) {
70+
seed = time(NULL);
71+
}
72+
ctx->rng.seed(seed);
73+
}
74+
6575
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
6676
if (dst->grammar) {
6777
llama_grammar_free(dst->grammar);
@@ -203,7 +213,7 @@ static llama_token llama_sampling_sample_impl(
203213

204214
sampler_queue(ctx_main, params, cur_p, min_keep);
205215

206-
id = llama_sample_token(ctx_main, &cur_p);
216+
id = llama_sample_token_with_rng(ctx_main, &cur_p, &ctx_sampling->rng);
207217

208218
//{
209219
// const int n_top = 10;

common/sampling.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "grammar-parser.h"
66

7+
#include <random>
78
#include <string>
89
#include <vector>
910
#include <unordered_map>
@@ -79,6 +80,8 @@ struct llama_sampling_context {
7980
// TODO: replace with ring-buffer
8081
std::vector<llama_token> prev;
8182
std::vector<llama_token_data> cur;
83+
84+
std::mt19937 rng;
8285
};
8386

8487
#include "common.h"
@@ -93,6 +96,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
9396
// - reset grammar
9497
void llama_sampling_reset(llama_sampling_context * ctx);
9598

99+
// Set the sampler seed
100+
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
101+
96102
// Copy the sampler context
97103
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
98104

examples/lookup/lookup-stats.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ int main(int argc, char ** argv){
3030

3131
// load the model
3232
std::tie(model, ctx) = llama_init_from_gpt_params(params);
33-
llama_set_rng_seed(ctx, params.seed);
3433
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
3534

3635
// tokenize the prompt

examples/lookup/lookup.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ int main(int argc, char ** argv){
3838

3939
// load the model
4040
std::tie(model, ctx) = llama_init_from_gpt_params(params);
41-
llama_set_rng_seed(ctx, params.seed);
4241
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
4342

4443
// tokenize the prompt
@@ -108,6 +107,7 @@ int main(int argc, char ** argv){
108107
bool has_eos = false;
109108

110109
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
110+
llama_sampling_set_rng_seed(ctx_sampling, params.seed);
111111

112112
std::vector<llama_token> draft;
113113

examples/main/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ int main(int argc, char ** argv) {
240240
return 1;
241241
}
242242
session_tokens.resize(n_token_count_out);
243-
llama_set_rng_seed(ctx, params.seed);
244243
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
245244
}
246245
}
@@ -521,6 +520,7 @@ int main(int argc, char ** argv) {
521520
}
522521

523522
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
523+
llama_sampling_set_rng_seed(ctx_sampling, params.seed);
524524

525525
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
526526
// predict

examples/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1028,7 +1028,7 @@ struct server_context {
10281028
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
10291029
return false;
10301030
}
1031-
llama_set_rng_seed(ctx, slot.params.seed);
1031+
llama_sampling_set_rng_seed(slot.ctx_sampling, slot.params.seed);
10321032
}
10331033

10341034
slot.command = SLOT_COMMAND_LOAD_PROMPT;

llama.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13478,7 +13478,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
1347813478
return result;
1347913479
}
1348013480

13481-
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
13481+
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, void * rng) {
1348213482
GGML_ASSERT(ctx);
1348313483

1348413484
const int64_t t_start_sample_us = ggml_time_us();
@@ -13491,8 +13491,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
1349113491
}
1349213492

1349313493
std::discrete_distribution<> dist(probs.begin(), probs.end());
13494-
auto & rng = ctx->rng;
13495-
int idx = dist(rng);
13494+
int idx = dist(*((std::mt19937 *) rng));
1349613495

1349713496
llama_token result = candidates->data[idx].id;
1349813497

@@ -13501,6 +13500,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
1350113500
return result;
1350213501
}
1350313502

13503+
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
13504+
return llama_sample_token_with_rng(ctx, candidates, &ctx->rng);
13505+
}
13506+
1350413507
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
1350513508
const int64_t t_start_sample_us = ggml_time_us();
1350613509

llama.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,13 @@ extern "C" {
987987
struct llama_context * ctx,
988988
llama_token_data_array * candidates);
989989

990-
/// @details Randomly selects a token from the candidates based on their probabilities.
990+
/// @details Randomly selects a token from the candidates based on their probabilities using a given pointer to a std::mt19937.
991+
LLAMA_API llama_token llama_sample_token_with_rng(
992+
struct llama_context * ctx,
993+
llama_token_data_array * candidates,
994+
void * rng);
995+
996+
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
991997
LLAMA_API llama_token llama_sample_token(
992998
struct llama_context * ctx,
993999
llama_token_data_array * candidates);

0 commit comments

Comments
 (0)