Skip to content

Commit

Permalink
EarlyStopping based no improvement interval (#1560)
Browse files Browse the repository at this point in the history
* adding no_improvement_stopper callback and test for it.
Early stop when no loss reduce during tolerance_window asks

tested python3 -m pytest test_callbacks.py

example of using :
'''
    optimizer.register_callback("ask", ng.callbacks.EarlyStopping.no_improvement_stopper(no_imp_window))  # should not get triggered
'''

* comment fix

* english fix

* reformated with black and fix typing in test

* add typing spect
  • Loading branch information
irumata authored Sep 26, 2023
1 parent d9f767b commit 8aec957
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
27 changes: 27 additions & 0 deletions nevergrad/optimization/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,11 @@ def timer(cls, max_duration: float) -> "EarlyStopping":
"""Early stop when max_duration seconds has been reached (from the first ask)"""
return cls(_DurationCriterion(max_duration))

@classmethod
def no_improvement_stopper(cls, tolerance_window: int) -> "EarlyStopping":
"""Early stop when loss didn't reduce during tolerance_window asks"""
return cls(_LossImprovementToleranceCriterion(tolerance_window))


class _DurationCriterion:
def __init__(self, max_duration: float) -> None:
Expand All @@ -357,3 +362,25 @@ def __call__(self, optimizer: base.Optimizer) -> bool:
if np.isinf(self._start):
self._start = time.time()
return time.time() > self._start + self._max_duration


class _LossImprovementToleranceCriterion:
def __init__(self, tolerance_window: int) -> None:
self._tolerance_window: int = tolerance_window
self._best_value: tp.Optional[np.ndarray] = None
self._tolerance_count: int = 0

def __call__(self, optimizer: base.Optimizer) -> bool:
best_param = optimizer.provide_recommendation()
if best_param is None or (best_param.loss is None and best_param._losses is None):
return False
best_last_losses = best_param.losses
if self._best_value is None:
self._best_value = best_last_losses
return False
if self._best_value <= best_last_losses:
self._tolerance_count += 1
else:
self._tolerance_count = 0
self._best_value = best_last_losses
return self._tolerance_count > self._tolerance_window
30 changes: 27 additions & 3 deletions nevergrad/optimization/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,16 @@ def test_progressbar_dump(tmp_path: Path) -> None:


class _EarlyStoppingTestee:
def __init__(self) -> None:
def __init__(self, val=None, multi=False) -> None:
self.num_calls = 0
self.val = val
self.multi = False

def __call__(self, *args, **kwds) -> float:
def __call__(self, *args, **kwds) -> tp.Union[float, tp.Tuple]:
self.num_calls += 1
return np.random.rand()
if self.val is not None:
return self.val if not self.multi else (self.val, self.val)
return np.random.rand() if not self.multi else (np.random.rand(), np.random.rand())


def test_early_stopping() -> None:
Expand All @@ -127,6 +131,26 @@ def test_early_stopping() -> None:
assert optimizer.current_bests["minimum"].mean < 12
assert optimizer.recommend().loss < 12 # type: ignore

# test for no improvement
func = _EarlyStoppingTestee(5)
optimizer = optimizerlib.OnePlusOne(parametrization=instrum, budget=100)
no_imp_window = 7
optimizer.register_callback(
"ask", ng.callbacks.EarlyStopping.no_improvement_stopper(no_imp_window)
) # should get triggered
optimizer.minimize(func, verbosity=2)
assert func.num_calls == no_imp_window + 2

# test for no improvement multi objective
func = _EarlyStoppingTestee(5, multi=True)
optimizer = optimizerlib.OnePlusOne(parametrization=instrum, budget=100)
no_imp_window = 5
optimizer.register_callback(
"ask", ng.callbacks.EarlyStopping.no_improvement_stopper(no_imp_window)
) # should get triggered
optimizer.minimize(func, verbosity=2)
assert func.num_calls == no_imp_window + 2


def test_duration_criterion() -> None:
optim = optimizerlib.OnePlusOne(2, budget=100)
Expand Down

0 comments on commit 8aec957

Please sign in to comment.