From 6fad605582059e5d10076a30d98cd831a0f83844 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Tue, 15 Jul 2025 20:11:12 -0700 Subject: [PATCH 1/2] whisper Signed-off-by: tinyinl --- cpp/tensorrt_llm/common/attentionOp.cpp | 1 + .../bertAttentionPlugin.cpp | 29 ++++++++++++++----- .../bertAttentionPlugin/bertAttentionPlugin.h | 2 ++ tensorrt_llm/functional.py | 11 ++++--- tensorrt_llm/layers/attention.py | 6 ++-- 5 files changed, 33 insertions(+), 16 deletions(-) diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index aba735f8258..06911b0be39 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -2562,6 +2562,7 @@ int AttentionOp::initialize() noexcept fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale; fmhaParams.hasAlibi = isALiBi(); fmhaParams.scaleAlibi = isAliBiWithScale(); + fmhaParams.dataTypeOut = data_type; // Load kernels from the pre-compiled cubins. mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams)); diff --git a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp index e2fab9044c0..1eabbfa554a 100644 --- a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp @@ -520,7 +520,7 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc cudaMemsetAsync(fmhaParams.outputPtr, 0, ring_block_output_size, stream); cudaMemcpyAsync(fmhaParams.tileCounterPtr, fmha_scheduler_counter_h, sizeof(uint32_t), cudaMemcpyHostToDevice, stream); - mFMHARunner->run(fmhaParams); + mFmhaDispatcher->run(fmhaParams); if (iter != 0) { invokeRecoverFromRA((T*) context_buf_, (float*) ring_softmax_accu_stats_buf_, @@ -703,8 +703,18 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc fmhaParams.vMaxNBlock = (input_seq_len + mSageAttnVBlockSize - 1) / mSageAttnVBlockSize; } + fmhaParams.totalKvSeqLen = num_tokens; + + fmhaParams.cuKvSeqLenPtr = cu_seqlens; + fmhaParams.cuMaskRowsPtr = cu_seqlens; + fmhaParams.tileCounterPtr = fmha_tile_counter_ptr; + + fmhaParams.scaleBmm1Ptr = scale_bmm1_ptr; + fmhaParams.scaleBmm2Ptr = scale_bmm2_ptr; + fmhaParams.forceFp32Acc = mFMHAForceFP32Acc; + // Run the fmha kernel. - mFMHARunner->run(fmhaParams); + mFmhaDispatcher->run(fmhaParams); sync_check_cuda_error(stream); if (mSageAttn) { @@ -946,12 +956,15 @@ int BertAttentionPlugin::initialize() noexcept fmhaParams.attentionInputLayout = AttentionInputLayout::Q_CONTIGUOUS_KV; fmhaParams.saveSoftmax = true; } - - // Load kernels from the pre-compiled cubins. - mFMHARunner.reset(new FusedMHARunnerV2(fmhaParams)); - - // Fall back to unfused MHA kernels if not supported. - mEnableContextFMHA = mFMHARunner->isFmhaSupported(); + // The KV input data type. The default is same as dataType. + fmhaParams.dataTypeKv = data_type; + fmhaParams.forceFp32Acc = false; + fmhaParams.headSizeV = mHeadSize; + + // Load kernels from the pre-compiled cubins for blackwell. + mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams)); + // Fall back to unfused MHA kernels if not supported for blackwell. + mEnableContextFMHA = mFmhaDispatcher->isSupported(); } #if ENABLE_MULTI_DEVICE diff --git a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h index 0c5fdc15b60..3e34c2477b5 100644 --- a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h +++ b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h @@ -19,6 +19,7 @@ #include "tensorrt_llm/common/cublasMMWrapper.h" #include "tensorrt_llm/common/quantization.h" #include "tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h" +#include "tensorrt_llm/kernels/fmhaDispatcher.h" #include "tensorrt_llm/kernels/gptKernels.h" #include "tensorrt_llm/plugins/common/plugin.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" @@ -116,6 +117,7 @@ class BertAttentionPlugin : public BasePlugin // The default copy constructor will leave them as nullptr. clone() shall initialize it. UniqPtrWNullCopy mFMHARunner; UniqPtrWNullCopy mCublasWrapper; + UniqPtrWNullCopy mFmhaDispatcher; }; class BertAttentionPluginCreator : public BaseCreator diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 52c96d40f59..02b76f94f43 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -30,9 +30,9 @@ from ._common import default_net, default_trtnet, precision from ._utils import (QuantModeWrapper, bf16_array, bool_array, dim_resolve_negative, dim_to_trt_axes, dims_array, - fp16_array, fp32_array, int32_array, int64_array, - np_dtype_to_trt, str_dtype_to_trt, trt_dtype_to_np, - trt_dtype_to_str) + fp16_array, fp32_array, get_sm_version, int32_array, + int64_array, np_dtype_to_trt, str_dtype_to_trt, + trt_dtype_to_np, trt_dtype_to_str) from .network import PluginInfo, set_np_weight, set_plugin_info from .plugin import TRT_LLM_PLUGIN_NAMESPACE, current_all_reduce_helper from .quantization import QuantMode @@ -5680,6 +5680,9 @@ def gpt_attention( # context fmha needs packed mask. assert attention_packed_mask is not None mask_type = AttentionMaskType.custom_mask + if get_sm_version( + ) >= 100: #and model_type == "whisper": #whisper use causal mask + mask_type = AttentionMaskType.causal mask_type_filed = trt.PluginField("mask_type", np.array([int(mask_type)], np.int32), @@ -5804,7 +5807,7 @@ def gpt_attention( if attention_mask is not None and mask_type == AttentionMaskType.custom_mask: # useFullCustomMask plug_inputs += [attention_mask] - if attention_packed_mask is not None: + if attention_packed_mask is not None and mask_type == AttentionMaskType.custom_mask: # usePackedCustomMask plug_inputs += [attention_packed_mask] if use_cache: diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index ffe07590781..f46bc9ff858 100755 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -20,8 +20,8 @@ import torch from .._common import default_net, precision -from .._utils import (fp32_array, get_sm_version, int32_array, is_same_dtype, - set_obj_attrs, trt_dtype_to_np, trt_dtype_to_str) +from .._utils import (fp32_array, int32_array, is_same_dtype, set_obj_attrs, + trt_dtype_to_np, trt_dtype_to_str) # isort: off from ..functional import ( @@ -1755,8 +1755,6 @@ def forward(self, if default_net().plugin_config.bert_attention_plugin: # TRT plugin mode assert input_lengths is not None - assert get_sm_version() < 100 or get_sm_version() >= 120, \ - "bert_attention_plugin does not support SM100" context = bert_attention( qkv, input_lengths, From 225c3d0a322df2eccd3934bbe22e301931e2399b Mon Sep 17 00:00:00 2001 From: tinyinl Date: Wed, 16 Jul 2025 01:42:02 -0700 Subject: [PATCH 2/2] changed add cmake fix on triton_server based on PR6076 Signed-off-by: tinyinl --- cpp/tensorrt_llm/common/attentionOp.cpp | 1 - jenkins/Build.groovy | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index 06911b0be39..aba735f8258 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -2562,7 +2562,6 @@ int AttentionOp::initialize() noexcept fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale; fmhaParams.hasAlibi = isALiBi(); fmhaParams.scaleAlibi = isAliBiWithScale(); - fmhaParams.dataTypeOut = data_type; // Load kernels from the pre-compiled cubins. mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams)); diff --git a/jenkins/Build.groovy b/jenkins/Build.groovy index 83e09453811..276ebbfa50f 100644 --- a/jenkins/Build.groovy +++ b/jenkins/Build.groovy @@ -427,7 +427,8 @@ def runLLMBuild(pipeline, buildFlags, tarName, is_linux_x86_64) // Build tritonserver artifacts def llmPath = sh (script: "realpath ${LLM_ROOT}",returnStdout: true).trim() - sh "cd ${LLM_ROOT}/triton_backend/inflight_batcher_llm && mkdir build && cd build && cmake .. -DTRTLLM_DIR=${llmPath} -DUSE_CXX11_ABI=ON && make -j${BUILD_JOBS} install" + // TODO: Remove after the cmake version is upgraded to 3.31.8 + sh "cd ${LLM_ROOT}/triton_backend/inflight_batcher_llm && mkdir build && cd build && cmake .. -DTRTLLM_DIR=${llmPath} -DTRITON_COMMON_REPO_TAG=r25.05 -DTRITON_CORE_REPO_TAG=r25.05 -DTRITON_THIRD_PARTY_REPO_TAG=r25.05 -DTRITON_BACKEND_REPO_TAG=r25.05 -DUSE_CXX11_ABI=ON && make -j${BUILD_JOBS} install" // Step 3: packaging wheels into tarfile sh "cp ${LLM_ROOT}/build/tensorrt_llm-*.whl TensorRT-LLM/"