Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
da5b20f
remove hard code cuda related API.
jikunshang Feb 29, 2024
0c6c9f5
fix ut due to API change
jikunshang Feb 29, 2024
13d2989
refactor measure_cuda_memory, remove cuda
jikunshang Mar 12, 2024
7fb8306
fix format
jikunshang Mar 12, 2024
dccc37d
fix
jikunshang Mar 12, 2024
028834a
remove hard code cuda related API.
jikunshang Feb 29, 2024
dcb4313
add xpu build, dependency and kernels
jikunshang Feb 27, 2024
92bdeae
add xpu
jikunshang Feb 29, 2024
87fa3aa
enable tensor parallel on pvc
jikunshang Mar 6, 2024
4705fb0
add doc
jikunshang Mar 6, 2024
0b9faa2
rebase and fix
jikunshang Mar 14, 2024
476fcad
fix compiler warnings and rename headers
abhilash1910 Feb 28, 2024
29cc50d
fix compiler warnings and rename headers
abhilash1910 Feb 28, 2024
f31e815
fix compiler warnings and rename headers
abhilash1910 Feb 28, 2024
875e09c
add xpu header
abhilash1910 Feb 28, 2024
9cb705b
fix
jikunshang Mar 4, 2024
992c303
default build xpu
jikunshang Mar 4, 2024
b79fa7f
fix build issues
jikunshang Mar 4, 2024
182053d
use torch_sdpa backend and other fix
jikunshang Mar 14, 2024
b46cf23
enable kernel uts
jikunshang Feb 27, 2024
178fa39
fix sdpa
jikunshang Mar 15, 2024
4e7c357
fix get_device()
jikunshang Mar 15, 2024
e2c2cda
fix sdpa tensor shape
jikunshang Mar 18, 2024
5115131
reset setup.py
jikunshang Mar 19, 2024
317e0a7
add xpu build cmake system
jikunshang Mar 12, 2024
99e3b34
add sycl cmake build system
jikunshang Mar 19, 2024
3fe9067
add kernels
jikunshang Mar 19, 2024
e4a5ba3
minor
jikunshang Mar 20, 2024
054583a
fix
jikunshang Mar 20, 2024
65ce184
fix
jikunshang Mar 20, 2024
dfb5a89
update requirements
jikunshang Mar 20, 2024
eb1ec5c
minor
jikunshang Mar 20, 2024
d828af0
update requirements
jikunshang Mar 21, 2024
6cabc4e
refactor worker
jikunshang Mar 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 55 additions & 23 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ elseif(HIP_FOUND)
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} "
"expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.")
endif()
elseif(VLLM_BUILD_XPU_OPS)
message(STATUS "Building XPU")
set(VLLM_GPU_LANG "SYCL")
else()
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
endif()
Expand Down Expand Up @@ -136,17 +139,41 @@ endif()
# _C extension
#

set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
"csrc/attention/attention_kernels.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")

if(VLLM_GPU_LANG STREQUAL "SYCL")
append_cmake_prefix_path("intel_extension_for_pytorch" "intel_extension_for_pytorch.cmake_prefix_path")
find_package(IPEX REQUIRED)
# IPEX will overwrite TORCH_LIBRARIES, so re-add torch_python lib.
append_torchlib_if_found(torch_python)
include_directories(${IPEX_INCLUDE_DIRS})
set(CMPLR_ROOT $ENV{CMPLR_ROOT})
set(CMAKE_CXX_COMPILER icpx)
set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl)
set(VLLM_EXT_SRC
"csrc/xpu/activation_xpu.cpp"
"csrc/xpu/attention_xpu.cpp"
"csrc/xpu/cache_ops_xpu.cpp"
"csrc/xpu/gemm_kernels_xpu.cpp"
"csrc/xpu/layernorm_xpu.cpp"
"csrc/xpu/pos_encoding_xpu.cpp"
"csrc/xpu/utils.cpp"
"csrc/pybind.cpp")
list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" "-fsycl" "-fsycl-targets=spir64")
list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64")
list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "dnnl" "intel-ext-pt-gpu" )
else()
set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
"csrc/attention/attention_kernels.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")
endif()

if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
Expand All @@ -161,25 +188,30 @@ define_gpu_extension_target(
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
LINK_FLAGS ${VLLM_GPU_LINK_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${VLLM_EXTRA_INCLUDE_DIRECTORIES}
LIBRARIES ${VLLM_LINK_LIBRARIES}
WITH_SOABI)

#
# _moe_C extension
#

set(VLLM_MOE_EXT_SRC
"csrc/moe/moe_ops.cpp"
"csrc/moe/topk_softmax_kernels.cu")
if(NOT VLLM_GPU_LANG STREQUAL "SYCL")
set(VLLM_MOE_EXT_SRC
"csrc/moe/moe_ops.cpp"
"csrc/moe/topk_softmax_kernels.cu")

define_gpu_extension_target(
_moe_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_MOE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
WITH_SOABI)
define_gpu_extension_target(
_moe_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_MOE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
WITH_SOABI)
endif()

#
# _punica_C extension
Expand Down Expand Up @@ -259,7 +291,7 @@ endif()
#
add_custom_target(default)

if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP" OR VLLM_GPU_LANG STREQUAL "SYCL")
message(STATUS "Enabling C extension.")
add_dependencies(default _C)
endif()
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def main(args: argparse.Namespace):
"--device",
type=str,
default="cuda",
choices=["cuda"],
choices=["cuda", "xpu"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument(
"--enable-prefix-caching",
Expand Down
8 changes: 6 additions & 2 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc")

endif()
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
endfunction()
Expand Down Expand Up @@ -282,6 +281,7 @@ endmacro()
# COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
# LINK_LIBRARIES <libraries> - Extra link libraries.
# LINK_FLAGS <flags> - Extra link flags.
# WITH_SOABI - Generate library with python SOABI suffix name.
#
# Note: optimization level/debug info is set via cmake build type.
Expand All @@ -291,7 +291,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
GPU
"WITH_SOABI"
"DESTINATION;LANGUAGE"
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES;LINK_FLAGS")

# Add hipify preprocessing step when building with HIP/ROCm.
if (GPU_LANGUAGE STREQUAL "HIP")
Expand Down Expand Up @@ -329,6 +329,10 @@ function (define_gpu_extension_target GPU_MOD_NAME)

target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES}
${GPU_LIBRARIES})
if (GPU_LANGUAGE STREQUAL "SYCL")
target_compile_options(${GPU_MOD_NAME} PRIVATE ${GPU_COMPILE_FLAGS})
target_link_options(${GPU_MOD_NAME} PRIVATE ${GPU_LINK_FLAGS})
endif()

install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION})
endfunction()
11 changes: 10 additions & 1 deletion csrc/pybind.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include "dispatch_utils.h"
#include <torch/extension.h>

#ifdef VLLM_BUILD_XPU_OPS
#include "xpu/xpu_ops.h"
int get_device_attribute(
int attribute,
int device_id) { return 94387; }
#endif

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
Expand Down Expand Up @@ -75,7 +83,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"moe_align_block_size",
&moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");

// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def(
Expand Down Expand Up @@ -108,6 +115,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Gets the maximum shared memory per block device attribute.");

#ifndef USE_ROCM
#ifndef VLLM_BUILD_XPU_OPS
// Custom all-reduce kernels
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
Expand All @@ -122,5 +130,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers");
#endif
#endif

}
Loading