Skip to content

Commit a62d7cb

Browse files
add q8_0 q4_0 tests
1 parent 5c7e9c4 commit a62d7cb

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

ggml-cuda.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2888,10 +2888,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28882888
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
28892889
return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
28902890
#else
2891-
if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
2891+
if (op->src[0]->ne[0] == 128) {
28922892
return true;
28932893
}
2894-
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
2894+
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
2895+
return true;
2896+
}
2897+
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
2898+
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
28952899
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
28962900
default:
28972901
return false;

tests/test-backend-ops.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,21 +1540,23 @@ struct test_flash_attn_ext : public test_case {
15401540

15411541
const float max_bias; // ALiBi
15421542

1543+
const ggml_type type_KV;
1544+
15431545
std::string vars() override {
1544-
return VARS_TO_STR6(hs, nh, kv, nb, mask, max_bias);
1546+
return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV);
15451547
}
15461548

15471549
double max_nmse_err() override {
15481550
return 5e-4;
15491551
}
15501552

1551-
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f)
1552-
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias) {}
1553+
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
1554+
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
15531555

15541556
ggml_tensor * build_graph(ggml_context * ctx) override {
15551557
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
1556-
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
1557-
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
1558+
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs, kv, nh, 1);
1559+
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs, kv, nh, 1);
15581560
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
15591561
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
15601562
return out;
@@ -2238,7 +2240,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22382240
for (int nh : { 32, }) {
22392241
for (int kv : { 512, 1024, }) {
22402242
for (int nb : { 1, 2, 4, 8, }) {
2241-
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias));
2243+
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
2244+
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV));
2245+
}
22422246
}
22432247
}
22442248
}

0 commit comments

Comments
 (0)