Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,14 +1239,30 @@ def get_text_config(self, decoder=None, encoder=None) -> "PretrainedConfig":
if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder:
config_to_return = copy.deepcopy(config_to_return)
prefix_to_discard = "encoder" if decoder else "decoder"
prefix_to_keep = "decoder" if decoder else "encoder"
for key in config_to_return.to_dict():
if key.startswith(prefix_to_discard):
# NOTE: We don't want to discard the key if it is mapped from a different attribute name at read time
if key.startswith(prefix_to_discard) and key not in config_to_return.attribute_map.values():
delattr(config_to_return, key)
# old encoder/decoder models may use "encoder_layers"/"decoder_layers" instead of "num_hidden_layers"
if decoder and hasattr(config_to_return, "decoder_layers"):
config_to_return.num_hidden_layers = config_to_return.decoder_layers
elif encoder and hasattr(config_to_return, "encoder_layers"):
config_to_return.num_hidden_layers = config_to_return.encoder_layers
if key.startswith(prefix_to_keep):
# [encoder/decoder]_layers -> num_hidden_layers
if key == prefix_to_keep + "_layers":
new_key = "num_hidden_layers"
# [encoder/decoder]_attention_heads -> num_attention_heads
elif key == prefix_to_keep + "_attention_heads":
new_key = "num_attention_heads"
# e.g. encoder_hidden_act -> hidden_act
else:
new_key = key[len(prefix_to_keep) + 1 :]

# Does the class map the new key into a different attribute name at read time? if so, let's write
# into that attribute instead
if new_key in config_to_return.attribute_map:
new_key = config_to_return.attribute_map[new_key]

Comment on lines +1258 to +1262
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I got this. So if we map the new key back to attribute map, in models like BART we will do num_attention_head -> encoder_attention_heads. This doesn't look quite right if we asked for a decoder config

Copy link
Member Author

@gante gante Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not like right indeed 😢 But attribute_map is a class-level attribute, so we can't update it for the new configuration either (i.e. for the config instance returned by get_text_config).

Note that these encoder/decoder attributes in attribute_map are from old models, and that these inconsistencies only show up if they decide to print internal variables 👀

This means we are limited to two options, to maintain BC:

  1. [This PR] We use the same mapping all over the code (e.g.config.get_text_config(decoder=True).num_attention_head to get the number of attention heads in the decoder), but accept that some old configs will have odd representation because of their attribute map;
  2. [main] Have several if/else scattered across our codebase, like
num_decoder_layers = (
    getattr(config, "decoder_layers", None)  # flat configs case 1
    or getattr(config, "num_decoder_layers", None)  # flat configs case 2
    or decoder_config.num_hidden_layers  # modern default for decoders
)

(if we double-down in direction 2, we need to add more if/else cases, our current logic is not robust in all tests).

Option 1. seems much more reliable in the long run, and also nudges everyone into using the same names everywhere (as opposed to relying on attribute maps) 🤗

Copy link
Member Author

@gante gante Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we may be able to update the attribute_map logic to read/write into the target variable, as opposed to mapping the read/writes 👀

Example:
If we have the {"a": "b"} mapping, atm all reads to config.a actually read config.b without checking if a exists in config. Same for writes.

We could instade make config.a reads read config.a first and, if it doesn't exist, try to read config.b. All writes would write into config.a.

WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could instade make config.a reads read config.a first and, if it doesn't exist, try to read config.b. All writes would write into config.a.

This sounds interesting, and slightly breaking because we will end up with two keys for the same concept. It might raise questions such as which value is correct when inspecting visually or serializing configs. For ex: we might have both: image_token_id/image_token_index in some VLMs

Coming back to "Option 1", I see we always check for attribute mapping now. I was expecting that get_text_config() will return a different config only if config structure is nested tbh. Otherwise the whole config is a text config and has no other modalities

In this case I think the current approach is best we can do, because it helps to reduce LOC and is not much breaking. We can ignore the weird naming as noone would serialize/print the text config, I hope. Let's either keep it as is and I also have another option below. Feel free to ignore if it doesn't work

I looked though attribute maps in repo, and it always maps to encoder if encoder-decoder is used. We could deprecate this pattern gradually from mapping and nudge users to explicitly get with config.encoder_attention_heads. We will need to use consistent naming in encoder-decoder models and promote it for future model. Though this option will take a long time to deprecate, maybe even till v5 🙃

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp

What I'm reading is "let's go with this PR, and try to nudge users away from attribute_map". Is this correct? :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap, the second one is more longer term to make our lives better

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

(Approval please then 💛 )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, in general we never want attribute_map for new models anyway! We should always be explicit, and use it as a "last resort" to keep BC when we change attribute names or similar. But it's hard to read/follow, and make our lives harder everywhere, so we definitely want to stay away from it as much as possible!

value = getattr(config_to_return, key)
delattr(config_to_return, key)
setattr(config_to_return, new_key, value)

return config_to_return

Expand Down
45 changes: 15 additions & 30 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,8 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()

# 1. If it doesn't support cache, skip the test
if not hasattr(config.get_text_config(), "use_cache"):
decoder_config = config.get_text_config(decoder=True)
if not hasattr(decoder_config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")

model = model_class(config).to(torch_device)
Expand All @@ -1050,39 +1051,19 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None):
past_kv = outputs["past_key_values"]
is_legacy_cache = not isinstance(past_kv, Cache)

text_config = config.get_text_config()
num_decoder_layers = (
getattr(text_config, "decoder_layers", None)
or getattr(text_config, "num_decoder_layers", None)
or text_config.num_hidden_layers
)

num_decoder_layers = decoder_config.num_hidden_layers
Comment on lines -1053 to +1054
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love this pattern

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment above 😅

if custom_all_cache_shapes is None:
num_query_attention_heads = getattr(
text_config, "decoder_attention_heads", text_config.num_attention_heads
)
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
per_head_embed_dim = embed_dim // num_query_attention_heads
num_key_value_heads = (
text_config.num_key_value_heads
if getattr(text_config, "num_key_value_heads", None) is not None
else num_query_attention_heads
num_query_attention_heads = decoder_config.num_attention_heads
embed_dim = getattr(decoder_config, "d_model", decoder_config.hidden_size)
per_head_embed_dim = (
getattr(decoder_config, "head_dim", None) or embed_dim // num_query_attention_heads
)
num_key_value_heads = getattr(decoder_config, "num_key_value_heads", None) or num_query_attention_heads
Comment on lines +1057 to +1061
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's directly use the fallback value, it's much easier!

Suggested change
embed_dim = getattr(decoder_config, "d_model", decoder_config.hidden_size)
per_head_embed_dim = (
getattr(decoder_config, "head_dim", None) or embed_dim // num_query_attention_heads
)
num_key_value_heads = getattr(decoder_config, "num_key_value_heads", None) or num_query_attention_heads
embed_dim = getattr(decoder_config, "d_model", decoder_config.hidden_size)
per_head_embed_dim = getattr(decoder_config, "head_dim", embed_dim // num_query_attention_heads)
num_key_value_heads = getattr(decoder_config, "num_key_value_heads", num_query_attention_heads)

Copy link
Member Author

@gante gante Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two expressions are not equivalent :o

let's consider the following:

getattr(decoder_config, "head_dim", None) or embed_dim // num_query_attention_heads

vs

getattr(decoder_config, "head_dim", embed_dim // num_query_attention_heads)

When we have decoder_config.head_dim == None, the first one evaluates to embed_dim // num_query_attention_heads, while the second one evaluates to None


I implemented as you suggested at first, but it causes failures because it doesn't fetch the fallback when the original attribute exists and is set to None 💔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha damn, don't know why is it set to None sometimes in the config 🥲 Thanks for checking!

if config.is_encoder_decoder:
encoder_num_attention_heads = (
text_config.encoder_attention_heads
if hasattr(text_config, "encoder_attention_heads")
else text_config.num_attention_heads
)
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
batch_size, seq_length = inputs["decoder_input_ids"].shape[:2]
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
# autoregressive generation, we're keeping the test general and not checking the 3rd dim
default_cross_attention_shape = (
batch_size,
encoder_num_attention_heads,
encoder_per_head_embed_dim,
)
default_cross_attention_shape = (batch_size, num_key_value_heads, per_head_embed_dim)
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
all_cache_shapes = [
[
Expand Down Expand Up @@ -1138,9 +1119,13 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None):
# 3.2. Decoder-only checks
else:
num_cache_decoder_layers = len(past_kv)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
self.assertEqual(
# we may have skipped layers
num_cache_decoder_layers + getattr(decoder_config, "num_kv_shared_layers", 0),
num_decoder_layers,
)

for i in range(num_decoder_layers):
for i in range(num_cache_decoder_layers):
if is_legacy_cache:
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple

Expand Down
4 changes: 0 additions & 4 deletions tests/models/dia/test_modeling_dia.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,10 +517,6 @@ def test_generate_continue_from_past_key_values(self):
)
)

@unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
def test_past_key_values_format(self, custom_all_cache_shapes=None):
pass

@unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
def test_hidden_states_output(self):
pass
Expand Down
136 changes: 0 additions & 136 deletions tests/models/gemma3n/test_modeling_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
Cache,
Gemma3nAudioConfig,
Gemma3nAudioFeatureExtractor,
Gemma3nConfig,
Expand Down Expand Up @@ -572,141 +571,6 @@ def test_generate_with_static_cache(self):
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
self.assertTrue(has_similar_generate_outputs(dynamic_cache_generation, static_cache_generation))

@pytest.mark.generate
def test_past_key_values_format(self, custom_all_cache_shapes=None):
"""
Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test, or pass the
expected cache shapes.
Having a standard KV cache format is important for a consistent API (and for advanced generation methods).
"""
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()

# 1. If it doesn't support cache, skip the test
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")

model = model_class(config).to(torch_device)
model = model.eval()
if "use_cache" not in inputs:
inputs["use_cache"] = True
outputs = model(**inputs)

if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`")

# 2. retrieve the KV cache and compute its default expected shapes (if no custom shapes are provided)
past_kv = outputs["past_key_values"]
is_legacy_cache = not isinstance(past_kv, Cache)

text_config = config.get_text_config()
num_decoder_layers = (
getattr(text_config, "decoder_layers", None)
or getattr(text_config, "num_decoder_layers", None)
or text_config.num_hidden_layers
)

if custom_all_cache_shapes is None:
num_query_attention_heads = getattr(
text_config, "decoder_attention_heads", text_config.num_attention_heads
)
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
per_head_embed_dim = embed_dim // num_query_attention_heads
num_key_value_heads = (
text_config.num_key_value_heads
if getattr(text_config, "num_key_value_heads", None) is not None
else num_query_attention_heads
)
if config.is_encoder_decoder:
encoder_num_attention_heads = (
text_config.encoder_attention_heads
if hasattr(text_config, "encoder_attention_heads")
else text_config.num_attention_heads
)
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
batch_size, seq_length = inputs["decoder_input_ids"].shape[:2]
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
# autoregressive generation, we're keeping the test general and not checking the 3rd dim
default_cross_attention_shape = (
batch_size,
encoder_num_attention_heads,
encoder_per_head_embed_dim,
)
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
all_cache_shapes = [
[
default_self_attention_shape,
default_self_attention_shape,
default_cross_attention_shape,
default_cross_attention_shape,
]
for _ in range(num_decoder_layers)
]
else:
batch_size, seq_length = inputs["input_ids"].shape[:2]
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
all_cache_shapes = [
[default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers)
]

else:
all_cache_shapes = custom_all_cache_shapes

# 3. Check cache shapes
# 3.1. Encoder-Decoder checks
if config.is_encoder_decoder:
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)

for i in range(num_decoder_layers):
if is_legacy_cache:
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple

# Self attention
self_attention_layer_keys = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
)
self_attention_layer_values = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
)
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])

# Cross attention (ignore 3rd dim, see default shape preparation)
cross_attention_layer_keys = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
)
cross_attention_layer_values = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
)
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])

# 3.2. Decoder-only checks
else:
num_cache_decoder_layers = len(past_kv)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers - text_config.num_kv_shared_layers)

for i in range(num_decoder_layers - text_config.num_kv_shared_layers):
if is_legacy_cache:
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple

# Self attention
if is_legacy_cache:
self_attention_layer_keys = past_kv[i][0]
self_attention_layer_values = past_kv[i][1]
elif getattr(past_kv, "layers", None) is None:
# Cache is lot layered (i.e, Mamba derivatives)
self_attention_layer_keys = past_kv.key_cache[i]
self_attention_layer_values = past_kv.value_cache[i]
else:
self_attention_layer_keys = past_kv.layers[i].keys
self_attention_layer_values = past_kv.layers[i].values
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])


class Gemma3nVision2TextModelTester:
text_config = {"activation_sparsity_pattern": None}
Expand Down
6 changes: 0 additions & 6 deletions tests/models/got_ocr2/test_modeling_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,6 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)

@unittest.skip(
reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
)
def test_past_key_values_format(self):
pass


@require_torch
class GotOcr2IntegrationTest(unittest.TestCase):
Expand Down
5 changes: 0 additions & 5 deletions tests/models/speecht5/test_modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from transformers import SpeechT5Config, SpeechT5HifiGanConfig
from transformers.testing_utils import (
is_flaky,
is_torch_available,
require_deterministic_for_xpu,
require_sentencepiece,
Expand Down Expand Up @@ -728,10 +727,6 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@is_flaky(max_attempts=5, description="Flaky for some input configurations.")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(double-checked with flake-finder -- it is no longer flaky)

def test_past_key_values_format(self):
super().test_past_key_values_format()

# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
Expand Down
Loading