Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@
MistralPreprocessor,
)
from keras_nlp.src.models.mistral.mistral_tokenizer import MistralTokenizer
from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import (
MiTBackbone,
)
from keras_nlp.src.models.mix_transformer.mix_transformer_classifier import (
MixTransformerImageClassifier,
)
from keras_nlp.src.models.opt.opt_backbone import OPTBackbone
from keras_nlp.src.models.opt.opt_causal_lm import OPTCausalLM
from keras_nlp.src.models.opt.opt_causal_lm_preprocessor import (
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/src/models/mix_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
157 changes: 157 additions & 0 deletions keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras
import numpy as np
from keras import ops

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
from keras_nlp.src.models.mix_transformer.mix_transformer_layers import (
HierarchicalTransformerEncoder,
)
from keras_nlp.src.models.mix_transformer.mix_transformer_layers import (
OverlappingPatchingAndEmbedding,
)


@keras_nlp_export("keras_nlp.models.MiTBackbone")
class MiTBackbone(FeaturePyramidBackbone):
def __init__(
self,
depths,
include_rescaling=True,
input_image_shape=(224, 224, 3),
embedding_dims=None,
**kwargs,
):
"""A Backbone implementing the MixTransformer.

This architecture to be used as a backbone for the SegFormer
architecture [SegFormer: Simple and Efficient Design for Semantic
Segmentation with Transformers](https://arxiv.org/abs/2105.15203)
[Based on the TensorFlow implementation from DeepVision](
https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer)

Args:
depths: the number of transformer encoders to be used per stage in the
network.
include_rescaling: bool, whether to rescale the inputs. If set
to `True`, inputs will be passed through a `Rescaling(1/255.0)`
layer. Defaults to `True`.
input_image_shape: optional shape tuple, defaults to (224, 224, 3).
embedding_dims: the embedding dims per hierarchical stage, used as
the levels of the feature pyramid

Examples:

Using the class with a `backbone`:

```python
images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_nlp.models.MiTBackbone.from_preset("mit_b0_imagenet")

# Evaluate model
model(images)

# Train model
model.compile(
optimizer="adam",
loss=keras.losses.BinaryCrossentropy(from_logits=False),
metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
```
"""
drop_path_rate = 0.1
dpr = [x for x in np.linspace(0.0, drop_path_rate, sum(depths))]
blockwise_num_heads = [1, 2, 5, 8]
blockwise_sr_ratios = [8, 4, 2, 1]
num_stages = 4

# === Layers ===
cur = 0
patch_embedding_layers = []
transformer_blocks = []
layer_norms = []

for i in range(num_stages):
patch_embed_layer = OverlappingPatchingAndEmbedding(
project_dim=embedding_dims[i],
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
name=f"patch_and_embed_{i}",
)
patch_embedding_layers.append(patch_embed_layer)

transformer_block = [
HierarchicalTransformerEncoder(
project_dim=embedding_dims[i],
num_heads=blockwise_num_heads[i],
sr_ratio=blockwise_sr_ratios[i],
drop_prob=dpr[cur + k],
name=f"hierarchical_encoder_{i}_{k}",
)
for k in range(depths[i])
]
transformer_blocks.append(transformer_block)
cur += depths[i]
layer_norms.append(keras.layers.LayerNormalization())

# === Functional Model ===
image_input = keras.layers.Input(shape=input_image_shape)
x = image_input

if include_rescaling:
x = keras.layers.Rescaling(scale=1 / 255)(x)

pyramid_outputs = {}
for i in range(num_stages):
# Compute new height/width after the `proj`
# call in `OverlappingPatchingAndEmbedding`
stride = 4 if i == 0 else 2
new_height, new_width = (
int(ops.shape(x)[1] / stride),
int(ops.shape(x)[2] / stride),
)

x = patch_embedding_layers[i](x)
for blk in transformer_blocks[i]:
x = blk(x)
x = layer_norms[i](x)
x = keras.layers.Reshape(
(new_height, new_width, -1), name=f"output_level_{i}"
)(x)
pyramid_outputs[f"P{i + 1}"] = x

super().__init__(inputs=image_input, outputs=x, **kwargs)

# === Config ===
self.depths = depths
self.include_rescaling = include_rescaling
self.input_image_shape = input_image_shape
self.embedding_dims = embedding_dims
self.pyramid_outputs = pyramid_outputs

def get_config(self):
config = super().get_config()
config.update(
{
"depths": self.depths,
"include_rescaling": self.include_rescaling,
"embedding_dims": self.embedding_dims,
"input_image_shape": self.input_image_shape,
}
)
return config
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest
from keras import models

from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import (
MiTBackbone,
)
from keras_nlp.src.tests.test_case import TestCase


class MiTBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"depths": [2, 2, 2, 2],
"include_rescaling": True,
"input_image_shape": (64, 64, 3),
"embedding_dims": [32, 64, 160, 256],
}
self.input_size = 32
self.input_data = np.ones((2, 64, 64, 3), dtype="float32")

def test_backbone_basics(self):
self.run_backbone_test(
cls=MiTBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 2, 2, 256),
run_quantization_check=False,
run_mixed_precision_check=False,
)

def test_pyramid_output_format(self):
init_kwargs = self.init_kwargs
backbone = MiTBackbone(**init_kwargs)
model = models.Model(backbone.inputs, backbone.pyramid_outputs)
output_data = model(self.input_data)

self.assertIsInstance(output_data, dict)
self.assertEqual(
list(output_data.keys()), list(backbone.pyramid_outputs.keys())
)
self.assertEqual(list(output_data.keys()), ["P1", "P2", "P3", "P4"])
for k, v in output_data.items():
size = self.input_size // (2 ** int(k[1:]))
self.assertEqual(tuple(v.shape[:3]), (2, size, size))

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=MiTBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
133 changes: 133 additions & 0 deletions keras_nlp/src/models/mix_transformer/mix_transformer_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.image_classifier import ImageClassifier
from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import (
MiTBackbone,
)


@keras_nlp_export("keras_nlp.models.MixTransformerImageClassifier")
class MixTransformerImageClassifier(ImageClassifier):
"""MixTransformerImageClassifier image classifier model.

Args:
backbone: A `keras_nlp.models.MiTBackbone` instance.
num_classes: int. The number of classes to predict.
activation: `None`, str or callable. The activation function to use on
the `Dense` layer. Set `activation=None` to return the output
logits. Defaults to `"softmax"`.

To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
where `x` is a tensor and `y` is a integer from `[0, num_classes)`.
All `ImageClassifier` tasks include a `from_preset()` constructor which can
be used to load a pre-trained config and weights.

Examples:

Call `predict()` to run inference.
```python
# Load preset and train
images = np.ones((2, 224, 224, 3), dtype="float32")
classifier = keras_nlp.models.MixTransformerImageClassifier.from_preset(
"mit_b0_imagenet")
classifier.predict(images)
```

Call `fit()` on a single batch.
```python
# Load preset and train
images = np.ones((2, 224, 224, 3), dtype="float32")
labels = [0, 3]
classifier = keras_nlp.models.MixTransformerImageClassifier.from_preset(
"mit_b0_imagenet")
classifier.fit(x=images, y=labels, batch_size=2)
```

Call `fit()` with custom loss, optimizer and backbone.
```python
classifier = keras_nlp.models.MixTransformerImageClassifier.from_preset(
"mit_b0_imagenet")
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
)
classifier.backbone.trainable = False
classifier.fit(x=images, y=labels, batch_size=2)
```

Custom backbone.
```python
images = np.ones((2, 224, 224, 3), dtype="float32")
labels = [0, 3]
backbone = keras_nlp.models.MiTBackbone(
stackwise_num_filters=[128, 256, 512, 1024],
stackwise_depth=[3, 9, 9, 3],
include_rescaling=False,
block_type="basic_block",
input_image_shape = (224, 224, 3),
)
classifier = keras_nlp.models.MixTransformerImageClassifier(
backbone=backbone,
num_classes=4,
)
classifier.fit(x=images, y=labels, batch_size=2)
```
"""

backbone_cls = MiTBackbone

def __init__(
self,
backbone,
num_classes,
activation="softmax",
preprocessor=None, # adding this dummy arg for saved model test
# TODO: once preprocessor flow is figured out, this needs to be updated
**kwargs,
):
# === Layers ===
self.backbone = backbone
self.output_dense = keras.layers.Dense(
num_classes,
activation=activation,
name="predictions",
)

# === Functional Model ===
inputs = self.backbone.input
x = self.backbone(inputs)
outputs = self.output_dense(x)
super().__init__(
inputs=inputs,
outputs=outputs,
**kwargs,
)

# === Config ===
self.num_classes = num_classes
self.activation = activation

def get_config(self):
# Backbone serialized in `super`
config = super().get_config()
config.update(
{
"num_classes": self.num_classes,
"activation": self.activation,
}
)
return config
Loading