diff --git a/ignite/handlers/__init__.py b/ignite/handlers/__init__.py index d7f3002ced01..f5d4a09db168 100644 --- a/ignite/handlers/__init__.py +++ b/ignite/handlers/__init__.py @@ -3,7 +3,7 @@ from ignite.engine import Engine from ignite.engine.events import Events from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint -from ignite.handlers.early_stopping import EarlyStopping +from ignite.handlers.early_stopping import EarlyStopping, NoImprovementHandler from ignite.handlers.ema_handler import EMAHandler from ignite.handlers.lr_finder import FastaiLRFinder from ignite.handlers.param_scheduler import ( @@ -38,6 +38,7 @@ "Checkpoint", "DiskSaver", "Timer", + "NoImprovementHandler", "EarlyStopping", "TerminateOnNan", "global_step_from_engine", diff --git a/ignite/handlers/early_stopping.py b/ignite/handlers/early_stopping.py index 7613f73dc1e5..3ba80fce3262 100644 --- a/ignite/handlers/early_stopping.py +++ b/ignite/handlers/early_stopping.py @@ -5,36 +5,49 @@ from ignite.engine import Engine from ignite.utils import setup_logger -__all__ = ["EarlyStopping"] +__all__ = ["NoImprovementHandler", "EarlyStopping"] -class EarlyStopping(Serializable): - """EarlyStopping handler can be used to stop the training if no improvement after a given number of events. - +class NoImprovementHandler(Serializable): + """NoImprovementHandler is a generalised version of Early stopping where you can define what should + happen if no improvement occurs after a given number of events. Args: - patience: Number of events to wait if no improvement and then stop the training. + patience: Number of events to wait if no improvement and then call stop_function. score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine` object, and return a score `float`. An improvement is considered if the score is higher. + pass_function: It should be a function taking a single argument, the trainer + object, and defines what to do in the case when the stopping condition is not met. + stop_function: It should be a function taking a single argument, the trainer + object, and defines what to do in the case when the stopping condition is met. trainer: Trainer engine to stop the run if no improvement. min_delta: A minimum increase in the score to qualify as an improvement, i.e. an increase of less than or equal to `min_delta`, will count as no improvement. cumulative_delta: It True, `min_delta` defines an increase since the last `patience` reset, otherwise, it defines an increase after the last event. Default value is False. - Examples: .. code-block:: python - from ignite.engine import Engine, Events - from ignite.handlers import EarlyStopping + #Example where if the score doesn't improve a user defined value `alpha` is doubled. + from ignite.engine import Engine, Events + from ignite.handlers import NoImprovementHandler def score_function(engine): val_loss = engine.state.metrics['nll'] return -val_loss + def pass_function(engine): + pass + def stop_function(trainer): + trainer.state.alpha *= 2 + + trainer = Engine(do_nothing_update_fn) + trainer.state_dict_user_keys.append("alpha") + trainer.state.alpha = 0.1 + + h = NoImprovementHandler(patience=3, score_function=score_function, pass_function=pass_function, + stop_function=stop_function, trainer=trainer) - handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset). evaluator.add_event_handler(Events.COMPLETED, handler) - """ _state_dict_all_req_keys = ( @@ -46,7 +59,9 @@ def __init__( self, patience: int, score_function: Callable, + stop_function: Callable, trainer: Engine, + pass_function: Callable = lambda engine: 0, min_delta: float = 0.0, cumulative_delta: bool = False, ): @@ -54,37 +69,51 @@ def __init__( if not callable(score_function): raise TypeError("Argument score_function should be a function.") + if not callable(pass_function): + raise TypeError("Argument pass_function should be a function.") + + if not callable(stop_function): + raise TypeError("Argument stop_function should be a function.") + + if not isinstance(trainer, Engine): + raise TypeError("Argument trainer should be an instance of Engine.") + if patience < 1: raise ValueError("Argument patience should be positive integer.") if min_delta < 0.0: raise ValueError("Argument min_delta should not be a negative number.") - if not isinstance(trainer, Engine): - raise TypeError("Argument trainer should be an instance of Engine.") - - self.score_function = score_function self.patience = patience - self.min_delta = min_delta - self.cumulative_delta = cumulative_delta + self.score_function = score_function + self.pass_function = pass_function + self.stop_function = stop_function self.trainer = trainer self.counter = 0 self.best_score = None # type: Optional[float] + self.min_delta = min_delta + self.cumulative_delta = cumulative_delta self.logger = setup_logger(__name__ + "." + self.__class__.__name__) def __call__(self, engine: Engine) -> None: score = self.score_function(engine) + self._update_state(score) + if self.counter >= self.patience: + self.stop_function(self.trainer) + else: + self.pass_function(self.trainer) + + def _update_state(self, score: int) -> None: + if self.best_score is None: self.best_score = score + elif score <= self.best_score + self.min_delta: if not self.cumulative_delta and score > self.best_score: self.best_score = score self.counter += 1 - self.logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience)) - if self.counter >= self.patience: - self.logger.info("EarlyStopping: Stop training") - self.trainer.terminate() + else: self.best_score = score self.counter = 0 @@ -104,3 +133,62 @@ def load_state_dict(self, state_dict: Mapping) -> None: super().load_state_dict(state_dict) self.counter = state_dict["counter"] self.best_score = state_dict["best_score"] + + +class EarlyStopping(NoImprovementHandler): + """EarlyStopping handler can be used to stop the training if no improvement after a given number of events. + Args: + patience: Number of events to wait if no improvement and then stop the training. + score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine` + object, and return a score `float`. An improvement is considered if the score is higher. + trainer: Trainer engine to stop the run if no improvement. + min_delta: A minimum increase in the score to qualify as an improvement, + i.e. an increase of less than or equal to `min_delta`, will count as no improvement. + cumulative_delta: It True, `min_delta` defines an increase since the last `patience` reset, otherwise, + it defines an increase after the last event. Default value is False. + Examples: + .. code-block:: python + from ignite.engine import Engine, Events + from ignite.handlers import EarlyStopping + def score_function(engine): + val_loss = engine.state.metrics['nll'] + return -val_loss + handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) + # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset). + evaluator.add_event_handler(Events.COMPLETED, handler) + """ + + _state_dict_all_req_keys = ( + "counter", + "best_score", + ) + + def __init__( + self, + patience: int, + score_function: Callable, + trainer: Engine, + min_delta: float = 0.0, + cumulative_delta: bool = False, + ): + super(EarlyStopping, self).__init__( + patience=patience, + score_function=score_function, + pass_function=self.pass_function, + stop_function=self.stop_function, + trainer=trainer, + min_delta=min_delta, + cumulative_delta=cumulative_delta, + ) + + self.logger = setup_logger(__name__ + "." + self.__class__.__name__) + + def __call__(self, engine: Engine) -> None: + super(EarlyStopping, self).__call__(engine) + + def pass_function(self, trainer: Engine) -> None: + pass + + def stop_function(self, trainer: Engine) -> None: + self.logger.info("EarlyStopping: Stop training") + trainer.terminate() diff --git a/tests/ignite/handlers/test_early_stopping.py b/tests/ignite/handlers/test_early_stopping.py index 7382c7ec1b20..bf0f960dac2c 100644 --- a/tests/ignite/handlers/test_early_stopping.py +++ b/tests/ignite/handlers/test_early_stopping.py @@ -5,14 +5,466 @@ import ignite.distributed as idist from ignite.engine import Engine, Events -from ignite.handlers import EarlyStopping +from ignite.handlers import EarlyStopping, NoImprovementHandler def do_nothing_update_fn(engine, batch): pass -def test_args_validation(): +@pytest.fixture +def trainer(): + trainer = Engine(do_nothing_update_fn) + trainer.state_dict_user_keys.append("alpha") + trainer.state.alpha = 0.1 + return trainer + + +def test_args_validation_no_improvement_handler(): + + trainer = Engine(do_nothing_update_fn) + + with pytest.raises(TypeError, match=r"Argument score_function should be a function."): + NoImprovementHandler( + patience=2, score_function=12345, pass_function=12345, stop_function=12345, trainer=trainer + ) + + with pytest.raises(TypeError, match=r"Argument pass_function should be a function."): + NoImprovementHandler( + patience=2, score_function=lambda engine: 0, pass_function=12345, stop_function=12345, trainer=trainer + ) + + with pytest.raises(TypeError, match=r"Argument stop_function should be a function."): + NoImprovementHandler( + patience=2, + score_function=lambda engine: 0, + pass_function=lambda engine: 0, + stop_function=12345, + trainer=trainer, + ) + + with pytest.raises(TypeError, match=r"Argument trainer should be an instance of Engine."): + NoImprovementHandler( + patience=2, + score_function=lambda engine: 0, + pass_function=lambda engine: 0, + stop_function=lambda engine: 0, + trainer=None, + ) + + with pytest.raises(ValueError, match=r"Argument patience should be positive integer."): + NoImprovementHandler( + patience=-1, + score_function=lambda engine: 0, + pass_function=lambda engine: 0, + stop_function=lambda engine: 0, + trainer=trainer, + ) + + with pytest.raises(ValueError, match=r"Argument min_delta should not be a negative number."): + NoImprovementHandler( + patience=2, + min_delta=-0.1, + score_function=lambda engine: 0, + pass_function=lambda engine: 0, + stop_function=lambda engine: 0, + trainer=trainer, + ) + + +def test_simple_no_improvement_handler(trainer): + + scores = iter([1.0, 0.8, 0.88]) + + def score_function(engine): + return next(scores) + + def stop_function(trainer): + trainer.state.alpha *= 2 + + def pass_function(trainer): + pass + + h = NoImprovementHandler(patience=2, score_function=score_function, stop_function=stop_function, trainer=trainer) + # Call 3 times and check if stop_function called + assert trainer.state.alpha == 0.1 + h(None) + assert trainer.state.alpha == 0.1 + h(None) + assert trainer.state.alpha == 0.1 + h(None) + assert trainer.state.alpha == 0.2 + + +def test_pass_function_no_improvement_handler(trainer): + + scores = iter([1.0, 0.8, 0.88]) + + def score_function(engine): + return next(scores) + + def stop_function(trainer): + trainer.state.alpha = 10 + + def pass_function(trainer): + trainer.state.alpha *= 2 + + h = NoImprovementHandler( + patience=2, + score_function=score_function, + stop_function=stop_function, + pass_function=pass_function, + trainer=trainer, + ) + assert trainer.state.alpha == 0.1 + h(None) + # Pass function should double the value + assert trainer.state.alpha == 0.2 + h(None) + # Pass function should double the value + assert trainer.state.alpha == 0.4 + h(None) + # stop function should convert to 10 + assert trainer.state.alpha == 10 + + +def test_repeated_pass_stop_no_improvement_handler(trainer): + scores = iter([1.0, 0.8, 0.88, 1.1, 0.9, 0.8, 1.2]) + + def score_function(engine): + return next(scores) + + def stop_function(trainer): + trainer.state.alpha = -1 + + def pass_function(trainer): + trainer.state.alpha *= 2 + + h = NoImprovementHandler( + patience=2, + score_function=score_function, + stop_function=stop_function, + pass_function=pass_function, + trainer=trainer, + ) + assert trainer.state.alpha == 0.1 + h(None) + assert trainer.state.alpha == 0.2 + h(None) + assert trainer.state.alpha == 0.4 + h(None) + # Stop function gets called + assert trainer.state.alpha == -1 + h(None) + # Pass function resumes + assert trainer.state.alpha == -2 + h(None) + assert trainer.state.alpha == -4 + h(None) + # Stop function gets called again + assert trainer.state.alpha == -1 + + +def test_state_dict_no_improvement_handler(trainer): + + scores = iter([1.0, 0.8, 0.88]) + + def score_function(engine): + return next(scores) + + def stop_function(trainer): + trainer.state.alpha *= 10 + + h = NoImprovementHandler(patience=2, score_function=score_function, stop_function=stop_function, trainer=trainer) + + assert trainer.state.alpha == 0.1 + h(None) + assert trainer.state.alpha == 0.1 + + # Swap to new object, but maintain state + h2 = NoImprovementHandler(patience=2, score_function=score_function, stop_function=stop_function, trainer=trainer) + h2.load_state_dict(h.state_dict()) + + h2(None) + assert trainer.state.alpha == 0.1 + h2(None) + assert trainer.state.alpha == 1 + + +def test_no_improvement_handler_on_delta(trainer): + + scores = iter([1.0, 2.0, 2.01, 3.0, 3.01, 3.02]) + + def score_function(engine): + return next(scores) + + def stop_function(trainer): + trainer.state.alpha *= 10 + + h = NoImprovementHandler( + patience=2, min_delta=0.1, score_function=score_function, stop_function=stop_function, trainer=trainer + ) + + assert trainer.state.alpha == 0.1 + h(None) # counter == 0 + assert trainer.state.alpha == 0.1 + h(None) # delta == 1.0; counter == 0 + assert trainer.state.alpha == 0.1 + h(None) # delta == 0.01; counter == 1 + assert trainer.state.alpha == 0.1 + h(None) # delta == 0.99; counter == 0 + assert trainer.state.alpha == 0.1 + h(None) # delta == 0.01; counter == 1 + assert trainer.state.alpha == 0.1 + h(None) # delta == 0.01; counter == 2 + assert trainer.state.alpha == 1 + + +def test_no_improvement_handler_on_last_event_delta(trainer): + + scores = iter([0.0, 0.3, 0.6]) + + def score_function(engine): + return next(scores) + + def stop_function(trainer): + trainer.state.alpha *= 10 + + h = NoImprovementHandler( + patience=2, min_delta=0.4, score_function=score_function, stop_function=stop_function, trainer=trainer + ) + + assert trainer.state.alpha == 0.1 + h(None) # counter == 0 + assert trainer.state.alpha == 0.1 + h(None) # delta == 0.3; counter == 1 + assert trainer.state.alpha == 0.1 + h(None) # delta == 0.3; counter == 2 + assert trainer.state.alpha == 1 + + +def test_no_improvement_on_cumulative_delta(trainer): + + scores = iter([0.0, 0.3, 0.6]) + + def score_function(engine): + return next(scores) + + def stop_function(trainer): + trainer.state.alpha *= 10 + + h = NoImprovementHandler( + patience=2, + min_delta=0.4, + cumulative_delta=True, + score_function=score_function, + stop_function=stop_function, + trainer=trainer, + ) + + assert trainer.state.alpha == 0.1 + h(None) # counter == 0 + assert trainer.state.alpha == 0.1 + h(None) # delta == 0.3; counter == 1 + assert trainer.state.alpha == 0.1 + h(None) # delta == 0.6; counter == 0 + assert trainer.state.alpha == 0.1 + + +def test_simple_no_improvement_on_plateau(trainer): + def score_function(engine): + return 42 + + def stop_function(trainer): + trainer.state.alpha *= 10 + + h = NoImprovementHandler(patience=1, score_function=score_function, stop_function=stop_function, trainer=trainer) + assert trainer.state.alpha == 0.1 + h(None) + assert trainer.state.alpha == 0.1 + h(None) + assert trainer.state.alpha == 1 + + +def test_simple_no_improvement_on_plateau_then_pass(trainer): + + scores = iter([42, 42, 43, 45]) + + def score_function(engine): + return next(scores) + + def stop_function(trainer): + trainer.state.alpha = -1 + + def pass_function(trainer): + trainer.state.alpha *= 10 + + h = NoImprovementHandler( + patience=1, + score_function=score_function, + pass_function=pass_function, + stop_function=stop_function, + trainer=trainer, + ) + assert trainer.state.alpha == 0.1 + h(None) + assert trainer.state.alpha == 1 + h(None) + # Stops here due to plateu + assert trainer.state.alpha == -1 + # Pass function gets called + h(None) + assert trainer.state.alpha == -10 + # Pass function gets called + h(None) + assert trainer.state.alpha == -100 + + +def test_simple_not_trigger_no_improvement_handler(trainer): + + scores = iter([1.0, 0.8, 1.2]) + + def score_function(engine): + return next(scores) + + def stop_function(trainer): + trainer.state.alpha *= 10 + + h = NoImprovementHandler(patience=2, score_function=score_function, stop_function=stop_function, trainer=trainer) + + # Call 3 times and check if not stopped + assert trainer.state.alpha == 0.1 + h(None) + h(None) + h(None) + assert trainer.state.alpha == 0.1 + + +def test_with_engine_no_improvement_handler(trainer): + class Counter(object): + def __init__(self, count=0): + self.count = count + + n_epochs_counter = Counter() + + scores = iter([1.0, 0.8, 1.2, 1.5, 0.9, 1.0, 0.99, 1.1, 0.9, 0.9]) + + def score_function(engine): + return next(scores) + + def stop_function(trainer): + trainer.terminate() + + def pass_function(trainer): + trainer.state.alpha *= 2 + + evaluator = Engine(do_nothing_update_fn) + + h = NoImprovementHandler( + patience=3, + score_function=score_function, + pass_function=pass_function, + stop_function=stop_function, + trainer=trainer, + ) + + @trainer.on(Events.EPOCH_COMPLETED) + def evaluation(engine): + evaluator.run([0]) + n_epochs_counter.count += 1 + + evaluator.add_event_handler(Events.COMPLETED, h) + trainer.run([0], max_epochs=10) + assert trainer.state.alpha == 6.4 + assert n_epochs_counter.count == 7 + assert trainer.state.epoch == 7 + + +def test_with_engine_no_improvement_on_plateau(trainer): + class Counter(object): + def __init__(self, count=0): + self.count = count + + n_epochs_counter = Counter() + + def score_function(engine): + return 0.03899 + + def stop_function(trainer): + trainer.state.alpha = -1 + + def pass_function(trainer): + trainer.state.alpha *= 2 + + evaluator = Engine(do_nothing_update_fn) + h = NoImprovementHandler( + patience=4, + score_function=score_function, + pass_function=pass_function, + stop_function=stop_function, + trainer=trainer, + ) + + @trainer.on(Events.EPOCH_COMPLETED) + def evaluation(engine): + evaluator.run([0]) + n_epochs_counter.count += 1 + + evaluator.add_event_handler(Events.COMPLETED, h) + trainer.run([0], max_epochs=4) + # Runs 4 times so 0.1* 2^4 + assert trainer.state.alpha == 1.6 + assert n_epochs_counter.count == 4 + assert trainer.state.epoch == 4 + + trainer.run([0], max_epochs=10) + # stop_function should get called for epochs > 4 + assert trainer.state.alpha == -1 + assert n_epochs_counter.count == 10 + assert trainer.state.epoch == 10 + + +def test_with_engine_not_triggering_no_improvement_handler(trainer): + class Counter(object): + def __init__(self, count=0): + self.count = count + + n_epochs_counter = Counter() + + scores = iter([1.0, 0.8, 1.2, 1.23, 0.9, 1.0, 1.1, 1.253, 1.26, 1.2]) + + def score_function(engine): + return next(scores) + + def stop_function(trainer): + trainer.terminate() + + def pass_function(trainer): + trainer.state.alpha *= 2 + + evaluator = Engine(do_nothing_update_fn) + h = NoImprovementHandler( + patience=5, + score_function=score_function, + pass_function=pass_function, + stop_function=stop_function, + trainer=trainer, + ) + + @trainer.on(Events.EPOCH_COMPLETED) + def evaluation(engine): + evaluator.run([0]) + n_epochs_counter.count += 1 + + evaluator.add_event_handler(Events.COMPLETED, h) + trainer.run([0], max_epochs=10) + # Runs 10 times so 0.1* 2^10 + assert trainer.state.alpha == 102.4 + assert n_epochs_counter.count == 10 + assert trainer.state.epoch == 10 + + +def test_args_validation_early_stopping(): trainer = Engine(do_nothing_update_fn) @@ -247,6 +699,144 @@ def evaluation(engine): assert trainer.state.epoch == 10 +def _test_distrib_with_engine_no_improvement_handler(device): + + if device is None: + device = idist.device() + if isinstance(device, str): + device = torch.device(device) + + torch.manual_seed(12) + + class Counter(object): + def __init__(self, count=0): + self.count = count + + n_epochs_counter = Counter() + + scores = torch.tensor([1.0, 0.8, 1.2, 1.5, 0.9, 1.0, 0.99, 1.6, 0.9], requires_grad=False).to(device) + + def score_function(engine): + i = trainer.state.epoch - 1 + v = scores[i] + idist.all_reduce(v) + v /= idist.get_world_size() + return v.item() + + def stop_function(trainer): + trainer.state.alpha = -1 + + def pass_function(trainer): + trainer.state.alpha *= 2 + + trainer = Engine(do_nothing_update_fn) + trainer.state_dict_user_keys.append("alpha") + trainer.state.alpha = 0.1 + + evaluator = Engine(do_nothing_update_fn) + nih = NoImprovementHandler( + patience=3, + score_function=score_function, + stop_function=stop_function, + pass_function=pass_function, + trainer=trainer, + ) + + @trainer.on(Events.EPOCH_COMPLETED) + def evaluation(engine): + evaluator.run([0]) + n_epochs_counter.count += 1 + + evaluator.add_event_handler(Events.COMPLETED, nih) + + # Runs pass_function in this case + trainer.run([0], max_epochs=6) + assert trainer.state.alpha == 6.4 + + # Terminates and calls stop_function + trainer.run([0], max_epochs=7) + assert trainer.state.alpha == -1 + + # Unlike EarlyStopping this No Improvement handler doesnt terminate, hence will start comparing conditions again + # after stopping condition is met once also. pass_function gets called again + trainer.run([0], max_epochs=8) + assert trainer.state.alpha == -2 + + +def _test_distrib_integration_engine_no_improvement_handler(device): + + from ignite.metrics import Accuracy + + if device is None: + device = idist.device() + if isinstance(device, str): + device = torch.device(device) + metric_device = device + if device.type == "xla": + metric_device = "cpu" + + rank = idist.get_rank() + ws = idist.get_world_size() + torch.manual_seed(12) + + n_epochs = 10 + n_iters = 20 + + y_preds = ( + [torch.randint(0, 2, size=(n_iters, ws)).to(device)] + + [torch.ones(n_iters, ws).to(device)] + + [torch.randint(0, 2, size=(n_iters, ws)).to(device) for _ in range(n_epochs - 2)] + ) + + y_true = ( + [torch.randint(0, 2, size=(n_iters, ws)).to(device)] + + [torch.ones(n_iters, ws).to(device)] + + [torch.randint(0, 2, size=(n_iters, ws)).to(device) for _ in range(n_epochs - 2)] + ) + + def update(engine, _): + e = trainer.state.epoch - 1 + i = engine.state.iteration - 1 + return y_preds[e][i, rank], y_true[e][i, rank] + + evaluator = Engine(update) + acc = Accuracy(device=metric_device) + acc.attach(evaluator, "acc") + + def score_function(engine): + return engine.state.metrics["acc"] + + def stop_function(trainer): + trainer.state.alpha = -1 + + def pass_function(trainer): + trainer.state.alpha *= 2 + + trainer = Engine(lambda e, b: None) + trainer.state_dict_user_keys.append("alpha") + trainer.state.alpha = 0.1 + + nih = NoImprovementHandler( + patience=3, + score_function=score_function, + stop_function=stop_function, + pass_function=pass_function, + trainer=trainer, + ) + + @trainer.on(Events.EPOCH_COMPLETED) + def evaluation(engine): + data = list(range(n_iters)) + evaluator.run(data=data) + + evaluator.add_event_handler(Events.COMPLETED, nih) + trainer.run([0], max_epochs=4) + assert trainer.state.alpha == 1.6 + + trainer.run([0], max_epochs=5) + assert trainer.state.alpha == -1 + + def _test_distrib_with_engine_early_stopping(device): if device is None: @@ -359,6 +949,8 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): device = idist.device() _test_distrib_with_engine_early_stopping(device) _test_distrib_integration_engine_early_stopping(device) + _test_distrib_with_engine_no_improvement_handler(device) + _test_distrib_integration_engine_no_improvement_handler(device) @pytest.mark.distributed @@ -371,6 +963,8 @@ def test_distrib_hvd(gloo_hvd_executor): gloo_hvd_executor(_test_distrib_with_engine_early_stopping, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_integration_engine_early_stopping, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_with_engine_no_improvement_handler, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_integration_engine_no_improvement_handler, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -381,6 +975,8 @@ def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): device = idist.device() _test_distrib_with_engine_early_stopping(device) _test_distrib_integration_engine_early_stopping(device) + _test_distrib_with_engine_no_improvement_handler(device) + _test_distrib_integration_engine_no_improvement_handler(device) @pytest.mark.multinode_distributed @@ -391,6 +987,8 @@ def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): device = idist.device() _test_distrib_with_engine_early_stopping(device) _test_distrib_integration_engine_early_stopping(device) + _test_distrib_with_engine_no_improvement_handler(device) + _test_distrib_integration_engine_no_improvement_handler(device) @pytest.mark.tpu @@ -400,12 +998,16 @@ def test_distrib_single_device_xla(): device = idist.device() _test_distrib_with_engine_early_stopping(device) _test_distrib_integration_engine_early_stopping(device) + _test_distrib_with_engine_no_improvement_handler(device) + _test_distrib_integration_engine_no_improvement_handler(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_with_engine_early_stopping(device) _test_distrib_integration_engine_early_stopping(device) + _test_distrib_with_engine_no_improvement_handler(device) + _test_distrib_integration_engine_no_improvement_handler(device) @pytest.mark.tpu