Skip to content

Conversation

Szustarol
Copy link

What does this PR do?

Fixes #40635

The problem with loading legacy LongT5 checkpoints stems from the way the shared embedding tokens are defined in the model:

class LongT5Stack(LongT5PreTrainedModel):
def __init__(self, config, embed_tokens=None):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight

In most legacy checkpoints they are not tied via the tied keys, but were still trained from the same shared weight that is created here:
self.shared = nn.Embedding(config.vocab_size, config.d_model)

The solution to this problem is to manually assign the weights in overriden from_pretrained if the model contains the shared weights.
This fix is somewhat dirty, it would be much better to just update them directly in the state_dict to avoid the missing keys warning, but there is currently no public interface of the PretrainedModel class that allows for that.

If the reviewer has a better idea of how this could be accounted for I'm ready to devote my time towards fixing this issue.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@Rocketknight1

@Szustarol
Copy link
Author

Just a word of commentary that we might discuss as well: I believe lm_head (if it exists) should also be loaded from the shared weights looking at how the model would usually be trained, but this is not evident from the model's architecture.

@Rocketknight1
Copy link
Member

Hi @Szustarol it seems like you've identified the right issue, but is there a reason we can't just set tied_weights_keys in the modeling code?

@Szustarol
Copy link
Author

Szustarol commented Sep 8, 2025

Thanks for the follow-up @Rocketknight1, the _tied_weights_keys are actually set in the modelling code, but the weights are not saved under the name of any of the actually tied weights ("encoder.embed_tokens.weight", "decoder.embed_tokens.weight", or "lm_head.weight"). Instead, they are saved as "shared.weight", since this is the source parameter name.
We could add "shared.weight" to the _tied_weights_keys, which would nicely resolve the issue by tying all the layers that need to be tied, but then we would still have to disrespect the config of the model.
The weights are actually only tied if this line runs:

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

But most T5 checkpoints don't actually set this (i.e. config.tie_word_embeddings=False) for most models, I have also reviewed this for popular LongT5 models such as google/long-t5-local-base (see https://huggingface.co/google/long-t5-local-base/blob/main/config.json#L27). As a side note, this custom _tie_weights is actually not needed since modeling_utils tie decoder and encoder embeddings anyway if tie_word_embeddings=True, here:
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):

After reviewing the modeling code it seems to me like the weights are always shared through the "shared" layer, so I don't understand why someone made the choice to implement it that way, instead of just tying them.

Furthermore, the _tie_weights code seems to also be wrong when the lm_head should be tied too:

class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin):
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]

[...]
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

lm_head remains untied in the above example, even though it should be.

Seems like _tie_weights was added to multiple models as part of 95020f2, which made LongT5 nonfunctional.
To be honest, I'm not completely sure why, since this is done, as mentioned previously, in modeling_utils.

To retain maximum compatibility we could:

  • remove the _tie_weights override which seems to just be be wrong
  • and add shared.weight to the _tied_weights_keys
    I will proceed with testing this solution, but this is interfering with the original source code, so I would appreciate some comment on this proposal.

A similar issue might affect other models to which the _tie_weights override has been added.

@Rocketknight1
Copy link
Member

@Szustarol updating the modeling code is totally okay, especially if part of it seems to be wrong. Models like LongT5 are not widely used and so it is very possible that a past update broke them and no-one realized.

I think if you make the update and you can show that:

  1. Tests (including slow tests) for LongT5 still pass
  2. Weight sharing now works correctly for LongT5 models on the Hub

that is enough evidence for us to accept the change. Either way, thank you for your careful approach to this PR!

Copy link
Contributor

github-actions bot commented Sep 8, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: longt5

@Szustarol
Copy link
Author

Thank you for the clarification @Rocketknight1, I think I now have a more resilient solution.

After a bit more consideration and model inspection I noticed that some models are indeed saved erroneously (i.e. they don't contain the lm_head saved). This concerns the Google models, and only in the safetensors format (.bin works correctly).
Given that those models are the most used ones (judging by download numbers) I decided to implement a solution that takes care about all the possible use cases, but this means that the lm_head cannot always be tied, because sometimes it is indeed supplied as a separate weight in the .bin or .safetensors checkpoint.

The current solution involves iterating over missing keys and checking if they should be tied. If yes, a warning is emitted (as in this case, tie_word_embeddings should be True), and a fallback weight value is loaded from the shared weights.
This is the only solution that works for the three types of models on the hub:

  1. models with saved shared weights requiring lm_head, decoder and encoder tying,
  2. models with saved shared weights and lm_head, requirng decoder and encoder tying,
  3. models with saved decoder, encoder and optionally lm_head weights, requiring no tying.

Implementing the tying method would make the usecase 2) not functional, since 'lm_head' would be tied too. Also, the current solution allows us to keep the current _tie_weights, for whatever reason it was ever introduced.

I evaluated my changes on 4 models, with the following snippet:

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# model_id = "google/long-t5-tglobal-base"
model_id = "google/long-t5-local-base"
# model_id = "pszemraj/long-t5-tglobal-base-sci-simplify"
# model_id = "agemagician/mlong-t5-tglobal-base"
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_id,
    use_safetensors=True
)

tokenizer = AutoTokenizer.from_pretrained(
    model_id,
)

text = "A brown fox did"
inputs = tokenizer(text, return_tensors="pt", max_length=20, truncation=True)
outputs = model.generate(**inputs, max_length=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
model_id Output - before the fix - .bin Output - before the fix - .safetensor Output - fixed - .bin Output - fixed - .safetensor
google/long-t5-tglobal-base Is it a fox or a fox terrier that is s Is it a fox or a fox terrier that is s Is it a fox or a fox terrier that is s
google/long-t5-local-base ),),),),),),),),),),),),),),),),),),), ),),),),),),),),),),),),),),),),),),), ),),),),),),),),),),),),),),),),),),),
pszemraj/long-t5-tglobal-base-sci-simplify Foxes were foxes that did not do their own things. thea: and to of Foxes were foxes that did not do their own things. Foxes were foxes that did not do their own things.
agemagician/mlong-t5-tglobal-base have a good time with the foxes. They are very good at laying eggs and have a good time with the foxes. They are very good at laying eggs and have a good time with the foxes. They are very good at laying eggs and

Seems like "google/long-t5-tlocal-base" is corrupted, by the way.

Anyway, I think this wraps it up.

PyTest report:
report.xml

@Rocketknight1
Copy link
Member

Seems good - cc @ArthurZucker to confirm since this overrides from_pretrained for a text model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Safetensors files for long-t5-tglobal models fail to load correctly
3 participants