diff --git a/CHANGELOG.md b/CHANGELOG.md index d137eae9..7a099175 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,8 @@ ### v1.0.1 (2024-06-05) -- add `max_order=1` to `TabularExplainer` -- +- add `max_order=1` to `TabularExplainer` and `TreeExplainer` +- fix `TreeExplainer.explain_X(..., njobs=2, random_state=0)` ### v1.0.0 (2024-06-04) diff --git a/shapiq/__init__.py b/shapiq/__init__.py index b6495626..be97dde5 100644 --- a/shapiq/__init__.py +++ b/shapiq/__init__.py @@ -2,7 +2,7 @@ the well established Shapley value and its generalization to interaction. """ -__version__ = "1.0.0.9000" +__version__ = "1.0.1" # approximator classes from .approximator import ( diff --git a/shapiq/explainer/_base.py b/shapiq/explainer/_base.py index 3ca7c26d..a8b26acb 100644 --- a/shapiq/explainer/_base.py +++ b/shapiq/explainer/_base.py @@ -73,9 +73,12 @@ def explain_X( """ assert len(X.shape) == 2 if random_state is not None: - self._imputer._rng = np.random.default_rng(random_state) - self._approximator._rng = np.random.default_rng(random_state) - self._approximator._sampler._rng = np.random.default_rng(random_state) + if hasattr(self, "_imputer"): + self._imputer._rng = np.random.default_rng(random_state) + if hasattr(self, "_approximator"): + self._approximator._rng = np.random.default_rng(random_state) + if hasattr(self._approximator, "_sampler"): + self._approximator._sampler._rng = np.random.default_rng(random_state) if n_jobs: import joblib diff --git a/shapiq/explainer/tree/explainer.py b/shapiq/explainer/tree/explainer.py index 7fdd61da..836e3fc4 100644 --- a/shapiq/explainer/tree/explainer.py +++ b/shapiq/explainer/tree/explainer.py @@ -32,7 +32,7 @@ class TreeExplainer(Explainer): interaction values up to that order. Defaults to ``2``. min_order: The minimum interaction order to be computed. Defaults to ``1``. index: The type of interaction to be computed. It can be one of - ``["k-SII", "SII", "STII", "FSII", "BII"]``. All indices apart from ``"BII"`` will + ``["k-SII", "SII", "STII", "FSII", "BII", "SV"]``. All indices apart from ``"BII"`` will reduce to the ``"SV"`` (Shapley value) for order 1. Defaults to ``"k-SII"``. class_label: The class label of the model to explain. """ @@ -52,6 +52,9 @@ def __init__( if index == "SV" and max_order > 1: warnings.warn("For index='SV' the max_order is set to 1.") max_order = 1 + elif max_order == 1 and index != "SV": + warnings.warn("For max_order=1 the index is set to 'SV'.") + index = "SV" # validate and parse model validated_model = validate_tree_model(model, class_label=class_label)