Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
print_warning_once, resolve_obj_by_qualname)
print_warning_once, random_uuid,
resolve_obj_by_qualname)

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -2408,6 +2409,7 @@ class VllmConfig:
init=True) # type: ignore
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore
instance_id: str = ""

@staticmethod
def get_graph_batch_size(batch_size: int) -> int:
Expand Down Expand Up @@ -2565,6 +2567,9 @@ def __post_init__(self):

current_platform.check_and_update_config(self)

if not self.instance_id:
self.instance_id = random_uuid()[:5]

def __str__(self):
return ("model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
Expand Down
6 changes: 0 additions & 6 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
VLLM_RPC_BASE_PATH: str = tempfile.gettempdir()
VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_INSTANCE_ID: Optional[str] = None
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
Expand Down Expand Up @@ -175,11 +174,6 @@ def get_default_config_root():
"VLLM_USE_MODELSCOPE":
lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true",

# Instance id represents an instance of the VLLM. All processes in the same
# instance should have the same instance id.
"VLLM_INSTANCE_ID":
lambda: os.environ.get("VLLM_INSTANCE_ID", None),

# Interval in seconds to log a warning message when the ring buffer is full
"VLLM_RINGBUFFER_WARNING_INTERVAL":
lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")),
Expand Down
6 changes: 1 addition & 5 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async)
from vllm.utils import get_distributed_init_method, get_open_port, make_async
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand All @@ -31,9 +30,6 @@ def _init_executor(self) -> None:
# Environment variables for CPU executor
#

# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()

# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

Expand Down
5 changes: 1 addition & 4 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
cuda_is_initialized, get_distributed_init_method,
get_open_port, get_vllm_instance_id, make_async,
get_open_port, make_async,
update_environment_variables)

if HAS_TRITON:
Expand All @@ -37,9 +37,6 @@ def _init_executor(self) -> None:
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size

# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()

# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

Expand Down
7 changes: 1 addition & 6 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, get_vllm_instance_id,
make_async)
get_ip, get_open_port, make_async)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down Expand Up @@ -220,14 +219,10 @@ def sort_by_driver_then_worker_ip(worker):
" environment variable, make sure it is unique for"
" each node.")

VLLM_INSTANCE_ID = get_vllm_instance_id()

# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
**({
Expand Down
7 changes: 1 addition & 6 deletions vllm/executor/ray_hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, get_vllm_instance_id,
make_async)
get_ip, get_open_port, make_async)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down Expand Up @@ -196,12 +195,8 @@ def sort_by_driver_then_worker_ip(worker):
"environment variable, make sure it is unique for"
" each node.")

VLLM_INSTANCE_ID = get_vllm_instance_id()

# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for (node_id, _) in worker_node_and_gpu_ids]
Expand Down
6 changes: 1 addition & 5 deletions vllm/executor/ray_tpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
make_async)

if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand Down Expand Up @@ -144,12 +144,8 @@ def sort_by_driver_then_worker_ip(worker):
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)

VLLM_INSTANCE_ID = get_vllm_instance_id()

# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for _ in worker_node_and_gpu_ids]
Expand Down
6 changes: 1 addition & 5 deletions vllm/executor/ray_xpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync
from vllm.executor.xpu_executor import XPUExecutor
from vllm.logger import init_logger
from vllm.utils import get_vllm_instance_id, make_async
from vllm.utils import make_async

logger = init_logger(__name__)

Expand All @@ -17,12 +17,8 @@ def _get_env_vars_to_be_updated(self):
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)

VLLM_INSTANCE_ID = get_vllm_instance_id()

# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for (_, _) in worker_node_and_gpu_ids]
Expand Down
25 changes: 9 additions & 16 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from collections.abc import Iterable, Mapping
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
Type, TypeVar, Union, overload)
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generic, Hashable, List, Literal, Optional,
OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
from uuid import uuid4

import numpy as np
Expand All @@ -43,6 +43,9 @@
from vllm.logger import enable_trace_function_call, init_logger
from vllm.platforms import current_platform

if TYPE_CHECKING:
from vllm.config import VllmConfig

logger = init_logger(__name__)

# Exception strings for non-implemented encoder/decoder scenarios
Expand Down Expand Up @@ -335,17 +338,6 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex)


@lru_cache(maxsize=None)
def get_vllm_instance_id() -> str:
"""
If the environment variable VLLM_INSTANCE_ID is set, return it.
Otherwise, return a random UUID.
Instance id represents an instance of the VLLM. All processes in the same
instance should have the same instance id.
"""
return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}"


@lru_cache(maxsize=None)
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
Expand Down Expand Up @@ -997,7 +989,7 @@ def find_nccl_library() -> str:
return so_file


def enable_trace_function_call_for_thread() -> None:
def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""
Expand All @@ -1009,7 +1001,8 @@ def enable_trace_function_call_for_thread() -> None:
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
log_path = os.path.join(tmp_dir, "vllm",
f"vllm-instance-{vllm_config.instance_id}",
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def init_worker(self, *args, **kwargs):
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
enable_trace_function_call_for_thread()
enable_trace_function_call_for_thread(self.vllm_config)

# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
Expand Down
Loading