Skip to content

Commit 02a6a82

Browse files
authored
metal : restore im2col perf (#16219)
1 parent c498fc8 commit 02a6a82

File tree

3 files changed

+93
-81
lines changed

3 files changed

+93
-81
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1237,7 +1237,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_
12371237
char base[256];
12381238
char name[256];
12391239

1240-
snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type));
1240+
snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
12411241
snprintf(name, 256, "%s", base);
12421242

12431243
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2768,7 +2768,6 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
27682768
const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4;
27692769
const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4;
27702770

2771-
27722771
ggml_metal_kargs_im2col args = {
27732772
/*.ofs0 =*/ ofs0,
27742773
/*.ofs1 =*/ ofs1,
@@ -2789,15 +2788,16 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
27892788

27902789
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
27912790

2792-
const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N);
2793-
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
2791+
GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2792+
2793+
const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
27942794

27952795
ggml_metal_encoder_set_pipeline(enc, pipeline);
27962796
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
27972797
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
27982798
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
27992799

2800-
ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1);
2800+
ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
28012801

28022802
return 1;
28032803
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 88 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -3987,8 +3987,72 @@ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kerne
39873987
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
39883988
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
39893989

3990+
typedef void (im2col_t)(
3991+
constant ggml_metal_kargs_im2col & args,
3992+
device const float * x,
3993+
device char * dst,
3994+
uint3 tgpig[[threadgroup_position_in_grid]],
3995+
uint3 tgpg[[threadgroups_per_grid]],
3996+
uint3 tpitg[[thread_position_in_threadgroup]],
3997+
uint3 ntg[[threads_per_threadgroup]]);
3998+
3999+
template <typename T>
4000+
kernel void kernel_im2col(
4001+
constant ggml_metal_kargs_im2col & args,
4002+
device const float * x,
4003+
device char * dst,
4004+
uint3 tgpig[[threadgroup_position_in_grid]],
4005+
uint3 tgpg[[threadgroups_per_grid]],
4006+
uint3 tpitg[[thread_position_in_threadgroup]],
4007+
uint3 ntg[[threads_per_threadgroup]]) {
4008+
// const int64_t IC = tgpg[0];
4009+
const int64_t OH = tgpg[1];
4010+
const int64_t OW = tgpg[2];
4011+
4012+
const int64_t KH = ntg[1];
4013+
const int64_t KW = ntg[2];
4014+
4015+
int64_t in = tpitg[0];
4016+
const int64_t ikh = tpitg[1];
4017+
const int64_t ikw = tpitg[2];
4018+
4019+
const int64_t iic = tgpig[0];
4020+
const int64_t ioh = tgpig[1];
4021+
const int64_t iow = tgpig[2];
4022+
4023+
const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
4024+
const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
4025+
4026+
int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
4027+
4028+
device T * pdst = (device T *) (dst);
4029+
4030+
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4031+
while (in < args.N) {
4032+
pdst[offset_dst] = 0.0f;
4033+
offset_dst += ntg[0]*args.CHW*OH*OW;
4034+
4035+
in += ntg[0];
4036+
}
4037+
} else {
4038+
int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
4039+
4040+
while (in < args.N) {
4041+
pdst[offset_dst] = x[offset_src];
4042+
4043+
offset_dst += ntg[0]*args.CHW*OH*OW;
4044+
offset_src += ntg[0]*args.ofs0;
4045+
4046+
in += ntg[0];
4047+
}
4048+
}
4049+
}
4050+
4051+
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
4052+
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
4053+
39904054
// TODO: obolete -- remove
3991-
//typedef void (im2col_t)(
4055+
//typedef void (im2col_ext_t)(
39924056
// constant ggml_metal_kargs_im2col & args,
39934057
// device const float * x,
39944058
// device char * dst,
@@ -3998,100 +4062,48 @@ template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t ker
39984062
// uint3 ntg[[threads_per_threadgroup]]);
39994063
//
40004064
//template <typename T>
4001-
//kernel void kernel_im2col(
4065+
//kernel void kernel_im2col_ext(
40024066
// constant ggml_metal_kargs_im2col & args,
40034067
// device const float * x,
40044068
// device char * dst,
40054069
// uint3 tgpig[[threadgroup_position_in_grid]],
4006-
// uint3 tgpg[[threadgroups_per_grid]],
4070+
// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
40074071
// uint3 tpitg[[thread_position_in_threadgroup]],
4008-
// uint3 ntg[[threads_per_threadgroup]]) {
4009-
//// const int64_t IC = tgpg[0];
4010-
// const int64_t OH = tgpg[1];
4011-
// const int64_t OW = tgpg[2];
4072+
// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4073+
// const int64_t KHW = (int64_t)args.KHW;
40124074
//
4013-
//// const int64_t N = ntg[0];
4014-
// const int64_t KH = ntg[1];
4015-
// const int64_t KW = ntg[2];
4075+
// const int64_t d = tgpig[0] / args.CHW;
4076+
// const int64_t chw = tgpig[0] % args.CHW;
4077+
// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4078+
// const int64_t HW = tgpig[0] % KHW;
40164079
//
4017-
// const int64_t in = tpitg[0];
4018-
// const int64_t ikh = tpitg[1];
4019-
// const int64_t ikw = tpitg[2];
4080+
// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4081+
// if (tpitg_0 >= args.N) {
4082+
// return;
4083+
// }
40204084
//
4021-
// const int64_t iic = tgpig[0];
4022-
// const int64_t ioh = tgpig[1];
4023-
// const int64_t iow = tgpig[2];
4085+
// const int64_t tpitg_1 = HW / args.KW;
4086+
// const int64_t tpitg_2 = HW % args.KW;
40244087
//
4025-
// const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
4026-
// const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
4088+
// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4089+
// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
40274090
//
4028-
// const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
4091+
// const int64_t offset_dst =
4092+
// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4093+
// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
40294094
//
40304095
// device T * pdst = (device T *) (dst);
40314096
//
40324097
// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
40334098
// pdst[offset_dst] = 0.0f;
40344099
// } else {
4035-
// const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
4036-
// pdst[offset_dst] = x[offset_src];
4100+
// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
4101+
// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
40374102
// }
40384103
//}
40394104
//
4040-
//template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
4041-
//template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
4042-
4043-
typedef void (im2col_ext_t)(
4044-
constant ggml_metal_kargs_im2col & args,
4045-
device const float * x,
4046-
device char * dst,
4047-
uint3 tgpig[[threadgroup_position_in_grid]],
4048-
uint3 tgpg[[threadgroups_per_grid]],
4049-
uint3 tpitg[[thread_position_in_threadgroup]],
4050-
uint3 ntg[[threads_per_threadgroup]]);
4051-
4052-
template <typename T>
4053-
kernel void kernel_im2col_ext(
4054-
constant ggml_metal_kargs_im2col & args,
4055-
device const float * x,
4056-
device char * dst,
4057-
uint3 tgpig[[threadgroup_position_in_grid]],
4058-
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4059-
uint3 tpitg[[thread_position_in_threadgroup]],
4060-
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4061-
const int64_t KHW = (int64_t)args.KHW;
4062-
4063-
const int64_t d = tgpig[0] / args.CHW;
4064-
const int64_t chw = tgpig[0] % args.CHW;
4065-
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4066-
const int64_t HW = tgpig[0] % KHW;
4067-
4068-
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4069-
if (tpitg_0 >= args.N) {
4070-
return;
4071-
}
4072-
4073-
const int64_t tpitg_1 = HW / args.KW;
4074-
const int64_t tpitg_2 = HW % args.KW;
4075-
4076-
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4077-
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
4078-
4079-
const int64_t offset_dst =
4080-
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4081-
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4082-
4083-
device T * pdst = (device T *) (dst);
4084-
4085-
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4086-
pdst[offset_dst] = 0.0f;
4087-
} else {
4088-
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
4089-
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4090-
}
4091-
}
4092-
4093-
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4094-
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
4105+
//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4106+
//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
40954107

40964108
typedef void (conv_transpose_1d_t)(
40974109
constant ggml_metal_kargs_conv_transpose_1d & args,

0 commit comments

Comments
 (0)