From 8aec957635d62a0895c6fd2b1963fe3e9a7b31f4 Mon Sep 17 00:00:00 2001 From: Nick Knyazev Date: Tue, 26 Sep 2023 13:25:03 +0100 Subject: [PATCH] EarlyStopping based no improvement interval (#1560) * adding no_improvement_stopper callback and test for it. Early stop when no loss reduce during tolerance_window asks tested python3 -m pytest test_callbacks.py example of using : ''' optimizer.register_callback("ask", ng.callbacks.EarlyStopping.no_improvement_stopper(no_imp_window)) # should not get triggered ''' * comment fix * english fix * reformated with black and fix typing in test * add typing spect --- nevergrad/optimization/callbacks.py | 27 +++++++++++++++++++++ nevergrad/optimization/test_callbacks.py | 30 +++++++++++++++++++++--- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/nevergrad/optimization/callbacks.py b/nevergrad/optimization/callbacks.py index c11c195b45..f59e77f31f 100644 --- a/nevergrad/optimization/callbacks.py +++ b/nevergrad/optimization/callbacks.py @@ -347,6 +347,11 @@ def timer(cls, max_duration: float) -> "EarlyStopping": """Early stop when max_duration seconds has been reached (from the first ask)""" return cls(_DurationCriterion(max_duration)) + @classmethod + def no_improvement_stopper(cls, tolerance_window: int) -> "EarlyStopping": + """Early stop when loss didn't reduce during tolerance_window asks""" + return cls(_LossImprovementToleranceCriterion(tolerance_window)) + class _DurationCriterion: def __init__(self, max_duration: float) -> None: @@ -357,3 +362,25 @@ def __call__(self, optimizer: base.Optimizer) -> bool: if np.isinf(self._start): self._start = time.time() return time.time() > self._start + self._max_duration + + +class _LossImprovementToleranceCriterion: + def __init__(self, tolerance_window: int) -> None: + self._tolerance_window: int = tolerance_window + self._best_value: tp.Optional[np.ndarray] = None + self._tolerance_count: int = 0 + + def __call__(self, optimizer: base.Optimizer) -> bool: + best_param = optimizer.provide_recommendation() + if best_param is None or (best_param.loss is None and best_param._losses is None): + return False + best_last_losses = best_param.losses + if self._best_value is None: + self._best_value = best_last_losses + return False + if self._best_value <= best_last_losses: + self._tolerance_count += 1 + else: + self._tolerance_count = 0 + self._best_value = best_last_losses + return self._tolerance_count > self._tolerance_window diff --git a/nevergrad/optimization/test_callbacks.py b/nevergrad/optimization/test_callbacks.py index 512add4985..4d0345185f 100644 --- a/nevergrad/optimization/test_callbacks.py +++ b/nevergrad/optimization/test_callbacks.py @@ -105,12 +105,16 @@ def test_progressbar_dump(tmp_path: Path) -> None: class _EarlyStoppingTestee: - def __init__(self) -> None: + def __init__(self, val=None, multi=False) -> None: self.num_calls = 0 + self.val = val + self.multi = False - def __call__(self, *args, **kwds) -> float: + def __call__(self, *args, **kwds) -> tp.Union[float, tp.Tuple]: self.num_calls += 1 - return np.random.rand() + if self.val is not None: + return self.val if not self.multi else (self.val, self.val) + return np.random.rand() if not self.multi else (np.random.rand(), np.random.rand()) def test_early_stopping() -> None: @@ -127,6 +131,26 @@ def test_early_stopping() -> None: assert optimizer.current_bests["minimum"].mean < 12 assert optimizer.recommend().loss < 12 # type: ignore + # test for no improvement + func = _EarlyStoppingTestee(5) + optimizer = optimizerlib.OnePlusOne(parametrization=instrum, budget=100) + no_imp_window = 7 + optimizer.register_callback( + "ask", ng.callbacks.EarlyStopping.no_improvement_stopper(no_imp_window) + ) # should get triggered + optimizer.minimize(func, verbosity=2) + assert func.num_calls == no_imp_window + 2 + + # test for no improvement multi objective + func = _EarlyStoppingTestee(5, multi=True) + optimizer = optimizerlib.OnePlusOne(parametrization=instrum, budget=100) + no_imp_window = 5 + optimizer.register_callback( + "ask", ng.callbacks.EarlyStopping.no_improvement_stopper(no_imp_window) + ) # should get triggered + optimizer.minimize(func, verbosity=2) + assert func.num_calls == no_imp_window + 2 + def test_duration_criterion() -> None: optim = optimizerlib.OnePlusOne(2, budget=100)