Skip to content

Fix group offloading synchronization bug for parameter-only GroupModule's #12077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 6, 2025
17 changes: 15 additions & 2 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def _offload_to_memory(self):
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]

else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
Expand Down Expand Up @@ -303,9 +302,23 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.group.onload_leader == module:
if self.group.onload_self:
self.group.onload_()
if self.next_group is not None and not self.next_group.onload_self:

should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
if should_onload_next_group:
self.next_group.onload_()

should_synchronize = (
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
)
if should_synchronize:
# If this group didn't onload itself, it means it was asynchronously onloaded by the
# previous group. We need to synchronize the side stream to ensure parameters
# are completely loaded to proceed with forward pass. Without this, uninitialized
# weights will be used in the computation, leading to incorrect results
# Also, we should only do this synchronization if we don't already do it from the sync call in
# self.next_group.onload_, hence the `not should_onload_next_group` check.
self.group.stream.synchronize()

args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
return args, kwargs
Expand Down
Loading