-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from mmschlk/development
Development
- Loading branch information
Showing
16 changed files
with
750 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.0.2" | ||
__version__ = "0.0.3" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""This module contains the approximators to estimate the Shapley interaction values.""" | ||
from .permutation.sii import PermutationSamplingSII | ||
from .permutation.sti import PermutationSamplingSTI | ||
|
||
__all__ = [ | ||
"PermutationSamplingSII", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
"""This module contains the base approximator classes for the shapiq package.""" | ||
import copy | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Callable, Union, Optional | ||
|
||
import numpy as np | ||
|
||
|
||
AVAILABLE_INDICES = {"SII", "nSII", "STI", "FSI"} | ||
|
||
|
||
@dataclass | ||
class InteractionValues: | ||
""" This class contains the interaction values as estimated by an approximator. | ||
Attributes: | ||
values: The interaction values of the model. Mapping from order to the interaction values. | ||
index: The interaction index estimated. Available indices are 'SII', 'nSII', 'STI', and | ||
'FSI'. | ||
order: The order of the approximation. | ||
""" | ||
values: dict[int, np.ndarray] | ||
index: str | ||
order: int | ||
|
||
def __post_init__(self) -> None: | ||
"""Checks if the index is valid.""" | ||
if self.index not in ["SII", "nSII", "STI", "FSI"]: | ||
raise ValueError( | ||
f"Index {self.index} is not valid. " | ||
f"Available indices are 'SII', 'nSII', 'STI', and 'FSI'." | ||
) | ||
if self.order < 1 or self.order != max(self.values.keys()): | ||
raise ValueError( | ||
f"Order {self.order} is not valid. " | ||
f"Order should be a positive integer equal to the maximum key of the values." | ||
) | ||
|
||
|
||
class Approximator(ABC): | ||
"""This class is the base class for all approximators. | ||
Approximators are used to estimate the interaction values of a model or any value function. | ||
Different approximators can be used to estimate different interaction indices. Some can be used | ||
to estimate all indices. | ||
Args: | ||
n: The number of players. | ||
max_order: The interaction order of the approximation. | ||
index: The interaction index to be estimated. Available indices are 'SII', 'nSII', 'STI', | ||
and 'FSI'. | ||
top_order: If True, the approximation is performed only for the top order interactions. If | ||
False, the approximation is performed for all orders up to the specified order. | ||
random_state: The random state to use for the approximation. Defaults to None. | ||
Attributes: | ||
n: The number of players. | ||
N: The set of players (starting from 0 to n - 1). | ||
max_order: The interaction order of the approximation. | ||
index: The interaction index to be estimated. | ||
top_order: If True, the approximation is performed only for the top order interactions. If | ||
False, the approximation is performed for all orders up to the specified order. | ||
min_order: The minimum order of the approximation. If top_order is True, min_order is equal | ||
to max_order. Otherwise, min_order is equal to 1. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n: int, | ||
max_order: int, | ||
index: str, | ||
top_order: bool, | ||
random_state: Optional[int] = None | ||
) -> None: | ||
"""Initializes the approximator.""" | ||
self.index: str = index | ||
if self.index not in AVAILABLE_INDICES: | ||
raise ValueError( | ||
f"Index {self.index} is not valid. " | ||
f"Available indices are {AVAILABLE_INDICES}." | ||
) | ||
self.n: int = n | ||
self.N: set = set(range(self.n)) | ||
self.max_order: int = max_order | ||
self.top_order: bool = top_order | ||
self.min_order: int = self.max_order if self.top_order else 1 | ||
self._random_state: Optional[int] = random_state | ||
self._rng: Optional[np.random.Generator] = np.random.default_rng(seed=self._random_state) | ||
|
||
@abstractmethod | ||
def approximate( | ||
self, | ||
budget: int, | ||
game: Callable[[Union[set, tuple]], float], | ||
*args, **kwargs | ||
) -> InteractionValues: | ||
"""Approximates the interaction values. Abstract method that needs to be implemented for | ||
each approximator. | ||
Args: | ||
budget: The budget for the approximation. | ||
game: The game function. | ||
Returns: | ||
The interaction values. | ||
Raises: | ||
NotImplementedError: If the method is not implemented. | ||
""" | ||
raise NotImplementedError | ||
|
||
def _init_result(self, dtype=float) -> dict[int, np.ndarray]: | ||
"""Initializes the result dictionary mapping from order to the interaction values. | ||
For order 1 the interaction values are of shape (n,) for order 2 of shape (n, n) and so on. | ||
Args: | ||
dtype: The data type of the result dictionary values. Defaults to float. | ||
Returns: | ||
The result dictionary. | ||
""" | ||
result = {s: self._get_empty_array(self.n, s, dtype=dtype) | ||
for s in self._order_iterator} | ||
return result | ||
|
||
@staticmethod | ||
def _get_empty_array(n: int, order: int, dtype=float) -> np.ndarray: | ||
"""Returns an empty array of the appropriate shape for the given order. | ||
Args: | ||
n: The number of players. | ||
order: The order of the array. | ||
dtype: The data type of the array. Defaults to float. | ||
Returns: | ||
The empty array. | ||
""" | ||
return np.zeros(n ** order, dtype=dtype).reshape((n,) * order) | ||
|
||
@property | ||
def _order_iterator(self) -> range: | ||
"""Returns an iterator over the orders of the approximation. | ||
Returns: | ||
The iterator. | ||
""" | ||
return range(self.min_order, self.max_order + 1) | ||
|
||
def _finalize_result(self, result) -> InteractionValues: | ||
"""Finalizes the result dictionary. | ||
Args: | ||
result: The result dictionary. | ||
Returns: | ||
The interaction values. | ||
""" | ||
return InteractionValues(result, self.index, self.max_order) | ||
|
||
@staticmethod | ||
def _smooth_with_epsilon( | ||
interaction_results: Union[dict, np.ndarray], | ||
eps=0.00001 | ||
) -> Union[dict, np.ndarray]: | ||
"""Smooth the interaction results with a small epsilon to avoid numerical issues. | ||
Args: | ||
interaction_results: Interaction results. | ||
eps: Small epsilon. Defaults to 0.00001. | ||
Returns: | ||
Union[dict, np.ndarray]: Smoothed interaction results. | ||
""" | ||
if not isinstance(interaction_results, dict): | ||
interaction_results[np.abs(interaction_results) < eps] = 0 | ||
return copy.deepcopy(interaction_results) | ||
interactions = {} | ||
for interaction_order, interaction_values in interaction_results.items(): | ||
interaction_values[np.abs(interaction_values) < eps] = 0 | ||
interactions[interaction_order] = interaction_values | ||
return copy.deepcopy(interactions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""This module contains all permutation-based sampling algorithms to estimate SII/nSII and STI.""" | ||
from ._base import PermutationSampling | ||
from .sii import PermutationSamplingSII | ||
from .sti import PermutationSamplingSTI | ||
|
||
__all__ = [ | ||
"PermutationSampling", | ||
"PermutationSamplingSII", | ||
"PermutationSamplingSTI", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
"""This module contains the base permutation sampling algorithms to estimate SII/nSII and STI.""" | ||
from typing import Optional, Callable, Union | ||
|
||
import numpy as np | ||
|
||
from approximator._base import Approximator, InteractionValues | ||
|
||
|
||
AVAILABLE_INDICES_PERMUTATION = {"SII", "nSII", "STI"} | ||
|
||
|
||
class PermutationSampling(Approximator): | ||
"""Permutation sampling approximator. This class contains the permutation sampling algorithm to | ||
estimate SII/nSII and STI values. | ||
Args: | ||
n: The number of players. | ||
max_order: The interaction order of the approximation. | ||
index: The interaction index to be estimated. Available indices are 'SII', 'nSII', and | ||
'STI'. | ||
top_order: Whether to approximate only the top order interactions (`True`) or all orders up | ||
to the specified order (`False`). | ||
Attributes: | ||
n (int): The number of players. | ||
N (set): The set of players (starting from 0 to n - 1). | ||
max_order (int): The interaction order of the approximation. | ||
index (str): The interaction index to be estimated. | ||
top_order (bool): Whether to approximate only the top order interactions or all orders up to | ||
the specified order. | ||
min_order (int): The minimum order of the approximation. If top_order is True, min_order is | ||
equal to order. Otherwise, min_order is equal to 1. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n: int, | ||
max_order: int, | ||
index: str, | ||
top_order: bool, | ||
random_state: Optional[int] = None | ||
) -> None: | ||
if index not in AVAILABLE_INDICES_PERMUTATION: | ||
raise ValueError( | ||
f"Index {index} is not valid. " | ||
f"Available indices are {AVAILABLE_INDICES_PERMUTATION}." | ||
) | ||
super().__init__(n, max_order, index, top_order, random_state) | ||
|
||
def approximate( | ||
self, | ||
budget: int, | ||
game: Callable[[Union[set, tuple]], float] | ||
) -> InteractionValues: | ||
"""Approximates the interaction values.""" | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def _get_n_iterations(budget: int, batch_size: int, iteration_cost: int) -> tuple[int, int]: | ||
"""Computes the number of iterations and the size of the last batch given the batch size and | ||
the budget. | ||
Args: | ||
budget: The budget for the approximation. | ||
batch_size: The size of the batch. | ||
iteration_cost: The cost of a single iteration. | ||
Returns: | ||
int, int: The number of iterations and the size of the last batch. | ||
""" | ||
n_iterations = budget // (iteration_cost * batch_size) | ||
last_batch_size = batch_size | ||
remaining_budget = budget - n_iterations * iteration_cost * batch_size | ||
if remaining_budget > 0 and remaining_budget // iteration_cost > 0: | ||
last_batch_size = remaining_budget // iteration_cost | ||
n_iterations += 1 | ||
return n_iterations, last_batch_size |
Oops, something went wrong.