Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 5c8c352

Browse files
committedApr 9, 2025
[Executorch][SDPA] Refactor + Make quantized sdpa handle sequence at dim 1 or 2
Pull Request resolved: #9943 For quantized SDPA we want to evaluate performance impact of having seq at dim 1 as well as dim 2. This diff refactors the code to enable this. The same should be done also for float SDPA but left for future. ghstack-source-id: 277160631 @exported-using-ghexport Differential Revision: [D71833060](https://our.internmc.facebook.com/intern/diff/D71833060/)
1 parent 46ce593 commit 5c8c352

File tree

5 files changed

+189
-62
lines changed

5 files changed

+189
-62
lines changed
 

‎extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,14 @@ Tensor& flash_attention_kernel_out(
264264
InvalidArgument,
265265
output);
266266

267-
auto q_seq_len = query.size(2);
267+
auto seq_len = query.size(2);
268268

269269
ET_SWITCH_FLOAT_TYPES(
270270
query.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
271271
// TODO we need to re-evaluate this for ARM CPUs
272272
// And there can be many so instead of templatizing
273273
// we might consider another appraoch
274-
if (q_seq_len >= 768) {
274+
if (seq_len >= 768) {
275275
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
276276
output,
277277
query,
@@ -287,7 +287,7 @@ Tensor& flash_attention_kernel_out(
287287
nullopt,
288288
nullopt,
289289
nullopt);
290-
} else if (q_seq_len >= 192) {
290+
} else if (seq_len >= 192) {
291291
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
292292
output,
293293
query,
@@ -341,7 +341,8 @@ Tensor& custom_sdpa_out_impl(
341341
const optional<Tensor>& k_zero_points = nullopt,
342342
const optional<Tensor>& k_scales = nullopt,
343343
const optional<Tensor>& v_zero_points = nullopt,
344-
const optional<Tensor>& v_scales = nullopt) {
344+
const optional<Tensor>& v_scales = nullopt,
345+
bool is_seq_at_dim_2 = false) {
345346
ET_KERNEL_CHECK_MSG(
346347
ctx,
347348
!attn_mask.has_value() || !is_causal,
@@ -357,13 +358,15 @@ Tensor& custom_sdpa_out_impl(
357358
"Invalid arguments");
358359

359360
int64_t seq_len = q.size(1);
360-
auto q_seq_len = q.size(1);
361+
SeqDim seq_dim{SeqDim::TWO};
362+
if (!is_seq_at_dim_2) {
363+
seq_dim = SeqDim::ONE;
364+
}
361365

362-
bool is_seq_at_dim_1{true};
363366
if (q.scalar_type() == ScalarType::Char) {
364-
is_seq_at_dim_1 = false;
365-
seq_len = q.size(2);
366-
q_seq_len = q.size(2);
367+
if (seq_dim == SeqDim::TWO) {
368+
seq_len = q.size(2);
369+
}
367370
ET_KERNEL_CHECK_MSG(
368371
ctx,
369372
q_scales.has_value() && q_zero_points.has_value() &&
@@ -412,7 +415,7 @@ Tensor& custom_sdpa_out_impl(
412415
// TODO we need to re-evaluate this for ARM CPUs
413416
// And there can be many so instead of templatizing
414417
// we might consider another appraoch
415-
if (q_seq_len >= 768) {
418+
if (seq_len >= 768) {
416419
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
417420
output,
418421
q,
@@ -428,10 +431,10 @@ Tensor& custom_sdpa_out_impl(
428431
k_scales, // k_scales
429432
v_zero_points, // v_zero_points
430433
v_scales, // v_scales
431-
is_seq_at_dim_1, /* is_seq_at_dim_1 */
434+
seq_dim, /* seq_dim */
432435
start_pos,
433436
num_keys_for_causal_attention);
434-
} else if (q_seq_len >= 192) {
437+
} else if (seq_len >= 192) {
435438
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
436439
output,
437440
q,
@@ -447,7 +450,7 @@ Tensor& custom_sdpa_out_impl(
447450
k_scales, // k_scales
448451
v_zero_points, // v_zero_points
449452
v_scales, // v_scales
450-
is_seq_at_dim_1, /* is_seq_at_dim_1 */
453+
seq_dim, /* seq_dim */
451454
start_pos,
452455
num_keys_for_causal_attention);
453456
} else {
@@ -466,7 +469,7 @@ Tensor& custom_sdpa_out_impl(
466469
k_scales, // k_scales
467470
v_zero_points, // v_zero_points
468471
v_scales, // v_scales
469-
is_seq_at_dim_1, /* is_seq_at_dim_1 */
472+
seq_dim, /* seq_dim */
470473
start_pos,
471474
num_keys_for_causal_attention);
472475
}
@@ -492,6 +495,7 @@ Tensor& custom_quantized_sdpa_out(
492495
const optional<Tensor>& k_scales,
493496
const optional<Tensor>& v_zero_points,
494497
const optional<Tensor>& v_scales,
498+
const bool is_seq_at_dim_2,
495499
Tensor& output) {
496500
return custom_sdpa_out_impl(
497501
ctx,
@@ -509,7 +513,8 @@ Tensor& custom_quantized_sdpa_out(
509513
k_zero_points,
510514
k_scales,
511515
v_zero_points,
512-
v_scales);
516+
v_scales,
517+
is_seq_at_dim_2);
513518
}
514519
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
515520

‎extension/llm/custom_ops/op_sdpa.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ Tensor& custom_quantized_sdpa_out(
7474
const optional<Tensor>& k_scales,
7575
const optional<Tensor>& v_zero_points,
7676
const optional<Tensor>& v_scales,
77+
const bool is_seq_at_dim_1,
7778
Tensor& output);
7879
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
7980
} // namespace native

‎extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
9696
const optional<Tensor> k_scales,
9797
const optional<Tensor> v_zero_points,
9898
const optional<Tensor> v_scales,
99+
const bool is_seq_at_dim_2,
99100
Tensor& output);
100101

101102
at::Tensor custom_quantized_sdpa_aten(
@@ -115,7 +116,8 @@ at::Tensor custom_quantized_sdpa_aten(
115116
const std::optional<at::Tensor>& k_zero_points,
116117
const std::optional<at::Tensor>& k_scales,
117118
const std::optional<at::Tensor>& v_zero_points,
118-
const std::optional<at::Tensor>& v_scales);
119+
const std::optional<at::Tensor>& v_scales,
120+
const bool is_seq_at_dim_2);
119121
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
120122

121123
Tensor& update_cache_out_no_context(
@@ -258,6 +260,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
258260
const optional<Tensor> k_scales,
259261
const optional<Tensor> v_zero_points,
260262
const optional<Tensor> v_scales,
263+
const bool is_seq_at_dim_2,
261264
Tensor& output) {
262265
executorch::aten::RuntimeContext context{};
263266
return torch::executor::native::custom_quantized_sdpa_out(
@@ -276,6 +279,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
276279
k_scales,
277280
v_zero_points,
278281
v_scales,
282+
is_seq_at_dim_2,
279283
output);
280284
}
281285

@@ -296,9 +300,10 @@ at::Tensor custom_quantized_sdpa_aten(
296300
const std::optional<at::Tensor>& k_zero_points,
297301
const std::optional<at::Tensor>& k_scales,
298302
const std::optional<at::Tensor>& v_zero_points,
299-
const std::optional<at::Tensor>& v_scales) {
303+
const std::optional<at::Tensor>& v_scales,
304+
const bool is_seq_at_dim_2) {
300305
auto output = at::empty(q.sizes());
301-
WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 14)
306+
WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 15)
302307
(q,
303308
k,
304309
v,
@@ -313,6 +318,7 @@ at::Tensor custom_quantized_sdpa_aten(
313318
k_scales,
314319
v_zero_points,
315320
v_scales,
321+
is_seq_at_dim_2,
316322
output);
317323
return output;
318324
}
@@ -371,13 +377,13 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
371377
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
372378
"float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
373379
"Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
374-
"Tensor? v_scales=None) -> Tensor");
380+
"Tensor? v_scales=None, bool is_seq_at_dim_2=False) -> Tensor");
375381
m.def(
376382
"custom_quantized_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
377383
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
378384
"float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
379385
"Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
380-
"Tensor? v_scales=None, *, Tensor(a!) out) -> Tensor(a!)");
386+
"Tensor? v_scales=None, bool is_seq_at_dim_2=False, *, Tensor(a!) out) -> Tensor(a!)");
381387
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
382388
}
383389

@@ -404,6 +410,6 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
404410
m.impl(
405411
"custom_quantized_sdpa.out",
406412
WRAP_TO_ATEN(
407-
torch::executor::native::custom_quantized_sdpa_out_no_context, 14));
413+
torch::executor::native::custom_quantized_sdpa_out_no_context, 15));
408414
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
409415
}

‎extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 84 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,17 @@ namespace executor {
3232

3333
namespace native {
3434

35+
enum class SeqDim { ONE = 1, TWO };
36+
3537
namespace sdpa::impl {
3638

3739
struct MaybeQuantizedMatrixData {
3840
const void* data{nullptr};
3941
const int8_t* zero_points{nullptr};
4042
const float* scales{nullptr};
4143
int64_t m = 0, n = 0;
44+
const int64_t zero_points_stride{1};
45+
const int64_t scales_stride{1};
4246
ScalarType dtype{ScalarType::Float};
4347
MaybeQuantizedMatrixData() = default;
4448
MaybeQuantizedMatrixData(
@@ -47,12 +51,15 @@ struct MaybeQuantizedMatrixData {
4751
const float* scales_,
4852
int64_t m_,
4953
int64_t n_,
54+
int64_t qparams_stride,
5055
ScalarType dtype_)
5156
: data(data_),
5257
zero_points(zero_points_),
5358
scales(scales_),
5459
m(m_),
5560
n(n_),
61+
zero_points_stride(qparams_stride),
62+
scales_stride(qparams_stride),
5663
dtype(dtype_) {}
5764
};
5865

@@ -91,8 +98,9 @@ void _q_at_k_gemm(
9198
static_cast<const int8_t*>(k_data.zero_points),
9299
static_cast<const float*>(q_data.scales),
93100
static_cast<const float*>(k_data.scales),
94-
1,
95-
1);
101+
// LHS and RHS are assumed to have same stride for qparams
102+
q_data.zero_points_stride,
103+
k_data.zero_points_stride);
96104
} else {
97105
ET_CHECK_MSG(
98106
false, "Accumulation in dtype other than float not supported yet");
@@ -152,7 +160,7 @@ void _qk_at_v_gemm(
152160
static_cast<const int8_t*>(v_data.zero_points),
153161
static_cast<const float*>(v_data.scales),
154162
beta,
155-
1);
163+
v_data.zero_points_stride);
156164
} else {
157165
ET_CHECK_MSG(
158166
false, "Accumulation in dtype other than float not supported yet");
@@ -351,6 +359,40 @@ sdpa_with_kv_cache does not use attn_mask.
351359
352360
TODO: Just handle conversion of bool mask to float
353361
*/
362+
/**
363+
* @brief Implements Flash Attention algorithm on CPU
364+
*
365+
* This function computes scaled dot-product attention with optimizations for
366+
CPU.
367+
* It supports both regular and quantized attention computation.
368+
*
369+
* @tparam scalar_t The data type for computation (e.g., float)
370+
* @tparam q_split_size Block size for query matrix in tiling algorithm
371+
* @tparam kv_split_size Block size for key/value matrices in tiling algorithm
372+
*
373+
* @param output Output tensor to store attention results
374+
* @param query Query tensor [Batch x Num_heads x Q_seq_len x Dim_per_head]
375+
* @param key Key tensor [Batch x Num_heads_kv x KV_seq_len x Dim_per_head]
376+
* @param value Value tensor [Batch x Num_heads_kv x KV_seq_len x Dim_per_head]
377+
* @param dropout_p Dropout probability (not used in current implementation)
378+
* @param is_causal Whether to apply causal mask (lower triangular)
379+
* @param attn_mask Optional explicit attention mask
380+
* @param scale Optional custom scaling factor (default: 1/sqrt(head_dim))
381+
* @param q_zero_points Optional zero points for quantized query
382+
* @param q_scales Optional scales for quantized query
383+
* @param k_zero_points Optional zero points for quantized key
384+
* @param k_scales Optional scales for quantized key
385+
* @param v_zero_points Optional zero points for quantized value
386+
* @param v_scales Optional scales for quantized value
387+
* @param seq_dim Which dimension is sequence dimension.
388+
If SeqDim::One, then query, key, value are
389+
expected to be in shape [Batch x Q_seq_len x Dim_per_head x Num_heads] and
390+
output is expected to be in shape [Batch x Q_seq_len x Dim_per_head x
391+
Num_heads]
392+
* @param start_pos Starting position for causal masking in generation
393+
* @param num_keys_for_causal_attention Number of keys to consider for causal
394+
attention (-1 for all)
395+
*/
354396
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
355397
void cpu_flash_attention(
356398
Tensor& output,
@@ -367,22 +409,10 @@ void cpu_flash_attention(
367409
const optional<Tensor>& k_scales,
368410
const optional<Tensor>& v_zero_points,
369411
const optional<Tensor>& v_scales,
370-
bool is_seq_at_dim_1 = false,
412+
const SeqDim seq_dim = SeqDim::TWO,
371413
const int64_t start_pos = 0,
372414
const int64_t num_keys_for_causal_attention = -1) {
373415
(void)dropout_p;
374-
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
375-
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
376-
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
377-
378-
/*
379-
// -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
380-
at::Tensor query = q.transpose(1, 2);
381-
// -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
382-
at::Tensor key = k.transpose(1, 2);
383-
// -> (Batch x KV_seq_len x Num_heads x Dim_per_head)
384-
at::Tensor value = v.transpose(1, 2);
385-
*/
386416

387417
// Without this we have out-of-bounds writes for
388418
// causal masking
@@ -408,7 +438,7 @@ void cpu_flash_attention(
408438
int64_t kvSize = value.size(2);
409439
int64_t num_heads_kv = key.size(1);
410440

411-
if (is_seq_at_dim_1) {
441+
if (seq_dim == SeqDim::ONE) {
412442
num_head = query.size(2);
413443
num_heads_kv = key.size(2);
414444
qSize = query.size(1);
@@ -466,7 +496,7 @@ void cpu_flash_attention(
466496
int64_t qStrideH = strides[1];
467497
int64_t qStrideM = strides[2];
468498

469-
if (is_seq_at_dim_1) {
499+
if (seq_dim == SeqDim::ONE) {
470500
qStrideH = strides[2];
471501
qStrideM = strides[1];
472502
}
@@ -476,7 +506,7 @@ void cpu_flash_attention(
476506
int64_t kStrideH = strides[1];
477507
int64_t kStrideN = strides[2];
478508

479-
if (is_seq_at_dim_1) {
509+
if (seq_dim == SeqDim::ONE) {
480510
kStrideH = strides[2];
481511
kStrideN = strides[1];
482512
}
@@ -486,7 +516,7 @@ void cpu_flash_attention(
486516
int64_t vStrideH = strides[1];
487517
int64_t vStrideN = strides[2];
488518

489-
if (is_seq_at_dim_1) {
519+
if (seq_dim == SeqDim::ONE) {
490520
vStrideH = strides[2];
491521
vStrideN = strides[1];
492522
}
@@ -502,28 +532,44 @@ void cpu_flash_attention(
502532
int64_t v_quant_params_StrideN = 0;
503533

504534
if (is_quantized_sdpa) {
505-
strides = q_zero_points.value().strides();
506-
q_quant_params_StrideB = strides[0];
507-
q_quant_params_StrideH = strides[1];
508-
q_quant_params_StrideM = strides[2];
509-
510-
strides = k_zero_points.value().strides();
511-
k_quant_params_StrideB = strides[0];
512-
k_quant_params_StrideH = strides[1];
513-
k_quant_params_StrideN = strides[2];
514-
515-
strides = v_zero_points.value().strides();
516-
v_quant_params_StrideB = strides[0];
517-
v_quant_params_StrideH = strides[1];
518-
v_quant_params_StrideN = strides[2];
535+
auto q_strides = q_zero_points.value().strides();
536+
q_quant_params_StrideB = q_strides[0];
537+
q_quant_params_StrideH = q_strides[1];
538+
q_quant_params_StrideM = q_strides[2];
539+
540+
auto k_strides = k_zero_points.value().strides();
541+
k_quant_params_StrideB = k_strides[0];
542+
k_quant_params_StrideH = k_strides[1];
543+
k_quant_params_StrideN = k_strides[2];
544+
545+
auto v_strides = v_zero_points.value().strides();
546+
v_quant_params_StrideB = v_strides[0];
547+
v_quant_params_StrideH = v_strides[1];
548+
v_quant_params_StrideN = v_strides[2];
549+
550+
ET_CHECK_MSG(
551+
(v_quant_params_StrideN == k_quant_params_StrideN) &&
552+
(v_quant_params_StrideN == q_quant_params_StrideM),
553+
"Quant params strides must be same for seq dim");
554+
555+
if (seq_dim == SeqDim::ONE) {
556+
q_quant_params_StrideH = q_strides[2];
557+
q_quant_params_StrideM = q_strides[1];
558+
559+
k_quant_params_StrideH = k_strides[2];
560+
k_quant_params_StrideN = k_strides[1];
561+
562+
v_quant_params_StrideH = v_strides[2];
563+
v_quant_params_StrideN = v_strides[1];
564+
}
519565
}
520566

521567
strides = output.strides();
522568
int64_t oStrideB = strides[0];
523569
int64_t oStrideH = strides[1];
524570
int64_t oStrideM = strides[2];
525571

526-
if (is_seq_at_dim_1) {
572+
if (seq_dim == SeqDim::ONE) {
527573
oStrideH = strides[2];
528574
oStrideM = strides[1];
529575
}
@@ -679,13 +725,15 @@ void cpu_flash_attention(
679725
q_scales_ptr,
680726
qBlockSize,
681727
headSize,
728+
q_quant_params_StrideM,
682729
query.scalar_type());
683730
MaybeQuantizedMatrixData k_sub_matrix_data = MaybeQuantizedMatrixData(
684731
static_cast<const void*>(k_sub_matrix_data_ptr),
685732
k_zero_points_ptr,
686733
k_scales_ptr,
687734
kvBlockSize,
688735
headSize,
736+
k_quant_params_StrideN,
689737
key.scalar_type());
690738
_q_at_k_gemm<accum_t>(
691739
qBlockSize,
@@ -835,6 +883,7 @@ void cpu_flash_attention(
835883
v_scales_ptr,
836884
kvBlockSize,
837885
headSize,
886+
v_quant_params_StrideN,
838887
value.scalar_type());
839888
// Calculate Softmax(q @ k.T) @ v
840889
_qk_at_v_gemm<accum_t>(

‎extension/llm/custom_ops/test_quantized_sdpa.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def setUp(self):
3535
self.float_dtype = torch.float32
3636
self.q_shape = None
3737
self.kv_shape = None
38+
self.is_seq_at_dim_2 = True
3839

3940
def _scale_tensor(self, tensor, min_value, max_value, scale=True):
4041
normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
@@ -105,6 +106,10 @@ def _sdpa_ref(
105106
self.float_dtype,
106107
)
107108

109+
if not self.is_seq_at_dim_2:
110+
q = q.transpose(1, 2).contiguous()
111+
k = k.transpose(1, 2).contiguous()
112+
v = v.transpose(1, 2).contiguous()
108113
num_heads_q = q.size(1)
109114
num_heads_kv = k.size(1)
110115
seq_len = q.size(2)
@@ -119,6 +124,8 @@ def _sdpa_ref(
119124
k = k.repeat_interleave(n_reps, dim=1)
120125
v = v.repeat_interleave(n_reps, dim=1)
121126
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
127+
if not self.is_seq_at_dim_2:
128+
out = out.transpose(1, 2).contiguous()
122129
return out
123130

124131
def _int_matmul(
@@ -212,7 +219,7 @@ def _test_sdpa_common(
212219
seq_len,
213220
scale_tensors=False,
214221
atol=1e-5,
215-
is_seq_at_dim_2=True,
222+
is_seq_at_dim_2=False,
216223
):
217224
# Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests
218225
tensor_scale_max = 15
@@ -221,9 +228,10 @@ def _test_sdpa_common(
221228
self.n_heads_q = n_heads_q
222229
self.head_dim = head_dim
223230
self.max_seq_len = max_seq_len
231+
self.is_seq_at_dim_2 = is_seq_at_dim_2
224232
seq_dim = 2
225233
self.q_shape = (self.n_batch, self.n_heads_q, seq_len, self.head_dim)
226-
self.kv_shape = (self.n_batch, self.n_heads_q, self.max_seq_len, self.head_dim)
234+
self.kv_shape = (self.n_batch, self.n_heads_kv, self.max_seq_len, self.head_dim)
227235
if not is_seq_at_dim_2:
228236
seq_dim = 1
229237
self.q_shape = (self.n_batch, seq_len, self.n_heads_q, self.head_dim)
@@ -286,7 +294,6 @@ def _test_sdpa_common(
286294
quantized_dtype,
287295
)
288296

289-
start_pos = 0
290297
seq_len = q.size(seq_dim)
291298
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
292299
attn_mask = attn_mask[:, : start_pos + seq_len]
@@ -334,6 +341,7 @@ def _test_sdpa_common(
334341
k_scale_fp32,
335342
v_zero_point_int8,
336343
v_scale_fp32,
344+
is_seq_at_dim_2,
337345
)
338346
self.assertTrue(torch.allclose(ref_output, op_output, atol=atol))
339347
# Following line crashes due to some weird issues in mkldnn with crash in mkl_sgemm with `wild jump`
@@ -374,6 +382,7 @@ def _test_sdpa_common(
374382
k_scale_fp32,
375383
v_zero_point_int8,
376384
v_scale_fp32,
385+
is_seq_at_dim_2,
377386
)
378387
self.assertTrue(torch.allclose(ref_output, op_output, atol=atol))
379388

@@ -393,6 +402,18 @@ def test_sdpa_with_custom_quantized(self):
393402
seq_len,
394403
True,
395404
atol=1e-4,
405+
is_seq_at_dim_2=True,
406+
)
407+
self._test_sdpa_common(
408+
n_heads_kv,
409+
n_heads_q,
410+
head_dim,
411+
max_seq_len,
412+
start_pos,
413+
seq_len,
414+
True,
415+
atol=1e-4,
416+
is_seq_at_dim_2=False,
396417
)
397418

398419
def test_sdpa_with_custom_quantized_seq_len_1(self):
@@ -403,7 +424,22 @@ def test_sdpa_with_custom_quantized_seq_len_1(self):
403424
seq_len = 1
404425
start_pos = 0
405426
self._test_sdpa_common(
406-
n_heads_kv, n_heads_q, head_dim, max_seq_len, start_pos, seq_len
427+
n_heads_kv,
428+
n_heads_q,
429+
head_dim,
430+
max_seq_len,
431+
start_pos,
432+
seq_len,
433+
is_seq_at_dim_2=True,
434+
)
435+
self._test_sdpa_common(
436+
n_heads_kv,
437+
n_heads_q,
438+
head_dim,
439+
max_seq_len,
440+
start_pos,
441+
seq_len,
442+
is_seq_at_dim_2=False,
407443
)
408444

409445
def test_sdpa_with_custom_quantized_seq_len_small(self):
@@ -414,7 +450,22 @@ def test_sdpa_with_custom_quantized_seq_len_small(self):
414450
seq_len = 4
415451
start_pos = 0
416452
self._test_sdpa_common(
417-
n_heads_kv, n_heads_q, head_dim, max_seq_len, start_pos, seq_len
453+
n_heads_kv,
454+
n_heads_q,
455+
head_dim,
456+
max_seq_len,
457+
start_pos,
458+
seq_len,
459+
is_seq_at_dim_2=True,
460+
)
461+
self._test_sdpa_common(
462+
n_heads_kv,
463+
n_heads_q,
464+
head_dim,
465+
max_seq_len,
466+
start_pos,
467+
seq_len,
468+
is_seq_at_dim_2=False,
418469
)
419470

420471
def test_sdpa_with_custom_quantized_seq_len_llava_example(self):
@@ -466,5 +517,20 @@ def test_sdpa_with_cache_mqa(self):
466517
seq_len = 24
467518
start_pos = 0
468519
self._test_sdpa_common(
469-
n_heads_kv, n_heads_q, head_dim, max_seq_len, start_pos, seq_len
520+
n_heads_kv,
521+
n_heads_q,
522+
head_dim,
523+
max_seq_len,
524+
start_pos,
525+
seq_len,
526+
is_seq_at_dim_2=True,
527+
)
528+
self._test_sdpa_common(
529+
n_heads_kv,
530+
n_heads_q,
531+
head_dim,
532+
max_seq_len,
533+
start_pos,
534+
seq_len,
535+
is_seq_at_dim_2=False,
470536
)

0 commit comments

Comments
 (0)
Please sign in to comment.