Skip to content

Commit b66df9d

Browse files
authored
CUDA: fix build error from ambiguous __half conversions in conv2d (#15690)
* CUDA: fix build error from ambiguous __half conversions in conv2d Building conv2d with half precision failed because `__half` defines multiple implicit conversion operators (to float, int, short, etc.), causing ambiguous overload resolution when multiplying with float. Introduce a templated `to_float` helper that explicitly converts `__half` via `__half2float`, while passing through float unchanged. Use this helper in conv2d accumulation to ensure unambiguous and correct promotion to float. Fixes some build errors with half-precision kernels on CUDA. ggml-ci * CUDA: Replace custom to_float helper with unified ggml_cuda_cast and add half‑>float conversion * CUDA: Add missing convert.cuh header * CUDA: remove unnecessary extension in ggml_cuda_cast * CUDA: Address review comment, remove second type template argument
1 parent b9382c3 commit b66df9d

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

ggml/src/ggml-cuda/conv2d.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "conv2d.cuh"
2+
#include "convert.cuh"
23

34
struct conv_params {
45
const int64_t IW, IH;
@@ -94,8 +95,8 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
9495
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
9596

9697
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
97-
const float kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
98-
acc += (input_val * kernel_val);
98+
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
99+
acc += (input_val * ggml_cuda_cast<float>(kernel_val));
99100
}
100101
}
101102
}

0 commit comments

Comments
 (0)