Skip to content

Commit 238c991

Browse files
krishnakalyan3rohitgr7carmocca
authored
Do not force sync_dist=True on epoch end (#13364)
Co-authored-by: rohitgr7 <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 9596fab commit 238c991

File tree

3 files changed

+47
-4
lines changed

3 files changed

+47
-4
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
144144
- Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646))
145145

146146

147+
- Raised a warning instead of forcing `sync_dist=True` on epoch end ([13364](https://github.com/Lightning-AI/lightning/pull/13364))
148+
149+
147150
- Updated `val_check_interval`(int) to consider total train batches processed instead of `_batches_that_stepped` for validation check during training ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)
148151

149152

src/pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
2525
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device
2626
from pytorch_lightning.utilities.data import extract_batch_size
27+
from pytorch_lightning.utilities.distributed import distributed_available
2728
from pytorch_lightning.utilities.exceptions import MisconfigurationException
29+
from pytorch_lightning.utilities.imports import _fault_tolerant_training
2830
from pytorch_lightning.utilities.memory import recursive_detach
2931
from pytorch_lightning.utilities.metrics import metrics_to_scalars
3032
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
31-
from pytorch_lightning.utilities.warnings import WarningCache
33+
from pytorch_lightning.utilities.warnings import PossibleUserWarning, WarningCache
3234

3335
_IN_METRIC = Union[Metric, Tensor] # Do not include scalars as they were converted to tensors
3436
_OUT_METRIC = Union[Tensor, Dict[str, Tensor]]
@@ -522,12 +524,26 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
522524
cache = result_metric._forward_cache
523525
elif not on_step and result_metric.meta.on_epoch:
524526
if result_metric._computed is None:
525-
# always reduce on epoch end
526527
should = result_metric.meta.sync.should
527-
result_metric.meta.sync.should = True
528+
if not result_metric.meta.sync.should and distributed_available():
529+
# ensure sync happens for FT since during a failure, the metrics are synced and saved to the
530+
# checkpoint, so during restart, metrics on rank 0 are from the accumulated ones from the previous
531+
# run, and on other ranks, they are 0. So we need to make sure they are synced in further training
532+
# to ensure correct calculation.
533+
if _fault_tolerant_training():
534+
result_metric.meta.sync.should = True
535+
else:
536+
warning_cache.warn(
537+
f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`"
538+
" when logging on epoch level in distributed setting to accumulate the metric across"
539+
" devices.",
540+
category=PossibleUserWarning,
541+
)
528542
result_metric.compute()
529543
result_metric.meta.sync.should = should
544+
530545
cache = result_metric._computed
546+
531547
if cache is not None:
532548
if not isinstance(cache, Tensor):
533549
raise ValueError(
@@ -536,6 +552,7 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
536552
)
537553
if not result_metric.meta.enable_graph:
538554
return cache.detach()
555+
539556
return cache
540557

541558
def valid_items(self) -> Generator:

tests/tests_pytorch/core/test_metric_result_integration.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
_ResultMetric,
3535
_Sync,
3636
)
37+
from pytorch_lightning.utilities.warnings import PossibleUserWarning
3738
from tests_pytorch.helpers.runif import RunIf
39+
from tests_pytorch.helpers.utils import no_warning_call
3840

3941

4042
class DummyMetric(Metric):
@@ -456,6 +458,8 @@ def on_train_epoch_end(self) -> None:
456458
"limit_val_batches": 0,
457459
"accelerator": accelerator,
458460
"devices": devices,
461+
"enable_progress_bar": False,
462+
"enable_model_summary": False,
459463
}
460464
trainer_kwargs.update(kwargs)
461465
trainer = Trainer(**trainer_kwargs)
@@ -471,7 +475,7 @@ def on_train_epoch_end(self) -> None:
471475
)
472476
ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")
473477

474-
trainer = Trainer(**trainer_kwargs, enable_progress_bar=False, enable_model_summary=False)
478+
trainer = Trainer(**trainer_kwargs)
475479
trainer.fit(model, ckpt_path=ckpt_path)
476480
assert model.has_validated_sum
477481

@@ -659,3 +663,22 @@ def on_train_start(self):
659663
)
660664
with pytest.raises(ValueError, match=r"compute\(\)` return of.*foo' must be a tensor"):
661665
trainer.fit(model)
666+
667+
668+
@pytest.mark.parametrize("distributed_env", [True, False])
669+
def test_logger_sync_dist(distributed_env):
670+
# self.log('bar', 7, ..., sync_dist=False)
671+
meta = _Metadata("foo", "bar")
672+
meta.sync = _Sync(_should=False)
673+
result_metric = _ResultMetric(metadata=meta, is_tensor=True)
674+
result_metric.update(torch.tensor(7.0), 10)
675+
676+
warning_ctx = pytest.warns if distributed_env else no_warning_call
677+
678+
with mock.patch(
679+
"pytorch_lightning.trainer.connectors.logger_connector.result.distributed_available",
680+
return_value=distributed_env,
681+
):
682+
with warning_ctx(PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"):
683+
value = _ResultCollection._get_cache(result_metric, on_step=False)
684+
assert value == 7.0

0 commit comments

Comments
 (0)