Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import Any, Dict, List, Literal, Optional, Type, Union

import torch
from pydantic import Field, ValidationInfo, field_validator, model_validator
from pydantic import Field, PrivateAttr, ValidationInfo, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

from tensorrt_llm.models.modeling_utils import QuantConfig

from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, _ParallelConfig
from ...llmapi.utils import get_type_repr
from .models import ModelFactory, ModelFactoryRegistry
Expand Down Expand Up @@ -259,6 +261,18 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
)
garbage_collection_gen0_threshold: int = Field(default=20000, description="See TorchLlmArgs.")

_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)

@property
def quant_config(self) -> QuantConfig:
if self._quant_config is None:
self._quant_config = QuantConfig()
return self._quant_config

@quant_config.setter
def quant_config(self, value: QuantConfig):
self._quant_config = value

### VALIDATION #################################################################################
@field_validator("build_config", mode="before")
@classmethod
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/defs/accuracy/accuracy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tensorrt_llm.evaluate
from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
from tensorrt_llm.builder import BuildConfig
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig
Expand Down Expand Up @@ -144,7 +145,7 @@ def get_num_samples_and_threshold(self, **acc_specs):
return num_samples, threshold

def evaluate(self,
llm: Union[LLM, PyTorchLLM],
llm: Union[LLM, PyTorchLLM, AutoDeployLLM],
extra_acc_spec: Optional[str] = None,
extra_evaluator_kwargs: Optional[dict] = None,
sampling_params: Optional[SamplingParams] = None,
Expand Down
66 changes: 66 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_autodeploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
from tensorrt_llm.sampling_params import SamplingParams

from ..conftest import llm_models_root
from .accuracy_core import MMLU, CnnDailymail, LlmapiAccuracyTestHarness


class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B"
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B"

def get_default_kwargs(self):
return {
'skip_tokenizer_init': False,
'trust_remote_code': True,
'kv_cache_config': {
'enable_block_reuse': False,
},
'max_batch_size': 512,
# 131072 is the max seq len for the model
'max_seq_len': 8192,
# max num tokens is derived in the build_config, which is not used by AutoDeploy llmargs.
# Set it explicitly here to 8192 which is the default in build_config.
'max_num_tokens': 8192,
'skip_loading_weights': False,
'compile_backend': 'torch-opt',
'free_mem_ratio': 0.7,
'cuda_graph_batch_sizes': [1, 2, 4, 8, 16, 32, 64, 128, 256]
}

def get_default_sampling_params(self):
eos_id = -1
beam_width = 1
return SamplingParams(end_id=eos_id,
pad_id=eos_id,
n=beam_width,
use_beam_search=beam_width > 1)

@pytest.mark.skip_less_device_memory(32000)
def test_auto_dtype(self):
kwargs = self.get_default_kwargs()
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH,
tokenizer=self.MODEL_PATH,
**kwargs) as llm:
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)