Skip to content
Merged
3 changes: 2 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ steps:
- pytest -v -s compile/test_functionalization.py
- pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py
- pytest -v -s compile/test_fusion_all_reduce.py
- pytest -v -s compile/test_decorator.py
- pytest -v -s compile/test_noop_elimination.py
Expand Down Expand Up @@ -920,6 +919,7 @@ steps:
- vllm/worker/worker_base.py
- vllm/v1/engine/
- vllm/v1/worker/
- tests/compile/test_async_tp.py
- tests/compile/test_basic_correctness.py
- tests/compile/test_wrapper.py
- tests/distributed/
Expand All @@ -933,6 +933,7 @@ steps:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- pytest -v -s ./compile/test_async_tp.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
Expand Down
28 changes: 22 additions & 6 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,23 @@ def replacement(
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim=0,
out_dtype=self.dtype,
group_name=self.tp.device_group.group_name,
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)

return gemm_rs
Expand Down Expand Up @@ -296,15 +304,23 @@ def replacement(
scale_b: torch.Tensor,
cutlass_mm_output: torch.Tensor,
) -> torch.Tensor:
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
scale_b,
"avg",
scatter_dim=0,
out_dtype=self.dtype,
group_name=self.tp.device_group.group_name,
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
output_shape,
None, # bias
None, # result_scale
self.dtype, # out_dtype
False, # use_fast_accum
)

return gemm_rs
Expand Down