Skip to content

Commit 6ff930b

Browse files
committed
[None][autodeploy] Add quantization_source to factory interface (#100)
* Move quant_config handling to _load_quantization_config Signed-off-by: Fridah-nv <[email protected]> move kv_cache_dtype into _quant_config in hf factory Signed-off-by: Fridah-nv <[email protected]> remove quant_source getter Signed-off-by: Fridah-nv <[email protected]> * add QuantConfigReader class Signed-off-by: Fridah-nv <[email protected]> minor Signed-off-by: Fridah-nv <[email protected]> tmp:Llama4 FP8 for BMM testing Signed-off-by: Fridah-nv <[email protected]> revert Llama4 FP8 patch Signed-off-by: Fridah-nv <[email protected]> move _quant_config into QuantConfigReader Signed-off-by: Fridah-nv <[email protected]> * move quantize and quantize_moe to the end of pattern matcher Signed-off-by: Fridah-nv <[email protected]> * delegate quant_config processing to QuantConfigReader and pass read to the transformation, spilit transformation into config and graph based Signed-off-by: Fridah-nv <[email protected]> have quantConfigReader return the dtype for NVFP4 Signed-off-by: Fridah-nv <[email protected]> move quantization target collection as a transform Signed-off-by: Fridah-nv <[email protected]> minor Signed-off-by: Fridah-nv <[email protected]> tmp:hacky fix modelopt graph based path quantizer loading Signed-off-by: Fridah-nv <[email protected]> fix rebase quantization to BaseTransform Signed-off-by: Fridah-nv <[email protected]> minor: remove QuantizationTarget Signed-off-by: Fridah-nv <[email protected]> rm tmp fix, rebase, minor update on interface Signed-off-by: h-guo18 <[email protected]> Signed-off-by: Fridah-nv <[email protected]> update transform docstring Signed-off-by: Fridah-nv <[email protected]> fix unit test Signed-off-by: Fridah-nv <[email protected]> * move quantization to end of transformations Signed-off-by: Fridah-nv <[email protected]> --------- Signed-off-by: Fridah-nv <[email protected]> minor Signed-off-by: Fridah-nv <[email protected]>
1 parent bdbfb05 commit 6ff930b

File tree

6 files changed

+221
-77
lines changed

6 files changed

+221
-77
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ transforms:
1919
stage: post_export
2020
cleanup_input_constraints:
2121
stage: post_export
22-
quantize:
23-
stage: pattern_matcher
24-
quantize_moe:
25-
stage: pattern_matcher
2622
match_repeat_kv:
2723
stage: pattern_matcher
2824
match_eager_attention:
@@ -41,3 +37,9 @@ transforms:
4137
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
4238
optimize_rope:
4339
stage: pattern_matcher
40+
quantize_from_config:
41+
stage: pattern_matcher
42+
quantize_from_graph:
43+
stage: pattern_matcher
44+
quantize_moe:
45+
stage: pattern_matcher

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Interface to initialize and load HF models."""
22

3-
import json
43
import os
54
import types
65
from contextlib import contextmanager, nullcontext
@@ -31,6 +30,7 @@
3130
from ..utils._config import deep_merge_dicts
3231
from ..utils.logger import ad_logger
3332
from .factory import ModelFactory, ModelFactoryRegistry
33+
from .quant_config_reader import QuantConfigReader, QuantConfigReaderRegistry
3434

3535

3636
@contextmanager
@@ -84,9 +84,7 @@ def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
8484

8585
def __init__(self, *args, **kwargs):
8686
super().__init__(*args, **kwargs)
87-
88-
self._quant_config: Optional[Dict] = None
89-
87+
self._quant_config_reader: Optional[QuantConfigReader] = None
9088
# Ingest defaults for tokenizer and model kwargs
9189
self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs)
9290
self.model_kwargs = deep_merge_dicts(
@@ -156,9 +154,6 @@ def _recursive_update_config(self, config: PretrainedConfig, update_dict: Dict[s
156154

157155
def _build_model(self, device: DeviceLikeType) -> nn.Module:
158156
"""Build the model on the desired device."""
159-
# We only support fp16 to fp4 conversion.
160-
if self._quant_config and self._quant_config.get("quant_algo", None) == "NVFP4":
161-
self.model_kwargs["torch_dtype"] = torch.half
162157

163158
# NOTE (lucaslie): HF doesn't recursively update nested PreTrainedConfig objects. Instead,
164159
# the entire subconfig will be overwritten.
@@ -178,23 +173,24 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
178173
model.forward = types.MethodType(self._simple_forward, model)
179174

180175
model.eval()
176+
181177
return model
182178

183179
def get_quant_config(self) -> Dict:
184-
return self._quant_config or {}
180+
"""Returns the quantization config for this model or None if not quantized."""
181+
if self._quant_config_reader is not None:
182+
return self._quant_config_reader.get_config()
183+
return {}
185184

186185
def get_cache_config(self):
187-
"""Setup cache information based on quantization information."""
188-
if self._quant_config is not None and "kv_cache_quant_algo" in self._quant_config.keys():
189-
kv_cache_format = self._quant_config.get("kv_cache_quant_algo", None)
190-
if kv_cache_format is not None:
191-
assert kv_cache_format == "FP8", (
192-
f"KV cache quantization format {kv_cache_format} is not supported."
193-
)
194-
kv_cache_dtype = torch.float8_e4m3fn if kv_cache_format is not None else None
195-
else:
196-
kv_cache_dtype = None
197-
return CacheConfig(dtype=kv_cache_dtype)
186+
"""Return kv cache dtype configuration."""
187+
if not self._quant_config_reader:
188+
return CacheConfig(dtype=None)
189+
190+
kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype")
191+
torch_dtype = {"float8_e4m3fn": torch.float8_e4m3fn}.get(kv_cache_dtype, None)
192+
193+
return CacheConfig(dtype=torch_dtype)
198194

199195
def init_tokenizer(self) -> Optional[Any]:
200196
"""Initialize the tokenizer—either a custom name or the model's default."""
@@ -325,22 +321,18 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
325321

326322
def _load_quantization_config(self, fetched_dir: str):
327323
"""Load the quantization config from the model directory if not done already."""
328-
if self._quant_config is not None:
324+
if self._quant_config_reader is not None:
325+
return
326+
# TODO: specified by user or auto-detect
327+
reader_cls = QuantConfigReaderRegistry.get("modelopt")
328+
result = reader_cls.from_file(fetched_dir)
329+
if result is None:
329330
return
331+
reader, extra_model_kwargs = result
330332

331-
assert self.model
332-
hf_quant_config_file = os.path.join(fetched_dir, "hf_quant_config.json")
333-
if os.path.exists(hf_quant_config_file):
334-
with open(hf_quant_config_file, "r") as file:
335-
quantization_config = json.load(file)
336-
assert quantization_config.get("producer", {}).get("name", None) == "modelopt", (
337-
"Only support modelopt quantized checkpoint"
338-
)
339-
self._quant_config = quantization_config.get("quantization", {})
340-
341-
# We do not quantize lm_head.
342-
if "exclude_modules" not in self._quant_config:
343-
self._quant_config["exclude_modules"] = ["lm_head"]
333+
if reader is not None:
334+
self._quant_config_reader = reader
335+
self.model_kwargs = deep_merge_dicts(self.model_kwargs, extra_model_kwargs)
344336

345337

346338
@ModelFactoryRegistry.register("AutoModelForImageTextToText")
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
Quantization Config Reader Registry.
3+
4+
This module defines a registry system for parsing quantization configurations
5+
from various sources (e.g., 'modelopt'). It enables extensible support for different
6+
quantization producers by delegating parsing logic to dedicated subclasses.
7+
"""
8+
9+
import json
10+
import os
11+
from abc import ABC, abstractmethod
12+
from typing import Any, Callable, Dict, Optional, Tuple, Type
13+
14+
import torch
15+
16+
17+
class QuantConfigReader(ABC):
18+
"""Base class for reading and parsing quantization config."""
19+
20+
def __init__(self):
21+
self._quant_config: Optional[Dict] = None
22+
23+
def get_config(self) -> Dict:
24+
"""Return the parsed quantization config."""
25+
return self._quant_config or {}
26+
27+
@abstractmethod
28+
def read_config(self, path: str) -> Dict:
29+
"""
30+
Parse and normalize a quantization config dictionary.
31+
32+
Args:
33+
config: The raw "quantization" field from the JSON file.
34+
35+
Returns:
36+
A processed and normalized config dictionary.
37+
"""
38+
pass
39+
40+
@classmethod
41+
@abstractmethod
42+
def from_file(cls, file_path: str) -> Optional["QuantConfigReader"]:
43+
"""
44+
Load and parse a quantization config file from disk.
45+
46+
This method is implemented by each reader to handle loading and parsing logic.
47+
48+
Args:
49+
file_path: Path to the quant config JSON file.
50+
51+
Returns:
52+
An initialized QuantConfigReader instance, or None if the file doesn't exist.
53+
"""
54+
pass
55+
56+
57+
class QuantConfigReaderRegistry:
58+
_registry: Dict[str, Type[QuantConfigReader]] = {}
59+
60+
@classmethod
61+
def register(cls, name: str) -> Callable[[Type[QuantConfigReader]], Type[QuantConfigReader]]:
62+
def inner(reader_cls: Type[QuantConfigReader]) -> Type[QuantConfigReader]:
63+
cls._registry[name] = reader_cls
64+
return reader_cls
65+
66+
return inner
67+
68+
@classmethod
69+
def get(cls, name: str) -> Type[QuantConfigReader]:
70+
if name not in cls._registry:
71+
raise ValueError(f"QuantConfigReader for '{name}' not registered.")
72+
return cls._registry[name]
73+
74+
@classmethod
75+
def has(cls, reader_cls: str) -> bool:
76+
return reader_cls in cls._registry
77+
78+
79+
@QuantConfigReaderRegistry.register("modelopt")
80+
class ModelOPTQuantConfigReader(QuantConfigReader):
81+
def read_config(self, config: Dict) -> Dict:
82+
# Inject default exclusion
83+
config.setdefault("exclude_modules", ["lm_head"])
84+
85+
# Update dtype
86+
if config.get("quant_algo") == "NVFP4":
87+
config["torch_dtype"] = "float16"
88+
89+
# Handle kv cache
90+
kv_algo = config.get("kv_cache_quant_algo")
91+
if kv_algo:
92+
if kv_algo != "FP8":
93+
raise ValueError(f"KV cache quantization format {kv_algo} not supported.")
94+
config["kv_cache_dtype"] = "float8_e4m3fn"
95+
96+
self._quant_config = config
97+
return self._quant_config
98+
99+
@classmethod
100+
def from_file(
101+
cls, ckpt_dir: str
102+
) -> Optional[Tuple["ModelOPTQuantConfigReader", Optional[torch.dtype]]]:
103+
"""
104+
Load and parse a modelopt-style quantization config from a checkpoint directory.
105+
106+
Args:
107+
ckpt_dir: Path to the root directory containing the checkpoint.
108+
109+
Returns:
110+
An initialized ModelOPTQuantConfigReader instance, or None if the file doesn't exist.
111+
"""
112+
quant_file = os.path.join(ckpt_dir, "hf_quant_config.json")
113+
if not os.path.exists(quant_file):
114+
return None
115+
116+
with open(quant_file, "r") as f:
117+
raw = json.load(f)
118+
119+
producer = raw.get("producer", {}).get("name")
120+
# sanity check
121+
if producer != "modelopt":
122+
raise ValueError(f"Expected producer 'modelopt', got '{producer}'")
123+
124+
quant_config = raw.get("quantization", {})
125+
reader = cls()
126+
reader.read_config(quant_config)
127+
extra_model_kwargs: Dict[str, Any] = {}
128+
if quant_config and quant_config.get("quant_algo", None) == "NVFP4":
129+
extra_model_kwargs["torch_dtype"] = "float16"
130+
return reader, extra_model_kwargs

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from collections import defaultdict
21
from functools import partial
3-
from typing import Dict, Tuple
2+
from typing import Tuple
43

54
import torch.nn as nn
65
from torch.fx import GraphModule, Node
@@ -166,67 +165,87 @@ def get_scale_name(scale_name):
166165
node.args = (*node.args, *scale_values)
167166

168167

169-
@TransformRegistry.register("quantize")
170-
class Quantization(BaseTransform):
171-
"""Quantize the GraphModule and replace linear/BMM with quantized linear/BMM."""
168+
@TransformRegistry.register("quantize_from_config")
169+
class QuantizationFromConfig(BaseTransform):
170+
"""
171+
Quantize linear and BMM ops using a quantization config.
172+
173+
Replaces eligible ops with quantized equivalents based on the quantization algorithm
174+
and exclude patterns defined in the config.
175+
"""
172176

173177
def _apply(
174178
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
175179
) -> Tuple[GraphModule, TransformInfo]:
176-
# extract info from quant_config
177180
quant_config = factory.get_quant_config()
178-
if not quant_config:
181+
quant_algo = quant_config.get("quant_algo")
182+
excluded_patterns = quant_config.get("exclude_modules", [])
183+
184+
if not quant_config or not quant_algo:
179185
return gm, TransformInfo(
180186
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
181187
)
182188

189+
num_matches = 0
190+
191+
for n in gm.graph.nodes:
192+
if should_skip_quantization(n, excluded_patterns):
193+
continue
194+
195+
if is_linear_op(n, include_quantization=False):
196+
impl = QuantizationImpl.create(quant_algo, is_bmm=False)
197+
_insert_quantized_linear(gm, n, impl, False)
198+
num_matches += 1
199+
200+
elif is_bmm_op(n):
201+
impl = QuantizationImpl.create(quant_algo, is_bmm=True)
202+
_insert_quantized_bmm(gm, n, impl, False)
203+
num_matches += 1
204+
205+
info = TransformInfo(
206+
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True
207+
)
208+
209+
return gm, info
210+
211+
212+
@TransformRegistry.register("quantize_from_graph")
213+
class QuantizationFromGraph(BaseTransform):
214+
"""
215+
Fuse ModelOpt-quantized linear ops into fused quantized implementations.
216+
217+
Detects quantized nodes from ModelOpt checkpoints's graph and replaces them with
218+
fused linear ops based on the quantization type.
219+
"""
220+
221+
def _apply(
222+
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
223+
) -> Tuple[GraphModule, TransformInfo]:
183224
is_quant_graph = is_quantized_graph(gm)
184-
quant_algo = quant_config.get("quant_algo")
185-
excluded_patterns = quant_config.get("exclude_modules", [])
186-
if not quant_algo:
225+
226+
# no quantization to do
227+
if not is_quant_graph:
187228
return gm, TransformInfo(
188229
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
189230
)
190231

191232
# tracking quantized operations in the graph
192-
quantized_nodes: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
233+
num_matches = 0
193234
for n in gm.graph.nodes:
194-
if should_skip_quantization(n, excluded_patterns):
195-
continue
196-
197235
# Process linear operations
198236
if is_linear_op(n, include_quantization=False):
199237
# get per-layer quantization format from the node
200-
quant_algo_n: str = (
201-
get_quantization_from_linear_node(n) if is_quant_graph else quant_algo
202-
)
238+
quant_algo_n: str = get_quantization_from_linear_node(n)
203239
if not quant_algo_n:
204240
continue
205241

206242
# insert quantized linear node
207-
_insert_quantized_linear(
208-
gm, n, QuantizationImpl.create(quant_algo_n), is_quant_graph
209-
)
210-
quantized_nodes[quant_algo_n]["linear"] += 1
243+
_insert_quantized_linear(gm, n, QuantizationImpl.create(quant_algo_n), True)
244+
num_matches += 1
211245

212-
# Process BMM operations
213-
elif is_bmm_op(n):
214-
if not quant_algo:
215-
continue
216-
217-
# insert quantized bmm node
218-
_insert_quantized_bmm(
219-
gm, n, QuantizationImpl.create(quant_algo, is_bmm=True), is_quant_graph
220-
)
221-
quantized_nodes[quant_algo]["bmm"] += 1
246+
# To check: quant BMM does not have graph based pass?
222247

223-
if is_quant_graph:
224-
remove_output_quantizers(gm)
225-
226-
num_matches = 0
227-
for quant_algo in quantized_nodes:
228-
for op_type, count in quantized_nodes[quant_algo].items():
229-
num_matches += count
248+
remove_output_quantizers(gm)
230249

231250
info = TransformInfo(
232251
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True

tensorrt_llm/_torch/auto_deploy/transformations/transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module:
7272
############################################################################################
7373

7474
# Match MoE pattern
75+
# TODO:remove quantized linear handling inside this transformation
7576
match_moe_pattern(egm)
7677

7778
############################################################################################

0 commit comments

Comments
 (0)