From b9013696b23dde372cccecdbaf69f0c852008844 Mon Sep 17 00:00:00 2001 From: Shomy Date: Thu, 25 Jul 2024 20:54:15 +0000 Subject: [PATCH 1/3] optimizations for process output step --- setup_cython.py | 1 + vllm/core/block_manager_v1.py | 12 ++++++---- vllm/engine/llm_engine.py | 22 ++++++++++-------- vllm/engine/output_processor/single_step.py | 24 ++++++++++++++++++++ vllm/engine/output_processor/stop_checker.py | 24 +++++++++++--------- vllm/sequence.py | 2 ++ 6 files changed, 60 insertions(+), 25 deletions(-) diff --git a/setup_cython.py b/setup_cython.py index 5ea5c39b4e2..8ae3e978774 100644 --- a/setup_cython.py +++ b/setup_cython.py @@ -24,6 +24,7 @@ "vllm/model_executor/layers/sampler.py", "vllm/sampling_params.py", "vllm/utils.py", + "vllm/block.py", ] diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 201cba309f6..f8384466816 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -612,12 +612,14 @@ def _free_block_table(self, block_table: BlockTable) -> None: self.cpu_allocator.free(block) def free(self, seq: Sequence) -> None: - if seq.seq_id not in self.block_tables: - # Already freed or haven't been scheduled yet. - return - block_table = self.block_tables[seq.seq_id] + seq_id = seq.seq_id + block_table = self.block_tables.pop(seq_id,[]) + #if seq.seq_id not in self.block_tables: + # # Already freed or haven't been scheduled yet. + # return + #block_table = self.block_tables[seq.seq_id] self._free_block_table(block_table) - del self.block_tables[seq.seq_id] + #del self.block_tables[seq.seq_id] def free_cross(self, seq_group: SequenceGroup) -> None: if seq_group.request_id not in self.cross_block_tables: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index affd5ae9d80..e67f1d71bc8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -681,12 +681,13 @@ def _process_model_outputs( """ now = time.time() - # Organize outputs by [sequence group][step] instead of # [step][sequence group]. output_by_sequence_group = create_output_by_sequence_group( output, num_seq_groups=len(scheduled_seq_groups)) + seq_groups = [scheduled_seq_group.seq_group for scheduled_seq_group in scheduled_seq_groups] + # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs, seq_group_meta in zip( scheduled_seq_groups, output_by_sequence_group, @@ -708,14 +709,17 @@ def _process_model_outputs( # Create the outputs. request_outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] - for scheduled_seq_group in scheduled_seq_groups: - seq_group = scheduled_seq_group.seq_group - seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) - request_outputs.append(request_output) - for seq_group in ignored_seq_groups: - request_output = RequestOutputFactory.create(seq_group) - request_outputs.append(request_output) + [seq_group.maybe_set_first_token_time(now) for seq_group in seq_groups] + request_outputs = [RequestOutputFactory.create(seq_group) for seq_group in seq_groups] + #for scheduled_seq_group in scheduled_seq_groups: + # seq_group = scheduled_seq_group.seq_group + # seq_group.maybe_set_first_token_time(now) + # request_output = RequestOutputFactory.create(seq_group) + # request_outputs.append(request_output) + request_outputs.extend([RequestOutputFactory.create(seq_group) for seq_group in ignored_seq_groups]) + #for seq_group in ignored_seq_groups: + # request_output = RequestOutputFactory.create(seq_group) + # request_outputs.append(request_output) return request_outputs def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 44de1d7ec56..fc8885a6965 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -74,6 +74,30 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: # Process samples samples = outputs.samples + if len(samples)==1: + #if there's only 1 sample, it has to be from 1 running seq in seq group + parent_seq = next(iter(seq_group.seqs_dict.values())) + child_sample = samples[0] + if not seq_group.sampling_params.use_beam_search: + #fastpath + parent_seq.append_token_id(child_sample.output_token, + child_sample.logprobs) + if self.detokenizer and seq_group.sampling_params.detokenize: + new_char_count = self.detokenizer.decode_sequence_inplace( + parent_seq, seq_group.sampling_params) + else: + new_char_count = 0 + + stopped = self.stop_checker.maybe_stop_sequence( + parent_seq, + new_char_count, + seq_group.sampling_params, + lora_req=seq_group.lora_request, + ) + #if parent_seq.is_finished(): + if stopped: + self.scheduler.free_seq(parent_seq) + return parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) existing_finished_seqs = seq_group.get_finished_seqs() parent_child_dict: Dict[int, List[SequenceOutput]] = { diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 96f0d114261..eda6ca16ac5 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -33,7 +33,7 @@ def maybe_stop_sequence( new_char_count: int, sampling_params: SamplingParams, lora_req: Optional[LoRARequest] = None, - ) -> None: + ) -> bool: """Stop the finished sequences. new_char_count is the number of chars added to the @@ -42,23 +42,24 @@ def maybe_stop_sequence( # Check if the minimum number of tokens has been generated yet; # skip the stop string/token checks if not - if seq.get_output_len() < sampling_params.min_tokens: - return + outlen = seq.get_output_len() + if outlen < sampling_params.min_tokens: + return False + last_token_id = seq.get_last_token_id() # Check if the sequence has generated the EOS token. if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): + and last_token_id == seq.eos_token_id): # Remove the last EOS token unless explicitly specified # This prevents unintended exposure of the EOS token if new_char_count and ( not sampling_params.include_stop_str_in_output): seq.output_text = seq.output_text[:-new_char_count] seq.status = SequenceStatus.FINISHED_STOPPED - return + return True # Check if a stop token was encountered. # This assumes a single token produced per step. - last_token_id = seq.get_last_token_id() if last_token_id in sampling_params.stop_token_ids: if new_char_count and ( not sampling_params.include_stop_str_in_output): @@ -66,7 +67,7 @@ def maybe_stop_sequence( seq.output_text = seq.output_text[:-new_char_count] seq.status = SequenceStatus.FINISHED_STOPPED seq.stop_reason = last_token_id - return + return True # Check if any stop strings are matched. stop_str = self._check_stop_strings(seq, new_char_count, @@ -74,17 +75,18 @@ def maybe_stop_sequence( if stop_str is not None: seq.status = SequenceStatus.FINISHED_STOPPED seq.stop_reason = stop_str - return + return True # Check if the sequence has reached max_model_len. if seq.get_len() > self._get_max_model_len(lora_req): seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return + return True # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: + if outlen == sampling_params.max_tokens: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return + return True + return False @staticmethod def _check_stop_strings(seq: Sequence, new_char_count: int, diff --git a/vllm/sequence.py b/vllm/sequence.py index 9dcef7f8041..522355adc8f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -166,6 +166,8 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int): self._num_computed_tokens += num_new_computed_tokens assert self._num_computed_tokens <= self.get_len(), ( self._num_computed_tokens, self.get_len()) + if self._stage == SequenceStage.DECODE: + return # If all tokens are computed, it means it is in decoding phase. if self.get_num_uncomputed_tokens() == 0: self._stage = SequenceStage.DECODE From 90f15dafcb10ab8f3a98a3314229d067b1e9d34c Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:45:40 -0400 Subject: [PATCH 2/3] Llama3.1 (#129) * Add support for a rope extension method (#6553) * [BugFix] Fix RoPE error in Llama 3.1 (#6693) --------- Co-authored-by: Simon Mo Co-authored-by: Woosuk Kwon --- vllm/config.py | 54 ++++-- .../model_executor/layers/rotary_embedding.py | 166 +++++++++++++++++- 2 files changed, 202 insertions(+), 18 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 160ec63d4d5..562bd6a4f69 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,7 +11,8 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron +from vllm.utils import (get_cpu_memory, is_cpu, is_hip, is_neuron, + print_warning_once) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -133,6 +134,17 @@ def __init__( code_revision, rope_scaling) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + + if (not self.disable_sliding_window + and self.hf_text_config.model_type == "gemma2" + and self.hf_text_config.sliding_window is not None): + print_warning_once( + "Gemma 2 uses sliding window attention for every odd layer, " + "which is currently not supported by vLLM. Disabling sliding " + "window and capping the max length to the sliding window size " + f"({self.hf_text_config.sliding_window}).") + self.disable_sliding_window = True + self.max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, max_model_len=max_model_len, @@ -1224,20 +1236,32 @@ def _get_and_verify_max_len( derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) - if rope_scaling is not None and rope_scaling["type"] != "su": - if disable_sliding_window: - # TODO(robertgshaw): Find a model that supports rope_scaling - # with sliding window to see if this case should be allowed. - raise NotImplementedError( - "Disabling sliding window is not supported for models " - "with rope_scaling. Please raise an issue so we can " - "investigate.") - assert "factor" in rope_scaling - scaling_factor = rope_scaling["factor"] - if rope_scaling["type"] == "yarn": - derived_max_model_len = rope_scaling[ - "original_max_position_embeddings"] - derived_max_model_len *= scaling_factor + if rope_scaling is not None: + if "type" in rope_scaling: + rope_type = rope_scaling["type"] + elif "rope_type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + else: + raise ValueError( + "rope_scaling must have a 'type' or 'rope_type' key.") + + # The correct one should be "longrope", kept "su" here + # to be backward compatible + if rope_type not in ("su", "longrope", "llama3"): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") + + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if rope_type == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor # If the user specified a max length, make sure it is smaller than the # derived length from the HF model config. diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d03903d206d..c15be15a9f0 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -503,6 +503,159 @@ def forward( return query.flatten(-2), key.flatten(-2) +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + device="cuda", + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + print("Cache shape", cache.shape) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device) + cos_sin = self.cos_sin_cache[torch.add(positions, offsets) + if offsets is not None else positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key + + +class GemmaRotaryEmbedding(RotaryEmbedding): + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 + inv_freq = 1.0 / (base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() / + self.rotary_dim)) + return inv_freq + + +class ExtendedRotaryEmbedding(RotaryEmbedding): + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + return self.apply_scaling(inv_freqs) + + def apply_scaling(self, freqs: torch.Tensor): + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / scale_factor + + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -534,10 +687,17 @@ def get_rope( rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) else: - scaling_type = rope_scaling["type"] - if scaling_type != "su": + scaling_type = rope_scaling[ + "type"] if "type" in rope_scaling else rope_scaling["rope_type"] + # The correct one should be "longrope" but keep "su" here + # for backward compatible + if scaling_type not in {"su", "longrope", "llama3"}: scaling_factor = rope_scaling["factor"] - if scaling_type == "linear": + if scaling_type == "llama3": + rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype) + elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, From ffa6d0a571f5c8201e54bc22e8d69770b17d4484 Mon Sep 17 00:00:00 2001 From: Jeremy Arnold Date: Mon, 26 Aug 2024 18:11:58 +0000 Subject: [PATCH 3/3] Update hipblaslt and FA revs to match what was used for MLPerf --- Dockerfile.rocm | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 28b5b267104..b16b35619c9 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -40,7 +40,7 @@ WORKDIR ${COMMON_WORKDIR} # ----------------------- # hipBLASLt build stages FROM base AS build_hipblaslt -ARG HIPBLASLT_BRANCH="6f65c6e" +ARG HIPBLASLT_BRANCH="8b71e7a8d26ba95774fdc372883ee0be57af3d28" RUN git clone https://github.com/ROCm/hipBLASLt \ && cd hipBLASLt \ && git checkout ${HIPBLASLT_BRANCH} \ @@ -70,7 +70,7 @@ FROM export_rccl_${BUILD_RCCL} AS export_rccl # ----------------------- # flash attn build stages FROM base AS build_flash_attn -ARG FA_BRANCH="ae7928c" +ARG FA_BRANCH="23a2b1c2f21de2289db83de7d42e125586368e66" ARG FA_REPO="https://github.com/ROCm/flash-attention.git" ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" RUN git clone ${FA_REPO} \