Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _read_requirements(filename: str) -> list[str]:
extras_require={},
entry_points={
"vllm.platform_plugins": ["hpu = vllm_gaudi:register"],
"vllm.general_plugins": ["hpu_custom_ops = vllm_gaudi:register_ops"],
"vllm.general_plugins":
["hpu_custom_ops = vllm_gaudi:register_ops", "hpu_custom_models = vllm_gaudi:register_models"],
},
)
6 changes: 6 additions & 0 deletions vllm_gaudi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,9 @@ def register_ops():
import vllm_gaudi.ops.hpu_gptq # noqa: F401
import vllm_gaudi.ops.hpu_awq # noqa: F401
import vllm_gaudi.ops.hpu_multihead_attn # noqa: F401


def register_models():
import vllm_gaudi.models.utils # noqa: F401
from .models import register_model
register_model()
10 changes: 7 additions & 3 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,9 +570,13 @@ def forward(

common_args = self.common_attention_args(block_list, key_cache, value_cache, attn_metadata.block_size)

if self.sliding_window and hasattr(attn_metadata,
'window_attn_bias') and attn_metadata.window_attn_bias is not None:
attn_bias = attn_metadata.window_attn_bias
if self.sliding_window:
if hasattr(attn_metadata, 'window_attn_bias') and attn_metadata.window_attn_bias is not None:
attn_bias = attn_metadata.window_attn_bias
else:
attn_bias = None
window_size = (self.sliding_window, 0)
common_args['window_size'] = window_size

out = ops.prompt_attention(impl=self.prefill_impl,
query=query.view(query_shape),
Expand Down
25 changes: 24 additions & 1 deletion vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size, max_num_bat
self.fallback_seq_base_step = 32
self.fallback_blocks_base_step = 32

self.use_sliding_window = get_config().PT_HPU_SDPA_QKV_SLICE_MODE_FWD
if self.use_sliding_window:
self.slice_size = get_config().PT_HPU_SDPA_BC_FACTOR if \
get_config().PT_HPU_SDPA_BC_FACTOR is not None else 1024
self.slice_thld = get_config().VLLM_FUSEDSDPA_SLIDE_THLD if \
get_config().VLLM_FUSEDSDPA_SLIDE_THLD is not None else 8192

msg = (
f"use_sliding_window {self.use_sliding_window}, slice_size {self.slice_size}, threshold {self.slice_thld}"
)
logger().info(msg)

### GENERATE BUCKETS FUNCTIONS ###

def get_bucketing_strategy(self):
Expand Down Expand Up @@ -120,6 +132,13 @@ def generate_prompt_buckets(self):
self.max_num_seqs, self.max_num_prefill_seqs,
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks)
self.log_generate_info(True)
if self.use_sliding_window:
self.prompt_buckets = [
t for t in self.prompt_buckets
if t[2] != 0 or (t[2] == 0 and (t[1] < self.slice_thld or
(t[1] >= self.slice_thld and t[1] % self.slice_size == 0)))
]
self.log_generate_info(True)
else:
logger().info("Bucketing is off - skipping prompt buckets generation")
self.prompt_buckets = []
Expand Down Expand Up @@ -164,7 +183,11 @@ def log_generate_info(self, is_prompt=False):
def generate_fallback_bucket(self, batch_size, seq_len, ctx):
assert self.max_num_batched_tokens is not None
new_batch_size = calc_fallback_value(batch_size, self.fallback_bs_base_step)
new_seq_len = min(calc_fallback_value(seq_len, self.fallback_seq_base_step), self.max_num_batched_tokens)
if self.use_sliding_window and seq_len >= self.slice_thld:
new_seq_len = math.ceil(seq_len / self.slice_size) * self.slice_size
else:
new_seq_len = min(calc_fallback_value(seq_len, self.fallback_seq_base_step), self.max_num_batched_tokens)

if self.num_hpu_blocks is None:
new_ctx = 0
else:
Expand Down
158 changes: 158 additions & 0 deletions vllm_gaudi/extension/bucketing/vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import os
from vllm.logger import init_logger

logger = init_logger(__name__)

MULTIMODAL_CONFIG = {
# Batch-based models
'gemma-3': {
'is_batch_based': True,
'buckets': [1, 2, 4, 8]
},

# Pixel-based models
'ovis2.5': {
'is_batch_based': False,
'buckets': [784, 1600, 3136, 4096, 6400, 7744, 9216, 12544]
}
}


class HPUVisionBucketManager:
'''
This class is used to bucket image tokens
'''

def __init__(self, model_name, is_batch_based=True):
config = self._get_multimodal_config(model_name)

self.is_batch_based = is_batch_based if is_batch_based is not None else config['is_batch_based']

envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "")
if envvar == 'None':
self.multimodal_buckets = None
else:
if envvar == "":
multimodal_buckets = config['buckets']
else:
multimodal_buckets = [int(x) for x in envvar.split(',')]
self.multimodal_buckets = self._process_buckets(multimodal_buckets)

def _get_multimodal_config(self, model_name):
"""Get configuration for model"""
model_name_lower = model_name.lower()

# Find matching config
for key, config in MULTIMODAL_CONFIG.items():
if key.replace('-', '').replace('.', '') in model_name_lower.replace('-', '').replace('.', ''):
return config

# Default config
logger.info("MultiModal bucket config file for {model_name} not found.")
return {'is_batch_based': True, 'buckets': [1, 2, 4, 8]}

def _process_buckets(self, buckets):
#TODO If there is any limitation(such as if batch bucket need to be aligned by n, then put the assert check here!)

return sorted(buckets)

def get_multimodal_bucket(self, curr_num_image_patches):
if self.multimodal_buckets is not None:
for mm_bucket in self.multimodal_buckets:
if curr_num_image_patches <= mm_bucket:
return mm_bucket
return curr_num_image_patches
else:
return 0

def find_factor(self, desired_patches, orig):
for i in range(orig + 1, desired_patches + 1):
if desired_patches % i == 0:
if i % 2 != 0:
continue
else:
return i
return None

def find_padding(self, h_orig, w_orig, desired_patches):
best_pad_h, best_pad_w = 0, 0
if desired_patches % h_orig == 0:
best_pad_h = 0
w_factor = desired_patches // h_orig
best_pad_w = w_factor - w_orig if (w_factor > w_orig and w_factor % 2 == 0) else 0
elif desired_patches % w_orig == 0:
best_pad_w = 0
h_factor = desired_patches // w_orig
best_pad_h = h_factor - h_orig if (h_factor > h_orig and h_factor % 2 == 0) else 0
elif desired_patches % h_orig != 0 and desired_patches % w_orig != 0:
if h_orig > w_orig:
w_factor = self.find_factor(desired_patches, w_orig)
if w_factor is not None:
best_pad_w = w_factor - w_orig
h_factor = desired_patches // w_factor
if h_factor > h_orig:
best_pad_h = h_factor - h_orig
else:
h_factor = self.find_factor(desired_patches, h_orig)
if h_factor is not None:
best_pad_h = h_factor - h_orig
w_factor = desired_patches // h_factor
if w_factor > w_orig:
best_pad_w = w_factor - w_orig

if (best_pad_h + h_orig) * (best_pad_w + w_orig) != desired_patches:
best_pad_h, best_pad_w = 0, 0

return best_pad_h, best_pad_w

def pad_multimodal_data(self, pixel_values, image_grid_thw):

import pdb
pdb.set_trace()
desired_number_of_pixels = self.get_multimodal_bucket(pixel_values.shape[0])
padding_len = desired_number_of_pixels - pixel_values.shape[0]
if padding_len <= 0:
return pixel_values, image_grid_thw

logger_msg = "Padding current number pixel " \
+ str(pixel_values.shape[0]) \
+ " to " \
+ str(desired_number_of_pixels)
logger.info(logger_msg)

h_orig, w_orig = image_grid_thw[0, 1].item(), image_grid_thw[0, 2].item()
pad_h, pad_w = self.find_padding(h_orig, w_orig, desired_number_of_pixels)
if pad_h == 0 and pad_w == 0:
return pixel_values, image_grid_thw

constant_value = -100
pixel_values = torch.cat([
pixel_values,
torch.ones((padding_len, pixel_values.shape[1]), device=pixel_values.device) * constant_value
])

image_grid_thw = torch.tensor([[1, h_orig + pad_h, w_orig + pad_w]],
device=image_grid_thw.device,
dtype=image_grid_thw.dtype)

assert image_grid_thw.prod(-1).sum() == desired_number_of_pixels
return pixel_values, image_grid_thw

def greedy_plan(self, batchsize, available_batchsizes):
# sort descending
available_batchsizes_sorted = sorted(available_batchsizes, key=lambda x: -x)
idx = 0
left_to_process = batchsize
result = []
while (left_to_process > 0 and idx < len(available_batchsizes_sorted)):
if available_batchsizes_sorted[idx] <= left_to_process:
result += [available_batchsizes_sorted[idx]]
left_to_process -= available_batchsizes_sorted[idx]
else:
idx += 1
if left_to_process > 0:
result += [available_batchsizes_sorted[-1]] # this will be padded
return result

def __repr__(self):
return str(self.multimodal_buckets)
5 changes: 5 additions & 0 deletions vllm_gaudi/extension/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def get_user_flags():
# Non-vllm flags that are also important to print
Env('EXPERIMENTAL_WEIGHT_SHARING', str),
Env('PT_HPU_WEIGHT_SHARING', str),

# Sliding window flags
Env('PT_HPU_SDPA_QKV_SLICE_MODE_FWD', boolean),
Env('PT_HPU_SDPA_BC_FACTOR', int),
Env('VLLM_FUSEDSDPA_SLIDE_THLD', int),
]
return to_dict(flags)

Expand Down
12 changes: 10 additions & 2 deletions vllm_gaudi/extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,12 @@ def _fsdpa_prompt_attention(query: torch.Tensor,
is_causal: bool,
attn_bias: Optional[torch.Tensor] = None,
valid_seq_lengths: Optional[torch.Tensor] = None,
window_size: Optional[int] = None,
**ignored_args) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
padding_side = 'right'
if get_config().fp32_softmax:
softmax_mode = 'fp32'
else:
Expand All @@ -317,8 +319,14 @@ def _fsdpa_prompt_attention(query: torch.Tensor,
# TODO: causal + attn_bias is not yet supported
is_causal = False
valid_seq_lengths = None
attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode,
valid_seq_lengths, 'right')

args = [
query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths,
padding_side
]
args += [window_size] if window_size else []
attn_weights = fsdpa_op(*args)

attn_weights = attn_weights.transpose(1, 2)
return attn_weights

Expand Down
21 changes: 8 additions & 13 deletions vllm_gaudi/extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,15 @@ def forward(
recompute_mode,
valid_sequence_lengths,
padding_side="left",
window_size=None,
):
return self._hpu_kernel_fsdpa.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side,
)
if window_size is not None:
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
recompute_mode, valid_sequence_lengths, padding_side, False, False,
window_size)
else:
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
recompute_mode, valid_sequence_lengths, padding_side)


def pad_list(input, target_len, val_generator):
Expand Down
9 changes: 9 additions & 0 deletions vllm_gaudi/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from vllm.model_executor.models.registry import ModelRegistry


def register_model():
from vllm_gaudi.models.gemma3_mm import HpuGemma3ForConditionalGeneration # noqa: F401

ModelRegistry.register_model(
"Gemma3ForConditionalGeneration", # Original architecture identifier in vLLM
"vllm_gaudi.models.gemma3_mm:HpuGemma3ForConditionalGeneration")
57 changes: 57 additions & 0 deletions vllm_gaudi/models/gemma3_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from vllm.config import VllmConfig
from vllm.model_executor.models.gemma3_mm import (Gemma3ForConditionalGeneration, Gemma3MultiModalProcessor,
Gemma3ProcessingInfo, Gemma3DummyInputsBuilder, Gemma3ImageInputs)
from vllm.multimodal import MULTIMODAL_REGISTRY


@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
info=Gemma3ProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder)
class HpuGemma3ForConditionalGeneration(Gemma3ForConditionalGeneration):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

# For HPU optimization, process the vision tower using buckets to reduce recipe recompilation overhead
def _process_image_input(self, image_input: Gemma3ImageInputs) -> list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
num_patches = image_input["num_patches"]

batch_breakdown = self.vision_bucket_manager.greedy_plan(pixel_values.shape[0],
self.vision_bucket_manager.multimodal_buckets)
start_idx = 0
image_embeds_multibatches = []

for i in batch_breakdown:
end_idx = start_idx + i
indices = torch.arange(start_idx, end_idx)
batch_sliced_pixel_values = torch.index_select(pixel_values, dim=0, index=indices)

image_features = self._image_pixels_to_features(
self.vision_tower,
batch_sliced_pixel_values,
)
image_embeds = self.multi_modal_projector(image_features)
image_embeds_multibatches += [image_embeds.clone()]
start_idx = end_idx
image_embeds = torch.cat(image_embeds_multibatches, dim=0)

return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())]

def greedy_plan(self, batchsize, available_batchsizes):
# sort descending
available_batchsizes_sorted = sorted(available_batchsizes, key=lambda x: -x)
idx = 0
left_to_process = batchsize
result = []
while (left_to_process > 0 and idx < len(available_batchsizes_sorted)):
if available_batchsizes_sorted[idx] <= left_to_process:
result += [available_batchsizes_sorted[idx]]
left_to_process -= available_batchsizes_sorted[idx]
else:
idx += 1
if left_to_process > 0:
result += [available_batchsizes_sorted[-1]] # this will be padded
return result
Loading