Skip to content

Commit 0a1b398

Browse files
leejetjeffbolznv
andauthored
ggml: add ops for WAN video model (cuda && cpu) (#15669)
* add conv3d support * add ggml_pad_ext for cpu & cuda backend * cuda/cpu: add im2col_3d support * cuda: make im2col a little faster * fix cuda pad/scale/im2col3d * make im2col_3d faster * gguf: support loading tensors which n_dims > GGML_MAX_DIMS * fix cuda get_rows * avoid ggml_conv_3d conflict * correct GGML_OP_COUNT assertion * avoid build failure * avoid build failure on MacOS * cuda: remove unnecessary MIN define * fix cpu im2col_3d * adjust the code style * cuda: use simpler loop in get_rows * add test_im2col_3d to test-backend-ops * test-backend-ops.cpp: remove trailing whitespace * cpu: im2col_3d support non continuous src Co-authored-by: Jeff Bolz <[email protected]> * fix test_im2col_3d * remove unused variables * cuda: get_rows: dfloat2 -> float2 * add test_pad_ext to test-backend-ops.cpp * add gguf_init_from_file_ext impl * Revert "gguf: support loading tensors which n_dims > GGML_MAX_DIMS" This reverts commit d8377a0. * Revert "add gguf_init_from_file_ext impl" This reverts commit d9f1d13. * update ggml_backend_vk_device_supports_op * fix ggml_backend_vk_device_supports_op * update other backend supports op for ggml_pad_ext * metal/opencl/sycl/vulkan: fix GGML_OP_PAD check in supports_op --------- Co-authored-by: Jeff Bolz <[email protected]>
1 parent 5421f63 commit 0a1b398

File tree

17 files changed

+754
-85
lines changed

17 files changed

+754
-85
lines changed

ggml/include/ggml.h

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ extern "C" {
511511
GGML_OP_CONV_TRANSPOSE_1D,
512512
GGML_OP_IM2COL,
513513
GGML_OP_IM2COL_BACK,
514+
GGML_OP_IM2COL_3D,
514515
GGML_OP_CONV_2D,
515516
GGML_OP_CONV_3D,
516517
GGML_OP_CONV_2D_DW,
@@ -1870,6 +1871,41 @@ extern "C" {
18701871
int d0, // dilation dimension 0
18711872
int d1); // dilation dimension 1
18721873

1874+
GGML_API struct ggml_tensor * ggml_im2col_3d(
1875+
struct ggml_context * ctx,
1876+
struct ggml_tensor * a,
1877+
struct ggml_tensor * b,
1878+
int64_t IC,
1879+
int s0, // stride width
1880+
int s1, // stride height
1881+
int s2, // stride depth
1882+
int p0, // padding width
1883+
int p1, // padding height
1884+
int p2, // padding depth
1885+
int d0, // dilation width
1886+
int d1, // dilation height
1887+
int d2, // dilation depth
1888+
enum ggml_type dst_type);
1889+
1890+
// a: [OC*IC, KD, KH, KW]
1891+
// b: [N*IC, ID, IH, IW]
1892+
// result: [N*OC, OD, OH, OW]
1893+
GGML_API struct ggml_tensor * ggml_conv_3d(
1894+
struct ggml_context * ctx,
1895+
struct ggml_tensor * a,
1896+
struct ggml_tensor * b,
1897+
int64_t IC,
1898+
int s0, // stride width
1899+
int s1, // stride height
1900+
int s2, // stride depth
1901+
int p0, // padding width
1902+
int p1, // padding height
1903+
int p2, // padding depth
1904+
int d0, // dilation width
1905+
int d1, // dilation height
1906+
int d2 // dilation depth
1907+
);
1908+
18731909
// kernel size is a->ne[0] x a->ne[1]
18741910
// stride is equal to kernel size
18751911
// padding is zero
@@ -1941,7 +1977,7 @@ extern "C" {
19411977
int d0, // dilation dimension 0
19421978
int d1); // dilation dimension 1
19431979

1944-
GGML_API struct ggml_tensor * ggml_conv_3d(
1980+
GGML_API struct ggml_tensor * ggml_conv_3d_direct(
19451981
struct ggml_context * ctx,
19461982
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
19471983
struct ggml_tensor * b, // input [W, H, D, C * N]
@@ -2048,6 +2084,19 @@ extern "C" {
20482084
int p2,
20492085
int p3);
20502086

2087+
GGML_API struct ggml_tensor * ggml_pad_ext(
2088+
struct ggml_context * ctx,
2089+
struct ggml_tensor * a,
2090+
int lp0,
2091+
int rp0,
2092+
int lp1,
2093+
int rp1,
2094+
int lp2,
2095+
int rp2,
2096+
int lp3,
2097+
int rp3
2098+
);
2099+
20512100
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
20522101
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
20532102
struct ggml_context * ctx,

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,9 +589,16 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
589589
// the position of elements in the array means which dirction to padding,
590590
// each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind,
591591
// dim2.front, dim2.behind, dim3.front, dim3.behind]
592-
int64_t paddings[] = {
593-
0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1],
594-
0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]};
592+
const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
593+
const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
594+
const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
595+
const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
596+
const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
597+
const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
598+
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
599+
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
600+
601+
int64_t paddings[] = {lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3};
595602
aclnn_pad(ctx, acl_src, acl_dst, paddings);
596603
ggml_cann_release_resources(ctx, acl_src, acl_dst);
597604
}

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,6 +1876,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
18761876
{
18771877
ggml_compute_forward_im2col_back_f32(params, tensor);
18781878
} break;
1879+
case GGML_OP_IM2COL_3D:
1880+
{
1881+
ggml_compute_forward_im2col_3d(params, tensor);
1882+
} break;
18791883
case GGML_OP_CONV_2D:
18801884
{
18811885
ggml_compute_forward_conv_2d(params, tensor);
@@ -2255,6 +2259,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22552259
} break;
22562260
case GGML_OP_IM2COL:
22572261
case GGML_OP_IM2COL_BACK:
2262+
case GGML_OP_IM2COL_3D:
22582263
case GGML_OP_CONV_2D:
22592264
case GGML_OP_CONV_3D:
22602265
case GGML_OP_CONV_2D_DW:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 218 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7027,6 +7027,209 @@ void ggml_compute_forward_im2col_back_f32(
70277027
}
70287028
}
70297029

7030+
7031+
// ggml_compute_forward_im2col_3d_f16
7032+
// src0: kernel [OC*IC, KD, KH, KW]
7033+
// src1: image [N*IC, ID, IH, IW]
7034+
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
7035+
static void ggml_compute_forward_im2col_3d_f16(
7036+
const ggml_compute_params * params,
7037+
ggml_tensor * dst) {
7038+
7039+
const ggml_tensor * src0 = dst->src[0];
7040+
const ggml_tensor * src1 = dst->src[1];
7041+
7042+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
7043+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
7044+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
7045+
7046+
GGML_TENSOR_BINARY_OP_LOCALS;
7047+
7048+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
7049+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
7050+
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
7051+
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
7052+
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
7053+
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
7054+
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
7055+
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
7056+
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
7057+
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
7058+
7059+
7060+
const int ith = params->ith;
7061+
const int nth = params->nth;
7062+
7063+
const int64_t N = ne13 / IC;
7064+
const int64_t ID = ne12;
7065+
const int64_t IH = ne11;
7066+
const int64_t IW = ne10;
7067+
7068+
const int64_t OC = ne03 / IC;
7069+
GGML_UNUSED(OC);
7070+
const int64_t KD = ne02;
7071+
const int64_t KH = ne01;
7072+
const int64_t KW = ne00;
7073+
7074+
const int64_t OD = ne3 / N;
7075+
const int64_t OH = ne2;
7076+
const int64_t OW = ne1;
7077+
const int64_t OH_OW = OH*OW;
7078+
const int64_t KD_KH_KW = KD*KH*KW;
7079+
const int64_t KH_KW = KH*KW;
7080+
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
7081+
7082+
GGML_ASSERT(nb10 == sizeof(float));
7083+
7084+
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
7085+
{
7086+
ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
7087+
7088+
for (int64_t in = 0; in < N; in++) {
7089+
for (int64_t iod = 0; iod < OD; iod++) {
7090+
for (int64_t ioh = 0; ioh < OH; ioh++) {
7091+
for (int64_t iow = 0; iow < OW; iow++) {
7092+
for (int64_t iic = ith; iic < IC; iic += nth) {
7093+
7094+
// micro kernel
7095+
ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
7096+
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
7097+
7098+
for (int64_t ikd = 0; ikd < KD; ikd++) {
7099+
for (int64_t ikh = 0; ikh < KH; ikh++) {
7100+
for (int64_t ikw = 0; ikw < KW; ikw++) {
7101+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
7102+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
7103+
const int64_t iid = iod*s2 + ikd*d2 - p2;
7104+
7105+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
7106+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
7107+
} else {
7108+
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
7109+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
7110+
}
7111+
}
7112+
}
7113+
}
7114+
}
7115+
}
7116+
}
7117+
}
7118+
}
7119+
}
7120+
}
7121+
7122+
// ggml_compute_forward_im2col_3d_f32
7123+
// src0: kernel [OC*IC, KD, KH, KW]
7124+
// src1: image [N*IC, ID, IH, IW]
7125+
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
7126+
static void ggml_compute_forward_im2col_3d_f32(
7127+
const ggml_compute_params * params,
7128+
ggml_tensor * dst) {
7129+
7130+
const ggml_tensor * src0 = dst->src[0];
7131+
const ggml_tensor * src1 = dst->src[1];
7132+
7133+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
7134+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
7135+
7136+
GGML_TENSOR_BINARY_OP_LOCALS;
7137+
7138+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
7139+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
7140+
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
7141+
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
7142+
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
7143+
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
7144+
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
7145+
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
7146+
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
7147+
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
7148+
7149+
7150+
const int ith = params->ith;
7151+
const int nth = params->nth;
7152+
7153+
const int64_t N = ne13 / IC;
7154+
const int64_t ID = ne12;
7155+
const int64_t IH = ne11;
7156+
const int64_t IW = ne10;
7157+
7158+
const int64_t OC = ne03 / IC;
7159+
GGML_UNUSED(OC);
7160+
const int64_t KD = ne02;
7161+
const int64_t KH = ne01;
7162+
const int64_t KW = ne00;
7163+
7164+
const int64_t OD = ne3 / N;
7165+
const int64_t OH = ne2;
7166+
const int64_t OW = ne1;
7167+
7168+
const int64_t OH_OW = OH*OW;
7169+
const int64_t KD_KH_KW = KD*KH*KW;
7170+
const int64_t KH_KW = KH*KW;
7171+
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
7172+
7173+
GGML_ASSERT(nb10 == sizeof(float));
7174+
7175+
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
7176+
{
7177+
float * const wdata = (float *) dst->data;
7178+
7179+
for (int64_t in = 0; in < N; in++) {
7180+
for (int64_t iod = 0; iod < OD; iod++) {
7181+
for (int64_t ioh = 0; ioh < OH; ioh++) {
7182+
for (int64_t iow = 0; iow < OW; iow++) {
7183+
for (int64_t iic = ith; iic < IC; iic += nth) {
7184+
7185+
// micro kernel
7186+
float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
7187+
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
7188+
7189+
for (int64_t ikd = 0; ikd < KD; ikd++) {
7190+
for (int64_t ikh = 0; ikh < KH; ikh++) {
7191+
for (int64_t ikw = 0; ikw < KW; ikw++) {
7192+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
7193+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
7194+
const int64_t iid = iod*s2 + ikd*d2 - p2;
7195+
7196+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
7197+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
7198+
} else {
7199+
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
7200+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
7201+
}
7202+
}
7203+
}
7204+
}
7205+
}
7206+
}
7207+
}
7208+
}
7209+
}
7210+
}
7211+
}
7212+
7213+
7214+
void ggml_compute_forward_im2col_3d(
7215+
const ggml_compute_params * params,
7216+
ggml_tensor * dst) {
7217+
switch (dst->type) {
7218+
case GGML_TYPE_F16:
7219+
{
7220+
ggml_compute_forward_im2col_3d_f16(params, dst);
7221+
} break;
7222+
case GGML_TYPE_F32:
7223+
{
7224+
ggml_compute_forward_im2col_3d_f32(params, dst);
7225+
} break;
7226+
default:
7227+
{
7228+
GGML_ABORT("fatal error");
7229+
}
7230+
}
7231+
}
7232+
70307233
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
70317234
void * a, void * b, float * c) {
70327235
const ggml_type_traits * traits = ggml_get_type_traits(type);
@@ -8014,6 +8217,15 @@ static void ggml_compute_forward_pad_f32(
80148217
GGML_TENSOR_UNARY_OP_LOCALS
80158218

80168219
float * dst_ptr = (float *) dst->data;
8220+
const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
8221+
const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
8222+
const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
8223+
const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
8224+
const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
8225+
const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
8226+
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
8227+
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
8228+
80178229

80188230
// TODO: optimize
80198231

@@ -8022,10 +8234,12 @@ static void ggml_compute_forward_pad_f32(
80228234
for (int64_t i0 = 0; i0 < ne0; ++i0) {
80238235
for (int64_t i3 = 0; i3 < ne3; ++i3) {
80248236
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
8025-
8026-
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
8027-
8028-
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
8237+
if ((i0 >= lp0 && i0 < ne0 - rp0) \
8238+
&& (i1 >= lp1 && i1 < ne1 - rp1) \
8239+
&& (i2 >= lp2 && i2 < ne2 - rp2) \
8240+
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
8241+
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
8242+
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
80298243
dst_ptr[dst_idx] = *src_ptr;
80308244
} else {
80318245
dst_ptr[dst_idx] = 0;

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc
6969
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7070
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7171
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
72+
void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7273
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7374
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7475
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);

0 commit comments

Comments
 (0)