-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add proposal for global config context-manager * Add register * Example test Taking the randvars.Normal cholesky inversion damping factor as an example * test * Extra config file per module * In test, set attribute to config to add coverage * Improved tests * Configuration documentation * probnum --> ProbNum * Made damping non-optional in low-level functions * Moved everything to top-level * Remove unused import * Changes in docstring * Changes in documentation * Documentation Co-authored-by: Jonathan Schmidt <[email protected]>
- Loading branch information
1 parent
50dc05e
commit 888f911
Showing
7 changed files
with
229 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
************** | ||
probnum.config | ||
************** | ||
|
||
.. automodapi:: probnum._config | ||
:no-heading: | ||
:headings: "=" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import contextlib | ||
from typing import Any | ||
|
||
|
||
class Configuration: | ||
""" | ||
Configuration over which some mechanics of ProbNum can be controlled dynamically. | ||
ProbNum provides some configurations together with default values. These | ||
are listed in the tables below. | ||
Additionally, users can register their own configuration entries via | ||
:meth:`register`. Configuration entries can only be registered once and can only | ||
be used (accessed or overwritten) once they have been registered. | ||
+----------------------------------+---------------+----------------------------------------------+ | ||
| Config entry | Default value | Description | | ||
+==================================+===============+==============================================+ | ||
| ``covariance_inversion_damping`` | ``1e-12`` | A (typically small) value that is per | | ||
| | | default added to the diagonal of covariance | | ||
| | | matrices in order to make inversion | | ||
| | | numerically stable. | | ||
+----------------------------------+---------------+----------------------------------------------+ | ||
| ``...`` | ``...`` | ... | | ||
+----------------------------------+---------------+----------------------------------------------+ | ||
Examples | ||
======== | ||
>>> import probnum | ||
>>> probnum.config.covariance_inversion_damping | ||
1e-12 | ||
>>> with probnum.config( | ||
... covariance_inversion_damping=1e-2, | ||
... ): | ||
... probnum.config.covariance_inversion_damping | ||
0.01 | ||
""" | ||
|
||
@contextlib.contextmanager | ||
def __call__(self, **kwargs) -> None: | ||
old_entries = dict() | ||
|
||
for key, value in kwargs.items(): | ||
if not hasattr(self, key): | ||
raise KeyError( | ||
f"Configuration entry {key} does not exist yet." | ||
"Configuration entries must be `register`ed before they can be " | ||
"accessed." | ||
) | ||
|
||
old_entries[key] = getattr(self, key) | ||
|
||
setattr(self, key, value) | ||
|
||
try: | ||
yield | ||
finally: | ||
self.__dict__.update(old_entries) | ||
|
||
def __setattr__(self, key: str, value: Any) -> None: | ||
if not hasattr(self, key): | ||
raise KeyError( | ||
f"Configuration entry {key} does not exist yet." | ||
"Configuration entries must be `register`ed before they can be " | ||
"accessed." | ||
) | ||
|
||
self.__dict__[key] = value | ||
|
||
def register(self, key: str, default_value: Any) -> None: | ||
"""Register new configuration option. | ||
Parameters | ||
---------- | ||
key: | ||
The name of the configuration option. This will be the ``key`` when calling | ||
``with config(key=<some_value>): ...``. | ||
default_value: | ||
The default value of the configuration option. | ||
""" | ||
if hasattr(self, key): | ||
raise KeyError( | ||
f"Configuration entry {key} does already exist and " | ||
"cannot be registered again." | ||
) | ||
self.__dict__[key] = default_value | ||
|
||
|
||
# Create a single, global configuration object,... | ||
_GLOBAL_CONFIG_SINGLETON = Configuration() | ||
|
||
# ... define some configuration options, and the respective default values | ||
# (which have to be documented in the Configuration-class docstring!!), ... | ||
_DEFAULT_CONFIG_OPTIONS = [ | ||
# list of tuples (config_key, default_value) | ||
("covariance_inversion_damping", 1e-12), | ||
] | ||
|
||
# ... and register the default configuration options. | ||
for key, default_value in _DEFAULT_CONFIG_OPTIONS: | ||
_GLOBAL_CONFIG_SINGLETON.register(key, default_value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import pytest | ||
|
||
import probnum | ||
from probnum._config import _DEFAULT_CONFIG_OPTIONS | ||
|
||
|
||
def test_defaults(): | ||
none_vals = {key: None for (key, _) in _DEFAULT_CONFIG_OPTIONS} | ||
|
||
for key, default_val in _DEFAULT_CONFIG_OPTIONS: | ||
# Check if default is correct before context manager | ||
assert getattr(probnum.config, key) == default_val | ||
# Temporarily set all config values to None | ||
with probnum.config(**none_vals): | ||
assert getattr(probnum.config, key) is None | ||
|
||
# Check if the original (default) values are set after exiting the context | ||
# manager | ||
assert getattr(probnum.config, key) == default_val | ||
|
||
|
||
def test_register(): | ||
# Check if registering a new config entry works | ||
probnum.config.register("some_config", 3.14) | ||
assert probnum.config.some_config == 3.14 | ||
|
||
# When registering a new entry with an already existing name, throw | ||
with pytest.raises(KeyError): | ||
probnum.config.register("some_config", 4.2) | ||
|
||
# Check if temporarily setting the config entry to a different value (via | ||
# the context manager) works | ||
with probnum.config(some_config=9.9): | ||
assert probnum.config.some_config == 9.9 | ||
|
||
# Upon exiting the context manager, the previous value is restored | ||
assert probnum.config.some_config == 3.14 | ||
|
||
# Setting the config entry permanently also works by | ||
# accessing the attribute directly | ||
probnum.config.some_config = 4.5 | ||
assert probnum.config.some_config == 4.5 | ||
|
||
# Setting a config entry before registering it, does not work. Neither via | ||
# the context manager ... | ||
with pytest.raises(KeyError): | ||
with probnum.config(unknown_config=False): | ||
pass | ||
|
||
# ... nor by accessing the attribute directly. | ||
with pytest.raises(KeyError): | ||
probnum.config.unknown_config = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters