Skip to content

Custom Sampling Pipelines #348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions LLama.Unittest/GrammarParserTest.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Text;
using LLama.Exceptions;
using LLama.Exceptions;
using LLama.Native;
using LLama.Grammars;

Expand Down
6 changes: 5 additions & 1 deletion LLama.Unittest/StatelessExecutorTest.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Diagnostics;
using LLama.Common;
using LLama.Sampling;
using Xunit.Abstractions;

namespace LLama.Unittest
Expand Down Expand Up @@ -30,10 +31,13 @@ public void Dispose()
[Fact]
public async Task Stateless()
{
// Create a custom pipeline that mimics the default pipeline
var pipeline = new DefaultSamplingPipeline();

var executor = new StatelessExecutor(_weights, _params);

const string question = "Question. what is a cat?\nAnswer: ";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };

var timer = new Stopwatch();
timer.Start();
Expand Down
10 changes: 8 additions & 2 deletions LLama.Web/Common/InferenceOptions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using LLama.Common;
#nullable enable

using LLama.Common;
using LLama.Abstractions;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Web.Common
{
Expand Down Expand Up @@ -64,6 +67,9 @@ public class InferenceOptions
/// <summary>
/// A grammar to constrain possible tokens
/// </summary>
public SafeLLamaGrammarHandle Grammar { get; set; } = null;
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }
}
}
6 changes: 6 additions & 0 deletions LLama/Abstractions/IInferenceParams.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Abstractions
{
Expand Down Expand Up @@ -108,5 +109,10 @@ public interface IInferenceParams
/// Grammar to constrain possible tokens
/// </summary>
SafeLLamaGrammarHandle? Grammar { get; set; }

/// <summary>
/// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b>
/// </summary>
ISamplingPipeline? SamplingPipeline { get; set; }
}
}
4 changes: 4 additions & 0 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Common
{
Expand Down Expand Up @@ -76,6 +77,9 @@ public record InferenceParams

/// <inheritdoc />
public SafeLLamaGrammarHandle? Grammar { get; set; }

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }
}

/// <summary>
Expand Down
12 changes: 12 additions & 0 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Runtime.InteropServices;
using LLama.Extensions;
using LLama.Abstractions;
using LLama.Sampling;
using Microsoft.Extensions.Logging;

namespace LLama
Expand Down Expand Up @@ -212,6 +213,17 @@ public void LoadState(State state)
}
}

/// <summary>
/// Sample a single token from this context, using the given sampling pipeline
/// </summary>
/// <param name="pipeline">The pipeline to use to process the logits and to select a token</param>
/// <param name="lastTokens">The tokens recently returned from the model</param>
/// <returns>The selected token</returns>
public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens)
{
return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
}

/// <summary>
/// Perform the sampling. Please don't use it unless you fully know what it does.
/// </summary>
Expand Down
26 changes: 17 additions & 9 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,24 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
SaveSessionFile(_pathSession);
}

var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
var mu = MirostatMu;
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
}

_last_n_tokens.Enqueue(id);

Expand Down
28 changes: 18 additions & 10 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,24 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
SaveSessionFile(_pathSession);
}

var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;
}

_last_n_tokens.Enqueue(id);

Expand Down
29 changes: 19 additions & 10 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;

namespace LLama
Expand Down Expand Up @@ -85,16 +86,24 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
}

// Decode this token into text
decoder.Add(id);
Expand Down
39 changes: 34 additions & 5 deletions LLama/Native/LLamaTokenDataArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,41 @@ public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits)
return new LLamaTokenDataArray(candidates);
}

/// <summary>
/// Overwrite the logit values for all given tokens
/// </summary>
/// <param name="values">tuples of token and logit value to overwrite</param>
public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values)
{
if (values.Length == 0)
return;

var dataSpan = data.Span;
foreach (var (token, value) in values)
{
for (var i = 0; i < data.Length; i++)
{
if (dataSpan[i].id == token)
{
dataSpan[i].logit = value;
break;
}
}
}
sorted = false;
}

#region sampling
/// <summary>
/// Apply grammar rules to candidate tokens
/// </summary>
/// <param name="ctx"></param>
/// <param name="grammar"></param>
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar)
{
if (grammar == null)
return;

using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
Expand Down Expand Up @@ -145,15 +172,17 @@ public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_ke
/// <param name="penalty_repeat"></param>
/// <param name="penalty_freq"></param>
/// <param name="penalty_present"></param>
public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
{
unsafe
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
using (var last_tokens_handle = last_tokens.Pin())
{
NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
sorted = st.sorted;
fixed (int* last_tokens_handle = last_tokens)
{
NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
sorted = st.sorted;
}
}
}
}
Expand Down
Loading