Skip to content

Commit b0583f7

Browse files
committed
Merge branch 'master' into phi-1
2 parents 348d565 + b9f4795 commit b0583f7

File tree

13 files changed

+636
-399
lines changed

13 files changed

+636
-399
lines changed

.github/ISSUE_TEMPLATE/bug.md

Lines changed: 1 addition & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -6,179 +6,4 @@ assignees: ''
66

77
---
88

9-
# Prerequisites
10-
11-
Please answer the following questions for yourself before submitting an issue.
12-
13-
- [ ] I am running the latest code. Development is very rapid so there are no tagged versions as of now.
14-
- [ ] I carefully followed the [README.md](https://github.com/ggerganov/llama.cpp/blob/master/README.md).
15-
- [ ] I [searched using keywords relevant to my issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/filtering-and-searching-issues-and-pull-requests) to make sure that I am creating a new issue that is not already open (or closed).
16-
- [ ] I reviewed the [Discussions](https://github.com/ggerganov/llama.cpp/discussions), and have a new bug or useful enhancement to share.
17-
18-
# Expected Behavior
19-
20-
Please provide a detailed written description of what you were trying to do, and what you expected `llama.cpp` to do.
21-
22-
# Current Behavior
23-
24-
Please provide a detailed written description of what `llama.cpp` did, instead.
25-
26-
# Environment and Context
27-
28-
Please provide detailed information about your computer setup. This is important in case the issue is not reproducible except for under certain specific conditions.
29-
30-
* Physical (or virtual) hardware you are using, e.g. for Linux:
31-
32-
`$ lscpu`
33-
34-
* Operating System, e.g. for Linux:
35-
36-
`$ uname -a`
37-
38-
* SDK version, e.g. for Linux:
39-
40-
```
41-
$ python3 --version
42-
$ make --version
43-
$ g++ --version
44-
```
45-
46-
# Failure Information (for bugs)
47-
48-
Please help provide information about the failure / bug.
49-
50-
# Steps to Reproduce
51-
52-
Please provide detailed steps for reproducing the issue. We are not sitting in front of your screen, so the more detail the better.
53-
54-
1. step 1
55-
2. step 2
56-
3. step 3
57-
4. etc.
58-
59-
# Failure Logs
60-
61-
Please include any relevant log snippets or files. If it works under one configuration but not under another, please provide logs for both configurations and their corresponding outputs so it is easy to see where behavior changes.
62-
63-
Also, please try to **avoid using screenshots** if at all possible. Instead, copy/paste the console output and use [Github's markdown](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax) to cleanly format your logs for easy readability.
64-
65-
Example environment info:
66-
```
67-
llama.cpp$ git log | head -1
68-
commit 2af23d30434a677c6416812eea52ccc0af65119c
69-
70-
llama.cpp$ lscpu | egrep "AMD|Flags"
71-
Vendor ID: AuthenticAMD
72-
Model name: AMD Ryzen Threadripper 1950X 16-Core Processor
73-
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid amd_dcm aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb hw_pstate ssbd ibpb vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt sha_ni xsaveopt xsavec xgetbv1 xsaves clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif overflow_recov succor smca sme sev
74-
Virtualization: AMD-V
75-
76-
llama.cpp$ python3 --version
77-
Python 3.10.9
78-
79-
llama.cpp$ pip list | egrep "torch|numpy|sentencepiece"
80-
numpy 1.24.2
81-
numpydoc 1.5.0
82-
sentencepiece 0.1.97
83-
torch 1.13.1
84-
torchvision 0.14.1
85-
86-
llama.cpp$ make --version | head -1
87-
GNU Make 4.3
88-
89-
$ md5sum ./models/65B/ggml-model-q4_0.bin
90-
dbdd682cce80e2d6e93cefc7449df487 ./models/65B/ggml-model-q4_0.bin
91-
```
92-
93-
Example run with the Linux command [perf](https://www.brendangregg.com/perf.html)
94-
```
95-
llama.cpp$ perf stat ./main -m ./models/65B/ggml-model-q4_0.bin -t 16 -n 1024 -p "Please close your issue when it has been answered."
96-
main: seed = 1679149377
97-
llama_model_load: loading model from './models/65B/ggml-model-q4_0.bin' - please wait ...
98-
llama_model_load: n_vocab = 32000
99-
llama_model_load: n_ctx = 512
100-
llama_model_load: n_embd = 8192
101-
llama_model_load: n_mult = 256
102-
llama_model_load: n_head = 64
103-
llama_model_load: n_layer = 80
104-
llama_model_load: n_rot = 128
105-
llama_model_load: f16 = 2
106-
llama_model_load: n_ff = 22016
107-
llama_model_load: n_parts = 8
108-
llama_model_load: ggml ctx size = 41477.73 MB
109-
llama_model_load: memory_size = 2560.00 MB, n_mem = 40960
110-
llama_model_load: loading model part 1/8 from './models/65B/ggml-model-q4_0.bin'
111-
llama_model_load: .......................................................................................... done
112-
llama_model_load: model size = 4869.09 MB / num tensors = 723
113-
llama_model_load: loading model part 2/8 from './models/65B/ggml-model-q4_0.bin.1'
114-
llama_model_load: .......................................................................................... done
115-
llama_model_load: model size = 4869.09 MB / num tensors = 723
116-
llama_model_load: loading model part 3/8 from './models/65B/ggml-model-q4_0.bin.2'
117-
llama_model_load: .......................................................................................... done
118-
llama_model_load: model size = 4869.09 MB / num tensors = 723
119-
llama_model_load: loading model part 4/8 from './models/65B/ggml-model-q4_0.bin.3'
120-
llama_model_load: .......................................................................................... done
121-
llama_model_load: model size = 4869.09 MB / num tensors = 723
122-
llama_model_load: loading model part 5/8 from './models/65B/ggml-model-q4_0.bin.4'
123-
llama_model_load: .......................................................................................... done
124-
llama_model_load: model size = 4869.09 MB / num tensors = 723
125-
llama_model_load: loading model part 6/8 from './models/65B/ggml-model-q4_0.bin.5'
126-
llama_model_load: .......................................................................................... done
127-
llama_model_load: model size = 4869.09 MB / num tensors = 723
128-
llama_model_load: loading model part 7/8 from './models/65B/ggml-model-q4_0.bin.6'
129-
llama_model_load: .......................................................................................... done
130-
llama_model_load: model size = 4869.09 MB / num tensors = 723
131-
llama_model_load: loading model part 8/8 from './models/65B/ggml-model-q4_0.bin.7'
132-
llama_model_load: .......................................................................................... done
133-
llama_model_load: model size = 4869.09 MB / num tensors = 723
134-
135-
system_info: n_threads = 16 / 32 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 |
136-
137-
main: prompt: 'Please close your issue when it has been answered.'
138-
main: number of tokens in prompt = 11
139-
1 -> ''
140-
12148 -> 'Please'
141-
3802 -> ' close'
142-
596 -> ' your'
143-
2228 -> ' issue'
144-
746 -> ' when'
145-
372 -> ' it'
146-
756 -> ' has'
147-
1063 -> ' been'
148-
7699 -> ' answered'
149-
29889 -> '.'
150-
151-
sampling parameters: temp = 0.800000, top_k = 40, top_p = 0.950000, repeat_last_n = 64, repeat_penalty = 1.300000
152-
153-
154-
Please close your issue when it has been answered.
155-
@duncan-donut: I'm trying to figure out what kind of "support" you need for this script and why, exactly? Is there a question about how the code works that hasn't already been addressed in one or more comments below this ticket, or are we talking something else entirely like some sorta bugfixing job because your server setup is different from mine??
156-
I can understand if your site needs to be running smoothly and you need help with a fix of sorts but there should really be nothing wrong here that the code itself could not handle. And given that I'm getting reports about how it works perfectly well on some other servers, what exactly are we talking? A detailed report will do wonders in helping us get this resolved for ya quickly so please take your time and describe the issue(s) you see as clearly & concisely as possible!!
157-
@duncan-donut: I'm not sure if you have access to cPanel but you could try these instructions. It is worth a shot! Let me know how it goes (or what error message, exactly!) when/if ya give that code a go? [end of text]
158-
159-
160-
main: mem per token = 71159620 bytes
161-
main: load time = 19309.95 ms
162-
main: sample time = 168.62 ms
163-
main: predict time = 223895.61 ms / 888.47 ms per token
164-
main: total time = 246406.42 ms
165-
166-
Performance counter stats for './main -m ./models/65B/ggml-model-q4_0.bin -t 16 -n 1024 -p Please close your issue when it has been answered.':
167-
168-
3636882.89 msec task-clock # 14.677 CPUs utilized
169-
13509 context-switches # 3.714 /sec
170-
2436 cpu-migrations # 0.670 /sec
171-
10476679 page-faults # 2.881 K/sec
172-
13133115082869 cycles # 3.611 GHz (16.77%)
173-
29314462753 stalled-cycles-frontend # 0.22% frontend cycles idle (16.76%)
174-
10294402631459 stalled-cycles-backend # 78.39% backend cycles idle (16.74%)
175-
23479217109614 instructions # 1.79 insn per cycle
176-
# 0.44 stalled cycles per insn (16.76%)
177-
2353072268027 branches # 647.002 M/sec (16.77%)
178-
1998682780 branch-misses # 0.08% of all branches (16.76%)
179-
180-
247.802177522 seconds time elapsed
181-
182-
3618.573072000 seconds user
183-
18.491698000 seconds sys
184-
```
9+
Please include information about your system, the steps to reproduce the bug, and the version of llama.cpp that you are using. If possible, please provide a minimal code example that reproduces the bug.

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,8 @@ if (LLAMA_CUBLAS)
302302
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
303303
endif()
304304

305+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)
306+
305307
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
306308
# 52 == lowest CUDA 12 standard
307309
# 60 == f16 CUDA intrinsics

Makefile

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,17 +367,15 @@ endif # LLAMA_BLIS
367367

368368
ifdef LLAMA_CUBLAS
369369
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include -I/usr/local/cuda/targets/aarch64-linux/include
370-
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib
370+
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/lib/wsl/lib
371371
OBJS += ggml-cuda.o
372372
MK_NVCCFLAGS = -use_fast_math
373373
ifndef JETSON_EOL_MODULE_DETECT
374374
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
375375
endif # JETSON_EOL_MODULE_DETECT
376-
377376
ifdef LLAMA_DEBUG
378377
MK_NVCCFLAGS += -lineinfo
379-
endif
380-
378+
endif # LLAMA_DEBUG
381379
ifdef LLAMA_CUDA_NVCC
382380
NVCC = $(LLAMA_CUDA_NVCC)
383381
else

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ as the main playground for developing new features for the [ggml](https://github
102102
- [x] [Deepseek models](https://huggingface.co/models?search=deepseek-ai/deepseek)
103103
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
104104
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
105+
- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)
105106

106107
**Multimodal models:**
107108

convert-hf-to-gguf.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ def from_model_architecture(model_architecture):
183183
return MixtralModel
184184
if model_architecture == "PhiForCausalLM":
185185
return PhiModel
186+
if model_architecture == "PlamoForCausalLM":
187+
return PlamoModel
186188
return Model
187189

188190
def _is_model_safetensors(self) -> bool:
@@ -224,6 +226,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
224226
return gguf.MODEL_ARCH.LLAMA
225227
if arch == "PhiForCausalLM":
226228
return gguf.MODEL_ARCH.PHI
229+
if arch == "PlamoForCausalLM":
230+
return gguf.MODEL_ARCH.PLAMO
227231

228232
raise NotImplementedError(f'Architecture "{arch}" not supported!')
229233

@@ -1001,11 +1005,91 @@ def set_gguf_parameters(self):
10011005
self.gguf_writer.add_add_bos_token(False)
10021006

10031007

1008+
class PlamoModel(Model):
1009+
def set_vocab(self):
1010+
self._set_vocab_sentencepiece()
1011+
1012+
def set_gguf_parameters(self):
1013+
hparams = self.hparams
1014+
block_count = hparams["num_hidden_layers"]
1015+
1016+
self.gguf_writer.add_name("PLaMo")
1017+
self.gguf_writer.add_context_length(4096) # not in config.json
1018+
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
1019+
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
1020+
self.gguf_writer.add_block_count(block_count)
1021+
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
1022+
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
1023+
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
1024+
1025+
def shuffle_attn_q_weight(self, data_torch):
1026+
assert data_torch.size() == (5120, 5120)
1027+
data_torch = data_torch.reshape(8, 5, 128, 5120)
1028+
data_torch = torch.permute(data_torch, (1, 0, 2, 3))
1029+
data_torch = torch.reshape(data_torch, (5120, 5120))
1030+
return data_torch
1031+
1032+
def shuffle_attn_output_weight(self, data_torch):
1033+
assert data_torch.size() == (5120, 5120)
1034+
data_torch = data_torch.reshape(5120, 8, 5, 128)
1035+
data_torch = torch.permute(data_torch, (0, 2, 1, 3))
1036+
data_torch = torch.reshape(data_torch, (5120, 5120))
1037+
return data_torch
1038+
1039+
def write_tensors(self):
1040+
block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
1041+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
1042+
1043+
for name, data_torch in self.get_tensors():
1044+
if "self_attn.rotary_emb.inv_freq" in name:
1045+
continue
1046+
1047+
# map tensor names
1048+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
1049+
if new_name is None:
1050+
print(f"Can not map tensor {name!r}")
1051+
sys.exit()
1052+
1053+
# shuffle for broadcasting of gqa in ggml_mul_mat
1054+
if new_name.endswith("attn_q.weight"):
1055+
data_torch = self.shuffle_attn_q_weight(data_torch)
1056+
elif new_name.endswith("attn_output.weight"):
1057+
data_torch = self.shuffle_attn_output_weight(data_torch)
1058+
1059+
old_dtype = data_torch.dtype
1060+
1061+
# convert any unsupported data types to float32
1062+
if data_torch.dtype not in (torch.float16, torch.float32):
1063+
data_torch = data_torch.to(torch.float32)
1064+
1065+
data = data_torch.squeeze().numpy()
1066+
1067+
n_dims = len(data.shape)
1068+
data_dtype = data.dtype
1069+
1070+
# if f32 desired, convert any float16 to float32
1071+
if self.ftype == 0 and data_dtype == np.float16:
1072+
data = data.astype(np.float32)
1073+
1074+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
1075+
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
1076+
data = data.astype(np.float32)
1077+
1078+
# if f16 desired, convert any float32 2-dim weight tensors to float16
1079+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
1080+
data = data.astype(np.float16)
1081+
1082+
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
1083+
1084+
self.gguf_writer.add_tensor(new_name, data)
1085+
1086+
10041087
###### CONVERSION LOGIC ######
10051088

10061089

10071090
def parse_args() -> argparse.Namespace:
1008-
parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file")
1091+
parser = argparse.ArgumentParser(
1092+
description="Convert a huggingface model to a GGML compatible file")
10091093
parser.add_argument(
10101094
"--vocab-only", action="store_true",
10111095
help="extract only the vocab",

ggml-backend.c

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ static void ggml_backend_registry_init(void) {
297297
void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
298298
GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG);
299299

300-
int id = ggml_backend_registry_count;
300+
size_t id = ggml_backend_registry_count;
301301

302302
ggml_backend_registry[id] = (struct ggml_backend_reg) {
303303
/* .name = */ {0},
@@ -330,6 +330,8 @@ size_t ggml_backend_reg_find_by_name(const char * name) {
330330
return i;
331331
}
332332
}
333+
334+
// not found
333335
return SIZE_MAX;
334336
}
335337

@@ -340,15 +342,15 @@ ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str)
340342
const char * params = strchr(backend_str, ':');
341343
char backend_name[128];
342344
if (params == NULL) {
343-
strcpy(backend_name, backend_str);
345+
snprintf(backend_name, sizeof(backend_name), "%s", backend_str);
344346
params = "";
345347
} else {
346-
strncpy(backend_name, backend_str, params - backend_str);
347-
backend_name[params - backend_str] = '\0';
348+
snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str);
348349
params++;
349350
}
350351

351352
size_t backend_i = ggml_backend_reg_find_by_name(backend_name);
353+
352354
if (backend_i == SIZE_MAX) {
353355
fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
354356
return NULL;
@@ -396,18 +398,12 @@ static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
396398
}
397399

398400
static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
399-
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
400-
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
401-
402401
memcpy((char *)tensor->data + offset, data, size);
403402

404403
GGML_UNUSED(buffer);
405404
}
406405

407406
static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
408-
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
409-
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
410-
411407
memcpy(data, (const char *)tensor->data + offset, size);
412408

413409
GGML_UNUSED(buffer);

0 commit comments

Comments
 (0)