diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 783cfd5087..6916bf4288 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -106,6 +106,7 @@ ) from keras_nlp.src.models.falcon.falcon_preprocessor import FalconPreprocessor from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_nlp.src.models.gemma.gemma_causal_lm_preprocessor import ( diff --git a/keras_nlp/src/models/backbone.py b/keras_nlp/src/models/backbone.py index a58072dfce..0f41c63c81 100644 --- a/keras_nlp/src/models/backbone.py +++ b/keras_nlp/src/models/backbone.py @@ -30,6 +30,7 @@ from keras_nlp.src.utils.preset_utils import save_metadata from keras_nlp.src.utils.preset_utils import save_serialized_object from keras_nlp.src.utils.python_utils import classproperty +from keras_nlp.src.utils.timm.convert import load_timm_backbone from keras_nlp.src.utils.transformers.convert import load_transformers_backbone @@ -204,6 +205,8 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from if format == "transformers": return load_transformers_backbone(cls, preset, load_weights) + elif format == "timm": + return load_timm_backbone(cls, preset, load_weights, **kwargs) preset_cls = check_config_class(preset) if not issubclass(preset_cls, cls): diff --git a/keras_nlp/src/models/feature_pyramid_backbone.py b/keras_nlp/src/models/feature_pyramid_backbone.py new file mode 100644 index 0000000000..989d9fbd64 --- /dev/null +++ b/keras_nlp/src/models/feature_pyramid_backbone.py @@ -0,0 +1,73 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.backbone import Backbone + + +@keras_nlp_export("keras_nlp.models.FeaturePyramidBackbone") +class FeaturePyramidBackbone(Backbone): + """A backbone with feature pyramid outputs. + + `FeaturePyramidBackbone` extends `Backbone` with a single `pyramid_outputs` + property for accessing the feature pyramid outputs of the model. Subclassers + should set the `pyramid_outputs` property during the model constructor. + + Example: + + ```python + input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3)) + + # Convert to feature pyramid output format using ResNet. + backbone = ResNetBackbone.from_preset("resnet50") + model = keras.Model( + inputs=backbone.inputs, outputs=backbone.pyramid_outputs + ) + model(input_data) # A dict containing the keys ["P2", "P3", "P4", "P5"] + ``` + """ + + @property + def pyramid_outputs(self): + """A dict for feature pyramid outputs. + + The key is a string represents the name of the feature output and the + value is a `keras.KerasTensor`. A typical feature pyramid has multiple + levels corresponding to scales such as `["P2", "P3", "P4", "P5"]`. Scale + `Pn` represents a feature map `2^n` times smaller in width and height + than the inputs. + """ + return getattr(self, "_pyramid_outputs", {}) + + @pyramid_outputs.setter + def pyramid_outputs(self, value): + if not isinstance(value, dict): + raise TypeError( + "`pyramid_outputs` must be a dictionary. " + f"Received: value={value} of type {type(value)}" + ) + for k, v in value.items(): + if not isinstance(k, str): + raise TypeError( + "The key of `pyramid_outputs` must be a string. " + f"Received: key={k} of type {type(k)}" + ) + if not isinstance(v, keras.KerasTensor): + raise TypeError( + "The value of `pyramid_outputs` must be a " + "`keras.KerasTensor`. " + f"Received: value={v} of type {type(v)}" + ) + self._pyramid_outputs = value diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index bec5ba60b5..0f4d7c139a 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -13,20 +13,23 @@ # limitations under the License. import keras from keras import layers +from keras import ops from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_nlp.src.utils.keras_utils import standardize_data_format @keras_nlp_export("keras_nlp.models.ResNetBackbone") -class ResNetBackbone(Backbone): +class ResNetBackbone(FeaturePyramidBackbone): """ResNet and ResNetV2 core network with hyperparameters. This class implements a ResNet backbone as described in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)( - CVPR 2016) and [Identity Mappings in Deep Residual Networks]( - https://arxiv.org/abs/1603.05027)(ECCV 2016). + CVPR 2016), [Identity Mappings in Deep Residual Networks]( + https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An + improved training procedure in timm](https://arxiv.org/abs/2110.00476)( + NeurIPS 2021 Workshop). The difference in ResNet and ResNetV2 rests in the structure of their individual building blocks. In ResNetV2, the batch normalization and @@ -34,6 +37,9 @@ class ResNetBackbone(Backbone): the batch normalization and ReLU activation are applied after the convolution layers. + Note that `ResNetBackbone` expects the inputs to be images with a value + range of `[0, 255]` when `include_rescaling=True`. + Args: stackwise_num_filters: list of ints. The number of filters for each stack. @@ -46,8 +52,8 @@ class ResNetBackbone(Backbone): use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. include_rescaling: boolean. If `True`, rescale the input using - `Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to - `True`. + `Rescaling` and `Normalization` layers. If `False`, do nothing. + Defaults to `True`. input_image_shape: tuple. The input shape without the batch size. Defaults to `(None, None, 3)`. pooling: `None` or str. Pooling mode for feature extraction. Defaults @@ -70,11 +76,11 @@ class ResNetBackbone(Backbone): `~/.keras/keras.json`. If you never set it, then it will be `"channels_last"`. dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype - to use for the models computations and weights. + to use for the model's computations and weights. Examples: ```python - input_data = np.ones((2, 224, 224, 3), dtype="float32") + input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3)) # Pretrained ResNet backbone. model = keras_nlp.models.ResNetBackbone.from_preset("resnet50") @@ -136,34 +142,66 @@ def __init__( image_input = layers.Input(shape=input_image_shape) if include_rescaling: x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input) + x = layers.Normalization( + axis=bn_axis, + mean=(0.485, 0.456, 0.406), + variance=(0.229**2, 0.224**2, 0.225**2), + dtype=dtype, + name="normalization", + )(x) else: x = image_input + # The padding between torch and tensorflow/jax differs when `strides>1`. + # Therefore, we need to manually pad the tensor. + x = layers.ZeroPadding2D( + 3, + data_format=data_format, + dtype=dtype, + name="conv1_pad", + )(x) x = layers.Conv2D( 64, 7, strides=2, - padding="same", data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name="conv1_conv", )(x) if not use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="conv1_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name="conv1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x) - x = layers.MaxPool2D( + if use_pre_activation: + # A workaround for ResNetV2: we need -inf padding to prevent zeros + # from being the max values in the following `MaxPooling2D`. + pad_width = [[1, 1], [1, 1]] + if data_format == "channels_last": + pad_width += [[0, 0]] + else: + pad_width = [[0, 0]] + pad_width + pad_width = [[0, 0]] + pad_width + x = ops.pad(x, pad_width=pad_width, constant_values=float("-inf")) + else: + x = layers.ZeroPadding2D( + 1, data_format=data_format, dtype=dtype, name="pool1_pad" + )(x) + x = layers.MaxPooling2D( 3, strides=2, - padding="same", data_format=data_format, dtype=dtype, name="pool1_pool", )(x) + pyramid_outputs = {} for stack_index in range(num_stacks): x = apply_stack( x, @@ -179,10 +217,15 @@ def __init__( dtype=dtype, name=f"{version}_stack{stack_index}", ) + pyramid_outputs[f"P{stack_index + 2}"] = x if use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name="post_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name="post_bn", )(x) x = layers.Activation("relu", dtype=dtype, name="post_relu")(x) @@ -213,18 +256,23 @@ def __init__( self.include_rescaling = include_rescaling self.input_image_shape = input_image_shape self.pooling = pooling + self.pyramid_outputs = pyramid_outputs def get_config(self): - return { - "stackwise_num_filters": self.stackwise_num_filters, - "stackwise_num_blocks": self.stackwise_num_blocks, - "stackwise_num_strides": self.stackwise_num_strides, - "block_type": self.block_type, - "use_pre_activation": self.use_pre_activation, - "include_rescaling": self.include_rescaling, - "input_image_shape": self.input_image_shape, - "pooling": self.pooling, - } + config = super().get_config() + config.update( + { + "stackwise_num_filters": self.stackwise_num_filters, + "stackwise_num_blocks": self.stackwise_num_blocks, + "stackwise_num_strides": self.stackwise_num_strides, + "block_type": self.block_type, + "use_pre_activation": self.use_pre_activation, + "include_rescaling": self.include_rescaling, + "input_image_shape": self.input_image_shape, + "pooling": self.pooling, + } + ) + return config def apply_basic_block( @@ -269,68 +317,81 @@ def apply_basic_block( if use_pre_activation: x_preact = layers.BatchNormalization( axis=bn_axis, - epsilon=1.001e-5, + epsilon=1e-5, + momentum=0.9, dtype=dtype, - name=f"{name}_use_preactivation_bn", + name=f"{name}_pre_activation_bn", )(x) x_preact = layers.Activation( - "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" )(x_preact) if conv_shortcut: + x = x_preact if x_preact is not None else x shortcut = layers.Conv2D( filters, 1, strides=stride, data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x_preact if x_preact is not None else x) + )(x) if not use_pre_activation: shortcut = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", )(shortcut) else: - if not use_pre_activation or stride == 1: - shortcut = x - else: - shortcut = layers.MaxPooling2D( - 1, - strides=stride, - data_format=data_format, - dtype=dtype, - name=f"{name}_0_max_pooling", - )(x) + shortcut = x + x = x_preact if x_preact is not None else x + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_1_pad", + )(x) x = layers.Conv2D( filters, kernel_size, - strides=stride if not use_pre_activation else 1, - padding="same", + strides=stride, + padding="valid" if stride > 1 else "same", data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_1_conv", - )(x_preact if x_preact is not None else x) + )(x) x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + x = layers.Conv2D( filters, kernel_size, - strides=1 if not use_pre_activation else stride, + strides=1, padding="same", data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_2_conv", )(x) - if not use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", )(x) x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) @@ -381,79 +442,97 @@ def apply_bottleneck_block( if use_pre_activation: x_preact = layers.BatchNormalization( axis=bn_axis, - epsilon=1.001e-5, + epsilon=1e-5, + momentum=0.9, dtype=dtype, - name=f"{name}_use_preactivation_bn", + name=f"{name}_pre_activation_bn", )(x) x_preact = layers.Activation( - "relu", dtype=dtype, name=f"{name}_use_preactivation_relu" + "relu", dtype=dtype, name=f"{name}_pre_activation_relu" )(x_preact) if conv_shortcut: + x = x_preact if x_preact is not None else x shortcut = layers.Conv2D( 4 * filters, 1, strides=stride, data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name=f"{name}_0_conv", - )(x_preact if x_preact is not None else x) + )(x) if not use_pre_activation: shortcut = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_0_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_0_bn", )(shortcut) else: - if not use_pre_activation or stride == 1: - shortcut = x - else: - shortcut = layers.MaxPooling2D( - 1, - strides=stride, - data_format=data_format, - dtype=dtype, - name=f"{name}_0_max_pooling", - )(x) + shortcut = x + x = x_preact if x_preact is not None else x x = layers.Conv2D( filters, 1, - strides=stride if not use_pre_activation else 1, + strides=1, data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_1_conv", - )(x_preact if x_preact is not None else x) + )(x) x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_1_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_1_bn", )(x) x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x) + + if stride > 1: + x = layers.ZeroPadding2D( + (kernel_size - 1) // 2, + data_format=data_format, + dtype=dtype, + name=f"{name}_2_pad", + )(x) x = layers.Conv2D( filters, kernel_size, - strides=1 if not use_pre_activation else stride, - padding="same", + strides=stride, + padding="valid" if stride > 1 else "same", data_format=data_format, use_bias=False, dtype=dtype, name=f"{name}_2_conv", )(x) x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_2_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_2_bn", )(x) x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x) + x = layers.Conv2D( 4 * filters, 1, data_format=data_format, - use_bias=use_pre_activation, + use_bias=False, dtype=dtype, name=f"{name}_3_conv", )(x) - if not use_pre_activation: x = layers.BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, dtype=dtype, name=f"{name}_3_bn" + axis=bn_axis, + epsilon=1e-5, + momentum=0.9, + dtype=dtype, + name=f"{name}_3_bn", )(x) x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x]) x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x) @@ -513,32 +592,21 @@ def apply_stack( '`block_type` must be either `"basic_block"` or ' f'`"bottleneck_block"`. Received block_type={block_type}.' ) - x = block_fn( - x, - filters, - stride=stride if not use_pre_activation else 1, - conv_shortcut=first_shortcut, - use_pre_activation=use_pre_activation, - data_format=data_format, - dtype=dtype, - name=f"{name}_block1", - ) - for i in range(2, blocks): + for i in range(blocks): + if i == 0: + stride = stride + conv_shortcut = first_shortcut + else: + stride = 1 + conv_shortcut = False x = block_fn( x, filters, + stride=stride, + conv_shortcut=conv_shortcut, use_pre_activation=use_pre_activation, data_format=data_format, dtype=dtype, name=f"{name}_block{str(i)}", ) - x = block_fn( - x, - filters, - stride=1 if not use_pre_activation else stride, - use_pre_activation=use_pre_activation, - data_format=data_format, - dtype=dtype, - name=f"{name}_block{str(blocks)}", - ) return x diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index 2113bcd131..6d3f774559 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -14,6 +14,7 @@ import pytest from absl.testing import parameterized +from keras import models from keras import ops from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone @@ -29,8 +30,8 @@ def setUp(self): "input_image_shape": (None, None, 3), "pooling": "avg", } - self.input_size = (16, 16) - self.input_data = ops.ones((2, 16, 16, 3)) + self.input_size = 64 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) @parameterized.named_parameters( ("v1_basic", False, "basic_block"), @@ -52,6 +53,24 @@ def test_backbone_basics(self, use_pre_activation, block_type): ), ) + def test_pyramid_output_format(self): + init_kwargs = self.init_kwargs.copy() + init_kwargs.update( + {"block_type": "basic_block", "use_pre_activation": False} + ) + backbone = ResNetBackbone(**init_kwargs) + model = models.Model(backbone.inputs, backbone.pyramid_outputs) + output_data = model(self.input_data) + + self.assertIsInstance(output_data, dict) + self.assertEqual( + list(output_data.keys()), list(backbone.pyramid_outputs.keys()) + ) + self.assertEqual(list(output_data.keys()), ["P2", "P3", "P4"]) + for k, v in output_data.items(): + size = self.input_size // (2 ** int(k[1:])) + self.assertEqual(tuple(v.shape[:3]), (2, size, size)) + @parameterized.named_parameters( ("v1_basic", False, "basic_block"), ("v1_bottleneck", False, "bottleneck_block"), @@ -65,7 +84,7 @@ def test_saved_model(self, use_pre_activation, block_type): { "block_type": block_type, "use_pre_activation": use_pre_activation, - "input_image_shape": (16, 16, 3), + "input_image_shape": (None, None, 3), } ) self.run_model_saving_test( diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier.py b/keras_nlp/src/models/resnet/resnet_image_classifier.py index 02c8c78b27..815dc7fcca 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier.py @@ -28,6 +28,8 @@ class ResNetImageClassifier(ImageClassifier): activation: `None`, str or callable. The activation function to use on the `Dense` layer. Set `activation=None` to return the output logits. Defaults to `"softmax"`. + head_dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The + dtype to use for the classification head's computations and weights. To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` where `x` is a tensor and `y` is a integer from `[0, num_classes)`. @@ -92,16 +94,19 @@ def __init__( backbone, num_classes, activation="softmax", + head_dtype=None, preprocessor=None, # adding this dummy arg for saved model test # TODO: once preprocessor flow is figured out, this needs to be updated **kwargs, ): + head_dtype = head_dtype or backbone.dtype_policy + # === Layers === self.backbone = backbone self.output_dense = keras.layers.Dense( num_classes, activation=activation, - dtype=self.backbone.dtype_policy, + dtype=head_dtype, name="predictions", ) diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py index bbbda72d64..f3f63a14a1 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -53,6 +53,10 @@ def test_classifier_basics(self): expected_output_shape=(2, 2), ) + def test_head_dtype(self): + model = ResNetImageClassifier(**self.init_kwargs, head_dtype="bfloat16") + self.assertEqual(model.output_dense.compute_dtype, "bfloat16") + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index f797bf9f18..9e3f51c43a 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -544,6 +544,10 @@ def check_format(preset): if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists( preset, SAFETENSOR_CONFIG_FILE ): + # Determine the format by parsing the config file. + config = load_config(preset, HF_CONFIG_FILE) + if "hf://timm" in preset or "architecture" in config: + return "timm" return "transformers" if not check_file_exists(preset, METADATA_FILE): diff --git a/keras_nlp/src/utils/timm/__init__.py b/keras_nlp/src/utils/timm/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/src/utils/timm/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/src/utils/timm/convert.py b/keras_nlp/src/utils/timm/convert.py new file mode 100644 index 0000000000..edfde3316b --- /dev/null +++ b/keras_nlp/src/utils/timm/convert.py @@ -0,0 +1,37 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert timm models to KerasNLP.""" + +from keras_nlp.src.utils.timm.convert_resnet import load_resnet_backbone + + +def load_timm_backbone(cls, preset, load_weights, **kwargs): + """Load a timm model config and weights as a KerasNLP backbone. + + Args: + cls (class): Keras model class. + preset (str): Preset configuration name. + load_weights (bool): Whether to load the weights. + + Returns: + backbone: Initialized Keras model backbone. + """ + if cls is None: + raise ValueError("Backbone class is None") + if cls.__name__ == "ResNetBackbone": + return load_resnet_backbone(cls, preset, load_weights, **kwargs) + raise ValueError( + f"{cls} has not been ported from the Hugging Face format yet. " + "Please check Hugging Face Hub for the Keras model. " + ) diff --git a/keras_nlp/src/utils/timm/convert_resnet.py b/keras_nlp/src/utils/timm/convert_resnet.py new file mode 100644 index 0000000000..de2224eb9e --- /dev/null +++ b/keras_nlp/src/utils/timm/convert_resnet.py @@ -0,0 +1,171 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.utils.preset_utils import jax_memory_cleanup +from keras_nlp.src.utils.preset_utils import load_config +from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + + +def convert_backbone_config(timm_config): + timm_architecture = timm_config["architecture"] + + if "resnetv2_" in timm_architecture: + use_pre_activation = True + else: + use_pre_activation = False + + if timm_architecture == "resnet18": + stackwise_num_blocks = [2, 2, 2, 2] + block_type = "basic_block" + elif timm_architecture == "resnet26": + stackwise_num_blocks = [2, 2, 2, 2] + block_type = "bottleneck_block" + elif timm_architecture == "resnet34": + stackwise_num_blocks = [3, 4, 6, 3] + block_type = "basic_block" + elif timm_architecture in ("resnet50", "resnetv2_50"): + stackwise_num_blocks = [3, 4, 6, 3] + block_type = "bottleneck_block" + elif timm_architecture in ("resnet101", "resnetv2_101"): + stackwise_num_blocks = [3, 4, 23, 3] + block_type = "bottleneck_block" + elif timm_architecture in ("resnet152", "resnetv2_152"): + stackwise_num_blocks = [3, 8, 36, 3] + block_type = "bottleneck_block" + else: + raise ValueError( + f"Currently, the architecture {timm_architecture} is not supported." + ) + + return dict( + stackwise_num_filters=[64, 128, 256, 512], + stackwise_num_blocks=stackwise_num_blocks, + stackwise_num_strides=[1, 2, 2, 2], + block_type=block_type, + use_pre_activation=use_pre_activation, + ) + + +def convert_weights(backbone, loader, timm_config): + def port_conv2d(keras_layer_name, hf_weight_prefix): + loader.port_weight( + backbone.get_layer(keras_layer_name).kernel, + hf_weight_key=f"{hf_weight_prefix}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + def port_batch_normalization(keras_layer_name, hf_weight_prefix): + loader.port_weight( + backbone.get_layer(keras_layer_name).gamma, + hf_weight_key=f"{hf_weight_prefix}.weight", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).beta, + hf_weight_key=f"{hf_weight_prefix}.bias", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_mean, + hf_weight_key=f"{hf_weight_prefix}.running_mean", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_variance, + hf_weight_key=f"{hf_weight_prefix}.running_var", + ) + + version = "v1" if not backbone.use_pre_activation else "v2" + block_type = backbone.block_type + + # Stem + if version == "v1": + port_conv2d("conv1_conv", "conv1") + port_batch_normalization("conv1_bn", "bn1") + else: + port_conv2d("conv1_conv", "stem.conv") + + # Stages + num_stacks = len(backbone.stackwise_num_filters) + for stack_index in range(num_stacks): + for block_idx in range(backbone.stackwise_num_blocks[stack_index]): + if version == "v1": + keras_name = f"v1_stack{stack_index}_block{block_idx}" + hf_name = f"layer{stack_index+1}.{block_idx}" + else: + keras_name = f"v2_stack{stack_index}_block{block_idx}" + hf_name = f"stages.{stack_index}.blocks.{block_idx}" + + if version == "v1": + if block_idx == 0 and ( + block_type == "bottleneck_block" or stack_index > 0 + ): + port_conv2d( + f"{keras_name}_0_conv", f"{hf_name}.downsample.0" + ) + port_batch_normalization( + f"{keras_name}_0_bn", f"{hf_name}.downsample.1" + ) + port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") + port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.bn1") + port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") + port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.bn2") + if block_type == "bottleneck_block": + port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3") + port_batch_normalization( + f"{keras_name}_3_bn", f"{hf_name}.bn3" + ) + else: + if block_idx == 0 and ( + block_type == "bottleneck_block" or stack_index > 0 + ): + port_conv2d( + f"{keras_name}_0_conv", f"{hf_name}.downsample.conv" + ) + port_batch_normalization( + f"{keras_name}_pre_activation_bn", f"{hf_name}.norm1" + ) + port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") + port_batch_normalization( + f"{keras_name}_1_bn", f"{hf_name}.norm2" + ) + port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") + if block_type == "bottleneck_block": + port_batch_normalization( + f"{keras_name}_2_bn", f"{hf_name}.norm3" + ) + port_conv2d(f"{keras_name}_3_conv", f"{hf_name}.conv3") + + # Post + if version == "v2": + port_batch_normalization("post_bn", "norm") + + # Rebuild normalization layer with pretrained mean & std + mean = timm_config["pretrained_cfg"]["mean"] + std = timm_config["pretrained_cfg"]["std"] + normalization_layer = backbone.get_layer("normalization") + normalization_layer.input_mean = mean + normalization_layer.input_variance = [s**2 for s in std] + normalization_layer.build(normalization_layer._build_input_shape) + + +def load_resnet_backbone(cls, preset, load_weights, **kwargs): + timm_config = load_config(preset, HF_CONFIG_FILE) + keras_config = convert_backbone_config(timm_config) + backbone = cls(**keras_config, **kwargs) + if load_weights: + jax_memory_cleanup(backbone) + # Use prefix="" to avoid using `get_prefixed_key`. + with SafetensorLoader(preset, prefix="") as loader: + convert_weights(backbone, loader, timm_config) + return backbone diff --git a/keras_nlp/src/utils/timm/convert_resnet_test.py b/keras_nlp/src/utils/timm/convert_resnet_test.py new file mode 100644 index 0000000000..a30bee46af --- /dev/null +++ b/keras_nlp/src/utils/timm/convert_resnet_test.py @@ -0,0 +1,28 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from keras import ops + +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.tests.test_case import TestCase + + +class TimmResNetBackboneTest(TestCase): + @pytest.mark.large + def test_convert_resnet18_preset(self): + model = ResNetBackbone.from_preset("hf://timm/resnet18.a1_in1k") + outputs = model.predict(ops.ones((1, 224, 224, 3))) + self.assertEqual(outputs.shape, (1, 512)) + + # TODO: compare numerics with timm model diff --git a/keras_nlp/src/utils/transformers/safetensor_utils.py b/keras_nlp/src/utils/transformers/safetensor_utils.py index 40ef473ff3..2fbd7e1aba 100644 --- a/keras_nlp/src/utils/transformers/safetensor_utils.py +++ b/keras_nlp/src/utils/transformers/safetensor_utils.py @@ -26,7 +26,7 @@ class SafetensorLoader(contextlib.ExitStack): - def __init__(self, preset): + def __init__(self, preset, prefix=None): super().__init__() if safetensors is None: @@ -42,7 +42,7 @@ def __init__(self, preset): else: self.safetensor_config = None self.safetensor_files = {} - self.prefix = None + self.prefix = prefix def get_prefixed_key(self, hf_weight_key, dict_like): """