Skip to content

Commit 348d565

Browse files
committed
Merge branch 'master' into phi-1
2 parents af9cd93 + 708e179 commit 348d565

38 files changed

+1595
-1067
lines changed

.github/workflows/docker.yml

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,44 @@ jobs:
5252
username: ${{ github.repository_owner }}
5353
password: ${{ secrets.GITHUB_TOKEN }}
5454

55+
# https://github.com/jlumbroso/free-disk-space/tree/54081f138730dfa15788a46383842cd2f914a1be#example
56+
- name: Free Disk Space (Ubuntu)
57+
uses: jlumbroso/free-disk-space@main
58+
with:
59+
# this might remove tools that are actually needed,
60+
# if set to "true" but frees about 6 GB
61+
tool-cache: false
62+
63+
# all of these default to true, but feel free to set to
64+
# "false" if necessary for your workflow
65+
android: true
66+
dotnet: true
67+
haskell: true
68+
large-packages: true
69+
docker-images: true
70+
swap-storage: true
71+
72+
- name: Determine tag name
73+
id: tag
74+
shell: bash
75+
run: |
76+
BUILD_NUMBER="$(git rev-list --count HEAD)"
77+
SHORT_HASH="$(git rev-parse --short=7 HEAD)"
78+
if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
79+
echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT
80+
else
81+
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
82+
echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
83+
fi
84+
5585
- name: Build and push Docker image (versioned)
5686
if: github.event_name == 'push'
5787
uses: docker/build-push-action@v4
5888
with:
5989
context: .
6090
push: true
6191
platforms: ${{ matrix.config.platforms }}
62-
tags: "ghcr.io/ggerganov/llama.cpp:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
92+
tags: "ghcr.io/${{ github.repository_owner }}/llama.cpp:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
6393
file: ${{ matrix.config.dockerfile }}
6494

6595
- name: Build and push Docker image (tagged)
@@ -68,5 +98,5 @@ jobs:
6898
context: .
6999
push: ${{ github.event_name == 'push' }}
70100
platforms: ${{ matrix.config.platforms }}
71-
tags: "ghcr.io/ggerganov/llama.cpp:${{ matrix.config.tag }}"
101+
tags: "ghcr.io/${{ github.repository_owner }}/llama.cpp:${{ matrix.config.tag }},ghcr.io/${{ github.repository_owner }}/llama.cpp:${{ matrix.config.tag }}-${{ steps.tag.outputs.name }}"
72102
file: ${{ matrix.config.dockerfile }}

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ models-mnt
4949
/llama-bench
5050
/llava-cli
5151
/lookahead
52+
/lookup
5253
/main
5354
/metal
5455
/perplexity

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
9191
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
9292
"llama: max. batch size for using peer access")
9393
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
94+
option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF)
9495
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
9596
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
9697
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
@@ -377,6 +378,9 @@ if (LLAMA_HIPBLAS)
377378
if (${hipblas_FOUND} AND ${hip_FOUND})
378379
message(STATUS "HIP and hipBLAS found")
379380
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
381+
if (LLAMA_HIP_UMA)
382+
add_compile_definitions(GGML_HIP_UMA)
383+
endif()
380384
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
381385
if (BUILD_SHARED_LIBS)
382386
set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)

Makefile

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
BUILD_TARGETS = \
33
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
44
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
5-
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead tests/test-c.o
5+
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup tests/test-c.o
66

77
# Binaries only useful for tests
88
TEST_TARGETS = \
@@ -65,7 +65,7 @@ test: $(TEST_TARGETS)
6565
./$$test_target; \
6666
fi; \
6767
if [ $$? -ne 0 ]; then \
68-
printf 'Test $$test_target FAILED!\n\n' $$test_target; \
68+
printf 'Test %s FAILED!\n\n' $$test_target; \
6969
failures=$$(( failures + 1 )); \
7070
else \
7171
printf 'Test %s passed.\n\n' $$test_target; \
@@ -282,8 +282,17 @@ endif
282282
ifneq ($(filter aarch64%,$(UNAME_M)),)
283283
# Apple M1, M2, etc.
284284
# Raspberry Pi 3, 4, Zero 2 (64-bit)
285+
# Nvidia Jetson
285286
MK_CFLAGS += -mcpu=native
286287
MK_CXXFLAGS += -mcpu=native
288+
JETSON_RELEASE_INFO = $(shell jetson_release)
289+
ifdef JETSON_RELEASE_INFO
290+
ifneq ($(filter TX2%,$(JETSON_RELEASE_INFO)),)
291+
JETSON_EOL_MODULE_DETECT = 1
292+
CC = aarch64-unknown-linux-gnu-gcc
293+
cxx = aarch64-unknown-linux-gnu-g++
294+
endif
295+
endif
287296
endif
288297

289298
ifneq ($(filter armv6%,$(UNAME_M)),)
@@ -357,10 +366,13 @@ ifdef LLAMA_BLIS
357366
endif # LLAMA_BLIS
358367

359368
ifdef LLAMA_CUBLAS
360-
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
361-
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
369+
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include -I/usr/local/cuda/targets/aarch64-linux/include
370+
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib
362371
OBJS += ggml-cuda.o
363-
MK_NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
372+
MK_NVCCFLAGS = -use_fast_math
373+
ifndef JETSON_EOL_MODULE_DETECT
374+
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
375+
endif # JETSON_EOL_MODULE_DETECT
364376

365377
ifdef LLAMA_DEBUG
366378
MK_NVCCFLAGS += -lineinfo
@@ -417,7 +429,11 @@ ifdef LLAMA_CUDA_CCBIN
417429
MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
418430
endif
419431
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
432+
ifdef JETSON_EOL_MODULE_DETECT
433+
$(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
434+
else
420435
$(NVCC) $(BASE_CXXFLAGS) $(NVCCFLAGS) -Wno-pedantic -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
436+
endif # JETSON_EOL_MODULE_DETECT
421437
endif # LLAMA_CUBLAS
422438

423439
ifdef LLAMA_CLBLAST
@@ -452,6 +468,9 @@ ifdef LLAMA_HIPBLAS
452468
LLAMA_CUDA_MMV_Y ?= 1
453469
LLAMA_CUDA_KQUANTS_ITER ?= 2
454470
MK_CPPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
471+
ifdef LLAMA_HIP_UMA
472+
MK_CPPFLAGS += -DGGML_HIP_UMA
473+
endif # LLAMA_HIP_UMA
455474
MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
456475
MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas
457476
HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
@@ -606,7 +625,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
606625
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
607626
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) -Wno-cast-qual
608627

609-
gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS)
628+
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
610629
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
611630

612631
train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS)
@@ -645,6 +664,9 @@ parallel: examples/parallel/parallel.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
645664
lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
646665
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
647666

667+
lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
668+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
669+
648670
ifdef LLAMA_METAL
649671
metal: examples/metal/metal.cpp ggml.o $(OBJS)
650672
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)

README.md

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ as the main playground for developing new features for the [ggml](https://github
123123
- Clojure: [phronmophobic/llama.clj](https://github.com/phronmophobic/llama.clj)
124124
- React Native: [mybigday/llama.rn](https://github.com/mybigday/llama.rn)
125125
- Java: [kherud/java-llama.cpp](https://github.com/kherud/java-llama.cpp)
126+
- Zig: [deins/llama.cpp.zig](https://github.com/Deins/llama.cpp.zig)
126127

127128
**UI:**
128129

@@ -395,6 +396,9 @@ Building the program with BLAS support may lead to some performance improvements
395396
- #### cuBLAS
396397
397398
This provides BLAS acceleration using the CUDA cores of your Nvidia GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager (e.g. `apt install nvidia-cuda-toolkit`) or from here: [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads).
399+
400+
For Jetson user, if you have Jetson Orin, you can try this: [Offical Support](https://www.jetson-ai-lab.com/tutorial_text-generation.html). If you are using an old model(nano/TX2), need some additional operations before compiling.
401+
398402
- Using `make`:
399403
```bash
400404
make LLAMA_CUBLAS=1
@@ -432,14 +436,21 @@ Building the program with BLAS support may lead to some performance improvements
432436
```bash
433437
make LLAMA_HIPBLAS=1
434438
```
435-
- Using `CMake` for Linux:
439+
- Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU):
436440
```bash
437-
mkdir build
438-
cd build
439-
CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DLLAMA_HIPBLAS=ON
440-
cmake --build .
441+
CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ \
442+
cmake -H. -Bbuild -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
443+
&& cmake --build build -- -j 16
441444
```
442-
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS):
445+
On Linux it is also possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting `-DLLAMA_HIP_UMA=ON"`.
446+
However, this hurts performance for non-integrated GPUs (but enables working with integrated GPUs).
447+
448+
- Using `make` (example for target gfx1030, build with 16 CPU threads):
449+
```bash
450+
make -j16 LLAMA_HIPBLAS=1 LLAMA_HIP_UMA=1 AMDGPU_TARGETS=gxf1030
451+
```
452+
453+
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU):
443454
```bash
444455
set PATH=%HIP_PATH%\bin;%PATH%
445456
mkdir build
@@ -448,10 +459,11 @@ Building the program with BLAS support may lead to some performance improvements
448459
cmake --build .
449460
```
450461
Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
462+
Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`.
451463
452464
453465
The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
454-
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 or 11.0.0 on RDNA3.
466+
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3.
455467
The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above):
456468
457469
| Option | Legal values | Default | Description |
@@ -982,6 +994,8 @@ docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /
982994
- There are no strict rules for the code style, but try to follow the patterns in the code (indentation, spaces, etc.). Vertical alignment makes things more readable and easier to batch edit
983995
- Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a`
984996
- See [good first issues](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions
997+
- Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices
998+
- Matrix multiplication is unconventional: [`z = ggml_mul_mat(ctx, x, y)`](https://github.com/ggerganov/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means `zT = x @ yT`
985999

9861000
### Docs
9871001

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
920920
printf(" -m FNAME, --model FNAME\n");
921921
printf(" model path (default: %s)\n", params.model.c_str());
922922
printf(" -md FNAME, --model-draft FNAME\n");
923-
printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str());
923+
printf(" draft model for speculative decoding\n");
924924
printf(" -ld LOGDIR, --logdir LOGDIR\n");
925925
printf(" path under which to save YAML logs (no logging if unset)\n");
926926
printf(" --override-kv KEY=TYPE:VALUE\n");

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct gpt_params {
5151
int32_t n_ctx = 512; // context size
5252
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
5353
int32_t n_keep = 0; // number of tokens to keep from initial prompt
54-
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
54+
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
5555
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
5656
int32_t n_parallel = 1; // number of parallel sequences to decode
5757
int32_t n_sequences = 1; // number of sequences to decode
@@ -240,3 +240,4 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
240240

241241
// Dump the KV cache view showing individual sequences in each cell (long output).
242242
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
243+

common/sampling.cpp

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,12 @@ static void sampler_queue(
149149
}
150150
}
151151

152-
llama_token llama_sampling_sample(
152+
static llama_token llama_sampling_sample_impl(
153153
struct llama_sampling_context * ctx_sampling,
154154
struct llama_context * ctx_main,
155155
struct llama_context * ctx_cfg,
156-
const int idx) {
156+
const int idx,
157+
bool is_resampling) { // Add a parameter to indicate if we are resampling
157158
const llama_sampling_params & params = ctx_sampling->params;
158159

159160
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@@ -173,8 +174,17 @@ llama_token llama_sampling_sample(
173174

174175
llama_token id = 0;
175176

177+
// Get a pointer to the logits
176178
float * logits = llama_get_logits_ith(ctx_main, idx);
177179

180+
// Declare original_logits at the beginning of the function scope
181+
std::vector<float> original_logits;
182+
183+
if (!is_resampling) {
184+
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
185+
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
186+
}
187+
178188
// apply params.logit_bias map
179189
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
180190
logits[it->first] += it->second;
@@ -193,12 +203,14 @@ llama_token llama_sampling_sample(
193203
}
194204

195205
// apply penalties
196-
if (!prev.empty()) {
206+
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
207+
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
208+
if (penalty_tokens_used_size) {
197209
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
198210

199211
llama_sample_repetition_penalties(ctx_main, &cur_p,
200-
prev.data() + prev.size() - penalty_last_n,
201-
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
212+
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
213+
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
202214

203215
if (!penalize_nl) {
204216
for (size_t idx = 0; idx < cur_p.size; idx++) {
@@ -210,7 +222,8 @@ llama_token llama_sampling_sample(
210222
}
211223
}
212224

213-
if (ctx_sampling->grammar != NULL) {
225+
// If we are in the resampling phase, apply grammar checks before sampling logic
226+
if (is_resampling && ctx_sampling->grammar != NULL) {
214227
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
215228
}
216229

@@ -252,9 +265,40 @@ llama_token llama_sampling_sample(
252265
}
253266
}
254267

268+
if (ctx_sampling->grammar != NULL && !is_resampling) {
269+
// Create an array with a single token data element for the sampled id
270+
llama_token_data single_token_data = {id, logits[id], 0.0f};
271+
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
272+
273+
// Apply grammar constraints to the single token
274+
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
275+
276+
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
277+
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
278+
279+
// If the token is not valid according to the grammar, perform resampling
280+
if (!is_valid) {
281+
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
282+
283+
// Restore logits from the copy
284+
std::copy(original_logits.begin(), original_logits.end(), logits);
285+
286+
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
287+
}
288+
}
289+
255290
return id;
256291
}
257292

293+
llama_token llama_sampling_sample(
294+
struct llama_sampling_context * ctx_sampling,
295+
struct llama_context * ctx_main,
296+
struct llama_context * ctx_cfg,
297+
const int idx) {
298+
// Call the implementation function with is_resampling set to false by default
299+
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
300+
}
301+
258302
void llama_sampling_accept(
259303
struct llama_sampling_context * ctx_sampling,
260304
struct llama_context * ctx_main,

common/sampling.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ typedef struct llama_sampling_params {
3636
float cfg_scale = 1.f; // how strong is guidance
3737

3838
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
39+
40+
std::vector<llama_token> penalty_prompt_tokens;
41+
bool use_penalty_prompt_tokens = false;
3942
} llama_sampling_params;
4043

4144
// general sampler context

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ else()
3333
add_subdirectory(simple)
3434
add_subdirectory(speculative)
3535
add_subdirectory(lookahead)
36+
add_subdirectory(lookup)
3637
add_subdirectory(train-text-from-scratch)
3738
if (LLAMA_METAL)
3839
add_subdirectory(metal)

0 commit comments

Comments
 (0)