Skip to content

Commit 22a648f

Browse files
committed
Merge branch 'master' into pr/7359
2 parents 9971c38 + f8c4c07 commit 22a648f

File tree

88 files changed

+1895
-1444
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+1895
-1444
lines changed

.devops/nix/package.nix

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,18 @@
1717
rocmPackages,
1818
vulkan-headers,
1919
vulkan-loader,
20-
clblast,
20+
curl,
2121
useBlas ? builtins.all (x: !x) [
2222
useCuda
2323
useMetalKit
24-
useOpenCL
2524
useRocm
2625
useVulkan
2726
] && blas.meta.available,
2827
useCuda ? config.cudaSupport,
29-
useMetalKit ? stdenv.isAarch64 && stdenv.isDarwin && !useOpenCL,
28+
useMetalKit ? stdenv.isAarch64 && stdenv.isDarwin,
3029
useMpi ? false, # Increases the runtime closure size by ~700M
31-
useOpenCL ? false,
3230
useRocm ? config.rocmSupport,
31+
enableCurl ? true,
3332
useVulkan ? false,
3433
llamaVersion ? "0.0.0", # Arbitrary version, substituted by the flake
3534

@@ -56,7 +55,6 @@ let
5655
++ lib.optionals useCuda [ "CUDA" ]
5756
++ lib.optionals useMetalKit [ "MetalKit" ]
5857
++ lib.optionals useMpi [ "MPI" ]
59-
++ lib.optionals useOpenCL [ "OpenCL" ]
6058
++ lib.optionals useRocm [ "ROCm" ]
6159
++ lib.optionals useVulkan [ "Vulkan" ];
6260

@@ -198,19 +196,19 @@ effectiveStdenv.mkDerivation (
198196
optionals effectiveStdenv.isDarwin darwinBuildInputs
199197
++ optionals useCuda cudaBuildInputs
200198
++ optionals useMpi [ mpi ]
201-
++ optionals useOpenCL [ clblast ]
202199
++ optionals useRocm rocmBuildInputs
203200
++ optionals useBlas [ blas ]
204-
++ optionals useVulkan vulkanBuildInputs;
201+
++ optionals useVulkan vulkanBuildInputs
202+
++ optionals enableCurl [ curl ];
205203

206204
cmakeFlags =
207205
[
208206
(cmakeBool "LLAMA_BUILD_SERVER" true)
209207
(cmakeBool "BUILD_SHARED_LIBS" (!enableStatic))
210208
(cmakeBool "CMAKE_SKIP_BUILD_RPATH" true)
209+
(cmakeBool "LLAMA_CURL" enableCurl)
211210
(cmakeBool "GGML_NATIVE" false)
212211
(cmakeBool "GGML_BLAS" useBlas)
213-
(cmakeBool "GGML_CLBLAST" useOpenCL)
214212
(cmakeBool "GGML_CUDA" useCuda)
215213
(cmakeBool "GGML_HIPBLAS" useRocm)
216214
(cmakeBool "GGML_METAL" useMetalKit)
@@ -254,7 +252,6 @@ effectiveStdenv.mkDerivation (
254252
useCuda
255253
useMetalKit
256254
useMpi
257-
useOpenCL
258255
useRocm
259256
useVulkan
260257
;
@@ -281,7 +278,7 @@ effectiveStdenv.mkDerivation (
281278
# Configurations we don't want even the CI to evaluate. Results in the
282279
# "unsupported platform" messages. This is mostly a no-op, because
283280
# cudaPackages would've refused to evaluate anyway.
284-
badPlatforms = optionals (useCuda || useOpenCL) lib.platforms.darwin;
281+
badPlatforms = optionals useCuda lib.platforms.darwin;
285282

286283
# Configurations that are known to result in build failures. Can be
287284
# overridden by importing Nixpkgs with `allowBroken = true`.

.github/ISSUE_TEMPLATE/config.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,3 @@ contact_links:
99
- name: Want to contribute?
1010
url: https://github.com/ggerganov/llama.cpp/wiki/contribute
1111
about: Head to the contribution guide page of the wiki for areas you can help with
12-
13-

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ endif()
4242

4343
option(BUILD_SHARED_LIBS "build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
4444

45+
if (WIN32)
46+
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
47+
endif()
48+
4549
#
4650
# option list
4751
#

CMakePresets.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"cacheVariables": {
2020
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
2121
"CMAKE_CXX_COMPILER": "icx",
22+
"CMAKE_C_COMPILER": "cl",
2223
"GGML_SYCL": "ON",
2324
"CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.."
2425
}

Makefile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ TEST_TARGETS = \
6262
tests/test-tokenizer-1-bpe \
6363
tests/test-tokenizer-1-spm
6464

65+
# Legacy build targets that were renamed in #7809, but should still be removed when the project is cleaned
66+
LEGACY_TARGETS = main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
67+
simple batched batched-bench save-load-state server gguf gguf-split eval-callback llama-bench libllava.a llava-cli baby-llama \
68+
retrieval speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm
69+
6570
# Deprecation aliases
6671
ifdef LLAMA_CUBLAS
6772
$(error LLAMA_CUBLAS is removed. Use GGML_CUDA instead.)
@@ -1086,6 +1091,7 @@ clean:
10861091
rm -vrf ggml/src/ggml-cuda/template-instances/*.o
10871092
rm -rvf $(BUILD_TARGETS)
10881093
rm -rvf $(TEST_TARGETS)
1094+
rm -rvf $(LEGACY_TARGETS)
10891095
find examples pocs -type f -name "*.o" -delete
10901096

10911097
#

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ Typically finetunes of the base models below are supported as well.
108108
- [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon)
109109
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
110110
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)
111+
- [X] [BERT](https://github.com/ggerganov/llama.cpp/pull/5423)
111112
- [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/)
112113
- [X] [Baichuan 1 & 2](https://huggingface.co/models?search=baichuan-inc/Baichuan) + [derivations](https://huggingface.co/hiyouga/baichuan-7b-sft)
113114
- [X] [Aquila 1 & 2](https://huggingface.co/models?search=BAAI/Aquila)
@@ -217,6 +218,11 @@ Unless otherwise noted these projects are open-source with permissive licensing:
217218
**Tools:**
218219

219220
- [akx/ggify](https://github.com/akx/ggify) – download PyTorch models from HuggingFace Hub and convert them to GGML
221+
- [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
222+
223+
**Infrastructure:**
224+
225+
- [Paddler](https://github.com/distantmagic/paddler) - Stateful load balancer custom-tailored for llama.cpp
220226

221227
---
222228

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
757757
params.cache_type_v = argv[++i];
758758
return true;
759759
}
760-
if (arg == "--multiline-input") {
760+
if (arg == "-mli" || arg == "--multiline-input") {
761761
params.multiline_input = true;
762762
return true;
763763
}

common/common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,4 +459,3 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
459459
void yaml_dump_non_result_info(
460460
FILE * stream, const gpt_params & params, const llama_context * lctx,
461461
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
462-

convert-hf-to-gguf-update.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ class TOKENIZER_TYPE(IntEnum):
8686
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
8787
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
8888
{"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B
89+
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
90+
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
91+
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
8992
]
9093

9194

@@ -272,7 +275,8 @@ def get_vocab_base_pre(self, tokenizer) -> str:
272275
"3333333",
273276
"33333333",
274277
"333333333",
275-
# "Cửa Việt", # llama-bpe fails on this
278+
"Cửa Việt", # llama-bpe fails on this
279+
" discards",
276280
chktxt,
277281
]
278282

convert-hf-to-gguf.py

Lines changed: 135 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
490490
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
491491
# ref: https://huggingface.co/LumiOpen/Viking-7B
492492
res = "viking"
493+
if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901":
494+
# ref: https://huggingface.co/core42/jais-13b
495+
res = "jais"
493496

494497
if res is None:
495498
logger.warning("\n")
@@ -576,7 +579,19 @@ def _set_vocab_qwen(self):
576579
special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
577580
special_vocab.add_to_gguf(self.gguf_writer)
578581

579-
def _set_vocab_sentencepiece(self):
582+
def _set_vocab_sentencepiece(self, add_to_gguf=True):
583+
tokens, scores, toktypes = self._create_vocab_sentencepiece()
584+
585+
self.gguf_writer.add_tokenizer_model("llama")
586+
self.gguf_writer.add_tokenizer_pre("default")
587+
self.gguf_writer.add_token_list(tokens)
588+
self.gguf_writer.add_token_scores(scores)
589+
self.gguf_writer.add_token_types(toktypes)
590+
591+
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
592+
special_vocab.add_to_gguf(self.gguf_writer)
593+
594+
def _create_vocab_sentencepiece(self):
580595
from sentencepiece import SentencePieceProcessor
581596

582597
tokenizer_path = self.dir_model / 'tokenizer.model'
@@ -638,14 +653,7 @@ def _set_vocab_sentencepiece(self):
638653
scores.append(-1000.0)
639654
toktypes.append(SentencePieceTokenTypes.UNUSED)
640655

641-
self.gguf_writer.add_tokenizer_model("llama")
642-
self.gguf_writer.add_tokenizer_pre("default")
643-
self.gguf_writer.add_token_list(tokens)
644-
self.gguf_writer.add_token_scores(scores)
645-
self.gguf_writer.add_token_types(toktypes)
646-
647-
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
648-
special_vocab.add_to_gguf(self.gguf_writer)
656+
return tokens, scores, toktypes
649657

650658
def _set_vocab_llama_hf(self):
651659
vocab = gguf.LlamaHfVocab(self.dir_model)
@@ -1979,7 +1987,7 @@ def set_gguf_parameters(self):
19791987
if len(rope_scaling_type) == 0:
19801988
raise KeyError('Missing the required key rope_scaling.type')
19811989

1982-
if rope_scaling_type == 'su':
1990+
if rope_scaling_type == 'su' or rope_scaling_type == 'longrope':
19831991
attn_factor = math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) if scale > 1.0 else 1.0
19841992
elif rope_scaling_type == 'yarn':
19851993
attn_factor = 0.1 * math.log(scale) + 1.0 if scale > 1.0 else 1.0
@@ -2353,6 +2361,8 @@ def set_vocab(self):
23532361
special_vocab._set_special_token("eot", 107)
23542362
special_vocab.add_to_gguf(self.gguf_writer)
23552363

2364+
self.gguf_writer.add_add_space_prefix(False)
2365+
23562366
def set_gguf_parameters(self):
23572367
hparams = self.hparams
23582368
block_count = hparams["num_hidden_layers"]
@@ -2390,7 +2400,20 @@ class Gemma2Model(Model):
23902400
model_arch = gguf.MODEL_ARCH.GEMMA2
23912401

23922402
def set_vocab(self):
2393-
self._set_vocab_llama_hf()
2403+
tokens, scores, toktypes = self._create_vocab_sentencepiece()
2404+
# hack: This is required so that we can properly use start/end-of-turn for chat template
2405+
for i in range(108):
2406+
# including <unusedX>, <start_of_turn>, <end_of_turn>
2407+
toktypes[i] = SentencePieceTokenTypes.CONTROL
2408+
self.gguf_writer.add_tokenizer_model("llama")
2409+
self.gguf_writer.add_tokenizer_pre("default")
2410+
self.gguf_writer.add_token_list(tokens)
2411+
self.gguf_writer.add_token_scores(scores)
2412+
self.gguf_writer.add_token_types(toktypes)
2413+
2414+
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
2415+
special_vocab.add_to_gguf(self.gguf_writer)
2416+
23942417
self.gguf_writer.add_add_space_prefix(False)
23952418

23962419
def set_gguf_parameters(self):
@@ -2414,6 +2437,12 @@ def set_gguf_parameters(self):
24142437
self.gguf_writer.add_final_logit_softcapping(
24152438
self.hparams["final_logit_softcapping"]
24162439
)
2440+
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
2441+
2442+
# sanity check
2443+
attn_scalar = self.hparams["query_pre_attn_scalar"]
2444+
if attn_scalar != hparams["hidden_size"] / hparams["num_attention_heads"]:
2445+
raise ValueError("query_pre_attn_scalar must be equal to n_embd / n_head")
24172446

24182447
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
24192448
del bid # unusem
@@ -3031,6 +3060,96 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
30313060
return [(self.map_tensor_name(name), data_torch)]
30323061

30333062

3063+
@Model.register("JAISLMHeadModel")
3064+
class JaisModel(Model):
3065+
model_arch = gguf.MODEL_ARCH.JAIS
3066+
3067+
def __init__(self, *args, **kwargs):
3068+
super().__init__(*args, **kwargs)
3069+
3070+
# SwigLU activation
3071+
assert self.hparams["activation_function"] == "swiglu"
3072+
# ALiBi position embedding
3073+
assert self.hparams["position_embedding_type"] == "alibi"
3074+
3075+
# Embeddings scale
3076+
self.embeddings_scale = 1.0
3077+
# note: For some JAIS flavors, output is tied to (same as) wte in original model
3078+
self.output_is_wte = False
3079+
if 'mup_embeddings_scale' in self.hparams:
3080+
self.output_is_wte = True # Hack (?)
3081+
self.embeddings_scale = self.hparams['mup_embeddings_scale']
3082+
elif 'embeddings_scale' in self.hparams:
3083+
self.embeddings_scale = self.hparams['embeddings_scale']
3084+
else:
3085+
assert False
3086+
3087+
self.width_scale = 1.0
3088+
if 'mup_output_alpha' in self.hparams:
3089+
assert 'mup_width_scale' in self.hparams
3090+
self.width_scale = self.hparams['mup_output_alpha'] * self.hparams['mup_width_scale']
3091+
elif 'width_scale' in self.hparams:
3092+
self.width_scale = self.hparams['width_scale']
3093+
else:
3094+
assert False
3095+
3096+
self.max_alibi_bias = 8.0
3097+
3098+
def set_vocab(self):
3099+
self._set_vocab_gpt2()
3100+
3101+
def set_gguf_parameters(self):
3102+
self.gguf_writer.add_name(self.dir_model.name)
3103+
self.gguf_writer.add_block_count(self.hparams["n_layer"])
3104+
self.gguf_writer.add_context_length(self.hparams["n_positions"])
3105+
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
3106+
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
3107+
self.gguf_writer.add_head_count(self.hparams["n_head"])
3108+
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
3109+
self.gguf_writer.add_file_type(self.ftype)
3110+
3111+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3112+
del bid # unused
3113+
3114+
tensors: list[tuple[str, Tensor]] = []
3115+
3116+
# we don't need these
3117+
if name.endswith((".attn.bias")):
3118+
return tensors
3119+
3120+
if name.endswith(("relative_pe.slopes")):
3121+
# Calculate max ALiBi bias (this is the inverse of the ALiBi calculation)
3122+
# Some other models has max_alibi_bias spelled out explicitly in the hyperparams,
3123+
# but Jais's PyTorch model simply precalculates the slope values and places them
3124+
# in relative_pes.slopes
3125+
n_head_closest_log2 = 2 ** math.floor(math.log2(self.hparams["n_head"]))
3126+
first_val = float(data_torch._data[0])
3127+
self.max_alibi_bias = -round(math.log2(first_val) * n_head_closest_log2)
3128+
3129+
return tensors
3130+
3131+
if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_fc2.weight")):
3132+
data_torch = data_torch.transpose(1, 0)
3133+
3134+
new_name = self.map_tensor_name(name)
3135+
3136+
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
3137+
tensors.append((new_name, data_torch * self.embeddings_scale))
3138+
if self.output_is_wte:
3139+
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch * self.width_scale))
3140+
elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
3141+
assert not self.output_is_wte
3142+
tensors.append((new_name, data_torch * self.width_scale))
3143+
else:
3144+
tensors.append((new_name, data_torch))
3145+
3146+
return tensors
3147+
3148+
def write_tensors(self):
3149+
super().write_tensors()
3150+
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
3151+
3152+
30343153
###### CONVERSION LOGIC ######
30353154

30363155

@@ -3186,7 +3305,8 @@ def main() -> None:
31863305
"auto": gguf.LlamaFileType.GUESSED,
31873306
}
31883307

3189-
if args.use_temp_file and (args.split_max_tensors > 0 or args.split_max_size != "0"):
3308+
is_split = args.split_max_tensors > 0 or args.split_max_size != "0"
3309+
if args.use_temp_file and is_split:
31903310
logger.error("Error: Cannot use temp file when splitting")
31913311
sys.exit(1)
31923312

@@ -3223,11 +3343,12 @@ def main() -> None:
32233343
if args.vocab_only:
32243344
logger.info("Exporting model vocab...")
32253345
model_instance.write_vocab()
3226-
logger.info("Model vocab successfully exported.")
3346+
logger.info(f"Model vocab successfully exported to {model_instance.fname_out}")
32273347
else:
32283348
logger.info("Exporting model...")
32293349
model_instance.write()
3230-
logger.info("Model successfully exported.")
3350+
out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out
3351+
logger.info(f"Model successfully exported to {out_path}")
32313352

32323353

32333354
if __name__ == '__main__':

examples/embedding/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,3 @@ The above command will output space-separated float values.
5858
```powershell
5959
embedding.exe -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
6060
```
61-

0 commit comments

Comments
 (0)