diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 6b6871f9dc2a..38f291f5203c 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -245,7 +245,6 @@ def _offload_to_memory(self): param.data = self.cpu_param_dict[param] for buffer in self.buffers: buffer.data = self.cpu_param_dict[buffer] - else: for group_module in self.modules: group_module.to(self.offload_device, non_blocking=False) @@ -303,9 +302,23 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): if self.group.onload_leader == module: if self.group.onload_self: self.group.onload_() - if self.next_group is not None and not self.next_group.onload_self: + + should_onload_next_group = self.next_group is not None and not self.next_group.onload_self + if should_onload_next_group: self.next_group.onload_() + should_synchronize = ( + not self.group.onload_self and self.group.stream is not None and not should_onload_next_group + ) + if should_synchronize: + # If this group didn't onload itself, it means it was asynchronously onloaded by the + # previous group. We need to synchronize the side stream to ensure parameters + # are completely loaded to proceed with forward pass. Without this, uninitialized + # weights will be used in the computation, leading to incorrect results + # Also, we should only do this synchronization if we don't already do it from the sync call in + # self.next_group.onload_, hence the `not should_onload_next_group` check. + self.group.stream.synchronize() + args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) return args, kwargs