diff --git a/smac/acquisition/function/abstract_acquisition_function.py b/smac/acquisition/function/abstract_acquisition_function.py index 519f5b3d0f..a0fb3533f5 100644 --- a/smac/acquisition/function/abstract_acquisition_function.py +++ b/smac/acquisition/function/abstract_acquisition_function.py @@ -23,6 +23,13 @@ class AbstractAcquisitionFunction: def __init__(self) -> None: self._model: AbstractModel | None = None + def close(self): + if self._model: + self._model.close() + + def __del__(self): + self.close() + @property def name(self) -> str: """Returns the full name of the acquisition function.""" diff --git a/smac/acquisition/maximizer/abstract_acqusition_maximizer.py b/smac/acquisition/maximizer/abstract_acqusition_maximizer.py index e148cb6ede..4892e021cb 100644 --- a/smac/acquisition/maximizer/abstract_acqusition_maximizer.py +++ b/smac/acquisition/maximizer/abstract_acqusition_maximizer.py @@ -46,6 +46,13 @@ def __init__( self._seed = seed self._rng = np.random.RandomState(seed=seed) + def close(self): + if self.acquisition_function: + self.acquisition_function.close() + + def __del__(self): + self.close() + @property def acquisition_function(self) -> AbstractAcquisitionFunction | None: """The acquisition function used for maximization.""" diff --git a/smac/facade/abstract_facade.py b/smac/facade/abstract_facade.py index 9a2031099f..e0e4a75309 100644 --- a/smac/facade/abstract_facade.py +++ b/smac/facade/abstract_facade.py @@ -230,6 +230,19 @@ def __init__( # every time new information are available self._optimizer.register_callback(self._intensifier.get_callback(), index=0) + def close(self): + if self._model: + self._model.close() + if self._acquisition_function: + self._acquisition_function.close() + if self._acquisition_maximizer: + self._acquisition_maximizer.close() + if self._config_selector: + self._config_selector.close() + + def __del__(self): + self.close() + @property def scenario(self) -> Scenario: """The scenario object which holds all environment information.""" diff --git a/smac/intensifier/abstract_intensifier.py b/smac/intensifier/abstract_intensifier.py index b7a5ae1ca6..5eed76d660 100644 --- a/smac/intensifier/abstract_intensifier.py +++ b/smac/intensifier/abstract_intensifier.py @@ -62,7 +62,7 @@ def __init__( ): self._scenario = scenario self._config_selector: ConfigSelector | None = None - self._config_generator: Iterator[ConfigSelector] | None = None + self._config_generator: Iterator[Configuration] | None = None self._runhistory: RunHistory | None = None if seed is None: @@ -80,6 +80,13 @@ def __init__( # Reset everything self.reset() + def close(self): + if self._config_selector: + self._config_selector.close() + + def __del__(self): + self.close() + def reset(self) -> None: """Reset the internal variables of the intensifier.""" self._tf_seeds: list[int] = [] diff --git a/smac/main/config_selector.py b/smac/main/config_selector.py index 4e3574d589..ea688ace74 100644 --- a/smac/main/config_selector.py +++ b/smac/main/config_selector.py @@ -82,6 +82,17 @@ def __init__( # Processed configurations should be stored here; this is important to not return the same configuration twice self._processed_configs: list[Configuration] = [] + def close(self): + if self._model: + self._model.close() + if self._acquisition_maximizer: + self._acquisition_maximizer.close() + if self._acquisition_function: + self._acquisition_function.close() + + def __del__(self): + self.close() + def _set_components( self, initial_design: AbstractInitialDesign, diff --git a/smac/model/abstract_model.py b/smac/model/abstract_model.py index 80a7312c44..f6106a2028 100644 --- a/smac/model/abstract_model.py +++ b/smac/model/abstract_model.py @@ -82,6 +82,12 @@ def __init__( # Initial types array which is used to reset the type array at every call to `self.train()` self._initial_types = copy.deepcopy(self._types) + def close(self): + pass + + def __del__(self): + self.close() + @property def meta(self) -> dict[str, Any]: """Returns the meta data of the created object.""" diff --git a/smac/model/random_forest/multiproc_util/GrowingSharedArray.py b/smac/model/random_forest/multiproc_util/GrowingSharedArray.py new file mode 100644 index 0000000000..2a47621379 --- /dev/null +++ b/smac/model/random_forest/multiproc_util/GrowingSharedArray.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from typing import Optional + +import math +from multiprocessing import Lock + +import uuid +import numpy as np +from numpy import typing as npt + +# from multiprocessing.shared_memory import SharedMemory +from .SharedMemory import SharedMemory as UntrackableSharedMemory + + +def SharedMemory(*args, **kwargs) -> UntrackableSharedMemory: + return UntrackableSharedMemory(*args, track=False, **kwargs) + + +def dtypes_are_equal(dtype1: np.dtype, dtype2: np.dtype) -> bool: + return np.issubdtype(dtype2, dtype1) and np.issubdtype(dtype1, dtype2) + + +class GrowingSharedArrayReaderView: + basename_X: str = 'X' + basename_y: str = 'y' + + def __init__(self, lock: Lock): + self.lock = lock + self.shm_id: Optional[int] = None + self.shm_X: Optional[SharedMemory] = None + self.shm_y: Optional[SharedMemory] = None + self.size: Optional[int] = None + + def open(self, shm_id: int, size: int): + if shm_id != self.shm_id: + self.close() + self.shm_X = SharedMemory(f'{self.basename_X}_{shm_id}') + self.shm_y = SharedMemory(f'{self.basename_y}_{shm_id}') + self.shm_id = shm_id + self.size = size + + def close_shm(self, unlink=False): + if self.shm_X is not None: + self.shm_X.close() + if unlink: + self.shm_X.unlink() + del self.shm_X + self.shm_X = None + if self.shm_y is not None: + self.shm_y.close() + if unlink: + self.shm_y.unlink() + del self.shm_y + self.shm_y = None + self.shm_id = None + self.size = None + + def close(self): + self.close_shm() + + def __del__(self): + self.close() + + @property + def capacity(self) -> int: + if self.shm_y is None: + return 0 + assert self.shm_y.size % np.dtype(np.float64).itemsize == 0 + return self.shm_y.size // np.dtype(np.float64).itemsize + + @property + def row_size(self) -> Optional[int]: + if self.shm_X is None: + return None + if self.shm_X.size == 0: + return None + assert self.shm_X.size % self.shm_y.size == 0 + return self.shm_X.size // self.shm_y.size + + @property + def X(self): + X = np.ndarray(shape=(self.capacity, self.row_size), dtype=np.float64, buffer=self.shm_X.buf) + return X[:self.size] + + @property + def y(self): + y = np.ndarray(shape=(self.capacity,), dtype=np.float64, buffer=self.shm_y.buf) + return y[:self.size] + + def get_data(self, shm_id: int, size: int) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: + with self.lock: + self.open(shm_id, size) + X, y = np.array(self.X), np.array(self.y) # make copies and release lock to minimize critical section + + return X, y + + +class GrowingSharedArray(GrowingSharedArrayReaderView): + def __init__(self): + self.growth_rate = 1.5 + super().__init__(lock=Lock()) + + def close(self): + self.close_shm(unlink=True) + + def __del__(self): + self.close() + + def set_data(self, X: npt.NDArray[np.float64], y: npt.NDArray[np.float64]) -> None: + assert len(X) == len(y) + assert X.ndim == 2 + assert y.ndim == 1 + assert dtypes_are_equal(X.dtype, np.float64) + assert dtypes_are_equal(y.dtype, np.float64) + assert X.dtype.itemsize == 8 + assert y.dtype.itemsize == 8 + + size = len(y) + grow = size > self.capacity + if grow: + if self.capacity: + n_growth = math.ceil(math.log(size / self.capacity, self.growth_rate)) + capacity = int(math.ceil(self.capacity * self.growth_rate ** n_growth)) + else: + assert self.shm_X is None + assert self.shm_y is None + capacity = size + + shm_id = uuid.uuid4().int # self.shm_id + 1 if self.shm_id else 0 + + row_size = X.shape[1] + if self.row_size is not None: + assert row_size == self.row_size + shm_X = SharedMemory(f'{self.basename_X}_{shm_id}', create=True, + size=capacity * row_size * X.dtype.itemsize) + shm_y = SharedMemory(f'{self.basename_y}_{shm_id}', create=True, size=capacity * y.dtype.itemsize) + + with self.lock: + if grow: + if self.capacity: + # here, before, reallocating we unlink the underlying shared memory without making sure that the + # training loop process has had a chance to close() it first, so this might lead to some warnings + # references: + # - https://stackoverflow.com/a/63004750/2447427 + # - https://github.com/python/cpython/issues/84140 + # - https://github.com/python/cpython/issues/82300 + # - comment provides a fix that turns off tracking: + # https://github.com/python/cpython/issues/82300#issuecomment-2169035092 + self.close() + self.shm_X = shm_X + self.shm_y = shm_y + self.shm_id = shm_id + self.size = size + self.X[...] = X + self.y[...] = y diff --git a/smac/model/random_forest/multiproc_util/RFTrainer.py b/smac/model/random_forest/multiproc_util/RFTrainer.py new file mode 100644 index 0000000000..365055c976 --- /dev/null +++ b/smac/model/random_forest/multiproc_util/RFTrainer.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from typing import Iterable, Optional, Union + +from multiprocessing import Lock, Queue, Process +import queue +import sys + +from numpy import typing as npt +import numpy as np +from pyrfr.regression import (binary_rss_forest as BinaryForest, default_random_engine as DefaultRandomEngine, + forest_opts as ForestOpts) + +from .GrowingSharedArray import GrowingSharedArrayReaderView, GrowingSharedArray +from ..util import get_rf_opts, train + +from enum import Enum, auto, unique + +try: + from setproctitle import setproctitle +except ImportError: + setproctitle = None +else: + import uuid + + +@unique +class Concurrency(Enum): + THREADING = auto() + THREADING_SYNCED = auto() + MULTIPROC = auto() + MULTIPROC_SYNCED = auto() + + +SHUTDOWN = None + + +ENABLE_DBG_PRINT = False + + +def debug_print(*args, file=sys.stdout, **kwargs): + if ENABLE_DBG_PRINT: + print(*args, **kwargs, flush=True, file=file) + file.flush() + + +# TODO: the type of the value passed for the 'bounds' param below is a tuple of tuples. Might this add some memory +# dependency between the processes which might mess up the cleanup process? +def rf_training_loop( + model_queue: Queue, data_queue: Queue, data_lock: Lock, + # init rf train + bounds: Iterable[tuple[float, float]], seed: int, + # rf opts + n_trees: int, bootstrapping: bool, max_features: int, min_samples_split: int, min_samples_leaf: int, + max_depth: int, eps_purity: float, max_nodes: int, n_points_per_tree: int +) -> None: + if setproctitle is not None: + setproctitle(f'rf_trainer_{uuid.uuid4().int}'[:15]) + + rf_opts = get_rf_opts(n_trees, bootstrapping, max_features, min_samples_split, min_samples_leaf, max_depth, + eps_purity, max_nodes, n_points_per_tree) + + # Cast to `int` incase we get an `np.integer` type + rng = DefaultRandomEngine(int(seed)) + shared_arrs = GrowingSharedArrayReaderView(data_lock) + + def send_to_optimization_loop_process(msg: Union[BinaryForest, type(SHUTDOWN)]): + # remove previous models from queue, if any, before pushing the latest model + while True: + try: + _ = model_queue.get(block=False) + except queue.Empty: + break + debug_print(f'TRAINER SENDING {"SHUTDOWN CONFIRM" if msg == SHUTDOWN else "MODEL"}', file=sys.stderr) + model_queue.put(msg) + debug_print(f'TRAINER SENDING {"SHUTDOWN CONFIRM" if msg == SHUTDOWN else "MODEL"} DONE', file=sys.stderr) + + while True: + debug_print('TRAINER WAIT MSG', file=sys.stderr) + data_msg = data_queue.get() # if queue is empty, wait for training data or shutdown signal + debug_print(f'TRAINER GOT MSG: {data_msg}', file=sys.stderr) + must_shutdown = data_msg == SHUTDOWN + if must_shutdown: + debug_print(f'TRAINER GOT SHUTDOWN 1', file=sys.stderr) + + # discard all but the last data_msg in the queue + while True: + try: + data_msg = data_queue.get(block=False) + except queue.Empty: + break + else: + if data_msg == SHUTDOWN: + debug_print(f'TRAINER GOT SHUTDOWN 2', file=sys.stderr) + must_shutdown = must_shutdown or data_msg == SHUTDOWN + if must_shutdown: + shared_arrs.close() + send_to_optimization_loop_process(SHUTDOWN) + # don't kill current process until we make sure the queue's underlying pipe is flushed + model_queue.close() + model_queue.join_thread() + break + + shm_id, size = data_msg + X, y = shared_arrs.get_data(shm_id, size) + # when shm_id changes, here we should notify main thread it can call unlink the shared memory bc we called + # close() on it + # UPDATE: we avoided the warnings by disabling tracking for shared memory + + rf = train(rng, rf_opts, n_points_per_tree, bounds, X, y) + + send_to_optimization_loop_process(rf) + debug_print(f'TRAINER BYE BYE', file=sys.stderr) + + +class RFTrainer: + def __init__(self, + # init rf train + bounds: Iterable[tuple[float, float]], seed: int, + # rf opts + n_trees: int, bootstrapping: bool, max_features: int, min_samples_split: int, min_samples_leaf: int, + max_depth: int, eps_purity: float, max_nodes: int, n_points_per_tree: int, + # process synchronization + background_training: Optional[Concurrency] = None) -> None: + self.background_training = background_training + + self._model: Optional[BinaryForest] = None + self.shared_arrs: Optional[GrowingSharedArray] = None + self.model_queue: Optional[Queue] = None + self.data_queue: Optional[Queue] = None + self.training_loop_proc: Optional[Process] = None + + # in case we disable training in the background, and we need these objects in the main thread + self.opts: ForestOpts = get_rf_opts(n_trees, bootstrapping, max_features, min_samples_split, min_samples_leaf, + max_depth, eps_purity, max_nodes, n_points_per_tree) + self.n_points_per_tree: int = n_points_per_tree + self.bounds = tuple(bounds) + + # this is NOT used when training in background + # Cast to `int` incase we get an `np.integer` type + self.rng = DefaultRandomEngine(int(seed)) + + self.open(seed) + + super().__init__() + + def open(self, seed: int) -> None: + assert self.background_training is None or self.background_training in Concurrency + if self.background_training is None: + pass + elif self.background_training is Concurrency.THREADING: + raise NotImplementedError + elif self.background_training is Concurrency.THREADING_SYNCED: + raise NotImplementedError + else: + self.shared_arrs = GrowingSharedArray() + self.model_queue = Queue(maxsize=1) + self.data_queue = Queue(maxsize=1) + self.training_loop_proc = Process( + target=rf_training_loop, + daemon=True, + name='rf_trainer', + args=(self.model_queue, self.data_queue, self.shared_arrs.lock, self.bounds, seed, self.opts.num_trees, + self.opts.do_bootstrapping, self.opts.tree_opts.max_features, + self.opts.tree_opts.min_samples_to_split, self.opts.tree_opts.min_samples_in_leaf, + self.opts.tree_opts.max_depth, self.opts.tree_opts.epsilon_purity, + self.opts.tree_opts.max_num_nodes, self.n_points_per_tree) + ) + self.training_loop_proc.start() + + def close(self): + # send kill signal to training process + if self.data_queue is not None: + if self.training_loop_proc is not None: + debug_print('MAIN SEND SHUTDOWN') + self.send_to_training_loop_proc(SHUTDOWN) + debug_print('MAIN FINISHED SEND SHUTDOWN') + # make sure the shutdown message is flush before moving on + self.data_queue.close() + self.data_queue.join_thread() + del self.data_queue + self.data_queue = None + + # wait till the training process died + if self.model_queue is not None and self.training_loop_proc is not None and self.training_loop_proc.is_alive(): + # flush the model queue, and store the latest model + while True: + debug_print('MAIN WAIT SHUTDOWN CONFIRM') + msg = self.model_queue.get() + debug_print(f'MAIN RECEIVED {"SHUTDOWN CONFIRMATION" if msg == SHUTDOWN else "MODEL"}' + f' AFTER WAITING FOR SHUTDOWN CONFIRMATION') + # wait for SHUTDOWN message, because that guarantees that shared_arrs.close() has been called within + # the training process; this way we make sure we call unlink only after close has had the chance to be + # called within the child process + if msg == SHUTDOWN: + break + else: + self._model = msg + + if self.training_loop_proc is not None: + # wait for training to finish + if self.training_loop_proc.is_alive(): + self.training_loop_proc.join() + del self.training_loop_proc + self.training_loop_proc = None + + if self.model_queue is not None: + del self.model_queue + self.model_queue = None + + # make sure this is called after SHUTDOWN was received because we want the trainer process to call + # shared_arrs.close() before we call unlink + if self.shared_arrs is not None: + self.shared_arrs.close() + del self.shared_arrs + self.shared_arrs = None + + def __del__(self): + self.close() + + @property + def model(self) -> BinaryForest: + if self._model is None: + if self.model_queue is None: + raise RuntimeError('rf training loop process has been stopped before being able to train a model') + # wait until the first training is done + msg = self.model_queue.get() + if msg == SHUTDOWN: + raise RuntimeError("the shutdown message wasn't supposed to end up here") + else: + self._model = msg + + if self.model_queue is not None: + # discard all but the last model in the queue + while True: + try: + msg = self.model_queue.get(block=False) + except queue.Empty: + break + else: + if msg == SHUTDOWN: + raise RuntimeError("the shutdown message wasn't supposed to end up here") + else: + self._model = msg + + return self._model + + def send_to_training_loop_proc(self, data_info: Union[tuple[int, int], type[SHUTDOWN]]): + if self.data_queue is None: + raise RuntimeError('rf training loop process has been stopped, so we cannot submit new training data') + + # empty queue before pushing new data onto it + while True: + try: + old_data = self.data_queue.get(block=False) + except queue.Empty: + break + else: + assert old_data != SHUTDOWN + self.data_queue.put(data_info) + + def submit_for_training(self, X: npt.NDArray[np.float64], y: npt.NDArray[np.float64]): + if self.background_training is None: + self._model = train(self.rng, self.opts, self.n_points_per_tree, self.bounds, X, y) + else: + if self.background_training in (Concurrency.THREADING, Concurrency.THREADING_SYNCED): + raise NotImplementedError + self.shared_arrs.set_data(X, y) + self.send_to_training_loop_proc((self.shared_arrs.shm_id, len(X))) + if self.background_training is Concurrency.MULTIPROC_SYNCED: + self._model = self.model_queue.get() diff --git a/smac/model/random_forest/multiproc_util/SharedMemory.py b/smac/model/random_forest/multiproc_util/SharedMemory.py new file mode 100644 index 0000000000..31dc5b194e --- /dev/null +++ b/smac/model/random_forest/multiproc_util/SharedMemory.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +"""https://github.com/python/cpython/issues/82300#issuecomment-2169035092 +""" + +from typing import Optional +import sys +import threading +from multiprocessing import resource_tracker as _mprt +from multiprocessing import shared_memory as _mpshm + + +if sys.version_info >= (3, 13): + SharedMemory = _mpshm.SharedMemory +else: + class SharedMemory(_mpshm.SharedMemory): + __lock = threading.Lock() + + def __init__( + self, name: Optional[str] = None, create: bool = False, + size: int = 0, *, track: bool = True + ) -> None: + self._track = track + + # if tracking, normal init will suffice + if track: + super().__init__(name=name, create=create, size=size) + return + + # lock so that other threads don't attempt to use the + # register function during this time + with self.__lock: + # temporarily disable registration during initialization + orig_register = _mprt.register + _mprt.register = self.__tmp_register + + # initialize; ensure original register function is + # re-instated + try: + super().__init__(name=name, create=create, size=size) + finally: + _mprt.register = orig_register + + @staticmethod + def __tmp_register(*args, **kwargs) -> None: + return + + def unlink(self) -> None: + if _mpshm._USE_POSIX and self._name: + _mpshm._posixshmem.shm_unlink(self._name) + if self._track: + _mprt.unregister(self._name, "shared_memory") \ No newline at end of file diff --git a/smac/model/random_forest/multiproc_util/__init__.py b/smac/model/random_forest/multiproc_util/__init__.py new file mode 100644 index 0000000000..56fafa58b3 --- /dev/null +++ b/smac/model/random_forest/multiproc_util/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- diff --git a/smac/model/random_forest/random_forest.py b/smac/model/random_forest/random_forest.py index 72685803f9..6a798694da 100644 --- a/smac/model/random_forest/random_forest.py +++ b/smac/model/random_forest/random_forest.py @@ -4,12 +4,11 @@ import numpy as np from ConfigSpace import ConfigurationSpace -from pyrfr import regression -from pyrfr.regression import binary_rss_forest as BinaryForest -from pyrfr.regression import default_data_container as DataContainer from smac.constants import N_TREES, VERY_SMALL_NUMBER -from smac.model.random_forest import AbstractRandomForest +from . import AbstractRandomForest +from .multiproc_util.RFTrainer import RFTrainer, Concurrency + __copyright__ = "Copyright 2022, automl.org" __license__ = "3-clause BSD" @@ -75,21 +74,15 @@ def __init__( max_features = 0 if ratio_features > 1.0 else max(1, int(len(self._types) * ratio_features)) - self._rf_opts = regression.forest_opts() - self._rf_opts.num_trees = n_trees - self._rf_opts.do_bootstrapping = bootstrapping - self._rf_opts.tree_opts.max_features = max_features - self._rf_opts.tree_opts.min_samples_to_split = min_samples_split - self._rf_opts.tree_opts.min_samples_in_leaf = min_samples_leaf - self._rf_opts.tree_opts.max_depth = max_depth - self._rf_opts.tree_opts.epsilon_purity = eps_purity - self._rf_opts.tree_opts.max_num_nodes = max_nodes - self._rf_opts.compute_law_of_total_variance = False - self._rf: BinaryForest | None = None + self._rf_trainer = RFTrainer( + self._bounds, seed, n_trees, bootstrapping, max_features, min_samples_split, + min_samples_leaf, max_depth, eps_purity, max_nodes, n_points_per_tree, + background_training=Concurrency.MULTIPROC + ) self._log_y = log_y - # Case to `int` incase we get an `np.integer` type - self._rng = regression.default_random_engine(int(seed)) + # this is NOT used when training in background + self._rng = self._rf_trainer.rng self._n_trees = n_trees self._n_points_per_tree = n_points_per_tree @@ -115,6 +108,12 @@ def __init__( # self._seed, # ] + def close(self): + self._rf_trainer.close() + + def __del__(self): + self.close() + @property def meta(self) -> dict[str, Any]: # noqa: D102 meta = super().meta @@ -142,48 +141,14 @@ def _train(self, X: np.ndarray, y: np.ndarray) -> RandomForest: # self.X = X # self.y = y.flatten() - if self._n_points_per_tree <= 0: - self._rf_opts.num_data_points_per_tree = X.shape[0] - else: - self._rf_opts.num_data_points_per_tree = self._n_points_per_tree - - self._rf = regression.binary_rss_forest() - self._rf.options = self._rf_opts + self._rf_trainer.submit_for_training(X, y) - data = self._init_data_container(X, y) - self._rf.fit(data, rng=self._rng) + # call this to make sure that there exists a trained model before returning (actually, not sure this is + # required, since we check within predict() anyway) + # _ = self._rf.model return self - def _init_data_container(self, X: np.ndarray, y: np.ndarray) -> DataContainer: - """Fills a pyrfr default data container s.t. the forest knows categoricals and bounds for continous data. - - Parameters - ---------- - X : np.ndarray [#samples, #hyperparameter + #features] - Input data points. - Y : np.ndarray [#samples, #objectives] - The corresponding target values. - - Returns - ------- - data : DataContainer - The filled data container that pyrfr can interpret. - """ - # Retrieve the types and the bounds from the ConfigSpace - data = regression.default_data_container(X.shape[1]) - - for i, (mn, mx) in enumerate(self._bounds): - if np.isnan(mx): - data.set_type_of_feature(i, mn) - else: - data.set_bounds_of_feature(i, mn, mx) - - for row_X, row_y in zip(X, y): - data.add_data_point(row_X, row_y) - - return data - def _predict( self, X: np.ndarray, @@ -198,7 +163,9 @@ def _predict( if covariance_type != "diagonal": raise ValueError("`covariance_type` can only take `diagonal` for this model.") - assert self._rf is not None + rf = self._rf_trainer.model + + assert rf is not None X = self._impute_inactive(X) if self._log_y: @@ -207,13 +174,13 @@ def _predict( # Gather data in a list of 2d arrays and get statistics about the required size of the 3d array for row_X in X: - preds_per_tree = self._rf.all_leaf_values(row_X) + preds_per_tree = rf.all_leaf_values(row_X) all_preds.append(preds_per_tree) max_num_leaf_data = max(map(len, preds_per_tree)) third_dimension = max(max_num_leaf_data, third_dimension) # Transform list of 2d arrays into a 3d array - preds_as_array = np.zeros((X.shape[0], self._rf_opts.num_trees, third_dimension)) * np.nan + preds_as_array = np.zeros((X.shape[0], self._n_trees, third_dimension)) * np.nan for i, preds_per_tree in enumerate(all_preds): for j, pred in enumerate(preds_per_tree): preds_as_array[i, j, : len(pred)] = pred @@ -227,7 +194,7 @@ def _predict( else: means, vars_ = [], [] for row_X in X: - mean_, var = self._rf.predict_mean_var(row_X) + mean_, var = rf.predict_mean_var(row_X) means.append(mean_) vars_.append(var) @@ -273,11 +240,12 @@ def predict_marginalized(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: if X.shape[1] != len(self._bounds): raise ValueError("Rows in X should have %d entries but have %d!" % (len(self._bounds), X.shape[1])) - assert self._rf is not None + rf = self._rf_trainer.model + assert rf is not None X = self._impute_inactive(X) X_feat = list(self._instance_features.values()) - dat_ = self._rf.predict_marginalized_over_instances_batch(X, X_feat, self._log_y) + dat_ = rf.predict_marginalized_over_instances_batch(X, X_feat, self._log_y) dat_ = np.array(dat_) # 3. compute statistics across trees diff --git a/smac/model/random_forest/util.py b/smac/model/random_forest/util.py new file mode 100644 index 0000000000..d520eb1d3d --- /dev/null +++ b/smac/model/random_forest/util.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Iterable, TYPE_CHECKING + +import numpy as np +from pyrfr.regression import (default_data_container as DataContainer, forest_opts as ForestOpts, + binary_rss_forest as BinaryForest, default_random_engine as DefaultRandomEngine) + +if TYPE_CHECKING: + from numpy import typing as npt + + +def get_rf_opts(n_trees: int, bootstrapping: bool, max_features: int, min_samples_split: int, min_samples_leaf: int, + max_depth: int, eps_purity: float, max_nodes: int, n_points_per_tree: int) -> ForestOpts: + rf_opts = ForestOpts() + rf_opts.num_trees = n_trees + rf_opts.do_bootstrapping = bootstrapping + rf_opts.tree_opts.max_features = max_features + rf_opts.tree_opts.min_samples_to_split = min_samples_split + rf_opts.tree_opts.min_samples_in_leaf = min_samples_leaf + rf_opts.tree_opts.max_depth = max_depth + rf_opts.tree_opts.epsilon_purity = eps_purity + rf_opts.tree_opts.max_num_nodes = max_nodes + rf_opts.compute_law_of_total_variance = False + if n_points_per_tree > 0: + rf_opts.num_data_points_per_tree = n_points_per_tree + + return rf_opts + + +def init_data_container(X: npt.NDArray[np.float64], y: npt.NDArray[np.float64], + bounds: Iterable[tuple[float, float]]) -> DataContainer: + """Fills a pyrfr default data container s.t. the forest knows categoricals and bounds for continous data. + + Parameters + ---------- + X : np.ndarray [#samples, #hyperparameter + #features] + Input data points. + Y : np.ndarray [#samples, #objectives] + The corresponding target values. + + Returns + ------- + data : DataContainer + The filled data container that pyrfr can interpret. + """ + # Retrieve the types and the bounds from the ConfigSpace + data = DataContainer(X.shape[1]) + + for i, (mn, mx) in enumerate(bounds): + if np.isnan(mx): + data.set_type_of_feature(i, mn) + else: + data.set_bounds_of_feature(i, mn, mx) + + for row_X, row_y in zip(X, y): + data.add_data_point(row_X, row_y) + + return data + + +def train(rng: DefaultRandomEngine, rf_opts: ForestOpts, n_points_per_tree: int, bounds: Iterable[tuple[float, float]], + X: npt.NDArray[np.float64], y: npt.NDArray[np.float64]) -> BinaryForest: + data = init_data_container(X, y, bounds) + + if n_points_per_tree <= 0: + rf_opts.num_data_points_per_tree = len(X) + + rf = BinaryForest() + rf.options = rf_opts + + rf.fit(data, rng) + + return rf + diff --git a/tests/test_acquisition/test_functions.py b/tests/test_acquisition/test_functions.py index 53c8b5f3a0..f4e093a722 100644 --- a/tests/test_acquisition/test_functions.py +++ b/tests/test_acquisition/test_functions.py @@ -36,6 +36,9 @@ def predict_marginalized(self, X): [np.mean(X, axis=1).reshape((1, -1))] * self.num_targets ).reshape((-1, 1)) + def close(self): + pass + class MockModelDual: def __init__(self, num_targets=1): @@ -46,6 +49,9 @@ def predict_marginalized(self, X): [np.mean(X, axis=1).reshape((1, -1))] * self.num_targets ).reshape((-1, 2)) + def close(self): + pass + class MockPrior: def __init__(self, pdf, max_density): @@ -116,6 +122,9 @@ def predict_marginalized(self, X): def update_prior(self, hyperparameter_dict): self._configspace.get_hyperparameters_dict.return_value = hyperparameter_dict + def close(self): + pass + class MockModelRNG(MockModel): def __init__(self, num_targets=1, seed=0): @@ -154,9 +163,13 @@ def acquisition_function(model): # Test AbstractAcquisitionFunction # -------------------------------------------------------------- +class CloseableString(str): + def close(self): + pass + def test_update_model_and_eta(model, acquisition_function): - model = "abc" + model = CloseableString("abc") assert acquisition_function._eta is None acquisition_function.update(model=model, eta=0.1) assert acquisition_function.model == model @@ -164,7 +177,8 @@ def test_update_model_and_eta(model, acquisition_function): def test_update_with_kwargs(acquisition_function): - acquisition_function.update(model="abc", eta=0.0, other="hi there:)") + model = CloseableString("abc") + acquisition_function.update(model=model, eta=0.0, other="hi there:)") assert acquisition_function.model == "abc" diff --git a/tests/test_acquisition/test_maximizers.py b/tests/test_acquisition/test_maximizers.py index d7698e0a29..fdac066199 100644 --- a/tests/test_acquisition/test_maximizers.py +++ b/tests/test_acquisition/test_maximizers.py @@ -1,21 +1,14 @@ from __future__ import annotations -from typing import Any - -import os import unittest import unittest.mock import numpy as np import pytest from ConfigSpace import ( - Categorical, Configuration, ConfigurationSpace, - EqualsCondition, Float, - InCondition, - Integer, ) from ConfigSpace.hyperparameters import ( BetaIntegerHyperparameter, @@ -24,7 +17,6 @@ UniformFloatHyperparameter, UniformIntegerHyperparameter, ) -from ConfigSpace.read_and_write import pcs from scipy.spatial.distance import euclidean from smac.acquisition.function import EI @@ -35,8 +27,6 @@ RandomSearch, ) from smac.model.random_forest.random_forest import RandomForest -from smac.runhistory.runhistory import RunHistory -from smac.runner.abstract_runner import StatusType __copyright__ = "Copyright 2021, AutoML.org Freiburg-Hannover" __license__ = "3-clause BSD" @@ -185,9 +175,7 @@ def test_get_next_by_random_search(): # TestLocalSearch # -------------------------------------------------------------- - -@pytest.fixture -def configspace() -> ConfigurationSpace: +def get_configspace() -> ConfigurationSpace: cs = ConfigurationSpace(seed=0) a = Float("a", (0, 1), default=0.5) @@ -201,25 +189,42 @@ def configspace() -> ConfigurationSpace: @pytest.fixture -def model(configspace: ConfigurationSpace): +def configspace() -> ConfigurationSpace: + return get_configspace() + + +def get_model(configspace: ConfigurationSpace) -> RandomForest: model = RandomForest(configspace) np.random.seed(0) - X = np.random.rand(100, len(configspace.get_hyperparameters())) - y = 1 - (np.sum(X, axis=1) / len(configspace.get_hyperparameters())) + X = np.random.rand(100, len(configspace.values())) + y = 1 - (np.sum(X, axis=1) / len(configspace.values())) model.train(X, y) return model @pytest.fixture -def acquisition_function(model): +def model(configspace: ConfigurationSpace) -> RandomForest: + model = get_model(configspace) + # return model + yield model + model.close() + + + +def get_acquisition_function(model): ei = EI() ei.update(model=model, eta=0.5) return ei +@pytest.fixture +def acquisition_function(model): + return get_acquisition_function(model) + + def test_local_search(configspace): def acquisition_function(points): rval = [] @@ -267,6 +272,9 @@ class AcquisitionFunction: def __call__(self, X): return np.array([x.get_array().sum() for x in X]).reshape((-1, 1)) + def close(self): + pass + ls = LocalSearch( configspace=configspace, acquisition_function=AcquisitionFunction(), @@ -385,6 +393,9 @@ def __call__(self, arrays): rval.append([-rosenbrock_4d(array)]) return np.array(rval) + def close(self): + pass + budget_kwargs = {"max_steps": 2, "n_steps_plateau_walk": 2, "local_search_iterations": 2} prs_0 = LocalAndSortedRandomSearch( @@ -427,3 +438,60 @@ def test_differential_evolution(configspace, acquisition_function): values = rs._maximize(start_points, 1) values[0][1].origin == "Acquisition Function Maximizer: Differential Evolution" + + +# manual testing + +def differential_evolution(): + cs = get_configspace() + m = get_model(cs) + af = get_acquisition_function(m) + test_differential_evolution(cs, af) + + +def min_repro_differential_evolution_bug(): + cs = ConfigurationSpace(seed=0) + a = Float("a", (0, 1), default=0.5) + cs.add(a) + + model = RandomForest(cs) + + af = EI() + af.update(model=model, eta=0.5) + + np.random.seed(0) + X = np.random.rand(100, len(cs.values())) + y = 1 - (np.sum(X, axis=1) / len(cs.values())) + model.train(X, y) + + start_points = cs.sample_configuration(100) + # start_point = cs.sample_configuration() # this circumvents bug + rs = DifferentialEvolution(cs, af, challengers=1000) + values = rs._maximize(start_points, 1) + values[0][1].origin == "Acquisition Function Maximizer: Differential Evolution" + # model._rf_trainer.close() # this circumvents the bug + + +def random_search(): + cs = get_configspace() + m = get_model(cs) + af = get_acquisition_function(m) + test_random_search(cs, af) + + +def main(): + from smac.model.random_forest.multiproc_util import RFTrainer + RFTrainer.ENABLE_DBG_PRINT = True + # TODO: running ALL these three IN THIS ORDER causes a hang, probably because of the dependency graph growing too + # complex and circular for the garbage collector to handle, so RFTrainer.close() is never called. In order to avoid + # hangs while running tests, we explicitly call RFTrainer.close() in model fixture teardown + print('differential_evolution:') + differential_evolution() + print('\nmin_repro_differential_evolution_bug:') + min_repro_differential_evolution_bug() + print('\nrandom_search:') + random_search() + + +if __name__ == '__main__': + main() diff --git a/tests/test_ask_and_tell/test_ask_and_tell_intensifier.py b/tests/test_ask_and_tell/test_ask_and_tell_intensifier.py index 4fd5ab1cab..2c491bdb9f 100644 --- a/tests/test_ask_and_tell/test_ask_and_tell_intensifier.py +++ b/tests/test_ask_and_tell/test_ask_and_tell_intensifier.py @@ -9,8 +9,7 @@ __license__ = "3-clause BSD" -@pytest.fixture -def make_facade(digits_dataset, make_sgd) -> HyperparameterOptimizationFacade: +def get_make_facade(digits_dataset, make_sgd) -> HyperparameterOptimizationFacade: def create( deterministic: bool = True, use_instances: bool = False, max_config_calls: int = 5 ) -> HyperparameterOptimizationFacade: @@ -47,6 +46,11 @@ def create( return create +@pytest.fixture +def make_facade(digits_dataset, make_sgd) -> HyperparameterOptimizationFacade: + return get_make_facade(digits_dataset, make_sgd) + + # -------------------------------------------------------------- # Test tell without ask # -------------------------------------------------------------- @@ -171,3 +175,23 @@ def test_multiple_asks_successively(make_facade): # Make sure the trials are different assert trial_info not in info info += [trial_info] + + +def ask_and_tell_after_optimization(): + from ..fixtures.datasets import DigitsDataset, Dataset + from ..fixtures.models import SGD + digits_dataset = DigitsDataset() + + def make_sgd(dataset: Dataset) -> SGD: + return SGD(dataset) + + make_facade = get_make_facade(digits_dataset, make_sgd) + test_ask_and_tell_after_optimization(make_facade) + + +def main(): + ask_and_tell_after_optimization() + + +if __name__ == '__main__': + main() diff --git a/tests/test_model/test_rf.py b/tests/test_model/test_rf.py index c81549a16e..8fd80e5d24 100644 --- a/tests/test_model/test_rf.py +++ b/tests/test_model/test_rf.py @@ -19,7 +19,7 @@ def _get_cs(n_dimensions): configspace = ConfigurationSpace(seed=0) for i in range(n_dimensions): - configspace.add_hyperparameter(UniformFloatHyperparameter("x%d" % i, 0, 1)) + configspace.add(UniformFloatHyperparameter("x%d" % i, 0, 1)) return configspace