@@ -7027,6 +7027,209 @@ void ggml_compute_forward_im2col_back_f32(
7027
7027
}
7028
7028
}
7029
7029
7030
+
7031
+ // ggml_compute_forward_im2col_3d_f16
7032
+ // src0: kernel [OC*IC, KD, KH, KW]
7033
+ // src1: image [N*IC, ID, IH, IW]
7034
+ // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
7035
+ static void ggml_compute_forward_im2col_3d_f16 (
7036
+ const ggml_compute_params * params,
7037
+ ggml_tensor * dst) {
7038
+
7039
+ const ggml_tensor * src0 = dst->src [0 ];
7040
+ const ggml_tensor * src1 = dst->src [1 ];
7041
+
7042
+ GGML_ASSERT (src0->type == GGML_TYPE_F16);
7043
+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
7044
+ GGML_ASSERT ( dst->type == GGML_TYPE_F16);
7045
+
7046
+ GGML_TENSOR_BINARY_OP_LOCALS;
7047
+
7048
+ const int32_t s0 = ((const int32_t *)(dst->op_params ))[0 ];
7049
+ const int32_t s1 = ((const int32_t *)(dst->op_params ))[1 ];
7050
+ const int32_t s2 = ((const int32_t *)(dst->op_params ))[2 ];
7051
+ const int32_t p0 = ((const int32_t *)(dst->op_params ))[3 ];
7052
+ const int32_t p1 = ((const int32_t *)(dst->op_params ))[4 ];
7053
+ const int32_t p2 = ((const int32_t *)(dst->op_params ))[5 ];
7054
+ const int32_t d0 = ((const int32_t *)(dst->op_params ))[6 ];
7055
+ const int32_t d1 = ((const int32_t *)(dst->op_params ))[7 ];
7056
+ const int32_t d2 = ((const int32_t *)(dst->op_params ))[8 ];
7057
+ const int32_t IC = ((const int32_t *)(dst->op_params ))[9 ];
7058
+
7059
+
7060
+ const int ith = params->ith ;
7061
+ const int nth = params->nth ;
7062
+
7063
+ const int64_t N = ne13 / IC;
7064
+ const int64_t ID = ne12;
7065
+ const int64_t IH = ne11;
7066
+ const int64_t IW = ne10;
7067
+
7068
+ const int64_t OC = ne03 / IC;
7069
+ GGML_UNUSED (OC);
7070
+ const int64_t KD = ne02;
7071
+ const int64_t KH = ne01;
7072
+ const int64_t KW = ne00;
7073
+
7074
+ const int64_t OD = ne3 / N;
7075
+ const int64_t OH = ne2;
7076
+ const int64_t OW = ne1;
7077
+ const int64_t OH_OW = OH*OW;
7078
+ const int64_t KD_KH_KW = KD*KH*KW;
7079
+ const int64_t KH_KW = KH*KW;
7080
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
7081
+
7082
+ GGML_ASSERT (nb10 == sizeof (float ));
7083
+
7084
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
7085
+ {
7086
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data ;
7087
+
7088
+ for (int64_t in = 0 ; in < N; in++) {
7089
+ for (int64_t iod = 0 ; iod < OD; iod++) {
7090
+ for (int64_t ioh = 0 ; ioh < OH; ioh++) {
7091
+ for (int64_t iow = 0 ; iow < OW; iow++) {
7092
+ for (int64_t iic = ith; iic < IC; iic += nth) {
7093
+
7094
+ // micro kernel
7095
+ ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
7096
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
7097
+
7098
+ for (int64_t ikd = 0 ; ikd < KD; ikd++) {
7099
+ for (int64_t ikh = 0 ; ikh < KH; ikh++) {
7100
+ for (int64_t ikw = 0 ; ikw < KW; ikw++) {
7101
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
7102
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
7103
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
7104
+
7105
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
7106
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0 ;
7107
+ } else {
7108
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
7109
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16 (*s);
7110
+ }
7111
+ }
7112
+ }
7113
+ }
7114
+ }
7115
+ }
7116
+ }
7117
+ }
7118
+ }
7119
+ }
7120
+ }
7121
+
7122
+ // ggml_compute_forward_im2col_3d_f32
7123
+ // src0: kernel [OC*IC, KD, KH, KW]
7124
+ // src1: image [N*IC, ID, IH, IW]
7125
+ // dst: result [N*OD, OH, OW, IC * KD * KH * KW]
7126
+ static void ggml_compute_forward_im2col_3d_f32 (
7127
+ const ggml_compute_params * params,
7128
+ ggml_tensor * dst) {
7129
+
7130
+ const ggml_tensor * src0 = dst->src [0 ];
7131
+ const ggml_tensor * src1 = dst->src [1 ];
7132
+
7133
+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
7134
+ GGML_ASSERT ( dst->type == GGML_TYPE_F32);
7135
+
7136
+ GGML_TENSOR_BINARY_OP_LOCALS;
7137
+
7138
+ const int32_t s0 = ((const int32_t *)(dst->op_params ))[0 ];
7139
+ const int32_t s1 = ((const int32_t *)(dst->op_params ))[1 ];
7140
+ const int32_t s2 = ((const int32_t *)(dst->op_params ))[2 ];
7141
+ const int32_t p0 = ((const int32_t *)(dst->op_params ))[3 ];
7142
+ const int32_t p1 = ((const int32_t *)(dst->op_params ))[4 ];
7143
+ const int32_t p2 = ((const int32_t *)(dst->op_params ))[5 ];
7144
+ const int32_t d0 = ((const int32_t *)(dst->op_params ))[6 ];
7145
+ const int32_t d1 = ((const int32_t *)(dst->op_params ))[7 ];
7146
+ const int32_t d2 = ((const int32_t *)(dst->op_params ))[8 ];
7147
+ const int32_t IC = ((const int32_t *)(dst->op_params ))[9 ];
7148
+
7149
+
7150
+ const int ith = params->ith ;
7151
+ const int nth = params->nth ;
7152
+
7153
+ const int64_t N = ne13 / IC;
7154
+ const int64_t ID = ne12;
7155
+ const int64_t IH = ne11;
7156
+ const int64_t IW = ne10;
7157
+
7158
+ const int64_t OC = ne03 / IC;
7159
+ GGML_UNUSED (OC);
7160
+ const int64_t KD = ne02;
7161
+ const int64_t KH = ne01;
7162
+ const int64_t KW = ne00;
7163
+
7164
+ const int64_t OD = ne3 / N;
7165
+ const int64_t OH = ne2;
7166
+ const int64_t OW = ne1;
7167
+
7168
+ const int64_t OH_OW = OH*OW;
7169
+ const int64_t KD_KH_KW = KD*KH*KW;
7170
+ const int64_t KH_KW = KH*KW;
7171
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
7172
+
7173
+ GGML_ASSERT (nb10 == sizeof (float ));
7174
+
7175
+ // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
7176
+ {
7177
+ float * const wdata = (float *) dst->data ;
7178
+
7179
+ for (int64_t in = 0 ; in < N; in++) {
7180
+ for (int64_t iod = 0 ; iod < OD; iod++) {
7181
+ for (int64_t ioh = 0 ; ioh < OH; ioh++) {
7182
+ for (int64_t iow = 0 ; iow < OW; iow++) {
7183
+ for (int64_t iic = ith; iic < IC; iic += nth) {
7184
+
7185
+ // micro kernel
7186
+ float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
7187
+ const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
7188
+
7189
+ for (int64_t ikd = 0 ; ikd < KD; ikd++) {
7190
+ for (int64_t ikh = 0 ; ikh < KH; ikh++) {
7191
+ for (int64_t ikw = 0 ; ikw < KW; ikw++) {
7192
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
7193
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
7194
+ const int64_t iid = iod*s2 + ikd*d2 - p2;
7195
+
7196
+ if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
7197
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0 ;
7198
+ } else {
7199
+ const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
7200
+ dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
7201
+ }
7202
+ }
7203
+ }
7204
+ }
7205
+ }
7206
+ }
7207
+ }
7208
+ }
7209
+ }
7210
+ }
7211
+ }
7212
+
7213
+
7214
+ void ggml_compute_forward_im2col_3d (
7215
+ const ggml_compute_params * params,
7216
+ ggml_tensor * dst) {
7217
+ switch (dst->type ) {
7218
+ case GGML_TYPE_F16:
7219
+ {
7220
+ ggml_compute_forward_im2col_3d_f16 (params, dst);
7221
+ } break ;
7222
+ case GGML_TYPE_F32:
7223
+ {
7224
+ ggml_compute_forward_im2col_3d_f32 (params, dst);
7225
+ } break ;
7226
+ default :
7227
+ {
7228
+ GGML_ABORT (" fatal error" );
7229
+ }
7230
+ }
7231
+ }
7232
+
7030
7233
static void ggml_call_mul_mat (ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
7031
7234
void * a, void * b, float * c) {
7032
7235
const ggml_type_traits * traits = ggml_get_type_traits (type);
@@ -8014,6 +8217,15 @@ static void ggml_compute_forward_pad_f32(
8014
8217
GGML_TENSOR_UNARY_OP_LOCALS
8015
8218
8016
8219
float * dst_ptr = (float *) dst->data ;
8220
+ const int32_t lp0 = ggml_get_op_params_i32 (dst, 0 );
8221
+ const int32_t rp0 = ggml_get_op_params_i32 (dst, 1 );
8222
+ const int32_t lp1 = ggml_get_op_params_i32 (dst, 2 );
8223
+ const int32_t rp1 = ggml_get_op_params_i32 (dst, 3 );
8224
+ const int32_t lp2 = ggml_get_op_params_i32 (dst, 4 );
8225
+ const int32_t rp2 = ggml_get_op_params_i32 (dst, 5 );
8226
+ const int32_t lp3 = ggml_get_op_params_i32 (dst, 6 );
8227
+ const int32_t rp3 = ggml_get_op_params_i32 (dst, 7 );
8228
+
8017
8229
8018
8230
// TODO: optimize
8019
8231
@@ -8022,10 +8234,12 @@ static void ggml_compute_forward_pad_f32(
8022
8234
for (int64_t i0 = 0 ; i0 < ne0; ++i0) {
8023
8235
for (int64_t i3 = 0 ; i3 < ne3; ++i3) {
8024
8236
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
8025
-
8026
- const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
8027
-
8028
- if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
8237
+ if ((i0 >= lp0 && i0 < ne0 - rp0) \
8238
+ && (i1 >= lp1 && i1 < ne1 - rp1) \
8239
+ && (i2 >= lp2 && i2 < ne2 - rp2) \
8240
+ && (i3 >= lp3 && i3 < ne3 - rp3)) {
8241
+ const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
8242
+ const float * src_ptr = (const float *)((char *) src0->data + src_idx);
8029
8243
dst_ptr[dst_idx] = *src_ptr;
8030
8244
} else {
8031
8245
dst_ptr[dst_idx] = 0 ;
0 commit comments