Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,10 @@ void run_blockwise_scaled_group_mm(
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(
layout_sfb.data_ptr())};

cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = a_ptrs.get_device();
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
int device_id = a_ptrs.device().index();
static const cutlass::KernelHardwareInfo hw_info{
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
device_id)};

// Epilogue Arguments
typename GemmKernel::EpilogueArguments epilogue_args{
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ async def _cross_encoding_score(
# cross_encoder models defaults to using pad_token.
tokenized_prompts = await asyncio.gather(*(
tokenize_async(
text=t1, # type: ignore[arg-type]
text_pair=t2, # type: ignore[arg-type]
text=t1, # type: ignore[arg-type]
text_pair=t2, # type: ignore[arg-type]
**tokenization_kwargs) for t1, t2 in input_pairs))
else:
# `llm as reranker` models defaults to not using pad_token.
Expand Down
29 changes: 27 additions & 2 deletions vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,10 @@ def cutlass_moe_fp4(a: torch.Tensor,
return out.to(dtype=out_dtype)


def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
def _valid_cutlass_block_scaled_grouped_gemm(
w1: torch.Tensor, w2: torch.Tensor, inplace: bool, activation: str,
apply_router_weight_on_input: bool,
expert_map: Optional[torch.Tensor]) -> bool:

def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
return N % 128 == 0 and K % 128 == 0
Expand All @@ -564,6 +566,29 @@ def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
"CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).")
return False

if expert_map is not None:
logger.debug(
"CutlassBlockScaledGroupedGemm disabled: expert_parallel is"
" not supported.")
return False

if activation != "silu":
logger.debug(
"CutlassBlockScaledGroupedGemm disabled: only activation silu is"
" supported.")
return False

if apply_router_weight_on_input:
logger.debug("CutlassBlockScaledGroupedGemm disabled:"
" apply_router_weight_on_input is not supported.")
return False

if inplace:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @djmmoss , why do we need to disable when inplace is True ?

logger.debug(
"CutlassBlockScaledGroupedGemm disabled: inplace is not supported."
)
return False

return True


Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,8 +1190,9 @@ def fused_experts(
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)):
assert apply_router_weight_on_input is False
and _valid_cutlass_block_scaled_grouped_gemm(
w1, w2, inplace, activation, apply_router_weight_on_input,
expert_map)):
return run_cutlass_block_scaled_fused_experts(
a=hidden_states,
w1=w1,
Expand Down