Skip to content

Add Esm #2244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open

Add Esm #2244

wants to merge 11 commits into from

Conversation

pass-lin
Copy link
Contributor

@pass-lin pass-lin commented May 3, 2025

from #2177
Achieved a smaller error with hf.

import os
os.environ["KERAS_BACKEND"] = "torch"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

from keras import ops
from transformers.models.esm.modeling_esm import EsmAttention as hf_EsmSelfAttention
from transformers import EsmConfig
from esm2.esm2_layers import EsmSelfAttention
import numpy as np
import keras
from transformers.models.esm.modeling_esm import EsmModel
weights_path = "facebook/esm2_t6_8M_UR50D"
hf_model = EsmModel.from_pretrained(weights_path)
hf_model.cuda().eval()
hf_model.embeddings.token_dropout = False


from keras_hub.src.models.esm.esm_backbone import (
    ESMBackbone,
)


keras_model =  ESMBackbone.from_preset('hf://'+weights_path)
keras_model.summary()


x = ops.array([[1,2,3,4,5]])+1
hf_out = hf_model(x,ops.ones_like(x))[0]
keras_out = keras_model({'token_ids': x})

print(ops.all(ops.isclose(hf_out, keras_out,atol=1e-4)))

@pass-lin
Copy link
Contributor Author

pass-lin commented May 3, 2025

ruff.....................................................................Passed
ruff-format..............................................................Passed
Error: Process completed with exit code 1.

Please help me figure out how to solve this problem.

@mattdangerw
Copy link
Member

Probably an issue with generating the API symbols. Looks like you need to sync with the latest changes on master, then you could try running ./shell/api_gen.sh

@sachinprasadhs
Copy link
Collaborator

ruff.....................................................................Passed
ruff-format..............................................................Passed
Error: Process completed with exit code 1.

Please help me figure out how to solve this problem.

You can rebase it to latest master code
and then run - pre-commit run --all-files
pip install -u namex

@pass-lin
Copy link
Contributor Author

keras_hub/src/layers/modeling/reversible_embedding_test.py::ReversibleEmbeddingTest::test_quantize_dtype_argument_tie_weights - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/layers/modeling/reversible_embedding_test.py::ReversibleEmbeddingTest::test_quantize_dtype_argument_untie_weights - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/layers/modeling/reversible_embedding_test.py::ReversibleEmbeddingTest::test_quantize_int8_tie_weights - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/layers/modeling/reversible_embedding_test.py::ReversibleEmbeddingTest::test_quantize_int8_untie_weights - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/albert/albert_backbone_test.py::AlbertBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/bart/bart_backbone_test.py::BartBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/bert/bert_backbone_test.py::BertBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/bloom/bloom_backbone_test.py::BloomBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/clip/clip_backbone_test.py::CLIPBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/deberta_v3/deberta_v3_backbone_test.py::DebertaV3BackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/distil_bert/distil_bert_backbone_test.py::DistilBertBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/electra/electra_backbone_test.py::ElectraBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/f_net/f_net_backbone_test.py::FNetBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/falcon/falcon_backbone_test.py::FalconBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/gemma/gemma_backbone_test.py::GemmaBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/gemma/gemma_backbone_test.py::Gemma2BackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/gpt2/gpt2_backbone_test.py::GPT2BackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone_test.py::GPTNeoXBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/llama/llama_backbone_test.py::LlamaTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/mistral/mistral_backbone_test.py::MistralBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/opt/opt_backbone_test.py::OPTBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py::PaliGemmaBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py::PaliGemma2BackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/phi3/phi3_backbone_test.py::Phi3Test::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/phi3/phi3_backbone_test.py::Phi3Test::test_backbone_basics_with_su_rotary - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/roberta/roberta_backbone_test.py::RobertaBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/siglip/siglip_backbone_test.py::SigLIPBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/siglip/siglip_backbone_test.py::SigLIP2BackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/t5/t5_backbone_test.py::T5BackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/whisper/whisper_backbone_test.py::WhisperBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/xlm_roberta/xlm_roberta_backbone_test.py

@mattdangerw @sachinprasadhs
Is it a problem with the test environment? Why are there so many errors that don't belong to me?

@sachinprasadhs
Copy link
Collaborator

It's not related to your code, looks like some issue with the JAX backend, we will look into it.

Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks fro the PR, I have added my comments, also add checkpoints conversion under: keras-hub/tools/checkpoint_conversion

Comment on lines +27 to +29
Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this, we can have this mentioned in the Kaggle or HuggingFace if needed.

intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
dropout: float. Dropout probability for the Transformer encoder.
layer_norm_eps:bool.Should we use ln after embedding?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't get the point here, are you asking our input or it's the arg detail, if it is the arg details, it needs to be repharsed, avoid question marks and the argument name is emb_layer_norm_before

layer_norm_eps discription needs to be updated.

hidden_dim: int. The size of the transformer encoding and pooler layers.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
dropout: float. Dropout probability for the Transformer encoder.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add: Defaults to 0.1

Comment on lines +47 to +51
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
float32 precision regardless of dtype.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None or str or keras.mixed_precision.DTypePolicy.

Comment on lines +94 to +98
support_positon_type = ["rotary", "absolute"]
if position_embedding_type.lower() not in support_positon_type:
raise (
f"This model only support below position embedding type: {support_positon_type}" # noqa: E501
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
support_positon_type = ["rotary", "absolute"]
if position_embedding_type.lower() not in support_positon_type:
raise (
f"This model only support below position embedding type: {support_positon_type}" # noqa: E501
)
if position_embedding_type not in (
"rotary",
"absolute",):
raise ValueError(
'`position_embedding_type` must be either `"rotary"`, or '
f'`"absolute"`. Received position_embedding_type={position_embedding_type}.'
)

init_kwargs=self.init_kwargs,
train_data=self.train_data,
expected_output_shape=(2, 2),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saved model test is missing

init_kwargs=self.init_kwargs,
train_data=self.train_data,
expected_output_shape=(2, 5, 10),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saved model test is missing

Comment on lines +30 to +34
Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
third party and subject to a separate license, available
[here](https://github.com/facebookresearch/esm).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this, this will be part of the license and in the model card.

Comment on lines +25 to +27
Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this

Comment on lines +23 to +36
), # 默认值为None
"layer_norm_eps": transformers_config.get(
"layer_norm_eps", 1e-12
), # 默认值为1e-12
"emb_layer_norm_before": transformers_config.get(
"emb_layer_norm_before", False
), # 默认值为False
"activation": transformers_config.get(
"activation", "gelu"
), # 默认值为"gelu"
"max_wavelength": transformers_config.get(
"max_wavelength", 10000
), # 默认值为10000
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants