diff --git a/composer/callbacks/optimizer_monitor.py b/composer/callbacks/optimizer_monitor.py index f16d84bcef..485c0037c9 100644 --- a/composer/callbacks/optimizer_monitor.py +++ b/composer/callbacks/optimizer_monitor.py @@ -9,6 +9,11 @@ from composer.loggers import Logger from composer.utils import dist +try: + from torch.distributed.tensor import DTensor +except ImportError: + DTensor = None + __all__ = ['OptimizerMonitor'] @@ -86,6 +91,11 @@ def batch_end(self, state: State, logger: Logger): param_grad_norm = torch.linalg.vector_norm(p.grad) optimizer_metrics[f'l2_norm/grad/{name}'] = param_grad_norm + # Added for compatibility with FSDP2, since all-reduce on random DTensor can have issues + for name, metric_val in optimizer_metrics.items(): + if isinstance(metric_val, DTensor): + optimizer_metrics[name] = metric_val.to_local() + if state.fsdp_enabled and dist.get_world_size() > 0 and self.log_optimizer_metrics: # If FSDP is enabled, the optimizer state lives on different ranks and must be reduced # and combined before we can compute metrics.