Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,14 @@ struct CLIPLayer : public GGMLBlock {
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, bool mask = true) {
struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* x, bool mask = true) {
// x: [N, n_token, d_model]
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
auto layer_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm2"]);
auto mlp = std::dynamic_pointer_cast<CLIPMLP>(blocks["mlp"]);

x = ggml_add(ctx, x, self_attn->forward(ctx, layer_norm1->forward(ctx, x), mask));
x = ggml_add(ctx, x, self_attn->forward(ctx, backend, layer_norm1->forward(ctx, x), mask));
x = ggml_add(ctx, x, mlp->forward(ctx, layer_norm2->forward(ctx, x)));
return x;
}
Expand All @@ -517,7 +517,11 @@ struct CLIPEncoder : public GGMLBlock {
}
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int clip_skip = -1, bool mask = true) {
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
int clip_skip = -1,
bool mask = true) {
// x: [N, n_token, d_model]
int layer_idx = n_layer - 1;
// LOG_DEBUG("clip_skip %d", clip_skip);
Expand All @@ -532,7 +536,7 @@ struct CLIPEncoder : public GGMLBlock {
}
std::string name = "layers." + std::to_string(i);
auto layer = std::dynamic_pointer_cast<CLIPLayer>(blocks[name]);
x = layer->forward(ctx, x, mask); // [N, n_token, d_model]
x = layer->forward(ctx, backend, x, mask); // [N, n_token, d_model]
// LOG_DEBUG("layer %d", i);
}
return x;
Expand Down Expand Up @@ -712,6 +716,7 @@ class CLIPTextModel : public GGMLBlock {
}

struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* input_ids,
struct ggml_tensor* tkn_embeddings,
size_t max_token_idx = 0,
Expand All @@ -722,7 +727,7 @@ class CLIPTextModel : public GGMLBlock {
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);

auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
x = encoder->forward(ctx, backend, x, return_pooled ? -1 : clip_skip, true);
if (return_pooled || with_final_ln) {
x = final_layer_norm->forward(ctx, x);
}
Expand Down Expand Up @@ -775,6 +780,7 @@ class CLIPVisionModel : public GGMLBlock {
}

struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values,
bool return_pooled = true,
int clip_skip = -1) {
Expand All @@ -786,7 +792,7 @@ class CLIPVisionModel : public GGMLBlock {

auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
x = pre_layernorm->forward(ctx, x);
x = encoder->forward(ctx, x, clip_skip, false);
x = encoder->forward(ctx, backend, x, clip_skip, false);
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
auto last_hidden_state = x;
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
Expand Down Expand Up @@ -855,6 +861,7 @@ class CLIPVisionModelProjection : public GGMLBlock {
}

struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* pixel_values,
bool return_pooled = true,
int clip_skip = -1) {
Expand All @@ -863,7 +870,7 @@ class CLIPVisionModelProjection : public GGMLBlock {
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]);

auto x = vision_model->forward(ctx, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size]
auto x = vision_model->forward(ctx, backend, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size]

if (return_pooled) {
x = visual_projection->forward(ctx, x); // [N, projection_dim]
Expand Down Expand Up @@ -900,6 +907,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
}

struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* input_ids,
struct ggml_tensor* embeddings,
size_t max_token_idx = 0,
Expand All @@ -911,7 +919,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
input_ids = ggml_reshape_2d(ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
}

return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled);
return model.forward(ctx, backend, input_ids, embeddings, max_token_idx, return_pooled);
}

struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
Expand All @@ -937,7 +945,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
}

struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, embeddings, max_token_idx, return_pooled);
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, embeddings, max_token_idx, return_pooled);

ggml_build_forward_expand(gf, hidden_states);

Expand Down
23 changes: 16 additions & 7 deletions common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,10 @@ class CrossAttention : public GGMLBlock {
// to_out_1 is nn.Dropout(), skip for inference
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* context) {
// x: [N, n_token, query_dim]
// context: [N, n_context, context_dim]
// return: [N, n_token, query_dim]
Expand All @@ -288,7 +291,7 @@ class CrossAttention : public GGMLBlock {
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]

x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]
x = ggml_nn_attention_ext(ctx, backend, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]

x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x;
Expand Down Expand Up @@ -327,7 +330,10 @@ class BasicTransformerBlock : public GGMLBlock {
}
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* context) {
// x: [N, n_token, query_dim]
// context: [N, n_context, context_dim]
// return: [N, n_token, query_dim]
Expand All @@ -352,11 +358,11 @@ class BasicTransformerBlock : public GGMLBlock {

auto r = x;
x = norm1->forward(ctx, x);
x = attn1->forward(ctx, x, x); // self-attention
x = attn1->forward(ctx, backend, x, x); // self-attention
x = ggml_add(ctx, x, r);
r = x;
x = norm2->forward(ctx, x);
x = attn2->forward(ctx, x, context); // cross-attention
x = attn2->forward(ctx, backend, x, context); // cross-attention
x = ggml_add(ctx, x, r);
r = x;
x = norm3->forward(ctx, x);
Expand Down Expand Up @@ -401,7 +407,10 @@ class SpatialTransformer : public GGMLBlock {
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
}

virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* context) {
// x: [N, in_channels, h, w]
// context: [N, max_position(aka n_token), hidden_size(aka context_dim)]
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
Expand All @@ -424,7 +433,7 @@ class SpatialTransformer : public GGMLBlock {
std::string name = "transformer_blocks." + std::to_string(i);
auto transformer_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[name]);

x = transformer_block->forward(ctx, x, context);
x = transformer_block->forward(ctx, backend, x, context);
}

x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
Expand Down
2 changes: 1 addition & 1 deletion conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {

pixel_values = to_backend(pixel_values);

struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values, return_pooled, clip_skip);
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, runtime_backend, pixel_values, return_pooled, clip_skip);

ggml_build_forward_expand(gf, hidden_states);

Expand Down
13 changes: 8 additions & 5 deletions control.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,11 @@ class ControlNetBlock : public GGMLBlock {

struct ggml_tensor* attention_layer_forward(std::string name,
struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* context) {
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
return block->forward(ctx, x, context);
return block->forward(ctx, backend, x, context);
}

struct ggml_tensor* input_hint_block_forward(struct ggml_context* ctx,
Expand All @@ -199,6 +200,7 @@ class ControlNetBlock : public GGMLBlock {
}

std::vector<struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* hint,
struct ggml_tensor* guided_hint,
Expand Down Expand Up @@ -272,7 +274,7 @@ class ControlNetBlock : public GGMLBlock {
h = resblock_forward(name, ctx, h, emb); // [N, mult*model_channels, h, w]
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
h = attention_layer_forward(name, ctx, h, context); // [N, mult*model_channels, h, w]
h = attention_layer_forward(name, ctx, backend, h, context); // [N, mult*model_channels, h, w]
}

auto zero_conv = std::dynamic_pointer_cast<Conv2d>(blocks["zero_convs." + std::to_string(input_block_idx) + ".0"]);
Expand All @@ -296,9 +298,9 @@ class ControlNetBlock : public GGMLBlock {
// [N, 4*model_channels, h/8, w/8]

// middle_block
h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, h, context); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, backend, h, context); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]

// out
outs.push_back(middle_block_out->forward(ctx, h));
Expand Down Expand Up @@ -403,6 +405,7 @@ struct ControlNet : public GGMLRunner {
timesteps = to_backend(timesteps);

auto outs = control_net.forward(compute_ctx,
runtime_backend,
x,
hint,
guided_hint_cached ? guided_hint : NULL,
Expand Down
Loading
Loading