|
| 1 | +"""Minimal implementation of CLIPVisionModel intended to be only used |
| 2 | +within a vision language model.""" |
| 3 | +from typing import Optional, Tuple |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | +from transformers import CLIPVisionConfig |
| 8 | +from transformers.models.clip.modeling_clip import CLIPAttention |
| 9 | + |
| 10 | +from vllm.model_executor.layers.activation import get_act_fn |
| 11 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
| 12 | + RowParallelLinear) |
| 13 | +from vllm.model_executor.layers.quantization.base_config import ( |
| 14 | + QuantizationConfig) |
| 15 | + |
| 16 | + |
| 17 | +def get_clip_num_patches(image_size: int, patch_size: int) -> int: |
| 18 | + assert image_size % patch_size == 0 |
| 19 | + return (image_size // patch_size)**2 |
| 20 | + |
| 21 | + |
| 22 | +# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa |
| 23 | +class CLIPVisionEmbeddings(nn.Module): |
| 24 | + |
| 25 | + def __init__(self, config: CLIPVisionConfig): |
| 26 | + super().__init__() |
| 27 | + self.config = config |
| 28 | + self.embed_dim = config.hidden_size |
| 29 | + self.image_size = config.image_size |
| 30 | + self.patch_size = config.patch_size |
| 31 | + |
| 32 | + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) |
| 33 | + |
| 34 | + self.patch_embedding = nn.Conv2d( |
| 35 | + in_channels=config.num_channels, |
| 36 | + out_channels=self.embed_dim, |
| 37 | + kernel_size=self.patch_size, |
| 38 | + stride=self.patch_size, |
| 39 | + bias=False, |
| 40 | + ) |
| 41 | + |
| 42 | + self.num_patches = get_clip_num_patches(self.image_size, |
| 43 | + self.patch_size) |
| 44 | + self.num_positions = self.num_patches + 1 |
| 45 | + self.position_embedding = nn.Embedding(self.num_positions, |
| 46 | + self.embed_dim) |
| 47 | + self.register_buffer("position_ids", |
| 48 | + torch.arange(self.num_positions).expand((1, -1)), |
| 49 | + persistent=False) |
| 50 | + |
| 51 | + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| 52 | + batch_size = pixel_values.shape[0] |
| 53 | + target_dtype = self.patch_embedding.weight.dtype |
| 54 | + patch_embeds = self.patch_embedding(pixel_values.to( |
| 55 | + dtype=target_dtype)) # shape = [*, width, grid, grid] |
| 56 | + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) |
| 57 | + |
| 58 | + class_embeds = self.class_embedding.expand(batch_size, 1, -1) |
| 59 | + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) |
| 60 | + embeddings = embeddings + self.position_embedding(self.position_ids) |
| 61 | + |
| 62 | + return embeddings |
| 63 | + |
| 64 | + |
| 65 | +class CLIPMLP(nn.Module): |
| 66 | + |
| 67 | + def __init__(self, |
| 68 | + config: CLIPVisionConfig, |
| 69 | + quant_config: Optional[QuantizationConfig] = None): |
| 70 | + super().__init__() |
| 71 | + self.config = config |
| 72 | + self.activation_fn = get_act_fn(config.hidden_act) |
| 73 | + self.fc1 = ColumnParallelLinear(config.hidden_size, |
| 74 | + config.intermediate_size, |
| 75 | + bias=True, |
| 76 | + quant_config=quant_config) |
| 77 | + self.fc2 = RowParallelLinear(config.intermediate_size, |
| 78 | + config.hidden_size, |
| 79 | + bias=True, |
| 80 | + quant_config=quant_config) |
| 81 | + |
| 82 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 83 | + hidden_states, _ = self.fc1(hidden_states) |
| 84 | + hidden_states = self.activation_fn(hidden_states) |
| 85 | + hidden_states, _ = self.fc2(hidden_states) |
| 86 | + |
| 87 | + return hidden_states |
| 88 | + |
| 89 | + |
| 90 | +class CLIPEncoderLayer(nn.Module): |
| 91 | + |
| 92 | + def __init__(self, |
| 93 | + config: CLIPVisionConfig, |
| 94 | + quant_config: Optional[QuantizationConfig] = None): |
| 95 | + super().__init__() |
| 96 | + |
| 97 | + self.self_attn = CLIPAttention(config) |
| 98 | + self.layer_norm1 = nn.LayerNorm(config.hidden_size, |
| 99 | + eps=config.layer_norm_eps) |
| 100 | + self.mlp = CLIPMLP(config, quant_config=quant_config) |
| 101 | + self.layer_norm2 = nn.LayerNorm(config.hidden_size, |
| 102 | + eps=config.layer_norm_eps) |
| 103 | + |
| 104 | + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: |
| 105 | + |
| 106 | + residual = hidden_states |
| 107 | + |
| 108 | + hidden_states = self.layer_norm1(hidden_states) |
| 109 | + hidden_states, _ = self.self_attn(hidden_states=hidden_states) |
| 110 | + hidden_states = residual + hidden_states |
| 111 | + |
| 112 | + residual = hidden_states |
| 113 | + hidden_states = self.layer_norm2(hidden_states) |
| 114 | + hidden_states = self.mlp(hidden_states) |
| 115 | + hidden_states = residual + hidden_states |
| 116 | + |
| 117 | + return hidden_states |
| 118 | + |
| 119 | + |
| 120 | +class CLIPEncoder(nn.Module): |
| 121 | + """ |
| 122 | + Transformer encoder consisting of `config.num_hidden_layers` self |
| 123 | + attention layers. Each layer is a [`CLIPEncoderLayer`]. |
| 124 | +
|
| 125 | + Args: |
| 126 | + config: CLIPConfig |
| 127 | + """ |
| 128 | + |
| 129 | + def __init__(self, |
| 130 | + config: CLIPVisionConfig, |
| 131 | + quant_config: Optional[QuantizationConfig] = None): |
| 132 | + super().__init__() |
| 133 | + self.config = config |
| 134 | + self.layers = nn.ModuleList([ |
| 135 | + CLIPEncoderLayer(config=config, quant_config=quant_config) |
| 136 | + for _ in range(config.num_hidden_layers) |
| 137 | + ]) |
| 138 | + |
| 139 | + def forward(self, |
| 140 | + inputs_embeds: torch.Tensor, |
| 141 | + vision_feature_layer: int = -1): |
| 142 | + |
| 143 | + # Encoder forward pass only up to the required layer |
| 144 | + num_layer = len(self.layers) + vision_feature_layer + 1 |
| 145 | + hidden_states = inputs_embeds |
| 146 | + for encoder_layer in self.layers[:num_layer]: |
| 147 | + hidden_states = encoder_layer(hidden_states) |
| 148 | + |
| 149 | + return hidden_states |
| 150 | + |
| 151 | + |
| 152 | +class CLIPVisionTransformer(nn.Module): |
| 153 | + |
| 154 | + def __init__(self, |
| 155 | + config: CLIPVisionConfig, |
| 156 | + quant_config: Optional[QuantizationConfig] = None): |
| 157 | + super().__init__() |
| 158 | + self.config = config |
| 159 | + embed_dim = config.hidden_size |
| 160 | + |
| 161 | + self.embeddings = CLIPVisionEmbeddings(config) |
| 162 | + |
| 163 | + # NOTE: This typo of "layrnorm" is not fixed on purpose to match |
| 164 | + # the original transformers code and name of the model weights. |
| 165 | + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) |
| 166 | + self.encoder = CLIPEncoder(config=config, quant_config=quant_config) |
| 167 | + |
| 168 | + def forward( |
| 169 | + self, |
| 170 | + pixel_values: torch.Tensor, |
| 171 | + vision_feature_layer: int = -1, |
| 172 | + ) -> torch.Tensor: |
| 173 | + |
| 174 | + hidden_states = self.embeddings(pixel_values) |
| 175 | + hidden_states = self.pre_layrnorm(hidden_states) |
| 176 | + hidden_states = self.encoder(inputs_embeds=hidden_states, |
| 177 | + vision_feature_layer=vision_feature_layer) |
| 178 | + |
| 179 | + return hidden_states |
| 180 | + |
| 181 | + |
| 182 | +class CLIPVisionModel(nn.Module): |
| 183 | + |
| 184 | + config_class = CLIPVisionConfig |
| 185 | + main_input_name = "pixel_values" |
| 186 | + |
| 187 | + def __init__(self, |
| 188 | + config: CLIPVisionConfig, |
| 189 | + quant_config: Optional[QuantizationConfig] = None): |
| 190 | + super().__init__() |
| 191 | + self.vision_model = CLIPVisionTransformer(config=config, |
| 192 | + quant_config=quant_config) |
| 193 | + |
| 194 | + def forward(self, |
| 195 | + pixel_values: Optional[torch.Tensor] = None, |
| 196 | + vision_feature_layer: int = -1): |
| 197 | + |
| 198 | + return self.vision_model(pixel_values=pixel_values, |
| 199 | + vision_feature_layer=vision_feature_layer) |
| 200 | + |
| 201 | + @property |
| 202 | + def device(self): |
| 203 | + return next(self.parameters()).device |
0 commit comments