Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions vllm/lora/ops/triton_ops/bgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
from .utils import get_lora_op_configs


def get_autotune_config():
return [
triton.Config({'BLOCK_N': 32}, num_warps=8),
triton.Config({'BLOCK_N': 64}, num_warps=8),
triton.Config({'BLOCK_N': 128}, num_warps=8),
triton.Config({'BLOCK_N': 256}, num_warps=8),
]


@triton.autotune(configs=get_autotune_config(),
key=['N', 'K'],
restore_value=["out_ptr"])
@triton.jit
def _bgmv_expand_slice_kernel(
input_ptr,
Expand Down Expand Up @@ -65,8 +77,8 @@ def _bgmv_expand_slice_kernel(
# sliding to next row-block
b_ptr = (lora_ptr + l0_stride * lora_index +
pid_sn * split_n_length * lora_k_stride)
c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length +
slice_offset * cn_stride)
c_ptr = (out_ptr + cur_batch * cm_stride +
pid_sn * split_n_length * cn_stride + slice_offset * cn_stride)

for n in range(0, split_n_length, BLOCK_N):
current_n = n + offset_n
Expand Down Expand Up @@ -177,7 +189,7 @@ def _bgmv_expand_slice(
EVEN_K=EVEN_K,
ADD_INPUTS=ADD_INPUTS,
CAST_TYPE=CAST_TYPE,
**config,
SPLIT_N=config["SPLIT_N"],
)
return

Expand Down
62 changes: 58 additions & 4 deletions vllm/lora/ops/triton_ops/bgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,63 @@

from vllm.utils import direct_register_custom_op

from .utils import get_lora_op_configs

def get_autotune_config():
return [
triton.Config({
'BLOCK_K': 32,
'SPLIT_K': 64
}, num_warps=8),
triton.Config({
'BLOCK_K': 32,
'SPLIT_K': 96
}, num_warps=8),
triton.Config({
'BLOCK_K': 32,
'SPLIT_K': 128
}, num_warps=8),
triton.Config({
'BLOCK_K': 32,
'SPLIT_K': 256
}, num_warps=8),
triton.Config({
'BLOCK_K': 64,
'SPLIT_K': 64
}, num_warps=8),
triton.Config({
'BLOCK_K': 64,
'SPLIT_K': 96
}, num_warps=8),
triton.Config({
'BLOCK_K': 64,
'SPLIT_K': 128
}, num_warps=8),
triton.Config({
'BLOCK_K': 128,
'SPLIT_K': 64
}, num_warps=8),
triton.Config({
'BLOCK_K': 128,
'SPLIT_K': 96
}, num_warps=8),
triton.Config({
'BLOCK_K': 128,
'SPLIT_K': 128
}, num_warps=8),
triton.Config({
'BLOCK_K': 256,
'SPLIT_K': 64
}, num_warps=8),
triton.Config({
'BLOCK_K': 256,
'SPLIT_K': 96
}, num_warps=8),
]


@triton.autotune(configs=get_autotune_config(),
key=['N', 'K'],
restore_value=["out_ptr"])
@triton.jit
def _bgmv_shrink_kernel(
input_ptr,
Expand Down Expand Up @@ -45,6 +99,9 @@ def _bgmv_shrink_kernel(
if lora_index == -1:
return

if pid_sk * BLOCK_K >= K:
return

offset_n = tl.arange(0, BLOCK_N)
offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K
a_ptr = input_ptr + cur_batch * xm_stride
Expand Down Expand Up @@ -117,8 +174,6 @@ def _bgmv_shrink(
batches = lora_indices_tensor.size(0)
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
BLOCK_N = triton.next_power_of_2(N)
# First try to load optimal config from the file
config = get_lora_op_configs("bgmv_shrink", batches, K)

grid = lambda META: (
META["SPLIT_K"],
Expand All @@ -140,7 +195,6 @@ def _bgmv_shrink(
output_tensor.stride(0),
output_tensor.stride(1),
BLOCK_N=BLOCK_N,
**config,
)
return

Expand Down
105 changes: 88 additions & 17 deletions vllm/lora/ops/triton_ops/sgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,86 @@
from .utils import _get_lora_a_ptr


def get_autotune_config():
return [
triton.Config({
'BLOCK_M': 32,
'BLOCK_N': 16,
'BLOCK_K': 32,
'SPLIT_K': 8
}),
triton.Config({
'BLOCK_M': 32,
'BLOCK_N': 16,
'BLOCK_K': 32,
'SPLIT_K': 16
}),
triton.Config({
'BLOCK_M': 32,
'BLOCK_N': 16,
'BLOCK_K': 32,
'SPLIT_K': 32
}),
triton.Config({
'BLOCK_M': 32,
'BLOCK_N': 16,
'BLOCK_K': 32,
'SPLIT_K': 64
}),
triton.Config({
'BLOCK_M': 32,
'BLOCK_N': 32,
'BLOCK_K': 32,
'SPLIT_K': 8
}),
triton.Config({
'BLOCK_M': 32,
'BLOCK_N': 32,
'BLOCK_K': 32,
'SPLIT_K': 16
}),
triton.Config({
'BLOCK_M': 32,
'BLOCK_N': 32,
'BLOCK_K': 32,
'SPLIT_K': 32
}),
triton.Config({
'BLOCK_M': 32,
'BLOCK_N': 32,
'BLOCK_K': 32,
'SPLIT_K': 64
}),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 32,
'BLOCK_K': 32,
'SPLIT_K': 8
}),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 32,
'BLOCK_K': 32,
'SPLIT_K': 16
}),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 32,
'BLOCK_K': 32,
'SPLIT_K': 32
}),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 32,
'BLOCK_K': 32,
'SPLIT_K': 64
}),
]


@triton.autotune(configs=get_autotune_config(),
key=['N', 'K'],
restore_value=["out_ptr"])
@triton.jit
def _sgmv_shrink_kernel(
input_ptr,
Expand All @@ -37,12 +117,11 @@ def _sgmv_shrink_kernel(
output_d0_stride,
output_d1_stride,
output_d2_stride, # 1
SLICE_NUM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr):
SPLIT_K: tl.constexpr):
"""
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
Expand Down Expand Up @@ -77,6 +156,7 @@ def _sgmv_shrink_kernel(
ram = cta_m_offset + tl.max_contiguous(
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)

EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
do_shrink_kernel(
pid_n,
pid_sk,
Expand Down Expand Up @@ -158,14 +238,10 @@ def _sgmv_shrink(
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device)
# TODO tuning this config
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
BLOCK_M = 32
BLOCK_N = 16
BLOCK_K = 32
SPLIT_K = 8
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
grid = (
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
SPLIT_K * len(lora_a_weights),
grid = lambda META: (
triton.cdiv(max_seq_length, META["BLOCK_M"]) * triton.cdiv(
N, META["BLOCK_N"]),
META["SPLIT_K"] * len(lora_a_weights),
batches,
)
_sgmv_shrink_kernel[grid](
Expand All @@ -186,12 +262,7 @@ def _sgmv_shrink(
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor.stride(2),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
len(lora_a_weights),
SLICE_NUM=len(lora_a_weights),
)
return

Expand Down