From 49e4959f16309be841dca0ca0e3a31ffa57b4c4b Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Tue, 19 Nov 2024 10:17:05 +0100 Subject: [PATCH] Make toggle_discrete_candidates expect a collection of constraints --- baybe/campaign.py | 25 ++++++++++++++----------- tests/test_campaign.py | 6 +++--- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index c3f61bf23..8813998fe 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -4,7 +4,8 @@ import gc import json -from functools import singledispatchmethod +from collections.abc import Collection +from functools import reduce, singledispatchmethod from typing import TYPE_CHECKING import cattrs @@ -291,7 +292,7 @@ def _mark_as_measured( @singledispatchmethod def toggle_discrete_candidates( # noqa: DOC501 self, - constraint: DiscreteConstraint | pd.DataFrame, + constraints: Collection[DiscreteConstraint] | pd.DataFrame, exclude: bool, complement: bool = False, dry_run: bool = False, @@ -299,9 +300,9 @@ def toggle_discrete_candidates( # noqa: DOC501 """In-/exclude certain discrete points in/from the candidate set. Args: - constraint: A filtering mechanism determining the candidates subset to be - in-/excluded. Can be either a - :class:`~baybe.constraints.base.DiscreteConstraint` or a dataframe. + constraints: A filtering mechanism determining the candidates subset to be + in-/excluded. Can be either a collection of + :class:`~baybe.constraints.base.DiscreteConstraint`s or a dataframe. For the latter, see :func:`~baybe.utils.dataframe.filter_df` for details. exclude: If ``True``, the specified candidates are excluded. @@ -320,20 +321,22 @@ def toggle_discrete_candidates( # noqa: DOC501 """ raise NotImplementedError( f"Candidate toggling is not implemented for constraint specifications of " - f"type {type(constraint)}." + f"type {type(constraints)}." ) - @toggle_discrete_candidates.register + @toggle_discrete_candidates.register(Collection) def _( self, - constraint: DiscreteConstraint, + constraints: Collection[DiscreteConstraint], exclude: bool, complement: bool = False, dry_run: bool = False, ) -> pd.DataFrame: # Filter search space dataframe according to the given constraint df = self.searchspace.discrete.exp_rep - idx = constraint.get_valid(df) + idx = reduce( + lambda x, y: x.intersection(y), (c.get_valid(df) for c in constraints) + ) # Determine the candidate subset to be toggled points = df.drop(index=idx) if complement else df.loc[idx].copy() @@ -346,13 +349,13 @@ def _( @toggle_discrete_candidates.register def _( self, - constraint: pd.DataFrame, + constraints: pd.DataFrame, exclude: bool, complement: bool = False, dry_run: bool = False, ) -> pd.DataFrame: # Determine the candidate subset to be toggled - points = filter_df(self.searchspace.discrete.exp_rep, constraint, complement) + points = filter_df(self.searchspace.discrete.exp_rep, constraints, complement) if not dry_run: self._searchspace_metadata.loc[points.index, _EXCLUDED] = exclude diff --git a/tests/test_campaign.py b/tests/test_campaign.py index 271a100c8..bc4c41765 100644 --- a/tests/test_campaign.py +++ b/tests/test_campaign.py @@ -41,14 +41,14 @@ def test_get_surrogate(campaign, n_iterations, batch_size): @pytest.mark.parametrize("complement", [False, True], ids=["regular", "complement"]) @pytest.mark.parametrize("exclude", [True, False], ids=["exclude", "include"]) @pytest.mark.parametrize( - "constraint", + "constraints", [ pd.DataFrame({"a": [0]}), DiscreteExcludeConstraint(["a"], [SubSelectionCondition([1])]), ], ids=["dataframe", "constraints"], ) -def test_candidate_toggling(constraint, exclude, complement): +def test_candidate_toggling(constraints, exclude, complement): """Toggling discrete candidates updates the campaign metadata accordingly.""" subspace = SubspaceDiscrete.from_product( [ @@ -62,7 +62,7 @@ def test_candidate_toggling(constraint, exclude, complement): campaign._searchspace_metadata[_EXCLUDED] = not exclude # Toggle the candidates - campaign.toggle_discrete_candidates(constraint, exclude, complement=complement) + campaign.toggle_discrete_candidates(constraints, exclude, complement=complement) # Extract row indices of candidates whose metadata should have been toggled matches = campaign.searchspace.discrete.exp_rep["a"] == 0