diff --git a/hydra_plugins/hyper_pbt/hyper_pbt.py b/hydra_plugins/hyper_pbt/hyper_pbt.py index a034bf6..0fe3243 100644 --- a/hydra_plugins/hyper_pbt/hyper_pbt.py +++ b/hydra_plugins/hyper_pbt/hyper_pbt.py @@ -2,13 +2,15 @@ from __future__ import annotations +import os +import shutil + import numpy as np -from ConfigSpace.hyperparameters import ( - CategoricalHyperparameter, - NormalIntegerHyperparameter, - OrdinalHyperparameter, - UniformIntegerHyperparameter, -) +from ConfigSpace.hyperparameters import (CategoricalHyperparameter, + NormalIntegerHyperparameter, + OrdinalHyperparameter, + UniformIntegerHyperparameter) + from hydra_plugins.hypersweeper import Info @@ -171,12 +173,15 @@ def tell(self, info, value): def remove_checkpoints(self, iteration: int) -> None: """Remove checkpoints.""" - import os - # Delete all files in checkpoints dir starting with iteration_{iteration} for file in os.listdir(self.checkpoint_dir): if file.startswith(f"iteration_{iteration}"): - os.remove(os.path.join(self.checkpoint_dir, file)) + file_path = os.path.join(self.checkpoint_dir, file) + if os.path.isfile(file_path): + os.remove(file_path) + else: + shutil.rmtree(file_path) + def make_pbt(configspace, pbt_args): """Make a PBT instance for optimization."""