Skip to content

Commit 9fa0876

Browse files
sampling: separate rng per sampling context
1 parent b1a1891 commit 9fa0876

File tree

5 files changed

+29
-6
lines changed

5 files changed

+29
-6
lines changed

common/sampling.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "sampling.h"
2+
#include <random>
23

34
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
45
struct llama_sampling_context * result = new llama_sampling_context();
@@ -62,6 +63,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
6263
ctx->cur.clear();
6364
}
6465

66+
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
67+
if (seed == LLAMA_DEFAULT_SEED) {
68+
seed = time(NULL);
69+
}
70+
ctx->rng.seed(seed);
71+
}
72+
6573
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
6674
if (dst->grammar) {
6775
llama_grammar_free(dst->grammar);
@@ -203,7 +211,7 @@ static llama_token llama_sampling_sample_impl(
203211

204212
sampler_queue(ctx_main, params, cur_p, min_keep);
205213

206-
id = llama_sample_token(ctx_main, &cur_p);
214+
id = llama_sample_token_with_rng(ctx_main, &cur_p, &ctx_sampling->rng);
207215

208216
//{
209217
// const int n_top = 10;

common/sampling.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "grammar-parser.h"
66

7+
#include <random>
78
#include <string>
89
#include <vector>
910
#include <unordered_map>
@@ -79,6 +80,8 @@ struct llama_sampling_context {
7980
// TODO: replace with ring-buffer
8081
std::vector<llama_token> prev;
8182
std::vector<llama_token_data> cur;
83+
84+
std::mt19937 rng;
8285
};
8386

8487
#include "common.h"
@@ -93,6 +96,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
9396
// - reset grammar
9497
void llama_sampling_reset(llama_sampling_context * ctx);
9598

99+
// Set the sampler seed
100+
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
101+
96102
// Copy the sampler context
97103
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
98104

examples/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1028,7 +1028,7 @@ struct server_context {
10281028
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
10291029
return false;
10301030
}
1031-
llama_set_rng_seed(ctx, slot.params.seed);
1031+
llama_sampling_set_rng_seed(slot.ctx_sampling, slot.params.seed);
10321032
}
10331033

10341034
slot.command = SLOT_COMMAND_LOAD_PROMPT;

llama.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13478,7 +13478,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
1347813478
return result;
1347913479
}
1348013480

13481-
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
13481+
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, void * rng) {
1348213482
GGML_ASSERT(ctx);
1348313483

1348413484
const int64_t t_start_sample_us = ggml_time_us();
@@ -13491,8 +13491,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
1349113491
}
1349213492

1349313493
std::discrete_distribution<> dist(probs.begin(), probs.end());
13494-
auto & rng = ctx->rng;
13495-
int idx = dist(rng);
13494+
int idx = dist(*((std::mt19937 *) rng));
1349613495

1349713496
llama_token result = candidates->data[idx].id;
1349813497

@@ -13501,6 +13500,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
1350113500
return result;
1350213501
}
1350313502

13503+
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
13504+
return llama_sample_token_with_rng(ctx, candidates, &ctx->rng);
13505+
}
13506+
1350413507
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
1350513508
const int64_t t_start_sample_us = ggml_time_us();
1350613509

llama.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,13 @@ extern "C" {
987987
struct llama_context * ctx,
988988
llama_token_data_array * candidates);
989989

990-
/// @details Randomly selects a token from the candidates based on their probabilities.
990+
/// @details Randomly selects a token from the candidates based on their probabilities using a given pointer to a std::mt19937.
991+
LLAMA_API llama_token llama_sample_token_with_rng(
992+
struct llama_context * ctx,
993+
llama_token_data_array * candidates,
994+
void * rng);
995+
996+
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
991997
LLAMA_API llama_token llama_sample_token(
992998
struct llama_context * ctx,
993999
llama_token_data_array * candidates);

0 commit comments

Comments
 (0)