-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Speculative Decoding] Medusa Implementation with Top-1 proposer #4978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8f219db
381fad5
1a09ca5
0cd042c
920fcc6
46cc72e
06ef396
13d5356
ed3a43e
0ef282f
7995e65
5f86ed8
b712449
24c9e91
e613a1d
da2cc47
bfc8e13
d9eb7ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
"""This docstring details important information on the testing methodology. | ||
|
||
Most of the tests rely on "greedy equality", where we expect the output of | ||
speculative decoding on a sequence to exactly match the output of normal non- | ||
speculative decoding. | ||
|
||
Since speculative decoding with rejection sampling guarantees that the output | ||
distribution matches the target model's output distribution (up to hardware | ||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy | ||
equality. | ||
|
||
However, we still need to verify below scenario could be passed: | ||
* Batch size 1 greedy equality | ||
* Batch size >1 greedy equality | ||
* Test greedy equality under preemption | ||
* Test greedy equality under various number of speculative tokens. | ||
|
||
With those tests, we can say at least, Medusa would not break the | ||
correctess for the target model outputs. | ||
""" | ||
|
||
import pytest | ||
|
||
from .conftest import run_greedy_equality_correctness_test | ||
|
||
# main model | ||
# lmsys/vicuna-7b-v1.3 was to be used but it's causing | ||
# OOM in CI pipeline, so using a smaller model. | ||
MAIN_MODEL = "JackFram/llama-68m" | ||
|
||
# speculative model | ||
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random" | ||
|
||
# max. number of speculative tokens: this corresponds to | ||
# num_heads in the config.json of the speculator model. | ||
MAX_SPEC_TOKENS = 5 | ||
|
||
# precision | ||
PRECISION = "float32" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
# Skip cuda graph recording for fast test. | ||
"enforce_eager": True, | ||
|
||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
|
||
# Print spec metrics. | ||
"disable_log_stats": False, | ||
|
||
# Precision | ||
"dtype": PRECISION, | ||
|
||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("test_llm_kwargs", [ | ||
{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": MAX_SPEC_TOKENS, | ||
}, | ||
]) | ||
@pytest.mark.parametrize("output_len", [ | ||
128, | ||
]) | ||
@pytest.mark.parametrize("batch_size", [1, 32]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, | ||
batch_size: int, output_len: int): | ||
"""Verify greedy equality with different batch size.""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
"block_size": 8, | ||
# 2 for small prompt, 256//8 for generated. | ||
"num_gpu_blocks_override": 2 + 256 // 8, | ||
"max_model_len": (2 + 256 // 8) * 8, | ||
|
||
# Skip cuda graph recording for fast test. | ||
"enforce_eager": True, | ||
|
||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
|
||
# Precision | ||
"dtype": PRECISION, | ||
|
||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("test_llm_kwargs", [ | ||
{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": MAX_SPEC_TOKENS, | ||
}, | ||
]) | ||
@pytest.mark.parametrize( | ||
"output_len", | ||
[ | ||
# Use small output len for fast test. | ||
128, | ||
]) | ||
@pytest.mark.parametrize("batch_size", [4]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size: int, | ||
output_len: int): | ||
"""Verify greedy equality, even when some sequences are preempted mid- | ||
generation. | ||
""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
# Skip cuda graph recording for fast test. | ||
"enforce_eager": True, | ||
|
||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
|
||
# Precision | ||
"dtype": PRECISION, | ||
|
||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize( | ||
"test_llm_kwargs", | ||
[ | ||
{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": k, | ||
} | ||
# Try a range of num. speculative tokens | ||
for k in range(1, 1 + MAX_SPEC_TOKENS) | ||
]) | ||
@pytest.mark.parametrize("batch_size", [2]) | ||
@pytest.mark.parametrize( | ||
"output_len", | ||
[ | ||
# Use smaller output len for fast test. | ||
32, | ||
]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_mlp_different_k(baseline_llm_generator, test_llm_generator, | ||
batch_size: int, output_len: int): | ||
"""Verify that mlp speculative decoding produces exact equality | ||
to without spec decode with different values of num_speculative_tokens. | ||
""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
# Skip cuda graph recording for fast test. | ||
"enforce_eager": True, | ||
|
||
# Required for spec decode. | ||
"use_v2_block_manager": True, | ||
|
||
# Precision | ||
"dtype": PRECISION, | ||
|
||
# Main model | ||
"model": MAIN_MODEL, | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("test_llm_kwargs", | ||
[{ | ||
"speculative_model": SPEC_MODEL, | ||
"num_speculative_tokens": MAX_SPEC_TOKENS, | ||
"speculative_disable_by_batch_size": 4 | ||
}]) | ||
@pytest.mark.parametrize("batch_size", [1, 5]) | ||
@pytest.mark.parametrize( | ||
"output_len", | ||
[ | ||
# Use smaller output len for fast test. | ||
32, | ||
]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator, | ||
batch_size: int, output_len: int): | ||
"""Verify that mlp speculative decoding produces exact equality | ||
to without spec decode when speculation is disabled for large | ||
batch sizes. | ||
""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, | ||
test_llm_generator, | ||
batch_size, | ||
max_output_len=output_len, | ||
force_output_len=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
import pytest | ||
pytest.main([__file__]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from typing import Iterable, List, Optional, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
from vllm.sequence import SamplerOutput | ||
from vllm.transformers_utils.configs.medusa import MedusaConfig | ||
|
||
|
||
class ResidualBlock(nn.Module): | ||
|
||
def __init__(self, hidden_size: int, num_layers: int) -> None: | ||
super().__init__() | ||
|
||
self.layers = nn.ModuleList([ | ||
nn.Linear(hidden_size, hidden_size, bias=False) | ||
for _ in range(num_layers) | ||
]) | ||
self.act = nn.SiLU() | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
for layer in self.layers: | ||
x = x + self.act(layer(x)) | ||
return x | ||
|
||
|
||
class Medusa(nn.Module): | ||
|
||
def __init__(self, config: MedusaConfig, **_) -> None: | ||
super().__init__() | ||
self.config = config | ||
self.blocks = nn.ModuleList([ | ||
ResidualBlock(hidden_size=self.config.hidden_size, | ||
num_layers=self.config.num_hidden_layers) | ||
for _ in range(self.config.num_heads) | ||
]) | ||
self.orig_vocab_size = config.vocab_size | ||
self.truncated_vocab_size = config.truncated_vocab_size | ||
self.unpadded_vocab_size = self.truncated_vocab_size | ||
|
||
self.lm_heads = nn.ModuleList([ | ||
ParallelLMHead( | ||
self.unpadded_vocab_size, | ||
config.hidden_size, | ||
org_num_embeddings=self.truncated_vocab_size, | ||
padding_size=DEFAULT_VOCAB_PADDING_SIZE, | ||
) for _ in range(self.config.num_heads) | ||
]) | ||
|
||
logit_scale = getattr(config, "logit_scale", 1.0) | ||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, | ||
self.truncated_vocab_size, | ||
logit_scale) | ||
|
||
self.token_map = None | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: | ||
return [block(hidden_states) for block in self.blocks] | ||
|
||
def compute_logits( | ||
self, hidden_states: List[torch.Tensor], | ||
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]: | ||
logits = [] | ||
|
||
for hs, lm_head in zip(hidden_states, self.lm_heads): | ||
_logits = self.logits_processor(lm_head, hs, sampling_metadata) | ||
|
||
if self.token_map is None: | ||
logits.append(_logits) | ||
else: | ||
logits.append(-torch.inf * torch.ones( | ||
size=(*_logits.shape[:-1], self.orig_vocab_size), | ||
device=_logits.device, | ||
dtype=_logits.dtype)) | ||
|
||
logits[-1][..., self.token_map] = _logits | ||
|
||
return logits | ||
|
||
def sample( | ||
self, | ||
logits: List[torch.Tensor], | ||
sampling_metadata: SamplingMetadata, | ||
) -> List[SamplerOutput]: | ||
logits = torch.stack(logits, dim=0).float() | ||
logprobs = torch.log_softmax(logits, dim=-1) | ||
token_ids = logits.argmax(-1) # support only top-1 for now | ||
probs = torch.softmax(logits, dim=-1) | ||
Comment on lines
+90
to
+93
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we use the lossless rejection sampler, we will have to run vLLM's standard sampling routine here -- the probability distribution must be modified in the same way as the scoring probability distributions, else you will get distributional drift in the output. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please elaborate on the distribution shift? The tokens from the draft model are either accepted or rejected based on target model distribution, right? So even if the tokens from the draft are from a slightly different distribution, the final output should still match the target model distribution due to rejection. Is this understanding wrong or am I missing something? The issue with using the standard sampling is that it was causing too much overhead. So if we do need to use it, we might need some optimizations there to get some speed-up out of Medusa. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's one case that I have noticed generates different tokens sometimes (not sure if this is what you are referring to though). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realised this was happening because of bf16 precision, not seeing any such shift when using fp32. |
||
|
||
token_id_list = [] | ||
token_prob_list = [] | ||
token_logprob_list = [] | ||
|
||
for idx, seq_group in enumerate(sampling_metadata.seq_groups): | ||
token_id_list.append(token_ids[:, seq_group.sample_indices]) | ||
token_prob_list.append(probs[:, seq_group.sample_indices]) | ||
token_logprob_list.append(logprobs[:, seq_group.sample_indices]) | ||
|
||
outputs: List[Optional[SamplerOutput]] = [] | ||
for idx in range(len(sampling_metadata.seq_groups)): | ||
outputs.append( | ||
SamplerOutput( | ||
outputs=None, | ||
sampled_token_probs=token_prob_list[idx].squeeze(1), | ||
logprobs=token_logprob_list[idx].squeeze(1), | ||
sampled_token_ids=token_id_list[idx].squeeze(1), | ||
)) | ||
|
||
return outputs | ||
|
||
def generate_proposals( | ||
self, | ||
previous_hidden_states: torch.Tensor, | ||
sampling_metadata: SamplingMetadata, | ||
) -> List[SamplerOutput]: | ||
return self.sample( | ||
logits=self.compute_logits( | ||
hidden_states=self.forward(previous_hidden_states), | ||
sampling_metadata=sampling_metadata, | ||
), | ||
sampling_metadata=sampling_metadata, | ||
) | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
params_dict = dict(self.named_parameters()) | ||
|
||
weights_map = {} | ||
|
||
for name, loaded_weight in weights: | ||
name = name.replace("medusa_heads.", "") | ||
|
||
if name == "token_map": | ||
if self.truncated_vocab_size < self.orig_vocab_size: | ||
self.token_map = nn.Parameter(loaded_weight, | ||
requires_grad=False) | ||
elif name in params_dict: | ||
weights_map[name] = loaded_weight | ||
|
||
for name, loaded_weight in weights_map.items(): | ||
if "lm_head" in name and self.token_map is not None and\ | ||
loaded_weight.shape[0] > self.token_map.shape[0]: | ||
|
||
loaded_weight = loaded_weight[self.token_map] | ||
|
||
param = params_dict[name] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
weight_loader(param, loaded_weight) | ||
|
||
if self.token_map is not None: | ||
self.token_map.to(device=self.lm_heads[0].weight.device) | ||
|
||
assert (self.truncated_vocab_size | ||
== self.orig_vocab_size) or (self.token_map is not None) |
Uh oh!
There was an error while loading. Please reload this page.