Skip to content

Commit 97c5663

Browse files
authored
Enable SDPA Unit Tests and Adjust Fudge Factors (ROCm#34)
* Adjust of fudge factor for gfx950 * Enable SDPA UTs for gfx950
1 parent f2faad5 commit 97c5663

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

test/test_transformers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3154,6 +3154,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
31543154
'grad_value': 8.5,
31553155
}
31563156
if TEST_WITH_ROCM:
3157+
fudge_factors['grad_value'] = 16.0
31573158
fudge_factors['grad_key'] = 45.0
31583159
fudge_factors['grad_query'] = 360.0
31593160
if seq_len_k >= 1024:
@@ -3273,6 +3274,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
32733274
"grad_attn_mask": 45.0,
32743275
}
32753276
if TEST_WITH_ROCM:
3277+
fudge_factors['grad_value'] = 16.0
32763278
fudge_factors['grad_key'] = 45.0
32773279
fudge_factors['grad_query'] = 360.0
32783280
if seq_len_k >= 1024:
@@ -3528,7 +3530,7 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d
35283530
g.replay()
35293531
out = output_tuple[0]
35303532
if dropout_p == 0.0:
3531-
self.assertEqual(out_first, out, atol=0, rtol=0)
3533+
self.assertEqual(out_first, out, atol=0, rtol=0, msg='Two passes of non-dropout graph mismatches')
35323534
else:
35333535
# replays produce different results
35343536
self.assertNotEqual(out_first, out)
@@ -3569,8 +3571,8 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d
35693571
fudge_factors={
35703572
'out': 3.0,
35713573
'grad_query': 100.0,
3572-
'grad_key': 8.0,
3573-
'grad_value': 3.0,
3574+
'grad_key': 8.0 if not TEST_WITH_ROCM else 16.0,
3575+
'grad_value': 3.0 if not TEST_WITH_ROCM else 6.0,
35743576
}
35753577
)
35763578

torch/testing/_internal/common_cuda.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
def CDNA2OrLater():
4040
if TEST_WITH_ROCM:
4141
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
42-
return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942"})
42+
return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942", "gfx950"})
4343
return False
4444

4545
def evaluate_gfx_arch_exact(matching_arch):
@@ -54,14 +54,14 @@ def evaluate_gfx_arch_exact(matching_arch):
5454

5555
def evaluate_platform_supports_flash_attention():
5656
if TEST_WITH_ROCM:
57-
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
57+
return CDNA2OrLater()
5858
if TEST_CUDA:
5959
return not IS_WINDOWS and SM80OrLater
6060
return False
6161

6262
def evaluate_platform_supports_efficient_attention():
6363
if TEST_WITH_ROCM:
64-
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
64+
return CDNA2OrLater()
6565
if TEST_CUDA:
6666
return True
6767
return False

0 commit comments

Comments
 (0)