diff --git a/clip.hpp b/clip.hpp index ec2e1733..f92c9c2f 100644 --- a/clip.hpp +++ b/clip.hpp @@ -488,14 +488,14 @@ struct CLIPLayer : public GGMLBlock { blocks["mlp"] = std::shared_ptr(new CLIPMLP(d_model, intermediate_size)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, bool mask = true) { + struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* x, bool mask = true) { // x: [N, n_token, d_model] auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); auto layer_norm1 = std::dynamic_pointer_cast(blocks["layer_norm1"]); auto layer_norm2 = std::dynamic_pointer_cast(blocks["layer_norm2"]); auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); - x = ggml_add(ctx, x, self_attn->forward(ctx, layer_norm1->forward(ctx, x), mask)); + x = ggml_add(ctx, x, self_attn->forward(ctx, backend, layer_norm1->forward(ctx, x), mask)); x = ggml_add(ctx, x, mlp->forward(ctx, layer_norm2->forward(ctx, x))); return x; } @@ -517,7 +517,11 @@ struct CLIPEncoder : public GGMLBlock { } } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int clip_skip = -1, bool mask = true) { + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + int clip_skip = -1, + bool mask = true) { // x: [N, n_token, d_model] int layer_idx = n_layer - 1; // LOG_DEBUG("clip_skip %d", clip_skip); @@ -532,7 +536,7 @@ struct CLIPEncoder : public GGMLBlock { } std::string name = "layers." + std::to_string(i); auto layer = std::dynamic_pointer_cast(blocks[name]); - x = layer->forward(ctx, x, mask); // [N, n_token, d_model] + x = layer->forward(ctx, backend, x, mask); // [N, n_token, d_model] // LOG_DEBUG("layer %d", i); } return x; @@ -712,6 +716,7 @@ class CLIPTextModel : public GGMLBlock { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* input_ids, struct ggml_tensor* tkn_embeddings, size_t max_token_idx = 0, @@ -722,7 +727,7 @@ class CLIPTextModel : public GGMLBlock { auto final_layer_norm = std::dynamic_pointer_cast(blocks["final_layer_norm"]); auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size] - x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true); + x = encoder->forward(ctx, backend, x, return_pooled ? -1 : clip_skip, true); if (return_pooled || with_final_ln) { x = final_layer_norm->forward(ctx, x); } @@ -775,6 +780,7 @@ class CLIPVisionModel : public GGMLBlock { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* pixel_values, bool return_pooled = true, int clip_skip = -1) { @@ -786,7 +792,7 @@ class CLIPVisionModel : public GGMLBlock { auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] x = pre_layernorm->forward(ctx, x); - x = encoder->forward(ctx, x, clip_skip, false); + x = encoder->forward(ctx, backend, x, clip_skip, false); // print_ggml_tensor(x, true, "ClipVisionModel x: "); auto last_hidden_state = x; x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] @@ -855,6 +861,7 @@ class CLIPVisionModelProjection : public GGMLBlock { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* pixel_values, bool return_pooled = true, int clip_skip = -1) { @@ -863,7 +870,7 @@ class CLIPVisionModelProjection : public GGMLBlock { auto vision_model = std::dynamic_pointer_cast(blocks["vision_model"]); auto visual_projection = std::dynamic_pointer_cast(blocks["visual_projection"]); - auto x = vision_model->forward(ctx, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size] + auto x = vision_model->forward(ctx, backend, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size] if (return_pooled) { x = visual_projection->forward(ctx, x); // [N, projection_dim] @@ -900,6 +907,7 @@ struct CLIPTextModelRunner : public GGMLRunner { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* input_ids, struct ggml_tensor* embeddings, size_t max_token_idx = 0, @@ -911,7 +919,7 @@ struct CLIPTextModelRunner : public GGMLRunner { input_ids = ggml_reshape_2d(ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token); } - return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled); + return model.forward(ctx, backend, input_ids, embeddings, max_token_idx, return_pooled); } struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, @@ -937,7 +945,7 @@ struct CLIPTextModelRunner : public GGMLRunner { embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1); } - struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, embeddings, max_token_idx, return_pooled); + struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, embeddings, max_token_idx, return_pooled); ggml_build_forward_expand(gf, hidden_states); diff --git a/common.hpp b/common.hpp index 3a130776..bf4da24e 100644 --- a/common.hpp +++ b/common.hpp @@ -270,7 +270,10 @@ class CrossAttention : public GGMLBlock { // to_out_1 is nn.Dropout(), skip for inference } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* context) { // x: [N, n_token, query_dim] // context: [N, n_context, context_dim] // return: [N, n_token, query_dim] @@ -288,7 +291,7 @@ class CrossAttention : public GGMLBlock { auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] - x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim] + x = ggml_nn_attention_ext(ctx, backend, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] return x; @@ -327,7 +330,10 @@ class BasicTransformerBlock : public GGMLBlock { } } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* context) { // x: [N, n_token, query_dim] // context: [N, n_context, context_dim] // return: [N, n_token, query_dim] @@ -352,11 +358,11 @@ class BasicTransformerBlock : public GGMLBlock { auto r = x; x = norm1->forward(ctx, x); - x = attn1->forward(ctx, x, x); // self-attention + x = attn1->forward(ctx, backend, x, x); // self-attention x = ggml_add(ctx, x, r); r = x; x = norm2->forward(ctx, x); - x = attn2->forward(ctx, x, context); // cross-attention + x = attn2->forward(ctx, backend, x, context); // cross-attention x = ggml_add(ctx, x, r); r = x; x = norm3->forward(ctx, x); @@ -401,7 +407,10 @@ class SpatialTransformer : public GGMLBlock { blocks["proj_out"] = std::shared_ptr(new Conv2d(inner_dim, in_channels, {1, 1})); } - virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { + virtual struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* context) { // x: [N, in_channels, h, w] // context: [N, max_position(aka n_token), hidden_size(aka context_dim)] auto norm = std::dynamic_pointer_cast(blocks["norm"]); @@ -424,7 +433,7 @@ class SpatialTransformer : public GGMLBlock { std::string name = "transformer_blocks." + std::to_string(i); auto transformer_block = std::dynamic_pointer_cast(blocks[name]); - x = transformer_block->forward(ctx, x, context); + x = transformer_block->forward(ctx, backend, x, context); } x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w] diff --git a/conditioner.hpp b/conditioner.hpp index d01b1c6e..cfd2b4ca 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -639,7 +639,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner { pixel_values = to_backend(pixel_values); - struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values, return_pooled, clip_skip); + struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, runtime_backend, pixel_values, return_pooled, clip_skip); ggml_build_forward_expand(gf, hidden_states); diff --git a/control.hpp b/control.hpp index 094dd124..f9a49235 100644 --- a/control.hpp +++ b/control.hpp @@ -174,10 +174,11 @@ class ControlNetBlock : public GGMLBlock { struct ggml_tensor* attention_layer_forward(std::string name, struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* context) { auto block = std::dynamic_pointer_cast(blocks[name]); - return block->forward(ctx, x, context); + return block->forward(ctx, backend, x, context); } struct ggml_tensor* input_hint_block_forward(struct ggml_context* ctx, @@ -199,6 +200,7 @@ class ControlNetBlock : public GGMLBlock { } std::vector forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* hint, struct ggml_tensor* guided_hint, @@ -272,7 +274,7 @@ class ControlNetBlock : public GGMLBlock { h = resblock_forward(name, ctx, h, emb); // [N, mult*model_channels, h, w] if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1"; - h = attention_layer_forward(name, ctx, h, context); // [N, mult*model_channels, h, w] + h = attention_layer_forward(name, ctx, backend, h, context); // [N, mult*model_channels, h, w] } auto zero_conv = std::dynamic_pointer_cast(blocks["zero_convs." + std::to_string(input_block_idx) + ".0"]); @@ -296,9 +298,9 @@ class ControlNetBlock : public GGMLBlock { // [N, 4*model_channels, h/8, w/8] // middle_block - h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8] - h = attention_layer_forward("middle_block.1", ctx, h, context); // [N, 4*model_channels, h/8, w/8] - h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8] + h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8] + h = attention_layer_forward("middle_block.1", ctx, backend, h, context); // [N, 4*model_channels, h/8, w/8] + h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8] // out outs.push_back(middle_block_out->forward(ctx, h)); @@ -403,6 +405,7 @@ struct ControlNet : public GGMLRunner { timesteps = to_backend(timesteps); auto outs = control_net.forward(compute_ctx, + runtime_backend, x, hint, guided_hint_cached ? guided_hint : NULL, diff --git a/flux.hpp b/flux.hpp index 044ea82a..ae0cd375 100644 --- a/flux.hpp +++ b/flux.hpp @@ -114,6 +114,7 @@ namespace Flux { } __STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* q, struct ggml_tensor* k, struct ggml_tensor* v, @@ -126,7 +127,7 @@ namespace Flux { q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] - auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head] + auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head] return x; } @@ -169,13 +170,17 @@ namespace Flux { return x; } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask) { + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* pe, + struct ggml_tensor* mask) { // x: [N, n_token, dim] // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] - auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] - x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + x = attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -299,6 +304,7 @@ namespace Flux { } std::pair forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* img, struct ggml_tensor* txt, struct ggml_tensor* vec, @@ -362,8 +368,8 @@ namespace Flux { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] - attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] + attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, attn->ne[0], @@ -446,6 +452,7 @@ namespace Flux { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* vec, struct ggml_tensor* pe, @@ -496,7 +503,7 @@ namespace Flux { auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); - auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] + auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] @@ -699,6 +706,7 @@ namespace Flux { } struct ggml_tensor* forward_orig(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* img, struct ggml_tensor* txt, struct ggml_tensor* timesteps, @@ -763,7 +771,7 @@ namespace Flux { auto block = std::dynamic_pointer_cast(blocks["double_blocks." + std::to_string(i)]); - auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask); + auto img_txt = block->forward(ctx, backend, img, txt, vec, pe, txt_img_mask); img = img_txt.first; // [N, n_img_token, hidden_size] txt = img_txt.second; // [N, n_txt_token, hidden_size] } @@ -775,7 +783,7 @@ namespace Flux { } auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); - txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask); + txt_img = block->forward(ctx, backend, txt_img, vec, pe, txt_img_mask); } txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] @@ -808,6 +816,7 @@ namespace Flux { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, @@ -857,7 +866,7 @@ namespace Flux { } } - auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size] + auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size] if (out->ne[1] > img_tokens) { out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); @@ -1001,6 +1010,7 @@ namespace Flux { set_backend_tensor_data(pe, pe_vec.data()); struct ggml_tensor* out = flux.forward(compute_ctx, + runtime_backend, x, timesteps, context, diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 09657847..560d2861 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1000,6 +1000,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx // mask: [N, L_q, L_k] // return: [N, L_q, C] __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* q, struct ggml_tensor* k, struct ggml_tensor* v, @@ -1038,69 +1039,74 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* float scale = (1.0f / sqrt((float)d_head)); - int kv_pad = 0; - if (flash_attn) { - // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); - bool can_use_flash_attn = true; - if (can_use_flash_attn && L_k % 256 != 0) { - kv_pad = GGML_PAD(L_k, 256) - L_k; - } - - if (mask != nullptr) { - // TODO(Green-Sky): figure out if we can bend t5 to work too - can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1; - } - - if (!can_use_flash_attn) { - flash_attn = false; - } - } - + int kv_pad = 0; ggml_tensor* kqv = nullptr; - if (flash_attn) { - // LOG_DEBUG(" uses flash attention"); + + auto build_kqv = [&](ggml_tensor* q_in, ggml_tensor* k_in, ggml_tensor* v_in, ggml_tensor* mask_in) -> ggml_tensor* { if (kv_pad != 0) { - // LOG_DEBUG(" padding k and v dim1 by %d", kv_pad); - k = ggml_pad(ctx, k, 0, kv_pad, 0, 0); + k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0); } - k = ggml_cast(ctx, k, GGML_TYPE_F16); + k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16); - v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] - v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] + v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3)); + v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_head * N); if (kv_pad != 0) { - v = ggml_pad(ctx, v, 0, kv_pad, 0, 0); + v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0); } - v = ggml_cast(ctx, v, GGML_TYPE_F16); + v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16); - if (mask != nullptr) { - mask = ggml_transpose(ctx, mask); + if (mask_in != nullptr) { + mask_in = ggml_transpose(ctx, mask_in); } else { if (kv_pad > 0) { - mask = ggml_zeros(ctx, L_k, L_q, 1, 1); // [L_q, L_k] - auto pad_tensor = ggml_full(ctx, -INFINITY, kv_pad, L_q, 1, 1); // [L_q, kv_pad] - mask = ggml_concat(ctx, mask, pad_tensor, 0); // [L_q, L_k + kv_pad] + mask_in = ggml_zeros(ctx, L_k, L_q, 1, 1); + auto pad_tensor = ggml_full(ctx, -INFINITY, kv_pad, L_q, 1, 1); + mask_in = ggml_concat(ctx, mask_in, pad_tensor, 0); } } - // mask pad - if (mask != nullptr) { + if (mask_in != nullptr) { int mask_pad = 0; - if (mask->ne[1] % GGML_KQ_MASK_PAD != 0) { - mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask->ne[1]; + if (mask_in->ne[1] % GGML_KQ_MASK_PAD != 0) { + mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask_in->ne[1]; } if (mask_pad > 0) { - mask = ggml_pad(ctx, mask, 0, mask_pad, 0, 0); // [L_q + mask_pad, L_k + kv_pad] + mask_in = ggml_pad(ctx, mask_in, 0, mask_pad, 0, 0); } - mask = ggml_cast(ctx, mask, GGML_TYPE_F16); - // LOG_DEBUG("L_k: %ld, L_q: %ld, mask->ne[1]: %ld, mask_pad: %d, kv_pad: %d", L_k, L_q, mask->ne[1], mask_pad, kv_pad); + mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16); } - kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0); - ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); + auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale, 0, 0); + ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32); + return out; + }; - // kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0); - kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0); - } else { + if (flash_attn) { + // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); + bool can_use_flash_attn = true; + if (can_use_flash_attn && L_k % 256 != 0) { + kv_pad = GGML_PAD(L_k, 256) - L_k; + } + + if (mask != nullptr) { + // TODO: figure out if we can bend t5 to work too + can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1; + } + + if (can_use_flash_attn) { + kqv = build_kqv(q, k, v, mask); + if (!ggml_backend_supports_op(backend, kqv)) { + kqv = nullptr; + } else { + kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0); + } + } + } + + if (kqv == nullptr) { + // if (flash_attn) { + // LOG_DEBUG("fallback to default attention, L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); + // } v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] @@ -2164,7 +2170,10 @@ class MultiheadAttention : public GGMLBlock { } // x: [N, n_token, embed_dim] - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, bool mask = false) { + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + bool mask = false) { auto q_proj = std::dynamic_pointer_cast(blocks[q_proj_name]); auto k_proj = std::dynamic_pointer_cast(blocks[k_proj_name]); auto v_proj = std::dynamic_pointer_cast(blocks[v_proj_name]); @@ -2174,7 +2183,7 @@ class MultiheadAttention : public GGMLBlock { struct ggml_tensor* k = k_proj->forward(ctx, x); struct ggml_tensor* v = v_proj->forward(ctx, x); - x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, mask); // [N, n_token, embed_dim] + x = ggml_nn_attention_ext(ctx, backend, q, k, v, n_head, NULL, mask); // [N, n_token, embed_dim] x = out_proj->forward(ctx, x); // [N, n_token, embed_dim] return x; diff --git a/mmdit.hpp b/mmdit.hpp index 904cda47..acb55e60 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -202,10 +202,12 @@ class SelfAttention : public GGMLBlock { } // x: [N, n_token, dim] - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x) { auto qkv = pre_attention(ctx, x); - x = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -415,7 +417,10 @@ struct DismantledBlock : public GGMLBlock { return x; } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* c) { // x: [N, n_token, hidden_size] // c: [N, hidden_size] // return: [N, n_token, hidden_size] @@ -430,8 +435,8 @@ struct DismantledBlock : public GGMLBlock { auto qkv2 = std::get<1>(qkv_intermediates); auto intermediates = std::get<2>(qkv_intermediates); - auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] - auto attn2_out = ggml_nn_attention_ext(ctx, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim] + auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim] x = post_attention_x(ctx, attn_out, attn2_out, @@ -447,7 +452,7 @@ struct DismantledBlock : public GGMLBlock { auto qkv = qkv_intermediates.first; auto intermediates = qkv_intermediates.second; - auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] x = post_attention(ctx, attn_out, intermediates[0], @@ -462,6 +467,7 @@ struct DismantledBlock : public GGMLBlock { __STATIC_INLINE__ std::pair block_mixing(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* context, struct ggml_tensor* x, struct ggml_tensor* c, @@ -491,8 +497,8 @@ block_mixing(struct ggml_context* ctx, qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); } - auto attn = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], x_block->num_heads); // [N, n_context + n_token, hidden_size] - attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] + auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads); // [N, n_context + n_token, hidden_size] + attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] auto context_attn = ggml_view_3d(ctx, attn, attn->ne[0], @@ -525,7 +531,7 @@ block_mixing(struct ggml_context* ctx, } if (x_block->self_attn) { - auto attn2 = ggml_nn_attention_ext(ctx, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size] + auto attn2 = ggml_nn_attention_ext(ctx, backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size] x = x_block->post_attention_x(ctx, x_attn, @@ -563,13 +569,14 @@ struct JointBlock : public GGMLBlock { } std::pair forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* context, struct ggml_tensor* x, struct ggml_tensor* c) { auto context_block = std::dynamic_pointer_cast(blocks["context_block"]); auto x_block = std::dynamic_pointer_cast(blocks["x_block"]); - return block_mixing(ctx, context, x, c, context_block, x_block); + return block_mixing(ctx, backend, context, x, c, context_block, x_block); } }; @@ -771,6 +778,7 @@ struct MMDiT : public GGMLBlock { } struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* c_mod, struct ggml_tensor* context, @@ -789,7 +797,7 @@ struct MMDiT : public GGMLBlock { auto block = std::dynamic_pointer_cast(blocks["joint_blocks." + std::to_string(i)]); - auto context_x = block->forward(ctx, context, x, c_mod); + auto context_x = block->forward(ctx, backend, context, x, c_mod); context = context_x.first; x = context_x.second; } @@ -800,6 +808,7 @@ struct MMDiT : public GGMLBlock { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* t, struct ggml_tensor* y = NULL, @@ -835,7 +844,7 @@ struct MMDiT : public GGMLBlock { context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536] } - x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) + x = forward_core_with_concat(ctx, backend, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) x = unpatchify(ctx, x, h, w); // [N, C, H, W] @@ -874,6 +883,7 @@ struct MMDiTRunner : public GGMLRunner { timesteps = to_backend(timesteps); struct ggml_tensor* out = mmdit.forward(compute_ctx, + runtime_backend, x, timesteps, y, diff --git a/pmid.hpp b/pmid.hpp index 9b725ded..5e9b0d5b 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -508,6 +508,7 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* id_pixel_values, struct ggml_tensor* prompt_embeds, struct ggml_tensor* class_tokens_mask, @@ -520,9 +521,9 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection { auto visual_projection_2 = std::dynamic_pointer_cast(blocks["visual_projection_2"]); auto fuse_module = std::dynamic_pointer_cast(blocks["fuse_module"]); - struct ggml_tensor* shared_id_embeds = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size] - struct ggml_tensor* id_embeds = visual_projection->forward(ctx, shared_id_embeds); // [N, proj_dim(768)] - struct ggml_tensor* id_embeds_2 = visual_projection_2->forward(ctx, shared_id_embeds); // [N, 1280] + struct ggml_tensor* shared_id_embeds = vision_model->forward(ctx, backend, id_pixel_values); // [N, hidden_size] + struct ggml_tensor* id_embeds = visual_projection->forward(ctx, shared_id_embeds); // [N, proj_dim(768)] + struct ggml_tensor* id_embeds_2 = visual_projection_2->forward(ctx, shared_id_embeds); // [N, 1280] id_embeds = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3)); id_embeds_2 = ggml_cont(ctx, ggml_permute(ctx, id_embeds_2, 2, 0, 1, 3)); @@ -579,6 +580,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo */ struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* id_pixel_values, struct ggml_tensor* prompt_embeds, struct ggml_tensor* class_tokens_mask, @@ -592,7 +594,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo auto qformer_perceiver = std::dynamic_pointer_cast(blocks["qformer_perceiver"]); // struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size] - struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size] + struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, backend, id_pixel_values, false); // [N, hidden_size] id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state); struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx, @@ -742,6 +744,7 @@ struct PhotoMakerIDEncoder : public GGMLRunner { struct ggml_tensor* updated_prompt_embeds = NULL; if (pm_version == PM_VERSION_1) updated_prompt_embeds = id_encoder.forward(ctx0, + runtime_backend, id_pixel_values_d, prompt_embeds_d, class_tokens_mask_d, @@ -749,6 +752,7 @@ struct PhotoMakerIDEncoder : public GGMLRunner { left, right); else if (pm_version == PM_VERSION_2) updated_prompt_embeds = id_encoder2.forward(ctx0, + runtime_backend, id_pixel_values_d, prompt_embeds_d, class_tokens_mask_d, diff --git a/t5.hpp b/t5.hpp index f149dade..062e37bb 100644 --- a/t5.hpp +++ b/t5.hpp @@ -578,6 +578,7 @@ class T5Attention : public GGMLBlock { // x: [N, n_token, model_dim] std::pair forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* past_bias = NULL, struct ggml_tensor* mask = NULL, @@ -608,7 +609,7 @@ class T5Attention : public GGMLBlock { k = ggml_scale_inplace(ctx, k, sqrt(d_head)); - x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head] + x = ggml_nn_attention_ext(ctx, backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head] x = out_proj->forward(ctx, x); // [N, n_token, model_dim] return {x, past_bias}; @@ -627,6 +628,7 @@ struct T5LayerSelfAttention : public GGMLBlock { } std::pair forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* past_bias = NULL, struct ggml_tensor* mask = NULL, @@ -636,7 +638,7 @@ struct T5LayerSelfAttention : public GGMLBlock { auto layer_norm = std::dynamic_pointer_cast(blocks["layer_norm"]); auto normed_hidden_state = layer_norm->forward(ctx, x); - auto ret = SelfAttention->forward(ctx, normed_hidden_state, past_bias, mask, relative_position_bucket); + auto ret = SelfAttention->forward(ctx, backend, normed_hidden_state, past_bias, mask, relative_position_bucket); auto output = ret.first; past_bias = ret.second; @@ -653,6 +655,7 @@ struct T5Block : public GGMLBlock { } std::pair forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* past_bias = NULL, struct ggml_tensor* mask = NULL, @@ -661,7 +664,7 @@ struct T5Block : public GGMLBlock { auto layer_0 = std::dynamic_pointer_cast(blocks["layer.0"]); auto layer_1 = std::dynamic_pointer_cast(blocks["layer.1"]); - auto ret = layer_0->forward(ctx, x, past_bias, mask, relative_position_bucket); + auto ret = layer_0->forward(ctx, backend, x, past_bias, mask, relative_position_bucket); x = ret.first; past_bias = ret.second; x = layer_1->forward(ctx, x); @@ -688,6 +691,7 @@ struct T5Stack : public GGMLBlock { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* past_bias = NULL, struct ggml_tensor* attention_mask = NULL, @@ -696,7 +700,7 @@ struct T5Stack : public GGMLBlock { for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["block." + std::to_string(i)]); - auto ret = block->forward(ctx, x, past_bias, attention_mask, relative_position_bucket); + auto ret = block->forward(ctx, backend, x, past_bias, attention_mask, relative_position_bucket); x = ret.first; past_bias = ret.second; } @@ -735,6 +739,7 @@ struct T5 : public GGMLBlock { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* input_ids, struct ggml_tensor* past_bias = NULL, struct ggml_tensor* attention_mask = NULL, @@ -745,7 +750,7 @@ struct T5 : public GGMLBlock { auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); auto x = shared->forward(ctx, input_ids); - x = encoder->forward(ctx, x, past_bias, attention_mask, relative_position_bucket); + x = encoder->forward(ctx, backend, x, past_bias, attention_mask, relative_position_bucket); return x; } }; @@ -778,13 +783,14 @@ struct T5Runner : public GGMLRunner { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* input_ids, struct ggml_tensor* relative_position_bucket, struct ggml_tensor* attention_mask = NULL) { size_t N = input_ids->ne[1]; size_t n_token = input_ids->ne[0]; - auto hidden_states = model.forward(ctx, input_ids, NULL, attention_mask, relative_position_bucket); // [N, n_token, model_dim] + auto hidden_states = model.forward(ctx, backend, input_ids, NULL, attention_mask, relative_position_bucket); // [N, n_token, model_dim] return hidden_states; } @@ -810,7 +816,7 @@ struct T5Runner : public GGMLRunner { input_ids->ne[0]); set_backend_tensor_data(relative_position_bucket, relative_position_bucket_vec.data()); - struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, relative_position_bucket, attention_mask); + struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, relative_position_bucket, attention_mask); ggml_build_forward_expand(gf, hidden_states); diff --git a/unet.hpp b/unet.hpp index 7e7b2277..19bedb32 100644 --- a/unet.hpp +++ b/unet.hpp @@ -61,6 +61,7 @@ class SpatialVideoTransformer : public SpatialTransformer { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* context, int timesteps) { @@ -127,7 +128,7 @@ class SpatialVideoTransformer : public SpatialTransformer { auto block = std::dynamic_pointer_cast(blocks[transformer_name]); auto mix_block = std::dynamic_pointer_cast(blocks[time_stack_name]); - x = block->forward(ctx, x, spatial_context); // [N, h * w, inner_dim] + x = block->forward(ctx, backend, x, spatial_context); // [N, h * w, inner_dim] // in_channels == inner_dim auto x_mix = x; @@ -143,7 +144,7 @@ class SpatialVideoTransformer : public SpatialTransformer { x_mix = ggml_cont(ctx, ggml_permute(ctx, x_mix, 0, 2, 1, 3)); // b t s c -> b s t c x_mix = ggml_reshape_3d(ctx, x_mix, C, T, S * B); // b s t c -> (b s) t c - x_mix = mix_block->forward(ctx, x_mix, time_context); // [B * h * w, T, inner_dim] + x_mix = mix_block->forward(ctx, backend, x_mix, time_context); // [B * h * w, T, inner_dim] x_mix = ggml_reshape_4d(ctx, x_mix, C, T, S, B); // (b s) t c -> b s t c x_mix = ggml_cont(ctx, ggml_permute(ctx, x_mix, 0, 2, 1, 3)); // b s t c -> b t s c @@ -363,21 +364,23 @@ class UnetModelBlock : public GGMLBlock { struct ggml_tensor* attention_layer_forward(std::string name, struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* context, int timesteps) { if (version == VERSION_SVD) { auto block = std::dynamic_pointer_cast(blocks[name]); - return block->forward(ctx, x, context, timesteps); + return block->forward(ctx, backend, x, context, timesteps); } else { auto block = std::dynamic_pointer_cast(blocks[name]); - return block->forward(ctx, x, context); + return block->forward(ctx, backend, x, context); } } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, @@ -456,7 +459,7 @@ class UnetModelBlock : public GGMLBlock { h = resblock_forward(name, ctx, h, emb, num_video_frames); // [N, mult*model_channels, h, w] if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1"; - h = attention_layer_forward(name, ctx, h, context, num_video_frames); // [N, mult*model_channels, h, w] + h = attention_layer_forward(name, ctx, backend, h, context, num_video_frames); // [N, mult*model_channels, h, w] } hs.push_back(h); } @@ -474,9 +477,9 @@ class UnetModelBlock : public GGMLBlock { // [N, 4*model_channels, h/8, w/8] // middle_block - h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] - h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] - h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] + h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] + h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] + h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] if (controls.size() > 0) { auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength); @@ -507,7 +510,7 @@ class UnetModelBlock : public GGMLBlock { if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1"; - h = attention_layer_forward(name, ctx, h, context, num_video_frames); + h = attention_layer_forward(name, ctx, backend, h, context, num_video_frames); up_sample_idx++; } @@ -592,6 +595,7 @@ struct UNetModelRunner : public GGMLRunner { } struct ggml_tensor* out = unet.forward(compute_ctx, + runtime_backend, x, timesteps, context, diff --git a/wan.hpp b/wan.hpp index d385cac7..48603a95 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1306,6 +1306,7 @@ namespace WAN { } virtual struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask = NULL) { @@ -1332,7 +1333,7 @@ namespace WAN { k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] - x = Flux::attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] + x = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim] return x; @@ -1348,6 +1349,7 @@ namespace WAN { bool flash_attn = false) : WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn) {} virtual struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* context, int64_t context_img_len) = 0; @@ -1362,6 +1364,7 @@ namespace WAN { bool flash_attn = false) : WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {} struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* context, int64_t context_img_len) { @@ -1385,7 +1388,7 @@ namespace WAN { k = norm_k->forward(ctx, k); auto v = v_proj->forward(ctx, context); // [N, n_context, dim] - x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] + x = ggml_nn_attention_ext(ctx, backend, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim] return x; @@ -1411,6 +1414,7 @@ namespace WAN { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* context, int64_t context_img_len) { @@ -1451,8 +1455,8 @@ namespace WAN { k_img = norm_k_img->forward(ctx, k_img); auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] - auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] - x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] + auto img_x = ggml_nn_attention_ext(ctx, backend, q, k_img, v_img, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] + x = ggml_nn_attention_ext(ctx, backend, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] x = ggml_add(ctx, x, img_x); @@ -1529,6 +1533,7 @@ namespace WAN { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* e, struct ggml_tensor* pe, @@ -1555,14 +1560,14 @@ namespace WAN { auto y = norm1->forward(ctx, x); y = ggml_add(ctx, y, modulate_mul(ctx, y, es[1])); y = modulate_add(ctx, y, es[0]); - y = self_attn->forward(ctx, y, pe); + y = self_attn->forward(ctx, backend, y, pe); x = ggml_add(ctx, x, modulate_mul(ctx, y, es[2])); // cross-attention x = ggml_add(ctx, x, - cross_attn->forward(ctx, norm3->forward(ctx, x), context, context_img_len)); + cross_attn->forward(ctx, backend, norm3->forward(ctx, x), context, context_img_len)); // ffn y = norm2->forward(ctx, x); @@ -1785,6 +1790,7 @@ namespace WAN { } struct ggml_tensor* forward_orig(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, @@ -1842,7 +1848,7 @@ namespace WAN { for (int i = 0; i < params.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]); - x = block->forward(ctx, x, e0, pe, context, context_img_len); + x = block->forward(ctx, backend, x, e0, pe, context, context_img_len); } x = head->forward(ctx, x, e); // [N, t_len*h_len*w_len, pt*ph*pw*out_dim] @@ -1851,6 +1857,7 @@ namespace WAN { } struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, @@ -1885,7 +1892,7 @@ namespace WAN { t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); } - auto out = forward_orig(ctx, x, timestep, context, pe, clip_fea, N); // [N, t_len*h_len*w_len, pt*ph*pw*C] + auto out = forward_orig(ctx, backend, x, timestep, context, pe, clip_fea, N); // [N, t_len*h_len*w_len, pt*ph*pw*C] out = unpatchify(ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] @@ -2040,6 +2047,7 @@ namespace WAN { } struct ggml_tensor* out = wan.forward(compute_ctx, + runtime_backend, x, timesteps, context,