Skip to content

fix gqa rotary dim 1 #19874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ struct GroupQueryAttentionParameters {
int kv_hidden_size;
int kv_num_heads;
int num_splits; // number of splits for splitkv
int rotary_dim; // rotary embedding dimension
bool is_unidirectional; // causal
int local_window_size;
bool kv_share_buffer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int seqlen_q,
int seqlen_k,
int seqlen_k_new,
int rotary_dim,
const float softmax_scale,
bool is_causal,
bool is_bf16,
Expand Down Expand Up @@ -448,7 +449,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.rotary_cos_ptr = rotary_cos;
params.rotary_sin_ptr = rotary_sin;
params.is_rotary_interleaved = is_rotary_interleaved;
params.rotary_dim = (head_size / 16) * 16;
params.rotary_dim = rotary_dim;
}

params.num_splits = num_splits;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int seqlen_q,
int seqlen_k,
int seqlen_k_new,
int rotary_dim,
const float softmax_scale,
bool is_causal,
bool is_bf16,
Expand Down
11 changes: 9 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ Status CheckInputs(const Tensor* query,
int total_sequence_length = *((*total_seqlen).template Data<int32_t>());
int present_sequence_length = std::max(total_sequence_length, past_sequence_length);

int rotary_dim = 0;
if (cos_cache != nullptr && sin_cache != nullptr) {
const auto& cos_dims = cos_cache->Shape().GetDims();
const auto& sin_dims = sin_cache->Shape().GetDims();
Expand All @@ -222,14 +223,19 @@ Status CheckInputs(const Tensor* query,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 0 should be of max_sequence_length.");
}
if (cos_dims[1] != (head_size / 16) * 8) {
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
}
if (sin_dims[1] != (head_size / 16) * 8) {
if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
}
if (cos_dims[1] != sin_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache and sin_cache dimension 1 must be the same.");
}
rotary_dim = static_cast<int>(cos_dims[1] * 2);
} else if (cos_cache != nullptr || sin_cache != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
Expand All @@ -248,6 +254,7 @@ Status CheckInputs(const Tensor* query,
output_parameters->head_size = head_size;
output_parameters->kv_hidden_size = kv_hidden_size;
output_parameters->kv_num_heads = kv_num_heads;
output_parameters->rotary_dim = rotary_dim;
output_parameters->is_packed_qkv = is_packed_qkv;
output_parameters->is_unidirectional = true;
output_parameters->is_prompt = is_prompt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ Status FlashAttention(
device_prop, stream, query, present_key, present_value, key, value, data.output,
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache,
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved,
parameters.is_packed_qkv));
Expand Down