From fc2fc14c8761c4617ca01c484aeb1af9217d20fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Sass?= Date: Thu, 19 May 2022 19:46:54 +0200 Subject: [PATCH] Version 1.3.3 * Hotfix: Since multi-objective implementation depends on normalized costs, it now is ensured that the cached costs are updated everytime a new entry is added. * Removed mac-specific files. * Added entry point for cli. * Added `ConfigSpace` to third known parties s.t. sorting should be the same across different operating systems. * Fixed bugs in makefile in which tools were specified incorrectly. * Executed isort/black on examples and tests. * Updated README. * Fixed a problem, which incremented time twice before taking log (#833). * New wrapper for multi-objective models (base_uncorrelated_mo_model). Makes it easier for developing new multi-objective models. * Raise error if acquisition function is incompatible with the epm models. * Restricting pynisher. Co-authored-by: Difan Deng Co-authored-by: Deyao Chen Co-authored-by: BastianZim --- .DS_Store | Bin 6148 -> 0 bytes .github/stale.yml | 2 +- .gitignore | 5 +- Makefile | 6 +- README.md | 2 +- changelog.md | 17 ++ .../scripts/genericWrapper.py | 4 +- examples/python/plot_gb_non_deterministic.py | 8 +- examples/python/plot_mlp_mf.py | 12 +- .../python/plot_scalarized_multi_objective.py | 6 +- examples/python/plot_sgd_instances.py | 9 +- .../python/plot_simple_multi_objective.py | 2 +- examples/python/plot_svm_cv.py | 6 +- examples/python/plot_svm_eips.py | 122 ++++++++++++ examples/python/plot_synthetic_function.py | 1 - ...er_prior_mlp.py => plot_user_prior_mlp.py} | 22 ++- pyproject.toml | 1 + setup.py | 5 +- smac/__init__.py | 2 +- smac/epm/base_uncorrelated_mo_model.py | 183 ++++++++++++++++++ smac/epm/uncorrelated_mo_rf_with_instances.py | 147 +++----------- smac/facade/smac_ac_facade.py | 8 + smac/runhistory/runhistory.py | 43 ++-- smac/runhistory/runhistory2epm.py | 3 +- tests/test_epm/test_gp_priors.py | 2 +- .../test_uncorrelated_mo_rf_with_instances.py | 8 +- tests/test_facade/test_smac_facade.py | 5 +- tests/test_multi_objective/test_schaffer.py | 2 +- .../test_runhistory_multi_objective.py | 57 +++++- tests/test_scenario/test_scenario.py | 2 +- 30 files changed, 502 insertions(+), 190 deletions(-) delete mode 100644 .DS_Store create mode 100644 examples/python/plot_svm_eips.py rename examples/python/{user_prior_mlp.py => plot_user_prior_mlp.py} (93%) create mode 100644 smac/epm/base_uncorrelated_mo_model.py diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index ce6e8b2f5316005cac6d8bf7eae76c7c2c447aa0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%}T>S5Z-O0O({YS3WApfuN9lBZNW>Z^#zRRL8UgPXfS3=liEWmaNqw&s5WovD{Vpr^%ed9mKsh63gTs#)|kNj*u82 z28aP-U_}`)2ZCK&(fX-uVt^RedkNm|HvWNj<;GZ$Tt)AcOLQ&>y{Z<~HwF25LG!%@>Q2_yc;SvA_?jzgEsr>?V Zi1Q5Q8gUjJSLuLs5m1CsM-2P|10Ov8O8Ec) diff --git a/.github/stale.yml b/.github/stale.yml index 8cfed16c9..f1508d0f2 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -17,6 +17,6 @@ markComment: > This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. - + # Comment to post when closing a stale issue. Set to `false` to disable closeComment: false \ No newline at end of file diff --git a/.gitignore b/.gitignore index ba354aa6a..10f0fcdd6 100644 --- a/.gitignore +++ b/.gitignore @@ -133,4 +133,7 @@ dmypy.json # Pyre type checker .pyre/ -*smac3-output_* \ No newline at end of file +*smac3-output_* + +# macOS files +.DS_Store diff --git a/Makefile b/Makefile index 93c43b4ba..645483da0 100644 --- a/Makefile +++ b/Makefile @@ -49,7 +49,7 @@ check-black: check-isort: $(ISORT) ${SOURCE_DIR} --check || : - $(BLACK) ${EXAMPLES_DIR} --check || : + $(ISORT) ${EXAMPLES_DIR} --check || : $(ISORT) ${TESTS_DIR} --check || : check-pydocstyle: @@ -60,7 +60,7 @@ check-mypy: check-flake8: $(FLAKE8) ${SOURCE_DIR} || : - $(BLACK) ${EXAMPLES_DIR} --check || : + $(FLAKE8) ${EXAMPLES_DIR} --check || : $(FLAKE8) ${TESTS_DIR} || : check: check-black check-isort check-mypy check-flake8 check-pydocstyle @@ -75,7 +75,7 @@ format-black: format-isort: $(ISORT) ${SOURCE_DIR} - $(BLACK) ${EXAMPLES_DIR} + $(ISORT) ${EXAMPLES_DIR} $(ISORT) ${TESTS_DIR} format: format-black format-isort diff --git a/README.md b/README.md index d95f689f6..d06e47747 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ arbitrary algorithms, including hyperparameter optimization of Machine Learning Bayesian Optimization in combination with an aggressive racing mechanism to efficiently decide which of two configurations performs better. -SMAC3 is written in Python3 and continuously tested with Python 3.7, 3.8 and 3.9. Its Random +SMAC3 is written in Python3 and continuously tested with Python 3.7, 3.8, 3.9, and 3.10. Its Random Forest is written in C++. In further texts, SMAC is representatively mentioned for SMAC3. [Documention](https://automl.github.io/SMAC3) diff --git a/changelog.md b/changelog.md index 4b17cc650..65e156b2d 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,20 @@ +# 1.3.3 +* Hotfix: Since multi-objective implementation depends on normalized costs, it now is ensured that the +cached costs are updated everytime a new entry is added. +* Removed mac-specific files. +* Added entry point for cli. +* Added `ConfigSpace` to third known parties s.t. sorting should be the same across different +operating systems. +* Fixed bugs in makefile in which tools were specified incorrectly. +* Executed isort/black on examples and tests. +* Updated README. +* Fixed a problem, which incremented time twice before taking log (#833). +* New wrapper for multi-objective models (base_uncorrelated_mo_model). Makes it easier for +developing new multi-objective models. +* Raise error if acquisition function is incompatible with the epm models. +* Restricting pynisher. + + # 1.3.2 * Added stale bot support. * If package version 0.0.0 via `get_distribution` is found, the version of the module is used diff --git a/examples/commandline/spear_qcp/target_algorithm/scripts/genericWrapper.py b/examples/commandline/spear_qcp/target_algorithm/scripts/genericWrapper.py index 9b46eb396..41fd27ba4 100755 --- a/examples/commandline/spear_qcp/target_algorithm/scripts/genericWrapper.py +++ b/examples/commandline/spear_qcp/target_algorithm/scripts/genericWrapper.py @@ -705,7 +705,9 @@ def __init__(self): self.required = [] self.args = Arguments() - def add_argument(self, parameter_name, dest, default=None, help="", type=str, required=False): # pylint: disable=built-in + def add_argument( + self, parameter_name, dest, default=None, help="", type=str, required=False + ): # pylint: disable=built-in """ adds arguments to parse from command line Args: diff --git a/examples/python/plot_gb_non_deterministic.py b/examples/python/plot_gb_non_deterministic.py index de3e144d3..2cc1beb4a 100644 --- a/examples/python/plot_gb_non_deterministic.py +++ b/examples/python/plot_gb_non_deterministic.py @@ -16,14 +16,14 @@ logging.basicConfig(level=logging.INFO) import numpy as np -from sklearn.datasets import make_hastie_10_2 -from sklearn.ensemble import GradientBoostingClassifier -from sklearn.model_selection import KFold, cross_val_score - from ConfigSpace.hyperparameters import ( UniformFloatHyperparameter, UniformIntegerHyperparameter, ) +from sklearn.datasets import make_hastie_10_2 +from sklearn.ensemble import GradientBoostingClassifier +from sklearn.model_selection import KFold, cross_val_score + from smac.configspace import ConfigurationSpace from smac.facade.smac_hpo_facade import SMAC4HPO from smac.scenario.scenario import Scenario diff --git a/examples/python/plot_mlp_mf.py b/examples/python/plot_mlp_mf.py index c22b0e951..676d475c4 100644 --- a/examples/python/plot_mlp_mf.py +++ b/examples/python/plot_mlp_mf.py @@ -22,18 +22,18 @@ import warnings -import numpy as np -from sklearn.datasets import load_digits -from sklearn.exceptions import ConvergenceWarning -from sklearn.model_selection import StratifiedKFold, cross_val_score -from sklearn.neural_network import MLPClassifier - import ConfigSpace as CS +import numpy as np from ConfigSpace.hyperparameters import ( CategoricalHyperparameter, UniformFloatHyperparameter, UniformIntegerHyperparameter, ) +from sklearn.datasets import load_digits +from sklearn.exceptions import ConvergenceWarning +from sklearn.model_selection import StratifiedKFold, cross_val_score +from sklearn.neural_network import MLPClassifier + from smac.configspace import ConfigurationSpace from smac.facade.smac_mf_facade import SMAC4MF from smac.scenario.scenario import Scenario diff --git a/examples/python/plot_scalarized_multi_objective.py b/examples/python/plot_scalarized_multi_objective.py index cc1e0ec8c..0854a4343 100644 --- a/examples/python/plot_scalarized_multi_objective.py +++ b/examples/python/plot_scalarized_multi_objective.py @@ -17,15 +17,15 @@ import matplotlib.pyplot as plt import numpy as np -from sklearn import datasets, svm -from sklearn.model_selection import cross_val_score - from ConfigSpace.conditions import InCondition from ConfigSpace.hyperparameters import ( CategoricalHyperparameter, UniformFloatHyperparameter, UniformIntegerHyperparameter, ) +from sklearn import datasets, svm +from sklearn.model_selection import cross_val_score + from smac.configspace import ConfigurationSpace from smac.facade.smac_hpo_facade import SMAC4HPO from smac.scenario.scenario import Scenario diff --git a/examples/python/plot_sgd_instances.py b/examples/python/plot_sgd_instances.py index 01837a07a..65b4c3ba8 100644 --- a/examples/python/plot_sgd_instances.py +++ b/examples/python/plot_sgd_instances.py @@ -21,15 +21,14 @@ import warnings import numpy as np -from sklearn import datasets -from sklearn.exceptions import ConvergenceWarning -from sklearn.linear_model import SGDClassifier -from sklearn.model_selection import StratifiedKFold, cross_val_score - from ConfigSpace.hyperparameters import ( CategoricalHyperparameter, UniformFloatHyperparameter, ) +from sklearn import datasets +from sklearn.exceptions import ConvergenceWarning +from sklearn.linear_model import SGDClassifier +from sklearn.model_selection import StratifiedKFold, cross_val_score # Import ConfigSpace and different types of parameters from smac.configspace import ConfigurationSpace diff --git a/examples/python/plot_simple_multi_objective.py b/examples/python/plot_simple_multi_objective.py index a7cdf0382..df87556bd 100644 --- a/examples/python/plot_simple_multi_objective.py +++ b/examples/python/plot_simple_multi_objective.py @@ -8,9 +8,9 @@ __license__ = "3-clause BSD" import numpy as np +from ConfigSpace.hyperparameters import UniformFloatHyperparameter from matplotlib import pyplot as plt -from ConfigSpace.hyperparameters import UniformFloatHyperparameter from smac.configspace import ConfigurationSpace from smac.facade.smac_bb_facade import SMAC4BB from smac.scenario.scenario import Scenario diff --git a/examples/python/plot_svm_cv.py b/examples/python/plot_svm_cv.py index 4234f9e82..0e45ad33f 100644 --- a/examples/python/plot_svm_cv.py +++ b/examples/python/plot_svm_cv.py @@ -16,15 +16,15 @@ logging.basicConfig(level=logging.INFO) import numpy as np -from sklearn import datasets, svm -from sklearn.model_selection import cross_val_score - from ConfigSpace.conditions import InCondition from ConfigSpace.hyperparameters import ( CategoricalHyperparameter, UniformFloatHyperparameter, UniformIntegerHyperparameter, ) +from sklearn import datasets, svm +from sklearn.model_selection import cross_val_score + from smac.configspace import ConfigurationSpace from smac.facade.smac_hpo_facade import SMAC4HPO from smac.scenario.scenario import Scenario diff --git a/examples/python/plot_svm_eips.py b/examples/python/plot_svm_eips.py new file mode 100644 index 000000000..56a83f208 --- /dev/null +++ b/examples/python/plot_svm_eips.py @@ -0,0 +1,122 @@ +""" +SVM with EIPS as acquisition functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +An example to optimize a simple SVM on the IRIS-benchmark with EIPS (EI per seconds) +acquisition function. Since EIPS requires two types of objections: EI values and the predicted +time used for the configurations. We need to fit the data +with a multi-objective model +""" + +import logging + +logging.basicConfig(level=logging.INFO) + +import numpy as np +from sklearn import datasets, svm +from sklearn.model_selection import cross_val_score + +from ConfigSpace.hyperparameters import UniformFloatHyperparameter, CategoricalHyperparameter + +from smac.configspace import ConfigurationSpace +from smac.facade.smac_ac_facade import SMAC4AC + +# Import SMAC-utilities +from smac.scenario.scenario import Scenario + +# EIPS related +from smac.optimizer.acquisition import EIPS +from smac.runhistory.runhistory2epm import RunHistory2EPM4EIPS +from smac.epm.uncorrelated_mo_rf_with_instances import UncorrelatedMultiObjectiveRandomForestWithInstances + +__copyright__ = "Copyright 2021, AutoML.org Freiburg-Hannover" +__license__ = "3-clause BSD" + +iris = datasets.load_iris() + + +# Target Algorithm +def svm_from_cfg(cfg): + """Creates a SVM based on a configuration and evaluates it on the + iris-dataset using cross-validation. Note here random seed is fixed + + Parameters: + ----------- + cfg: Configuration (ConfigSpace.ConfigurationSpace.Configuration) + Configuration containing the parameters. + Configurations are indexable! + + Returns: + -------- + A crossvalidated mean score for the svm on the loaded data-set. + """ + # For deactivated parameters, the configuration stores None-values. + # This is not accepted by the SVM, so we remove them. + cfg = {k: cfg[k] for k in cfg if cfg[k]} + # And for gamma, we set it to a fixed value or to "auto" (if used) + if "gamma" in cfg: + cfg["gamma"] = cfg["gamma_value"] if cfg["gamma"] == "value" else "auto" + cfg.pop("gamma_value", None) # Remove "gamma_value" + + clf = svm.SVC(**cfg, random_state=42) + + scores = cross_val_score(clf, iris.data, iris.target, cv=5) + return 1 - np.mean(scores) # Minimize! + + +if __name__ == "__main__": + # Build Configuration Space which defines all parameters and their ranges + cs = ConfigurationSpace() + + # We define a few possible types of SVM-kernels and add them as "kernel" to our cs + kernel = CategoricalHyperparameter("kernel", ["linear", "rbf", "poly", "sigmoid"], default_value="poly") + cs.add_hyperparameter(kernel) + + # There are some hyperparameters shared by all kernels + C = UniformFloatHyperparameter("C", 0.001, 1000.0, default_value=1.0, log=True) + shrinking = CategoricalHyperparameter("shrinking", [True, False], default_value=True) + cs.add_hyperparameters([C, shrinking]) + + # Scenario object + scenario = Scenario( + { + "run_obj": "quality", # we optimize quality (alternatively runtime) + "runcount-limit": 50, # max. number of function evaluations + "cs": cs, # configuration space + "deterministic": True, + } + ) + + # Example call of the function + # It returns: Status, Cost, Runtime, Additional Infos + def_value = svm_from_cfg(cs.get_default_configuration()) + print("Default Value: %.2f" % def_value) + + # Optimize, using a SMAC-object + print("Optimizing! Depending on your machine, this might take a few minutes.") + + # Besides the kwargs used for initializing UncorrelatedMultiObjectiveRandomForestWithInstances, + # we also need kwargs for initializing the model insides UncorrelatedMultiObjectiveModel + model_kwargs = {"target_names": ["loss", "time"], "model_kwargs": {"seed": 1}} + smac = SMAC4AC( + scenario=scenario, + model=UncorrelatedMultiObjectiveRandomForestWithInstances, + rng=np.random.RandomState(42), + model_kwargs=model_kwargs, + tae_runner=svm_from_cfg, + acquisition_function=EIPS, + runhistory2epm=RunHistory2EPM4EIPS + ) + + incumbent = smac.optimize() + + inc_value = svm_from_cfg(incumbent) + print("Optimized Value: %.2f" % (inc_value)) + + # We can also validate our results (though this makes a lot more sense with instances) + smac.validate( + config_mode="inc", # We can choose which configurations to evaluate + # instance_mode='train+test', # Defines what instances to validate + repetitions=100, # Ignored, unless you set "deterministic" to "false" in line 95 + n_jobs=1, + ) # How many cores to use in parallel for optimization diff --git a/examples/python/plot_synthetic_function.py b/examples/python/plot_synthetic_function.py index af493364f..8f45c864d 100644 --- a/examples/python/plot_synthetic_function.py +++ b/examples/python/plot_synthetic_function.py @@ -16,7 +16,6 @@ logging.basicConfig(level=logging.INFO) import numpy as np - from ConfigSpace.hyperparameters import UniformFloatHyperparameter # Import ConfigSpace and different types of parameters diff --git a/examples/python/user_prior_mlp.py b/examples/python/plot_user_prior_mlp.py similarity index 93% rename from examples/python/user_prior_mlp.py rename to examples/python/plot_user_prior_mlp.py index 64e431208..390fb64f5 100644 --- a/examples/python/user_prior_mlp.py +++ b/examples/python/plot_user_prior_mlp.py @@ -19,26 +19,25 @@ logging.basicConfig(level=logging.INFO) import warnings -import numpy as np import ConfigSpace as CS +import numpy as np from ConfigSpace.hyperparameters import ( - CategoricalHyperparameter, - UniformIntegerHyperparameter, BetaIntegerHyperparameter, + CategoricalHyperparameter, NormalFloatHyperparameter, + UniformIntegerHyperparameter, ) - from sklearn.datasets import load_digits from sklearn.exceptions import ConvergenceWarning -from sklearn.model_selection import cross_val_score, StratifiedKFold +from sklearn.model_selection import StratifiedKFold, cross_val_score from sklearn.neural_network import MLPClassifier from smac.configspace import ConfigurationSpace -from smac.facade.smac_hpo_facade import SMAC4HPO from smac.facade.smac_bb_facade import SMAC4BB -from smac.scenario.scenario import Scenario +from smac.facade.smac_hpo_facade import SMAC4HPO from smac.initial_design.random_configuration_design import RandomConfigurations +from smac.scenario.scenario import Scenario __copyright__ = "Copyright 2021, AutoML.org Freiburg-Hannover" __license__ = "3-clause BSD" @@ -135,7 +134,8 @@ def mlp_from_cfg(cfg, seed): } ) - # The rate at which SMAC forgets the prior. The higher the value, the more the prior is considered. + # The rate at which SMAC forgets the prior. + # The higher the value, the more the prior is considered. # Defaults to # n_iterations / 10 user_prior_kwargs = {"decay_beta": 1.5} @@ -144,9 +144,11 @@ def mlp_from_cfg(cfg, seed): scenario=scenario, rng=np.random.RandomState(42), tae_runner=mlp_from_cfg, - user_priors=True, # This flag is required to conduct the optimisation using priors over the optimum + # This flag is required to conduct the optimisation using priors over the optimum + user_priors=True, user_prior_kwargs=user_prior_kwargs, - initial_design=RandomConfigurations, # Using random configurations will cause the initialization to be samples drawn from the prior + # Using random configurations will cause the initialization to be samples drawn from the prior + initial_design=RandomConfigurations, ) # Example call of the function with default values diff --git a/pyproject.toml b/pyproject.toml index bdeecd144..a2d47729e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ src_paths = ["smac", "tests"] known_types = ["typing", "abc"] # We put these in their own section "types" known_test = ["tests"] known_first_party = ["smac"] +known_third_party = ["ConfigSpace"] sections = [ "FUTURE", "TYPES", diff --git a/setup.py b/setup.py index 63244a6df..f22d02acc 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ def read_file(filepath: str) -> str: "numpy>=1.7.1", "scipy>=1.7.0", "psutil", - "pynisher>=0.4.1", + "pynisher<1.0.0", "ConfigSpace>=0.5.0", "joblib", "scikit-learn>=0.22.0", @@ -85,6 +85,9 @@ def read_file(filepath: str) -> str: extras_require=extras_require, test_suite="pytest", platforms=["Linux"], + entry_points={ + "console_scripts": ["smac = smac.smac_cli:cmd_line_call"], + }, classifiers=[ "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", diff --git a/smac/__init__.py b/smac/__init__.py index 288850a3d..0993ca306 100644 --- a/smac/__init__.py +++ b/smac/__init__.py @@ -22,7 +22,7 @@ Matthias Feurer, André Biedenkapp, Difan Deng, Carolin Benjamins, Tim Ruhkopf, René Sass and Frank Hutter """ -version = "1.3.2" +version = "1.3.3" if os.name != "posix": diff --git a/smac/epm/base_uncorrelated_mo_model.py b/smac/epm/base_uncorrelated_mo_model.py new file mode 100644 index 000000000..94815cdbe --- /dev/null +++ b/smac/epm/base_uncorrelated_mo_model.py @@ -0,0 +1,183 @@ +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from smac.configspace import ConfigurationSpace +from smac.epm.base_epm import AbstractEPM + +__copyright__ = "Copyright 2021, AutoML.org Freiburg-Hannover" +__license__ = "3-clause BSD" + + +class UncorrelatedMultiObjectiveModel(AbstractEPM): + """Wrapper for the surrogate models to predict multiple targets. + + Only a list with the target names and the types array for the + underlying model are mandatory. All other hyperparameters to + model can be passed via kwargs. Consult the documentation of + the corresponding model for the hyperparameters and their meanings. + + Parameters + ---------- + target_names : list + List of str, each entry is the name of one target dimension. Length + of the list will be ``n_objectives``. + types : List[int] + Specifies the number of categorical values of an input dimension where + the i-th entry corresponds to the i-th input dimension. Let's say we + have 2 dimension where the first dimension consists of 3 different + categorical choices and the second dimension is continuous than we + have to pass [3, 0]. Note that we count starting from 0. + bounds : List[Tuple[float, float]] + bounds of input dimensions: (lower, uppper) for continuous dims; (n_cat, np.nan) + for categorical dims + instance_features : np.ndarray (I, K) + Contains the K dimensional instance features of I different instances + pca_components : float + Number of components to keep when using PCA to reduce dimensionality of instance features. + Requires to set n_feats (> pca_dims). + model_kwargs: Optional[Dict[str, Any]]: + arguments for initialing estimators + + Attributes + ---------- + target_names: + target names + num_targets: int + number of targets + estimators: List[AbstractEPM] + a list of estimators predicting different target values + """ + + def __init__( + self, + target_names: List[str], + configspace: ConfigurationSpace, + types: List[int], + bounds: List[Tuple[float, float]], + seed: int, + instance_features: Optional[np.ndarray] = None, + pca_components: Optional[int] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__( + configspace=configspace, + bounds=bounds, + types=types, + seed=seed, + instance_features=instance_features, + pca_components=pca_components, + ) + if model_kwargs is None: + model_kwargs = {} + self.target_names = target_names + self.num_targets = len(self.target_names) + self.estimators: List[AbstractEPM] = self.construct_estimators(configspace, types, bounds, model_kwargs) + + @abstractmethod + def construct_estimators( + self, + configspace: ConfigurationSpace, + types: List[int], + bounds: List[Tuple[float, float]], + model_kwargs: Dict[str, Any], + ) -> List[AbstractEPM]: + """ + Construct a list of estimators. The number of the estimators equals 'self.num_targets' + Parameters + ---------- + configspace : ConfigurationSpace + Configuration space to tune for. + types : List[int] + Specifies the number of categorical values of an input dimension where + the i-th entry corresponds to the i-th input dimension. Let's say we + have 2 dimension where the first dimension consists of 3 different + categorical choices and the second dimension is continuous than we + have to pass [3, 0]. Note that we count starting from 0. + bounds : List[Tuple[float, float]] + bounds of input dimensions: (lower, uppper) for continuous dims; (n_cat, np.nan) for categorical dims + model_kwargs : Dict[str, Any] + model kwargs for initializing models + Returns + ------- + estimators: List[AbstractEPM] + A list of estimators + """ + raise NotImplementedError + + def _train(self, X: np.ndarray, Y: np.ndarray) -> "UncorrelatedMultiObjectiveModel": + """Trains the models on X and y. + + Parameters + ---------- + X : np.ndarray [n_samples, n_features (config + instance features)] + Input data points. + Y : np.ndarray [n_samples, n_objectives] + The corresponding target values. n_objectives must match the + number of target names specified in the constructor. + + Returns + ------- + self + """ + if len(self.estimators) == 0: + raise ValueError("The list of estimators for this model is empty!") + for i, estimator in enumerate(self.estimators): + estimator.train(X, Y[:, i]) + + return self + + def _predict(self, X: np.ndarray, cov_return_type: Optional[str] = "diagonal_cov") -> Tuple[np.ndarray, np.ndarray]: + """Predict means and variances for given X. + + Parameters + ---------- + X : np.ndarray of shape = [n_samples, n_features (config + instance + features)] + cov_return_type: typing.Optional[str] + Specifies what to return along with the mean. Refer ``predict()`` for more information. + + Returns + ------- + means : np.ndarray of shape = [n_samples, n_objectives] + Predictive mean + vars : np.ndarray of shape = [n_samples, n_objectives] + Predictive variance + """ + if cov_return_type != "diagonal_cov": + raise ValueError("'cov_return_type' can only take 'diagonal_cov' for this model") + + mean = np.zeros((X.shape[0], self.num_targets)) + var = np.zeros((X.shape[0], self.num_targets)) + for i, estimator in enumerate(self.estimators): + m, v = estimator.predict(X) + assert v is not None # please mypy + mean[:, i] = m.flatten() + var[:, i] = v.flatten() + return mean, var + + def predict_marginalized_over_instances(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Predict mean and variance marginalized over all instances. + + Returns the predictive mean and variance marginalised over all + instances for a set of configurations. + + Parameters + ---------- + X : np.ndarray of shape = [n_features (config), ] + + Returns + ------- + means : np.ndarray of shape = [n_samples, n_objectives] + Predictive mean + vars : np.ndarray of shape = [n_samples, n_objectives] + Predictive variance + """ + mean = np.zeros((X.shape[0], self.num_targets)) + var = np.zeros((X.shape[0], self.num_targets)) + for i, estimator in enumerate(self.estimators): + m, v = estimator.predict_marginalized_over_instances(X) + mean[:, i] = m.flatten() + var[:, i] = v.flatten() + return mean, var diff --git a/smac/epm/uncorrelated_mo_rf_with_instances.py b/smac/epm/uncorrelated_mo_rf_with_instances.py index c2616936e..f628050d5 100644 --- a/smac/epm/uncorrelated_mo_rf_with_instances.py +++ b/smac/epm/uncorrelated_mo_rf_with_instances.py @@ -1,150 +1,49 @@ -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np +from typing import Any, Dict, List, Tuple from smac.configspace import ConfigurationSpace from smac.epm.base_epm import AbstractEPM +from smac.epm.base_uncorrelated_mo_model import UncorrelatedMultiObjectiveModel from smac.epm.rf_with_instances import RandomForestWithInstances __copyright__ = "Copyright 2021, AutoML.org Freiburg-Hannover" __license__ = "3-clause BSD" -class UncorrelatedMultiObjectiveRandomForestWithInstances(AbstractEPM): +class UncorrelatedMultiObjectiveRandomForestWithInstances(UncorrelatedMultiObjectiveModel): """Wrapper for the random forest to predict multiple targets. - Only the a list with the target names and the types array for the + Only a list with the target names and the types array for the underlying forest model are mandatory. All other hyperparameters to the random forest can be passed via kwargs. Consult the documentation of the random forest for the hyperparameters and their meanings. - - - Parameters - ---------- - target_names : list - List of str, each entry is the name of one target dimension. Length - of the list will be ``n_objectives``. - types : List[int] - Specifies the number of categorical values of an input dimension where - the i-th entry corresponds to the i-th input dimension. Let's say we - have 2 dimension where the first dimension consists of 3 different - categorical choices and the second dimension is continuous than we - have to pass [3, 0]. Note that we count starting from 0. - bounds : List[Tuple[float, float]] - bounds of input dimensions: (lower, uppper) for continuous dims; (n_cat, np.nan) for categorical dims - instance_features : np.ndarray (I, K) - Contains the K dimensional instance features of the I different instances - pca_components : float - Number of components to keep when using PCA to reduce dimensionality of instance features. Requires to - set n_feats (> pca_dims). - - - Attributes - ---------- - target_names - num_targets - estimators """ - def __init__( + def construct_estimators( self, - target_names: List[str], configspace: ConfigurationSpace, types: List[int], bounds: List[Tuple[float, float]], - seed: int, - rf_kwargs: Optional[Dict[str, Any]] = None, - instance_features: Optional[np.ndarray] = None, - pca_components: Optional[int] = None, - ) -> None: - super().__init__( - configspace=configspace, - bounds=bounds, - types=types, - seed=seed, - instance_features=instance_features, - pca_components=pca_components, - ) - if rf_kwargs is None: - rf_kwargs = {} - - self.target_names = target_names - self.num_targets = len(self.target_names) - print(seed, rf_kwargs) - self.estimators = [ - RandomForestWithInstances(configspace, types, bounds, **rf_kwargs) for _ in range(self.num_targets) - ] - - def _train(self, X: np.ndarray, Y: np.ndarray) -> "UncorrelatedMultiObjectiveRandomForestWithInstances": - """Trains the random forest on X and y. - - Parameters - ---------- - X : np.ndarray [n_samples, n_features (config + instance features)] - Input data points. - Y : np.ndarray [n_samples, n_objectives] - The corresponding target values. n_objectives must match the - number of target names specified in the constructor. - - Returns - ------- - self + model_kwargs: Dict[str, Any], + ) -> List[AbstractEPM]: """ - for i, estimator in enumerate(self.estimators): - estimator.train(X, Y[:, i]) - - return self - - def _predict(self, X: np.ndarray, cov_return_type: Optional[str] = "diagonal_cov") -> Tuple[np.ndarray, np.ndarray]: - """Predict means and variances for given X. - + Construct a list of estimators. The number of the estimators equals 'self.num_targets' Parameters ---------- - X : np.ndarray of shape = [n_samples, n_features (config + instance - features)] - cov_return_type: typing.Optional[str] - Specifies what to return along with the mean. Refer ``predict()`` for more information. - - Returns - ------- - means : np.ndarray of shape = [n_samples, n_objectives] - Predictive mean - vars : np.ndarray of shape = [n_samples, n_objectives] - Predictive variance - """ - if cov_return_type != "diagonal_cov": - raise ValueError("'cov_return_type' can only take 'diagonal_cov' for this model") - - mean = np.zeros((X.shape[0], self.num_targets)) - var = np.zeros((X.shape[0], self.num_targets)) - for i, estimator in enumerate(self.estimators): - m, v = estimator.predict(X) - assert v is not None # please mypy - mean[:, i] = m.flatten() - var[:, i] = v.flatten() - return mean, var - - def predict_marginalized_over_instances(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """Predict mean and variance marginalized over all instances. - - Returns the predictive mean and variance marginalised over all - instances for a set of configurations. - - Parameters - ---------- - X : np.ndarray of shape = [n_features (config), ] - + configspace : ConfigurationSpace + Configuration space to tune for. + types : List[int] + Specifies the number of categorical values of an input dimension where + the i-th entry corresponds to the i-th input dimension. Let's say we + have 2 dimension where the first dimension consists of 3 different + categorical choices and the second dimension is continuous than we + have to pass [3, 0]. Note that we count starting from 0. + bounds : List[Tuple[float, float]] + bounds of input dimensions: (lower, uppper) for continuous dims; (n_cat, np.nan) for categorical dims + model_kwargs : Dict[str, Any] + model kwargs for initializing models Returns ------- - means : np.ndarray of shape = [n_samples, n_objectives] - Predictive mean - vars : np.ndarray of shape = [n_samples, n_objectives] - Predictive variance + estimators: List[AbstractEPM] + A list of Random Forests """ - mean = np.zeros((X.shape[0], self.num_targets)) - var = np.zeros((X.shape[0], self.num_targets)) - for i, estimator in enumerate(self.estimators): - m, v = estimator.predict_marginalized_over_instances(X) - mean[:, i] = m.flatten() - var[:, i] = v.flatten() - return mean, var + return [RandomForestWithInstances(configspace, types, bounds, **model_kwargs) for _ in range(self.num_targets)] diff --git a/smac/facade/smac_ac_facade.py b/smac/facade/smac_ac_facade.py index b1f2c1bb2..c6173fbab 100644 --- a/smac/facade/smac_ac_facade.py +++ b/smac/facade/smac_ac_facade.py @@ -9,6 +9,7 @@ from smac.configspace import Configuration from smac.epm.base_epm import AbstractEPM +from smac.epm.base_uncorrelated_mo_model import UncorrelatedMultiObjectiveModel # epm from smac.epm.rf_with_instances import RandomForestWithInstances @@ -30,6 +31,7 @@ from smac.intensification.successive_halving import SuccessiveHalving from smac.optimizer.acquisition import ( EI, + EIPS, AbstractAcquisitionFunction, IntegratedAcquisitionFunction, LogEI, @@ -364,6 +366,12 @@ def __init__( "Argument acquisition_function must be None or an object implementing the " "AbstractAcquisitionFunction, not %s." % type(acquisition_function) ) + if isinstance(acquisition_function_instance, EIPS) and not isinstance( + model_instance, UncorrelatedMultiObjectiveModel + ): + raise TypeError( + "If the acquisition function is EIPS, the surrogate model must support multi-objective prediction!" + ) if integrate_acquisition_function: acquisition_function_instance = IntegratedAcquisitionFunction( acquisition_function=acquisition_function_instance, # type: ignore diff --git a/smac/runhistory/runhistory.py b/smac/runhistory/runhistory.py index 616cfbc13..a84a06a7c 100644 --- a/smac/runhistory/runhistory.py +++ b/smac/runhistory/runhistory.py @@ -354,11 +354,18 @@ def _update_objective_bounds(self) -> None: self.objective_bounds += [(min_v, max_v)] def _add(self, k: RunKey, v: RunValue, status: StatusType, origin: DataOrigin) -> None: - """Actual function to add new entry to data structures.""" + """ + Actual function to add new entry to data structures. + + Note + ---- + This method always calls `update_cost` in the multi- + objective setting. + """ self.data[k] = v self.external[k] = origin - # Update objective bounds + # Update objective bounds based on raw data self._update_objective_bounds() # Capped data is added above @@ -377,17 +384,24 @@ def _add(self, k: RunKey, v: RunValue, status: StatusType, origin: DataOrigin) - # append new budget to existing inst-seed-key dict self._configid_to_inst_seed_budget[k.config_id][is_k].append(k.budget) - # if budget is used, then update cost instead of incremental updates - if not self.overwrite_existing_runs and k.budget == 0: - # assumes an average across runs as cost function aggregation, this is used for algorithm configuration - # (incremental updates are used to save time as getting the cost for > 100 instances is high) - self.incremental_update_cost(self.ids_config[k.config_id], v.cost) + # Update costs in multi-objective setting s.t. all costs are + # normalized the same. + # TODO: This is only a temporary solution because the caching is not used. + if self.num_obj > 1: + self.update_all_costs() else: - # this is when budget > 0 (only successive halving and hyperband so far) - self.update_cost(config=self.ids_config[k.config_id]) - if k.budget > 0: - if self.num_runs_per_config[k.config_id] != 1: # This is updated in update_cost - raise ValueError("This should not happen!") + # if budget is used, then update cost instead of incremental updates + if not self.overwrite_existing_runs and k.budget == 0: + # assumes an average across runs as cost function aggregation, this is used for + # algorithm configuration (incremental updates are used to save time as getting the + # cost for > 100 instances is high) + self.incremental_update_cost(self.ids_config[k.config_id], v.cost) + else: + # this is when budget > 0 (only successive halving and hyperband so far) + self.update_cost(config=self.ids_config[k.config_id]) + if k.budget > 0: + if self.num_runs_per_config[k.config_id] != 1: # This is updated in update_cost + raise ValueError("This should not happen!") def update_cost(self, config: Configuration) -> None: """Stores the performance of a configuration across the instances in self.cost_per_config @@ -411,6 +425,11 @@ def update_cost(self, config: Configuration) -> None: all_inst_seed_budgets = list(dict.fromkeys(self.get_runs_for_config(config, only_max_observed_budget=False))) self._min_cost_per_config[config_id] = self.min_cost(config, all_inst_seed_budgets) + def update_all_costs(self) -> None: + """Update all costs in the runhistory.""" + for config in self.ids_config.values(): + self.update_cost(config) + def incremental_update_cost(self, config: Configuration, cost: Union[np.ndarray, list, float, int]) -> None: """Incrementally updates the performance of a configuration by using a moving average. diff --git a/smac/runhistory/runhistory2epm.py b/smac/runhistory/runhistory2epm.py index f46d6a598..768e41345 100644 --- a/smac/runhistory/runhistory2epm.py +++ b/smac/runhistory/runhistory2epm.py @@ -704,7 +704,7 @@ def _build_matrix( else: y[row, 0] = run.cost - y[row, 1] = 1 + run.time + y[row, 1] = run.time y_transformed = self.transform_response_values(values=y) @@ -725,5 +725,6 @@ def transform_response_values(self, values: np.ndarray) -> np.ndarray: ------- np.ndarray """ + # We need to ensure that time remains positive after the log transform. values[:, 1] = np.log(1 + values[:, 1]) return values diff --git a/tests/test_epm/test_gp_priors.py b/tests/test_epm/test_gp_priors.py index 702cf47e9..5bc889bc0 100644 --- a/tests/test_epm/test_gp_priors.py +++ b/tests/test_epm/test_gp_priors.py @@ -1,5 +1,5 @@ -from functools import partial import unittest +from functools import partial import numpy as np import scipy.optimize diff --git a/tests/test_epm/test_uncorrelated_mo_rf_with_instances.py b/tests/test_epm/test_uncorrelated_mo_rf_with_instances.py index 2e97344c5..e3d44cee1 100644 --- a/tests/test_epm/test_uncorrelated_mo_rf_with_instances.py +++ b/tests/test_epm/test_uncorrelated_mo_rf_with_instances.py @@ -41,7 +41,7 @@ def test_train_and_predict_with_rf(self): (0, np.nan), ], seed=1, - rf_kwargs={"seed": 1}, + model_kwargs={"seed": 1}, pca_components=5, ) self.assertEqual(model.estimators[0].seed, 1) @@ -52,6 +52,10 @@ def test_train_and_predict_with_rf(self): self.assertEqual(m.shape, (10, 2)) self.assertEqual(v.shape, (10, 2)) + m, v = model.predict_marginalized_over_instances(X[10:]) + self.assertEqual(m.shape, (10, 2)) + self.assertEqual(v.shape, (10, 2)) + # We need to track how often the base model was called! @mock.patch.object(RandomForestWithInstances, "predict") def test_predict_mocked(self, rf_mock): @@ -87,7 +91,7 @@ def __call__(self, X): (0, np.nan), ], seed=1, - rf_kwargs={"seed": 1}, + model_kwargs={"seed": 1}, ) model.train(X[:10], Y[:10]) diff --git a/tests/test_facade/test_smac_facade.py b/tests/test_facade/test_smac_facade.py index 0cd9efc89..7f19695e6 100644 --- a/tests/test_facade/test_smac_facade.py +++ b/tests/test_facade/test_smac_facade.py @@ -317,7 +317,7 @@ def test_init_EIPS_as_arguments(self): smbo = SMAC4AC( self.scenario, model=UncorrelatedMultiObjectiveRandomForestWithInstances, - model_kwargs={"target_names": ["a", "b"], "rf_kwargs": {"seed": 1}}, + model_kwargs={"target_names": ["a", "b"], "model_kwargs": {"seed": 1}}, acquisition_function=EIPS, runhistory2epm=RunHistory2EPM4EIPS, ).solver @@ -332,6 +332,9 @@ def test_init_EIPS_as_arguments(self): ) self.assertIsInstance(smbo.epm_chooser.rh2EPM, RunHistory2EPM4EIPS) + with self.assertRaisesRegex(TypeError, "the surrogate model must support multi-objective prediction!"): + SMAC4AC(self.scenario, acquisition_function=EIPS, runhistory2epm=RunHistory2EPM4EIPS) + #################################################################################################################### # Other tests... diff --git a/tests/test_multi_objective/test_schaffer.py b/tests/test_multi_objective/test_schaffer.py index e4be292b2..cffa8fddd 100644 --- a/tests/test_multi_objective/test_schaffer.py +++ b/tests/test_multi_objective/test_schaffer.py @@ -8,10 +8,10 @@ from matplotlib import pyplot as plt from smac.configspace import ConfigurationSpace +from smac.facade.roar_facade import ROAR from smac.facade.smac_ac_facade import SMAC4AC from smac.facade.smac_bb_facade import SMAC4BB from smac.facade.smac_hpo_facade import SMAC4HPO -from smac.facade.roar_facade import ROAR from smac.optimizer.multi_objective.parego import ParEGO from smac.scenario.scenario import Scenario diff --git a/tests/test_runhistory/test_runhistory_multi_objective.py b/tests/test_runhistory/test_runhistory_multi_objective.py index cf026fdb4..e21c5f2bc 100644 --- a/tests/test_runhistory/test_runhistory_multi_objective.py +++ b/tests/test_runhistory/test_runhistory_multi_objective.py @@ -135,6 +135,54 @@ def test_add_multiple_times(self): # We expect to get 1.0 and 2.0 because runhistory does not overwrite by default self.assertEqual(list(rh.data.values())[0].cost, [1.0, 2.0]) + def test_full(self): + rh = RunHistory() + cs = get_config_space() + config1 = Configuration(cs, values={"a": 1, "b": 2}) + config2 = Configuration(cs, values={"a": 1, "b": 3}) + config3 = Configuration(cs, values={"a": 1, "b": 4}) + rh.add( + config=config1, + cost=[50, 100], + time=20, + status=StatusType.SUCCESS, + ) + + print(rh._cost_per_config) + + # Only one value: Normalization goes to 1.0 + self.assertEqual(rh.get_cost(config1), 1.0) + + rh.add( + config=config2, + cost=[150, 50], + time=30, + status=StatusType.SUCCESS, + ) + + # The cost of the first config must be updated + # We would expect [0, 1] and the normalized value would be 0.5 + self.assertEqual(rh.get_cost(config1), 0.5) + + # We would expect [1, 0] and the normalized value would be 0.5 + self.assertEqual(rh.get_cost(config2), 0.5) + + rh.add( + config=config3, + cost=[100, 0], + time=40, + status=StatusType.SUCCESS, + ) + + # [0, 1] -> 0.5 + self.assertEqual(rh.get_cost(config1), 0.5) + + # [1, 0.5] -> 0.75 + self.assertEqual(rh.get_cost(config2), 0.75) + + # [0.5, 0] -> 0.25 + self.assertEqual(rh.get_cost(config3), 0.25) + def test_full_update(self): rh = RunHistory(overwrite_existing_runs=True) cs = get_config_space() @@ -214,9 +262,9 @@ def test_incremental_update(self): seed=1, ) - # We except 0.75 because of moving average - # First we have 1 and then 0.5, the moving average is then 0.75 - self.assertEqual(rh.get_cost(config1), 0.75) + # We don't except moving average of 0.75 here because + # the costs should always be updated. + self.assertEqual(rh.get_cost(config1), 0.5) rh.add( config=config1, @@ -227,7 +275,7 @@ def test_incremental_update(self): seed=1, ) - self.assertAlmostEqual(rh.get_cost(config1), 0.694, places=3) + self.assertAlmostEqual(rh.get_cost(config1), 0.583, places=3) def test_multiple_budgets(self): @@ -562,4 +610,3 @@ def test_budgets(self): if __name__ == "__main__": t = RunhistoryMultiObjectiveTest() - t.test_add_and_pickle() diff --git a/tests/test_scenario/test_scenario.py b/tests/test_scenario/test_scenario.py index 636dae8c0..e47a6ba7f 100644 --- a/tests/test_scenario/test_scenario.py +++ b/tests/test_scenario/test_scenario.py @@ -219,7 +219,7 @@ def test_write(self): pcs- or instance-files, so they are checked manually.""" def check_scen_eq(scen1, scen2): - """ Customized check for scenario-equality, ignoring file-paths """ + """Customized check for scenario-equality, ignoring file-paths""" print("check_scen_eq") for name in scen1._arguments: dest = scen1._arguments[name]["dest"]