From 5cd58d16af0cd1aa358b4287bebbd4129086527a Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Sun, 5 Jan 2025 18:14:48 +0200 Subject: [PATCH 1/2] Enable long-contexts + LoRA support for Intel Gaudi Signed-off-by: Sanju C Sudhakaran --- vllm/lora/punica_wrapper/utils.py | 26 +++++++++++++++---- .../model_executor/layers/rotary_embedding.py | 3 ++- vllm/worker/hpu_model_runner.py | 17 ++++++++++-- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index dbc2d27c597f..b1759a489223 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -88,10 +88,18 @@ def convert_mapping( embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() long_lora_offsets: Optional[torch.Tensor] = None + + # Updating each element in `long_lora_offsets` with `lora_offset` slows + # down perf in HPU due to a series of `strided_insert` ops during lazy + # graph accumulation. Hence HPU appends `lora_offset` to a list and + # converts it to a tensor only after it is ready. if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) + if device == torch.device("hpu"): + long_lora_offsets_list: List[int] = [] + else: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=device, + dtype=torch.long) prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping @@ -104,10 +112,18 @@ def convert_mapping( embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx if long_lora_context: - assert long_lora_offsets is not None lora_offset: int = long_lora_context.offsets_by_lora_id.get( index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset + if device == torch.device("hpu"): + long_lora_offsets_list.append(lora_offset) + else: + assert long_lora_offsets is not None + long_lora_offsets[i] = lora_offset + + if long_lora_context and device == torch.device("hpu"): + long_lora_offsets = torch.tensor(long_lora_offsets_list, + device=device, + dtype=torch.long) indices_list: List[Union[List[int], torch.Tensor]] = [ index_mapping_indices, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index b3b9b0e87605..2fffcb7b6e6e 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -206,9 +206,10 @@ def forward_hpu( ) -> Tuple[torch.Tensor, torch.Tensor]: from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb) - positions = positions.flatten() if offsets is not None: + offsets = offsets.view(positions.shape[0], -1) positions = positions + offsets + positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions).view( num_tokens, 1, -1) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b846d4387ba5..774049a5281e 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -639,12 +639,25 @@ def load_model(self) -> None: "Bias support in LoRA is not enabled in HPU yet." assert not self.lora_config.fully_sharded_loras, \ "Fully sharded LoRAs is not enabled in HPU yet." + # It's necessary to distinguish between the + # max_position_embeddings of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = ( + self.model.config.max_position_embeddings) + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, - self.vocab_size, self.lora_config, self.device, + self.vocab_size, + self.lora_config, + self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) self.model = self.lora_manager.create_lora_manager(self.model) if self.model_config.quantization == 'inc': From 25dd63da1f4396db8305da4058bf9f4363d273eb Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Fri, 7 Feb 2025 11:48:33 +0200 Subject: [PATCH 2/2] Handle long-context + lora explicitly after convert_mapping Signed-off-by: Sanju C Sudhakaran --- vllm/lora/punica_wrapper/punica_hpu.py | 57 +++++++++++++++++++++++++- vllm/lora/punica_wrapper/utils.py | 26 +++--------- 2 files changed, 61 insertions(+), 22 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_hpu.py b/vllm/lora/punica_wrapper/punica_hpu.py index 51e1bfab3f51..3661a7214648 100644 --- a/vllm/lora/punica_wrapper/punica_hpu.py +++ b/vllm/lora/punica_wrapper/punica_hpu.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple, Union, final +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final import torch from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, dispatch_bgmv_linear) from .punica_base import PunicaWrapperBase +from .utils import convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext @final @@ -19,6 +25,55 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens, max_batches, device) + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size, + extra_vocab_size, self.device, None) + # Updating each element in `long_lora_offsets` with `lora_offset` slows + # down perf in HPU due to a series of `strided_insert` ops during lazy + # graph accumulation. Hence HPU appends `lora_offset` to a list and + # converts it to a tensor only after it is ready. + if long_lora_context: + index_mapping_indices: List[int] = list( + mapping.index_mapping).copy() + long_lora_offsets: List[int] = [] + for i in range(len(index_mapping_indices)): + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets.append(lora_offset) + long_lora_offsets_tensor = torch.tensor(long_lora_offsets, + device=self.device, + dtype=torch.long) + indices_len[-1] = long_lora_offsets_tensor.shape[-1] + + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + def add_lora_embedding(self, y: torch.Tensor, x: torch.Tensor, diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index b1759a489223..dbc2d27c597f 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -88,18 +88,10 @@ def convert_mapping( embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() long_lora_offsets: Optional[torch.Tensor] = None - - # Updating each element in `long_lora_offsets` with `lora_offset` slows - # down perf in HPU due to a series of `strided_insert` ops during lazy - # graph accumulation. Hence HPU appends `lora_offset` to a list and - # converts it to a tensor only after it is ready. if long_lora_context: - if device == torch.device("hpu"): - long_lora_offsets_list: List[int] = [] - else: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=device, + dtype=torch.long) prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping @@ -112,18 +104,10 @@ def convert_mapping( embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx if long_lora_context: + assert long_lora_offsets is not None lora_offset: int = long_lora_context.offsets_by_lora_id.get( index_mapping_indices[i], 0) - if device == torch.device("hpu"): - long_lora_offsets_list.append(lora_offset) - else: - assert long_lora_offsets is not None - long_lora_offsets[i] = lora_offset - - if long_lora_context and device == torch.device("hpu"): - long_lora_offsets = torch.tensor(long_lora_offsets_list, - device=device, - dtype=torch.long) + long_lora_offsets[i] = lora_offset indices_list: List[Union[List[int], torch.Tensor]] = [ index_mapping_indices,