@@ -776,6 +776,24 @@ static void ggml_compute_forward_dup_f32(
776
776
id += ne00 * (ne01 - ir1);
777
777
}
778
778
}
779
+ } else if (dst->type == GGML_TYPE_I32) {
780
+ size_t id = 0 ;
781
+ int32_t * dst_ptr = (int32_t *) dst->data ;
782
+
783
+ for (int i03 = 0 ; i03 < ne03; i03++) {
784
+ for (int i02 = 0 ; i02 < ne02; i02++) {
785
+ id += ne00 * ir0;
786
+ for (int i01 = ir0; i01 < ir1; i01++) {
787
+ for (int i00 = 0 ; i00 < ne00; i00++) {
788
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
789
+
790
+ dst_ptr[id] = *src0_ptr;
791
+ id++;
792
+ }
793
+ }
794
+ id += ne00 * (ne01 - ir1);
795
+ }
796
+ }
779
797
} else {
780
798
GGML_ABORT (" fatal error" ); // TODO: implement
781
799
}
@@ -947,6 +965,144 @@ static void ggml_compute_forward_dup_f32(
947
965
}
948
966
}
949
967
}
968
+ } else if (dst->type == GGML_TYPE_I32) {
969
+ for (int64_t i03 = 0 ; i03 < ne03; i03++) {
970
+ for (int64_t i02 = 0 ; i02 < ne02; i02++) {
971
+ i10 += ne00 * ir0;
972
+ while (i10 >= ne0) {
973
+ i10 -= ne0;
974
+ if (++i11 == ne1) {
975
+ i11 = 0 ;
976
+ if (++i12 == ne2) {
977
+ i12 = 0 ;
978
+ if (++i13 == ne3) {
979
+ i13 = 0 ;
980
+ }
981
+ }
982
+ }
983
+ }
984
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
985
+ for (int64_t i00 = 0 ; i00 < ne00; i00++) {
986
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
987
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
988
+
989
+ *(int32_t *) dst_ptr = *(const float *) src0_ptr;
990
+
991
+ if (++i10 == ne0) {
992
+ i10 = 0 ;
993
+ if (++i11 == ne1) {
994
+ i11 = 0 ;
995
+ if (++i12 == ne2) {
996
+ i12 = 0 ;
997
+ if (++i13 == ne3) {
998
+ i13 = 0 ;
999
+ }
1000
+ }
1001
+ }
1002
+ }
1003
+ }
1004
+ }
1005
+ i10 += ne00 * (ne01 - ir1);
1006
+ while (i10 >= ne0) {
1007
+ i10 -= ne0;
1008
+ if (++i11 == ne1) {
1009
+ i11 = 0 ;
1010
+ if (++i12 == ne2) {
1011
+ i12 = 0 ;
1012
+ if (++i13 == ne3) {
1013
+ i13 = 0 ;
1014
+ }
1015
+ }
1016
+ }
1017
+ }
1018
+ }
1019
+ }
1020
+ } else {
1021
+ GGML_ABORT (" fatal error" ); // TODO: implement
1022
+ }
1023
+ }
1024
+
1025
+ static void ggml_compute_forward_dup_i32 (
1026
+ const ggml_compute_params * params,
1027
+ ggml_tensor * dst) {
1028
+
1029
+ const ggml_tensor * src0 = dst->src [0 ];
1030
+
1031
+ GGML_ASSERT (ggml_nelements (dst) == ggml_nelements (src0));
1032
+
1033
+ GGML_TENSOR_UNARY_OP_LOCALS
1034
+
1035
+ const int ith = params->ith ; // thread index
1036
+ const int nth = params->nth ; // number of threads
1037
+
1038
+ // parallelize by rows
1039
+ const int nr = ne01;
1040
+ // number of rows per thread
1041
+ const int dr = (nr + nth - 1 ) / nth;
1042
+ // row range for this thread
1043
+ const int ir0 = dr * ith;
1044
+ const int ir1 = MIN (ir0 + dr, nr);
1045
+
1046
+ // dst counters
1047
+
1048
+ int64_t i10 = 0 ;
1049
+ int64_t i11 = 0 ;
1050
+ int64_t i12 = 0 ;
1051
+ int64_t i13 = 0 ;
1052
+
1053
+ // TODO: not optimal, but works
1054
+ if (dst->type == GGML_TYPE_F32) {
1055
+ for (int64_t i03 = 0 ; i03 < ne03; i03++) {
1056
+ for (int64_t i02 = 0 ; i02 < ne02; i02++) {
1057
+ i10 += ne00 * ir0;
1058
+ while (i10 >= ne0) {
1059
+ i10 -= ne0;
1060
+ if (++i11 == ne1) {
1061
+ i11 = 0 ;
1062
+ if (++i12 == ne2) {
1063
+ i12 = 0 ;
1064
+ if (++i13 == ne3) {
1065
+ i13 = 0 ;
1066
+ }
1067
+ }
1068
+ }
1069
+ }
1070
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
1071
+ for (int64_t i00 = 0 ; i00 < ne00; i00++) {
1072
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
1073
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
1074
+
1075
+ *(float *) dst_ptr = *(const int32_t *) src0_ptr;
1076
+
1077
+ if (++i10 == ne0) {
1078
+ i10 = 0 ;
1079
+ if (++i11 == ne1) {
1080
+ i11 = 0 ;
1081
+ if (++i12 == ne2) {
1082
+ i12 = 0 ;
1083
+ if (++i13 == ne3) {
1084
+ i13 = 0 ;
1085
+ }
1086
+ }
1087
+ }
1088
+ }
1089
+ }
1090
+ }
1091
+ i10 += ne00 * (ne01 - ir1);
1092
+ while (i10 >= ne0) {
1093
+ i10 -= ne0;
1094
+ if (++i11 == ne1) {
1095
+ i11 = 0 ;
1096
+ if (++i12 == ne2) {
1097
+ i12 = 0 ;
1098
+ if (++i13 == ne3) {
1099
+ i13 = 0 ;
1100
+ }
1101
+ }
1102
+ }
1103
+ }
1104
+ }
1105
+ }
950
1106
} else {
951
1107
GGML_ABORT (" fatal error" ); // TODO: implement
952
1108
}
@@ -1177,6 +1333,10 @@ void ggml_compute_forward_dup(
1177
1333
{
1178
1334
ggml_compute_forward_dup_f32 (params, dst);
1179
1335
} break ;
1336
+ case GGML_TYPE_I32:
1337
+ {
1338
+ ggml_compute_forward_dup_i32 (params, dst);
1339
+ } break ;
1180
1340
default :
1181
1341
{
1182
1342
if (ggml_is_quantized (src0->type ) && dst->type == GGML_TYPE_F32) {
0 commit comments