Skip to content

Commit ff252ea

Browse files
wwoodsTMl3utterflypi6am
authored
llama : add DRY sampler (#9702)
* sampling : add DRY sampler (post-refactor) * DRY: Trying to fix coauthors, removed unneeded line * DRY: Fixed redundant code * DRY: Fixed crash issue due to DRY being in chain but uninitialized --------- Co-authored-by: l3utterfly <[email protected]> Co-authored-by: pi6am <[email protected]>
1 parent d80fb71 commit ff252ea

17 files changed

+713
-63
lines changed

common/arg.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
251251
for (auto & antiprompt : params.antiprompt) {
252252
string_process_escapes(antiprompt);
253253
}
254+
for (auto & seq_breaker : params.sparams.dry_sequence_breakers) {
255+
string_process_escapes(seq_breaker);
256+
}
254257
}
255258

256259
if (!params.kv_overrides.empty()) {
@@ -997,6 +1000,64 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
9971000
params.sparams.penalty_freq = std::stof(value);
9981001
}
9991002
).set_sparam());
1003+
add_opt(common_arg(
1004+
{"--dry-multiplier"}, "N",
1005+
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier),
1006+
[](common_params & params, const std::string & value) {
1007+
params.sparams.dry_multiplier = std::stof(value);
1008+
}
1009+
).set_sparam());
1010+
add_opt(common_arg(
1011+
{"--dry-base"}, "N",
1012+
string_format("set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base),
1013+
[](common_params & params, const std::string & value) {
1014+
float potential_base = std::stof(value);
1015+
if (potential_base >= 1.0f)
1016+
{
1017+
params.sparams.dry_base = potential_base;
1018+
}
1019+
}
1020+
).set_sparam());
1021+
add_opt(common_arg(
1022+
{"--dry-allowed-length"}, "N",
1023+
string_format("set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length),
1024+
[](common_params & params, int value) {
1025+
params.sparams.dry_allowed_length = value;
1026+
}
1027+
).set_sparam());
1028+
add_opt(common_arg(
1029+
{"--dry-penalty-last-n"}, "N",
1030+
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n),
1031+
[](common_params & params, int value) {
1032+
params.sparams.dry_penalty_last_n = value;
1033+
}
1034+
).set_sparam());
1035+
add_opt(common_arg(
1036+
{"--dry-sequence-breaker"}, "STRING",
1037+
string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n",
1038+
params.sparams.dry_sequence_breakers.empty() ? "none" :
1039+
std::accumulate(std::next(params.sparams.dry_sequence_breakers.begin()),
1040+
params.sparams.dry_sequence_breakers.end(),
1041+
std::string("'") + (params.sparams.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sparams.dry_sequence_breakers[0]) + "'",
1042+
[](const std::string& a, const std::string& b) {
1043+
std::string formatted_b = (b == "\n") ? "\\n" : b;
1044+
return a + ", '" + formatted_b + "'";
1045+
}).c_str()),
1046+
[](common_params & params, const std::string & value) {
1047+
static bool defaults_cleared = false;
1048+
1049+
if (!defaults_cleared) {
1050+
params.sparams.dry_sequence_breakers.clear();
1051+
defaults_cleared = true;
1052+
}
1053+
1054+
if (value == "none") {
1055+
params.sparams.dry_sequence_breakers.clear();
1056+
} else {
1057+
params.sparams.dry_sequence_breakers.emplace_back(value);
1058+
}
1059+
}
1060+
).set_sparam());
10001061
add_opt(common_arg(
10011062
{"--dynatemp-range"}, "N",
10021063
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),

common/common.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,6 +2006,10 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
20062006
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
20072007
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
20082008
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
2009+
fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
2010+
fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
2011+
fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
2012+
fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
20092013
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
20102014
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
20112015
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);

common/common.h

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,15 @@ enum llama_example {
8484

8585
enum common_sampler_type {
8686
COMMON_SAMPLER_TYPE_NONE = 0,
87-
COMMON_SAMPLER_TYPE_TOP_K = 1,
88-
COMMON_SAMPLER_TYPE_TOP_P = 2,
89-
COMMON_SAMPLER_TYPE_MIN_P = 3,
90-
COMMON_SAMPLER_TYPE_TFS_Z = 4,
91-
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
92-
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
93-
COMMON_SAMPLER_TYPE_XTC = 7,
94-
COMMON_SAMPLER_TYPE_INFILL = 8,
87+
COMMON_SAMPLER_TYPE_DRY = 1,
88+
COMMON_SAMPLER_TYPE_TOP_K = 2,
89+
COMMON_SAMPLER_TYPE_TOP_P = 3,
90+
COMMON_SAMPLER_TYPE_MIN_P = 4,
91+
COMMON_SAMPLER_TYPE_TFS_Z = 5,
92+
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
93+
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
94+
COMMON_SAMPLER_TYPE_XTC = 8,
95+
COMMON_SAMPLER_TYPE_INFILL = 9,
9596
};
9697

9798
// dimensionality reduction methods, used by cvector-generator
@@ -104,32 +105,39 @@ enum dimre_method {
104105
struct common_sampler_params {
105106
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
106107

107-
int32_t n_prev = 64; // number of previous tokens to remember
108-
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
109-
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
110-
int32_t top_k = 40; // <= 0 to use vocab size
111-
float top_p = 0.95f; // 1.0 = disabled
112-
float min_p = 0.05f; // 0.0 = disabled
113-
float xtc_probability = 0.00f; // 0.0 = disabled
114-
float xtc_threshold = 0.10f; // > 0.5 disables XTC
115-
float tfs_z = 1.00f; // 1.0 = disabled
116-
float typ_p = 1.00f; // typical_p, 1.0 = disabled
117-
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
118-
float dynatemp_range = 0.00f; // 0.0 = disabled
119-
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
120-
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
121-
float penalty_repeat = 1.00f; // 1.0 = disabled
122-
float penalty_freq = 0.00f; // 0.0 = disabled
123-
float penalty_present = 0.00f; // 0.0 = disabled
124-
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
125-
float mirostat_tau = 5.00f; // target entropy
126-
float mirostat_eta = 0.10f; // learning rate
127-
bool penalize_nl = false; // consider newlines as a repeatable token
128-
bool ignore_eos = false;
129-
bool no_perf = false; // disable performance metrics
108+
int32_t n_prev = 64; // number of previous tokens to remember
109+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
110+
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
111+
int32_t top_k = 40; // <= 0 to use vocab size
112+
float top_p = 0.95f; // 1.0 = disabled
113+
float min_p = 0.05f; // 0.0 = disabled
114+
float xtc_probability = 0.00f; // 0.0 = disabled
115+
float xtc_threshold = 0.10f; // > 0.5 disables XTC
116+
float tfs_z = 1.00f; // 1.0 = disabled
117+
float typ_p = 1.00f; // typical_p, 1.0 = disabled
118+
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
119+
float dynatemp_range = 0.00f; // 0.0 = disabled
120+
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
121+
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
122+
float penalty_repeat = 1.00f; // 1.0 = disabled
123+
float penalty_freq = 0.00f; // 0.0 = disabled
124+
float penalty_present = 0.00f; // 0.0 = disabled
125+
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
126+
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
127+
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
128+
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
129+
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
130+
float mirostat_tau = 5.00f; // target entropy
131+
float mirostat_eta = 0.10f; // learning rate
132+
bool penalize_nl = false; // consider newlines as a repeatable token
133+
bool ignore_eos = false;
134+
bool no_perf = false; // disable performance metrics
135+
136+
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
130137

131138

132139
std::vector<enum common_sampler_type> samplers = {
140+
COMMON_SAMPLER_TYPE_DRY,
133141
COMMON_SAMPLER_TYPE_TOP_K,
134142
COMMON_SAMPLER_TYPE_TFS_Z,
135143
COMMON_SAMPLER_TYPE_TYPICAL_P,

common/sampling.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,11 @@ std::string common_sampler_params::print() const {
130130

131131
snprintf(result, sizeof(result),
132132
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
133+
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
133134
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
134135
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
135136
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
137+
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
136138
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
137139
mirostat, mirostat_eta, mirostat_tau);
138140

@@ -174,6 +176,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
174176
if (params.mirostat == 0) {
175177
for (const auto & cnstr : params.samplers) {
176178
switch (cnstr) {
179+
case COMMON_SAMPLER_TYPE_DRY:
180+
{
181+
std::vector<const char*> c_breakers;
182+
c_breakers.reserve(params.dry_sequence_breakers.size());
183+
for (const auto& str : params.dry_sequence_breakers) {
184+
c_breakers.push_back(str.c_str());
185+
}
186+
187+
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
188+
}
189+
break;
177190
case COMMON_SAMPLER_TYPE_TOP_K:
178191
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
179192
break;
@@ -358,6 +371,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
358371

359372
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
360373
switch (cnstr) {
374+
case COMMON_SAMPLER_TYPE_DRY: return 'd';
361375
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
362376
case COMMON_SAMPLER_TYPE_TFS_Z: return 'f';
363377
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
@@ -372,6 +386,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
372386

373387
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
374388
switch (cnstr) {
389+
case COMMON_SAMPLER_TYPE_DRY: return "dry";
375390
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
376391
case COMMON_SAMPLER_TYPE_TFS_Z: return "tfs_z";
377392
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
@@ -386,6 +401,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
386401

387402
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
388403
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
404+
{ "dry", COMMON_SAMPLER_TYPE_DRY },
389405
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
390406
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
391407
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
@@ -434,6 +450,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
434450

435451
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
436452
std::unordered_map<char, common_sampler_type> sampler_name_map = {
453+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
437454
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
438455
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TFS_Z), COMMON_SAMPLER_TYPE_TFS_Z },
439456
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },

examples/main/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,30 @@ Use the `--no-penalize-nl` option to disable newline penalization when applying
187187

188188
Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
189189

190+
### DRY Repetition Penalty
191+
192+
DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)).
193+
194+
- `--dry-multiplier N`: Set the DRY sampling multiplier (default: 0.0, 0.0 = disabled).
195+
- `--dry-base N`: Set the DRY sampling base value (default: 1.75).
196+
- `--dry-allowed-length N`: Set the allowed length for DRY sampling (default: 2).
197+
- `--dry-penalty-last-n N`: Set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size).
198+
- `--dry-sequence-breaker STRING`: Add a sequence breaker for DRY sampling. Can be used more than once to add multiple sequence breakers. Using this clears out the default breakers, which consist of: `['\n', ':', '"', '*']`. If the string `"none"` is supplied, no sequence breakers are used.
199+
200+
The `dry-multiplier` option controls the strength of the DRY sampling effect. A value of 0.0 disables DRY sampling, while higher values increase its influence. A typical recommended value is 0.8.
201+
202+
The `dry-base` option sets the base value for the exponential penalty calculation in DRY sampling. Higher values lead to more aggressive penalization of repetitions.
203+
204+
The `dry-allowed-length` option sets the maximum length of repeated sequences that will not be penalized. Repetitions shorter than or equal to this length are not penalized, allowing for natural repetitions of short phrases or common words.
205+
206+
The `dry-penalty-last-n` option controls how many recent tokens to consider when applying the DRY penalty. A value of -1 considers the entire context. Use a positive value to limit the consideration to a specific number of recent tokens.
207+
208+
The `dry-sequence-breaker` option adds a single sequence breaker and can be used more than once to specify multiple sequence breakers. Sequence breakers interrupt sequence matching and break the input into parts where matching can be applied.
209+
210+
DRY sampling provides more nuanced control over text generation, particularly for reducing long-range repetitions and maintaining global coherence.
211+
212+
Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dry-penalty-last-n -1 --dry-sequence-breaker "—" --dry-sequence-breaker "##"`
213+
190214
### Top-K Sampling
191215

192216
- `--top-k N`: Limit the next token selection to the K most probable tokens (default: 40).

examples/server/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ The project is under active development, and we are [looking for feedback and co
114114
| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) |
115115
| `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) |
116116
| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) |
117+
| `--dry-multiplier N` | DRY sampling multiplier (default: 0.0, 0.0 = disabled) |
118+
| `--dry-base N` | DRY sampling base value (default: 1.75) |
119+
| `--dry-allowed-length N` | allowed length for DRY sampling (default: 2) |
120+
| `--dry-penalty-last-n N` | DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) |
121+
| `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers (`['\n', ':', '"', '*']`) in the process; use `"none"` to not use any sequence breakers
117122
| `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) |
118123
| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) |
119124
| `--mirostat N` | use Mirostat sampling.<br/>Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.<br/>(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) |
@@ -369,6 +374,16 @@ node index.js
369374

370375
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
371376

377+
`dry_multiplier`: Set the DRY (Don't Repeat Yourself) repetition penalty multiplier. Default: `0.0`, which is disabled.
378+
379+
`dry_base`: Set the DRY repetition penalty base value. Default: `1.75`
380+
381+
`dry_allowed_length`: Tokens that extend repetition beyond this receive exponentially increasing penalty: multiplier * base ^ (length of repeating sequence before token - allowed length). Default: `2`
382+
383+
`dry_penalty_last_n`: How many tokens to scan for repetitions. Default: `-1`, where `0` is disabled and `-1` is context size.
384+
385+
`dry_sequence_breakers`: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']`
386+
372387
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
373388

374389
`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`

0 commit comments

Comments
 (0)