-
Notifications
You must be signed in to change notification settings - Fork 280
Add Esm #2244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add Esm #2244
Changes from all commits
f9ff098
cc4123b
d3f598d
737a147
140207b
cc9a11c
f8da784
72e9829
16bb9f2
5cbf577
6e9f817
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding | ||
from keras_hub.src.models.roformer_v2.roformer_v2_attention import ( | ||
RoformerAttention, | ||
) | ||
|
||
|
||
class ESMRotaryEmbedding(RotaryEmbedding): | ||
def _compute_cos_sin_embedding(self, x, position=1): | ||
dim = x.shape[-1] | ||
inv_freq = self.scaling_factor / ( | ||
self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim) | ||
) | ||
t = ops.arange(x.shape[position], dtype=x.dtype) | ||
freqs = ops.outer(t, inv_freq) | ||
emb = ops.concatenate((freqs, freqs), axis=-1) | ||
|
||
cos_emb = ops.cos(emb)[None, :, None, :] | ||
sin_emb = ops.sin(emb)[None, :, None, :] | ||
return cos_emb, sin_emb | ||
|
||
def call(self, q, k, position=1): | ||
cos_emb, sin_emb = self._compute_cos_sin_embedding(q, position) | ||
|
||
return ( | ||
self.apply_rotary_pos_emb(q, cos_emb, sin_emb), | ||
self.apply_rotary_pos_emb(k, cos_emb, sin_emb), | ||
) | ||
|
||
def rotate_half(self, x): | ||
x1, x2 = ops.split(x, 2, -1) | ||
return ops.concatenate((-x2, x1), axis=-1) | ||
|
||
def apply_rotary_pos_emb(self, x, cos, sin): | ||
cos = cos[:, : x.shape[1], :, :] | ||
sin = sin[:, : x.shape[1], :, :] | ||
|
||
return (x * cos) + (self.rotate_half(x) * sin) | ||
|
||
|
||
class EsmSelfAttention(RoformerAttention): | ||
"""MultiHeadAttention by ESM2 | ||
|
||
Referred to the implementation of HuggingFace. | ||
In fact, this part of the calculation is exactly the same as RoFormer. | ||
Only the calculation of the rotary part is different. | ||
""" | ||
|
||
def __init__(self, use_rotary=True, **kwargs): | ||
super().__init__(**kwargs) | ||
self.use_rotary = use_rotary | ||
|
||
def build(self, input_shape): | ||
super().build(input_shape) | ||
if self.use_rotary: | ||
self.rotary_embedding_layer = ESMRotaryEmbedding( | ||
max_wavelength=self.max_wavelength, dtype=self.dtype_policy | ||
) | ||
self.rotary_embedding_layer.build([]) | ||
|
||
def call(self, x, attention_mask=None): | ||
qw = self.q_dense(x) | ||
kw = self.k_dense(x) | ||
vw = self.v_dense(x) | ||
|
||
b, s = ops.shape(qw)[:2] | ||
qw = ops.reshape(qw, (b, s, self.heads, self.head_size)) | ||
kw = ops.reshape(kw, (b, s, self.heads, self.head_size)) | ||
vw = ops.reshape(vw, (b, s, self.heads, self.head_size)) | ||
|
||
if self.use_rotary: | ||
qw, kw = self.rotary_embedding_layer(qw, kw) | ||
if keras.__version__ < "3.6": | ||
raise ("Please make sure your Keras version is >=3.6.") | ||
flash_attention = keras.config.is_flash_attention_enabled() | ||
attention_mask = ops.reshape(attention_mask, [b, 1, s, 1]) | ||
if keras.config.backend() == "torch": | ||
attention_mask = ops.repeat(attention_mask, s, -1) | ||
attention_mask = ops.transpose(attention_mask, [0, 1, 3, 2]) | ||
o = ops.dot_product_attention( | ||
qw, kw, vw, mask=attention_mask, flash_attention=flash_attention | ||
) | ||
return self.o_dense(ops.reshape(o, [b, s, -1])) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"use_rotary": self.use_rotary, | ||
} | ||
) | ||
return config |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,217 @@ | ||||||||||||||||||||||||||
import keras | ||||||||||||||||||||||||||
from keras import activations | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
from keras_hub.src.api_export import keras_hub_export | ||||||||||||||||||||||||||
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding | ||||||||||||||||||||||||||
from keras_hub.src.models.backbone import Backbone | ||||||||||||||||||||||||||
from keras_hub.src.models.esm.esm_encoder import ESMEncoder | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def esm2_kernel_initializer(stddev=0.02): | ||||||||||||||||||||||||||
return keras.initializers.TruncatedNormal(stddev=stddev) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
@keras_hub_export( | ||||||||||||||||||||||||||
["keras_hub.models.ESM2Backbone", "keras_hub.models.ESMBackbone"] | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
class ESMBackbone(Backbone): | ||||||||||||||||||||||||||
"""A ESM2 and ESM encoder network. | ||||||||||||||||||||||||||
This class implements a bi-directional Transformer-based encoder as | ||||||||||||||||||||||||||
described in ["Roformer"](https://github.com/facebookresearch/esm). | ||||||||||||||||||||||||||
The default constructor gives a fully customizable, randomly initialized | ||||||||||||||||||||||||||
ESM2 encoder with any number of layers, heads, and embed dim.To | ||||||||||||||||||||||||||
load preset architectures and weights, use the `from_preset()` constructor. | ||||||||||||||||||||||||||
Disclaimer: Pre-trained models are provided on an "as is" basis, without | ||||||||||||||||||||||||||
warranties or conditions of any kind. | ||||||||||||||||||||||||||
Comment on lines
+27
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can remove this, we can have this mentioned in the Kaggle or HuggingFace if needed. |
||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add |
||||||||||||||||||||||||||
vocabulary_size: int. The size of the token vocabulary. | ||||||||||||||||||||||||||
num_layers: int. The number of transformer layers. | ||||||||||||||||||||||||||
num_heads: int. The number of attention heads for each transformer. | ||||||||||||||||||||||||||
The hidden size must be divisible by the number of attention heads. | ||||||||||||||||||||||||||
hidden_dim: int. The size of the transformer encoding and pooler layers. | ||||||||||||||||||||||||||
intermediate_dim: int. The output dimension of the first Dense layer in | ||||||||||||||||||||||||||
a two-layer feedforward network for each transformer. | ||||||||||||||||||||||||||
dropout: float. Dropout probability for the Transformer encoder. | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add: Defaults to |
||||||||||||||||||||||||||
layer_norm_eps:bool.Should we use ln after embedding? | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't get the point here, are you asking our input or it's the arg detail, if it is the arg details, it needs to be repharsed, avoid question marks and the argument name is
|
||||||||||||||||||||||||||
Since it's pre-norm, the default is false. | ||||||||||||||||||||||||||
max_sequence_length: int. The maximum sequence length that this encoder | ||||||||||||||||||||||||||
can consume. If None, `max_sequence_length` uses the value from | ||||||||||||||||||||||||||
sequence length. This determines the variable shape for positional | ||||||||||||||||||||||||||
embeddings. | ||||||||||||||||||||||||||
position_embedding_type:esm1 use abs position embeding,esm2 use rope. | ||||||||||||||||||||||||||
so this parameter is only except for absolute and rotary. | ||||||||||||||||||||||||||
Comment on lines
+45
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change to: position_embedding_type: str. The position embedding type to use. One of |
||||||||||||||||||||||||||
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use | ||||||||||||||||||||||||||
for model computations and weights. Note that some computations, | ||||||||||||||||||||||||||
such as softmax and layer normalization, will always be done at | ||||||||||||||||||||||||||
float32 precision regardless of dtype. | ||||||||||||||||||||||||||
Comment on lines
+47
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||||||
Examples: | ||||||||||||||||||||||||||
```python | ||||||||||||||||||||||||||
input_data = { | ||||||||||||||||||||||||||
"token_ids": np.ones(shape=(1, 12), dtype="int32"), | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
# Pretrained ESM2 encoder. | ||||||||||||||||||||||||||
model = keras_hub.models.ESM2Backbone.from_preset('hf://facebook/esm2_t6_8M_UR50D') | ||||||||||||||||||||||||||
model(input_data) | ||||||||||||||||||||||||||
# Randomly initialized ESM2 encoder with a custom config. | ||||||||||||||||||||||||||
model = keras_hub.models.ESM2Backbone( | ||||||||||||||||||||||||||
vocabulary_size=30552, | ||||||||||||||||||||||||||
num_layers=4, | ||||||||||||||||||||||||||
num_heads=4, | ||||||||||||||||||||||||||
hidden_dim=256, | ||||||||||||||||||||||||||
intermediate_dim=512, | ||||||||||||||||||||||||||
head_size = 64, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
model(input_data) | ||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||
vocabulary_size, | ||||||||||||||||||||||||||
num_layers, | ||||||||||||||||||||||||||
num_heads, | ||||||||||||||||||||||||||
hidden_dim, | ||||||||||||||||||||||||||
intermediate_dim, | ||||||||||||||||||||||||||
use_bias=True, | ||||||||||||||||||||||||||
activation="gelu", | ||||||||||||||||||||||||||
dropout=0.1, | ||||||||||||||||||||||||||
dtype=None, | ||||||||||||||||||||||||||
max_sequence_length=1024, | ||||||||||||||||||||||||||
max_wavelength=10000, | ||||||||||||||||||||||||||
layer_norm_eps=1e-12, | ||||||||||||||||||||||||||
emb_layer_norm_before=False, | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instead of |
||||||||||||||||||||||||||
position_embedding_type="rotary", | ||||||||||||||||||||||||||
pad_token_id=0, | ||||||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||
support_positon_type = ["rotary", "absolute"] | ||||||||||||||||||||||||||
if position_embedding_type.lower() not in support_positon_type: | ||||||||||||||||||||||||||
raise ( | ||||||||||||||||||||||||||
f"This model only support below position embedding type: {support_positon_type}" # noqa: E501 | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
Comment on lines
+94
to
+98
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
head_size = hidden_dim // num_heads | ||||||||||||||||||||||||||
# === Layers === | ||||||||||||||||||||||||||
self.token_embedding = keras.layers.Embedding( | ||||||||||||||||||||||||||
input_dim=vocabulary_size, | ||||||||||||||||||||||||||
output_dim=hidden_dim, | ||||||||||||||||||||||||||
embeddings_initializer=esm2_kernel_initializer(), | ||||||||||||||||||||||||||
dtype=dtype, | ||||||||||||||||||||||||||
name="token_embedding", | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
if position_embedding_type == "absolute": | ||||||||||||||||||||||||||
self.position_embedding = PositionEmbedding( | ||||||||||||||||||||||||||
initializer=esm2_kernel_initializer(), | ||||||||||||||||||||||||||
sequence_length=max_sequence_length, | ||||||||||||||||||||||||||
dtype=dtype, | ||||||||||||||||||||||||||
name="position_embedding", | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
self.embeddings_add = keras.layers.Add( | ||||||||||||||||||||||||||
dtype=dtype, | ||||||||||||||||||||||||||
name="embeddings_add", | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
self.output_layer_norm = keras.layers.LayerNormalization( | ||||||||||||||||||||||||||
epsilon=layer_norm_eps, | ||||||||||||||||||||||||||
dtype=dtype, | ||||||||||||||||||||||||||
name="output_layer_norm", | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
if emb_layer_norm_before: | ||||||||||||||||||||||||||
self.emb_layer_norm = keras.layers.LayerNormalization( | ||||||||||||||||||||||||||
epsilon=layer_norm_eps, | ||||||||||||||||||||||||||
dtype=dtype, | ||||||||||||||||||||||||||
name="emb_layer_norm", | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
self.transformer_layers = [] | ||||||||||||||||||||||||||
for i in range(num_layers): | ||||||||||||||||||||||||||
layer = ESMEncoder( | ||||||||||||||||||||||||||
heads=num_heads, | ||||||||||||||||||||||||||
head_size=head_size, | ||||||||||||||||||||||||||
intermediate_size=intermediate_dim, | ||||||||||||||||||||||||||
use_bias=use_bias, | ||||||||||||||||||||||||||
max_wavelength=max_wavelength, | ||||||||||||||||||||||||||
dropout=dropout, | ||||||||||||||||||||||||||
activation=activation, | ||||||||||||||||||||||||||
kernel_initializer=esm2_kernel_initializer(), | ||||||||||||||||||||||||||
layer_norm_eps=layer_norm_eps, | ||||||||||||||||||||||||||
dtype=dtype, | ||||||||||||||||||||||||||
use_rotary=position_embedding_type == "rotary", | ||||||||||||||||||||||||||
name=f"transformer_layer_{i}", | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
self.transformer_layers.append(layer) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# === Functional Model === | ||||||||||||||||||||||||||
token_id_input = keras.Input( | ||||||||||||||||||||||||||
shape=(None,), dtype="int32", name="token_ids" | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
attention_mask = keras.ops.not_equal(token_id_input, pad_token_id) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
token_vector = self.token_embedding(token_id_input) | ||||||||||||||||||||||||||
if position_embedding_type == "absolute": | ||||||||||||||||||||||||||
position_vector = self.position_embedding( | ||||||||||||||||||||||||||
token_vector, start_index=pad_token_id | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
x = self.embeddings_add([token_vector, position_vector]) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
x = token_vector | ||||||||||||||||||||||||||
if emb_layer_norm_before: | ||||||||||||||||||||||||||
x = self.emb_layer_norm(x) | ||||||||||||||||||||||||||
for transformer_layer in self.transformer_layers: | ||||||||||||||||||||||||||
x = transformer_layer(x, attention_mask=attention_mask) | ||||||||||||||||||||||||||
output = self.output_layer_norm(x) | ||||||||||||||||||||||||||
super().__init__( | ||||||||||||||||||||||||||
inputs={ | ||||||||||||||||||||||||||
"token_ids": token_id_input, | ||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||
outputs=output, | ||||||||||||||||||||||||||
dtype=dtype, | ||||||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# === Config === | ||||||||||||||||||||||||||
self.vocabulary_size = vocabulary_size | ||||||||||||||||||||||||||
self.num_layers = num_layers | ||||||||||||||||||||||||||
self.num_heads = num_heads | ||||||||||||||||||||||||||
self.hidden_dim = hidden_dim | ||||||||||||||||||||||||||
self.intermediate_dim = intermediate_dim | ||||||||||||||||||||||||||
self.dropout = dropout | ||||||||||||||||||||||||||
self.max_wavelength = max_wavelength | ||||||||||||||||||||||||||
self.head_size = head_size | ||||||||||||||||||||||||||
self.dropout = dropout | ||||||||||||||||||||||||||
self.activation = activations.get(activation) | ||||||||||||||||||||||||||
self.use_bias = use_bias | ||||||||||||||||||||||||||
self.start_token_index = 0 | ||||||||||||||||||||||||||
self.layer_norm_eps = layer_norm_eps | ||||||||||||||||||||||||||
self.max_sequence_length = max_sequence_length | ||||||||||||||||||||||||||
self.emb_layer_norm_before = emb_layer_norm_before | ||||||||||||||||||||||||||
self.position_embedding_type = position_embedding_type | ||||||||||||||||||||||||||
self.pad_token_id = pad_token_id | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def get_config(self): | ||||||||||||||||||||||||||
config = super().get_config() | ||||||||||||||||||||||||||
config.update( | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
"vocabulary_size": self.vocabulary_size, | ||||||||||||||||||||||||||
"num_layers": self.num_layers, | ||||||||||||||||||||||||||
"num_heads": self.num_heads, | ||||||||||||||||||||||||||
"hidden_dim": self.hidden_dim, | ||||||||||||||||||||||||||
"intermediate_dim": self.intermediate_dim, | ||||||||||||||||||||||||||
"dropout": self.dropout, | ||||||||||||||||||||||||||
"max_wavelength": self.max_wavelength, | ||||||||||||||||||||||||||
"use_bias": self.use_bias, | ||||||||||||||||||||||||||
"activation": activations.serialize(self.activation), | ||||||||||||||||||||||||||
"layer_norm_eps": self.layer_norm_eps, | ||||||||||||||||||||||||||
"emb_layer_norm_before": self.emb_layer_norm_before, | ||||||||||||||||||||||||||
"position_embedding_type": self.position_embedding_type, | ||||||||||||||||||||||||||
"max_sequence_length": self.max_sequence_length, | ||||||||||||||||||||||||||
"pad_token_id": self.pad_token_id, | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
return config |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.models.esm.esm_backbone import ESMBackbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class ESMBackboneTest(TestCase): | ||
def setUp(self): | ||
self.init_kwargs = { | ||
"vocabulary_size": 10, | ||
"num_layers": 2, | ||
"num_heads": 1, | ||
"hidden_dim": 2, | ||
"intermediate_dim": 4, | ||
} | ||
self.input_data = { | ||
"token_ids": ops.ones((2, 5), dtype="int32"), | ||
"segment_ids": ops.zeros((2, 5), dtype="int32"), | ||
} | ||
|
||
def test_backbone_basics(self): | ||
if keras.__version__ < "3.6": | ||
self.skipTest("Failing on keras lower version") | ||
Comment on lines
+23
to
+24
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are the tests failing due to some bug which was addressed in 3.6 release? |
||
self.run_backbone_test( | ||
cls=ESMBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape=(2, 5, 2), | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add more tests to run saved model test and test_smallest_preset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use generic name --> esm_kernel_initializer, if this is used for both esm and esm2.