Skip to content

fix MUSA compiler warning #12704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 28 additions & 31 deletions ggml/src/ggml-cuda/ssm-conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ template <size_t split_d_inner, size_t d_conv>
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int ncs, const int nr, const int n_t, const int n_s) {
const int64_t n_t) {
GGML_UNUSED(src0_nb0);
const int tid = threadIdx.x;
const int bidx = blockIdx.x;
const int bidy = blockIdx.y;

const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);

const int stride_x = src0_nb1 / sizeof(float);
Expand All @@ -21,43 +22,42 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
float w[d_conv] = { 0.0f };

#pragma unroll
for (int j = 0; j < d_conv; j++) {
for (size_t j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j];
}

for (int i = 0; i < n_t; i++) {
for (int64_t i = 0; i < n_t; i++) {
float sumf = 0.0f;

if (i == 0) {
for (int j = 0; j < d_conv; j++) {
for (size_t j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j];
}
} else {
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
}

#pragma unroll
for (int j = 0; j < d_conv; j++) {
for (size_t j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
y_block[i * stride_y + tid] = sumf;
}
}

template <size_t split_d_inner, size_t d_conv, size_t split_n_t>
template <size_t split_d_inner, size_t d_conv, int64_t split_n_t>
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
const int dst_nb1, const int dst_nb2, const int nc, const int ncs,
const int nr, const int n_t, const int n_s) {
const int dst_nb1, const int dst_nb2, const int64_t n_t) {
const int tid = threadIdx.x;
const int bidx = blockIdx.x;
const int bidy = blockIdx.y;
const int bidz = blockIdx.z;

const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
bidz * split_n_t * src0_nb0);
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
float * y_block =
(float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);

Expand All @@ -69,25 +69,25 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
float w[d_conv] = { 0.0f };

#pragma unroll
for (int j = 0; j < d_conv; j++) {
for (size_t j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j];
}

#pragma unroll
for (int i = 0; i < split_n_t; i++) {
for (int64_t i = 0; i < split_n_t; i++) {
if (bidz * split_n_t + i < n_t) {
float sumf = 0.0f;

if (i == 0) {
for (int j = 0; j < d_conv; j++) {
for (size_t j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j];
}
} else {
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
}

#pragma unroll
for (int j = 0; j < d_conv; j++) {
for (size_t j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
y_block[i * stride_y + tid] = sumf;
Expand All @@ -97,27 +97,25 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,

static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t,
const int n_s, cudaStream_t stream) {
const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
const int64_t n_s, cudaStream_t stream) {
const int threads = 128;
GGML_ASSERT(nr % threads == 0);

if (n_t <= 32) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
if (nc == 4) {
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t,
n_s);
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
GGML_ABORT("Only support kernel size = 4 now.");
}
} else {
if (nc == 4) {
const int split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 4, split_n_t>
<<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0,
dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s);
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
GGML_ABORT("Only support kernel size = 4 right now.");
}
Expand All @@ -128,11 +126,10 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight

const int nc = src1->ne[0]; // d_conv
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
const int nr = src0->ne[1]; // d_inner
const int n_t = dst->ne[1]; // tokens per sequence
const int n_s = dst->ne[2]; // number of sequences in the batch
const int64_t nc = src1->ne[0]; // d_conv
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = dst->ne[1]; // tokens per sequence
const int64_t n_s = dst->ne[2]; // number of sequences in the batch

GGML_ASSERT(dst->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float));
Expand All @@ -147,5 +144,5 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
dst->nb[2], nc, ncs, nr, n_t, n_s, stream);
dst->nb[2], nc, nr, n_t, n_s, stream);
}
34 changes: 16 additions & 18 deletions ggml/src/ggml-cuda/ssm-scan.cu
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
#include "ssm-scan.cuh"

// #include <cuda_runtime.h>
// static __device__ void global_to_shared(const float *src, float *dst) {
// asm volatile("cp.async.");
// }

template <size_t splitD, size_t N>
__global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
float * __restrict__ dst, const int D, const int L, const int B) {
float * __restrict__ dst, const int64_t L) {
GGML_UNUSED(src1_nb0);
GGML_UNUSED(src2_nb0);
const int bidx = blockIdx.x; // split along B
const int bidy = blockIdx.y; // split along D
const int tid = threadIdx.x;
Expand All @@ -25,12 +22,12 @@ __global__ void __launch_bounds__(splitD, 2)
float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA;

const float * s0_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
const float * x_block = (const float *) ((char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((char *) src4 + (bidx * src4_nb2));
const float * C_block = (const float *) ((char *) src5 + (bidx * src5_nb2));
const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);

Expand All @@ -46,31 +43,31 @@ __global__ void __launch_bounds__(splitD, 2)
// can N not be 16? for example 32?
if (N == 16) {
#pragma unroll
for (int i = 0; i < splitD / 4; i += 2) {
for (size_t i = 0; i < splitD / 4; i += 2) {
float value = A_block[(wid * warpSize + i) * stride_A + wtid];
// todo: bank conflict
// I am always confused with how to use the swizzling method to solve
// bank conflit. Hoping somebody can tell me.
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
#pragma unroll
for (int i = 0; i < splitD / 4; i += 2) {
for (size_t i = 0; i < splitD / 4; i += 2) {
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
}

__syncthreads();

for (int i = 0; i < L; i++) {
for (int64_t i = 0; i < L; i++) {
float dt_soft_plus = dt_block[i * stride_dt + tid];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(exp(dt_soft_plus));
}
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
float sumf = 0.0f;
#pragma unroll
for (int j = 0; j < N; j++) {
for (size_t j = 0; j < N; j++) {
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
(B_block[i * stride_B + j] * x_dt);
sumf += state * C_block[i * stride_C + j];
Expand All @@ -90,7 +87,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
float * dst, const int N, const int D, const int L, const int B, cudaStream_t stream) {
float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
cudaStream_t stream) {
const int threads = 128;
// todo: consider D cannot be divided,does this situation exist?
GGML_ASSERT(D % threads == 0);
Expand All @@ -99,7 +97,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
if (N == 16) {
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, D, L, B);
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
} else {
GGML_ABORT("doesn't support N!=16.");
}
Expand Down
Loading