diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 4fb3b3cf00..41f1a47284 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -129,6 +129,7 @@ GPTNeoXPreprocessor, ) from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_nlp.src.models.image_classifier import ImageClassifier from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( @@ -194,6 +195,8 @@ from keras_nlp.src.models.t5.t5_backbone import T5Backbone from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer from keras_nlp.src.models.task import Task +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone +from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) diff --git a/keras_nlp/src/models/image_classifier.py b/keras_nlp/src/models/image_classifier.py new file mode 100644 index 0000000000..f0cc031dbc --- /dev/null +++ b/keras_nlp/src/models/image_classifier.py @@ -0,0 +1,90 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.task import Task + + +@keras_nlp_export("keras_nlp.models.ImageClassifier") +class ImageClassifier(Task): + """Base class for all image classification tasks. + + `ImageClassifier` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + image classification. `ImageClassifier` tasks take an additional + `num_classes` argument, controlling the number of predicted output classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + + All `ImageClassifier` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Default compilation. + self.compile() + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `ImageClassifier` task for training. + + The `ImageClassifier` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.SparseCategoricalCrossentropy` loss will be + applied for the classification task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a list of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.SparseCategoricalAccuracy` will be + applied to track the accuracy of the model during training. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(5e-5) + if loss == "auto": + activation = getattr(self, "activation", None) + activation = keras.activations.get(activation) + from_logits = activation != keras.activations.softmax + loss = keras.losses.SparseCategoricalCrossentropy(from_logits) + if metrics == "auto": + metrics = [keras.metrics.SparseCategoricalAccuracy()] + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) diff --git a/keras_nlp/src/models/vgg/__init__.py b/keras_nlp/src/models/vgg/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/models/vgg/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/models/vgg/vgg_backbone.py b/keras_nlp/src/models/vgg/vgg_backbone.py new file mode 100644 index 0000000000..497381c0fc --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_backbone.py @@ -0,0 +1,159 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras +from keras import layers + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + + +@keras_nlp_export("keras_nlp.models.VGGBackbone") +class VGGBackbone(Backbone): + """ + This class represents Keras Backbone of VGG model. + + This class implements a VGG backbone as described in [Very Deep + Convolutional Networks for Large-Scale Image Recognition]( + https://arxiv.org/abs/1409.1556)(ICLR 2015). + + Args: + stackwise_num_repeats: list of ints, number of repeated convolutional + blocks per VGG block. For VGG16 this is [2, 2, 3, 3, 3] and for + VGG19 this is [2, 2, 4, 4, 4]. + stackwise_num_filters: list of ints, filter size for convolutional + blocks per VGG block. For both VGG16 and VGG19 this is [ + 64, 128, 256, 512, 512]. + include_rescaling: bool, whether to rescale the inputs. If set to + True, inputs will be passed through a `Rescaling(1/255.0)` layer. + input_shape: tuple, optional shape tuple, defaults to (224, 224, 3). + pooling: bool, Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + + Examples: + ```python + input_data = np.ones((2, 224, 224, 3), dtype="float32") + + # Pretrained VGG backbone. + model = keras_nlp.models.VGGBackbone.from_preset("vgg16") + model(input_data) + + # Randomly initialized VGG backbone with a custom config. + model = keras_nlp.models.VGGBackbone( + stackwise_num_repeats = [2, 2, 3, 3, 3], + stackwise_num_filters = [64, 128, 256, 512, 512], + input_shape = (224, 224, 3), + include_rescaling = False, + pooling = "avg", + ) + model(input_data) + ``` + """ + + def __init__( + self, + stackwise_num_repeats, + stackwise_num_filters, + include_rescaling, + input_image_shape=(224, 224, 3), + pooling="avg", + **kwargs, + ): + + # === Functional Model === + img_input = keras.layers.Input(shape=input_image_shape) + x = img_input + + if include_rescaling: + x = layers.Rescaling(scale=1 / 255.0)(x) + for stack_index in range(len(stackwise_num_repeats) - 1): + x = apply_vgg_block( + x=x, + num_layers=stackwise_num_repeats[stack_index], + filters=stackwise_num_filters[stack_index], + kernel_size=(3, 3), + activation="relu", + padding="same", + max_pool=True, + name=f"block{stack_index + 1}", + ) + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + + super().__init__(inputs=img_input, outputs=x, **kwargs) + + # === Config === + self.stackwise_num_repeats = stackwise_num_repeats + self.stackwise_num_filters = stackwise_num_filters + self.include_rescaling = include_rescaling + self.input_image_shape = input_image_shape + self.pooling = pooling + + def get_config(self): + return { + "stackwise_num_repeats": self.stackwise_num_repeats, + "stackwise_num_filters": self.stackwise_num_filters, + "include_rescaling": self.include_rescaling, + "input_image_shape": self.input_image_shape, + "pooling": self.pooling, + } + + +def apply_vgg_block( + x, + num_layers, + filters, + kernel_size, + activation, + padding, + max_pool, + name, +): + """ + Applies VGG block + Args: + x: Tensor, input tensor to pass through network + num_layers: int, number of CNN layers in the block + filters: int, filter size of each CNN layer in block + kernel_size: int (or) tuple, kernel size for CNN layer in block + activation: str (or) callable, activation function for each CNN layer in + block + padding: str (or) callable, padding function for each CNN layer in block + max_pool: bool, whether to add MaxPooling2D layer at end of block + name: str, name of the block + + Returns: + keras.KerasTensor + """ + for num in range(1, num_layers + 1): + x = layers.Conv2D( + filters, + kernel_size, + activation=activation, + padding=padding, + name=f"{name}_conv{num}", + )(x) + if max_pool: + x = layers.MaxPooling2D((2, 2), (2, 2), name=f"{name}_pool")(x) + return x diff --git a/keras_nlp/src/models/vgg/vgg_backbone_test.py b/keras_nlp/src/models/vgg/vgg_backbone_test.py new file mode 100644 index 0000000000..05ed33ba0f --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_backbone_test.py @@ -0,0 +1,48 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class VGGBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stackwise_num_repeats": [2, 3, 3], + "stackwise_num_filters": [8, 64, 64], + "input_image_shape": (16, 16, 3), + "include_rescaling": False, + "pooling": "avg", + } + self.input_data = np.ones((2, 16, 16, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_backbone_test( + cls=VGGBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 64), + run_mixed_precision_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=VGGBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/vgg/vgg_image_classifier.py b/keras_nlp/src/models/vgg/vgg_image_classifier.py new file mode 100644 index 0000000000..a26fbfbc30 --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_image_classifier.py @@ -0,0 +1,124 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone + + +@keras_nlp_export("keras_nlp.models.VGGImageClassifier") +class VGGImageClassifier(ImageClassifier): + """VGG16 image classifier task model. + + Args: + backbone: A `keras_nlp.models.VGGBackbone` instance. + num_classes: int, number of classes to predict. + pooling: str, type of pooling layer. Must be one of "avg", "max". + activation: Optional `str` or callable, defaults to "softmax". The + activation function to use on the Dense layer. Set `activation=None` + to return the output logits. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Examples: + Train from preset + ```python + # Load preset and train + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + classifier = keras_nlp.models.VGGImageClassifier.from_preset( + 'vgg_16_image_classifier') + classifier.fit(x=images, y=labels, batch_size=2) + + # Re-compile (e.g., with a new learning rate). + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + jit_compile=True, + ) + + # Access backbone programmatically (e.g., to change `trainable`). + classifier.backbone.trainable = False + # Fit again. + classifier.fit(x=images, y=labels, batch_size=2) + ``` + Custom backbone + ```python + images = np.ones((2, 224, 224, 3), dtype="float32") + labels = [0, 3] + + backbone = keras_nlp.models.VGGBackbone( + stackwise_num_repeats = [2, 2, 3, 3, 3], + stackwise_num_filters = [64, 128, 256, 512, 512], + input_shape = (224, 224, 3), + include_rescaling = False, + pooling = "avg", + ) + classifier = keras_nlp.models.VGGImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + + backbone_cls = VGGBackbone + + def __init__( + self, + backbone, + num_classes, + activation="softmax", + preprocessor=None, # adding this dummy arg for saved model test + # TODO: once preprocessor flow is figured out, this needs to be updated + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + + # Instantiate using Functional API Model constructor + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "activation": self.activation, + } + ) + return config diff --git a/keras_nlp/src/models/vgg/vgg_image_classifier_test.py b/keras_nlp/src/models/vgg/vgg_image_classifier_test.py new file mode 100644 index 0000000000..4a2573e496 --- /dev/null +++ b/keras_nlp/src/models/vgg/vgg_image_classifier_test.py @@ -0,0 +1,61 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone +from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier +from keras_nlp.src.tests.test_case import TestCase + + +class VGGImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 4, 4, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = VGGBackbone( + stackwise_num_repeats=[2, 4, 4], + stackwise_num_filters=[2, 16, 16], + input_image_shape=(4, 4, 3), + include_rescaling=False, + pooling="max", + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=VGGImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=VGGImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 7e8e0cec95..fc1ce77e1e 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -419,20 +419,22 @@ def run_backbone_test( self.assertEqual(output[key].shape, expected_output_shape[key]) else: self.assertEqual(output.shape, expected_output_shape) - - # Check we can embed tokens eagerly. - output = backbone.token_embedding(ops.zeros((2, 3), dtype="int32")) - - # Check variable length sequences. - if variable_length_data is None: - # If no variable length data passed, assume the second axis of all - # inputs is our sequence axis and create it ourselves. - variable_length_data = [ - tree.map_structure(lambda x: x[:, :seq_length, ...], input_data) - for seq_length in (2, 3, 4) - ] - for batch in variable_length_data: - backbone(batch) + if backbone.token_embedding is not None: + # Check we can embed tokens eagerly. + output = backbone.token_embedding(ops.zeros((2, 3), dtype="int32")) + + # Check variable length sequences. + if variable_length_data is None: + # If no variable length data passed, assume the second axis of all + # inputs is our sequence axis and create it ourselves. + variable_length_data = [ + tree.map_structure( + lambda x: x[:, :seq_length, ...], input_data + ) + for seq_length in (2, 3, 4) + ] + for batch in variable_length_data: + backbone(batch) # Check compiled predict function. backbone.predict(input_data)