diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index c95192a5a1bc..7de5b05a0b05 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -265,24 +265,21 @@ def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: # when custom allreduce is disabled, this will be None - if self.disabled: + if self.disabled or not self.should_custom_ar(input): return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - if self.should_custom_ar(input): - return self.all_reduce_reg(input) + return self.all_reduce_reg(input) else: - if self.should_custom_ar(input): - # if warm up, mimic the allocation pattern - # since custom allreduce is out-of-place - return torch.empty_like(input) + # if warm up, mimic the allocation pattern + # since custom allreduce is out-of-place + return torch.empty_like(input) else: # note: outside of cuda graph context, # custom allreduce incurs a cost of cudaMemcpy, which should # be small(<=1% of overall latency) compared to the performance # gains of using custom kernels - if self.should_custom_ar(input): - return self.all_reduce_unreg(input) + return self.all_reduce_unreg(input) return None diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d3ac4eb78b15..6e1970bfed98 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -105,7 +105,7 @@ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") - group._all_reduce(tensor) + group._all_reduce_in_place(tensor) @inplace_all_reduce.register_fake def _(tensor: torch.Tensor, group_name: str) -> None: @@ -118,7 +118,7 @@ def outplace_all_reduce(tensor: torch.Tensor, group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce(tensor) + return group._all_reduce_out_place(tensor) @outplace_all_reduce.register_fake def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: @@ -338,14 +338,17 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return input_ if not supports_custom_op(): - return self._all_reduce(input_) + self._all_reduce_in_place(input_) + return input_ if self.tpu_communicator is not None and \ not self.tpu_communicator.disabled: # TPU handles Dynamo with its own logic. - return self._all_reduce(input_) + return self.tpu_communicator.all_reduce(input_) - if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_): + if self.ca_comm is not None and \ + not self.ca_comm.disabled and \ + self.ca_comm.should_custom_ar(input_): return torch.ops.vllm.outplace_all_reduce( input_, group_name=self.unique_name) else: @@ -353,25 +356,15 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: group_name=self.unique_name) return input_ - def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor: - """ - The actual all-reduce implementation. - - NOTE: This operation will be applied in-place or out-of-place. - Always assume this function modifies its input, but use the return - value as the output. - """ + def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: ca_comm = self.ca_comm + assert ca_comm is not None + assert not ca_comm.disabled + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out - # For TPUs, use TPU communicator. - tpu_comm = self.tpu_communicator - if tpu_comm is not None and not tpu_comm.disabled: - return tpu_comm.all_reduce(input_) - - if ca_comm is not None: - out = ca_comm.custom_all_reduce(input_) - if out is not None: - return out + def _all_reduce_in_place(self, input_: torch.Tensor) -> None: pynccl_comm = self.pynccl_comm if (pynccl_comm is not None and not pynccl_comm.disabled): pynccl_comm.all_reduce(input_) @@ -380,7 +373,6 @@ def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor: ipex.distributed.all_reduce(input_, group=self.device_group) else: torch.distributed.all_reduce(input_, group=self.device_group) - return input_ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size