Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
f470b26
gptq_marlin compat dynamic_bits quantize config
ZX-ModelCloud Aug 1, 2024
c56e3de
Merge branch 'main' into compat_dynamic_bits
ZX-ModelCloud Aug 1, 2024
502edb3
Update gptq_marlin.py
Qubitium Aug 2, 2024
18064cd
cleanup
ZX-ModelCloud Aug 2, 2024
1b132c3
cleanup
ZX-ModelCloud Aug 2, 2024
4b63754
cleanup
ZX-ModelCloud Aug 2, 2024
90258d2
cleanup
ZX-ModelCloud Aug 2, 2024
a5d3c8b
cleanup
ZX-ModelCloud Aug 2, 2024
c84793f
Merge remote-tracking branch 'origin/compat_dynamic_bits' into compat…
ZX-ModelCloud Aug 2, 2024
5682124
load "dynamic" field from config
ZX-ModelCloud Aug 2, 2024
d651668
fix key error: change "is_sym" to "sym"
ZX-ModelCloud Aug 2, 2024
9a36694
Merge branch 'main' into compat_dynamic_bits
ZX-ModelCloud Aug 6, 2024
fbc594f
Merge branch 'main' into compat_dynamic_bits
ZX-ModelCloud Aug 6, 2024
e9ae8f5
update quant_type
ZX-ModelCloud Aug 6, 2024
19d7772
update
ZX-ModelCloud Dec 24, 2024
7057dbb
Merge branch 'main' into compat_dynamic_bits
ZX-ModelCloud Dec 24, 2024
8565328
fix judgment error
ZX-ModelCloud Dec 24, 2024
84ada54
cleanup
ZX-ModelCloud Dec 24, 2024
e81a7da
cleanup
ZX-ModelCloud Dec 24, 2024
68291ce
cleanup
ZX-ModelCloud Dec 24, 2024
7867405
cleanup
ZX-ModelCloud Dec 24, 2024
c63ba51
cleanup
ZX-ModelCloud Dec 24, 2024
5f9b712
Update gptq_marlin.py
Qubitium Dec 24, 2024
3692578
Update gptq_marlin.py
Qubitium Dec 24, 2024
f902b2d
cleanup
ZX-ModelCloud Dec 24, 2024
a570509
Merge remote-tracking branch 'origin/compat_dynamic_bits' into compat…
ZX-ModelCloud Dec 24, 2024
9b9d7e3
Update gptq_marlin.py
Qubitium Dec 24, 2024
0559137
cleanup
ZX-ModelCloud Dec 24, 2024
b29a094
Merge remote-tracking branch 'origin/compat_dynamic_bits' into compat…
ZX-ModelCloud Dec 24, 2024
3a2bb94
cleanup
ZX-ModelCloud Dec 24, 2024
3c0d45a
cleanup
ZX-ModelCloud Dec 24, 2024
74b1d42
add test_gptq_dynamic_cfg.py
ZX-ModelCloud Dec 24, 2024
b0672ae
cleanup
ZX-ModelCloud Dec 24, 2024
066f489
Update test_gptq_dynamic_cfg.py
Qubitium Dec 24, 2024
6dc56a6
Update test_gptq_dynamic_cfg.py
Qubitium Dec 24, 2024
98a198e
cleanup
ZX-ModelCloud Dec 24, 2024
b2861d8
Merge remote-tracking branch 'origin/compat_dynamic_bits' into compat…
ZX-ModelCloud Dec 24, 2024
c4a29eb
use PROMPT variable
ZX-ModelCloud Dec 24, 2024
25703e3
cleanup
ZX-ModelCloud Dec 24, 2024
1fd690e
Merge branch 'main' into compat_dynamic_bits
ZX-ModelCloud Jan 8, 2025
4f48d1b
Merge branch 'main' into compat_dynamic_bits
ZX-ModelCloud Feb 6, 2025
070ae3c
rename method and add detailed comments
Qubitium Feb 6, 2025
13b2b7b
Changed VocabParallelEmbedding.linear_method to quant_method to be co…
ZX-ModelCloud Feb 7, 2025
6850e6d
Merge remote-tracking branch 'origin/compat_dynamic_bits' into compat…
ZX-ModelCloud Feb 7, 2025
40562d1
fix unittest
ZX-ModelCloud Feb 7, 2025
7b774bb
cleanup
ZX-ModelCloud Feb 7, 2025
c72125a
cleanup
ZX-ModelCloud Feb 7, 2025
c298195
cleanup
ZX-ModelCloud Feb 7, 2025
bbc049d
Update gptq_marlin.py
Qubitium Feb 7, 2025
78f8818
format
ZX-ModelCloud Feb 7, 2025
2cfec63
Merge branch 'main' into compat_dynamic_bits
ZX-ModelCloud Feb 7, 2025
93ee576
Update gptq_marlin.py
Qubitium Feb 7, 2025
6ebf85c
rename to parallel_lm_head_quantized for clarity
Qubitium Feb 7, 2025
59bdf54
simplify
Qubitium Feb 7, 2025
9de0382
shorten code
Qubitium Feb 7, 2025
67d0882
cleanup
ZX-ModelCloud Feb 7, 2025
5623936
cleanup
ZX-ModelCloud Feb 7, 2025
e41bdd7
make lint pass
Qubitium Feb 7, 2025
965d7da
change model_id
ZX-ModelCloud Feb 11, 2025
1a34027
format
ZX-ModelCloud Feb 11, 2025
0b249a1
format code
ZX-ModelCloud Feb 11, 2025
4de04ae
format code
ZX-ModelCloud Feb 11, 2025
4c0608b
format code
ZX-ModelCloud Feb 11, 2025
8f21375
disable E712 ruff check
ZX-ModelCloud Feb 11, 2025
e3084e3
Extract code to gptq_utils.get_linear_quant_method()
ZX-ModelCloud Feb 11, 2025
25dbd5a
cleanup
ZX-ModelCloud Feb 11, 2025
874076c
cleanup
ZX-ModelCloud Feb 11, 2025
17704df
Merge branch 'main' into compat_dynamic_bits
ZX-ModelCloud Feb 11, 2025
c7f10be
do not use Fraction
ZX-ModelCloud Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions tests/quantization/test_gptq_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests whether gptq models with dynamic quantized can be loaded.

Run `pytest tests/quantization/test_gptq_dynamic.py --forked`.
"""

import pytest
import torch

from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_dynamic_override)

PROMPT = "On the surface of Mars, we found"

# The first layer is quantized using bits=4, group_size=128
# The second layer is quantized using bits=8, group_size=32
# All other layers (layer index >= 2) are not quantized
MODEL_QUANT = [
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue",
True),
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse",
False),
]


@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT)
def test_gptq_with_dynamic(vllm_runner, model_id: str,
use_marlin_kernel: bool):

vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048)

linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else (
GPTQLinearMethod)

for name, submodule in (vllm_model.model.llm_engine.model_executor.
driver_worker.model_runner.model.named_modules()):
if name == "lm_head":
assert isinstance(submodule.quant_method, linear_method_cls)
elif name == 'model.layers.0.self_attn.qkv_proj':
# The first layer is quantized using bits=4, group_size=128
# desc_act=True
assert isinstance(submodule.quant_method, linear_method_cls)
config = submodule.quant_method.quant_config
assert config.weight_bits == 4
assert config.group_size == 128
assert config.desc_act
elif name == 'model.layers.1.self_attn.qkv_proj':
# The second layer is quantized using bits=8, group_size=32
# desc_act=False
assert isinstance(submodule.quant_method, linear_method_cls)
config = submodule.quant_method.quant_config
assert get_dynamic_override(config, layer_name=name,
key="bits") == 8
assert get_dynamic_override(config,
layer_name=name,
key="group_size") == 32
assert not get_dynamic_override(
config, layer_name=name, key="desc_act")
elif (name == 'model.layers.2.self_attn.qkv_proj'
or name == 'model.layers.2.mlp.gate_up_proj'):
# All other layers (layer index >= 2) are not quantized
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)

del vllm_model
25 changes: 12 additions & 13 deletions tests/quantization/test_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
"""
from typing import Tuple

import pytest
import torch
Expand All @@ -17,31 +16,31 @@

PROMPT = "On the surface of Mars, we found"

MODELS_QUANT = [(
"LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse",
True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)]
MODELS_QUANT = [
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", True),
("ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", False),
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)
]


@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT)
@pytest.mark.parametrize("model_id, lm_head_quantized", MODELS_QUANT)
def test_lm_head(
vllm_runner,
model_lm_head_quant: Tuple[str, bool],
model_id: str,
lm_head_quantized: bool,
) -> None:
model, lm_head_quantized = model_lm_head_quant

with vllm_runner(model, dtype=torch.float16,
with vllm_runner(model_id, dtype=torch.float16,
max_model_len=2048) as vllm_model:

def check_model(model):
lm_head_layer = model.lm_head

if lm_head_quantized:
assert isinstance(lm_head_layer.linear_method,
assert isinstance(lm_head_layer.quant_method,
(GPTQLinearMethod, GPTQMarlinLinearMethod,
MarlinLinearMethod))
else:
assert isinstance(lm_head_layer.linear_method,
assert isinstance(lm_head_layer.quant_method,
UnquantizedEmbeddingMethod)

vllm_model.apply_model(check_model)
Expand Down
2 changes: 1 addition & 1 deletion vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def _get_logits(
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.linear_method.apply(lm_head, hidden_states)
logits = lm_head.quant_method.apply(lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias

Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def _get_logits(
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)

# Gather logits for TP
logits = self._gather_logits(logits)
Expand Down
47 changes: 38 additions & 9 deletions vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import torch
from torch.nn.parameter import Parameter

from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
Expand All @@ -32,7 +33,33 @@ def __init__(
group_size: int,
desc_act: bool,
lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
self.dynamic = dynamic

self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
Expand All @@ -47,7 +74,8 @@ def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}),"
f"lm_head_quantized={self.lm_head_quantized}")
f"lm_head_quantized={self.lm_head_quantized}), "
f"dynamic={self.dynamic}")

@classmethod
def get_name(cls) -> str:
Expand All @@ -68,19 +96,20 @@ def get_config_filenames(cls) -> List[str]:

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic

weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, lm_head_quantized)
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
dynamic)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQLinearMethod(self)
return None
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)


class ExllamaState(Enum):
Expand Down
59 changes: 47 additions & 12 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,21 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.linear import (LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
Expand Down Expand Up @@ -47,12 +51,41 @@ def __init__(
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False

# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is Dict[str, Dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
self.dynamic = dynamic

self.weight_bits = weight_bits
self.is_sym = is_sym

self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.desc_act = desc_act
Expand All @@ -68,7 +101,8 @@ def __repr__(self) -> str:
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized})")
f"lm_head_quantized={self.lm_head_quantized}), "
f"dynamic={self.dynamic}")

@classmethod
def get_name(cls) -> str:
Expand All @@ -88,14 +122,17 @@ def get_config_filenames(cls) -> List[str]:

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic

weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized)
lm_head_quantized, dynamic)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
Expand All @@ -120,17 +157,15 @@ def override_quantization_method(cls, hf_quant_cfg,

def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]:
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
and self.lm_head_quantized):
return GPTQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod",
UnquantizedLinearMethod, UnquantizedEmbeddingMethod]]:
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
return None
return get_linear_quant_method(self, layer, prefix,
GPTQMarlinLinearMethod)

@classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
Expand All @@ -143,7 +178,7 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
if quant_method != "gptq":
return False

# If we cannot find the info needed in the config, cannot convert.
# Marlin conversion is only valid if required properties are found
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
Expand Down
Loading