-
Notifications
You must be signed in to change notification settings - Fork 30.4k
[tests] update test_past_key_values_format
and delete overwrites
#40701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1034,7 +1034,8 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): | |||||||||||||||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common() | ||||||||||||||||||
|
||||||||||||||||||
# 1. If it doesn't support cache, skip the test | ||||||||||||||||||
if not hasattr(config.get_text_config(), "use_cache"): | ||||||||||||||||||
decoder_config = config.get_text_config(decoder=True) | ||||||||||||||||||
if not hasattr(decoder_config, "use_cache"): | ||||||||||||||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") | ||||||||||||||||||
|
||||||||||||||||||
model = model_class(config).to(torch_device) | ||||||||||||||||||
|
@@ -1050,39 +1051,19 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): | |||||||||||||||||
past_kv = outputs["past_key_values"] | ||||||||||||||||||
is_legacy_cache = not isinstance(past_kv, Cache) | ||||||||||||||||||
|
||||||||||||||||||
text_config = config.get_text_config() | ||||||||||||||||||
num_decoder_layers = ( | ||||||||||||||||||
getattr(text_config, "decoder_layers", None) | ||||||||||||||||||
or getattr(text_config, "num_decoder_layers", None) | ||||||||||||||||||
or text_config.num_hidden_layers | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
num_decoder_layers = decoder_config.num_hidden_layers | ||||||||||||||||||
Comment on lines
-1053
to
+1054
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. love this pattern There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my comment above 😅 |
||||||||||||||||||
if custom_all_cache_shapes is None: | ||||||||||||||||||
num_query_attention_heads = getattr( | ||||||||||||||||||
text_config, "decoder_attention_heads", text_config.num_attention_heads | ||||||||||||||||||
) | ||||||||||||||||||
embed_dim = getattr(text_config, "d_model", text_config.hidden_size) | ||||||||||||||||||
per_head_embed_dim = embed_dim // num_query_attention_heads | ||||||||||||||||||
num_key_value_heads = ( | ||||||||||||||||||
text_config.num_key_value_heads | ||||||||||||||||||
if getattr(text_config, "num_key_value_heads", None) is not None | ||||||||||||||||||
else num_query_attention_heads | ||||||||||||||||||
num_query_attention_heads = decoder_config.num_attention_heads | ||||||||||||||||||
embed_dim = getattr(decoder_config, "d_model", decoder_config.hidden_size) | ||||||||||||||||||
per_head_embed_dim = ( | ||||||||||||||||||
getattr(decoder_config, "head_dim", None) or embed_dim // num_query_attention_heads | ||||||||||||||||||
) | ||||||||||||||||||
num_key_value_heads = getattr(decoder_config, "num_key_value_heads", None) or num_query_attention_heads | ||||||||||||||||||
Comment on lines
+1057
to
+1061
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's directly use the fallback value, it's much easier!
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The two expressions are not equivalent :o let's consider the following: getattr(decoder_config, "head_dim", None) or embed_dim // num_query_attention_heads vs getattr(decoder_config, "head_dim", embed_dim // num_query_attention_heads) When we have I implemented as you suggested at first, but it causes failures because it doesn't fetch the fallback when the original attribute exists and is set to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ha damn, don't know why is it set to None sometimes in the config 🥲 Thanks for checking! |
||||||||||||||||||
if config.is_encoder_decoder: | ||||||||||||||||||
encoder_num_attention_heads = ( | ||||||||||||||||||
text_config.encoder_attention_heads | ||||||||||||||||||
if hasattr(text_config, "encoder_attention_heads") | ||||||||||||||||||
else text_config.num_attention_heads | ||||||||||||||||||
) | ||||||||||||||||||
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads | ||||||||||||||||||
batch_size, seq_length = inputs["decoder_input_ids"].shape[:2] | ||||||||||||||||||
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in | ||||||||||||||||||
# autoregressive generation, we're keeping the test general and not checking the 3rd dim | ||||||||||||||||||
default_cross_attention_shape = ( | ||||||||||||||||||
batch_size, | ||||||||||||||||||
encoder_num_attention_heads, | ||||||||||||||||||
encoder_per_head_embed_dim, | ||||||||||||||||||
) | ||||||||||||||||||
default_cross_attention_shape = (batch_size, num_key_value_heads, per_head_embed_dim) | ||||||||||||||||||
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) | ||||||||||||||||||
all_cache_shapes = [ | ||||||||||||||||||
[ | ||||||||||||||||||
|
@@ -1138,9 +1119,13 @@ def test_past_key_values_format(self, custom_all_cache_shapes=None): | |||||||||||||||||
# 3.2. Decoder-only checks | ||||||||||||||||||
else: | ||||||||||||||||||
num_cache_decoder_layers = len(past_kv) | ||||||||||||||||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers) | ||||||||||||||||||
self.assertEqual( | ||||||||||||||||||
# we may have skipped layers | ||||||||||||||||||
num_cache_decoder_layers + getattr(decoder_config, "num_kv_shared_layers", 0), | ||||||||||||||||||
num_decoder_layers, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
for i in range(num_decoder_layers): | ||||||||||||||||||
for i in range(num_cache_decoder_layers): | ||||||||||||||||||
if is_legacy_cache: | ||||||||||||||||||
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple | ||||||||||||||||||
|
||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,6 @@ | |
|
||
from transformers import SpeechT5Config, SpeechT5HifiGanConfig | ||
from transformers.testing_utils import ( | ||
is_flaky, | ||
is_torch_available, | ||
require_deterministic_for_xpu, | ||
require_sentencepiece, | ||
|
@@ -728,10 +727,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): | |
def test_training_gradient_checkpointing_use_reentrant_false(self): | ||
pass | ||
|
||
@is_flaky(max_attempts=5, description="Flaky for some input configurations.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (double-checked with |
||
def test_past_key_values_format(self): | ||
super().test_past_key_values_format() | ||
|
||
# overwrite from test_modeling_common | ||
def _mock_init_weights(self, module): | ||
if hasattr(module, "weight") and module.weight is not None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I got this. So if we map the new key back to attribute map, in models like BART we will do
num_attention_head
->encoder_attention_heads
. This doesn't look quite right if we asked for a decoder configUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not like right indeed 😢 But
attribute_map
is a class-level attribute, so we can't update it for the new configuration either (i.e. for the config instance returned byget_text_config
).Note that these
encoder
/decoder
attributes inattribute_map
are from old models, and that these inconsistencies only show up if they decide to print internal variables 👀This means we are limited to two options, to maintain BC:
config.get_text_config(decoder=True).num_attention_head
to get the number of attention heads in the decoder), but accept that some old configs will have odd representation because of their attribute map;main
] Have several if/else scattered across our codebase, like(if we double-down in direction 2, we need to add more if/else cases, our current logic is not robust in all tests).
Option 1. seems much more reliable in the long run, and also nudges everyone into using the same names everywhere (as opposed to relying on attribute maps) 🤗
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, we may be able to update the
attribute_map
logic to read/write into the target variable, as opposed to mapping the read/writes 👀Example:
If we have the
{"a": "b"}
mapping, atm all reads toconfig.a
actually readconfig.b
without checking ifa
exists inconfig
. Same for writes.We could instade make
config.a
reads readconfig.a
first and, if it doesn't exist, try to readconfig.b
. All writes would write intoconfig.a
.WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds interesting, and slightly breaking because we will end up with two keys for the same concept. It might raise questions such as which value is correct when inspecting visually or serializing configs. For ex: we might have both:
image_token_id/image_token_index
in some VLMsComing back to "Option 1", I see we always check for attribute mapping now. I was expecting that
get_text_config()
will return a different config only if config structure is nested tbh. Otherwise the whole config is a text config and has no other modalitiesIn this case I think the current approach is best we can do, because it helps to reduce LOC and is not much breaking. We can ignore the weird naming as noone would serialize/print the text config, I hope. Let's either keep it as is and I also have another option below. Feel free to ignore if it doesn't work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zucchini-nlp
What I'm reading is "let's go with this PR, and try to nudge users away from
attribute_map
". Is this correct? :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeap, the second one is more longer term to make our lives better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
(Approval please then 💛 )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, in general we never want
attribute_map
for new models anyway! We should always be explicit, and use it as a "last resort" to keep BC when we change attribute names or similar. But it's hard to read/follow, and make our lives harder everywhere, so we definitely want to stay away from it as much as possible!