Skip to content

Commit

Permalink
fix: self-destruct
Browse files Browse the repository at this point in the history
  • Loading branch information
becktepe committed Oct 11, 2024
1 parent ea8b850 commit c559733
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions hydra_plugins/hyper_pbt/hyper_pbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from __future__ import annotations

import numpy as np
from ConfigSpace.hyperparameters import (CategoricalHyperparameter,
NormalIntegerHyperparameter,
OrdinalHyperparameter,
UniformIntegerHyperparameter)

from ConfigSpace.hyperparameters import (
CategoricalHyperparameter,
NormalIntegerHyperparameter,
OrdinalHyperparameter,
UniformIntegerHyperparameter,
)
from hydra_plugins.hypersweeper import Info


Expand Down Expand Up @@ -88,6 +89,7 @@ def ask(self):
self.population_id += 1
if iteration_end:
self.iteration += 1

return Info(
config=config,
budget=self.budget_per_run,
Expand Down Expand Up @@ -160,18 +162,21 @@ def tell(self, info, value):
if self.model_based:
self.fit_model(self.performance_history, self.config_history)

if self.self_destruct and self.iteration > 1:
import shutil

print(info)
# Try to remove the checkpoint without seeds
path = self.checkpoint_dir / f"{info.load_path!s}{self.checkpoint_path_typing}"
shutil.rmtree(path, ignore_errors=True)
# Try to remove the checkpoint with seeds
for s in self.seeds:
path = self.checkpoint_dir / f"{info.load_path!s}_s{s}{self.checkpoint_path_typing}"
shutil.rmtree(path, ignore_errors=True)

# Now that we have finished the iteration,
# we can safely remove all checkpoints from the previous iteration
print(f"Finished iteration {self.iteration}")
print("Remove checkpoints")
if self.self_destruct and self.iteration > 1:
self.remove_checkpoints(self.iteration - 2)

def remove_checkpoints(self, iteration: int) -> None:
"""Remove checkpoints."""
import os

# Delete all files in checkpoints dir starting with iteration_{iteration}
for file in os.listdir(self.checkpoint_dir):
if file.startswith(f"iteration_{iteration}"):
os.remove(os.path.join(self.checkpoint_dir, file))

def make_pbt(configspace, pbt_args):
"""Make a PBT instance for optimization."""
Expand Down

0 comments on commit c559733

Please sign in to comment.