Skip to content

Add phi3 128K model support #7225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8fa413d
add phi3 128k support in convert-hf-to-gguf
liuwei-git May 1, 2024
56d9fa7
add phi3 128k support in cuda
liuwei-git May 1, 2024
cc19780
address build warnings on llama.cpp
liuwei-git May 1, 2024
9f87129
adjust index value in cuda long rope freq factors
liuwei-git May 11, 2024
c556931
add long rope support in ggml cpu backend
liuwei-git May 11, 2024
6333ed1
make freq factors only depend on ctx size
liuwei-git May 11, 2024
5683db3
remove unused rope scaling type 'su' frin gguf converter
liuwei-git May 11, 2024
b1f491a
fix flint warnings on convert-hf-to-gguf.py
liuwei-git May 11, 2024
d05ae12
set to the short freq factor when context size is small than trained …
liuwei-git May 11, 2024
8a9c897
add one line of comments
liuwei-git May 11, 2024
2d473a4
metal : support rope freq_factors
ggerganov May 16, 2024
471d817
ggml : update ggml_rope_ext API to support freq. factors
ggerganov May 16, 2024
352c385
backends : add dev messages to support rope freq. factors
ggerganov May 16, 2024
f4cb482
minor : style
ggerganov May 16, 2024
e7c7d8c
tests : update to use new rope API
ggerganov May 16, 2024
4f787ea
backends : fix pragma semicolons
ggerganov May 16, 2024
d93b5ca
minor : cleanup
ggerganov May 16, 2024
600896b
llama : move rope factors from KV header to tensors
ggerganov May 21, 2024
23b72b8
llama : remove tmp assert
ggerganov May 21, 2024
e9acbce
cuda : fix compile warning
ggerganov May 21, 2024
9271113
convert : read/write n_head_kv
ggerganov May 21, 2024
7528c70
llama : fix uninitialized tensors
ggerganov May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast

import math
import numpy as np
import torch

Expand Down Expand Up @@ -1784,23 +1785,59 @@ def set_vocab(self):
def set_gguf_parameters(self):
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])

rot_pct = 1.0
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
rms_eps = self.find_hparam(["rms_norm_eps"])
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
rope_dims = n_embd // n_head

self.gguf_writer.add_name("Phi3")
self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))

self.gguf_writer.add_context_length(max_pos_embds)
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
self.gguf_writer.add_embedding_length(n_embd)
self.gguf_writer.add_feed_forward_length(8192)
self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
self.gguf_writer.add_rope_dimension_count(rope_dims)
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
self.gguf_writer.add_file_type(self.ftype)

# write rope scaling for long context (128k) model
rope_scaling = self.find_hparam(['rope_scaling'], True)
if (rope_scaling is None):
return

scale = max_pos_embds / orig_max_pos_embds

rope_scaling_type = rope_scaling.get('type', '').lower()
if len(rope_scaling_type) == 0:
raise KeyError('Missing the required key rope_scaling.type')

if rope_scaling_type == 'su':
attn_factor = math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) if scale > 1.0 else 1.0
elif rope_scaling_type == 'yarn':
attn_factor = 0.1 * math.log(scale) + 1.0 if scale > 1.0 else 1.0
else:
raise NotImplementedError(f'The rope scaling type {rope_scaling_type} is not supported yet')

self.gguf_writer.add_rope_scaling_attn_factors(attn_factor)

long_factors = rope_scaling.get('long_factor', None)
short_factors = rope_scaling.get('short_factor', None)

if long_factors is None or short_factors is None:
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')

if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')

self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))


@Model.register("PlamoForCausalLM")
class PlamoModel(Model):
Expand Down
4 changes: 2 additions & 2 deletions examples/finetune/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// not capturing these, to silcence warnings
const int rope_mode = 0;

return ggml_rope_custom(ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
return ggml_rope_ext(ctx,
t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0,
rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
);
};
Expand Down
4 changes: 2 additions & 2 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ static struct ggml_tensor * llama_build_train_graphs(
// not capturing these, to silcence warnings
const int rope_mode = 0;

return ggml_rope_custom(
ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
return ggml_rope_ext(
ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
);
};

Expand Down
72 changes: 48 additions & 24 deletions ggml-cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ static __global__ void rope(
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}

template<typename T, bool has_pos>
template<typename T, bool has_pos, bool has_freq_facs>
static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);

Expand All @@ -88,7 +88,9 @@ static __global__ void rope_neox(
float cur_rot = inv_ndims * ic - ib;

const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;

const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;

float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
Expand Down Expand Up @@ -164,7 +166,7 @@ static void rope_cuda(
template<typename T>
static void rope_neox_cuda(
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
Expand All @@ -175,15 +177,29 @@ static void rope_neox_cuda(
const float inv_ndims = -1.0f / n_dims;

if (pos == nullptr) {
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims
);
if (freq_factors == nullptr) {
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
} else {
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
} else {
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims
);
if (freq_factors == nullptr) {
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
} else {
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
}
}

Expand Down Expand Up @@ -214,24 +230,27 @@ static void rope_cuda_f32(

static void rope_neox_cuda_f16(
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {

rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

static void rope_neox_cuda_f32(
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {

rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];

const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;

float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

Expand All @@ -241,7 +260,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t nrows = ggml_nrows(src0);

//const int n_past = ((int32_t *) dst->op_params)[0];
Expand All @@ -259,16 +277,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));

const float * freq_factors = nullptr;
const int32_t * pos = nullptr;
if ((mode & 1) == 0) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->ne[0] == ne2);
pos = (const int32_t *) src1_d;
}

const bool is_neox = mode & 2;
const bool is_glm = mode & 4;

if (is_neox) {
pos = (const int32_t *) src1_d;

if (src2 != nullptr) {
freq_factors = (const float *) src2->data;
}
} else {
GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
}

rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);

Expand All @@ -280,12 +304,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, stream
attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, stream
attn_factor, corr_dims, freq_factors, stream
);
} else {
GGML_ASSERT(false);
Expand Down
4 changes: 4 additions & 0 deletions ggml-kompute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
} break;
case GGML_OP_ROPE:
{
#pragma message("TODO: implement phi3 frequency factors support")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");

GGML_ASSERT(ne10 == ne02);
GGML_ASSERT(src0t == dstt);
// const int n_past = ((int32_t *) dst->op_params)[0];
Expand Down
Loading
Loading