Skip to content

Commit d9ee2ec

Browse files
dsikkallmpros
authored andcommitted
[Misc] Fix get_min_capability (vllm-project#5971)
1 parent 1299216 commit d9ee2ec

File tree

5 files changed

+17
-6
lines changed

5 files changed

+17
-6
lines changed

vllm/model_executor/layers/quantization/awq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def get_name(self) -> str:
4343
def get_supported_act_dtypes(self) -> List[torch.dtype]:
4444
return [torch.half]
4545

46-
def get_min_capability(self) -> int:
46+
@classmethod
47+
def get_min_capability(cls) -> int:
4748
# The AWQ kernel only supports Turing or newer GPUs.
4849
return 75
4950

vllm/model_executor/layers/quantization/base_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def get_supported_act_dtypes(self) -> List[torch.dtype]:
4444
"""List of supported activation dtypes."""
4545
raise NotImplementedError
4646

47+
@classmethod
4748
@abstractmethod
48-
def get_min_capability(self) -> int:
49+
def get_min_capability(cls) -> int:
4950
"""Minimum GPU capability to support the quantization method.
5051
5152
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.

vllm/model_executor/layers/quantization/bitsandbytes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def get_supported_act_dtypes(self) -> List[torch.dtype]:
3838
return [torch.float32, torch.float16, torch.bfloat16]
3939

4040
@classmethod
41-
def get_min_capability(self) -> int:
41+
def get_min_capability(cls) -> int:
4242
return 70
4343

4444
@staticmethod

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ def get_scaled_act_names(self) -> List[str]:
3333
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
3434
return [torch.float16, torch.bfloat16]
3535

36-
# Need to figure it out
3736
@classmethod
3837
def get_min_capability(cls) -> int:
39-
return 60
38+
return 75
4039

4140
def get_name(self) -> str:
4241
return "compressed_tensors"
@@ -84,6 +83,14 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
8483
def get_config_filenames(cls) -> List[str]:
8584
return []
8685

86+
def _check_gptq_and_marlin_can_run(self):
87+
capability = torch.cuda.get_device_capability()
88+
capability = capability[0] * 10 + capability[1]
89+
if capability < 80:
90+
raise RuntimeError("The quantization config is not supported for ",
91+
"the current GPU. Minimum capability: 80. ",
92+
f"Current capability: {capability}.")
93+
8794
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
8895
input_quant: BaseModel) -> bool:
8996
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
@@ -126,6 +133,7 @@ def _get_schema(self, weight_quant: BaseModel,
126133
input_quant: BaseModel) -> "CompressedTensorsScheme":
127134

128135
if self._is_wNa16_group_channel(weight_quant, input_quant):
136+
self._check_gptq_and_marlin_can_run()
129137
if (self.quant_format == CompressionFormat.marlin_24.value
130138
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
131139
return CompressedTensorsW4A16Sparse24(

vllm/model_executor/layers/quantization/squeezellm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def get_name(self) -> str:
3939
def get_supported_act_dtypes(self) -> List[torch.dtype]:
4040
return [torch.half]
4141

42-
def get_min_capability(self) -> int:
42+
@classmethod
43+
def get_min_capability(cls) -> int:
4344
return 70
4445

4546
@staticmethod

0 commit comments

Comments
 (0)