Skip to content

Commit ecbc23d

Browse files
committed
2 parents 5cd7185 + ace9536 commit ecbc23d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+658
-324
lines changed

.github/workflows/scorecard.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,6 @@ jobs:
5656

5757
# Upload the results to GitHub's code scanning dashboard.
5858
- name: "Upload to code-scanning"
59-
uses: github/codeql-action/upload-sarif@1b549b9259bda1cb5ddde3b41741a82a2d15a841 # v3.28.13
59+
uses: github/codeql-action/upload-sarif@28deaeda66b76a05916b6923827895f2b14ab387 # v3.28.16
6060
with:
6161
sarif_file: results.sarif

conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def pytest_collection_modifyitems(config, items):
3838
]
3939

4040
requires_trainable_backend = pytest.mark.skipif(
41-
backend() == "numpy" or backend() == "openvino",
41+
backend() in ["numpy", "openvino"],
4242
reason="Trainer not implemented for NumPy and OpenVINO backend.",
4343
)
4444
for item in items:

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
from keras.src.ops.numpy import argsort as argsort
136136
from keras.src.ops.numpy import array as array
137137
from keras.src.ops.numpy import average as average
138+
from keras.src.ops.numpy import bartlett as bartlett
138139
from keras.src.ops.numpy import bincount as bincount
139140
from keras.src.ops.numpy import bitwise_and as bitwise_and
140141
from keras.src.ops.numpy import bitwise_invert as bitwise_invert
@@ -143,6 +144,7 @@
143144
from keras.src.ops.numpy import bitwise_or as bitwise_or
144145
from keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift
145146
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
147+
from keras.src.ops.numpy import blackman as blackman
146148
from keras.src.ops.numpy import broadcast_to as broadcast_to
147149
from keras.src.ops.numpy import ceil as ceil
148150
from keras.src.ops.numpy import clip as clip

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from keras.src.ops.numpy import argsort as argsort
2828
from keras.src.ops.numpy import array as array
2929
from keras.src.ops.numpy import average as average
30+
from keras.src.ops.numpy import bartlett as bartlett
3031
from keras.src.ops.numpy import bincount as bincount
3132
from keras.src.ops.numpy import bitwise_and as bitwise_and
3233
from keras.src.ops.numpy import bitwise_invert as bitwise_invert
@@ -35,6 +36,7 @@
3536
from keras.src.ops.numpy import bitwise_or as bitwise_or
3637
from keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift
3738
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
39+
from keras.src.ops.numpy import blackman as blackman
3840
from keras.src.ops.numpy import broadcast_to as broadcast_to
3941
from keras.src.ops.numpy import ceil as ceil
4042
from keras.src.ops.numpy import clip as clip

keras/api/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
from keras.src.ops.numpy import argsort as argsort
136136
from keras.src.ops.numpy import array as array
137137
from keras.src.ops.numpy import average as average
138+
from keras.src.ops.numpy import bartlett as bartlett
138139
from keras.src.ops.numpy import bincount as bincount
139140
from keras.src.ops.numpy import bitwise_and as bitwise_and
140141
from keras.src.ops.numpy import bitwise_invert as bitwise_invert
@@ -143,6 +144,7 @@
143144
from keras.src.ops.numpy import bitwise_or as bitwise_or
144145
from keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift
145146
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
147+
from keras.src.ops.numpy import blackman as blackman
146148
from keras.src.ops.numpy import broadcast_to as broadcast_to
147149
from keras.src.ops.numpy import ceil as ceil
148150
from keras.src.ops.numpy import clip as clip

keras/api/ops/numpy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from keras.src.ops.numpy import argsort as argsort
2828
from keras.src.ops.numpy import array as array
2929
from keras.src.ops.numpy import average as average
30+
from keras.src.ops.numpy import bartlett as bartlett
3031
from keras.src.ops.numpy import bincount as bincount
3132
from keras.src.ops.numpy import bitwise_and as bitwise_and
3233
from keras.src.ops.numpy import bitwise_invert as bitwise_invert
@@ -35,6 +36,7 @@
3536
from keras.src.ops.numpy import bitwise_or as bitwise_or
3637
from keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift
3738
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
39+
from keras.src.ops.numpy import blackman as blackman
3840
from keras.src.ops.numpy import broadcast_to as broadcast_to
3941
from keras.src.ops.numpy import ceil as ceil
4042
from keras.src.ops.numpy import clip as clip

keras/src/backend/jax/nn.py

Lines changed: 32 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ def wrap_flash_attention(
11271127
custom_mask=None,
11281128
attn_logits_soft_cap=None,
11291129
head_shards=1,
1130-
q_seq_shards=1
1130+
q_seq_shards=1,
11311131

11321132
):
11331133
if decoder_segment_ids is not None:
@@ -1149,10 +1149,8 @@ def wrap_flash_attention(
11491149
)
11501150
splash_kernel = splash_attention_kernel.make_splash_mha(
11511151
mask=multi_head_mask,
1152-
head_shards=head_shards,
1153-
q_seq_shards=q_seq_shards,
1154-
head_shards=head_shards,
1155-
q_seq_shards=q_seq_shards,
1152+
head_shards=1,
1153+
q_seq_shards=1,
11561154
attn_logits_soft_cap=attn_logits_soft_cap,
11571155
)
11581156

@@ -1185,92 +1183,34 @@ def dot_product_attention(
11851183
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
11861184
f"value.shape={value.shape}."
11871185
)
1188-
1189-
# Check platform
1190-
platform = jax.devices()[0].platform
1191-
is_tpu = platform == "tpu"
1192-
1193-
# Check if inputs use partial sharding (not fully replicated)
1194-
# Flash attention works well with fully replicated tensors on all platforms
1195-
# but may have issues with certain partial sharding patterns on non-TPU platforms
1196-
partially_sharded_inputs = any(
1197-
hasattr(t, "sharding") and not t.sharding.is_fully_replicated
1198-
for t in (query, key, value)
1199-
)
1200-
1201-
# Determine flash attention compatibility
12021186
if flash_attention is None:
1203-
# Auto-detect flash attention availability
1204-
if is_tpu:
1205-
# TPUs have specialized hardware for attention that works with any sharding pattern
1206-
flash_attention = True
1207-
else:
1208-
# For GPU/CPU with partially sharded inputs, we need multiple devices
1209-
# to efficiently handle the sharding
1210-
if partially_sharded_inputs and len(jax.devices()) <= 1:
1211-
flash_attention = False
1212-
else:
1213-
flash_attention = _can_use_flash_attention(query, key, value, bias)
1214-
elif flash_attention is True and not is_tpu:
1215-
# If flash attention is explicitly requested, validate compatibility
1216-
# Skip validation for TPU as it has specialized hardware support
1217-
try:
1218-
_can_use_flash_attention(query, key, value, bias, raise_error=True)
1219-
except Exception:
1220-
# Only disable flash attention on non-TPU platforms if validation fails
1221-
flash_attention = False
1222-
1223-
# TPU-specific flash attention path
1224-
if is_tpu and flash_attention:
1225-
# Transpose to ('batch', 'heads', 'length', 'head_dim')
1226-
query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3))
1227-
key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3))
1228-
value_tpu_layout = jnp.transpose(value, axes=(0, 2, 1, 3))
1229-
1230-
bs, num_heads, q_len, head_dim = query_tpu_layout.shape
1231-
1232-
# Apply scale to query if provided
1233-
if scale is not None:
1234-
# TPU kernel applies 1/sqrt(head_dim) internally, to achieve
1235-
# overall QK^T * scale, scale query by (scale * sqrt(head_dim))
1236-
query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim))
1237-
1238-
# Create segment IDs for Splash Attention (for packing/batching)
1239-
segment_ids = jnp.zeros([bs, q_len], dtype=jnp.int32)
1240-
decoder_segment_ids = splash_attention_kernel.SegmentIds(
1241-
q=segment_ids, kv=segment_ids
1242-
)
1243-
1244-
# Process mask for Splash Attention
1245-
custom_mask = None
1246-
if mask is not None:
1247-
mask_bool = mask.astype("bool") if mask.dtype != jnp.bool_ else mask
1248-
1249-
if mask_bool.ndim == 3 and mask_bool.shape[0] == bs:
1250-
custom_mask = mask_bool[0]
1251-
elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs:
1252-
custom_mask = mask_bool[0, 0]
1253-
1254-
if is_causal and custom_mask is not None:
1255-
causal_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))
1256-
custom_mask = jnp.logical_and(custom_mask, causal_mask)
1257-
1258-
if custom_mask is None and is_causal:
1259-
custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))
1187+
flash_attention = _can_use_flash_attention(query, key, value, bias)
1188+
elif flash_attention is True:
1189+
# Use `raise_error=True` to provide more details if the inputs failed to
1190+
# use flash attention
1191+
_can_use_flash_attention(query, key, value, bias, raise_error=True)
12601192

1261-
try:
1262-
output = wrap_flash_attention(
1263-
query_tpu_layout,
1264-
key_tpu_layout,
1265-
value_tpu_layout,
1266-
decoder_segment_ids=decoder_segment_ids,
1267-
custom_mask=custom_mask,
1268-
attn_logits_soft_cap=attn_logits_soft_cap,
1269-
)
1270-
# Transpose output back to Keras layout
1271-
return jnp.transpose(output, axes=(0, 2, 1, 3))
1272-
except Exception:
1273-
flash_attention = False
1193+
if jax.devices()[0].platform == "tpu":
1194+
# Transpose to ('batch', 'heads', 'length', 'kv')
1195+
query = jnp.transpose(query, axes=(0, 2, 1, 3))
1196+
key = jnp.transpose(key, axes=(0, 2, 1, 3))
1197+
value = jnp.transpose(value, axes=(0, 2, 1, 3))
1198+
B, H, S, KV = query.shape
1199+
1200+
segment_ids = jnp.ones([B, S])
1201+
# {token_ids, padding_mask, segment_ids} enable packing
1202+
out = wrap_flash_attention(
1203+
query,
1204+
key,
1205+
value,
1206+
decoder_segment_ids=splash_attention_kernel.SegmentIds(
1207+
segment_ids, segment_ids
1208+
),
1209+
custom_mask=mask,
1210+
attn_logits_soft_cap=attn_logits_soft_cap,
1211+
)
1212+
out = jnp.transpose(out, axes=(0, 2, 1, 3))
1213+
return out
12741214

12751215
# JAX native dot_product_attention for GPU or fallback for TPU
12761216
if hasattr(jax.nn, "dot_product_attention"):
@@ -1306,9 +1246,9 @@ def dot_product_attention(
13061246
"Please update it by following the official guide: "
13071247
"https://jax.readthedocs.io/en/latest/installation.html"
13081248
)
1309-
1310-
# Fallback to custom XLA implementation
1311-
# This is the reference implementation from jax.nn.dot_product_attention
1249+
# Ref: jax.nn.dot_product_attention
1250+
# https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
1251+
# Not support `query_seq_lengths` and `key_value_seq_lengths` args
13121252
output_shape = query.shape
13131253
_, _, K, H = key.shape
13141254
scale = (1.0 / jnp.sqrt(H)) if scale is None else scale

keras/src/backend/jax/numpy.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ def add(x1, x2):
3737
return jnp.add(x1, x2)
3838

3939

40+
def bartlett(x):
41+
x = convert_to_tensor(x)
42+
return jnp.bartlett(x)
43+
44+
4045
def bincount(x, weights=None, minlength=0, sparse=False):
4146
# Note: bincount is never tracable / jittable because the output shape
4247
# depends on the values in x.
@@ -469,6 +474,11 @@ def right_shift(x, y):
469474
return bitwise_right_shift(x, y)
470475

471476

477+
def blackman(x):
478+
x = convert_to_tensor(x)
479+
return jnp.blackman(x)
480+
481+
472482
def broadcast_to(x, shape):
473483
x = convert_to_tensor(x)
474484
return jnp.broadcast_to(x, shape)

keras/src/backend/jax/rnn.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,16 @@ def _step(states, current_input):
164164
else:
165165
# Assume the first state is the previous output.
166166
output_tm1 = states[0]
167+
if tree.is_nested(output_tm1):
168+
# Stacked RNN case: assume first state of last cell.
169+
output_tm1 = states[-1][0]
167170
masked_outs = jnp.where(is_masked, output_tm1, output_t)
168171

169-
new_states = [
170-
jnp.where(is_masked, s, ns)
171-
for s, ns in zip(states, new_states)
172-
]
172+
new_states = tree.map_structure(
173+
lambda s, ns: jnp.where(is_masked, s, ns),
174+
states,
175+
new_states,
176+
)
173177
return (new_states, masked_outs)
174178

175179
scan_xs = (inputs, mask)

keras/src/backend/numpy/numpy.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,11 @@ def average(x, axis=None, weights=None):
305305
return np.average(x, weights=weights, axis=axis)
306306

307307

308+
def bartlett(x):
309+
x = convert_to_tensor(x)
310+
return np.bartlett(x).astype(config.floatx())
311+
312+
308313
def bincount(x, weights=None, minlength=0, sparse=False):
309314
if sparse:
310315
raise ValueError("Unsupported value `sparse=True` with numpy backend")
@@ -385,6 +390,11 @@ def right_shift(x, y):
385390
return bitwise_right_shift(x, y)
386391

387392

393+
def blackman(x):
394+
x = convert_to_tensor(x)
395+
return np.blackman(x).astype(config.floatx())
396+
397+
388398
def broadcast_to(x, shape):
389399
return np.broadcast_to(x, shape)
390400

keras/src/backend/numpy/rnn.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,16 @@ def _step(states, current_input):
160160
else:
161161
# Assume the first state is the previous output.
162162
output_tm1 = states[0]
163+
if tree.is_nested(output_tm1):
164+
# Stacked RNN case: assume first state of last cell.
165+
output_tm1 = states[-1][0]
163166
masked_outs = np.where(is_masked, output_tm1, output_t)
164167

165-
new_states = [
166-
np.where(is_masked, s, ns)
167-
for s, ns in zip(states, new_states)
168-
]
168+
new_states = tree.map_structure(
169+
lambda s, ns: np.where(is_masked, s, ns),
170+
states,
171+
new_states,
172+
)
169173
return (new_states, masked_outs)
170174

171175
scan_xs = (inputs, mask)

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ NumpyDtypeTest::test_angle
88
NumpyDtypeTest::test_any
99
NumpyDtypeTest::test_argpartition
1010
NumpyDtypeTest::test_array
11+
NumpyDtypeTest::test_bartlett
12+
NumpyDtypeTest::test_blackman
1113
NumpyDtypeTest::test_bitwise
1214
NumpyDtypeTest::test_ceil
1315
NumpyDtypeTest::test_concatenate
@@ -42,7 +44,6 @@ NumpyDtypeTest::test_outer_
4244
NumpyDtypeTest::test_power
4345
NumpyDtypeTest::test_prod
4446
NumpyDtypeTest::test_quantile
45-
NumpyDtypeTest::test_ravel
4647
NumpyDtypeTest::test_repeat
4748
NumpyDtypeTest::test_roll
4849
NumpyDtypeTest::test_round
@@ -75,6 +76,8 @@ NumpyOneInputOpsCorrectnessTest::test_angle
7576
NumpyOneInputOpsCorrectnessTest::test_any
7677
NumpyOneInputOpsCorrectnessTest::test_argpartition
7778
NumpyOneInputOpsCorrectnessTest::test_array
79+
NumpyOneInputOpsCorrectnessTest::test_bartlett
80+
NumpyOneInputOpsCorrectnessTest::test_blackman
7881
NumpyOneInputOpsCorrectnessTest::test_bitwise_invert
7982
NumpyOneInputOpsCorrectnessTest::test_conj
8083
NumpyOneInputOpsCorrectnessTest::test_correlate
@@ -102,7 +105,6 @@ NumpyOneInputOpsCorrectnessTest::test_pad_int8_constant_2
102105
NumpyOneInputOpsCorrectnessTest::test_pad_uint8_constant_2
103106
NumpyOneInputOpsCorrectnessTest::test_pad_int32_constant_2
104107
NumpyOneInputOpsCorrectnessTest::test_prod
105-
NumpyOneInputOpsCorrectnessTest::test_ravel
106108
NumpyOneInputOpsCorrectnessTest::test_real
107109
NumpyOneInputOpsCorrectnessTest::test_reciprocal
108110
NumpyOneInputOpsCorrectnessTest::test_repeat
@@ -151,4 +153,6 @@ NumpyTwoInputOpsCorrectnessTest::test_tensordot
151153
NumpyTwoInputOpsCorrectnessTest::test_vdot
152154
NumpyTwoInputOpsCorrectnessTest::test_where
153155
NumpyOneInputOpsDynamicShapeTest::test_angle
156+
NumpyOneInputOpsDynamicShapeTest::test_bartlett
157+
NumpyOneInputOpsDynamicShapeTest::test_blackman
154158
NumpyOneInputOpsStaticShapeTest::test_angle

0 commit comments

Comments
 (0)