Skip to content

Commit

Permalink
Use context for expected failures
Browse files Browse the repository at this point in the history
  • Loading branch information
Scienfitz committed Jan 20, 2025
1 parent 0baf9b7 commit 8797dab
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions tests/insights/test_shap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for insights subpackage."""

import inspect
from contextlib import nullcontext
from unittest import mock

import numpy as np
Expand Down Expand Up @@ -80,26 +81,35 @@ def test_non_shap_signature(explainer_name):

def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap):
"""Helper function for general SHAP explainer tests."""
try:
context = nullcontext()
if (
(not use_comp_rep)
and (explainer_cls != "KernelExplainer")
and any(not p.is_numerical for p in campaign.parameters)
):
# We expect a validation error in case an explanation with an unsupported
# explainer type is attempted on a search space representation with
# non-numerical entries
context = pytest.raises(IncompatibleExplainerError)

with context:
shap_insight = SHAPInsight.from_campaign(
campaign,
explainer_cls=explainer_cls,
use_comp_rep=use_comp_rep,
)
except IncompatibleExplainerError:
pytest.skip("Unsupported model/explainer combination.")

# Sanity check explainer
assert isinstance(shap_insight, insights.SHAPInsight)
assert isinstance(shap_insight.explainer, _get_explainer_cls(explainer_cls))
assert shap_insight.uses_shap_explainer == is_shap

# Sanity check explanation
df = campaign.measurements[[p.name for p in campaign.parameters]]
if use_comp_rep:
df = campaign.searchspace.transform(df)
shap_explanation = shap_insight.explain(df)
assert isinstance(shap_explanation, shap.Explanation)
# Sanity check explainer
assert isinstance(shap_insight, insights.SHAPInsight)
assert isinstance(shap_insight.explainer, _get_explainer_cls(explainer_cls))
assert shap_insight.uses_shap_explainer == is_shap

# Sanity check explanation
df = campaign.measurements[[p.name for p in campaign.parameters]]
if use_comp_rep:
df = campaign.searchspace.transform(df)
shap_explanation = shap_insight.explain(df)
assert isinstance(shap_explanation, shap.Explanation)


@mark.slow
Expand Down

0 comments on commit 8797dab

Please sign in to comment.