Skip to content

Commit 0a85ae7

Browse files
committed
metal : fix GELU kernel numerical stability by using precise::tanh
1 parent b693000 commit 0a85ae7

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

ggml-metal.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,8 @@ void ggml_metal_graph_compute(
539539

540540
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
541541

542-
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
543-
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
542+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
543+
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
544544

545545
for (int ind = node_start; ind < node_end; ++ind) {
546546
const int i = has_concur ? ctx->concur_list[ind] : ind;

ggml-metal.metal

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,12 @@ kernel void kernel_gelu(
8787
device float * dst,
8888
uint tpig[[thread_position_in_grid]]) {
8989
float x = src0[tpig];
90-
dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
90+
91+
// BEWARE !!!
92+
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
93+
// This was observed with Falcon 7B and 40B models
94+
//
95+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
9196
}
9297

9398
kernel void kernel_soft_max(

0 commit comments

Comments
 (0)