Skip to content

Commit ad137cd

Browse files
authored
[Model] Port over CLIPVisionModel for VLMs (#5591)
1 parent 111af1f commit ad137cd

File tree

9 files changed

+269
-21
lines changed

9 files changed

+269
-21
lines changed

csrc/activation_kernels.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
135135
return ((T)0.5) * x * (((T)1.0) + t);
136136
}
137137

138+
template <typename T>
139+
__device__ __forceinline__ T gelu_quick_kernel(const T& x) {
140+
// x * sigmoid(1.702 * x)
141+
return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
142+
}
143+
138144
} // namespace vllm
139145

140146
void gelu_new(torch::Tensor& out, // [..., d]
@@ -148,3 +154,9 @@ void gelu_fast(torch::Tensor& out, // [..., d]
148154
{
149155
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
150156
}
157+
158+
void gelu_quick(torch::Tensor& out, // [..., d]
159+
torch::Tensor& input) // [..., d]
160+
{
161+
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
162+
}

csrc/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input);
4949

5050
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
5151

52+
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
53+
5254
#ifndef USE_ROCM
5355
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
5456
const torch::Tensor& codebooks,

csrc/torch_bindings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
6868
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
6969
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
7070

71+
// Quick GELU implementation.
72+
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
73+
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
74+
7175
// Layernorm
7276
// Apply Root Mean Square (RMS) Normalization to the input tensor.
7377
ops.def(

vllm/_custom_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
6666
torch.ops._C.gelu_new(out, x)
6767

6868

69+
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
70+
torch.ops._C.gelu_quick(out, x)
71+
72+
6973
# page attention ops
7074
def paged_attention_v1(
7175
out: torch.Tensor,

vllm/model_executor/layers/activation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,21 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
141141
return out
142142

143143

144+
class QuickGELU(CustomOp):
145+
146+
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
147+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
148+
"""PyTorch-native implementation equivalent to forward()."""
149+
return x * torch.sigmoid(1.702 * x)
150+
151+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
152+
from vllm import _custom_ops as ops
153+
154+
out = torch.empty_like(x)
155+
ops.gelu_quick(out, x)
156+
return out
157+
158+
144159
class ScaledActivation(nn.Module):
145160
"""An activation function with post-scale parameters.
146161
@@ -189,6 +204,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
189204
"gelu_new": NewGELU(),
190205
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
191206
"relu": nn.ReLU(),
207+
"quick_gelu": QuickGELU(),
192208
}
193209

194210

vllm/model_executor/models/clip.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""Minimal implementation of CLIPVisionModel intended to be only used
2+
within a vision language model."""
3+
from typing import Optional, Tuple
4+
5+
import torch
6+
import torch.nn as nn
7+
from transformers import CLIPVisionConfig
8+
from transformers.models.clip.modeling_clip import CLIPAttention
9+
10+
from vllm.model_executor.layers.activation import get_act_fn
11+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
12+
RowParallelLinear)
13+
from vllm.model_executor.layers.quantization.base_config import (
14+
QuantizationConfig)
15+
16+
17+
def get_clip_num_patches(image_size: int, patch_size: int) -> int:
18+
assert image_size % patch_size == 0
19+
return (image_size // patch_size)**2
20+
21+
22+
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
23+
class CLIPVisionEmbeddings(nn.Module):
24+
25+
def __init__(self, config: CLIPVisionConfig):
26+
super().__init__()
27+
self.config = config
28+
self.embed_dim = config.hidden_size
29+
self.image_size = config.image_size
30+
self.patch_size = config.patch_size
31+
32+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
33+
34+
self.patch_embedding = nn.Conv2d(
35+
in_channels=config.num_channels,
36+
out_channels=self.embed_dim,
37+
kernel_size=self.patch_size,
38+
stride=self.patch_size,
39+
bias=False,
40+
)
41+
42+
self.num_patches = get_clip_num_patches(self.image_size,
43+
self.patch_size)
44+
self.num_positions = self.num_patches + 1
45+
self.position_embedding = nn.Embedding(self.num_positions,
46+
self.embed_dim)
47+
self.register_buffer("position_ids",
48+
torch.arange(self.num_positions).expand((1, -1)),
49+
persistent=False)
50+
51+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
52+
batch_size = pixel_values.shape[0]
53+
target_dtype = self.patch_embedding.weight.dtype
54+
patch_embeds = self.patch_embedding(pixel_values.to(
55+
dtype=target_dtype)) # shape = [*, width, grid, grid]
56+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
57+
58+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
59+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
60+
embeddings = embeddings + self.position_embedding(self.position_ids)
61+
62+
return embeddings
63+
64+
65+
class CLIPMLP(nn.Module):
66+
67+
def __init__(self,
68+
config: CLIPVisionConfig,
69+
quant_config: Optional[QuantizationConfig] = None):
70+
super().__init__()
71+
self.config = config
72+
self.activation_fn = get_act_fn(config.hidden_act)
73+
self.fc1 = ColumnParallelLinear(config.hidden_size,
74+
config.intermediate_size,
75+
bias=True,
76+
quant_config=quant_config)
77+
self.fc2 = RowParallelLinear(config.intermediate_size,
78+
config.hidden_size,
79+
bias=True,
80+
quant_config=quant_config)
81+
82+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
83+
hidden_states, _ = self.fc1(hidden_states)
84+
hidden_states = self.activation_fn(hidden_states)
85+
hidden_states, _ = self.fc2(hidden_states)
86+
87+
return hidden_states
88+
89+
90+
class CLIPEncoderLayer(nn.Module):
91+
92+
def __init__(self,
93+
config: CLIPVisionConfig,
94+
quant_config: Optional[QuantizationConfig] = None):
95+
super().__init__()
96+
97+
self.self_attn = CLIPAttention(config)
98+
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
99+
eps=config.layer_norm_eps)
100+
self.mlp = CLIPMLP(config, quant_config=quant_config)
101+
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
102+
eps=config.layer_norm_eps)
103+
104+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
105+
106+
residual = hidden_states
107+
108+
hidden_states = self.layer_norm1(hidden_states)
109+
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
110+
hidden_states = residual + hidden_states
111+
112+
residual = hidden_states
113+
hidden_states = self.layer_norm2(hidden_states)
114+
hidden_states = self.mlp(hidden_states)
115+
hidden_states = residual + hidden_states
116+
117+
return hidden_states
118+
119+
120+
class CLIPEncoder(nn.Module):
121+
"""
122+
Transformer encoder consisting of `config.num_hidden_layers` self
123+
attention layers. Each layer is a [`CLIPEncoderLayer`].
124+
125+
Args:
126+
config: CLIPConfig
127+
"""
128+
129+
def __init__(self,
130+
config: CLIPVisionConfig,
131+
quant_config: Optional[QuantizationConfig] = None):
132+
super().__init__()
133+
self.config = config
134+
self.layers = nn.ModuleList([
135+
CLIPEncoderLayer(config=config, quant_config=quant_config)
136+
for _ in range(config.num_hidden_layers)
137+
])
138+
139+
def forward(self,
140+
inputs_embeds: torch.Tensor,
141+
vision_feature_layer: int = -1):
142+
143+
# Encoder forward pass only up to the required layer
144+
num_layer = len(self.layers) + vision_feature_layer + 1
145+
hidden_states = inputs_embeds
146+
for encoder_layer in self.layers[:num_layer]:
147+
hidden_states = encoder_layer(hidden_states)
148+
149+
return hidden_states
150+
151+
152+
class CLIPVisionTransformer(nn.Module):
153+
154+
def __init__(self,
155+
config: CLIPVisionConfig,
156+
quant_config: Optional[QuantizationConfig] = None):
157+
super().__init__()
158+
self.config = config
159+
embed_dim = config.hidden_size
160+
161+
self.embeddings = CLIPVisionEmbeddings(config)
162+
163+
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
164+
# the original transformers code and name of the model weights.
165+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
166+
self.encoder = CLIPEncoder(config=config, quant_config=quant_config)
167+
168+
def forward(
169+
self,
170+
pixel_values: torch.Tensor,
171+
vision_feature_layer: int = -1,
172+
) -> torch.Tensor:
173+
174+
hidden_states = self.embeddings(pixel_values)
175+
hidden_states = self.pre_layrnorm(hidden_states)
176+
hidden_states = self.encoder(inputs_embeds=hidden_states,
177+
vision_feature_layer=vision_feature_layer)
178+
179+
return hidden_states
180+
181+
182+
class CLIPVisionModel(nn.Module):
183+
184+
config_class = CLIPVisionConfig
185+
main_input_name = "pixel_values"
186+
187+
def __init__(self,
188+
config: CLIPVisionConfig,
189+
quant_config: Optional[QuantizationConfig] = None):
190+
super().__init__()
191+
self.vision_model = CLIPVisionTransformer(config=config,
192+
quant_config=quant_config)
193+
194+
def forward(self,
195+
pixel_values: Optional[torch.Tensor] = None,
196+
vision_feature_layer: int = -1):
197+
198+
return self.vision_model(pixel_values=pixel_values,
199+
vision_feature_layer=vision_feature_layer)
200+
201+
@property
202+
def device(self):
203+
return next(self.parameters()).device

vllm/model_executor/models/llava.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
import torch
44
import torch.nn as nn
5-
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
6-
# transformers' impl.
7-
from transformers import CLIPVisionModel, LlavaConfig
5+
from transformers import LlavaConfig
86

97
from vllm.attention import AttentionMetadata
108
from vllm.config import CacheConfig, VisionLanguageConfig
@@ -15,6 +13,7 @@
1513
from vllm.model_executor.layers.sampler import Sampler
1614
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
1715
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
16+
from vllm.model_executor.models.clip import CLIPVisionModel
1817
from vllm.model_executor.models.llama import LlamaModel
1918
from vllm.model_executor.sampling_metadata import SamplingMetadata
2019
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -189,12 +188,11 @@ def _select_image_features(self, image_features: torch.Tensor, *,
189188

190189
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
191190
pixel_values: torch.Tensor) -> torch.Tensor:
192-
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
193-
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
194-
output_hidden_states=True)
195191

196-
image_features = image_outputs.hidden_states[
197-
self.config.vision_feature_layer]
192+
# NOTE: we skip the step to select the vision feature layer since
193+
# this is already done inside the vision tower
194+
image_features = vision_tower(pixel_values.to(vision_tower.device),
195+
self.config.vision_feature_layer)
198196

199197
return self._select_image_features(
200198
image_features,
@@ -317,6 +315,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
317315
for name, loaded_weight in weights:
318316
if "rotary_emb.inv_freq" in name:
319317
continue
318+
# post_layernorm is not needed in CLIPVisionModel
319+
if "vision_model.post_layernorm" in name:
320+
continue
320321
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
321322
if key_to_modify in name:
322323
name = name.replace(key_to_modify, new_key)

vllm/model_executor/models/llava_next.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
import torch
55
import torch.nn as nn
66
from PIL import Image
7-
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
8-
# transformers' impl.
9-
from transformers import CLIPVisionModel, LlavaNextConfig
7+
from transformers import LlavaNextConfig
108
from transformers.models.llava_next.modeling_llava_next import (
119
get_anyres_image_grid_shape, unpad_image)
1210
from typing_extensions import NotRequired
@@ -20,6 +18,7 @@
2018
from vllm.model_executor.layers.sampler import Sampler
2119
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
2220
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21+
from vllm.model_executor.models.clip import CLIPVisionModel
2322
from vllm.model_executor.models.llama import LlamaModel
2423
from vllm.model_executor.sampling_metadata import SamplingMetadata
2524
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
@@ -121,7 +120,7 @@ def __init__(self,
121120

122121
if self.vision_language_config.image_input_type == (
123122
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
124-
self.vision_tower = CLIPVisionModel(config.vision_config)
123+
self.vision_tower = CLIPVisionModel(config=config.vision_config)
125124
else:
126125
raise TypeError("Image features are not supported by LLaVA-NeXT")
127126

@@ -219,12 +218,11 @@ def _select_image_features(self, image_features: torch.Tensor, *,
219218

220219
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
221220
pixel_values: torch.Tensor) -> torch.Tensor:
222-
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
223-
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
224-
output_hidden_states=True)
225221

226-
image_features = image_outputs.hidden_states[
227-
self.config.vision_feature_layer]
222+
# NOTE: we skip the step to select the vision feature layer since
223+
# this is already done inside the vision tower
224+
image_features = vision_tower(pixel_values.to(vision_tower.device),
225+
self.config.vision_feature_layer)
228226

229227
return self._select_image_features(
230228
image_features,
@@ -430,6 +428,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
430428
for name, loaded_weight in weights:
431429
if "rotary_emb.inv_freq" in name:
432430
continue
431+
# post_layernorm is not needed in CLIPVisionModel
432+
if "vision_model.post_layernorm" in name:
433+
continue
433434
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
434435
if key_to_modify in name:
435436
name = name.replace(key_to_modify, new_key)

0 commit comments

Comments
 (0)