Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 185 additions & 28 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,8 @@ struct vk_device_struct {
vk_pipeline pipeline_opt_step_sgd_f32;
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;

Expand Down Expand Up @@ -1117,6 +1119,56 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
}

struct vk_op_conv_transpose_2d_push_constants {
uint32_t Cout;
uint32_t Cin;
uint32_t N;

uint32_t KW;
uint32_t KH;
uint32_t W;
uint32_t H;
uint32_t OW;
uint32_t OH;

uint32_t s0;
uint32_t s1;
uint32_t p0;
uint32_t p1;
uint32_t d0;
uint32_t d1;

uint32_t nb01;
uint32_t nb02;
uint32_t nb03;

uint32_t nb11;
uint32_t nb12;
uint32_t nb13;

uint32_t nb1;
uint32_t nb2;
uint32_t nb3;

// init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1
uint32_t KWmp; uint32_t KWL;
uint32_t KWKHmp; uint32_t KWKHL;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t s0mp; uint32_t s0L;
uint32_t s1mp; uint32_t s1L;
};

template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) {
// Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1
init_fastdiv_values(p.KW, p.KWmp, p.KWL);
init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
init_fastdiv_values(p.s0, p.s0mp, p.s0L);
init_fastdiv_values(p.s1, p.s1mp, p.s1L);
}

struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
Expand Down Expand Up @@ -1313,7 +1365,7 @@ class vk_perf_logger {
flops[name].push_back(m * n * (k + (k - 1)) * batch);
return;
}
if (node->op == GGML_OP_CONV_2D) {
if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
std::string name = ggml_op_name(node->op);
ggml_tensor * knl = node->src[0];
uint64_t OW = node->ne[0];
Expand All @@ -1322,7 +1374,7 @@ class vk_perf_logger {
uint64_t Cout = node->ne[2];
uint64_t KW = knl->ne[0];
uint64_t KH = knl->ne[1];
uint64_t Cin = knl->ne[2];
uint64_t Cin = node->src[1]->ne[2];
// KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
uint64_t size_M = Cout;
uint64_t size_K = Cin * KW * KH;
Expand Down Expand Up @@ -3471,7 +3523,7 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);

// conv2d
// conv2d, conv_transpose_2d
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
uint32_t conv2d_WG_SIZE = 256;
uint32_t conv2d_BS_K = 128;
Expand Down Expand Up @@ -3546,31 +3598,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };

#define CREATE_CONV(name, type_suffix, spv_suffix) \
ggml_vk_create_pipeline( \
device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \
name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
#define CREATE_CONVS(spv_suffix) \
CREATE_CONV(conv2d, _f32, spv_suffix) \
CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \
CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \
CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \
}
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
CREATE_CONVS(_cm2)
} else
#endif
if (conv2d_UNROLL) {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
CREATE_CONVS(_unroll)
} else {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
CREATE_CONVS( )
}
#undef CREATE_CONV
#undef CREATE_CONVS
}

ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
Expand Down Expand Up @@ -7502,6 +7553,33 @@ static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst)
return elements;
}

static std::array<uint32_t, 3> ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) {
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];

// src0 - kernel: [KW, KH, Cout, Cin]
// src1 - input: [W, H, Cin, N]
// dst - result: [OW, OH, Cout, N]

auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
return (ins - 1) * s - 2 * p + (ks - 1) * d + 1;
};
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
int64_t W = src1->ne[0];
int64_t H = src1->ne[1];
int64_t KW = src0->ne[0];
int64_t KH = src0->ne[1];
int64_t Cout = src0->ne[2];
int64_t N = src1->ne[3];
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1);
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1);
int64_t NPQ = N * OW * OH;

// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
return elements;
}

static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
switch (op) {
case GGML_OP_GET_ROWS:
Expand Down Expand Up @@ -7879,9 +7957,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_CONV_2D:
case GGML_OP_CONV_TRANSPOSE_2D:
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
auto elements = ggml_vk_get_conv_elements(dst);
std::array<uint32_t, 3> elements;
if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst);
else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst);
vk_conv_shapes shape;

uint32_t tiles[CONV_SHAPE_COUNT];
Expand All @@ -7901,10 +7982,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
shape = CONV_SHAPE_64x32;
}

if (src0->type == GGML_TYPE_F32) {
return ctx->device->pipeline_conv2d_f32[shape];
} else if (src0->type == GGML_TYPE_F16) {
return ctx->device->pipeline_conv2d_f16_f32[shape];
if (op == GGML_OP_CONV_2D) {
if (src0->type == GGML_TYPE_F32) {
return ctx->device->pipeline_conv2d_f32[shape];
} else if (src0->type == GGML_TYPE_F16) {
return ctx->device->pipeline_conv2d_f16_f32[shape];
}
} else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
if (src0->type == GGML_TYPE_F32) {
return ctx->device->pipeline_conv_transpose_2d_f32[shape];
} else if (src0->type == GGML_TYPE_F16) {
return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
}
}
}
return nullptr;
Expand Down Expand Up @@ -8304,6 +8393,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
{
elements = ggml_vk_get_conv_elements(dst);
} break;
case GGML_OP_CONV_TRANSPOSE_2D:
{
elements = ggml_vk_get_conv_transpose_2d_elements(dst);
} break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
Expand Down Expand Up @@ -9477,6 +9570,55 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
}

static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);

GGML_TENSOR_BINARY_OP_LOCALS

GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb0 == sizeof(float));

vk_op_conv_transpose_2d_push_constants p{};
p.Cout = static_cast<uint32_t>(ne02);
p.Cin = static_cast<uint32_t>(ne03);
p.N = static_cast<uint32_t>(ne13);

p.KW = static_cast<uint32_t>(ne00);
p.KH = static_cast<uint32_t>(ne01);
p.W = static_cast<uint32_t>(ne10);
p.H = static_cast<uint32_t>(ne11);
p.OW = static_cast<uint32_t>(ne0);
p.OH = static_cast<uint32_t>(ne1);

p.s0 = static_cast<uint32_t>(dst->op_params[0]);
p.s1 = static_cast<uint32_t>(dst->op_params[0]);
p.p0 = 0;
p.p1 = 0;
p.d0 = 1;
p.d1 = 1;

p.nb01 = static_cast<uint32_t>(nb01 / nb00);
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
p.nb03 = static_cast<uint32_t>(nb03 / nb00);

p.nb11 = static_cast<uint32_t>(nb11 / nb10);
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
p.nb13 = static_cast<uint32_t>(nb13 / nb10);

p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
p.nb3 = static_cast<uint32_t>(nb3 / nb0);

GGML_ASSERT(ne02 == ne2);
GGML_ASSERT(ne03 == ne12);

ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun);
}

static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
vk_op_conv2d_dw_push_constants p{};
p.ne = ggml_nelements(dst);
Expand Down Expand Up @@ -10569,6 +10711,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
Expand Down Expand Up @@ -10640,6 +10783,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_SGD:
Expand Down Expand Up @@ -10951,6 +11095,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_2D:
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);

break;
case GGML_OP_CONV_TRANSPOSE_2D:
ggml_vk_conv_transpose_2d(ctx, compute_ctx, src0, src1, node, dryrun);

break;
case GGML_OP_CONV_2D_DW:
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
Expand Down Expand Up @@ -11091,6 +11239,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
Expand Down Expand Up @@ -11743,10 +11892,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
} else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
} else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D || cgraph->nodes[i]->op == GGML_OP_CONV_TRANSPOSE_2D) {
// Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
auto CRS_size =
cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2];
cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[1]->ne[2];
auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
}
Expand Down Expand Up @@ -12567,10 +12716,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
case GGML_OP_CONV_TRANSPOSE_2D:
{
// Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
if (op->op == GGML_OP_CONV_TRANSPOSE_2D &&
device->properties.limits.maxPushConstantsSize < sizeof(vk_op_conv_transpose_2d_push_constants)) {
return false;
}
// Channel-contiguous format is not supported yet.
return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->src[1]->type == GGML_TYPE_F32 &&
Expand Down Expand Up @@ -13175,6 +13329,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const int32_t d0 = tensor->op_params[4];
const int32_t d1 = tensor->op_params[5];
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
} else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) {
const int32_t s = tensor->op_params[0];
tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s);
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
const float * op_params = (const float *)tensor->op_params;
tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
Expand Down
Loading
Loading