Skip to content

Commit d7d2e43

Browse files
authored
Turn the attribute _return_attention_scores into an argument (#20803)
1 parent 8d292da commit d7d2e43

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

keras/src/layers/attention/multi_head_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ def __init__(
158158
self.seed = seed
159159

160160
self._inverse_sqrt_key_dim = 1.0 / math.sqrt(float(self._key_dim))
161-
self._return_attention_scores = False
162161

163162
# Check for flash attention constraints
164163
if self._flash_attention and self._dropout > 0.0:
@@ -419,6 +418,7 @@ def _compute_attention(
419418
value,
420419
attention_mask=None,
421420
training=None,
421+
return_attention_scores=False,
422422
):
423423
"""Applies Dot-product attention with query, key, value tensors.
424424
@@ -442,7 +442,7 @@ def _compute_attention(
442442
attention_scores: Multi-headed attention weights.
443443
"""
444444
# Check for flash attention constraints
445-
if self._flash_attention and self._return_attention_scores:
445+
if self._flash_attention and return_attention_scores:
446446
raise ValueError(
447447
"Returning attention scores is not supported when flash "
448448
"attention is enabled. Please disable flash attention to access"
@@ -452,7 +452,7 @@ def _compute_attention(
452452
# Determine whether to use dot-product attention
453453
use_dot_product_attention = not (
454454
self._dropout > 0.0
455-
or self._return_attention_scores
455+
or return_attention_scores
456456
or (len(query.shape) != 4)
457457
)
458458

@@ -525,7 +525,6 @@ def call(
525525
training=None,
526526
use_causal_mask=False,
527527
):
528-
self._return_attention_scores = return_attention_scores
529528
if key is None:
530529
key = value
531530

@@ -562,6 +561,7 @@ def call(
562561
value,
563562
attention_mask,
564563
training,
564+
return_attention_scores,
565565
)
566566
attention_output = self._output_dense(attention_output)
567567

0 commit comments

Comments
 (0)