diff --git a/.gitignore b/.gitignore index 741c6b4ea3447..1c75d38d11d1a 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ models/* /result /perplexity /embedding +/Pipfile arm_neon.h compile_commands.json diff --git a/convert-ggml-to-pth.py b/convert-ggml-to-pth.py index 20158c9ca8650..8ab17410d7929 100644 --- a/convert-ggml-to-pth.py +++ b/convert-ggml-to-pth.py @@ -84,6 +84,11 @@ def read_variables(fin): shape = shape[::-1] name = fin.read(name_length).decode("utf-8") + # ensure tensor data is aligned + tensor_data_offset = fin.tell() + tensor_data_offset = (tensor_data_offset + 31) & -32 + fin.seek(tensor_data_offset) + if ftype_cur == 2: # 4-bit quantized weights dtype = np.uint8 diff --git a/convert-gptq-to-ggml.py b/convert-gptq-to-ggml.py index 6c77808fcd186..860eb148b2a86 100644 --- a/convert-gptq-to-ggml.py +++ b/convert-gptq-to-ggml.py @@ -72,6 +72,11 @@ def write_header(shape, dst_name, ftype_cur): fout.write(struct.pack("i" * len(shape), *shape[::-1])) fout.write(sname) + # ensure tensor data is aligned + tensor_data_offset = fout.tell() + tensor_data_offset = (tensor_data_offset + 31) & -32 + fout.seek(tensor_data_offset) + def convert_non_q4(src_name, dst_name): v = model[src_name] shape = v.shape diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index d83f8a1373251..df42e76bdd0d2 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -1,4 +1,4 @@ -# Convert a LLaMA model checkpoint to a ggml compatible file +# Convert a LLaMA model checkpoint to a ggjt compatible file # # Load the model using Torch # Iterate over all variables and write them to a binary file. @@ -24,8 +24,57 @@ from sentencepiece import SentencePieceProcessor -def parse_args(): +QK = 32 + +GGML_TYPE_Q4_0 = 0 +GGML_TYPE_Q4_1 = 1 +GGML_TYPE_I8 = 2 +GGML_TYPE_I16 = 3 +GGML_TYPE_I32 = 4 +GGML_TYPE_F16 = 5 +GGML_TYPE_F32 = 6 + +WTYPES = { + 0: GGML_TYPE_F32, + 1: GGML_TYPE_F16, + 2: GGML_TYPE_Q4_0, + 3: GGML_TYPE_Q4_1, +} + +GGML_BLCK_SIZE = { + GGML_TYPE_Q4_0: QK, + GGML_TYPE_Q4_1: QK, + GGML_TYPE_I8: 1, + GGML_TYPE_I16: 1, + GGML_TYPE_I32: 1, + GGML_TYPE_F16: 1, + GGML_TYPE_F32: 1, +} + +GGML_TYPE_SIZE = { + GGML_TYPE_Q4_0: 4 + QK//2, + GGML_TYPE_Q4_1: 4*2 + QK//2, + GGML_TYPE_I8: 1, + GGML_TYPE_I16: 2, + GGML_TYPE_I32: 4, + GGML_TYPE_F16: 2, + GGML_TYPE_F32: 4, +} + +def ggml_nelements(shape): + r = 1 + for i in shape: + r *= i + return r + +def ggml_nbytes(shape, ftype): + x = ggml_nelements(shape) + t = WTYPES[ftype] + x *= GGML_TYPE_SIZE[t] + x //= GGML_BLCK_SIZE[t] + return x +def parse_args(): parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') parser.add_argument('dir_model', help='directory containing the model checkpoint') parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1) @@ -33,7 +82,6 @@ def parse_args(): return parser.parse_args() def get_n_parts(dim): - mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8} n_parts = mappings.get(dim) if n_parts is None: @@ -44,30 +92,24 @@ def get_n_parts(dim): return n_parts def load_hparams_and_tokenizer(dir_model): - # `dir_model` is something like `models/7B` or `models/7B/`. # "tokenizer.model" is expected under model's parent dir. # When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found. # Let's use the model's parent dir directly. model_parent_dir = os.path.dirname(os.path.normpath(dir_model)) - fname_hparams = f"{dir_model}/params.json" fname_tokenizer = f"{model_parent_dir}/tokenizer.model" - with open(fname_hparams, "r") as f: hparams = json.load(f) print(hparams) - tokenizer = SentencePieceProcessor(fname_tokenizer) hparams.update({"vocab_size": tokenizer.vocab_size()}) - return hparams, tokenizer def write_header(fout, hparams, ftype): - keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"] values = [ - 0x67676d66, # magic: ggmf in hex + 0x67676a74, # magic: ggjt in hex 1, # file version *[hparams[key] for key in keys], hparams["dim"] // hparams["n_heads"], # rot (obsolete) @@ -76,7 +118,6 @@ def write_header(fout, hparams, ftype): fout.write(struct.pack("i" * len(values), *values)) def write_tokens(fout, tokenizer): - for i in range(tokenizer.vocab_size()): if tokenizer.is_unknown(i): text = " \u2047 ".encode("utf-8") @@ -95,85 +136,139 @@ def write_tokens(fout, tokenizer): fout.write(text) fout.write(struct.pack("f", tokenizer.get_score(i))) -def process_and_write_variables(fout, model, ftype): - +def process_and_write_variables(fout, model, ftype, part_id, n_parts): for name, datao in model.items(): - if name.endswith("freqs"): continue - shape = datao.shape - - print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}") - + # remove dimensions with a single element data = datao.numpy().squeeze() - n_dims = len(shape) + partshape = data.shape + n_dims = len(data.shape) + assert n_dims in (1, 2) - # default type is fp16 + print(f"Processing variable: {name} with shape: {partshape} and type: {datao.dtype}") + + # coerce single-dimensional tensors from float16 to float32 ftype_cur = 1 if ftype == 0 or n_dims == 1: print(" Converting to float32") data = data.astype(np.float32) ftype_cur = 0 - - # header + blck_size = GGML_BLCK_SIZE[WTYPES[ftype_cur]] + type_size = GGML_TYPE_SIZE[WTYPES[ftype_cur]] + + # determine dimension along which multipart tensor is sharded + # + # split_dim 0 regex: + # - output.* + # - layers.*.attention.wq.weight + # - layers.*.attention.wk.weight + # - layers.*.attention.wv.weight + # - layers.*.feed_forward.w1.weight + # - layers.*.feed_forward.w3.weight + # + # split_dim 1 regex: + # - tok_embeddings.* + # - layers.*.attention.wo.weight + # - layers.*.feed_forward.w2.weight + # + if n_dims > 1: + split_dim = 1 + if "tok_embeddings" in name: + split_dim = 1 + elif "layers" in name: + if "attention.wo.weight" in name: + split_dim = 1 + elif "feed_forward.w2.weight" in name: + split_dim = 1 + else: + split_dim = 0 + elif "output" in name: + split_dim = 0 + + # output tensor header + fullshape = list(partshape) + if n_dims > 1: + fullshape[split_dim] *= n_parts sname = name.encode('utf-8') - fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur)) - for dim in reversed(data.shape): + fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) + for dim in reversed(fullshape): fout.write(struct.pack("i", dim)) fout.write(sname) - # data output to file - data.tofile(fout) + # ensure tensor data is aligned + tensor_data_offset = fout.tell() + while tensor_data_offset % QK != 0: + fout.write(struct.pack("B", 0)) + tensor_data_offset += 1 + + # output unified mappable tensor data + if n_dims == 1 or n_parts == 1: + # copy tensor which we thankfully received in one piece + if part_id == 0: + data.tofile(fout) + elif split_dim == 0: + # reassemble multifile tensor containing some of the rows + rows_per_chunk = partshape[0] + current_row = part_id * rows_per_chunk + bytes_per_row = fullshape[1] // blck_size * type_size + offset = current_row * bytes_per_row + fout.seek(tensor_data_offset + offset) + data.tofile(fout) + elif split_dim == 1: + # reassemble multifile tensor containing some of the cols + cols_per_chunk = partshape[1] + current_col = part_id * cols_per_chunk + bytes_per_row = fullshape[1] // blck_size * type_size + offset_current_col = current_col // blck_size * type_size + for row in range(partshape[0]): + offset_row = row * bytes_per_row + offset = offset_row + offset_current_col + fout.seek(tensor_data_offset + offset) + data[row].tofile(fout) + + # advance file position to next tensor + fout.seek(tensor_data_offset + ggml_nbytes(fullshape, ftype_cur)) def main(): - args = parse_args() dir_model = args.dir_model ftype = args.ftype ftype_str = ["f32", "f16"] - hparams, tokenizer = load_hparams_and_tokenizer(dir_model) print(args) # if only writing vocab to file if args.vocab_only: - fname_model = f"{dir_model}/consolidated.00.pth" fname_out = f"{dir_model}/ggml-vocab.bin" - print(f"Extracting only the vocab from '{fname_model}'\n") - - with open(fname_out, "wb") as fout: write_header(fout, hparams, ftype) write_tokens(fout, tokenizer) - - print(f"Done. Output file: {fname_out}\n") - return n_parts = get_n_parts(hparams["dim"]) - - for p in range(n_parts): - - print(f"Processing part {p+1} of {n_parts}\n") - - fname_model = f"{dir_model}/consolidated.0{p}.pth" - fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}" - - model = torch.load(fname_model, map_location="cpu") - - with open(fname_out, "wb") as fout: - write_header(fout, hparams, ftype) - write_tokens(fout, tokenizer) - process_and_write_variables(fout, model, ftype) - - del model - - print(f"Done. Output file: {fname_out}, (part {p})\n") + fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin" + + # we output a single file for ggml + with open(fname_out, "wb") as fout: + write_header(fout, hparams, ftype) + write_tokens(fout, tokenizer) + offset_of_tensors = fout.tell() + # the tensors we load could be split across multiple files + for part_id in range(n_parts): + fout.seek(offset_of_tensors) + print(f"Processing part {part_id+1} of {n_parts}\n") + fname_model = f"{dir_model}/consolidated.0{part_id}.pth" + model = torch.load(fname_model, map_location="cpu") + process_and_write_variables(fout, model, ftype, part_id, n_parts) + del model + + print(f"Done. Output file: {fname_out}\n") if __name__ == "__main__": main() diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index b444328acc6aa..680757c6bf356 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -19,7 +19,7 @@ int main(int argc, char ** argv) { // needed to initialize f16 tables { - struct ggml_init_params params = { 0, NULL }; + struct ggml_init_params params = { 0, NULL, false }; struct ggml_context * ctx = ggml_init(params); ggml_free(ctx); } diff --git a/ggml.c b/ggml.c index 02675ee67072d..3d474a1a1cbf5 100644 --- a/ggml.c +++ b/ggml.c @@ -2530,8 +2530,9 @@ struct ggml_context { void * mem_buffer; bool mem_buffer_owned; bool mem_buffer_mlocked; + bool no_alloc; - int n_objects; + int n_objects; struct ggml_object * objects_begin; struct ggml_object * objects_end; @@ -2816,6 +2817,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size), /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, /*.mem_buffer_mlocked =*/ false, + /*.no_alloc =*/ params.no_alloc, /*.n_objects =*/ 0, /*.objects_begin =*/ NULL, /*.objects_end =*/ NULL, @@ -2883,36 +2885,47 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) return result; } +#ifdef __APPLE__ +#define MLOCK_SUGGESTION \ + "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \ + "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n" +#else +#define MLOCK_SUGGESTION \ + "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n" +#endif + bool ggml_mlock_supported(void) { return GGML_MLOCK_SUPPORT; } +bool ggml_mlock( + struct ggml_context * ctx, + const void *opt_extra_addr, + size_t opt_extra_len, + char **err_p) { + // TODO: Use SetProcessWorkingSetSize() + VirtualLock() on WIN32 #if GGML_MLOCK_SUPPORT -#ifdef __APPLE__ - #define MLOCK_SUGGESTION "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or\n" \ - "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l)." -#else - #define MLOCK_SUGGESTION "Try increasing RLIMIT_MLOCK (ulimit -l)." -#endif -bool ggml_mlock(struct ggml_context * ctx, char ** err_p) { if (ctx->mem_buffer_mlocked) { return true; } - if (mlock(ctx->mem_buffer, ctx->mem_size)) { - int ret = asprintf(err_p, "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION, - ctx->mem_size, strerror(errno)); - GGML_ASSERT(ret >= 0); + if (mlock(ctx->mem_buffer, ctx->mem_size) || + (opt_extra_len && + mlock(opt_extra_addr, opt_extra_len))) { + if ((*err_p = malloc(1024))) { + snprintf(*err_p, 1024, + "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION, + ctx->mem_size + opt_extra_len, + strerror(errno)); + } return false; } ctx->mem_buffer_mlocked = true; return true; -} #else // GGML_MLOCK_SUPPORT -bool ggml_mlock(struct ggml_context * ctx, char ** err_p) { *err_p = strdup("can't mlock because it's not supported on this system"); return false; -} #endif // GGML_MLOCK_SUPPORT +} //////////////////////////////////////////////////////////////////////////////// @@ -2931,7 +2944,7 @@ struct ggml_tensor * ggml_new_tensor_impl( size_t size_needed = 0; - if (data == NULL) { + if (data == NULL && !ctx->no_alloc) { size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]); for (int i = 1; i < n_dims; i++) { size_needed *= ne[i]; @@ -3015,7 +3028,7 @@ struct ggml_tensor * ggml_new_tensor_impl( /*.perf_runs =*/ 0, /*.perf_cycles =*/ 0, /*.perf_time_us =*/ 0, - /*.data =*/ data == NULL ? (void *)(result + 1) : data, + /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data, /*.pad =*/ { 0 }, }; @@ -10278,6 +10291,7 @@ enum ggml_opt_result ggml_opt( struct ggml_init_params params_ctx = { .mem_size = 16*1024*1024, .mem_buffer = NULL, + .no_alloc = false, }; ctx = ggml_init(params_ctx); diff --git a/ggml.h b/ggml.h index 335230f9f0bb2..f7791ed11f084 100644 --- a/ggml.h +++ b/ggml.h @@ -316,6 +316,7 @@ struct ggml_init_params { // memory pool size_t mem_size; // bytes void * mem_buffer; // if NULL, memory will be allocated internally + bool no_alloc; // don't allocate memory for the tensor data }; void ggml_time_init(void); // call this once at the beginning of the program @@ -344,7 +345,11 @@ size_t ggml_used_mem(const struct ggml_context * ctx); size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch); bool ggml_mlock_supported(void); -bool ggml_mlock(struct ggml_context * ctx, char ** err_p); +bool ggml_mlock( + struct ggml_context * ctx, + const void *opt_extra_addr, + size_t opt_extra_len, + char **err_p); struct ggml_tensor * ggml_new_tensor( struct ggml_context * ctx, diff --git a/llama.cpp b/llama.cpp index e4998efa2bb20..bed24207db776 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12,6 +12,19 @@ #include #include +#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES) +#define WIN32_LEAN_AND_MEAN +#include +#else +#include +#include +#include +#include +#endif + +#define Min(X, Y) ((Y) > (X) ? (X) : (Y)) +#define Max(X, Y) ((Y) < (X) ? (X) : (Y)) + #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -142,6 +155,10 @@ struct llama_model { // the model memory buffer std::vector buf; + // model memory mapped file + void * mm_addr = NULL; + uint64_t mm_length = 0; + // tensors int n_loaded; std::unordered_map tensors; @@ -165,6 +182,7 @@ struct llama_context { int64_t t_load_us = 0; int64_t t_start_us = 0; + bool has_evaluated_once = false; int64_t t_sample_us = 0; int64_t t_eval_us = 0; @@ -206,7 +224,7 @@ struct llama_context { } if (buf_last >= 0) { - buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); + buf_max_size[buf_last] = Max(buf_max_size[buf_last], last_size); } buf_last = i; @@ -246,6 +264,7 @@ static bool kv_cache_init( struct ggml_init_params params; params.mem_size = cache.buf.size(); params.mem_buffer = cache.buf.data(); + params.no_alloc = false; cache.ctx = ggml_init(params); @@ -288,6 +307,58 @@ struct llama_context_params llama_context_default_params() { // model loading // +static void *mmap_file(const char *fname, uint64_t *mm_length) { +#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES) + HANDLE hFile = CreateFileA(fname, + GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + NULL, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL | FILE_ATTRIBUTE_NOT_CONTENT_INDEXED, + NULL); + if (hFile == INVALID_HANDLE_VALUE) return 0; + LARGE_INTEGER fileSize; + fileSize.QuadPart = -1; + GetFileSizeEx(hFile, &fileSize); + int64_t length = fileSize.QuadPart; + HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + CloseHandle(hFile); + if (!hMapping) return 0; + void *addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); + CloseHandle(hMapping); + if (!addr) return 0; +#else + int fd = open(fname, O_RDONLY); + if (fd == -1) return 0; + int64_t length = lseek(fd, 0, SEEK_END); + void *addr = mmap(NULL, length, PROT_READ, MAP_SHARED, fd, 0); + close(fd); + if (addr == MAP_FAILED) return 0; +#endif + *mm_length = length; + return addr; +} + +static void munmap_file(void * addr, size_t length) { +#if defined(_WIN32) && !defined(_POSIX_MAPPED_FILES) + UnmapViewOfFile(addr); +#else + munmap(addr, length); +#endif +} + +static bool report_bad_magic(const char *path, uint32_t got, uint32_t want) { + fprintf(stderr, + "%s: invalid model file (bad magic [got %#x want %#x])\n" + "\tyou most likely need to regenerate your ggml files\n" + "\tthe benefit is you'll get 10-100x faster load times\n" + "\tsee https://github.com/ggerganov/llama.cpp/issues/91\n" + "\tuse convert-pth-to-ggml.py to regenerate from original pth\n" + "\tuse migrate-ggml-2023-03-30-pr613.py if you deleted originals\n", + path, got, want); + return false; +} + static bool llama_model_load( const std::string & fname, llama_context & lctx, @@ -299,22 +370,24 @@ static bool llama_model_load( void *progress_callback_user_data) { fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); - const int64_t t_start_us = ggml_time_us(); - - lctx.t_start_us = t_start_us; - - std::vector f_buf(1024*1024); + lctx.t_start_us = ggml_time_us(); auto & model = lctx.model; auto & vocab = lctx.vocab; auto fin = std::ifstream(fname, std::ios::binary); - fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size()); if (!fin) { fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); return false; } + std::vector f_buf(1024*1024); + fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size()); + + fin.seekg(0, fin.end); + const size_t file_size = fin.tellg(); + fin.seekg(0); + // verify magic { uint32_t magic; @@ -325,8 +398,7 @@ static bool llama_model_load( return false; } if (magic != LLAMA_FILE_MAGIC) { - fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); - return false; + return report_bad_magic(fname.c_str(), magic, LLAMA_FILE_MAGIC); } uint32_t format_version; @@ -449,43 +521,24 @@ static bool llama_model_load( } } + // map model into memory + char *mm_addr = NULL; + model.mm_addr = mmap_file(fname.c_str(), &model.mm_length); + if (model.mm_addr == NULL) { + fprintf(stderr, "%s: failed to mmap '%s'\n", __func__, fname.c_str()); + return false; + } + mm_addr = (char *)model.mm_addr; + fprintf(stderr, "%s: ggml map size = %6.2f MB\n", __func__, model.mm_length/(1024.0*1024.0)); + auto & ctx = model.ctx; size_t ctx_size = 0; - { - const auto & hparams = model.hparams; - - const int n_embd = hparams.n_embd; + const auto &hparams = model.hparams; const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; - const int n_vocab = hparams.n_vocab; - - ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // tok_embeddings - - ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm - - ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // output - - ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm - - ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq - ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk - ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv - ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo - - ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm - - ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1 - ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2 - ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3 - - ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_k - ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_v - ctx_size += (5 + 10*n_layer)*256; // object overhead - - fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0); } // print memory requirements @@ -495,6 +548,7 @@ static bool llama_model_load( // this is the total memory required to run the inference const size_t mem_required = ctx_size + + model.mm_length + MEM_REQ_SCRATCH0.at(model.type) + MEM_REQ_SCRATCH1.at(model.type) + MEM_REQ_EVAL.at (model.type); @@ -514,6 +568,7 @@ static bool llama_model_load( struct ggml_init_params params = { /*.mem_size =*/ lctx.model.buf.size(), /*.mem_buffer =*/ lctx.model.buf.data(), + /*.no_alloc =*/ true, }; model.ctx = ggml_init(params); @@ -576,234 +631,106 @@ static bool llama_model_load( } } - const size_t file_offset = fin.tellg(); - - fin.close(); - std::vector tmp; if (progress_callback) { progress_callback(0.0, progress_callback_user_data); } - for (int i = 0; i < n_parts; ++i) { - const int part_id = i; - //const int part_id = n_parts - i - 1; - - std::string fname_part = fname; - if (i > 0) { - fname_part += "." + std::to_string(i); - } + fprintf(stderr, "%s: loading tensors from '%s'\n", __func__, fname.c_str()); - fprintf(stderr, "%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str()); + // load weights + { + size_t total_size = 0; + model.n_loaded = 0; - fin = std::ifstream(fname_part, std::ios::binary); - fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size()); + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; - fin.seekg(0, fin.end); - const size_t file_size = fin.tellg(); + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ftype), sizeof(ftype)); - fin.seekg(file_offset); + if (fin.eof()) { + break; + } - // load weights - { - size_t total_size = 0; + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } - model.n_loaded = 0; + std::string name(length, 0); + fin.read(&name[0], length); - fprintf(stderr, "%s: ", __func__); + if (model.tensors.find(name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } - while (true) { - int32_t n_dims; - int32_t length; - int32_t ftype; + auto tensor = model.tensors[name.data()]; - fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); - fin.read(reinterpret_cast(&length), sizeof(length)); - fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + if (ggml_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", + __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]); + return false; + } + if (0) { + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + fprintf(stderr, "%24s - [%5d, %5d], type = %6s\n", name.data(), ne[0], ne[1], ftype_str[ftype]); + } - if (fin.eof()) { + switch (ftype) { + case 0: // f32 + case 1: // f16 break; - } - - int32_t nelements = 1; - int32_t ne[2] = { 1, 1 }; - for (int i = 0; i < n_dims; ++i) { - fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); - nelements *= ne[i]; - } - - std::string name(length, 0); - fin.read(&name[0], length); - - if (model.tensors.find(name.data()) == model.tensors.end()) { - fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + case 2: // q4_0 + case 3: // q4_1 + assert(ne[0] % 64 == 0); + break; + default: + fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); return false; - } - - // split_type = 0: split by columns - // split_type = 1: split by rows - int split_type = 0; - - // split_type = 0: - // regex: - // - tok_embeddings.* - // - layers.*.attention.wo.weight - // - layers.*.feed_forward.w2.weight - - // split_type = 1: - // regex: - // - output.* - // - layers.*.attention.wq.weight - // - layers.*.attention.wk.weight - // - layers.*.attention.wv.weight - // - layers.*.feed_forward.w1.weight - // - layers.*.feed_forward.w3.weight - if (name.find("tok_embeddings") != std::string::npos) { - split_type = 0; - } else if (name.find("layers") != std::string::npos) { - if (name.find("attention.wo.weight") != std::string::npos) { - split_type = 0; - } else if (name.find("feed_forward.w2.weight") != std::string::npos) { - split_type = 0; - } else { - split_type = 1; - } - } else if (name.find("output") != std::string::npos) { - split_type = 1; - } - - auto tensor = model.tensors[name.data()]; - - if (n_dims == 1) { - if (ggml_nelements(tensor) != nelements) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - return false; - } - } else { - if (ggml_nelements(tensor)/n_parts != nelements) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - return false; - } - } - - if (n_dims == 1) { - if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", - __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]); - return false; - } - } else { - if (split_type == 0) { - if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", - __func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]); - return false; - } - } else { - if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", - __func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]); - return false; - } - } - } - - if (0) { - static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; - fprintf(stderr, "%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type); - } - - size_t bpe = 0; - - switch (ftype) { - case 0: bpe = ggml_type_size(GGML_TYPE_F32); break; - case 1: bpe = ggml_type_size(GGML_TYPE_F16); break; - case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; - case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; - default: - { - fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); - return false; - } - }; - - if (n_dims == 1 || n_parts == 1) { - if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); - return false; - } - - if (part_id == 0) { - fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); - } else { - fin.seekg(ggml_nbytes(tensor), std::ios::cur); - } - - total_size += ggml_nbytes(tensor); - } else { - if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe); - return false; - } - - if (split_type == 0) { - const int np0 = ne[0]; - - const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type); - assert(row_size == tensor->nb[1]); - - for (int i1 = 0; i1 < ne[1]; ++i1) { - const size_t offset_row = i1*row_size; - const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type); - fin.read(reinterpret_cast(tensor->data) + offset, row_size/n_parts); - } - } else { - const int np1 = ne[1]; - - const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type); - - for (int i1 = 0; i1 < ne[1]; ++i1) { - const size_t offset_row = (i1 + part_id*np1)*row_size; - fin.read(reinterpret_cast(tensor->data) + offset_row, row_size); - } - } - - total_size += ggml_nbytes(tensor)/n_parts; - } - - //fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); - model.n_loaded++; - - // progress - if (progress_callback) { - float current_file_progress = float(size_t(fin.tellg()) - file_offset) / float(file_size - file_offset); - float current_progress = (float(i) + current_file_progress) / float(n_parts); - progress_callback(current_progress, progress_callback_user_data); - } - if (model.n_loaded % 8 == 0) { - fprintf(stderr, "."); - fflush(stderr); - } - } - - fprintf(stderr, " done\n"); + }; - fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded); - if (model.n_loaded == 0) { - fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); - } else if (model.n_loaded != (int) model.tensors.size()) { - fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); - return false; + // load the tensor data into memory without copying or reading it + size_t offset = fin.tellg(); + size_t tensor_data_size = ggml_nbytes(tensor); + offset = (offset + 31) & -32; + tensor->data = mm_addr + offset; + fin.seekg(offset + tensor_data_size); + total_size += tensor_data_size; + model.n_loaded++; + + // progress + if (progress_callback) { + double current_progress = size_t(fin.tellg()) / double(file_size); + progress_callback(current_progress, progress_callback_user_data); } } fin.close(); + + fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded); + if (model.n_loaded == 0) { + fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (model.n_loaded != (int) model.tensors.size()) { + fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + return false; + } } - lctx.t_load_us = ggml_time_us() - t_start_us; + // loading time will be recalculate after the first eval, so + // we take page faults deferred by mmap() into consideration + lctx.t_load_us = ggml_time_us() - lctx.t_start_us; if (progress_callback) { progress_callback(1.0, progress_callback_user_data); @@ -849,6 +776,7 @@ static bool llama_eval_internal( struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size(), /*.mem_buffer =*/ buf_compute.data(), + /*.no_alloc =*/ false, }; struct ggml_context * ctx0 = ggml_init(params); @@ -1126,7 +1054,7 @@ struct llama_tokenizer { size_t offs = 0; while (offs < text.size()) { llama_sp_symbol sym; - size_t char_len = std::min(text.size() - offs, utf8_len(text[offs])); + size_t char_len = Min(text.size() - offs, utf8_len(text[offs])); sym.text = text.c_str() + offs; sym.n = char_len; offs += char_len; @@ -1291,7 +1219,7 @@ static llama_vocab::id llama_sample_top_p_top_k( float maxl = -std::numeric_limits::infinity(); for (const auto & kv : logits_id) { - maxl = std::max(maxl, kv.first); + maxl = Max(maxl, kv.first); } // compute probs for the top k tokens @@ -1385,8 +1313,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s return false; } if (magic != LLAMA_FILE_MAGIC) { - fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); - return false; + return report_bad_magic(fname_inp.c_str(), magic, LLAMA_FILE_MAGIC); } fout.write((char *) &magic, sizeof(magic)); @@ -1452,8 +1379,8 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s fout.write((char *) &len, sizeof(len)); word.resize(len); - finp.read ((char *) word.data(), len); - fout.write((char *) word.data(), len); + finp.read ((char *) &word[0], len); + fout.write((char *) &word[0], len); float score; finp.read ((char *) &score, sizeof(score)); @@ -1503,6 +1430,13 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s std::string name(length, 0); finp.read (&name[0], length); + { + // ensure tensor data is aligned + uint64_t offset = finp.tellg(); + offset = (offset + 31) & -32; + finp.seekg(offset); + } + { static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]); @@ -1558,6 +1492,13 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s } fout.write(&name[0], length); + { + // ensure tensor data is aligned + uint64_t offset = fout.tellp(); + offset = (offset + 31) & -32; + fout.seekp(offset); + } + if (quantize) { printf("quantizing .. "); work.resize(nelements); // for quantization @@ -1655,7 +1596,10 @@ struct llama_context * llama_init_from_file( if (params.use_mlock) { char *err; - if (!ggml_mlock(ctx->model.ctx, &err)) { + if (!ggml_mlock(ctx->model.ctx, + ctx->model.mm_addr, + ctx->model.mm_length, + &err)) { fprintf(stderr, "%s\n", err); free(err); llama_free(ctx); @@ -1705,6 +1649,10 @@ void llama_free(struct llama_context * ctx) { ggml_free(ctx->model.ctx); } + if (ctx->model.mm_addr) { + munmap_file(ctx->model.mm_addr, ctx->model.mm_length); + } + delete ctx; } @@ -1730,7 +1678,11 @@ int llama_eval( fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } - + // get a more accurate load time, upon first eval + if (!ctx->has_evaluated_once) { + ctx->t_load_us = ggml_time_us() - ctx->t_start_us; + ctx->has_evaluated_once = true; + } return 0; } @@ -1823,9 +1775,9 @@ llama_token llama_sample_top_p_top_k( void llama_print_timings(struct llama_context * ctx) { const int64_t t_end_us = ggml_time_us(); - const int32_t n_sample = std::max(1, ctx->n_sample); - const int32_t n_eval = std::max(1, ctx->n_eval); - const int32_t n_p_eval = std::max(1, ctx->n_p_eval); + const int32_t n_sample = Max(1, ctx->n_sample); + const int32_t n_eval = Max(1, ctx->n_eval); + const int32_t n_p_eval = Max(1, ctx->n_p_eval); fprintf(stderr, "\n"); fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0); @@ -1837,7 +1789,6 @@ void llama_print_timings(struct llama_context * ctx) { void llama_reset_timings(struct llama_context * ctx) { ctx->t_start_us = ggml_time_us(); - ctx->t_sample_us = ctx->n_sample = 0; ctx->t_eval_us = ctx->n_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; diff --git a/llama.h b/llama.h index 3368de3e00f58..258de5a944976 100644 --- a/llama.h +++ b/llama.h @@ -20,7 +20,7 @@ #endif #define LLAMA_FILE_VERSION 1 -#define LLAMA_FILE_MAGIC 0x67676d66 // 'ggmf' in hex +#define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex #define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files #ifdef __cplusplus diff --git a/migrate-ggml-2023-03-30-pr613.py b/migrate-ggml-2023-03-30-pr613.py new file mode 100644 index 0000000000000..5596f6c5513a5 --- /dev/null +++ b/migrate-ggml-2023-03-30-pr613.py @@ -0,0 +1,313 @@ +# Migrate ggml file(s) with ggmf magic to ggml file with ggjt magic +# +# We caused a breaking change to the file format on 2023-03-30 in: +# https://github.com/ggerganov/llama.cpp/pull/613 +# +# (1) If you still have the Meta LLaMA .pth files, then close this +# file now; you can just run `convert-pth-to-ggml.py` again to +# migrate to the new format. The tool is easier to use too. It +# isn't necessary anymore to manage split output files because +# the new format always combines things into a single file. +# +# (2) If you deleted the Meta LLaMA .pth files due to save on disk +# space, then this tool is intended to help you. Please check +# out the instructions below. +# +# USAGE +# +# python migrate-ggml-2023-03-30-pr613.py INPUT OUTPUT +# +# PREREQUISITES +# +# pip install numpy +# cd llama.cpp +# make -j4 +# +# EXAMPLE (7B MODEL) +# +# # you can replace all the 'f16' with 'q4_0' if you're using quantized weights +# python migrate-ggml-2023-03-30-pr613.py models/7B/ggml-model-f16.bin models/7B/ggml-model-f16-ggjt.bin +# +# # check that it works +# ./main -m models/7B/ggml-model-f16-ggjt.bin -p 'Question: Do you love me?' +# +# # you can delete the old files +# rm -f models/7B/ggml-model-f16.bin +# mv models/7B/ggml-model-f16-ggjt.bin models/7B/ggml-model-f16.bin +# +# EXAMPLE (13B MODEL) +# +# # you can replace all the 'f16' with 'q4_0' if you're using quantized weights +# python migrate-ggml-2023-03-30-pr613.py models/13B/ggml-model-f16.bin models/13B/ggml-model-f16-ggjt.bin +# +# # check that it works +# ./main -m models/13B/ggml-model-f16-ggjt.bin -p 'Question: Do you love me?' +# +# # you can delete the old files +# rm -f models/13B/ggml-model-f16.bin* +# mv models/13B/ggml-model-f16-ggjt.bin models/13B/ggml-model-f16.bin +# + +import argparse +import os +import sys +import json +import struct +import numpy as np + +QK = 32 + +GGML_TYPE_Q4_0 = 0 +GGML_TYPE_Q4_1 = 1 +GGML_TYPE_I8 = 2 +GGML_TYPE_I16 = 3 +GGML_TYPE_I32 = 4 +GGML_TYPE_F16 = 5 +GGML_TYPE_F32 = 6 + +WTYPE_NAMES = { + 0: "F32", + 1: "F16", + 2: "Q4_0", + 3: "Q4_1", +} + +WTYPES = { + 0: GGML_TYPE_F32, + 1: GGML_TYPE_F16, + 2: GGML_TYPE_Q4_0, + 3: GGML_TYPE_Q4_1, +} + +GGML_BLCK_SIZE = { + GGML_TYPE_Q4_0: QK, + GGML_TYPE_Q4_1: QK, + GGML_TYPE_I8: 1, + GGML_TYPE_I16: 1, + GGML_TYPE_I32: 1, + GGML_TYPE_F16: 1, + GGML_TYPE_F32: 1, +} + +GGML_TYPE_SIZE = { + GGML_TYPE_Q4_0: 4 + QK//2, + GGML_TYPE_Q4_1: 4*2 + QK//2, + GGML_TYPE_I8: 1, + GGML_TYPE_I16: 2, + GGML_TYPE_I32: 4, + GGML_TYPE_F16: 2, + GGML_TYPE_F32: 4, +} + +HPARAMS = [ + 'magic', # int32 + 'version', # int32 + 'n_vocab', # int32 + 'n_embd', # int32 + 'n_mult', # int32 + 'n_head', # int32 + 'n_layer', # int32 + 'n_rot', # int32 + 'f16', # int32 +] + +def read_hparams(fin): + struct_fmt = "i" * len(HPARAMS) + struct_size = struct.calcsize(struct_fmt) + buf = fin.read(struct_size) + ints = struct.unpack(struct_fmt, buf) + hparams = dict(zip(HPARAMS, ints)) + return hparams + +def write_hparams(fout, hparams): + struct_fmt = "i" * len(HPARAMS) + struct_size = struct.calcsize(struct_fmt) + ints = [hparams[h] for h in HPARAMS] + fout.write(struct.pack(struct_fmt, *ints)) + +def read_tokens(fin, hparams): + tokens = [] + for i in range(hparams['n_vocab']): + len_b = fin.read(4) + (length,) = struct.unpack("i", len_b) + word = fin.read(length) + score_b = fin.read(4) + (score,) = struct.unpack("f", score_b) + tokens.append((word, score)) + return tokens + +def write_tokens(fout, tokens): + for word, score in tokens: + fout.write(struct.pack("i", len(word))) + fout.write(word) + fout.write(struct.pack("f", score)) + +def ggml_nelements(shape): + r = 1 + for i in shape: + r *= i + return r + +def ggml_nbytes(shape, ftype): + x = ggml_nelements(shape) + t = WTYPES[ftype] + x *= GGML_TYPE_SIZE[t] + x //= GGML_BLCK_SIZE[t] + return x + +def copy_tensors(fin, fout, part_id, n_parts): + while True: + + b = fin.read(4) + if not b: break + (n_dims,) = struct.unpack("i", b) + b = fin.read(4) + (length,) = struct.unpack("i", b) + b = fin.read(4) + (ftype,) = struct.unpack("i", b) + + assert n_dims in (1, 2) + + partshape = list(range(n_dims)) + for i in range(n_dims): + b = fin.read(4) + partshape[i] = struct.unpack("i", b)[0] + partshape = list(reversed(partshape)) + + name = fin.read(length) + data = fin.read(ggml_nbytes(partshape, ftype)) + + blck_size = GGML_BLCK_SIZE[WTYPES[ftype]] + type_size = GGML_TYPE_SIZE[WTYPES[ftype]] + + print(f"Processing tensor {name} with shape: {partshape} and type: {WTYPE_NAMES[ftype]}") + + # determine dimension along which multipart tensor is sharded + # + # split_dim 0 regex: + # - output.* + # - layers.*.attention.wq.weight + # - layers.*.attention.wk.weight + # - layers.*.attention.wv.weight + # - layers.*.feed_forward.w1.weight + # - layers.*.feed_forward.w3.weight + # + # split_dim 1 regex: + # - tok_embeddings.* + # - layers.*.attention.wo.weight + # - layers.*.feed_forward.w2.weight + # + if n_dims > 1: + split_dim = 1 + if b"tok_embeddings" in name: + split_dim = 1 + elif b"layers" in name: + if b"attention.wo.weight" in name: + split_dim = 1 + elif b"feed_forward.w2.weight" in name: + split_dim = 1 + else: + split_dim = 0 + elif b"output" in name: + split_dim = 0 + + # output tensor header + fullshape = list(partshape) + if n_dims > 1: + fullshape[split_dim] *= n_parts + fout.write(struct.pack("iii", n_dims, len(name), ftype)) + for dim in reversed(fullshape): + fout.write(struct.pack("i", dim)) + fout.write(name) + + # ensure tensor data is aligned + tensor_data_offset = fout.tell() + while tensor_data_offset % QK != 0: + fout.write(struct.pack("B", 0)) + tensor_data_offset += 1 + + # output unified mappable tensor data + if n_dims == 1 or n_parts == 1: + # copy tensor which we thankfully received in one piece + if part_id == 0: + fout.write(data) + elif split_dim == 0: + # reassemble multifile tensor containing some of the rows + rows_per_chunk = partshape[0] + current_row = part_id * rows_per_chunk + bytes_per_row = fullshape[1] // blck_size * type_size + offset = current_row * bytes_per_row + fout.seek(tensor_data_offset + offset) + fout.write(data) + elif split_dim == 1: + # reassemble multifile tensor containing some of the cols + cols_per_chunk = partshape[1] + current_col = part_id * cols_per_chunk + bpr = partshape[1] // blck_size * type_size + bytes_per_row = fullshape[1] // blck_size * type_size + offset_current_col = current_col // blck_size * type_size + for row in range(partshape[0]): + offset_row = row * bytes_per_row + offset = offset_row + offset_current_col + fout.seek(tensor_data_offset + offset) + fout.write(data[row * bpr:row * bpr + bpr]) + + # advance file position to next tensor + fout.seek(tensor_data_offset + ggml_nbytes(fullshape, ftype)) + +def parse_args(): + parser = argparse.ArgumentParser(description='Migrate from GGML to new GGJT file format') + parser.add_argument('fin_path', help='your old ggml file (leave out the .1 .2 etc.)') + parser.add_argument('fout_path', help='your new ggjt file name') + return parser.parse_args() + +def main(): + args = parse_args() + assert args.fin_path + assert args.fout_path + assert args.fin_path != args.fout_path + + with open(args.fin_path, "rb") as fin: + hparams = read_hparams(fin) + tokens = read_tokens(fin, hparams) + + if hparams['magic'] == 0x67676a74: # ggjt + print("%s: input ggml has already been converted to 'ggjt' magic\n" % + (args.fin_path)) + sys.exit(1) + + if hparams['magic'] != 0x67676d66: # ggmf + print("%s: input ggml file doesn't have expected 'ggmf' magic: %#x\n" % + (args.fin_path, hparams['magic'])) + sys.exit(1) + + hparams['magic'] = 0x67676a74 # ggjt + + # count number of multipart files by convention + n_parts = 1 + while True: + if os.path.exists("%s.%d" % (args.fin_path, n_parts)): + n_parts += 1 + else: + break + + # we output a single file for ggml + with open(args.fout_path, "wb") as fout: + write_hparams(fout, hparams) + write_tokens(fout, tokens) + offset_of_tensors = fout.tell() + # the tensors we load could be split across multiple files + for part_id in range(n_parts): + fout.seek(offset_of_tensors) + print(f"Processing part {part_id+1} of {n_parts}\n") + fin_path = args.fin_path + if part_id > 0: + fin_path += ".%d" % (part_id) + with open(fin_path, "rb") as fin: + read_tokens(fin, read_hparams(fin)) + copy_tensors(fin, fout, part_id, n_parts) + + print(f"Done. Output file: {args.fout_path}\n") + +if __name__ == "__main__": + main() diff --git a/models/ggml-vocab.bin b/models/ggml-vocab.bin index 3651f708e80ea..38f63493a97a7 100644 Binary files a/models/ggml-vocab.bin and b/models/ggml-vocab.bin differ