Skip to content

Fixed issue with dot_product_attention when using TPU. #21254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 186 additions & 42 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,16 +1126,17 @@ def wrap_flash_attention(
decoder_segment_ids,
custom_mask=None,
attn_logits_soft_cap=None,
head_shards=1,
q_seq_shards=1,
):
if decoder_segment_ids is not None:
assert query.shape[2] == decoder_segment_ids.q.shape[1], (
"Sharding along sequence dimension not allowed in tpu kernel "
"attention"
"Sharding along sequence dimension not allowed"
" in TPU kernel attention"
)

if custom_mask is not None:
mask = splash_attention_mask.NumpyMask(array=custom_mask)

else:
mask = splash_attention_mask.CausalMask(
shape=(query.shape[2], query.shape[2])
Expand All @@ -1147,8 +1148,8 @@ def wrap_flash_attention(
)
splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask,
head_shards=1,
q_seq_shards=1,
head_shards=head_shards,
q_seq_shards=q_seq_shards,
attn_logits_soft_cap=attn_logits_soft_cap,
)

Expand All @@ -1168,6 +1169,38 @@ def dot_product_attention(
flash_attention=None,
attn_logits_soft_cap=None,
):
"""Computes dot-product attention given query, key, and value.

This is the core computation of attention that is used in transformers.
For TPU platforms, flash attention optimizations are automatically applied
when possible, and sharding parameters are inferred from the layout map
in the current distribution context.

Args:
query: Queries with shape `[batch, time, heads,
depth_k]`.
key: Keys with shape `[batch, time, heads,
depth_k]`.
value: Values with shape `[batch, time, heads,
depth_v]`.
bias: Optional bias with shape broadcastable to
`[batch, heads, dest_time, source_time]`.
mask: Optional mask with shape broadcastable to
`[batch, heads, dest_time, source_time]`.
scale: Float. Optional scale that is applied to the attention
computation.
is_causal: Boolean. Specifying whether causal masking is applied.
flash_attention: Boolean. Whether to use flash attention optimization
for increased performance. Default to None, which means it will
be auto-determined based on the platform, input shapes and
compatibility.
attn_logits_soft_cap: Float. Optional float to softly cap attention
logits to avoid numerical stability issues. Applied as:
`logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`.

Returns:
JAX Array of shape `[batch, time, heads, depth_v]`.
"""
query = convert_to_tensor(query)
key = convert_to_tensor(key)
value = convert_to_tensor(value)
Expand All @@ -1177,47 +1210,155 @@ def dot_product_attention(
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
f"value.shape={value.shape}."
)
if flash_attention is None:
flash_attention = _can_use_flash_attention(query, key, value, bias)
elif flash_attention is True:
# Use `raise_error=True` to provide more details if the inputs failed to
# use flash attention
_can_use_flash_attention(query, key, value, bias, raise_error=True)

if jax.devices()[0].platform == "tpu":
# Transpose to ('batch', 'heads', 'length', 'kv')
query = jnp.transpose(query, axes=(0, 2, 1, 3))
key = jnp.transpose(key, axes=(0, 2, 1, 3))
value = jnp.transpose(value, axes=(0, 2, 1, 3))
B, H, S, KV = query.shape

segment_ids = jnp.ones([B, S])
# {token_ids, padding_mask, segment_ids} enable packing
out = wrap_flash_attention(
query,
key,
value,
decoder_segment_ids=splash_attention_kernel.SegmentIds(
segment_ids, segment_ids
),
custom_mask=mask,
attn_logits_soft_cap=attn_logits_soft_cap,
# Check platform
platform = jax.devices()[0].platform
is_tpu = platform == "tpu"

# Get sharding parameters from distribution context
head_shards = 1
q_seq_shards = 1

if is_tpu:
try:
from keras.src.distribution.distribution_lib import ModelParallel
from keras.src.distribution.distribution_lib import (
distribution as get_dist,
)

# Get current distribution if available
dist = get_dist()
if dist and isinstance(dist, ModelParallel):
mesh = dist.device_mesh
if "model" in mesh.axis_names:
model_dim_index = mesh.axis_names.index("model")
# Set head_shards based on the model dimension of the mesh
head_shards = mesh.shape[model_dim_index]
# Typically keep q_seq_shards=1 for best performance
q_seq_shards = 1
except (ImportError, ValueError, AttributeError):
# Use default values if detection fails
head_shards = 1
q_seq_shards = 1

# Check if inputs use partial sharding (not fully replicated)
# Flash attention works well with fully replicated tensors on all platforms
# but may have issues with certain partial sharding patterns on non-TPU
# platforms
partially_sharded_inputs = any(
hasattr(t, "sharding") and not t.sharding.is_fully_replicated
for t in (query, key, value)
)

# Determine flash attention compatibility
if flash_attention is None:
# Auto-detect flash attention availability
if is_tpu:
# TPUs have specialized hardware for attention that works with any
# sharding pattern
flash_attention = True
else:
# For GPU/CPU with partially sharded inputs, we need
# multiple devices to efficiently handle the sharding
if partially_sharded_inputs and len(jax.devices()) <= 1:
flash_attention = False
else:
flash_attention = _can_use_flash_attention(
query, key, value, bias
)
elif flash_attention is True and not is_tpu:
# If flash attention is explicitly requested, validate compatibility
# Skip validation for TPU as it has specialized hardware support
try:
_can_use_flash_attention(query, key, value, bias, raise_error=True)
except Exception:
# Only disable flash attention on non-TPU platforms
# if validation fails
flash_attention = False

# TPU-specific flash attention path
if is_tpu and flash_attention:
# Transpose to ('batch', 'heads', 'length', 'head_dim')
query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3))
key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3))
value_tpu_layout = jnp.transpose(value, axes=(0, 2, 1, 3))

bs, num_heads, q_len, head_dim = query_tpu_layout.shape

# Apply scale to query if provided
if scale is not None:
# TPU kernel applies 1/sqrt(head_dim) internally, to achieve
# overall QK^T * scale, scale query by (scale * sqrt(head_dim))
query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim))

# Create segment IDs for Splash Attention (for packing/batching)
segment_ids = jnp.zeros([bs, q_len], dtype=jnp.int32)
decoder_segment_ids = splash_attention_kernel.SegmentIds(
q=segment_ids, kv=segment_ids
)
out = jnp.transpose(out, axes=(0, 2, 1, 3))
return out

# `dot_product_attention` is only available in jax>=0.4.31
# Process mask for Splash Attention
custom_mask = None
if mask is not None:
mask_bool = mask.astype("bool") if mask.dtype != jnp.bool_ else mask

if mask_bool.ndim == 3 and mask_bool.shape[0] == bs:
custom_mask = mask_bool[0]
elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs:
custom_mask = mask_bool[0, 0]

if is_causal and custom_mask is not None:
causal_mask = jnp.tril(
jnp.ones((q_len, q_len), dtype=jnp.bool_)
)
custom_mask = jnp.logical_and(custom_mask, causal_mask)

if custom_mask is None and is_causal:
custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))

try:
output = wrap_flash_attention(
query_tpu_layout,
key_tpu_layout,
value_tpu_layout,
decoder_segment_ids=decoder_segment_ids,
custom_mask=custom_mask,
attn_logits_soft_cap=attn_logits_soft_cap,
head_shards=head_shards,
q_seq_shards=q_seq_shards,
)
# Transpose output back to Keras layout
return jnp.transpose(output, axes=(0, 2, 1, 3))
except Exception:
flash_attention = False

# JAX native dot_product_attention for GPU or fallback for TPU
if hasattr(jax.nn, "dot_product_attention"):
return jax.nn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
implementation="cudnn" if flash_attention else "xla",
)
try:
return jax.nn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
implementation="cudnn" if flash_attention else "xla",
)
except Exception:
# If flash attention fails, fall back to XLA implementation
if flash_attention:
return jax.nn.dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
scale=scale,
is_causal=is_causal,
implementation="xla",
)
raise

if flash_attention:
raise RuntimeError(
Expand All @@ -1228,6 +1369,9 @@ def dot_product_attention(
# Ref: jax.nn.dot_product_attention
# https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
# Not support `query_seq_lengths` and `key_value_seq_lengths` args

# Fallback to custom XLA implementation
# This is the reference implementation from jax.nn.dot_product_attention
output_shape = query.shape
_, _, K, H = key.shape
scale = (1.0 / jnp.sqrt(H)) if scale is None else scale
Expand Down