Skip to content

Commit 9fcb29f

Browse files
authored
ggml: allow casting between f32 and i32 (#15783)
* ggml: allow casting between f32 and i32 * fix cuda * add vulkan * fix CPU non-cont * add non-cont test case * add note * extend test number range * correct note * add cont version for vulkan
1 parent 5ef22d2 commit 9fcb29f

File tree

12 files changed

+247
-3
lines changed

12 files changed

+247
-3
lines changed

ggml/include/ggml-cpu.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ extern "C" {
134134
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
135135

136136
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
137+
GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
137138
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
138139
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
139140
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);

ggml/include/ggml.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,6 +1404,7 @@ extern "C" {
14041404
struct ggml_tensor * a,
14051405
struct ggml_tensor * b);
14061406

1407+
// note: casting from f32 to i32 will discard the fractional part
14071408
GGML_API struct ggml_tensor * ggml_cast(
14081409
struct ggml_context * ctx,
14091410
struct ggml_tensor * a,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
373373
.vec_dot_type = GGML_TYPE_Q8_K,
374374
.nrows = 1,
375375
},
376+
[GGML_TYPE_I32] = {
377+
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
378+
},
376379
};
377380

378381
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
@@ -2696,7 +2699,10 @@ struct ggml_cplan ggml_graph_plan(
26962699
if (ggml_is_quantized(node->type) ||
26972700
// F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
26982701
(node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
2699-
(node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
2702+
(node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) ||
2703+
// conversion between F32 and I32
2704+
(node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) ||
2705+
(node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) {
27002706
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
27012707
}
27022708
} break;
@@ -3258,6 +3264,13 @@ void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) {
32583264
}
32593265
}
32603266

3267+
void ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) {
3268+
int64_t i = 0;
3269+
for (; i < n; ++i) {
3270+
y[i] = x[i];
3271+
}
3272+
}
3273+
32613274
void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
32623275
int64_t i = 0;
32633276
#if defined(__AVX2__)

ggml/src/ggml-cpu/ops.cpp

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,24 @@ static void ggml_compute_forward_dup_f32(
776776
id += ne00 * (ne01 - ir1);
777777
}
778778
}
779+
} else if (dst->type == GGML_TYPE_I32) {
780+
size_t id = 0;
781+
int32_t * dst_ptr = (int32_t *) dst->data;
782+
783+
for (int i03 = 0; i03 < ne03; i03++) {
784+
for (int i02 = 0; i02 < ne02; i02++) {
785+
id += ne00 * ir0;
786+
for (int i01 = ir0; i01 < ir1; i01++) {
787+
for (int i00 = 0; i00 < ne00; i00++) {
788+
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
789+
790+
dst_ptr[id] = *src0_ptr;
791+
id++;
792+
}
793+
}
794+
id += ne00 * (ne01 - ir1);
795+
}
796+
}
779797
} else {
780798
GGML_ABORT("fatal error"); // TODO: implement
781799
}
@@ -947,6 +965,144 @@ static void ggml_compute_forward_dup_f32(
947965
}
948966
}
949967
}
968+
} else if (dst->type == GGML_TYPE_I32) {
969+
for (int64_t i03 = 0; i03 < ne03; i03++) {
970+
for (int64_t i02 = 0; i02 < ne02; i02++) {
971+
i10 += ne00 * ir0;
972+
while (i10 >= ne0) {
973+
i10 -= ne0;
974+
if (++i11 == ne1) {
975+
i11 = 0;
976+
if (++i12 == ne2) {
977+
i12 = 0;
978+
if (++i13 == ne3) {
979+
i13 = 0;
980+
}
981+
}
982+
}
983+
}
984+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
985+
for (int64_t i00 = 0; i00 < ne00; i00++) {
986+
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
987+
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
988+
989+
*(int32_t *) dst_ptr = *(const float *) src0_ptr;
990+
991+
if (++i10 == ne0) {
992+
i10 = 0;
993+
if (++i11 == ne1) {
994+
i11 = 0;
995+
if (++i12 == ne2) {
996+
i12 = 0;
997+
if (++i13 == ne3) {
998+
i13 = 0;
999+
}
1000+
}
1001+
}
1002+
}
1003+
}
1004+
}
1005+
i10 += ne00 * (ne01 - ir1);
1006+
while (i10 >= ne0) {
1007+
i10 -= ne0;
1008+
if (++i11 == ne1) {
1009+
i11 = 0;
1010+
if (++i12 == ne2) {
1011+
i12 = 0;
1012+
if (++i13 == ne3) {
1013+
i13 = 0;
1014+
}
1015+
}
1016+
}
1017+
}
1018+
}
1019+
}
1020+
} else {
1021+
GGML_ABORT("fatal error"); // TODO: implement
1022+
}
1023+
}
1024+
1025+
static void ggml_compute_forward_dup_i32(
1026+
const ggml_compute_params * params,
1027+
ggml_tensor * dst) {
1028+
1029+
const ggml_tensor * src0 = dst->src[0];
1030+
1031+
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
1032+
1033+
GGML_TENSOR_UNARY_OP_LOCALS
1034+
1035+
const int ith = params->ith; // thread index
1036+
const int nth = params->nth; // number of threads
1037+
1038+
// parallelize by rows
1039+
const int nr = ne01;
1040+
// number of rows per thread
1041+
const int dr = (nr + nth - 1) / nth;
1042+
// row range for this thread
1043+
const int ir0 = dr * ith;
1044+
const int ir1 = MIN(ir0 + dr, nr);
1045+
1046+
// dst counters
1047+
1048+
int64_t i10 = 0;
1049+
int64_t i11 = 0;
1050+
int64_t i12 = 0;
1051+
int64_t i13 = 0;
1052+
1053+
// TODO: not optimal, but works
1054+
if (dst->type == GGML_TYPE_F32) {
1055+
for (int64_t i03 = 0; i03 < ne03; i03++) {
1056+
for (int64_t i02 = 0; i02 < ne02; i02++) {
1057+
i10 += ne00 * ir0;
1058+
while (i10 >= ne0) {
1059+
i10 -= ne0;
1060+
if (++i11 == ne1) {
1061+
i11 = 0;
1062+
if (++i12 == ne2) {
1063+
i12 = 0;
1064+
if (++i13 == ne3) {
1065+
i13 = 0;
1066+
}
1067+
}
1068+
}
1069+
}
1070+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
1071+
for (int64_t i00 = 0; i00 < ne00; i00++) {
1072+
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
1073+
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
1074+
1075+
*(float *) dst_ptr = *(const int32_t *) src0_ptr;
1076+
1077+
if (++i10 == ne0) {
1078+
i10 = 0;
1079+
if (++i11 == ne1) {
1080+
i11 = 0;
1081+
if (++i12 == ne2) {
1082+
i12 = 0;
1083+
if (++i13 == ne3) {
1084+
i13 = 0;
1085+
}
1086+
}
1087+
}
1088+
}
1089+
}
1090+
}
1091+
i10 += ne00 * (ne01 - ir1);
1092+
while (i10 >= ne0) {
1093+
i10 -= ne0;
1094+
if (++i11 == ne1) {
1095+
i11 = 0;
1096+
if (++i12 == ne2) {
1097+
i12 = 0;
1098+
if (++i13 == ne3) {
1099+
i13 = 0;
1100+
}
1101+
}
1102+
}
1103+
}
1104+
}
1105+
}
9501106
} else {
9511107
GGML_ABORT("fatal error"); // TODO: implement
9521108
}
@@ -1177,6 +1333,10 @@ void ggml_compute_forward_dup(
11771333
{
11781334
ggml_compute_forward_dup_f32(params, dst);
11791335
} break;
1336+
case GGML_TYPE_I32:
1337+
{
1338+
ggml_compute_forward_dup_i32(params, dst);
1339+
} break;
11801340
default:
11811341
{
11821342
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {

ggml/src/ggml-cuda/convert.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ template<typename dst_t, typename src_t>
3838
return __float2bfloat16(float(x));
3939
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
4040
return __bfloat162float(x);
41+
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
42+
return int32_t(x);
4143
} else {
4244
return float(x);
4345
}

ggml/src/ggml-cuda/cpy.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
374374
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
375375
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
376376
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
377+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
378+
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
379+
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
380+
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
377381
} else {
378382
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
379383
ggml_type_name(src0->type), ggml_type_name(src1->type));

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3461,6 +3461,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
34613461
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
34623462
return true;
34633463
}
3464+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {
3465+
return true;
3466+
}
3467+
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
3468+
return true;
3469+
}
34643470
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
34653471
return true;
34663472
}

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
583583
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
584584
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
585585
GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
586+
GGML_METAL_KERNEL_TYPE_CPY_F32_I32,
587+
GGML_METAL_KERNEL_TYPE_CPY_I32_F32,
586588
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
587589
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
588590
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -1616,6 +1618,8 @@ @implementation GGMLMetalClass
16161618
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
16171619
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
16181620
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
1621+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_I32, cpy_f32_i32, true);
1622+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_I32_F32, cpy_i32_f32, true);
16191623
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
16201624
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
16211625
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
@@ -1945,6 +1949,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19451949
case GGML_TYPE_Q5_0:
19461950
case GGML_TYPE_Q5_1:
19471951
case GGML_TYPE_IQ4_NL:
1952+
case GGML_TYPE_I32:
19481953
return true;
19491954
default:
19501955
return false;
@@ -1977,6 +1982,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19771982
default:
19781983
return false;
19791984
}
1985+
case GGML_TYPE_I32:
1986+
return op->type == GGML_TYPE_F32;
19801987
default:
19811988
return false;
19821989
};
@@ -5680,6 +5687,7 @@ static int ggml_metal_encode_node(
56805687

56815688
switch (dstt) {
56825689
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
5690+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_I32].pipeline; break;
56835691
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
56845692
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
56855693
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
@@ -5691,6 +5699,13 @@ static int ggml_metal_encode_node(
56915699
default: GGML_ABORT("not implemented");
56925700
};
56935701
} break;
5702+
case GGML_TYPE_I32:
5703+
{
5704+
switch (dstt) {
5705+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_I32_F32].pipeline; break;
5706+
default: GGML_ABORT("not implemented");
5707+
};
5708+
} break;
56945709
case GGML_TYPE_F16:
56955710
{
56965711
switch (dstt) {

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5338,6 +5338,8 @@ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
53385338

53395339
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
53405340
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
5341+
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy<float, int32_t>;
5342+
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy<int32_t, float>;
53415343
#if defined(GGML_METAL_USE_BF16)
53425344
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
53435345
#endif

0 commit comments

Comments
 (0)