Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
cudaMemsetAsync(fmhaParams.outputPtr, 0, ring_block_output_size, stream);
cudaMemcpyAsync(fmhaParams.tileCounterPtr, fmha_scheduler_counter_h, sizeof(uint32_t),
cudaMemcpyHostToDevice, stream);
mFMHARunner->run(fmhaParams);
mFmhaDispatcher->run(fmhaParams);
if (iter != 0)
{
invokeRecoverFromRA<T>((T*) context_buf_, (float*) ring_softmax_accu_stats_buf_,
Expand Down Expand Up @@ -703,8 +703,18 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
fmhaParams.vMaxNBlock = (input_seq_len + mSageAttnVBlockSize - 1) / mSageAttnVBlockSize;
}

fmhaParams.totalKvSeqLen = num_tokens;

fmhaParams.cuKvSeqLenPtr = cu_seqlens;
fmhaParams.cuMaskRowsPtr = cu_seqlens;
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;

fmhaParams.scaleBmm1Ptr = scale_bmm1_ptr;
fmhaParams.scaleBmm2Ptr = scale_bmm2_ptr;
fmhaParams.forceFp32Acc = mFMHAForceFP32Acc;

// Run the fmha kernel.
mFMHARunner->run(fmhaParams);
mFmhaDispatcher->run(fmhaParams);
sync_check_cuda_error(stream);
if (mSageAttn)
{
Expand Down Expand Up @@ -946,12 +956,15 @@ int BertAttentionPlugin::initialize() noexcept
fmhaParams.attentionInputLayout = AttentionInputLayout::Q_CONTIGUOUS_KV;
fmhaParams.saveSoftmax = true;
}

// Load kernels from the pre-compiled cubins.
mFMHARunner.reset(new FusedMHARunnerV2(fmhaParams));

// Fall back to unfused MHA kernels if not supported.
mEnableContextFMHA = mFMHARunner->isFmhaSupported();
// The KV input data type. The default is same as dataType.
fmhaParams.dataTypeKv = data_type;
fmhaParams.forceFp32Acc = false;
fmhaParams.headSizeV = mHeadSize;

// Load kernels from the pre-compiled cubins for blackwell.
mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams));
// Fall back to unfused MHA kernels if not supported for blackwell.
mEnableContextFMHA = mFmhaDispatcher->isSupported();
}

#if ENABLE_MULTI_DEVICE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h"
#include "tensorrt_llm/kernels/fmhaDispatcher.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/plugins/common/plugin.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
Expand Down Expand Up @@ -116,6 +117,7 @@ class BertAttentionPlugin : public BasePlugin
// The default copy constructor will leave them as nullptr. clone() shall initialize it.
UniqPtrWNullCopy<tensorrt_llm::kernels::FusedMHARunnerV2> mFMHARunner;
UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper;
UniqPtrWNullCopy<tensorrt_llm::kernels::FmhaDispatcher> mFmhaDispatcher;
};

class BertAttentionPluginCreator : public BaseCreator
Expand Down
3 changes: 2 additions & 1 deletion jenkins/Build.groovy
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is due to #6076

Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ def runLLMBuild(pipeline, buildFlags, tarName, is_linux_x86_64)

// Build tritonserver artifacts
def llmPath = sh (script: "realpath ${LLM_ROOT}",returnStdout: true).trim()
sh "cd ${LLM_ROOT}/triton_backend/inflight_batcher_llm && mkdir build && cd build && cmake .. -DTRTLLM_DIR=${llmPath} -DUSE_CXX11_ABI=ON && make -j${BUILD_JOBS} install"
// TODO: Remove after the cmake version is upgraded to 3.31.8
sh "cd ${LLM_ROOT}/triton_backend/inflight_batcher_llm && mkdir build && cd build && cmake .. -DTRTLLM_DIR=${llmPath} -DTRITON_COMMON_REPO_TAG=r25.05 -DTRITON_CORE_REPO_TAG=r25.05 -DTRITON_THIRD_PARTY_REPO_TAG=r25.05 -DTRITON_BACKEND_REPO_TAG=r25.05 -DUSE_CXX11_ABI=ON && make -j${BUILD_JOBS} install"

// Step 3: packaging wheels into tarfile
sh "cp ${LLM_ROOT}/build/tensorrt_llm-*.whl TensorRT-LLM/"
Expand Down
11 changes: 7 additions & 4 deletions tensorrt_llm/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from ._common import default_net, default_trtnet, precision
from ._utils import (QuantModeWrapper, bf16_array, bool_array,
dim_resolve_negative, dim_to_trt_axes, dims_array,
fp16_array, fp32_array, int32_array, int64_array,
np_dtype_to_trt, str_dtype_to_trt, trt_dtype_to_np,
trt_dtype_to_str)
fp16_array, fp32_array, get_sm_version, int32_array,
int64_array, np_dtype_to_trt, str_dtype_to_trt,
trt_dtype_to_np, trt_dtype_to_str)
from .network import PluginInfo, set_np_weight, set_plugin_info
from .plugin import TRT_LLM_PLUGIN_NAMESPACE, current_all_reduce_helper
from .quantization import QuantMode
Expand Down Expand Up @@ -5680,6 +5680,9 @@ def gpt_attention(
# context fmha needs packed mask.
assert attention_packed_mask is not None
mask_type = AttentionMaskType.custom_mask
if get_sm_version(
) >= 100: #and model_type == "whisper": #whisper use causal mask
mask_type = AttentionMaskType.causal

mask_type_filed = trt.PluginField("mask_type",
np.array([int(mask_type)], np.int32),
Expand Down Expand Up @@ -5804,7 +5807,7 @@ def gpt_attention(
if attention_mask is not None and mask_type == AttentionMaskType.custom_mask:
# useFullCustomMask
plug_inputs += [attention_mask]
if attention_packed_mask is not None:
if attention_packed_mask is not None and mask_type == AttentionMaskType.custom_mask:
# usePackedCustomMask
plug_inputs += [attention_packed_mask]
if use_cache:
Expand Down
6 changes: 2 additions & 4 deletions tensorrt_llm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import torch

from .._common import default_net, precision
from .._utils import (fp32_array, get_sm_version, int32_array, is_same_dtype,
set_obj_attrs, trt_dtype_to_np, trt_dtype_to_str)
from .._utils import (fp32_array, int32_array, is_same_dtype, set_obj_attrs,
trt_dtype_to_np, trt_dtype_to_str)

# isort: off
from ..functional import (
Expand Down Expand Up @@ -1755,8 +1755,6 @@ def forward(self,
if default_net().plugin_config.bert_attention_plugin:
# TRT plugin mode
assert input_lengths is not None
assert get_sm_version() < 100 or get_sm_version() >= 120, \
"bert_attention_plugin does not support SM100"
context = bert_attention(
qkv,
input_lengths,
Expand Down