Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import contextlib
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Generic, List, Optional, TypeVar

import filelock
import torch
import transformers
from transformers.utils import HF_MODULES_CACHE

from tensorrt_llm import logger
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
Expand Down Expand Up @@ -58,6 +61,35 @@ def get_layer_initial_global_assignments(self, layer_idx: int) -> List[int]:
return None


@contextlib.contextmanager
def config_file_lock(timeout: int = 10):
"""
Context manager for file locking when loading pretrained configs.

This prevents race conditions when multiple processes try to download/load
the same model configuration simultaneously.

Args:
timeout: Maximum time to wait for lock acquisition in seconds
"""
# Use a single global lock file in HF cache directory
# This serializes all model loading operations to prevent race conditions
lock_path = Path(HF_MODULES_CACHE) / "_remote_code.lock"

# Create and acquire the lock
lock = filelock.FileLock(str(lock_path), timeout=timeout)

try:
with lock:
yield
except filelock.Timeout:
logger.warning(
f"Failed to acquire config lock within {timeout} seconds, proceeding without lock"
)
# Fallback: proceed without locking to avoid blocking indefinitely
yield


@dataclass(kw_only=True)
class ModelConfig(Generic[TConfig]):
pretrained_config: Optional[TConfig] = None
Expand Down Expand Up @@ -358,16 +390,20 @@ def from_pretrained(cls,
checkpoint_dir: str,
trust_remote_code=False,
**kwargs):
pretrained_config = transformers.AutoConfig.from_pretrained(
checkpoint_dir,
trust_remote_code=trust_remote_code,
)
# Use file lock to prevent race conditions when multiple processes
# try to import/cache the same remote model config file
with config_file_lock():
pretrained_config = transformers.AutoConfig.from_pretrained(
checkpoint_dir,
trust_remote_code=trust_remote_code,
)

# Find the cache path by looking for the config.json file which should be in all
# huggingface models
model_dir = Path(
transformers.utils.hub.cached_file(checkpoint_dir,
'config.json')).parent

# Find the cache path by looking for the config.json file which should be in all
# huggingface models
model_dir = Path(
transformers.utils.hub.cached_file(checkpoint_dir,
'config.json')).parent
quant_config = QuantConfig()
layer_quant_config = None
moe_backend = kwargs.get('moe_backend', 'CUTLASS')
Expand Down
2 changes: 0 additions & 2 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,6 @@ triton_server/test_triton.py::test_t5_ib[t5-ib] SKIP (https://nvbugs/5456482)
triton_server/test_triton_llm.py::test_gpt_speculative_decoding_bls[False-False-1---False-True-True-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-guaranteed_no_evict---1-1-1-False-ensemble] SKIP (https://nvbugs/5456485)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] SKIP (https://nvbugs/5437384)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen] SKIP (https://nvbugs/5445466)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency] SKIP (https://nvbugs/5445466)
llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_mtp SKIP (https://nvbugs/5461796)
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_gather_generation_logits_cuda_graph SKIP (https://nvbugs/5365525)
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3-mini-128k-instruct] SKIP (https://nvbugs/5465143)
Expand Down