Skip to content

metal : bug with ggml_cont #3012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions ggml-metal.m
Original file line number Diff line number Diff line change
@@ -541,10 +541,7 @@ void ggml_metal_graph_find_concurrency(
int64_t data_start = (int64_t) gf->nodes[i]->data;
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
for (int j = n_start; j < i; j++) {
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
&& gf->nodes[j]->op != GGML_OP_VIEW \
&& gf->nodes[j]->op != GGML_OP_TRANSPOSE \
&& gf->nodes[j]->op != GGML_OP_PERMUTE) {
if (nodes_unused[j] && gf->nodes[j]->view_src == NULL) {
Comment on lines 543 to +544
Copy link
Member Author

@ggerganov ggerganov Sep 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need confirmation if this is correct.

Btw, I see that check_mem is always set to false.
Are there cases where we would still need it?
If it's not needed, should we remove it?

if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
continue;
8 changes: 4 additions & 4 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
@@ -783,11 +783,11 @@ kernel void kernel_cpy_f16_f16(
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);

device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
device const half * src = (device half *) ((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i00*nb0);

*dst_data = *src;
}
}

4 changes: 3 additions & 1 deletion ggml.c
Original file line number Diff line number Diff line change
@@ -4285,7 +4285,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
}

size_t ggml_nbytes(const struct ggml_tensor * tensor) {
size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
size_t nbytes = (tensor->ne[0]*tensor->nb[0])/ggml_blck_size(tensor->type);
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
@@ -5213,6 +5213,8 @@ struct ggml_tensor * ggml_view_tensor(
result->nb[i] = src->nb[i];
}

result->op = GGML_OP_VIEW;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added your proposal since it better expresses the KV cache dependency, until we figure out a better solution.

For this to work though, we need to set the ggml_view_tensor op to GGML_OP_VIEW as you have recently noted. I guess this change is fine

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, we should also set src[0] to src in ggml_view_tensor, and add the additional dependency as src[1] instead in llm_build_graph. Otherwise, it would break old code that assumes that the source is in src[0], such as this in ggml-cuda:
https://github.com/ggerganov/llama.cpp/blob/bd33e5ab92e7f214205792fc1cd9ca28e810f897/ggml-cuda.cu#L6539-L6544
ggml_graph_import may also be an issue, and I am not sure that it will import the second dependency in src[1] correctly, since it just creates a new tensor calling ggml_view_4d. Newer code should use view_src instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good points. I guess I'll come back to this later and make a better solution.
Made an issue so we don't forget: ggml-org/ggml#502


return result;
}

70 changes: 39 additions & 31 deletions llama.cpp
Original file line number Diff line number Diff line change
@@ -2341,57 +2341,65 @@ static struct ggml_cgraph * llm_build_llama(
// compute Q and K and RoPE them
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
offload_func_kq(tmpk);
ggml_set_name(tmpk, "tmpk");
ggml_set_name (tmpk, "tmpk");

struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq");
ggml_set_name (tmpq, "tmpq");

struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
offload_func_v(tmpv);
ggml_set_name (tmpv, "tmpv");

struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur");
ggml_set_name (Kcur, "Kcur");

struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");

// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
ggml_set_name (Qcur, "Qcur");

struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
offload_func_v(tmpv);
ggml_set_name(tmpv, "tmpv");
// compute the transposed [N, n_embd] V matrix
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
offload_func_v(Vcur);
ggml_set_name (Vcur, "Vcur");

struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
offload_func_v(Vcur);
ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k;
struct ggml_tensor * v;

struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
offload_func_kq(k);
ggml_set_name(k, "k");
// store key and value to memory
{
struct ggml_tensor * k_view = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
offload_func_kq(k_view);
ggml_set_name (k_view, "k_view");

struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
struct ggml_tensor * v_view = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
offload_func_v(v);
ggml_set_name(v, "v");
offload_func_v(v_view);
ggml_set_name (v_view, "v_view");

// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
struct ggml_tensor * k_cpy = ggml_cpy(ctx0, Kcur, k_view);
struct ggml_tensor * v_cpy = ggml_cpy(ctx0, Vcur, v_view);

// TODO: replace with ggml_dependency / ggml_depends_on
k = ggml_view_tensor(ctx0, kv_self.k);
v = ggml_view_tensor(ctx0, kv_self.v);
k->src[0] = k_cpy;
v->src[0] = v_cpy;
}

struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
offload_func_kq(Q);
ggml_set_name(Q, "Q");

struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k,
ggml_view_3d(ctx0, k,
n_embd_head, n_past + N, n_head_kv,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
ggml_element_size(k)*n_embd_gqa,
ggml_element_size(k)*n_embd_head,
ggml_element_size(k)*n_embd_gqa*n_ctx*il);
offload_func_kq(K);
ggml_set_name(K, "K");

@@ -2418,11 +2426,11 @@ static struct ggml_cgraph * llm_build_llama(

// split cached V into n_head heads
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
ggml_view_3d(ctx0, v,
n_past + N, n_embd_head, n_head_kv,
ggml_element_size(kv_self.v)*n_ctx,
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
ggml_element_size(v)*n_ctx,
ggml_element_size(v)*n_ctx*n_embd_head,
ggml_element_size(v)*n_ctx*n_embd_gqa*il);
offload_func_v(V);
ggml_set_name(V, "V");

@@ -2434,7 +2442,7 @@ static struct ggml_cgraph * llm_build_llama(
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
// is there a better way?
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, v->type, n_past + N, n_embd_head, n_head));
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
#endif