Skip to content

Commit da424b6

Browse files
committed
llama : gguf_file_saver write I32
1 parent 9574f41 commit da424b6

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

ggml.c

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19039,16 +19039,20 @@ enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) {
1903919039
return ctx->header.kv[i].value.arr.type;
1904019040
}
1904119041

19042-
const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) {
19043-
struct gguf_kv * kv = &ctx->header.kv[key_id];
19044-
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
19045-
return str->data;
19042+
int32_t gguf_get_arr_i32(struct gguf_context * ctx, int key_id, int i) {
19043+
return ((int32_t *) ctx->header.kv[key_id].value.arr.data)[i];
1904619044
}
1904719045

1904819046
float gguf_get_arr_f32(struct gguf_context * ctx, int key_id, int i) {
1904919047
return ((float *) ctx->header.kv[key_id].value.arr.data)[i];
1905019048
}
1905119049

19050+
const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) {
19051+
struct gguf_kv * kv = &ctx->header.kv[key_id];
19052+
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
19053+
return str->data;
19054+
}
19055+
1905219056
int gguf_get_arr_n(struct gguf_context * ctx, int i) {
1905319057
return ctx->header.kv[i].value.arr.n;
1905419058
}

ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1751,8 +1751,9 @@ extern "C" {
17511751
GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
17521752
GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i);
17531753

1754-
GGML_API const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i);
17551754
GGML_API float gguf_get_arr_f32(struct gguf_context * ctx, int key_id, int i);
1755+
GGML_API int32_t gguf_get_arr_i32(struct gguf_context * ctx, int key_id, int i);
1756+
GGML_API const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i);
17561757

17571758
GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i);
17581759
GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i);

gguf-llama.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -737,11 +737,24 @@ struct gguf_file_saver {
737737
file.write_arr<float>(key, type, data);
738738
}
739739

740+
void write_kv_arr_i32(const std::string & key, enum gguf_type type, int i, int n_arr) {
741+
std::vector<int32_t> data(n_arr);
742+
743+
for (int j = 0; j < n_arr; ++j) {
744+
int32_t val = gguf_get_arr_i32(ctx, i, j);
745+
data[j] = val;
746+
}
747+
748+
file.write_arr<int32_t>(key, type, data);
749+
}
750+
740751
// re-write the key-value section from the loaded file
741752
void write_kv() {
742753
const int32_t n_kv = gguf_get_n_kv(ctx);
743754
for (int i = 0; i < n_kv; ++i) {
744755
const char * key = gguf_get_key(ctx, i);
756+
LLAMA_LOG_INFO("%s: writing key '%s'\n", __func__, key);
757+
745758
if (strcmp(key, "general.quantization_version") == 0) {
746759
file.write_val<uint32_t>("general.quantization_version", GGUF_TYPE_UINT32, GGML_QNT_VERSION);
747760
} else {
@@ -761,12 +774,13 @@ struct gguf_file_saver {
761774
{
762775
const gguf_type arr_type = gguf_get_arr_type(ctx, i);
763776
const int n_arr = gguf_get_arr_n (ctx, i);
764-
if (arr_type == GGUF_TYPE_FLOAT32) {
765-
write_kv_arr_f32(key, arr_type, i, n_arr);
766-
} else if (arr_type == GGUF_TYPE_STRING) {
767-
write_kv_arr_str(key, arr_type, i, n_arr);
768-
} else {
769-
throw std::runtime_error("not implemented");
777+
778+
switch (arr_type) {
779+
case GGUF_TYPE_FLOAT32: write_kv_arr_f32(key, arr_type, i, n_arr); break;
780+
case GGUF_TYPE_INT32: write_kv_arr_i32(key, arr_type, i, n_arr); break;
781+
case GGUF_TYPE_STRING: write_kv_arr_str(key, arr_type, i, n_arr); break;
782+
default:
783+
throw std::runtime_error(format("cannot recognize array type for key %s\n", key));
770784
}
771785
} break;
772786
default:

0 commit comments

Comments
 (0)