From ee92aa89988fa8144cce5b5ee141886ed1de81b4 Mon Sep 17 00:00:00 2001 From: hjwei Date: Thu, 26 Dec 2024 01:36:55 -0800 Subject: [PATCH 1/3] FEAT(ROCM Version): Replace HIPCC version with more precise ROCm version retrieval Signed-off-by: hjwei --- setup.py | 58 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/setup.py b/setup.py index 61d2d710aa20..0b3248461f4e 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ import re import subprocess import sys +import ctypes from pathlib import Path from shutil import which from typing import Dict, List @@ -13,7 +14,7 @@ from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext from setuptools_scm import get_version -from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME def load_module_from_path(module_name, path): @@ -379,25 +380,35 @@ def _build_custom_ops() -> bool: return _is_cuda() or _is_hip() or _is_cpu() -def get_hipcc_rocm_version(): - # Run the hipcc --version command - result = subprocess.run(['hipcc', '--version'], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True) - - # Check if the command was executed successfully - if result.returncode != 0: - print("Error running 'hipcc --version'") +def get_rocm_version(): + # Get the Rocm version from the ROCM_HOME/bin/librocm-core.so + # see https://github.com/ROCm/rocm-core/blob/d11f5c20d500f729c393680a01fa902ebf92094b/rocm_version.cpp#L21 + try: + librocm_core_file = Path(ROCM_HOME) / "lib" / "librocm-core.so" + if not librocm_core_file.is_file(): + return None + librocm_core = ctypes.CDLL(librocm_core_file) + VerErrors = ctypes.c_uint32 + get_rocm_core_version = librocm_core.getROCmVersion + get_rocm_core_version.restype = VerErrors + get_rocm_core_version.argtypes = [ + ctypes.POINTER(ctypes.c_uint32), + ctypes.POINTER(ctypes.c_uint32), + ctypes.POINTER(ctypes.c_uint32), + ] + major = ctypes.c_uint32() + minor = ctypes.c_uint32() + patch = ctypes.c_uint32() + + if ( + get_rocm_core_version( + ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch) + ) + == 0 + ): + return "%d.%d.%d" % (major.value, minor.value, patch.value) return None - - # Extract the version using a regular expression - match = re.search(r'HIP version: (\S+)', result.stdout) - if match: - # Return the version string - return match.group(1) - else: - print("Could not find HIP version in the output") + except: return None @@ -479,11 +490,10 @@ def get_vllm_version() -> str: if "sdist" not in sys.argv: version += f"{sep}cu{cuda_version_str}" elif _is_hip(): - # Get the HIP version - hipcc_version = get_hipcc_rocm_version() - if hipcc_version != MAIN_CUDA_VERSION: - rocm_version_str = hipcc_version.replace(".", "")[:3] - version += f"{sep}rocm{rocm_version_str}" + # Get the Rocm Version + rocm_version = get_rocm_version() or torch.version.hip + if rocm_version and rocm_version != MAIN_CUDA_VERSION: + version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}" elif _is_neuron(): # Get the Neuron version neuron_version = str(get_neuronxcc_version()) From bd25b5590758405b745347d7035fec2f3cfee9d3 Mon Sep 17 00:00:00 2001 From: hjwei Date: Thu, 26 Dec 2024 03:38:21 -0800 Subject: [PATCH 2/3] fix ruff check with except Signed-off-by: hjwei --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0b3248461f4e..1f0d24bfb9dc 100644 --- a/setup.py +++ b/setup.py @@ -408,7 +408,7 @@ def get_rocm_version(): ): return "%d.%d.%d" % (major.value, minor.value, patch.value) return None - except: + except Exception: return None From ebda7d9ae37b26e971316d7522e2b93b0c11c56d Mon Sep 17 00:00:00 2001 From: hjwei Date: Thu, 26 Dec 2024 18:29:07 -0800 Subject: [PATCH 3/3] fix ruff & yapf lint check Signed-off-by: hjwei --- setup.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 1f0d24bfb9dc..ba6953dbdc17 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,10 @@ +import ctypes import importlib.util import logging import os import re import subprocess import sys -import ctypes from pathlib import Path from shutil import which from typing import Dict, List @@ -400,12 +400,8 @@ def get_rocm_version(): minor = ctypes.c_uint32() patch = ctypes.c_uint32() - if ( - get_rocm_core_version( - ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch) - ) - == 0 - ): + if (get_rocm_core_version(ctypes.byref(major), ctypes.byref(minor), + ctypes.byref(patch)) == 0): return "%d.%d.%d" % (major.value, minor.value, patch.value) return None except Exception: