Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ sphinx-book-theme==1.0.1
sphinx-copybutton==0.5.2
myst-parser==2.0.0
sphinx-argparse==0.4.0
msgspec

# packages to install to build the documentation
pydantic
Expand Down
47 changes: 39 additions & 8 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,54 @@
import torch

from .interface import Platform, PlatformEnum, UnspecifiedPlatform

current_platform: Platform

# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.

is_tpu = False
try:
import torch_xla.core.xla_model as xm
xm.xla_device(devkind="TPU")
is_tpu = True
except Exception:
pass
Comment on lines +5 to +15
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @WoosukKwon

we don't use any successful import as a flag. Instead, only when some device code executes successfully, then we trust that we are in the current platform.

technically, we can install libtpu python package for any platform.


is_cuda = False

try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
is_cuda = True
finally:
pynvml.nvmlShutdown()
except Exception:
pass

is_rocm = False

try:
import libtpu
except ImportError:
libtpu = None
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass

if libtpu is not None:
if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif torch.version.cuda is not None:
elif is_cuda:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif torch.version.hip is not None:
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
else:
Expand Down