@@ -1584,9 +1584,11 @@ struct test_flash_attn_ext : public test_case {
1584
1584
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
1585
1585
1586
1586
ggml_tensor * build_graph (ggml_context * ctx) override {
1587
- ggml_tensor * q = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, hs, nb, nh, 1 );
1588
- ggml_tensor * k = ggml_new_tensor_4d (ctx, type_KV, hs, kv, nh, 1 );
1589
- ggml_tensor * v = ggml_new_tensor_4d (ctx, type_KV, hs, kv, nh, 1 );
1587
+ const int64_t hs_padded = GGML_PAD (hs, ggml_blck_size (type_KV));
1588
+
1589
+ ggml_tensor * q = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1 );
1590
+ ggml_tensor * k = ggml_new_tensor_4d (ctx, type_KV, hs_padded, kv, nh, 1 );
1591
+ ggml_tensor * v = ggml_new_tensor_4d (ctx, type_KV, hs_padded, kv, nh, 1 );
1590
1592
ggml_tensor * m = mask ? ggml_new_tensor_4d (ctx, GGML_TYPE_F16, kv, GGML_PAD (nb, GGML_KQ_MASK_PAD), 1 , 1 ) : nullptr ;
1591
1593
ggml_tensor * out = ggml_flash_attn_ext (ctx, q, k, v, m, 1 .0f /sqrtf (hs), max_bias);
1592
1594
return out;
0 commit comments