diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 27b47acea39..a4ce0092a0b 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -1,11 +1,14 @@ +import contextlib import json import os from dataclasses import dataclass, field from pathlib import Path from typing import Dict, Generic, List, Optional, TypeVar +import filelock import torch import transformers +from transformers.utils import HF_MODULES_CACHE from tensorrt_llm import logger from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid @@ -58,6 +61,35 @@ def get_layer_initial_global_assignments(self, layer_idx: int) -> List[int]: return None +@contextlib.contextmanager +def config_file_lock(timeout: int = 10): + """ + Context manager for file locking when loading pretrained configs. + + This prevents race conditions when multiple processes try to download/load + the same model configuration simultaneously. + + Args: + timeout: Maximum time to wait for lock acquisition in seconds + """ + # Use a single global lock file in HF cache directory + # This serializes all model loading operations to prevent race conditions + lock_path = Path(HF_MODULES_CACHE) / "_remote_code.lock" + + # Create and acquire the lock + lock = filelock.FileLock(str(lock_path), timeout=timeout) + + try: + with lock: + yield + except filelock.Timeout: + logger.warning( + f"Failed to acquire config lock within {timeout} seconds, proceeding without lock" + ) + # Fallback: proceed without locking to avoid blocking indefinitely + yield + + @dataclass(kw_only=True) class ModelConfig(Generic[TConfig]): pretrained_config: Optional[TConfig] = None @@ -358,16 +390,20 @@ def from_pretrained(cls, checkpoint_dir: str, trust_remote_code=False, **kwargs): - pretrained_config = transformers.AutoConfig.from_pretrained( - checkpoint_dir, - trust_remote_code=trust_remote_code, - ) + # Use file lock to prevent race conditions when multiple processes + # try to import/cache the same remote model config file + with config_file_lock(): + pretrained_config = transformers.AutoConfig.from_pretrained( + checkpoint_dir, + trust_remote_code=trust_remote_code, + ) + + # Find the cache path by looking for the config.json file which should be in all + # huggingface models + model_dir = Path( + transformers.utils.hub.cached_file(checkpoint_dir, + 'config.json')).parent - # Find the cache path by looking for the config.json file which should be in all - # huggingface models - model_dir = Path( - transformers.utils.hub.cached_file(checkpoint_dir, - 'config.json')).parent quant_config = QuantConfig() layer_quant_config = None moe_backend = kwargs.get('moe_backend', 'CUTLASS') diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index a4951a2012c..ddbd3eb3e91 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -281,8 +281,6 @@ triton_server/test_triton.py::test_t5_ib[t5-ib] SKIP (https://nvbugs/5456482) triton_server/test_triton_llm.py::test_gpt_speculative_decoding_bls[False-False-1---False-True-True-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-guaranteed_no_evict---1-1-1-False-ensemble] SKIP (https://nvbugs/5456485) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320) accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] SKIP (https://nvbugs/5437384) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen] SKIP (https://nvbugs/5445466) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency] SKIP (https://nvbugs/5445466) llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_mtp SKIP (https://nvbugs/5461796) accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_gather_generation_logits_cuda_graph SKIP (https://nvbugs/5365525) examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3-mini-128k-instruct] SKIP (https://nvbugs/5465143)