Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add global config interface #149

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
21 changes: 21 additions & 0 deletions docs/source/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ Base Classes
BaseObject
BaseEstimator

.. _global_config:

Configure ``skbase``
====================

.. automodule:: skbase.config
:no-members:
:no-inherited-members:

.. currentmodule:: skbase.config

.. autosummary::
:toctree: api_reference/auto_generated/
:template: function.rst

get_config
get_default_config
set_config
reset_config
config_context

.. _obj_retrieval:

Object Retrieval
Expand Down
15 changes: 15 additions & 0 deletions docs/source/user_documentation/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ that ``skbase`` provides, see the :ref:`api_ref`.
user_guide/lookup
user_guide/validate
user_guide/testing
user_guide/configuration


.. grid:: 1 2 2 2
Expand Down Expand Up @@ -103,3 +104,17 @@ that ``skbase`` provides, see the :ref:`api_ref`.
:expand:

Testing

.. grid-item-card:: Configuration
:text-align: center

Configure ``skbase``.

+++

.. button-ref:: user_guide/configuration
:color: primary
:click-parent:
:expand:

Configuration
11 changes: 11 additions & 0 deletions docs/source/user_documentation/user_guide/configuration.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.. _user_guide_global_config:

====================
Configure ``skbase``
====================

.. note::

The user guide is under development. We have created a basic
structure and are looking for contributions to develop the user guide
further.
65 changes: 60 additions & 5 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class name: BaseEstimator

from skbase._exceptions import NotFittedError
from skbase.base._tagmanager import _FlagManager
from skbase.config import get_config
from skbase.config._config import _CONFIG_REGISTRY

__author__: List[str] = ["mloning", "RNKuhns", "fkiraly"]
__all__: List[str] = ["BaseEstimator", "BaseObject"]
Expand Down Expand Up @@ -437,16 +439,69 @@ def clone_tags(self, estimator, tag_names=None):
return self

def get_config(self):
"""Get config flags for self.
"""Get configuration parameters impacting the object.

The configuration is retrieved in the following order:

- ``skbase`` global configuration,
- downstream package configurations, and
- local configuration set via the object's _config class variable
or the object's `set_config` parameter.

Returns
-------
config_dict : dict
Dictionary of config name : config value pairs. Collected from _config
class attribute via nested inheritance and then any overrides
and new tags from _onfig_dynamic object attribute.
Dictionary of config name : config value pairs.
"""
return self._get_flags(flag_attr_name="_config")
# Configuration is collected in a specific order from the farthest to
# most local. skbase global config -> downstream package (optional) -> local
# Start by collecting skbase's global config
config = get_config().copy()

# Use the object config extension interface to optionally retrieve the
# configuration of any downstream package that is subclassing BaseObject
if hasattr(self, "__skbase_get_config__") and callable(
self.__skbase_get_config__
):
skbase_get_config_extension_dict = self.__skbase_get_config__()
else:
skbase_get_config_extension_dict = {}

# If the config extension dunder returned a dict, use it to update
# the dict of configs that have been retrieved so far
if isinstance(skbase_get_config_extension_dict, dict):
config.update(skbase_get_config_extension_dict)
# Otherwise warn the user that a dict wasn't returned (extension not
# done properly) and ignore the returned result
else:
msg = "Use of `__skbase_get_config__` to extend the interface for local "
msg += "overrides of the global configuration must return a dictionary.\n"
msg += f"But a {type(skbase_get_config_extension_dict)} was found."
msg += "Ignoring result returned from `__skbase_get_config__`."
warnings.warn(msg, UserWarning, stacklevel=2)

# Finally get the local config in case any optional instance config
# overrides were made)
local_config = self._get_flags(flag_attr_name="_config").copy()

# If the local config param name is one of the skbase global config
# options (as opposed to config of downstream package) then we want
# to make sure we don't return an invalid value. In this case, we'll
# fallback to the default value if an invalid value was set as the local
# override of the global config
for config_param, config_value in local_config.items():
if config_param in _CONFIG_REGISTRY:
msg = "Invalid value encountered for global configuration parameter "
msg += f"{config_param}. Using global parameter configuration value.\n"
config_value = _CONFIG_REGISTRY[
config_param
].get_valid_param_or_default(
config_value, default_value=config[config_param]
)
local_config[config_param] = config_value
config.update(local_config)

return config

def set_config(self, **config_dict):
"""Set config flags to given values.
Expand Down
32 changes: 32 additions & 0 deletions skbase/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
""":mod:`skbase.config` provides tools for the global configuration of ``skbase``.

For more information on configuration usage patterns see the
:ref:`user guide <user_guide_global_config>`.
"""
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
# Includes functionality like get_config, set_config, and config_context
# that is similar to scikit-learn. These elements are copyrighted by the
# scikit-learn developers, BSD-3-Clause License. For conditions see
# https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
from typing import List

from skbase.config._config import (
GlobalConfigParamSetting,
config_context,
get_config,
get_default_config,
reset_config,
set_config,
)

__author__: List[str] = ["RNKuhns"]
__all__: List[str] = [
"GlobalConfigParamSetting",
"get_default_config",
"get_config",
"set_config",
"reset_config",
"config_context",
]
Loading