Skip to content

Commit 302d4b1

Browse files
committed
fix
Signed-off-by: Superjomn <[email protected]>
1 parent 952e012 commit 302d4b1

38 files changed

+194
-175
lines changed

docs/source/torch.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ The PyTorch backend of TensorRT-LLM is available in version 0.17 and later. You
1111

1212
## Quick Start
1313

14-
Here is a simple example to show how to use `tensorrt_llm._torch.LLM` API with Llama model.
14+
Here is a simple example to show how to use `tensorrt_llm.LLM` API with Llama model.
1515

1616
```{literalinclude} ../../examples/pytorch/quickstart.py
1717
:language: python
@@ -24,7 +24,7 @@ The PyTorch backend supports FP8 and NVFP4 quantization. You can pass quantized
2424
which are generated by [TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer).
2525

2626
```python
27-
from tensorrt_llm._torch import LLM
27+
from tensorrt_llm import LLM
2828
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8')
2929
llm.generate("Hello, my name is")
3030
```
@@ -44,7 +44,7 @@ The PyTorch backend supports most of the sampling features that are supported on
4444
In order to use this feature, it is necessary to enable option `enable_trtllm_sampler` in the `LLM` class, and pass a `SamplingParams` object with the desired options as well. The following example prepares two identical prompts which will give different results due to the sampling parameters chosen:
4545

4646
```python
47-
from tensorrt_llm._torch import LLM
47+
from tensorrt_llm import LLM
4848
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
4949
enable_trtllm_sampler=True)
5050
sampling_params = SamplingParams(

docs/source/torch/adding_new_model.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ __all__ = [
186186
Alternatively, you can register the new model as an out-of-tree model, so that you can use the new model without touching the TensorRT-LLM codebase. To do so, place `modeling_mymodel.py` (and potentially `configuration_mymodel.py`) in your working directory, and import the modeling code in your script:
187187

188188
```python
189-
from tensorrt_llm._torch import LLM
189+
from tensorrt_llm import LLM
190190
import modeling_mymodel
191191

192192
def main():

docs/source/torch/arch_overview.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ Besides TensorRT, PyTorch can also serve as the backend for TensorRT-LLM. This d
55

66
## Top Level API
77

8-
The interface for PyTorch backend is `tensorrt._torch.LLM`.
8+
The interface for PyTorch backend is `tensorrt_llm.LLM`.
99

1010
```python
11-
from tensorrt_llm._torch import LLM
11+
from tensorrt_llm import LLM
1212
llm = LLM(model=<path_to_llama_from_hf>)
1313
```
1414

examples/pytorch/out_of_tree_example/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import modeling_opt # noqa
22

3-
from tensorrt_llm._torch import LLM
3+
from tensorrt_llm import LLM
44

55

66
def main():

examples/pytorch/quickstart.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from tensorrt_llm import SamplingParams
2-
from tensorrt_llm._torch import LLM
1+
from tensorrt_llm import LLM, SamplingParams
32

43

54
def main():

examples/pytorch/quickstart_advanced.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22

3-
from tensorrt_llm import SamplingParams
4-
from tensorrt_llm._torch import LLM
3+
from tensorrt_llm import LLM, SamplingParams
54
from tensorrt_llm.llmapi import (DraftTargetDecodingConfig, EagleDecodingConfig,
65
KvCacheConfig, MTPDecodingConfig,
76
NGramDecodingConfig, TorchCompileConfig)

examples/pytorch/star_attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
import torch
88

9-
from tensorrt_llm import SamplingParams
10-
from tensorrt_llm._torch import LLM
9+
from tensorrt_llm import LLM, SamplingParams
1110
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
1211

1312

tensorrt_llm/_torch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .llm import LLM
2+
from .model_config import MoeLoadBalancerConfig
23

3-
__all__ = ["LLM"]
4+
__all__ = ["LLM", "MoeLoadBalancerConfig"]

tensorrt_llm/bench/build/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tensorrt_llm.bench.utils.data import create_dataset_from_stream, initialize_tokenizer
1010
from tensorrt_llm.bench.utils import VALID_QUANT_ALGOS
1111
from tensorrt_llm.builder import BuildConfig
12-
from tensorrt_llm.llmapi import LLM
12+
from tensorrt_llm._tensorrt_engine import LLM
1313
from tensorrt_llm.llmapi.llm_utils import QuantConfig
1414
from tensorrt_llm.logger import logger
1515
from tensorrt_llm.quantization.mode import QuantAlgo

tensorrt_llm/llmapi/llm.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def _repr_fields(self):
9797
9898
Attributes:
9999
tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
100+
llm_id (str): The unique ID of the LLM instance.
100101
"""
101102

102103

@@ -883,6 +884,9 @@ def __init__(self,
883884
# TODO: deprecate backend in LLM kwargs
884885
kwargs.pop("backend", None)
885886

887+
# Validate that users don't pass TrtLlmArgs-specific arguments
888+
self._validate_args_for_torch_backend(kwargs)
889+
886890
super().__init__(model,
887891
tokenizer,
888892
tokenizer_mode,
@@ -895,6 +899,26 @@ def __init__(self,
895899
backend='pytorch',
896900
**kwargs)
897901

902+
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
903+
"""Validate that users don't pass TrtLlmArgs-specific arguments when using PyTorch backend.
904+
"""
905+
trtllm_fields = set(TrtLlmArgs.model_fields.keys())
906+
torchllm_fields = set(TorchLlmArgs.model_fields.keys())
907+
908+
trtllm_specific_fields = trtllm_fields - torchllm_fields
909+
910+
# Check if any TrtLlmArgs-specific arguments are passed
911+
trtllm_specific_args = []
912+
for key in kwargs:
913+
if key in trtllm_specific_fields:
914+
trtllm_specific_args.append(key)
915+
916+
if trtllm_specific_args:
917+
raise ValueError(
918+
f"The following arguments are specific to TensorRT backend and cannot be used with PyTorch backend: {trtllm_specific_args}.\n"
919+
f"Please use 'from tensorrt_llm._tensorrt_engine import LLM' instead to use the TensorRT backend."
920+
)
921+
898922

899923
class LLM(_TorchLLM):
900924

0 commit comments

Comments
 (0)