Skip to content

Commit 1be2491

Browse files
authored
feat: partial LyCORIS support (tucker decomposition for LoCon + LoHa + LoKr) (#577)
1 parent 3753223 commit 1be2491

File tree

2 files changed

+573
-345
lines changed

2 files changed

+573
-345
lines changed

ggml_extend.hpp

+68-3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,71 @@
5252
#define __STATIC_INLINE__ static inline
5353
#endif
5454

55+
// n-mode trensor-matrix product
56+
// example: 2-mode product
57+
// A: [ne03, k, ne01, ne00]
58+
// B: k rows, m columns => [k, m]
59+
// result is [ne03, m, ne01, ne00]
60+
__STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) {
61+
// reshape A
62+
// swap 0th and nth axis
63+
a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0));
64+
int ne1 = a->ne[1];
65+
int ne2 = a->ne[2];
66+
int ne3 = a->ne[3];
67+
// make 2D
68+
a = ggml_cont(ctx, ggml_reshape_2d(ctx, a, a->ne[0], (ne3 * ne2 * ne1)));
69+
70+
struct ggml_tensor* result = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, a, b)));
71+
72+
// reshape output (same shape as a after permutation except first dim)
73+
result = ggml_reshape_4d(ctx, result, result->ne[0], ne1, ne2, ne3);
74+
// swap back 0th and nth axis
75+
result = ggml_permute(ctx, result, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0);
76+
return result;
77+
}
78+
79+
__STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct ggml_tensor* lora_down, struct ggml_tensor* lora_up, struct ggml_tensor* lora_mid = NULL) {
80+
struct ggml_tensor* updown;
81+
// flat lora tensors to multiply it
82+
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
83+
lora_up = ggml_reshape_2d(ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
84+
auto lora_down_n_dims = ggml_n_dims(lora_down);
85+
// assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work)
86+
lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2);
87+
int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1];
88+
lora_down = ggml_reshape_2d(ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows);
89+
90+
// ggml_mul_mat requires tensor b transposed
91+
lora_down = ggml_cont(ctx, ggml_transpose(ctx, lora_down));
92+
if (lora_mid == NULL) {
93+
updown = ggml_mul_mat(ctx, lora_up, lora_down);
94+
updown = ggml_cont(ctx, ggml_transpose(ctx, updown));
95+
} else {
96+
// undoing tucker decomposition for conv layers.
97+
// lora_mid has shape (3, 3, Rank, Rank)
98+
// lora_down has shape (Rank, In, 1, 1)
99+
// lora_up has shape (Rank, Out, 1, 1)
100+
// conv layer shape is (3, 3, Out, In)
101+
updown = ggml_mul_n_mode(ctx, ggml_mul_n_mode(ctx, lora_mid, lora_down, 3), lora_up, 2);
102+
updown = ggml_cont(ctx, updown);
103+
}
104+
return updown;
105+
}
106+
107+
// Kronecker product
108+
// [ne03,ne02,ne01,ne00] x [ne13,ne12,ne11,ne10] => [ne03*ne13,ne02*ne12,ne01*ne11,ne00*ne10]
109+
__STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b) {
110+
return ggml_mul(ctx,
111+
ggml_upscale_ext(ctx,
112+
a,
113+
a->ne[0] * b->ne[0],
114+
a->ne[1] * b->ne[1],
115+
a->ne[2] * b->ne[2],
116+
a->ne[3] * b->ne[3]),
117+
b);
118+
}
119+
55120
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) {
56121
(void)level;
57122
(void)user_data;
@@ -319,7 +384,7 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
319384
for (int iy = 0; iy < height; iy++) {
320385
float m = ggml_tensor_get_f32(mask, ix, iy);
321386
for (int k = 0; k < channels; k++) {
322-
float value = ((float)(m < 254.5/255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
387+
float value = ((float)(m < 254.5 / 255)) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
323388
ggml_tensor_set_f32(output, value, ix, iy, k);
324389
}
325390
}
@@ -987,8 +1052,8 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
9871052
}
9881053

9891054
/* SDXL with LoRA requires more space */
990-
#define MAX_PARAMS_TENSOR_NUM 15360
991-
#define MAX_GRAPH_SIZE 15360
1055+
#define MAX_PARAMS_TENSOR_NUM 32768
1056+
#define MAX_GRAPH_SIZE 32768
9921057

9931058
struct GGMLRunner {
9941059
protected:

0 commit comments

Comments
 (0)