@@ -3987,8 +3987,72 @@ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kerne
3987
3987
template [[host_name(" kernel_rope_vision_f32" )]] kernel kernel_rope_vision_t kernel_rope_vision<float >;
3988
3988
template [[host_name(" kernel_rope_vision_f16" )]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
3989
3989
3990
+ typedef void (im2col_t )(
3991
+ constant ggml_metal_kargs_im2col & args,
3992
+ device const float * x,
3993
+ device char * dst,
3994
+ uint3 tgpig[[threadgroup_position_in_grid]],
3995
+ uint3 tgpg[[threadgroups_per_grid]],
3996
+ uint3 tpitg[[thread_position_in_threadgroup]],
3997
+ uint3 ntg[[threads_per_threadgroup]]);
3998
+
3999
+ template <typename T>
4000
+ kernel void kernel_im2col (
4001
+ constant ggml_metal_kargs_im2col & args,
4002
+ device const float * x,
4003
+ device char * dst,
4004
+ uint3 tgpig[[threadgroup_position_in_grid]],
4005
+ uint3 tgpg[[threadgroups_per_grid]],
4006
+ uint3 tpitg[[thread_position_in_threadgroup]],
4007
+ uint3 ntg[[threads_per_threadgroup]]) {
4008
+ // const int64_t IC = tgpg[0];
4009
+ const int64_t OH = tgpg[1 ];
4010
+ const int64_t OW = tgpg[2 ];
4011
+
4012
+ const int64_t KH = ntg[1 ];
4013
+ const int64_t KW = ntg[2 ];
4014
+
4015
+ int64_t in = tpitg[0 ];
4016
+ const int64_t ikh = tpitg[1 ];
4017
+ const int64_t ikw = tpitg[2 ];
4018
+
4019
+ const int64_t iic = tgpig[0 ];
4020
+ const int64_t ioh = tgpig[1 ];
4021
+ const int64_t iow = tgpig[2 ];
4022
+
4023
+ const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0 ;
4024
+ const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1 ;
4025
+
4026
+ int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
4027
+
4028
+ device T * pdst = (device T *) (dst);
4029
+
4030
+ if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW ) {
4031
+ while (in < args.N ) {
4032
+ pdst[offset_dst] = 0 .0f ;
4033
+ offset_dst += ntg[0 ]*args.CHW *OH*OW;
4034
+
4035
+ in += ntg[0 ];
4036
+ }
4037
+ } else {
4038
+ int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
4039
+
4040
+ while (in < args.N ) {
4041
+ pdst[offset_dst] = x[offset_src];
4042
+
4043
+ offset_dst += ntg[0 ]*args.CHW *OH*OW;
4044
+ offset_src += ntg[0 ]*args.ofs0 ;
4045
+
4046
+ in += ntg[0 ];
4047
+ }
4048
+ }
4049
+ }
4050
+
4051
+ template [[host_name(" kernel_im2col_f32" )]] kernel im2col_t kernel_im2col<float >;
4052
+ template [[host_name(" kernel_im2col_f16" )]] kernel im2col_t kernel_im2col<half>;
4053
+
3990
4054
// TODO: obolete -- remove
3991
- // typedef void (im2col_t )(
4055
+ // typedef void (im2col_ext_t )(
3992
4056
// constant ggml_metal_kargs_im2col & args,
3993
4057
// device const float * x,
3994
4058
// device char * dst,
@@ -3998,100 +4062,48 @@ template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t ker
3998
4062
// uint3 ntg[[threads_per_threadgroup]]);
3999
4063
//
4000
4064
// template <typename T>
4001
- // kernel void kernel_im2col (
4065
+ // kernel void kernel_im2col_ext (
4002
4066
// constant ggml_metal_kargs_im2col & args,
4003
4067
// device const float * x,
4004
4068
// device char * dst,
4005
4069
// uint3 tgpig[[threadgroup_position_in_grid]],
4006
- // uint3 tgpg[[threadgroups_per_grid]],
4070
+ // uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4007
4071
// uint3 tpitg[[thread_position_in_threadgroup]],
4008
- // uint3 ntg[[threads_per_threadgroup]]) {
4009
- // // const int64_t IC = tgpg[0];
4010
- // const int64_t OH = tgpg[1];
4011
- // const int64_t OW = tgpg[2];
4072
+ // uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4073
+ // const int64_t KHW = (int64_t)args.KHW;
4012
4074
//
4013
- // // const int64_t N = ntg[0];
4014
- // const int64_t KH = ntg[1];
4015
- // const int64_t KW = ntg[2];
4075
+ // const int64_t d = tgpig[0] / args.CHW;
4076
+ // const int64_t chw = tgpig[0] % args.CHW;
4077
+ // const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4078
+ // const int64_t HW = tgpig[0] % KHW;
4016
4079
//
4017
- // const int64_t in = tpitg[0];
4018
- // const int64_t ikh = tpitg[1];
4019
- // const int64_t ikw = tpitg[2];
4080
+ // const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4081
+ // if (tpitg_0 >= args.N) {
4082
+ // return;
4083
+ // }
4020
4084
//
4021
- // const int64_t iic = tgpig[0];
4022
- // const int64_t ioh = tgpig[1];
4023
- // const int64_t iow = tgpig[2];
4085
+ // const int64_t tpitg_1 = HW / args.KW;
4086
+ // const int64_t tpitg_2 = HW % args.KW;
4024
4087
//
4025
- // const int64_t iiw = iow* args.s0 + ikw* args.d0 - args.p0;
4026
- // const int64_t iih = ioh* args.s1 + ikh* args.d1 - args.p1;
4088
+ // const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4089
+ // const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
4027
4090
//
4028
- // const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
4091
+ // const int64_t offset_dst =
4092
+ // (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4093
+ // (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4029
4094
//
4030
4095
// device T * pdst = (device T *) (dst);
4031
4096
//
4032
4097
// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4033
4098
// pdst[offset_dst] = 0.0f;
4034
4099
// } else {
4035
- // const int64_t offset_src = in* args.ofs0 + iic*args.ofs1 + iih* args.IW + iiw ;
4036
- // pdst[offset_dst] = x[offset_src];
4100
+ // const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1 ;
4101
+ // pdst[offset_dst] = x[offset_src + iih * args.IW + iiw ];
4037
4102
// }
4038
4103
// }
4039
4104
//
4040
- // template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
4041
- // template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
4042
-
4043
- typedef void (im2col_ext_t )(
4044
- constant ggml_metal_kargs_im2col & args,
4045
- device const float * x,
4046
- device char * dst,
4047
- uint3 tgpig[[threadgroup_position_in_grid]],
4048
- uint3 tgpg[[threadgroups_per_grid]],
4049
- uint3 tpitg[[thread_position_in_threadgroup]],
4050
- uint3 ntg[[threads_per_threadgroup]]);
4051
-
4052
- template <typename T>
4053
- kernel void kernel_im2col_ext (
4054
- constant ggml_metal_kargs_im2col & args,
4055
- device const float * x,
4056
- device char * dst,
4057
- uint3 tgpig[[threadgroup_position_in_grid]],
4058
- uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4059
- uint3 tpitg[[thread_position_in_threadgroup]],
4060
- uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4061
- const int64_t KHW = (int64_t )args.KHW ;
4062
-
4063
- const int64_t d = tgpig[0 ] / args.CHW ;
4064
- const int64_t chw = tgpig[0 ] % args.CHW ;
4065
- const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4066
- const int64_t HW = tgpig[0 ] % KHW;
4067
-
4068
- const int64_t tpitg_0 = (d * ntg[0 ]) + tpitg[0 ];
4069
- if (tpitg_0 >= args.N ) {
4070
- return ;
4071
- }
4072
-
4073
- const int64_t tpitg_1 = HW / args.KW ;
4074
- const int64_t tpitg_2 = HW % args.KW ;
4075
-
4076
- const int64_t iiw = tgpig[2 ] * args.s0 + tpitg_2 * args.d0 - args.p0 ;
4077
- const int64_t iih = tgpig[1 ] * args.s1 + tpitg_1 * args.d1 - args.p1 ;
4078
-
4079
- const int64_t offset_dst =
4080
- (tpitg_0 * tgpg[1 ] * tgpg[2 ] + tgpig[1 ] * tgpg[2 ] + tgpig[2 ]) * args.CHW +
4081
- (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4082
-
4083
- device T * pdst = (device T *) (dst);
4084
-
4085
- if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW ) {
4086
- pdst[offset_dst] = 0 .0f ;
4087
- } else {
4088
- const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1 ;
4089
- pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4090
- }
4091
- }
4092
-
4093
- template [[host_name(" kernel_im2col_ext_f32" )]] kernel im2col_ext_t kernel_im2col_ext<float >;
4094
- template [[host_name(" kernel_im2col_ext_f16" )]] kernel im2col_ext_t kernel_im2col_ext<half>;
4105
+ // template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4106
+ // template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
4095
4107
4096
4108
typedef void (conv_transpose_1d_t )(
4097
4109
constant ggml_metal_kargs_conv_transpose_1d & args,
0 commit comments