@@ -158,7 +158,6 @@ def __init__(
158
158
self .seed = seed
159
159
160
160
self ._inverse_sqrt_key_dim = 1.0 / math .sqrt (float (self ._key_dim ))
161
- self ._return_attention_scores = False
162
161
163
162
# Check for flash attention constraints
164
163
if self ._flash_attention and self ._dropout > 0.0 :
@@ -419,6 +418,7 @@ def _compute_attention(
419
418
value ,
420
419
attention_mask = None ,
421
420
training = None ,
421
+ return_attention_scores = False ,
422
422
):
423
423
"""Applies Dot-product attention with query, key, value tensors.
424
424
@@ -442,7 +442,7 @@ def _compute_attention(
442
442
attention_scores: Multi-headed attention weights.
443
443
"""
444
444
# Check for flash attention constraints
445
- if self ._flash_attention and self . _return_attention_scores :
445
+ if self ._flash_attention and return_attention_scores :
446
446
raise ValueError (
447
447
"Returning attention scores is not supported when flash "
448
448
"attention is enabled. Please disable flash attention to access"
@@ -452,7 +452,7 @@ def _compute_attention(
452
452
# Determine whether to use dot-product attention
453
453
use_dot_product_attention = not (
454
454
self ._dropout > 0.0
455
- or self . _return_attention_scores
455
+ or return_attention_scores
456
456
or (len (query .shape ) != 4 )
457
457
)
458
458
@@ -525,7 +525,6 @@ def call(
525
525
training = None ,
526
526
use_causal_mask = False ,
527
527
):
528
- self ._return_attention_scores = return_attention_scores
529
528
if key is None :
530
529
key = value
531
530
@@ -562,6 +561,7 @@ def call(
562
561
value ,
563
562
attention_mask ,
564
563
training ,
564
+ return_attention_scores ,
565
565
)
566
566
attention_output = self ._output_dense (attention_output )
567
567
0 commit comments