Skip to content

Commit a8ba9f0

Browse files
authored
March Binary Update (#565)
* Updated binaries to llama.cpp `3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6` (build run: https://github.com/SciSharp/LLamaSharp/actions/runs/8118890586) * Added abort callback * Added properties to get/set thread count on `LLamaContext` * Fixed LLamaLogLevel numbering
1 parent 6f03d5a commit a8ba9f0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2773
-915
lines changed

LLama.Web/Common/ModelOptions.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,5 +103,11 @@ public class ModelOptions
103103

104104
/// <inheritdoc />
105105
public bool VocabOnly { get; set; }
106+
107+
/// <inheritdoc />
108+
public float DefragThreshold { get; set; }
109+
110+
/// <inheritdoc />
111+
public bool DoPooling { get; set; }
106112
}
107113
}

LLama/Abstractions/IContextParams.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,14 @@ public interface IContextParams
9898
/// Whether to disable offloading the KQV cache to the GPU
9999
/// </summary>
100100
bool NoKqvOffload { get; }
101+
102+
/// <summary>
103+
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt; 0 to disable (default)
104+
/// </summary>
105+
float DefragThreshold { get; }
106+
107+
/// <summary>
108+
/// Whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
109+
/// </summary>
110+
bool DoPooling { get; }
101111
}

LLama/Abstractions/IModelParams.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ public MetadataOverride(string key, int value)
251251
{
252252
Key = key;
253253
_valueInt = value;
254-
Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
254+
Type = LLamaModelKvOverrideType.Int;
255255
}
256256

257257
/// <summary>
@@ -263,7 +263,7 @@ public MetadataOverride(string key, float value)
263263
{
264264
Key = key;
265265
_valueFloat = value;
266-
Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
266+
Type = LLamaModelKvOverrideType.Float;
267267
}
268268

269269
/// <summary>
@@ -275,20 +275,20 @@ public MetadataOverride(string key, bool value)
275275
{
276276
Key = key;
277277
_valueBool = value;
278-
Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
278+
Type = LLamaModelKvOverrideType.Bool;
279279
}
280280

281281
internal void WriteValue(ref LLamaModelMetadataOverride dest)
282282
{
283283
switch (Type)
284284
{
285-
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT:
285+
case LLamaModelKvOverrideType.Int:
286286
dest.IntValue = _valueInt;
287287
break;
288-
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT:
288+
case LLamaModelKvOverrideType.Float:
289289
dest.FloatValue = _valueFloat;
290290
break;
291-
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL:
291+
case LLamaModelKvOverrideType.Bool:
292292
dest.BoolValue = _valueBool ? -1L : 0;
293293
break;
294294
default:
@@ -300,13 +300,13 @@ internal void WriteValue(Utf8JsonWriter writer)
300300
{
301301
switch (Type)
302302
{
303-
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT:
303+
case LLamaModelKvOverrideType.Int:
304304
writer.WriteNumberValue(_valueInt);
305305
break;
306-
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT:
306+
case LLamaModelKvOverrideType.Float:
307307
writer.WriteNumberValue(_valueFloat);
308308
break;
309-
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL:
309+
case LLamaModelKvOverrideType.Bool:
310310
writer.WriteBooleanValue(_valueBool);
311311
break;
312312
default:
@@ -328,9 +328,9 @@ public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConv
328328

329329
return ((LLamaModelKvOverrideType)ktv.Type) switch
330330
{
331-
LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()),
332-
LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()),
333-
LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()),
331+
LLamaModelKvOverrideType.Int => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()),
332+
LLamaModelKvOverrideType.Float => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()),
333+
LLamaModelKvOverrideType.Bool => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()),
334334
_ => throw new JsonException(),
335335
};
336336
}

LLama/Batched/Conversation.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,9 @@ public void Remove(LLamaPos start, int count)
262262
/// <param name="start">Start position (inclusive)</param>
263263
/// <param name="end">End position (exclusive)</param>
264264
/// <param name="delta">Amount to add on to each token position</param>
265-
public void Shift(LLamaPos start, LLamaPos end, int delta)
265+
public void Add(LLamaPos start, LLamaPos end, int delta)
266266
{
267-
_conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta);
267+
_conversation.Executor.Context.NativeHandle.KvCacheSequenceAdd(_conversation.ConversationId, start, end, delta);
268268
}
269269
#endregion
270270

LLama/Batched/ConversationExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public static void ShiftLeft(this Conversation conversation, int count, int keep
5050
kv.Remove(keep, count);
5151

5252
// Shift the C's
53-
kv.Shift(keep + count, end, -count);
53+
kv.Add(keep + count, end, -count);
5454

5555
// Update total count
5656
return end.Value - count;

LLama/Common/ModelParams.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ public record ModelParams
9393
/// <inheritdoc />
9494
public bool NoKqvOffload { get; set; }
9595

96+
/// <inheritdoc />
97+
public float DefragThreshold { get; set; }
98+
99+
/// <inheritdoc />
100+
public bool DoPooling { get; set; }
101+
96102
/// <inheritdoc />
97103
public bool VocabOnly { get; set; }
98104

LLama/Extensions/IContextParamsExtensions.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,17 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
3434
result.yarn_beta_fast = @params.YarnBetaFast ?? 32f;
3535
result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f;
3636
result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0;
37-
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.LLAMA_ROPE_SCALING_UNSPECIFIED;
37+
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.Unspecified;
38+
39+
result.defrag_threshold = @params.DefragThreshold;
3840

3941
result.cb_eval = IntPtr.Zero;
4042
result.cb_eval_user_data = IntPtr.Zero;
4143

4244
result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
4345
result.type_k = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
4446
result.offload_kqv = !@params.NoKqvOffload;
47+
result.do_pooling = @params.DoPooling;
4548

4649
result.n_threads = Threads(@params.Threads);
4750
result.n_threads_batch = Threads(@params.BatchThreads);

LLama/LLamaContext.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,35 @@ public sealed class LLamaContext
5656
/// </summary>
5757
public Encoding Encoding { get; }
5858

59+
private uint _generationThreads;
60+
private uint _batchThreads;
61+
62+
/// <summary>
63+
/// Get or set the number of threads to use for generation
64+
/// </summary>
65+
public uint GenerationThreads
66+
{
67+
get => _generationThreads;
68+
set
69+
{
70+
_generationThreads = value;
71+
NativeHandle.SetThreads(_generationThreads, _batchThreads);
72+
}
73+
}
74+
75+
/// <summary>
76+
/// Get or set the number of threads to use for batch processing
77+
/// </summary>
78+
public uint BatchThreads
79+
{
80+
get => _batchThreads;
81+
set
82+
{
83+
_batchThreads = value;
84+
NativeHandle.SetThreads(_generationThreads, _batchThreads);
85+
}
86+
}
87+
5988
/// <summary>
6089
/// Create a new LLamaContext for the given LLamaWeights
6190
/// </summary>
@@ -75,6 +104,10 @@ public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger
75104

76105
@params.ToLlamaContextParams(out var lparams);
77106
NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
107+
108+
// It's not possible to get these values from llama.cpp, store a copy of them here.
109+
_generationThreads = lparams.n_threads;
110+
_batchThreads = lparams.n_threads_batch;
78111
}
79112

80113
/// <summary>

LLama/LLamaQuantizer.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public static bool Quantize(string srcFileName, string dstFilename, string ftype
5959
private static bool ValidateFtype(LLamaFtype ftype)
6060
{
6161
// Validation copies from here:
62-
// https://github.com/ggerganov/llama.cpp/blob/d71ac90985854b0905e1abba778e407e17f9f887/llama.cpp#L9613
62+
// https://github.com/ggerganov/llama.cpp/blob/3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6/llama.cpp#L10965
6363

6464
switch (ftype)
6565
{
@@ -74,7 +74,7 @@ private static bool ValidateFtype(LLamaFtype ftype)
7474
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K_S:
7575
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K:
7676

77-
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_XS:
77+
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_K_XS:
7878
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_S:
7979
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_M:
8080
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_L:
@@ -89,8 +89,18 @@ private static bool ValidateFtype(LLamaFtype ftype)
8989

9090
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_XXS:
9191
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_XS:
92+
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_S:
93+
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_M:
9294

9395
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_XXS:
96+
97+
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ1_S:
98+
99+
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_NL:
100+
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_XS:
101+
102+
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_S:
103+
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_M:
94104
return true;
95105

96106
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:

LLama/LLamaStatelessExecutor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
134134
var n_discard = n_left / 2;
135135

136136
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
137-
NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
137+
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
138138

139139
n_past -= n_discard;
140140
}

0 commit comments

Comments
 (0)