Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_ARCHS)

#
Expand Down Expand Up @@ -445,8 +446,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
# (Build 8.9 for FP8)
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
"7.5;8.0;8.9+PTX" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
Expand Down Expand Up @@ -675,7 +677,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${CUDA_ARCHS}")

list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)

#
Expand Down
89 changes: 69 additions & 20 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs)
"${multiValueArgs}" ${ARGN} )

foreach(_ARCH ${arg_CUDA_ARCHS})
string(REPLACE "." "" _ARCH "${_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_ARCH}"
CODE "sm_${_ARCH}")
# handle +PTX suffix: generate both sm and ptx codes if requested
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
if(NOT _HAS_PTX EQUAL -1)
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_STRIPPED_ARCH}"
CODE "sm_${_STRIPPED_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_STRIPPED_ARCH}"
CODE "compute_${_STRIPPED_ARCH}")
else()
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_STRIPPED_ARCH}"
CODE "sm_${_STRIPPED_ARCH}")
endif()
endforeach()

if (${arg_BUILD_PTX_FOR_ARCH})
Expand All @@ -251,7 +266,10 @@ endmacro()
#
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# `<major>.<minor>[letter]` compute the "loose intersection" with the
# `TGT_CUDA_ARCHS` list of gencodes.
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
# architecture in `SRC_CUDA_ARCHS`.
# The loose intersection is defined as:
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
# where `<=` is the version comparison operator.
Expand All @@ -268,44 +286,63 @@ endmacro()
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
#
# Example With PTX:
# SRC_CUDA_ARCHS="8.0+PTX"
# TGT_CUDA_ARCHS="9.0"
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0+PTX"
#
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})

# handle +PTX suffix: separate base arch for matching, record PTX requests
set(_PTX_ARCHS)
foreach(_arch ${_SRC_CUDA_ARCHS})
if(_arch MATCHES "\\+PTX$")
string(REPLACE "+PTX" "" _base "${_arch}")
list(APPEND _PTX_ARCHS "${_base}")
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
list(APPEND _SRC_CUDA_ARCHS "${_base}")
endif()
endforeach()
list(REMOVE_DUPLICATES _PTX_ARCHS)
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)

# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
set(_CUDA_ARCHS)
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
set(_CUDA_ARCHS "9.0a")
endif()
endif()

if ("10.0a" IN_LIST SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a")
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0")
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
set(_CUDA_ARCHS "10.0a")
endif()
endif()

list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the _TGT_CUDA_ARCHS need to be sorted too?

Copy link
Collaborator Author

@LucasWilkinson LucasWilkinson May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for a more general utility you're right they should be!; but the target arches come from extract_unique_cuda_archs_ascending so they are already sorted. I can open up a now PR to refactor some of this though; kinda want to preserve current behavior as much as possible in this one


# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
# is less or equal to ARCH (but has the same major version since SASS binary
# compatibility is only forward compatible within the same major version).
foreach(_ARCH ${TGT_CUDA_ARCHS_})
foreach(_ARCH ${_TGT_CUDA_ARCHS})
set(_TMP_ARCH)
# Extract the major version of the target arch
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
# Extract the major version of the source arch
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
# Check major-version match AND version-less-or-equal
# Check version-less-or-equal, and allow PTX arches to match across majors
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
set(_TMP_ARCH "${_SRC_ARCH}")
endif()
else()
Expand All @@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
endforeach()

list(REMOVE_DUPLICATES _CUDA_ARCHS)

# reapply +PTX suffix to architectures that requested PTX
set(_FINAL_ARCHS)
foreach(_arch ${_CUDA_ARCHS})
if(_arch IN_LIST _PTX_ARCHS)
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
else()
list(APPEND _FINAL_ARCHS "${_arch}")
endif()
endforeach()
set(_CUDA_ARCHS ${_FINAL_ARCHS})

set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
endfunction()

Expand Down