Skip to content

Commit

Permalink
Insights module and SHAPInsight (#391)
Browse files Browse the repository at this point in the history
Adds the `insights` subpackage with `SHAPInsight` as first content
  • Loading branch information
AdrianSosic authored Jan 20, 2025
2 parents deb3d5a + 8797dab commit 79a94eb
Show file tree
Hide file tree
Showing 15 changed files with 745 additions and 35 deletions.
78 changes: 58 additions & 20 deletions .lockfiles/py310-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ anyio==4.4.0
# via
# httpx
# jupyter-server
appnope==0.1.4 ; platform_system == 'Darwin'
appnope==0.1.4 ; sys_platform == 'darwin'
# via ipykernel
argon2-cffi==23.1.0
# via jupyter-server
Expand Down Expand Up @@ -84,7 +84,9 @@ click==8.1.7
# pydoclint
# streamlit
cloudpickle==3.0.0
# via dask
# via
# dask
# shap
colorama==0.4.6
# via
# click
Expand Down Expand Up @@ -230,6 +232,8 @@ idna==3.7
# httpx
# jsonschema
# requests
imageio==2.36.1
# via scikit-image
imagesize==1.4.1
# via sphinx
importlib-metadata==7.1.0
Expand All @@ -238,7 +242,7 @@ importlib-metadata==7.1.0
# opentelemetry-api
iniconfig==2.0.0
# via pytest
intel-openmp==2021.4.0 ; platform_system == 'Windows'
intel-openmp==2021.4.0 ; sys_platform == 'win32'
# via mkl
interface-meta==1.3.0
# via formulaic
Expand Down Expand Up @@ -344,10 +348,14 @@ kiwisolver==1.4.5
# via matplotlib
latexcodec==3.0.0
# via pybtex
lazy-loader==0.4
# via scikit-image
license-expression==30.3.0
# via cyclonedx-python-lib
lifelines==0.29.0
# via ngboost
lime==0.2.0.1
# via shap
linear-operator==0.5.2
# via
# botorch
Expand All @@ -370,6 +378,7 @@ matplotlib==3.9.1
# via
# baybe (pyproject.toml)
# lifelines
# lime
# seaborn
# types-seaborn
matplotlib-inline==0.1.7
Expand All @@ -386,7 +395,7 @@ mdurl==0.1.2
# via markdown-it-py
mistune==3.0.2
# via nbconvert
mkl==2021.4.0 ; platform_system == 'Windows'
mkl==2021.4.0 ; sys_platform == 'win32'
# via torch
mmh3==5.0.1
# via e3fp
Expand Down Expand Up @@ -424,6 +433,7 @@ nest-asyncio==1.6.0
networkx==3.3
# via
# mordredcommunity
# scikit-image
# torch
ngboost==0.5.1
# via baybe (pyproject.toml)
Expand All @@ -436,7 +446,9 @@ notebook-shim==0.2.4
# jupyterlab
# notebook
numba==0.60.0
# via scikit-fingerprints
# via
# scikit-fingerprints
# shap
numpy==1.26.4
# via
# baybe (pyproject.toml)
Expand All @@ -449,7 +461,9 @@ numpy==1.26.4
# e3fp
# formulaic
# h5py
# imageio
# lifelines
# lime
# matplotlib
# mordredcommunity
# ngboost
Expand All @@ -465,44 +479,47 @@ numpy==1.26.4
# pyro-ppl
# rdkit
# scikit-fingerprints
# scikit-image
# scikit-learn
# scikit-learn-extra
# scipy
# seaborn
# shap
# streamlit
# tifffile
# types-seaborn
# xarray
# xyzpy
nvidia-cublas-cu12==12.1.3.1 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cublas-cu12==12.1.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-cupti-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cuda-runtime-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-runtime-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cudnn-cu12==8.9.2.26 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cudnn-cu12==8.9.2.26 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cufft-cu12==11.0.2.54 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cufft-cu12==11.0.2.54 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-curand-cu12==10.3.2.106 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-curand-cu12==10.3.2.106 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cusolver-cu12==11.4.5.107 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cusolver-cu12==11.4.5.107 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cusparse-cu12==12.1.0.106 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cusparse-cu12==12.1.0.106 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cusolver-cu12
# torch
nvidia-nccl-cu12==2.20.5 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nccl-cu12==2.20.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-nvjitlink-cu12==12.5.82 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nvjitlink-cu12==12.5.82 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nvtx-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
onnx==1.16.1
# via
Expand Down Expand Up @@ -566,6 +583,7 @@ packaging==24.1
# jupyterlab
# jupyterlab-server
# jupytext
# lazy-loader
# matplotlib
# mordredcommunity
# nbconvert
Expand All @@ -576,7 +594,9 @@ packaging==24.1
# plotly
# pyproject-api
# pytest
# scikit-image
# setuptools-scm
# shap
# sphinx
# streamlit
# tox
Expand All @@ -592,6 +612,7 @@ pandas==2.2.2
# pandas-flavor
# scikit-fingerprints
# seaborn
# shap
# streamlit
# xarray
# xyzpy
Expand All @@ -612,8 +633,10 @@ pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
pillow==10.4.0
# via
# baybe (pyproject.toml)
# imageio
# matplotlib
# rdkit
# scikit-image
# streamlit
pip==24.1.2
# via pip-api
Expand Down Expand Up @@ -784,13 +807,17 @@ s3transfer==0.10.4
# via boto3
scikit-fingerprints==1.9.0
# via baybe (pyproject.toml)
scikit-image==0.25.0
# via lime
scikit-learn==1.5.1
# via
# baybe (pyproject.toml)
# gpytorch
# lime
# ngboost
# scikit-fingerprints
# scikit-learn-extra
# shap
# skl2onnx
scikit-learn-extra==0.3.0
# via baybe (pyproject.toml)
Expand All @@ -805,11 +832,14 @@ scipy==1.14.0
# formulaic
# gpytorch
# lifelines
# lime
# linear-operator
# ngboost
# scikit-fingerprints
# scikit-image
# scikit-learn
# scikit-learn-extra
# shap
sdaxen-python-utilities==0.1.5
# via e3fp
seaborn==0.13.2
Expand All @@ -822,6 +852,8 @@ setuptools==71.1.0
# setuptools-scm
setuptools-scm==8.1.0
# via baybe (pyproject.toml)
shap==0.46.0
# via baybe (pyproject.toml)
six==1.16.0
# via
# asttokens
Expand All @@ -833,6 +865,8 @@ six==1.16.0
# rfc3339-validator
skl2onnx==1.17.0
# via baybe (pyproject.toml)
slicer==0.0.8
# via shap
smart-open==7.0.5
# via e3fp
smmap==5.0.1
Expand Down Expand Up @@ -889,7 +923,7 @@ sympy==1.13.1
# via
# onnxruntime
# torch
tbb==2021.13.0 ; platform_system == 'Windows'
tbb==2021.13.0 ; sys_platform == 'win32'
# via mkl
tenacity==8.5.0
# via
Expand All @@ -902,6 +936,8 @@ terminado==0.18.1
# jupyter-server-terminals
threadpoolctl==3.5.0
# via scikit-learn
tifffile==2024.12.12
# via scikit-image
tinycss2==1.3.0
# via nbconvert
tokenize-rt==6.1.0
Expand Down Expand Up @@ -950,9 +986,11 @@ tox-uv==1.9.1
tqdm==4.66.4
# via
# huggingface-hub
# lime
# ngboost
# pyro-ppl
# scikit-fingerprints
# shap
# xyzpy
traitlets==5.14.3
# via
Expand All @@ -970,7 +1008,7 @@ traitlets==5.14.3
# nbclient
# nbconvert
# nbformat
triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and platform_system == 'Linux'
triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
typeguard==2.13.3
# via
Expand Down Expand Up @@ -1013,7 +1051,7 @@ virtualenv==20.26.3
# via
# pre-commit
# tox
watchdog==4.0.1 ; platform_system != 'Darwin'
watchdog==4.0.1 ; sys_platform != 'darwin'
# via streamlit
wcwidth==0.2.13
# via prompt-toolkit
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
the corresponding parameter/target column labels

### Added
- Optional `insights` dependency group
- SHAP explanations via the new `SHAPInsight` class
- `allow_missing` and `allow_extra` keyword arguments to `Objective.transform`
- Example for a traditional mixture
- `add_noise_to_perturb_degenerate_rows` utility
Expand Down
2 changes: 2 additions & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@
`scikit-fingerprints` support
- Fabian Liebig (Merck KGaA, Darmstadt, Germany):\
Benchmarking structure and persistence capabilities for benchmarking results
- Alexander Wieczorek (Swiss Federal Institute for Materials Science and Technology, Dübendorf, Switzerland):\
SHAP explainers for insights
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@
The **Bay**esian **B**ack **E**nd (**BayBE**) is a general-purpose toolbox for Bayesian Design
of Experiments, focusing on additions that enable real-world experimental campaigns.

Besides functionality to perform a typical recommend-measure loop, BayBE's highlights are:
- ✨ Custom parameter encodings: Improve your campaign with domain knowledge
## 🔋 Batteries Included
Besides its core functionality to perform a typical recommend-measure loop, BayBE
offers a range of ✨**built&#8209;in&nbsp;features**✨ crucial for real-world use cases.
The following provides a non-comprehensive overview:

- 🛠️ Custom parameter encodings: Improve your campaign with domain knowledge
- 🧪 Built-in chemical encodings: Improve your campaign with chemical knowledge
- 🎯 Single and multiple targets with min, max and match objectives
- 🔍 Insights: Easily analyze feature importance and model behavior
- 🎭 Hybrid (mixed continuous and discrete) spaces
- 🚀 Transfer learning: Mix data from multiple campaigns and accelerate optimization
- 🎰 Bandit models: Efficiently find the best among many options in noisy environments (e.g. A/B Testing)
Expand Down Expand Up @@ -296,7 +301,8 @@ The available groups are:
- `lint`: Required for linting and formatting.
- `mypy`: Required for static type checking.
- `onnx`: Required for using custom surrogate models in [ONNX format](https://onnx.ai).
- `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/)
- `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/).
- `insights`: Required for built-in model and campaign analysis (e.g. using [SHAP](https://shap.readthedocs.io/)).
- `simulation`: Enabling the [simulation](https://emdgroup.github.io/baybe/stable/_autosummary/baybe.simulation.html) module.
- `test`: Required for running the tests.
- `benchmarking`: Required for running the benchmarking module.
Expand Down
3 changes: 3 additions & 0 deletions baybe/_optional/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404
# Individual packages
with exclude_sys_path(os.getcwd()):
FLAKE8_INSTALLED = find_spec("flake8") is not None
LIME_INSTALLED = find_spec("lime") is not None
ONNX_INSTALLED = find_spec("onnxruntime") is not None
POLARS_INSTALLED = find_spec("polars") is not None
PRE_COMMIT_INSTALLED = find_spec("pre_commit") is not None
PYDOCLINT_INSTALLED = find_spec("pydoclint") is not None
RUFF_INSTALLED = find_spec("ruff") is not None
SHAP_INSTALLED = find_spec("shap") is not None
SKFP_INSTALLED = find_spec("skfp") is not None # scikit-fingerprints
STREAMLIT_INSTALLED = find_spec("streamlit") is not None
XYZPY_INSTALLED = find_spec("xyzpy") is not None
Expand All @@ -44,6 +46,7 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404

# Information on whether all required packages for certain functionality are available
CHEM_INSTALLED = SKFP_INSTALLED
INSIGHTS_INSTALLED = SHAP_INSTALLED and LIME_INSTALLED
LINT_INSTALLED = all(
(
FLAKE8_INSTALLED,
Expand Down
16 changes: 16 additions & 0 deletions baybe/_optional/insights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Optional import for the insights subpackage."""

from baybe.exceptions import OptionalImportError

try:
import shap
except ModuleNotFoundError as ex:
raise OptionalImportError(
"Explainer functionality is unavailable because 'insights' is not installed."
" Consider installing BayBE with 'insights' dependency, e.g. via "
"`pip install baybe[insights]`."
) from ex

__all__ = [
"shap",
]
Loading

0 comments on commit 79a94eb

Please sign in to comment.