diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index d450193dbc2d..37bbf18b4e1d 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -1239,14 +1239,30 @@ def get_text_config(self, decoder=None, encoder=None) -> "PretrainedConfig": if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder: config_to_return = copy.deepcopy(config_to_return) prefix_to_discard = "encoder" if decoder else "decoder" + prefix_to_keep = "decoder" if decoder else "encoder" for key in config_to_return.to_dict(): - if key.startswith(prefix_to_discard): + # NOTE: We don't want to discard the key if it is mapped from a different attribute name at read time + if key.startswith(prefix_to_discard) and key not in config_to_return.attribute_map.values(): delattr(config_to_return, key) - # old encoder/decoder models may use "encoder_layers"/"decoder_layers" instead of "num_hidden_layers" - if decoder and hasattr(config_to_return, "decoder_layers"): - config_to_return.num_hidden_layers = config_to_return.decoder_layers - elif encoder and hasattr(config_to_return, "encoder_layers"): - config_to_return.num_hidden_layers = config_to_return.encoder_layers + if key.startswith(prefix_to_keep): + # [encoder/decoder]_layers -> num_hidden_layers + if key == prefix_to_keep + "_layers": + new_key = "num_hidden_layers" + # [encoder/decoder]_attention_heads -> num_attention_heads + elif key == prefix_to_keep + "_attention_heads": + new_key = "num_attention_heads" + # e.g. encoder_hidden_act -> hidden_act + else: + new_key = key[len(prefix_to_keep) + 1 :] + + # Does the class map the new key into a different attribute name at read time? if so, let's write + # into that attribute instead + if new_key in config_to_return.attribute_map: + new_key = config_to_return.attribute_map[new_key] + + value = getattr(config_to_return, key) + delattr(config_to_return, key) + setattr(config_to_return, new_key, value) return config_to_return diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 85841c557df1..077a54a0faec 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1034,7 +1034,8 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): config, inputs = self.model_tester.prepare_config_and_inputs_for_common() # 1. If it doesn't support cache, skip the test - if not hasattr(config.get_text_config(), "use_cache"): + decoder_config = config.get_text_config(decoder=True) + if not hasattr(decoder_config, "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") model = model_class(config).to(torch_device) @@ -1050,39 +1051,19 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): past_kv = outputs["past_key_values"] is_legacy_cache = not isinstance(past_kv, Cache) - text_config = config.get_text_config() - num_decoder_layers = ( - getattr(text_config, "decoder_layers", None) - or getattr(text_config, "num_decoder_layers", None) - or text_config.num_hidden_layers - ) - + num_decoder_layers = decoder_config.num_hidden_layers if custom_all_cache_shapes is None: - num_query_attention_heads = getattr( - text_config, "decoder_attention_heads", text_config.num_attention_heads - ) - embed_dim = getattr(text_config, "d_model", text_config.hidden_size) - per_head_embed_dim = embed_dim // num_query_attention_heads - num_key_value_heads = ( - text_config.num_key_value_heads - if getattr(text_config, "num_key_value_heads", None) is not None - else num_query_attention_heads + num_query_attention_heads = decoder_config.num_attention_heads + embed_dim = getattr(decoder_config, "d_model", decoder_config.hidden_size) + per_head_embed_dim = ( + getattr(decoder_config, "head_dim", None) or embed_dim // num_query_attention_heads ) + num_key_value_heads = getattr(decoder_config, "num_key_value_heads", None) or num_query_attention_heads if config.is_encoder_decoder: - encoder_num_attention_heads = ( - text_config.encoder_attention_heads - if hasattr(text_config, "encoder_attention_heads") - else text_config.num_attention_heads - ) - encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads batch_size, seq_length = inputs["decoder_input_ids"].shape[:2] # The sequence length for the encoder K V depends on the model. Since it is not manipulated in # autoregressive generation, we're keeping the test general and not checking the 3rd dim - default_cross_attention_shape = ( - batch_size, - encoder_num_attention_heads, - encoder_per_head_embed_dim, - ) + default_cross_attention_shape = (batch_size, num_key_value_heads, per_head_embed_dim) default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) all_cache_shapes = [ [ @@ -1138,9 +1119,13 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): # 3.2. Decoder-only checks else: num_cache_decoder_layers = len(past_kv) - self.assertEqual(num_cache_decoder_layers, num_decoder_layers) + self.assertEqual( + # we may have skipped layers + num_cache_decoder_layers + getattr(decoder_config, "num_kv_shared_layers", 0), + num_decoder_layers, + ) - for i in range(num_decoder_layers): + for i in range(num_cache_decoder_layers): if is_legacy_cache: self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple diff --git a/tests/models/dia/test_modeling_dia.py b/tests/models/dia/test_modeling_dia.py index b927b903085b..900bb0cef73d 100644 --- a/tests/models/dia/test_modeling_dia.py +++ b/tests/models/dia/test_modeling_dia.py @@ -517,10 +517,6 @@ def test_generate_continue_from_past_key_values(self): ) ) - @unittest.skip(reason="Indirectly checked in Dia through the generate methods.") - def test_past_key_values_format(self, custom_all_cache_shapes=None): - pass - @unittest.skip(reason="Indirectly checked in Dia through the generate methods.") def test_hidden_states_output(self): pass diff --git a/tests/models/gemma3n/test_modeling_gemma3n.py b/tests/models/gemma3n/test_modeling_gemma3n.py index 4e2581757a1c..674ec644e70b 100644 --- a/tests/models/gemma3n/test_modeling_gemma3n.py +++ b/tests/models/gemma3n/test_modeling_gemma3n.py @@ -28,7 +28,6 @@ AutoModelForCausalLM, AutoProcessor, AutoTokenizer, - Cache, Gemma3nAudioConfig, Gemma3nAudioFeatureExtractor, Gemma3nConfig, @@ -572,141 +571,6 @@ def test_generate_with_static_cache(self): dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) self.assertTrue(has_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)) - @pytest.mark.generate - def test_past_key_values_format(self, custom_all_cache_shapes=None): - """ - Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test, or pass the - expected cache shapes. - Having a standard KV cache format is important for a consistent API (and for advanced generation methods). - """ - for model_class in self.all_generative_model_classes: - config, inputs = self.model_tester.prepare_config_and_inputs_for_common() - - # 1. If it doesn't support cache, skip the test - if not hasattr(config.get_text_config(), "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - - model = model_class(config).to(torch_device) - model = model.eval() - if "use_cache" not in inputs: - inputs["use_cache"] = True - outputs = model(**inputs) - - if "past_key_values" not in outputs: - self.skipTest(reason="This model doesn't return `past_key_values`") - - # 2. retrieve the KV cache and compute its default expected shapes (if no custom shapes are provided) - past_kv = outputs["past_key_values"] - is_legacy_cache = not isinstance(past_kv, Cache) - - text_config = config.get_text_config() - num_decoder_layers = ( - getattr(text_config, "decoder_layers", None) - or getattr(text_config, "num_decoder_layers", None) - or text_config.num_hidden_layers - ) - - if custom_all_cache_shapes is None: - num_query_attention_heads = getattr( - text_config, "decoder_attention_heads", text_config.num_attention_heads - ) - embed_dim = getattr(text_config, "d_model", text_config.hidden_size) - per_head_embed_dim = embed_dim // num_query_attention_heads - num_key_value_heads = ( - text_config.num_key_value_heads - if getattr(text_config, "num_key_value_heads", None) is not None - else num_query_attention_heads - ) - if config.is_encoder_decoder: - encoder_num_attention_heads = ( - text_config.encoder_attention_heads - if hasattr(text_config, "encoder_attention_heads") - else text_config.num_attention_heads - ) - encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads - batch_size, seq_length = inputs["decoder_input_ids"].shape[:2] - # The sequence length for the encoder K V depends on the model. Since it is not manipulated in - # autoregressive generation, we're keeping the test general and not checking the 3rd dim - default_cross_attention_shape = ( - batch_size, - encoder_num_attention_heads, - encoder_per_head_embed_dim, - ) - default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) - all_cache_shapes = [ - [ - default_self_attention_shape, - default_self_attention_shape, - default_cross_attention_shape, - default_cross_attention_shape, - ] - for _ in range(num_decoder_layers) - ] - else: - batch_size, seq_length = inputs["input_ids"].shape[:2] - default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) - all_cache_shapes = [ - [default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers) - ] - - else: - all_cache_shapes = custom_all_cache_shapes - - # 3. Check cache shapes - # 3.1. Encoder-Decoder checks - if config.is_encoder_decoder: - num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache) - self.assertEqual(num_cache_decoder_layers, num_decoder_layers) - - for i in range(num_decoder_layers): - if is_legacy_cache: - self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple - - # Self attention - self_attention_layer_keys = ( - past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys - ) - self_attention_layer_values = ( - past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values - ) - self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) - - # Cross attention (ignore 3rd dim, see default shape preparation) - cross_attention_layer_keys = ( - past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys - ) - cross_attention_layer_values = ( - past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values - ) - cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] - cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] - self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) - self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) - - # 3.2. Decoder-only checks - else: - num_cache_decoder_layers = len(past_kv) - self.assertEqual(num_cache_decoder_layers, num_decoder_layers - text_config.num_kv_shared_layers) - - for i in range(num_decoder_layers - text_config.num_kv_shared_layers): - if is_legacy_cache: - self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple - - # Self attention - if is_legacy_cache: - self_attention_layer_keys = past_kv[i][0] - self_attention_layer_values = past_kv[i][1] - elif getattr(past_kv, "layers", None) is None: - # Cache is lot layered (i.e, Mamba derivatives) - self_attention_layer_keys = past_kv.key_cache[i] - self_attention_layer_values = past_kv.value_cache[i] - else: - self_attention_layer_keys = past_kv.layers[i].keys - self_attention_layer_values = past_kv.layers[i].values - self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) - class Gemma3nVision2TextModelTester: text_config = {"activation_sparsity_pattern": None} diff --git a/tests/models/got_ocr2/test_modeling_got_ocr2.py b/tests/models/got_ocr2/test_modeling_got_ocr2.py index 87f182ac9cdb..59577106b069 100644 --- a/tests/models/got_ocr2/test_modeling_got_ocr2.py +++ b/tests/models/got_ocr2/test_modeling_got_ocr2.py @@ -177,12 +177,6 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - @unittest.skip( - reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format" - ) - def test_past_key_values_format(self): - pass - @require_torch class GotOcr2IntegrationTest(unittest.TestCase): diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 642759b00dcc..3b7ce7954f33 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -20,7 +20,6 @@ from transformers import SpeechT5Config, SpeechT5HifiGanConfig from transformers.testing_utils import ( - is_flaky, is_torch_available, require_deterministic_for_xpu, require_sentencepiece, @@ -728,10 +727,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @is_flaky(max_attempts=5, description="Flaky for some input configurations.") - def test_past_key_values_format(self): - super().test_past_key_values_format() - # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: diff --git a/tests/models/t5gemma/test_modeling_t5gemma.py b/tests/models/t5gemma/test_modeling_t5gemma.py index b603d2b266da..1749cc33a12f 100644 --- a/tests/models/t5gemma/test_modeling_t5gemma.py +++ b/tests/models/t5gemma/test_modeling_t5gemma.py @@ -45,7 +45,6 @@ T5GemmaForTokenClassification, T5GemmaModel, ) - from transformers.cache_utils import Cache class T5GemmaModelTester: @@ -983,126 +982,6 @@ def test_attention_outputs(self): [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], ) - # Based on tests.generation.test_utils.GenerationTesterMixin.test_past_key_values_format - # Adjust encoder attention number for cross-attention caching and update attention head dimension - @pytest.mark.generate - def test_past_key_values_format(self, custom_all_cache_shapes=None): - """ - Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test, or pass the - expected cache shapes. - Having a standard KV cache format is important for a consistent API (and for advanced generation methods). - """ - for model_class in self.all_generative_model_classes: - config, inputs = self.model_tester.prepare_config_and_inputs_for_common() - - # 1. If it doesn't support cache, skip the test - if not hasattr(config.get_text_config(), "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - - model = model_class(config).to(torch_device) - model = model.eval() - if "use_cache" not in inputs: - inputs["use_cache"] = True - outputs = model(**inputs) - - if "past_key_values" not in outputs: - self.skipTest(reason="This model doesn't return `past_key_values`") - - # 2. retrieve the KV cache and compute its default expected shapes (if no custom shapes are provided) - past_kv = outputs["past_key_values"] - is_legacy_cache = not isinstance(past_kv, Cache) - - text_config = config.get_text_config().decoder - num_decoder_layers = text_config.num_hidden_layers - - if custom_all_cache_shapes is None: - num_query_attention_heads = getattr( - text_config, "decoder_attention_heads", text_config.num_attention_heads - ) - per_head_embed_dim = text_config.head_dim - num_key_value_heads = ( - text_config.num_key_value_heads - if getattr(text_config, "num_key_value_heads", None) is not None - else num_query_attention_heads - ) - if config.is_encoder_decoder: - encoder_num_attention_heads = num_key_value_heads - encoder_per_head_embed_dim = per_head_embed_dim - batch_size, seq_length = inputs["decoder_input_ids"].shape[:2] - # The sequence length for the encoder K V depends on the model. Since it is not manipulated in - # autoregressive generation, we're keeping the test general and not checking the 3rd dim - default_cross_attention_shape = ( - batch_size, - encoder_num_attention_heads, - encoder_per_head_embed_dim, - ) - default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) - all_cache_shapes = [ - [ - default_self_attention_shape, - default_self_attention_shape, - default_cross_attention_shape, - default_cross_attention_shape, - ] - for _ in range(num_decoder_layers) - ] - else: - batch_size, seq_length = inputs["input_ids"].shape[:2] - default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) - all_cache_shapes = [ - [default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers) - ] - - else: - all_cache_shapes = custom_all_cache_shapes - - # 3. Check cache shapes - # 3.1. Encoder-Decoder checks - if config.is_encoder_decoder: - num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache) - self.assertEqual(num_cache_decoder_layers, num_decoder_layers) - - for i in range(num_decoder_layers): - if is_legacy_cache: - self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple - - # Self attention - self_attention_layer_keys = ( - past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys - ) - self_attention_layer_values = ( - past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values - ) - self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) - - # Cross attention (ignore 3rd dim, see default shape preparation) - cross_attention_layer_keys = ( - past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys - ) - cross_attention_layer_values = ( - past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values - ) - cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :] - cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :] - self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2]) - self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3]) - - # 3.2. Decoder-only checks - else: - num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv) - self.assertEqual(num_cache_decoder_layers, num_decoder_layers) - - for i in range(num_decoder_layers): - if is_legacy_cache: - self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple - - # Self attention - self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys - self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values - self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0]) - self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1]) - @unittest.skip("Mismatch issue doesn't exist in T5Gemma.") def test_load_with_mismatched_shapes(self): pass