Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions composer/callbacks/optimizer_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from composer.loggers import Logger
from composer.utils import dist

try:
from torch.distributed.tensor import DTensor
except ImportError:
DTensor = None

__all__ = ['OptimizerMonitor']


Expand Down Expand Up @@ -86,6 +91,11 @@ def batch_end(self, state: State, logger: Logger):
param_grad_norm = torch.linalg.vector_norm(p.grad)
optimizer_metrics[f'l2_norm/grad/{name}'] = param_grad_norm

# Added for compatibility with FSDP2, since all-reduce on random DTensor can have issues
for name, metric_val in optimizer_metrics.items():
if isinstance(metric_val, DTensor):
optimizer_metrics[name] = metric_val.to_local()

if state.fsdp_enabled and dist.get_world_size() > 0 and self.log_optimizer_metrics:
# If FSDP is enabled, the optimizer state lives on different ranks and must be reduced
# and combined before we can compute metrics.
Expand Down