Skip to content

Commit 2b59ab6

Browse files
committed
Remove default nature of attn_bias
1 parent 211fe7e commit 2b59ab6

File tree

4 files changed

+8
-6
lines changed

4 files changed

+8
-6
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14858,7 +14858,7 @@
1485814858
MPS: _scaled_dot_product_attention_math_mps
1485914859
tags: nondeterministic_seeded
1486014860

14861-
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
14861+
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
1486214862
dispatch:
1486314863
CUDA: _scaled_dot_product_flash_attention_cuda
1486414864
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
@@ -14874,7 +14874,7 @@
1487414874
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable
1487514875
tags: nondeterministic_seeded
1487614876

14877-
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
14877+
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
1487814878
device_check: NoCheck
1487914879
variants: function
1488014880
dispatch:
@@ -14915,13 +14915,13 @@
1491514915
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
1491614916
tags: nondeterministic_seeded
1491714917

14918-
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
14918+
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
1491914919
variants: function
1492014920
dispatch:
1492114921
CUDA: _flash_attention_forward
1492214922
tags: nondeterministic_seeded
1492314923

14924-
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)
14924+
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)
1492514925
device_check: NoCheck
1492614926
variants: function
1492714927
dispatch:

aten/src/ATen/native/transformers/cuda/attention.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
719719
v_t,
720720
std::nullopt,
721721
std::nullopt,
722+
std::nullopt,
722723
max_seqlen_batch_q,
723724
max_seqlen_batch_k,
724725
dropout_p,

aten/src/ATen/native/transformers/cuda/attention_backward.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attenti
791791
q_t,
792792
k_t,
793793
v_t,
794+
c10::nullopt,
794795
out_t,
795796
logsumexp,
796797
cumulative_sequence_length_q,

tools/autograd/derivatives.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2873,15 +2873,15 @@
28732873
output_differentiability: [True, False, False, False]
28742874
query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale)
28752875

2876-
- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
2876+
- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
28772877
output_differentiability: [True, False, False, False, False, False, False, False, False]
28782878
query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, attn_bias, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
28792879

28802880
- name: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)
28812881
output_differentiability: [True, False]
28822882
query, key, value: _scaled_dot_product_flash_attention_for_cpu_backward(grad, query, key, value, output, logsumexp, dropout_p, is_causal, attn_mask, scale)
28832883

2884-
- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
2884+
- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
28852885
output_differentiability: [True, False, False, False, False]
28862886
query, key, value: _flash_attention_backward_symint(grad, query, key, value, attn_bias, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale, window_size_left, window_size_right)
28872887

0 commit comments

Comments
 (0)