diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py index e6cf9fe37d..ffef62c715 100644 --- a/mmengine/logging/logger.py +++ b/mmengine/logging/logger.py @@ -43,6 +43,28 @@ def filter(self, record: LogRecord) -> bool: self.seen.add(record.msg) return True return False + + +class FilterMasterRank(logging.Filter): + """Filter log messages from non-master processes in distributed training. + + Args: + name (str): name of the filter. Defaults to 'mmengine'. + """ + + def __init__(self, name: str = 'mmengine') -> None: + super().__init__(name) + + def filter(self, record: LogRecord) -> bool: + """Filter the log message of non-master processes. + + Args: + record (LogRecord): The log record. + + Returns: + bool: True if the log is from master process (rank 0). + """ + return int(os.environ.get("LOCAL_RANK", 0)) == 0 class MMFormatter(logging.Formatter): @@ -221,6 +243,7 @@ def __init__(self, else: stream_handler.setLevel(logging.ERROR) stream_handler.addFilter(FilterDuplicateWarning(logger_name)) + stream_handler.addFilter(FilterMasterRank(logger_name)) self.handlers.append(stream_handler) if log_file is not None: @@ -267,6 +290,7 @@ def __init__(self, MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S')) file_handler.setLevel(log_level) file_handler.addFilter(FilterDuplicateWarning(logger_name)) + file_handler.addFilter(FilterMasterRank(logger_name)) self.handlers.append(file_handler) self._log_file = log_file diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 2bf5f50f7c..e07ceabbcd 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -19,7 +19,7 @@ from mmengine.logging import print_log from mmengine.model import BaseTTAModel, is_model_wrapper from mmengine.utils import (apply_to, deprecated_function, digit_version, - mkdir_or_exist) + mkdir_or_exist,) from mmengine.utils.dl_utils import load_url # `MMENGINE_HOME` is the highest priority directory to save checkpoints @@ -810,6 +810,6 @@ def find_latest_checkpoint(path: str) -> Optional[str]: with open(save_file) as f: last_saved = f.read().strip() else: - print_log('Did not find last_checkpoint to be resumed.') + print_log('Did not find last_checkpoint to be resumed.', logger='current') last_saved = None return last_saved diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 5a678db7b9..d8ccc182ac 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -162,27 +162,39 @@ def __init__(self, dataloader: DataLoader) -> None: def __iter__(self): return self - def __next__(self) -> Sequence[dict]: + def skip_iter(self, iter: int) -> None: + for _ in range(iter): + self._next_data(skip_loading=True) + + def __next__(self) -> Union[Sequence[dict], None]: + return self._next_data() + + def _next_data(self, skip_loading=False) -> Union[Sequence[dict], None]: + data = None try: - data = next(self._iterator) + if skip_loading: + self._iterator._next_index() + else: + data = next(self._iterator) except StopIteration: print_log( - 'Reach the end of the dataloader, it will be ' - 'restarted and continue to iterate. It is ' - 'recommended to use ' - '`mmengine.dataset.InfiniteSampler` to enable the ' - 'dataloader to iterate infinitely.', - logger='current', - level=logging.WARNING) + "Reach the end of the dataloader, it will be " + "restarted and continue to iterate. It is " + "recommended to use " + "`mmengine.dataset.InfiniteSampler` to enable the " + "dataloader to iterate infinitely.", + logger="current", + level=logging.WARNING, + ) self._epoch += 1 - if hasattr(self._dataloader, 'sampler') and hasattr( - self._dataloader.sampler, 'set_epoch'): + if hasattr(self._dataloader, "sampler") and hasattr(self._dataloader.sampler, "set_epoch"): # In case the` _SingleProcessDataLoaderIter` has no sampler, # or data loader uses `SequentialSampler` in Pytorch. self._dataloader.sampler.set_epoch(self._epoch) - elif hasattr(self._dataloader, 'batch_sampler') and hasattr( - self._dataloader.batch_sampler.sampler, 'set_epoch'): + elif hasattr(self._dataloader, "batch_sampler") and hasattr( + self._dataloader.batch_sampler.sampler, "set_epoch" + ): # In case the` _SingleProcessDataLoaderIter` has no batch # sampler. batch sampler in pytorch warps the sampler as its # attributes. @@ -280,8 +292,7 @@ def run(self) -> None: 'that has already been trained', logger='current', level=logging.WARNING) - for _ in range(self._iter): - next(self.dataloader_iterator) + self.dataloader_iterator.skip_iter(self._iter) while self._iter < self._max_iters and not self.stop_training: self.runner.model.train() @@ -299,7 +310,7 @@ def run(self) -> None: self.runner.call_hook('after_train') return self.runner.model - def run_iter(self, data_batch: Sequence[dict]) -> None: + def run_iter(self, data_batch: Union[Sequence[dict], None]) -> None: """Iterate one mini-batch. Args: