diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index c9184398b422c..1471ad72a3cc1 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -9,19 +9,9 @@ __global__ void __launch_bounds__(splitD, 2) const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t d_inner, const int64_t L) { - - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); const int bidx = blockIdx.x; // split along B (sequences) const int bidy = blockIdx.y; // split along D (d_inner) const int tid = threadIdx.x; - const int wid = tid / 32; - const int wtid = tid % 32; - - extern __shared__ float smem[]; - const int stride_sA = N + 1; - const int stride_ss0 = N + 1; - float * smem_A = smem; - float * smem_s0 = smem_A + splitD * stride_sA; const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2); const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float)); @@ -35,52 +25,40 @@ __global__ void __launch_bounds__(splitD, 2) const int stride_s0 = src0_nb2 / sizeof(float); const int stride_x = src1_nb2 / sizeof(float); const int stride_dt = src2_nb1 / sizeof(float); - const int stride_A = src3_nb1 / sizeof(float); const int stride_B = src4_nb2 / sizeof(float); const int stride_C = src5_nb2 / sizeof(float); const int stride_s = stride_s0; const int stride_y = d_inner; - // can N not be 16? for example 32? - if (N == 16) { -#pragma unroll - for (size_t i = 0; i < splitD / 4; i += 2) { - float value = A_block[(wid * warp_size + i) * stride_A + wtid]; - // todo: bank conflict - // I am always confused with how to use the swizzling method to solve - // bank conflit. Hoping somebody can tell me. - smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; - } + float A[N]; + float s[N]; + #pragma unroll - for (size_t i = 0; i < splitD / 4; i += 2) { - float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid]; - smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; - } + for (int j = 0; j < N; j++) { + A[j] = A_block[tid * N + j]; + s[j] = s0_block[tid * stride_s0 + j]; } - __syncthreads(); - for (int64_t i = 0; i < L; i++) { float dt_soft_plus = dt_block[i * stride_dt + tid]; - if (dt_soft_plus <= 20.0f) { - dt_soft_plus = log1pf(exp(dt_soft_plus)); - } - float x_dt = x_block[i * stride_x + tid] * dt_soft_plus; + dt_soft_plus = (dt_soft_plus > 20.0f) ? dt_soft_plus : log1pf(expf(dt_soft_plus)); + + const float x_dt = x_block[i * stride_x + tid] * dt_soft_plus; float sumf = 0.0f; + #pragma unroll for (size_t j = 0; j < N; j++) { - float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) + - (B_block[i * stride_B + j] * x_dt); - sumf += state * C_block[i * stride_C + j]; - if (i == L - 1) { - s_block[tid * stride_s + j] = state; - } else { - smem_s0[tid * stride_ss0 + j] = state; - } + const float exp_term = expf(dt_soft_plus * A[j]); + s[j] = fmaf(s[j], exp_term, B_block[i * stride_B + j] * x_dt); + sumf = fmaf(s[j], C_block[i * stride_C + j], sumf); } - __syncthreads(); y_block[i * stride_y + tid] = sumf; } + +#pragma unroll + for (int j = 0; j < N; j++) { + s_block[tid * stride_s + j] = s[j]; + } } // assumes as many threads as d_state