From 8fa413d8b5dbb7985c8ea59e71b3cf00720fd419 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Wed, 1 May 2024 15:15:12 +0800 Subject: [PATCH 01/22] add phi3 128k support in convert-hf-to-gguf --- convert-hf-to-gguf.py | 45 ++++++++++++++++++++++++++++++++----- gguf-py/gguf/constants.py | 16 ++++++++----- gguf-py/gguf/gguf_writer.py | 9 ++++++++ 3 files changed, 59 insertions(+), 11 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 1acf45bf2f48e..8cb21d2238c3e 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -14,6 +14,7 @@ from hashlib import sha256 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast +import math import numpy as np import torch @@ -1784,23 +1785,57 @@ def set_vocab(self): def set_gguf_parameters(self): block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) - rot_pct = 1.0 n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) rms_eps = self.find_hparam(["rms_norm_eps"]) + max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) + orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) + rope_dims = n_embd // n_head self.gguf_writer.add_name("Phi3") - self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"])) - + self.gguf_writer.add_context_length(max_pos_embds) + self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds) self.gguf_writer.add_embedding_length(n_embd) - self.gguf_writer.add_feed_forward_length(8192) + self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"])) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head) self.gguf_writer.add_layer_norm_rms_eps(rms_eps) - self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) + self.gguf_writer.add_rope_dimension_count(rope_dims) + self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"])) self.gguf_writer.add_file_type(self.ftype) + # write rope scaling for long context (128k) model + rope_scaling = self.find_hparam(['rope_scaling'], True) + if (rope_scaling is None): + return + + scale = max_pos_embds / orig_max_pos_embds + + rope_scaling_type = rope_scaling.get('type', '').lower() + if len(rope_scaling_type) == 0: + raise KeyError(f'Missing the required key rope_scaling.type') + + if rope_scaling_type == 'su': + attn_factor = math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) if scale > 1.0 else 1.0 + elif rope_scaling_type == 'yarn': + attn_factor = 0.1 * math.log(scale) + 1.0 if scale > 1.0 else 1.0 + else: + raise NotImplementedError(f'The rope scaling type {rope_scaling_type} is not supported yet') + + self.gguf_writer.add_rope_scaling_attn_factors(attn_factor) + + long_factors = rope_scaling.get('long_factor', None) + short_factors = rope_scaling.get('short_factor', None) + + if long_factors is None or short_factors is None: + raise KeyError(f'Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + + if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: + raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') + + self.gguf_writer.add_rope_scaling_freq_long_factors(long_factors) + self.gguf_writer.add_rope_scaling_freq_short_factors(short_factors) @Model.register("PlamoForCausalLM") class PlamoModel(Model): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 692120f4d64b0..79d3033ef0189 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -57,12 +57,15 @@ class Attention: CAUSAL = "{arch}.attention.causal" class Rope: - DIMENSION_COUNT = "{arch}.rope.dimension_count" - FREQ_BASE = "{arch}.rope.freq_base" - SCALING_TYPE = "{arch}.rope.scaling.type" - SCALING_FACTOR = "{arch}.rope.scaling.factor" - SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" - SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + DIMENSION_COUNT = "{arch}.rope.dimension_count" + FREQ_BASE = "{arch}.rope.freq_base" + SCALING_TYPE = "{arch}.rope.scaling.type" + SCALING_FACTOR = "{arch}.rope.scaling.factor" + SCALING_LONG_FACTORS = "{arch}.rope.scaling.freq_long_factors" + SCALING_SHORT_FACTORS = "{arch}.rope.scaling.freq_short_factors" + SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" + SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" + SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" class SSM: CONV_KERNEL = "{arch}.ssm.conv_kernel" @@ -780,6 +783,7 @@ class RopeScalingType(Enum): NONE = 'none' LINEAR = 'linear' YARN = 'yarn' + SU = 'su' class PoolingType(IntEnum): diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index d5e323a52ef14..e583fc9b50853 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -433,6 +433,15 @@ def add_rope_scaling_type(self, value: RopeScalingType) -> None: def add_rope_scaling_factor(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value) + def add_rope_scaling_freq_long_factors(self, value: Sequence[float]) -> None: + self.add_array(Keys.Rope.SCALING_LONG_FACTORS.format(arch=self.arch), value) + + def add_rope_scaling_freq_short_factors(self, value: Sequence[float]) -> None: + self.add_array(Keys.Rope.SCALING_SHORT_FACTORS.format(arch=self.arch), value) + + def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None: + self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value) + def add_rope_scaling_orig_ctx_len(self, value: int) -> None: self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value) From 56d9fa72de5ce22c4ecefa43a8b04996d8c9f1bc Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Thu, 2 May 2024 02:06:20 +0800 Subject: [PATCH 02/22] add phi3 128k support in cuda --- ggml-cuda/rope.cu | 64 ++++++++++++++++++++++--------- ggml.c | 73 ++++++++++++++++++++--------------- llama.cpp | 98 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 185 insertions(+), 50 deletions(-) diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index 4b0d2e5adbbc5..cba5050fabd78 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -58,10 +58,10 @@ static __global__ void rope( dst[i + 1] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_neox( const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -88,7 +88,9 @@ static __global__ void rope_neox( float cur_rot = inv_ndims * ic - ib; const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); + const float freq_factor = has_freq_facs ? freq_factors[col/2] : 1.0f; + + const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor; float cos_theta, sin_theta; rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); @@ -164,7 +166,7 @@ static void rope_cuda( template static void rope_neox_cuda( const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float* freq_factors, cudaStream_t stream ) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); @@ -175,15 +177,32 @@ static void rope_neox_cuda( const float inv_ndims = -1.0f / n_dims; if (pos == nullptr) { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims - ); - } else { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims - ); + if (freq_factors == nullptr) { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } + else { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } + } + else { + if (freq_factors == nullptr) { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } + else { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } } } @@ -214,17 +233,17 @@ static void rope_cuda_f32( static void rope_neox_cuda_f16( const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) { + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float* freq_factors, cudaStream_t stream) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } static void rope_neox_cuda_f32( const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float* freq_factors, cudaStream_t stream ) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -259,11 +278,18 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + const float* freq_factors = nullptr; const int32_t * pos = nullptr; if ((mode & 1) == 0) { GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->ne[0] == ne2); pos = (const int32_t *) src1_d; + + if (dst->src[2] != nullptr) { + GGML_ASSERT(dst->src[2]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[2]->ne[0] >= n_dims / 2); + freq_factors = (const float*) dst->src[2]->data; + } } const bool is_neox = mode & 2; @@ -280,12 +306,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (src0->type == GGML_TYPE_F32) { rope_neox_cuda_f32( (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + attn_factor, corr_dims, freq_factors, stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_cuda_f16( (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + attn_factor, corr_dims, freq_factors, stream ); } else { GGML_ASSERT(false); diff --git a/ggml.c b/ggml.c index 4bd911528586b..ecbb0db80ca48 100644 --- a/ggml.c +++ b/ggml.c @@ -6275,6 +6275,13 @@ static struct ggml_tensor * ggml_rope_impl( return result; } +struct ggml_tensor * ggml_rope_with_freq_factors( + struct ggml_tensor* rope_tensor, + struct ggml_tensor* freq_factors) { + rope_tensor->src[2] = freq_factors; + return rope_tensor; +} + struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, @@ -18915,21 +18922,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, - ggml_rope_back(ctx, - tensor->grad, - src1, - n_dims, - mode, - n_ctx, - n_orig_ctx, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow, - xpos_base, - xpos_down), + ggml_rope_with_freq_factors( + ggml_rope_back(ctx, + tensor->grad, + src1, + n_dims, + mode, + n_ctx, + n_orig_ctx, + freq_base, + freq_scale, + ext_factor, + attn_factor, + beta_fast, + beta_slow, + xpos_base, + xpos_down), + tensor->src[2]), zero_table); } } break; @@ -18954,22 +18963,24 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, - ggml_rope_impl(ctx, - tensor->grad, - src1, - n_dims, - mode, - n_ctx, - n_orig_ctx, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow, - xpos_base, - xpos_down, - false), + ggml_rope_with_freq_factors( + ggml_rope_impl(ctx, + tensor->grad, + src1, + n_dims, + mode, + n_ctx, + n_orig_ctx, + freq_base, + freq_scale, + ext_factor, + attn_factor, + beta_fast, + beta_slow, + xpos_base, + xpos_down, + false), + tensor->src[2]), zero_table); } } break; diff --git a/llama.cpp b/llama.cpp index d26fe559a2051..de2b91ac4c868 100644 --- a/llama.cpp +++ b/llama.cpp @@ -304,6 +304,9 @@ enum llm_kv { LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_LONG_FACTORS, + LLM_KV_ROPE_SCALING_SHORT_FACTORS, + LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, @@ -381,6 +384,9 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_LONG_FACTORS, "%s.rope.scaling.freq_long_factors" }, + { LLM_KV_ROPE_SCALING_SHORT_FACTORS, "%s.rope.scaling.freq_short_factors" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, @@ -1754,6 +1760,10 @@ struct llama_hparams { float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; + std::vector rope_long_factors; + std::vector rope_short_factors; + float rope_attn_factor = 1.0f; + // for State Space Models uint32_t ssm_d_conv = 0; uint32_t ssm_d_inner = 0; @@ -1789,6 +1799,10 @@ struct llama_hparams { if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; + if (this->rope_long_factors != other.rope_long_factors) return true; + if (this->rope_short_factors != other.rope_short_factors) return true; + if (this->rope_attn_factor != other.rope_attn_factor) return true; + if (this->ssm_d_conv != other.ssm_d_conv) return true; if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_state != other.ssm_d_state) return true; @@ -2246,6 +2260,8 @@ struct llama_context { struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * freq_factors = nullptr; // F32 [kv_size / 2] + // control vectors struct llama_control_vector cvec; }; @@ -3306,6 +3322,39 @@ struct llama_model_loader { return get_arr_n(llm_kv(kid), result, required); } + template + bool get_arr(const std::string& key, std::vector& result, const bool required = true) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) { + throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str())); + } + + // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); + GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same::value)); + GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); + + result.resize(arr_info.length); + result.assign((T*)arr_info.data, (T*)arr_info.data + arr_info.length); + + return true; + } + + template + bool get_arr(const enum llm_kv kid, T& result, const bool required = true) { + return get_arr(llm_kv(kid), result, required); + } + template bool get_key(const std::string & key, T & result, const bool required = true) { auto it = kv_overrides.find(key); @@ -3849,6 +3898,14 @@ static void llm_load_hparams( } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; + ml.get_arr(LLM_KV_ROPE_SCALING_LONG_FACTORS, hparams.rope_long_factors, false); + ml.get_arr(LLM_KV_ROPE_SCALING_SHORT_FACTORS, hparams.rope_short_factors, false); + + GGML_ASSERT(hparams.rope_long_factors.size() == 0 || hparams.rope_long_factors.size() == hparams.n_embd / hparams.n_head / 2); + GGML_ASSERT(hparams.rope_long_factors.size() == hparams.rope_short_factors.size()); + + ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + // sanity check for n_rot (optional) { hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; @@ -6821,6 +6878,8 @@ struct llm_build_context { cb(lctx.inp_K_shift, "K_shift", -1); ggml_set_input(lctx.inp_K_shift); + lctx.freq_factors = build_freq_factors(); + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * tmp = // we rotate only the first n_rot dimensions @@ -6832,6 +6891,9 @@ struct llm_build_context { 0), lctx.inp_K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + + tmp = ggml_rope_with_freq_factors(tmp, lctx.freq_factors); + cb(tmp, "K_shifted", il); ggml_build_forward_expand(gf, tmp); } @@ -6934,6 +6996,20 @@ struct llm_build_context { return lctx.inp_pos; } + struct ggml_tensor* build_freq_factors() { + + if (hparams.rope_long_factors.empty() || hparams.rope_short_factors.empty()) { + lctx.freq_factors = nullptr; + return nullptr; + } + + lctx.freq_factors = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_embd_head_k / 2); + cb(lctx.freq_factors, "freq_factors", -1); + ggml_set_input(lctx.freq_factors); + + return lctx.freq_factors; + } + struct ggml_tensor * build_inp_out_ids() { lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs); cb(lctx.inp_out_ids, "inp_out_ids", -1); @@ -9052,6 +9128,9 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + // rope freq factors for 128k context + struct ggml_tensor* freq_factors = build_freq_factors(); + for (int il = 0; il < n_layer; ++il) { auto residual = inpL; @@ -9092,6 +9171,7 @@ struct llm_build_context { ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + Qcur = ggml_rope_with_freq_factors(Qcur, freq_factors); cb(Qcur, "Qcur", il); Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); @@ -9101,6 +9181,7 @@ struct llm_build_context { ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + Kcur = ggml_rope_with_freq_factors(Kcur, freq_factors); cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, @@ -10890,6 +10971,22 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } + if (lctx.freq_factors) { + auto freq_dim = hparams.n_embd_head_k / 2; + + GGML_ASSERT(lctx.freq_factors->ne[0] == freq_dim); + GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim); + GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim); + + auto max_pos = batch.n_tokens > 0 && batch.pos != nullptr ? *std::max_element(batch.pos, batch.pos + batch.n_tokens) : batch.n_tokens - 1; + if (max_pos + 1 > hparams.n_yarn_orig_ctx) { + ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); + } + else { + ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); + } + } + if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = batch.n_tokens; @@ -15417,6 +15514,7 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; } + cparams.yarn_attn_factor *= hparams.rope_attn_factor; cparams.causal_attn = hparams.causal_attn; if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { From cc19780a557020a0988a6da847b7f243555f1a3d Mon Sep 17 00:00:00 2001 From: liuwei Date: Wed, 1 May 2024 18:50:10 +0000 Subject: [PATCH 03/22] address build warnings on llama.cpp --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index de2b91ac4c868..dd42bd9059d5e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3345,7 +3345,7 @@ struct llama_model_loader { GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); result.resize(arr_info.length); - result.assign((T*)arr_info.data, (T*)arr_info.data + arr_info.length); + result.assign((const T*)arr_info.data, (const T*)arr_info.data + arr_info.length); return true; } @@ -10979,7 +10979,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim); auto max_pos = batch.n_tokens > 0 && batch.pos != nullptr ? *std::max_element(batch.pos, batch.pos + batch.n_tokens) : batch.n_tokens - 1; - if (max_pos + 1 > hparams.n_yarn_orig_ctx) { + if ((uint32_t)(max_pos + 1) > hparams.n_yarn_orig_ctx) { ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); } else { From 9f871298b63c02b45c0d36e0ca88e14f105235a0 Mon Sep 17 00:00:00 2001 From: liuwei Date: Sat, 11 May 2024 10:38:22 +0000 Subject: [PATCH 04/22] adjust index value in cuda long rope freq factors --- ggml-cuda/rope.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index cba5050fabd78..fee56a1f27e69 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -88,7 +88,7 @@ static __global__ void rope_neox( float cur_rot = inv_ndims * ic - ib; const int p = has_pos ? pos[i2] : 0; - const float freq_factor = has_freq_facs ? freq_factors[col/2] : 1.0f; + const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f; const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor; From c5569311a4307ffd4d017a6595822d55e77fb777 Mon Sep 17 00:00:00 2001 From: liuwei Date: Sat, 11 May 2024 10:44:31 +0000 Subject: [PATCH 05/22] add long rope support in ggml cpu backend --- ggml.c | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index ecbb0db80ca48..169f98b2318db 100644 --- a/ggml.c +++ b/ggml.c @@ -14370,6 +14370,15 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const float* freq_factors = NULL; + if (is_neox) { + if (dst->src[2] != NULL) { + GGML_ASSERT(dst->src[2]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[2]->ne[0] >= n_dims / 2); + freq_factors = (const float*) dst->src[2]->data; + } + } + // backward process uses inverse rotation by cos and sin. // cos and sin build a rotation matrix, where the inverse is the transpose. // this essentially just switches the sign of sin. @@ -14446,10 +14455,11 @@ static void ggml_compute_forward_rope_f32( // simplified from `(ib * n_dims + ic) * inv_ndims` float cur_rot = inv_ndims * ic - ib; + float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f; float cos_theta, sin_theta; rope_yarn( - theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta ); sin_theta *= sin_sign; From 6333ed1a3078ac59240b3dc5338f1043ecafba97 Mon Sep 17 00:00:00 2001 From: liuwei Date: Sat, 11 May 2024 17:25:08 +0000 Subject: [PATCH 06/22] make freq factors only depend on ctx size --- llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index dd42bd9059d5e..3bfe2910b1a51 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10978,12 +10978,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim); GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim); - auto max_pos = batch.n_tokens > 0 && batch.pos != nullptr ? *std::max_element(batch.pos, batch.pos + batch.n_tokens) : batch.n_tokens - 1; - if ((uint32_t)(max_pos + 1) > hparams.n_yarn_orig_ctx) { + auto n_ctx = llama_n_ctx(&lctx); + if (n_ctx > hparams.n_yarn_orig_ctx) { ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); } else { - ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); + ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); } } From 5683db3bf70a9b56562c8b3cfd4dcd6994bc7647 Mon Sep 17 00:00:00 2001 From: liuwei Date: Sat, 11 May 2024 17:55:41 +0000 Subject: [PATCH 07/22] remove unused rope scaling type 'su' frin gguf converter --- gguf-py/gguf/constants.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 79d3033ef0189..75a129c1ab802 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -783,7 +783,6 @@ class RopeScalingType(Enum): NONE = 'none' LINEAR = 'linear' YARN = 'yarn' - SU = 'su' class PoolingType(IntEnum): From b1f491a2978cbc60931538fd8bac5b9f788460d9 Mon Sep 17 00:00:00 2001 From: liuwei Date: Sat, 11 May 2024 19:16:45 +0000 Subject: [PATCH 08/22] fix flint warnings on convert-hf-to-gguf.py --- convert-hf-to-gguf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 8cb21d2238c3e..2c4764c5ec921 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1790,7 +1790,7 @@ def set_gguf_parameters(self): rms_eps = self.find_hparam(["rms_norm_eps"]) max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) - rope_dims = n_embd // n_head + rope_dims = n_embd // n_head self.gguf_writer.add_name("Phi3") self.gguf_writer.add_context_length(max_pos_embds) @@ -1814,7 +1814,7 @@ def set_gguf_parameters(self): rope_scaling_type = rope_scaling.get('type', '').lower() if len(rope_scaling_type) == 0: - raise KeyError(f'Missing the required key rope_scaling.type') + raise KeyError('Missing the required key rope_scaling.type') if rope_scaling_type == 'su': attn_factor = math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) if scale > 1.0 else 1.0 @@ -1829,7 +1829,7 @@ def set_gguf_parameters(self): short_factors = rope_scaling.get('short_factor', None) if long_factors is None or short_factors is None: - raise KeyError(f'Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') @@ -1837,6 +1837,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_freq_long_factors(long_factors) self.gguf_writer.add_rope_scaling_freq_short_factors(short_factors) + @Model.register("PlamoForCausalLM") class PlamoModel(Model): model_arch = gguf.MODEL_ARCH.PLAMO From d05ae12e93d6b400e533f165d673a14aef0e12b3 Mon Sep 17 00:00:00 2001 From: liuwei Date: Sat, 11 May 2024 20:30:32 +0000 Subject: [PATCH 09/22] set to the short freq factor when context size is small than trained context size --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 3bfe2910b1a51..ef9809ecd2347 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10983,7 +10983,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); } else { - ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); + ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); } } From 8a9c897fd0e169dccec3b173885d0b843bb121ba Mon Sep 17 00:00:00 2001 From: liuwei Date: Sat, 11 May 2024 21:42:32 +0000 Subject: [PATCH 10/22] add one line of comments --- llama.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/llama.cpp b/llama.cpp index ef9809ecd2347..a305930878e96 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10978,6 +10978,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim); GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim); + // choose long/short freq factors based on the context size auto n_ctx = llama_n_ctx(&lctx); if (n_ctx > hparams.n_yarn_orig_ctx) { ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); From 2d473a4a9a62c3c11841a8f2274376e78ee727a6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 16 May 2024 12:03:53 +0300 Subject: [PATCH 11/22] metal : support rope freq_factors --- ggml-cuda/rope.cu | 26 +++++---- ggml-metal.m | 131 +++++++++++++++++++++++++++------------------- ggml-metal.metal | 6 ++- ggml.c | 13 +++-- 4 files changed, 108 insertions(+), 68 deletions(-) diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index fee56a1f27e69..8d11d9d14b86f 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -249,8 +249,11 @@ static void rope_neox_cuda_f32( void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); @@ -278,23 +281,28 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - const float* freq_factors = nullptr; + const float * freq_factors = nullptr; const int32_t * pos = nullptr; - if ((mode & 1) == 0) { + + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + if (is_neox) { + // TODO: move these asserts to ggml.c GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->ne[0] == ne2); pos = (const int32_t *) src1_d; - if (dst->src[2] != nullptr) { - GGML_ASSERT(dst->src[2]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->src[2]->ne[0] >= n_dims / 2); - freq_factors = (const float*) dst->src[2]->data; + if (src2 != nullptr) { + // TODO: move these asserts to ggml.c + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float*) src2->data; } + } else { + GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for mode 1"); } - const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - rope_corr_dims corr_dims; ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); diff --git a/ggml-metal.m b/ggml-metal.m index b0b16dbf77160..7bc75f39b4bc8 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -927,22 +927,32 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t ne10 = src1 ? src1->ne[0] : 0; const int64_t ne11 = src1 ? src1->ne[1] : 0; const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + const int64_t ne13 = src1 ? src1->ne[3] : 0; const uint64_t nb10 = src1 ? src1->nb[0] : 0; const uint64_t nb11 = src1 ? src1->nb[1] : 0; const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + const uint64_t nb13 = src1 ? src1->nb[3] : 0; - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); + const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; @@ -1785,16 +1795,6 @@ static enum ggml_status ggml_metal_graph_compute( const int n_as = src0->ne[2]; // src2 = ids - const int64_t ne20 = src2->ne[0]; - const int64_t ne21 = src2->ne[1]; - const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22); - const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23); - - const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20); - const uint64_t nb21 = src2->nb[1]; - const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22); - const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23); - const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); GGML_ASSERT(src2t == GGML_TYPE_I32); @@ -2244,7 +2244,13 @@ static enum ggml_status ggml_metal_graph_compute( // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -2252,6 +2258,25 @@ static enum ggml_status ggml_metal_graph_compute( memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal"); + + if (is_neox) { + // TODO: move these asserts to ggml.c + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(src1->ne[0] == ne2); + + if (id_src2 != nil) { + // TODO: move these asserts to ggml.c + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + } + } else { + GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for mode 1"); + } + id pipeline = nil; switch (src0->type) { @@ -2263,33 +2288,38 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:19]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; - [encoder setBytes:&mode length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + if (id_src2 != nil) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; + [encoder setBytes:&mode length:sizeof( int) atIndex:22]; + [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:24]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -2535,11 +2565,6 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; - const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); //const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); diff --git a/ggml-metal.metal b/ggml-metal.metal index 386e9195fcffa..44e54ced2a545 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1640,6 +1640,7 @@ static void rope_yarn_corr_dims( typedef void (rope_t)( device const void * src0, device const int32_t * src1, + device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1675,6 +1676,7 @@ template kernel void kernel_rope( device const void * src0, device const int32_t * src1, + device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1744,8 +1746,10 @@ kernel void kernel_rope( // simplified from `(ib * n_dims + ic) * inv_ndims` const float cur_rot = inv_ndims*ic - ib; + const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f; + + const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor; - const float theta = theta_0 * pow(freq_base, cur_rot); float cos_theta, sin_theta; rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); diff --git a/ggml.c b/ggml.c index 169f98b2318db..fa4efccee5f4b 100644 --- a/ggml.c +++ b/ggml.c @@ -14311,6 +14311,7 @@ static void ggml_compute_forward_rope_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; @@ -14370,13 +14371,15 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & 2; const bool is_glm = mode & 4; - const float* freq_factors = NULL; + const float * freq_factors = NULL; if (is_neox) { - if (dst->src[2] != NULL) { - GGML_ASSERT(dst->src[2]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->src[2]->ne[0] >= n_dims / 2); - freq_factors = (const float*) dst->src[2]->data; + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; } + } else { + GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1"); } // backward process uses inverse rotation by cos and sin. From 471d8170bcea85db5ed87579cada83e077c79b44 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 16 May 2024 13:23:04 +0300 Subject: [PATCH 12/22] ggml : update ggml_rope_ext API to support freq. factors --- examples/finetune/finetune.cpp | 4 +- .../train-text-from-scratch.cpp | 4 +- ggml-cuda/rope.cu | 14 +- ggml.c | 138 ++++++++----- ggml.h | 45 ++++- llama.cpp | 184 +++++++++--------- 6 files changed, 219 insertions(+), 170 deletions(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 22743b1bf02fd..992426c1b69e2 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -563,8 +563,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( // not capturing these, to silcence warnings const int rope_mode = 0; - return ggml_rope_custom(ctx, - t, KQ_pos, n_rot, rope_mode, n_ctx, 0, + return ggml_rope_ext(ctx, + t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f ); }; diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 587418cc73964..45bdfa8f5d80c 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -301,8 +301,8 @@ static struct ggml_tensor * llama_build_train_graphs( // not capturing these, to silcence warnings const int rope_mode = 0; - return ggml_rope_custom( - ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f + return ggml_rope_ext( + ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f ); }; diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index 8d11d9d14b86f..e7968817f15e2 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -166,7 +166,7 @@ static void rope_cuda( template static void rope_neox_cuda( const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float* freq_factors, cudaStream_t stream + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream ) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); @@ -233,14 +233,14 @@ static void rope_cuda_f32( static void rope_neox_cuda_f16( const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float* freq_factors, cudaStream_t stream) { + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } static void rope_neox_cuda_f32( const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float* freq_factors, cudaStream_t stream + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream ) { rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); @@ -288,16 +288,10 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const bool is_glm = mode & 4; if (is_neox) { - // TODO: move these asserts to ggml.c - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(src1->ne[0] == ne2); pos = (const int32_t *) src1_d; if (src2 != nullptr) { - // TODO: move these asserts to ggml.c - GGML_ASSERT(src2->type == GGML_TYPE_F32); - GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float*) src2->data; + freq_factors = (const float *) src2->data; } } else { GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for mode 1"); diff --git a/ggml.c b/ggml.c index fa4efccee5f4b..37b16b7a9ce7f 100644 --- a/ggml.c +++ b/ggml.c @@ -6231,6 +6231,7 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -6248,6 +6249,11 @@ static struct ggml_tensor * ggml_rope_impl( GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); + if (c) { + GGML_ASSERT(c->type == GGML_TYPE_F32); + GGML_ASSERT(c->ne[0] >= n_dims / 2); + } + bool is_node = false; if (a->grad) { @@ -6271,17 +6277,11 @@ static struct ggml_tensor * ggml_rope_impl( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; + result->src[2] = c; return result; } -struct ggml_tensor * ggml_rope_with_freq_factors( - struct ggml_tensor* rope_tensor, - struct ggml_tensor* freq_factors) { - rope_tensor->src[2] = freq_factors; - return rope_tensor; -} - struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, @@ -6290,7 +6290,7 @@ struct ggml_tensor * ggml_rope( int mode, int n_ctx) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false + ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false ); } @@ -6302,14 +6302,15 @@ struct ggml_tensor * ggml_rope_inplace( int mode, int n_ctx) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true + ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true ); } -struct ggml_tensor * ggml_rope_custom( +struct ggml_tensor * ggml_rope_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -6321,15 +6322,16 @@ struct ggml_tensor * ggml_rope_custom( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false ); } -struct ggml_tensor * ggml_rope_custom_inplace( +struct ggml_tensor * ggml_rope_ext_inplace( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -6341,19 +6343,49 @@ struct ggml_tensor * ggml_rope_custom_inplace( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true ); } -struct ggml_tensor * ggml_rope_xpos_inplace( +struct ggml_tensor * ggml_rope_custom( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - float base, - bool down) { - return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true); + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ); +} + +struct ggml_tensor * ggml_rope_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true + ); } // ggml_rope_back @@ -6362,6 +6394,7 @@ struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -6377,6 +6410,7 @@ struct ggml_tensor * ggml_rope_back( GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); + GGML_ASSERT(c == NULL && "freq factors not implemented yet"); GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); @@ -18407,6 +18441,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) { struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src1 = tensor->src[1]; + struct ggml_tensor * src2 = tensor->src[2]; switch (tensor->op) { case GGML_OP_DUP: @@ -18935,23 +18970,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, - ggml_rope_with_freq_factors( - ggml_rope_back(ctx, - tensor->grad, - src1, - n_dims, - mode, - n_ctx, - n_orig_ctx, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow, - xpos_base, - xpos_down), - tensor->src[2]), + ggml_rope_back(ctx, + tensor->grad, + src1, + src2, + n_dims, + mode, + n_ctx, + n_orig_ctx, + freq_base, + freq_scale, + ext_factor, + attn_factor, + beta_fast, + beta_slow, + xpos_base, + xpos_down), zero_table); } } break; @@ -18976,24 +19010,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, - ggml_rope_with_freq_factors( - ggml_rope_impl(ctx, - tensor->grad, - src1, - n_dims, - mode, - n_ctx, - n_orig_ctx, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow, - xpos_base, - xpos_down, - false), - tensor->src[2]), + ggml_rope_impl(ctx, + tensor->grad, + src1, + src2, + n_dims, + mode, + n_ctx, + n_orig_ctx, + freq_base, + freq_scale, + ext_factor, + attn_factor, + beta_fast, + beta_slow, + xpos_base, + xpos_down, + false), zero_table); } } break; @@ -19062,7 +19095,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor masked); } - struct ggml_tensor * src2 = tensor->src[2]; const int64_t elem_q = ggml_nelements(src0); const int64_t elem_k = ggml_nelements(src1); const int64_t elem_v = ggml_nelements(src2); diff --git a/ggml.h b/ggml.h index 77475710129d7..35ac9110ceb17 100644 --- a/ggml.h +++ b/ggml.h @@ -1465,6 +1465,7 @@ extern "C" { // if mode & 4 == 1, ChatGLM style // // b is an int32 vector with size a->ne[2], it contains the positions + // c is freq factors (e.g. phi3-128k), (optional) GGML_API struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1483,10 +1484,11 @@ extern "C" { int n_ctx); // custom RoPE - GGML_API struct ggml_tensor * ggml_rope_custom( + GGML_API struct ggml_tensor * ggml_rope_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -1499,10 +1501,11 @@ extern "C" { float beta_slow); // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_rope_custom_inplace( + GGML_API struct ggml_tensor * ggml_rope_ext_inplace( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, @@ -1514,18 +1517,41 @@ extern "C" { float beta_fast, float beta_slow); - // compute correction dims for YaRN RoPE scaling - GGML_CALL void ggml_rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow), + "use ggml_rope_ext instead"); - // xPos RoPE, in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - float base, - bool down); + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow), + "use ggml_rope_ext_inplace instead"); + + // compute correction dims for YaRN RoPE scaling + GGML_CALL void ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); // rotary position embedding backward, i.e compute dx from dy // a - dy @@ -1533,6 +1559,7 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, int n_ctx, diff --git a/llama.cpp b/llama.cpp index a305930878e96..a08941af10cf2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6883,17 +6883,15 @@ struct llm_build_context { for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * tmp = // we rotate only the first n_rot dimensions - ggml_rope_custom_inplace(ctx0, + ggml_rope_ext_inplace(ctx0, ggml_view_3d(ctx0, kv_self.k_l[il], n_embd_head_k, n_head_kv, n_ctx, ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), 0), - lctx.inp_K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + lctx.inp_K_shift, lctx.freq_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); - tmp = ggml_rope_with_freq_factors(tmp, lctx.freq_factors); - cb(tmp, "K_shifted", il); ggml_build_forward_expand(gf, tmp); } @@ -7117,15 +7115,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -7247,13 +7245,13 @@ struct llm_build_context { switch (model.type) { case MODEL_7B: - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -7359,15 +7357,15 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -7480,14 +7478,14 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -7603,15 +7601,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -7755,15 +7753,15 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8108,15 +8106,15 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8548,15 +8546,15 @@ struct llm_build_context { } - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8668,14 +8666,14 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -8779,15 +8777,15 @@ struct llm_build_context { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8893,15 +8891,15 @@ struct llm_build_context { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9045,8 +9043,8 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); @@ -9056,8 +9054,8 @@ struct llm_build_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9167,21 +9165,19 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, freq_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Qcur = ggml_rope_with_freq_factors(Qcur, freq_factors); cb(Qcur, "Qcur", il); Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, freq_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_with_freq_factors(Kcur, freq_factors); cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, @@ -9285,14 +9281,14 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr, n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr, n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); @@ -9493,15 +9489,15 @@ struct llm_build_context { cb(tmpk, "tmpk", il); cb(Vcur, "Vcur", il); - struct ggml_tensor * Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, + struct ggml_tensor * Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, + struct ggml_tensor * Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9609,15 +9605,15 @@ struct llm_build_context { // cb(Vcur, "Vcur", il); // } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9726,15 +9722,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9856,15 +9852,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -9976,8 +9972,8 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr, n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); @@ -9985,8 +9981,8 @@ struct llm_build_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); cb(Qcur, "Qcur_scaled", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr, n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); @@ -10096,15 +10092,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -10386,15 +10382,15 @@ struct llm_build_context { cb(Kcur, "Kcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -10517,15 +10513,15 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); From 352c3859a7e5351e5b99e7fecd0946d1eeab86d9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 16 May 2024 13:23:30 +0300 Subject: [PATCH 13/22] backends : add dev messages to support rope freq. factors --- ggml-kompute.cpp | 4 ++++ ggml-sycl.cpp | 3 +++ ggml-vulkan.cpp | 4 ++++ 3 files changed, 11 insertions(+) diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 3f033d58be481..f03f27bea8e9f 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1677,6 +1677,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml } break; case GGML_OP_ROPE: { +#pragma message("TODO: implement phi3 frequency factors support"); +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225"); + GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); + GGML_ASSERT(ne10 == ne02); GGML_ASSERT(src0t == dstt); // const int n_past = ((int32_t *) dst->op_params)[0]; diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index eac8f55796735..2fa5e18c41cc4 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -14454,6 +14454,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const dpct::queue_ptr &main_stream) { +#pragma message("TODO: implement phi3 frequency factors support"); +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225"); + GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index aff451b6354e5..fce397bfaae80 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -4238,6 +4238,10 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#pragma message("TODO: implement phi3 frequency factors support"); +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225"); + GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); + const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; From f4cb482c6261c69b559baec4dafc0e842715bf41 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 16 May 2024 13:33:01 +0300 Subject: [PATCH 14/22] minor : style --- ggml-cuda/rope.cu | 9 +++------ llama.cpp | 25 ++++++++++++------------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index e7968817f15e2..8aba089f4346b 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -182,22 +182,19 @@ static void rope_neox_cuda( x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, inv_ndims, freq_factors ); - } - else { + } else { rope_neox<<>>( x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, inv_ndims, freq_factors ); } - } - else { + } else { if (freq_factors == nullptr) { rope_neox<<>>( x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, inv_ndims, freq_factors ); - } - else { + } else { rope_neox<<>>( x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, inv_ndims, freq_factors diff --git a/llama.cpp b/llama.cpp index a08941af10cf2..80db7375c85e7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1799,9 +1799,9 @@ struct llama_hparams { if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; - if (this->rope_long_factors != other.rope_long_factors) return true; + if (this->rope_long_factors != other.rope_long_factors) return true; if (this->rope_short_factors != other.rope_short_factors) return true; - if (this->rope_attn_factor != other.rope_attn_factor) return true; + if (this->rope_attn_factor != other.rope_attn_factor) return true; if (this->ssm_d_conv != other.ssm_d_conv) return true; if (this->ssm_d_inner != other.ssm_d_inner) return true; @@ -3323,7 +3323,7 @@ struct llama_model_loader { } template - bool get_arr(const std::string& key, std::vector& result, const bool required = true) { + bool get_arr(const std::string & key, std::vector & result, const bool required = true) { const int kid = gguf_find_key(meta, key.c_str()); if (kid < 0) { @@ -3342,10 +3342,10 @@ struct llama_model_loader { // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same::value)); - GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); + GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); result.resize(arr_info.length); - result.assign((const T*)arr_info.data, (const T*)arr_info.data + arr_info.length); + result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); return true; } @@ -3898,7 +3898,7 @@ static void llm_load_hparams( } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; - ml.get_arr(LLM_KV_ROPE_SCALING_LONG_FACTORS, hparams.rope_long_factors, false); + ml.get_arr(LLM_KV_ROPE_SCALING_LONG_FACTORS, hparams.rope_long_factors, false); ml.get_arr(LLM_KV_ROPE_SCALING_SHORT_FACTORS, hparams.rope_short_factors, false); GGML_ASSERT(hparams.rope_long_factors.size() == 0 || hparams.rope_long_factors.size() == hparams.n_embd / hparams.n_head / 2); @@ -6994,8 +6994,7 @@ struct llm_build_context { return lctx.inp_pos; } - struct ggml_tensor* build_freq_factors() { - + struct ggml_tensor * build_freq_factors() { if (hparams.rope_long_factors.empty() || hparams.rope_short_factors.empty()) { lctx.freq_factors = nullptr; return nullptr; @@ -10968,18 +10967,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } if (lctx.freq_factors) { - auto freq_dim = hparams.n_embd_head_k / 2; + // TODO: this might have to be hparams.n_rot instead of hparams.n_embd_head_k, but maybe it does not matter + const auto freq_dim = hparams.n_embd_head_k / 2; GGML_ASSERT(lctx.freq_factors->ne[0] == freq_dim); - GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim); + GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim); GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim); // choose long/short freq factors based on the context size - auto n_ctx = llama_n_ctx(&lctx); + const auto n_ctx = llama_n_ctx(&lctx); if (n_ctx > hparams.n_yarn_orig_ctx) { ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); - } - else { + } else { ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); } } From e7c7d8ca42b7e3011c1f74604a341e989485243d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 16 May 2024 13:34:57 +0300 Subject: [PATCH 15/22] tests : update to use new rope API --- tests/test-backend-ops.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c74e253db4b3b..1493a7ca7c405 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1763,14 +1763,14 @@ struct test_llama : public test_llm { struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur); struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur); - Qcur = ggml_rope_custom( - ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr, hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_custom( - ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr, hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -1889,13 +1889,13 @@ struct test_falcon : public test_llm { Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens); // using mode = 2 for neox mode - Qcur = ggml_rope_custom( - ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx, + Qcur = ggml_rope_ext( + ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_custom( - ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx, + Kcur = ggml_rope_ext( + ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); From 4f787ead1451a2d58dad8d799a259991bfb6f631 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 16 May 2024 13:42:00 +0300 Subject: [PATCH 16/22] backends : fix pragma semicolons --- ggml-kompute.cpp | 4 ++-- ggml-sycl.cpp | 4 ++-- ggml-vulkan.cpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index f03f27bea8e9f..6c6058b2a95b1 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1677,8 +1677,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml } break; case GGML_OP_ROPE: { -#pragma message("TODO: implement phi3 frequency factors support"); -#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225"); +#pragma message("TODO: implement phi3 frequency factors support") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); GGML_ASSERT(ne10 == ne02); diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 2fa5e18c41cc4..f486b6c0a5a3b 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -14454,8 +14454,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd, const dpct::queue_ptr &main_stream) { -#pragma message("TODO: implement phi3 frequency factors support"); -#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225"); +#pragma message("TODO: implement phi3 frequency factors support") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index fce397bfaae80..16287a28089a0 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -4238,8 +4238,8 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { -#pragma message("TODO: implement phi3 frequency factors support"); -#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225"); +#pragma message("TODO: implement phi3 frequency factors support") +#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); const int n_dims = ((int32_t *) dst->op_params)[1]; From d93b5cad0aecabc3487ba9e6cae3c61595db720c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 16 May 2024 13:53:22 +0300 Subject: [PATCH 17/22] minor : cleanup --- ggml-cuda/rope.cu | 2 +- ggml-metal.m | 14 ++------------ 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index 8aba089f4346b..bf0342e3218ac 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -291,7 +291,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { freq_factors = (const float *) src2->data; } } else { - GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for mode 1"); + GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox"); } rope_corr_dims corr_dims; diff --git a/ggml-metal.m b/ggml-metal.m index 7bc75f39b4bc8..5d5ad20ada788 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2263,18 +2263,8 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal"); - if (is_neox) { - // TODO: move these asserts to ggml.c - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(src1->ne[0] == ne2); - - if (id_src2 != nil) { - // TODO: move these asserts to ggml.c - GGML_ASSERT(src2->type == GGML_TYPE_F32); - GGML_ASSERT(src2->ne[0] >= n_dims / 2); - } - } else { - GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for mode 1"); + if (!is_neox) { + GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox"); } id pipeline = nil; From 600896b8827606b325537fddfb256c2247d51f4d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 May 2024 18:26:55 +0300 Subject: [PATCH 18/22] llama : move rope factors from KV header to tensors --- convert-hf-to-gguf.py | 4 +- gguf-py/gguf/constants.py | 6 ++- gguf-py/gguf/gguf_writer.py | 6 --- llama.cpp | 103 ++++++++++++++---------------------- 4 files changed, 46 insertions(+), 73 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 2c4764c5ec921..cbdb39e9509a2 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1834,8 +1834,8 @@ def set_gguf_parameters(self): if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') - self.gguf_writer.add_rope_scaling_freq_long_factors(long_factors) - self.gguf_writer.add_rope_scaling_freq_short_factors(short_factors) + self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) + self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) @Model.register("PlamoForCausalLM") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 75a129c1ab802..42df2e4d00604 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -61,8 +61,6 @@ class Rope: FREQ_BASE = "{arch}.rope.freq_base" SCALING_TYPE = "{arch}.rope.scaling.type" SCALING_FACTOR = "{arch}.rope.scaling.factor" - SCALING_LONG_FACTORS = "{arch}.rope.scaling.freq_long_factors" - SCALING_SHORT_FACTORS = "{arch}.rope.scaling.freq_short_factors" SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" @@ -151,6 +149,8 @@ class MODEL_TENSOR(IntEnum): OUTPUT = auto() OUTPUT_NORM = auto() ROPE_FREQS = auto() + ROPE_FACTORS_LONG = auto() + ROPE_FACTORS_SHORT = auto() ATTN_Q = auto() ATTN_K = auto() ATTN_V = auto() @@ -228,6 +228,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", + MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e583fc9b50853..8b41b54eaa5a6 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -433,12 +433,6 @@ def add_rope_scaling_type(self, value: RopeScalingType) -> None: def add_rope_scaling_factor(self, value: float) -> None: self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value) - def add_rope_scaling_freq_long_factors(self, value: Sequence[float]) -> None: - self.add_array(Keys.Rope.SCALING_LONG_FACTORS.format(arch=self.arch), value) - - def add_rope_scaling_freq_short_factors(self, value: Sequence[float]) -> None: - self.add_array(Keys.Rope.SCALING_SHORT_FACTORS.format(arch=self.arch), value) - def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None: self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value) diff --git a/llama.cpp b/llama.cpp index 80db7375c85e7..4122229ffd83b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -304,8 +304,6 @@ enum llm_kv { LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, - LLM_KV_ROPE_SCALING_LONG_FACTORS, - LLM_KV_ROPE_SCALING_SHORT_FACTORS, LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, @@ -384,8 +382,6 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_LONG_FACTORS, "%s.rope.scaling.freq_long_factors" }, - { LLM_KV_ROPE_SCALING_SHORT_FACTORS, "%s.rope.scaling.freq_short_factors" }, { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, @@ -442,6 +438,8 @@ enum llm_tensor { LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, LLM_TENSOR_ATTN_Q, LLM_TENSOR_ATTN_K, LLM_TENSOR_ATTN_V, @@ -809,18 +807,20 @@ static const std::map> LLM_TENSOR_NA { LLM_ARCH_PHI3, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, { @@ -1756,14 +1756,11 @@ struct llama_hparams { float f_norm_eps; float f_norm_rms_eps; + float rope_attn_factor = 1.0f; float rope_freq_base_train; float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; - std::vector rope_long_factors; - std::vector rope_short_factors; - float rope_attn_factor = 1.0f; - // for State Space Models uint32_t ssm_d_conv = 0; uint32_t ssm_d_inner = 0; @@ -1799,10 +1796,6 @@ struct llama_hparams { if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; - if (this->rope_long_factors != other.rope_long_factors) return true; - if (this->rope_short_factors != other.rope_short_factors) return true; - if (this->rope_attn_factor != other.rope_attn_factor) return true; - if (this->ssm_d_conv != other.ssm_d_conv) return true; if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_state != other.ssm_d_state) return true; @@ -1812,6 +1805,7 @@ struct llama_hparams { if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; + if (!is_float_close(this->rope_attn_factor, other.rope_attn_factor, EPSILON)) return true; if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; @@ -2117,6 +2111,10 @@ struct llama_model { struct ggml_tensor * output; struct ggml_tensor * output_b; + // long rope factors + struct ggml_tensor * rope_long; + struct ggml_tensor * rope_short; + std::vector layers; llama_split_mode split_mode; @@ -2260,8 +2258,6 @@ struct llama_context { struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] - struct ggml_tensor * freq_factors = nullptr; // F32 [kv_size / 2] - // control vectors struct llama_control_vector cvec; }; @@ -3898,12 +3894,6 @@ static void llm_load_hparams( } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; - ml.get_arr(LLM_KV_ROPE_SCALING_LONG_FACTORS, hparams.rope_long_factors, false); - ml.get_arr(LLM_KV_ROPE_SCALING_SHORT_FACTORS, hparams.rope_short_factors, false); - - GGML_ASSERT(hparams.rope_long_factors.size() == 0 || hparams.rope_long_factors.size() == hparams.n_embd / hparams.n_head / 2); - GGML_ASSERT(hparams.rope_long_factors.size() == hparams.rope_short_factors.size()); - ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); // sanity check for n_rot (optional) @@ -4937,6 +4927,7 @@ static bool llm_load_tensors( // create tensors for the weights { const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_head = n_embd / hparams.n_head; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = n_embd_v_gqa; @@ -5648,6 +5639,9 @@ static bool llm_load_tensors( { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); + model.rope_long = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, false); + model.rope_short = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, false); + // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); @@ -6878,7 +6872,7 @@ struct llm_build_context { cb(lctx.inp_K_shift, "K_shift", -1); ggml_set_input(lctx.inp_K_shift); - lctx.freq_factors = build_freq_factors(); + struct ggml_tensor * rope_factors = build_rope_factors(); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * tmp = @@ -6889,7 +6883,7 @@ struct llm_build_context { ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), 0), - lctx.inp_K_shift, lctx.freq_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(tmp, "K_shifted", il); @@ -6994,17 +6988,15 @@ struct llm_build_context { return lctx.inp_pos; } - struct ggml_tensor * build_freq_factors() { - if (hparams.rope_long_factors.empty() || hparams.rope_short_factors.empty()) { - lctx.freq_factors = nullptr; - return nullptr; - } + struct ggml_tensor * build_rope_factors() { + // choose long/short freq factors based on the context size + const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max; - lctx.freq_factors = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_embd_head_k / 2); - cb(lctx.freq_factors, "freq_factors", -1); - ggml_set_input(lctx.freq_factors); + if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) { + return model.rope_long; + } - return lctx.freq_factors; + return model.rope_short; } struct ggml_tensor * build_inp_out_ids() { @@ -9126,7 +9118,9 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); // rope freq factors for 128k context - struct ggml_tensor* freq_factors = build_freq_factors(); + struct ggml_tensor * rope_factors = build_rope_factors(); + + GGML_ASSERT(rope_factors != nullptr && "rope_factors is required for phi3"); // TMP: remove me for (int il = 0; il < n_layer; ++il) { auto residual = inpL; @@ -9165,7 +9159,7 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, freq_factors, n_rot, rope_type, 0, n_orig_ctx, + ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); @@ -9174,7 +9168,7 @@ struct llm_build_context { cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, freq_factors, n_rot, rope_type, 0, n_orig_ctx, + ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -10966,23 +10960,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (lctx.freq_factors) { - // TODO: this might have to be hparams.n_rot instead of hparams.n_embd_head_k, but maybe it does not matter - const auto freq_dim = hparams.n_embd_head_k / 2; - - GGML_ASSERT(lctx.freq_factors->ne[0] == freq_dim); - GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim); - GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim); - - // choose long/short freq factors based on the context size - const auto n_ctx = llama_n_ctx(&lctx); - if (n_ctx > hparams.n_yarn_orig_ctx) { - ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); - } else { - ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); - } - } - if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = batch.n_tokens; From 23b72b871c5477fecc23baaececf2c7ce269e853 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 May 2024 18:29:12 +0300 Subject: [PATCH 19/22] llama : remove tmp assert --- llama.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4122229ffd83b..7bb3718ee4496 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9120,8 +9120,6 @@ struct llm_build_context { // rope freq factors for 128k context struct ggml_tensor * rope_factors = build_rope_factors(); - GGML_ASSERT(rope_factors != nullptr && "rope_factors is required for phi3"); // TMP: remove me - for (int il = 0; il < n_layer; ++il) { auto residual = inpL; From e9acbce624e623bdd9cb5e4977f113dc1a2b2220 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 May 2024 19:08:12 +0300 Subject: [PATCH 20/22] cuda : fix compile warning --- ggml-cuda/rope.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index bf0342e3218ac..4a558f4b3757e 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -260,7 +260,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; - const int64_t ne2 = dst->ne[2]; const int64_t nrows = ggml_nrows(src0); //const int n_past = ((int32_t *) dst->op_params)[0]; From 92711138f9fe6c8e2c62c0dd618598b63e8bc7c8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 May 2024 19:40:01 +0300 Subject: [PATCH 21/22] convert : read/write n_head_kv --- convert-hf-to-gguf.py | 3 ++- llama.cpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index cbdb39e9509a2..06c89e23f4333 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1787,6 +1787,7 @@ def set_gguf_parameters(self): n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) + n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) rms_eps = self.find_hparam(["rms_norm_eps"]) max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) @@ -1799,7 +1800,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"])) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_head_count(n_head) - self.gguf_writer.add_head_count_kv(n_head) + self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_layer_norm_rms_eps(rms_eps) self.gguf_writer.add_rope_dimension_count(rope_dims) self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"])) diff --git a/llama.cpp b/llama.cpp index 7bb3718ee4496..b64e14b61dd70 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5652,12 +5652,12 @@ static bool llm_load_tensors( ggml_context* ctx_layer = ctx_for_layer(i); ggml_context* ctx_split = ctx_for_layer_split(i); - auto& layer = model.layers[i]; + auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, false); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); From 7528c705b0c741a68a1d85a523d827374c258195 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 May 2024 22:02:00 +0300 Subject: [PATCH 22/22] llama : fix uninitialized tensors --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index b64e14b61dd70..abff8c1c03e7a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2112,8 +2112,8 @@ struct llama_model { struct ggml_tensor * output_b; // long rope factors - struct ggml_tensor * rope_long; - struct ggml_tensor * rope_short; + struct ggml_tensor * rope_long = nullptr; + struct ggml_tensor * rope_short = nullptr; std::vector layers;