-
Notifications
You must be signed in to change notification settings - Fork 30.3k
Fixed loading LongT5 from legacy checkpoints #40724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Just a word of commentary that we might discuss as well: I believe |
Hi @Szustarol it seems like you've identified the right issue, but is there a reason we can't just set |
Thanks for the follow-up @Rocketknight1, the transformers/src/transformers/models/longt5/modeling_longt5.py Lines 1741 to 1744 in e481228
But most T5 checkpoints don't actually set this (i.e. config.tie_word_embeddings=False ) for most models, I have also reviewed this for popular LongT5 models such as google/long-t5-local-base (see https://huggingface.co/google/long-t5-local-base/blob/main/config.json#L27). As a side note, this custom _tie_weights is actually not needed since modeling_utils tie decoder and encoder embeddings anyway if tie_word_embeddings=True, here:transformers/src/transformers/modeling_utils.py Line 2790 in 896e9ce
After reviewing the modeling code it seems to me like the weights are always shared through the "shared" layer, so I don't understand why someone made the choice to implement it that way, instead of just tying them. Furthermore, the _tie_weights code seems to also be wrong when the lm_head should be tied too: transformers/src/transformers/models/longt5/modeling_longt5.py Lines 1904 to 1908 in e481228
[...] transformers/src/transformers/models/longt5/modeling_longt5.py Lines 1941 to 1944 in e481228
lm_head remains untied in the above example, even though it should be.
Seems like To retain maximum compatibility we could:
A similar issue might affect other models to which the |
@Szustarol updating the modeling code is totally okay, especially if part of it seems to be wrong. Models like LongT5 are not widely used and so it is very possible that a past update broke them and no-one realized. I think if you make the update and you can show that:
that is enough evidence for us to accept the change. Either way, thank you for your careful approach to this PR! |
[For maintainers] Suggested jobs to run (before merge) run-slow: longt5 |
Thank you for the clarification @Rocketknight1, I think I now have a more resilient solution. After a bit more consideration and model inspection I noticed that some models are indeed saved erroneously (i.e. they don't contain the lm_head saved). This concerns the Google models, and only in the safetensors format (.bin works correctly). The current solution involves iterating over missing keys and checking if they should be tied. If yes, a warning is emitted (as in this case,
Implementing the tying method would make the usecase 2) not functional, since 'lm_head' would be tied too. Also, the current solution allows us to keep the current I evaluated my changes on 4 models, with the following snippet: from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# model_id = "google/long-t5-tglobal-base"
model_id = "google/long-t5-local-base"
# model_id = "pszemraj/long-t5-tglobal-base-sci-simplify"
# model_id = "agemagician/mlong-t5-tglobal-base"
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
use_safetensors=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
)
text = "A brown fox did"
inputs = tokenizer(text, return_tensors="pt", max_length=20, truncation=True)
outputs = model.generate(**inputs, max_length=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Seems like "google/long-t5-tlocal-base" is corrupted, by the way. Anyway, I think this wraps it up. PyTest report: |
Seems good - cc @ArthurZucker to confirm since this overrides |
What does this PR do?
Fixes #40635
The problem with loading legacy LongT5 checkpoints stems from the way the shared embedding tokens are defined in the model:
transformers/src/transformers/models/longt5/modeling_longt5.py
Lines 1350 to 1356 in 96a5774
In most legacy checkpoints they are not tied via the tied keys, but were still trained from the same shared weight that is created here:
transformers/src/transformers/models/longt5/modeling_longt5.py
Line 1901 in 96a5774
The solution to this problem is to manually assign the weights in overriden
from_pretrained
if the model contains the shared weights.This fix is somewhat dirty, it would be much better to just update them directly in the state_dict to avoid the missing keys warning, but there is currently no public interface of the
PretrainedModel
class that allows for that.If the reviewer has a better idea of how this could be accounted for I'm ready to devote my time towards fixing this issue.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@Rocketknight1