diff --git a/CMakeLists.txt b/CMakeLists.txt index e4423efcc6..8d47298cbd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -135,7 +135,7 @@ if (FA2_ENABLED) # For CUDA we set the architectures on a per file basis if (VLLM_GPU_LANG STREQUAL "CUDA") - cuda_archs_loose_intersection(FA2_ARCHS "8.0;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(FA2_ARCHS "8.0+PTX" "${CUDA_ARCHS}") message(STATUS "FA2_ARCHS: ${FA2_ARCHS}") set_gencode_flags_for_srcs( diff --git a/cmake/utils.cmake b/cmake/utils.cmake index fcf1632a80..3be7655117 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -62,8 +62,8 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS) # set(SRCS ${ORIG_SRCS}) set(CXX_SRCS ${ORIG_SRCS}) - list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$") - list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$") + list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)|(hip)$") + list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)|(hip)$") # # Generate ROCm/HIP source file names from CUDA file names. @@ -80,7 +80,7 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS) set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc) add_custom_target( hipify${NAME} - COMMAND ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS} + COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS} DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS} BYPRODUCTS ${HIP_SRCS} COMMENT "Running hipify on ${NAME} extension source files.") @@ -232,11 +232,26 @@ macro(set_gencode_flags_for_srcs) "${multiValueArgs}" ${ARGN} ) foreach(_ARCH ${arg_CUDA_ARCHS}) - string(REPLACE "." "" _ARCH "${_ARCH}") - set_gencode_flag_for_srcs( - SRCS ${arg_SRCS} - ARCH "compute_${_ARCH}" - CODE "sm_${_ARCH}") + # handle +PTX suffix: generate both sm and ptx codes if requested + string(FIND "${_ARCH}" "+PTX" _HAS_PTX) + if(NOT _HAS_PTX EQUAL -1) + string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}") + string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "sm_${_STRIPPED_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "compute_${_STRIPPED_ARCH}") + else() + string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "sm_${_STRIPPED_ARCH}") + endif() endforeach() if (${arg_BUILD_PTX_FOR_ARCH}) @@ -255,15 +270,18 @@ endmacro() # # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form # `.[letter]` compute the "loose intersection" with the -# `TGT_CUDA_ARCHS` list of gencodes. +# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in +# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there +# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the +# architecture in `SRC_CUDA_ARCHS`. # The loose intersection is defined as: # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } # where `<=` is the version comparison operator. # In other words, for each version in `TGT_CUDA_ARCHS` find the highest version # in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. -# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is -# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add -# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS). +# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is +# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add +# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). # The result is stored in `OUT_CUDA_ARCHS`. # # Example: @@ -272,36 +290,63 @@ endmacro() # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) # OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" # +# Example With PTX: +# SRC_CUDA_ARCHS="8.0+PTX" +# TGT_CUDA_ARCHS="9.0" +# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) +# OUT_CUDA_ARCHS="8.0+PTX" +# function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) - list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) - set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS}) + set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}") + set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS}) + + # handle +PTX suffix: separate base arch for matching, record PTX requests + set(_PTX_ARCHS) + foreach(_arch ${_SRC_CUDA_ARCHS}) + if(_arch MATCHES "\\+PTX$") + string(REPLACE "+PTX" "" _base "${_arch}") + list(APPEND _PTX_ARCHS "${_base}") + list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") + list(APPEND _SRC_CUDA_ARCHS "${_base}") + endif() + endforeach() + list(REMOVE_DUPLICATES _PTX_ARCHS) + list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) - # if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should - # remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS + # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should + # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS set(_CUDA_ARCHS) - if ("9.0a" IN_LIST SRC_CUDA_ARCHS) - list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") - if ("9.0" IN_LIST TGT_CUDA_ARCHS_) - list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0") + if ("9.0a" IN_LIST _SRC_CUDA_ARCHS) + list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a") + if ("9.0" IN_LIST TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0") set(_CUDA_ARCHS "9.0a") endif() endif() - list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + if ("10.0a" IN_LIST _SRC_CUDA_ARCHS) + list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a") + if ("10.0" IN_LIST TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0") + set(_CUDA_ARCHS "10.0a") + endif() + endif() + + list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that # is less or equal to ARCH (but has the same major version since SASS binary # compatibility is only forward compatible within the same major version). - foreach(_ARCH ${TGT_CUDA_ARCHS_}) + foreach(_ARCH ${_TGT_CUDA_ARCHS}) set(_TMP_ARCH) # Extract the major version of the target arch string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}") - foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) + foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS}) # Extract the major version of the source arch string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}") - # Check major-version match AND version-less-or-equal + # Check version-less-or-equal, and allow PTX arches to match across majors if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) - if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) + if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) set(_TMP_ARCH "${_SRC_ARCH}") endif() else() @@ -317,6 +362,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR endforeach() list(REMOVE_DUPLICATES _CUDA_ARCHS) + + # reapply +PTX suffix to architectures that requested PTX + set(_FINAL_ARCHS) + foreach(_arch ${_CUDA_ARCHS}) + if(_arch IN_LIST _PTX_ARCHS) + list(APPEND _FINAL_ARCHS "${_arch}+PTX") + else() + list(APPEND _FINAL_ARCHS "${_arch}") + endif() + endforeach() + set(_CUDA_ARCHS ${_FINAL_ARCHS}) + set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) endfunction() diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index a22e05969d..3aa5484cbc 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -122,16 +122,24 @@ class FlashAttnFwdCombine { using ShapeLSE = cute::Shape; // (seqlen, head, batch) using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) + struct BlockCoord { + int block_m; + int block_k; + int bidb; + }; + struct SharedStorage : cute::aligned_struct<128> { cute::array_aligned> smem_lse_partial; cute::array_aligned smem_max_valid_split; cute::array_aligned> smem_o_partial; + BlockCoord block_coord; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); // Device side arguments struct Arguments { + int b; ElementPartial const* const ptr_O_partial; ShapeOPartial const shape_O_partial; StrideOPartial const stride_O_partial; @@ -149,7 +157,8 @@ class FlashAttnFwdCombine { }; // Kernel entry point API - struct Params { + struct CollectiveParams { + int b; ElementPartial const* const ptr_O_partial; ShapeOPartial const shape_O_partial; StrideOPartial const stride_O_partial; @@ -169,10 +178,11 @@ class FlashAttnFwdCombine { // Convert to underlying arguments. In this case, a simple copy for the aliased type. static - Params + CollectiveParams to_underlying_arguments(Arguments const& args) { assert(get<1>(args.shape_LSE_partial) <= kMaxSplits); return { + args.b, args.ptr_O_partial, args.shape_O_partial, args.stride_O_partial, @@ -191,33 +201,243 @@ class FlashAttnFwdCombine { }; } + struct SchedulerArguments { + int b; + int seqlen_q; + int total_q; + int num_heads; + int dv; + int const* cu_seqlens_q; + int const* seqused_q; + }; + + struct StaticTileScheduler { + struct Params {}; + static Params to_underlying_arguments(SchedulerArguments const& args) { return {}; } + + SharedStorage& shared_storage; + CUTE_DEVICE StaticTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {} + + static dim3 get_grid_shape(SchedulerArguments const& args) { + unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK); + unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM); + return {num_blocks_m, num_blocks_k, static_cast(args.b)}; + } + + CUTE_DEVICE BlockCoord get_block_coord(Params const& params) { + int block_m = blockIdx.x; + int block_k = blockIdx.y; + int bidb = blockIdx.z; + return {block_m, block_k, bidb}; + } + }; + + struct StaticVarlenTileScheduler { + // + // For varlen we have two Scheduling algos: + // 1) STANDARD, same as StaticTileScheduler + // 2) LINEARIZE_M_AND_BATCH, this flattens the tiled M dimension and + // batch dimension into a linear tile index. The grid is then a + // 2D grid of (tile_id, k_block). We then map the linear tile id + // to (m_block, bidb) in the get_block_coord function. This mapping + // is non-trivial since each batch element can have a different + // number of m_blocks. This has overhead when computing the block + // coordinates, but it is more efficient when prefills and decodes + // are mixed since in that case the STANDARD scheduling algo will + // have a lot of empty (no work) blocks in the grid. + // + + enum SchedulingAlgo { + STANDARD, // Same as StaticTileScheduler + LINEARIZE_M_AND_BATCH, // Linearize the M and batch dimensions into a single tile index + }; + + struct Params { + int b; + int num_heads; + int const* const cu_seqlens_q; + int const* const seqused_q; + SchedulingAlgo algo; + }; + + SharedStorage& shared_storage; + CUTE_DEVICE StaticVarlenTileScheduler(SharedStorage& shared_storage): shared_storage(shared_storage) {} + + static SchedulingAlgo choose_scheduling_algo(SchedulerArguments const& args) { + // Choose the scheduling algorithm based on how dense the grid of tiles that + // do actual work is. If the grid is more then 50% sparse, we linearize the M + // and batch. If the grid is more than 50% dense, we use the standard scheduling + // algorithm since its more efficient at calculating the block coordinates. + // NOTE: in varlen case args.seqlen_q is the max seqlen_q across all batches + // use lower bound to estimate when the density is more than 50% + int lower_bound_on_non_empty_tiles = cute::ceil_div(args.total_q, kBlockM); + int grid_size = args.b * cute::ceil_div(args.seqlen_q, kBlockM); + return 2 * lower_bound_on_non_empty_tiles >= grid_size ? + SchedulingAlgo::STANDARD : + SchedulingAlgo::LINEARIZE_M_AND_BATCH; + } + + static Params to_underlying_arguments(SchedulerArguments const& args) { + return { + args.b, + args.num_heads, + args.cu_seqlens_q, + args.seqused_q, + choose_scheduling_algo(args) + }; + } + + static dim3 get_grid_shape(SchedulerArguments const& args) { + unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK); + + switch (choose_scheduling_algo(args)) { + case SchedulingAlgo::STANDARD: { + unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK); + unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM); + return {num_blocks_m, num_blocks_k, static_cast(args.b)}; + } + case SchedulingAlgo::LINEARIZE_M_AND_BATCH: { + // rough worst case upper bound on the number of blocks required + // (assuming each batch has an additional partial block) + unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b; + return {num_blocks_m, num_blocks_k, 1}; + }} + + // rough worst case upper bound on the number of blocks required + // (assuming each batch has an additional partial block) + unsigned int num_blocks_m = cute::ceil_div(args.total_q * args.num_heads, kBlockM) + args.b; + return {num_blocks_m, num_blocks_k, 1}; + } + + CUTE_DEVICE BlockCoord get_block_coord_linearized_m_and_batch(Params const& params) { + int num_heads = params.num_heads; + int curr_tile_id = blockIdx.x; + + // Scan through the batches find the batch that contains the current + // tile_id. Compute using only the first warp of the block. + if (threadIdx.x < 32) { + // We compute linearized tile index start and ends for each batch + // in groups of 32 in parallel + int group_start_bidb = -(cutlass::NumThreadsPerWarp); + int group_end_bidb = 0; + int group_end_tile_id = 0; + int group_start_tile_id = 0; + int group_total_num_tiles = 0; + + int local_num_m_blocks = 0; + int local_num_m_blocks_cumulative = 0; + + do { + group_start_bidb += cutlass::NumThreadsPerWarp; + group_end_bidb += cutlass::NumThreadsPerWarp; + + auto get_num_m_blocks = [&](int bidb) { + if (bidb >= params.b) return 0; + flash::SeqlenInfo seqlen_info{bidb, 0, params.cu_seqlens_q, params.seqused_q}; + return cute::ceil_div(seqlen_info.seqlen * num_heads, Int{}()); + }; + + // Cumulative number of blocks for the next 31 batches + local_num_m_blocks = get_num_m_blocks(group_start_bidb + threadIdx.x); + local_num_m_blocks_cumulative = warp_prefix_sum(local_num_m_blocks); + // Total number of blocks for the next 32 batches + group_total_num_tiles = warp_shfl_get_last(local_num_m_blocks_cumulative); + + group_start_tile_id = group_end_tile_id; + group_end_tile_id += group_total_num_tiles; + } while (curr_tile_id >= group_end_tile_id && group_end_bidb < params.b); + + int local_batch_end_tile_id = group_start_tile_id + local_num_m_blocks_cumulative; + // Find the last batch idx in the group where `local_batch_end_tile_id <= curr_tile_id` + // these values below are now common to all threads in the warp + int batch_idx_in_group = warp_last_true_laneid(local_batch_end_tile_id <= curr_tile_id); + int batch_num_m_blocks = warp_shfl_get(local_num_m_blocks, batch_idx_in_group); + int batch_m_start_tile_id = group_start_tile_id + (batch_idx_in_group > 0 ? + warp_shfl_get(local_num_m_blocks_cumulative, batch_idx_in_group - 1) : 0); + + int bidb = group_start_bidb + batch_idx_in_group; + int block_m = curr_tile_id - batch_m_start_tile_id; + // NOTE(lucas): not sure why this causes a block_k unused warning + // just inlined `blockIdx.y` to suppress the warning + // int block_k = blockIdx.y; + // shared_storage.block_coord = {block_m, block_k, bidb}; + BlockCoord block_coord{block_m, static_cast(blockIdx.y), bidb}; + if (threadIdx.x == 0) { shared_storage.block_coord = block_coord; } + } + + __syncthreads(); + return shared_storage.block_coord; + } + + + CUTE_DEVICE BlockCoord get_block_coord_standard(Params const& params) { + int block_m = blockIdx.x; + int block_k = blockIdx.y; + int bidb = blockIdx.z; + return {block_m, block_k, bidb}; + } + + CUTE_DEVICE BlockCoord get_block_coord(Params const& params) { + switch (params.algo) { + case SchedulingAlgo::STANDARD: + return get_block_coord_standard(params); + case SchedulingAlgo::LINEARIZE_M_AND_BATCH: + return get_block_coord_linearized_m_and_batch(params); + } + return {0, 0, 0}; // Should never reach here + } + }; + + using TileScheduler = std::conditional_t< + Varlen, + StaticVarlenTileScheduler, + StaticTileScheduler + >; + + using SchedulerParams = typename TileScheduler::Params; + + struct Params { + CollectiveParams params; + SchedulerParams scheduler_params; + }; + CUTLASS_DEVICE void - operator()(Params const& params, char* smem_buf) { + operator()(Params const& kernel_params, char* smem_buf) { + CollectiveParams const& params = kernel_params.params; SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + TileScheduler tile_scheduler{shared_storage}; + Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{}); Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape>{}); Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{}); int const thread_idx = threadIdx.x; - int const m_block = blockIdx.x; - int const k_block = blockIdx.y; - int const batch = blockIdx.z; - int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); + + BlockCoord block_coord = tile_scheduler.get_block_coord(kernel_params.scheduler_params); + + int const m_block = block_coord.block_m; + int const k_block = block_coord.block_k; + int const batch = block_coord.bidb; if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { cutlass::arch::wait_on_dependent_grids(); *params.semaphore_to_reset = 0; } - if (num_splits <= 1) { return; } + flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; int const offset = seqlen_info.offset; int const seqlen = seqlen_info.seqlen; int max_idx = seqlen * get<2>(params.shape_LSE_partial); - if constexpr (Varlen) { - if (m_block * kBlockM >= max_idx) { return; } - } + + bool block_coord_valid = + block_coord.block_m < cute::ceil_div(max_idx, Int{}) && + block_coord.bidb < params.b; + if (!block_coord_valid) { return; } + + int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); + if (num_splits <= 1) { return; } cutlass::FastDivmod seqlen_divmod_dynamic(seqlen); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 11d422924b..c99efdadfe 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -25,6 +25,7 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e IsEvenK, Varlen, Element, ElementPartial, ArchTag>; typename CombineKernel::Arguments args { + params.b, static_cast(params.oaccum_ptr), {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial @@ -38,10 +39,17 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore }; - typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); - int num_blocks_k = cute::ceil_div(params.dv, kBlockK); - int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM); - dim3 grid_m(num_blocks_m, num_blocks_k, params.b); + typename CombineKernel::SchedulerArguments scheduler_args { + params.b, params.seqlen_q, params.total_q, params.h, params.dv, + params.cu_seqlens_q, params.seqused_q + }; + + typename CombineKernel::Params kernel_params = { + CombineKernel::to_underlying_arguments(args), + CombineKernel::TileScheduler::to_underlying_arguments(scheduler_args) + }; + + dim3 grid_m = CombineKernel::TileScheduler::get_grid_shape(scheduler_args); auto kernel = cutlass::device_kernel; int smem_size = CombineKernel::SharedStorageSize; if (smem_size >= 48 * 1024) { diff --git a/hopper/utils.h b/hopper/utils.h index 3f76ea66e9..3719eab920 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -646,6 +646,22 @@ CUTE_DEVICE T warp_prefix_sum(T val) { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +CUTE_DEVICE T warp_shfl_get(T val, int src_lane) { + return __shfl_sync(0xffffffff, val, src_lane); +}; + +template +CUTE_DEVICE T warp_shfl_get_last(T val) { + return __shfl_sync(0xffffffff, val, cutlass::NumThreadsPerWarp - 1); +}; + +CUTE_DEVICE int warp_last_true_laneid(bool cond) { + return __popc(__ballot_sync(0xffffffff, cond)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template CUTE_DEVICE T warp_uniform(T a) { return __shfl_sync(0xffffffff, a, 0);