Skip to content

Commit

Permalink
Global Configuration (#465)
Browse files Browse the repository at this point in the history
* Add proposal for global config context-manager

* Add register

* Example test

Taking the randvars.Normal cholesky inversion damping factor as an example

* test

* Extra config file per module

* In test, set attribute to config to add coverage

* Improved tests

* Configuration documentation

* probnum --> ProbNum

* Made damping non-optional in low-level functions

* Moved everything to top-level

* Remove unused import

* Changes in docstring

* Changes in documentation

* Documentation

Co-authored-by: Jonathan Schmidt <[email protected]>
  • Loading branch information
marvinpfoertner and schmidtjonathan authored Jul 13, 2021
1 parent 50dc05e commit 888f911
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 34 deletions.
53 changes: 28 additions & 25 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,33 @@ API Reference

.. table::

+----------------------------------+--------------------------------------------------------------+
| **Subpackage** | **Description** |
+----------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.diffeq` | Probabilistic solvers for ordinary differential equations. |
+----------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.filtsmooth` | Bayesian filtering and smoothing. |
+----------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.kernels` | Kernels / covariance functions. |
+----------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.linalg` | Probabilistic numerical linear algebra. |
+----------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.linops` | Finite-dimensional linear operators. |
+----------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.problems` | Definitions and collection of problems solved by PN methods. |
+----------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.quad` | Bayesian quadrature / numerical integration. |
+----------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.randprocs` | Random processes representing uncertain functions. |
+----------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.randvars` | Random variables representing uncertain values. |
+----------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.statespace` | Probabilistic state space models. |
+----------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.utils` | Utility functions. |
+----------------------------------+--------------------------------------------------------------+
+-------------------------------------------------+--------------------------------------------------------------+
| **Subpackage** | **Description** |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.diffeq` | Probabilistic solvers for ordinary differential equations. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.filtsmooth` | Bayesian filtering and smoothing. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.kernels` | Kernels / covariance functions. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.linalg` | Probabilistic numerical linear algebra. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.linops` | Finite-dimensional linear operators. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.problems` | Definitions and collection of problems solved by PN methods. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.quad` | Bayesian quadrature / numerical integration. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.randprocs` | Random processes representing uncertain functions. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.randvars` | Random variables representing uncertain values. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.statespace` | Probabilistic state space models. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.utils` | Utility functions. |
+-------------------------------------------------+--------------------------------------------------------------+
| :class:`config <probnum._config.Configuration>` | Global configuration options. |
+-------------------------------------------------+--------------------------------------------------------------+

.. toctree::
:maxdepth: 2
Expand All @@ -47,3 +49,4 @@ API Reference
api/randvars
api/statespace
api/utils
api/config
7 changes: 7 additions & 0 deletions docs/source/api/config.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
**************
probnum.config
**************

.. automodapi:: probnum._config
:no-heading:
:headings: "="
9 changes: 9 additions & 0 deletions src/probnum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@

from pkg_resources import DistributionNotFound, get_distribution

from . import _config

# Global Configuration
from ._config import _GLOBAL_CONFIG_SINGLETON as config

"""The global configuration registry. Can be used as a context manager to create local
contexts in which configuration is temporarily overwritten. This object contains
unguarded global state and is hence not thread-safe!"""

from . import (
diffeq,
filtsmooth,
Expand Down
101 changes: 101 additions & 0 deletions src/probnum/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import contextlib
from typing import Any


class Configuration:
"""
Configuration over which some mechanics of ProbNum can be controlled dynamically.
ProbNum provides some configurations together with default values. These
are listed in the tables below.
Additionally, users can register their own configuration entries via
:meth:`register`. Configuration entries can only be registered once and can only
be used (accessed or overwritten) once they have been registered.
+----------------------------------+---------------+----------------------------------------------+
| Config entry | Default value | Description |
+==================================+===============+==============================================+
| ``covariance_inversion_damping`` | ``1e-12`` | A (typically small) value that is per |
| | | default added to the diagonal of covariance |
| | | matrices in order to make inversion |
| | | numerically stable. |
+----------------------------------+---------------+----------------------------------------------+
| ``...`` | ``...`` | ... |
+----------------------------------+---------------+----------------------------------------------+
Examples
========
>>> import probnum
>>> probnum.config.covariance_inversion_damping
1e-12
>>> with probnum.config(
... covariance_inversion_damping=1e-2,
... ):
... probnum.config.covariance_inversion_damping
0.01
"""

@contextlib.contextmanager
def __call__(self, **kwargs) -> None:
old_entries = dict()

for key, value in kwargs.items():
if not hasattr(self, key):
raise KeyError(
f"Configuration entry {key} does not exist yet."
"Configuration entries must be `register`ed before they can be "
"accessed."
)

old_entries[key] = getattr(self, key)

setattr(self, key, value)

try:
yield
finally:
self.__dict__.update(old_entries)

def __setattr__(self, key: str, value: Any) -> None:
if not hasattr(self, key):
raise KeyError(
f"Configuration entry {key} does not exist yet."
"Configuration entries must be `register`ed before they can be "
"accessed."
)

self.__dict__[key] = value

def register(self, key: str, default_value: Any) -> None:
"""Register new configuration option.
Parameters
----------
key:
The name of the configuration option. This will be the ``key`` when calling
``with config(key=<some_value>): ...``.
default_value:
The default value of the configuration option.
"""
if hasattr(self, key):
raise KeyError(
f"Configuration entry {key} does already exist and "
"cannot be registered again."
)
self.__dict__[key] = default_value


# Create a single, global configuration object,...
_GLOBAL_CONFIG_SINGLETON = Configuration()

# ... define some configuration options, and the respective default values
# (which have to be documented in the Configuration-class docstring!!), ...
_DEFAULT_CONFIG_OPTIONS = [
# list of tuples (config_key, default_value)
("covariance_inversion_damping", 1e-12),
]

# ... and register the default configuration options.
for key, default_value in _DEFAULT_CONFIG_OPTIONS:
_GLOBAL_CONFIG_SINGLETON.register(key, default_value)
22 changes: 14 additions & 8 deletions src/probnum/randvars/_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import scipy.linalg
import scipy.stats

from probnum import linops
from probnum import config, linops
from probnum import utils as _utils
from probnum.typing import (
ArrayLikeGetitemArgType,
Expand All @@ -24,8 +24,6 @@
from cached_property import cached_property


COV_CHOLESKY_DAMPING = 10 ** -12

_ValueType = Union[np.floating, np.ndarray, linops.LinearOperator]


Expand Down Expand Up @@ -246,9 +244,12 @@ def cov_cholesky(self) -> _ValueType:
return self._cov_cholesky

def precompute_cov_cholesky(
self, damping_factor: Optional[FloatArgType] = COV_CHOLESKY_DAMPING
self,
damping_factor: Optional[FloatArgType] = None,
):
"""(P)recompute Cholesky factors (careful: in-place operation!)."""
if damping_factor is None:
damping_factor = config.covariance_inversion_damping
if self.cov_cholesky_is_precomputed:
raise Exception("A Cholesky factor is already available.")
self._cov_cholesky = self._compute_cov_cholesky(damping_factor=damping_factor)
Expand Down Expand Up @@ -403,7 +404,8 @@ def _sub_normal(self, other: "Normal") -> "Normal":

# Univariate Gaussians
def _univariate_cov_cholesky(
self, damping_factor: Optional[FloatArgType] = COV_CHOLESKY_DAMPING
self,
damping_factor: FloatArgType,
) -> np.floating:
return np.sqrt(self.cov + damping_factor)

Expand Down Expand Up @@ -452,10 +454,13 @@ def _univariate_entropy(self: _ValueType) -> np.float_:

# Multi- and matrixvariate Gaussians
def dense_cov_cholesky(
self, damping_factor: Optional[FloatArgType] = COV_CHOLESKY_DAMPING
self,
damping_factor: Optional[FloatArgType] = None,
) -> np.ndarray:
"""Compute the Cholesky factorization of the covariance from its dense
representation."""
if damping_factor is None:
damping_factor = config.covariance_inversion_damping
dense_cov = self.dense_cov

return scipy.linalg.cholesky(
Expand Down Expand Up @@ -530,7 +535,8 @@ def _dense_entropy(self) -> np.float_:

# Matrixvariate Gaussian with Kronecker covariance
def _kronecker_cov_cholesky(
self, damping_factor: Optional[FloatArgType] = COV_CHOLESKY_DAMPING
self,
damping_factor: FloatArgType,
) -> linops.Kronecker:
assert isinstance(self.cov, linops.Kronecker)

Expand All @@ -552,7 +558,7 @@ def _kronecker_cov_cholesky(
# factors
def _symmetric_kronecker_identical_factors_cov_cholesky(
self,
damping_factor: Optional[FloatArgType] = COV_CHOLESKY_DAMPING,
damping_factor: FloatArgType,
) -> linops.SymmetricKronecker:
assert (
isinstance(self.cov, linops.SymmetricKronecker)
Expand Down
52 changes: 52 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest

import probnum
from probnum._config import _DEFAULT_CONFIG_OPTIONS


def test_defaults():
none_vals = {key: None for (key, _) in _DEFAULT_CONFIG_OPTIONS}

for key, default_val in _DEFAULT_CONFIG_OPTIONS:
# Check if default is correct before context manager
assert getattr(probnum.config, key) == default_val
# Temporarily set all config values to None
with probnum.config(**none_vals):
assert getattr(probnum.config, key) is None

# Check if the original (default) values are set after exiting the context
# manager
assert getattr(probnum.config, key) == default_val


def test_register():
# Check if registering a new config entry works
probnum.config.register("some_config", 3.14)
assert probnum.config.some_config == 3.14

# When registering a new entry with an already existing name, throw
with pytest.raises(KeyError):
probnum.config.register("some_config", 4.2)

# Check if temporarily setting the config entry to a different value (via
# the context manager) works
with probnum.config(some_config=9.9):
assert probnum.config.some_config == 9.9

# Upon exiting the context manager, the previous value is restored
assert probnum.config.some_config == 3.14

# Setting the config entry permanently also works by
# accessing the attribute directly
probnum.config.some_config = 4.5
assert probnum.config.some_config == 4.5

# Setting a config entry before registering it, does not work. Neither via
# the context manager ...
with pytest.raises(KeyError):
with probnum.config(unknown_config=False):
pass

# ... nor by accessing the attribute directly.
with pytest.raises(KeyError):
probnum.config.unknown_config = False
19 changes: 18 additions & 1 deletion tests/test_randvars/test_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import scipy.sparse
import scipy.stats

from probnum import linops, randvars
from probnum import config, linops, randvars
from probnum.problems.zoo.linalg import random_spd_matrix
from tests.testing import NumpyAssertions

Expand Down Expand Up @@ -451,6 +451,23 @@ def test_precompute_cov_cholesky(self):
with self.subTest("Cholesky is precomputed"):
self.assertTrue(rv.cov_cholesky_is_precomputed)

def test_damping_factor_config(self):
mean, cov = self.params
rv = randvars.Normal(mean, cov)

chol_standard_damping = rv.dense_cov_cholesky(damping_factor=None)
self.assertAllClose(
chol_standard_damping,
np.sqrt(rv.cov + 1e-12),
)

with config(covariance_inversion_damping=1e-3):
chol_altered_damping = rv.dense_cov_cholesky(damping_factor=None)
self.assertAllClose(
chol_altered_damping,
np.sqrt(rv.cov + 1e-3),
)

def test_cov_cholesky_cov_cholesky_passed(self):
"""A value for cov_cholesky is passed in init.
Expand Down

0 comments on commit 888f911

Please sign in to comment.