From 2056ddadba5d4cbd63f1946f5f8d8c0ed67b652a Mon Sep 17 00:00:00 2001 From: Umberto Lupo <umberto.lupo@epfl.ch> Date: Sat, 14 Oct 2023 00:47:43 +0200 Subject: [PATCH] Generalize InformationAndReciprocalBestHits to include Blosum62, generalize Blosum62 --- diffpass/_modidx.py | 10 ++++++ diffpass/base.py | 53 +++++++++++++++++++++++++++++ diffpass/model.py | 8 ++++- diffpass/train.py | 60 +++++++++++++------------------- nbs/_example_prokaryotic.ipynb | 17 +++++++--- nbs/base.ipynb | 53 +++++++++++++++++++++++++++++ nbs/model.ipynb | 8 ++++- nbs/train.ipynb | 62 +++++++++++++--------------------- 8 files changed, 190 insertions(+), 81 deletions(-) diff --git a/diffpass/_modidx.py b/diffpass/_modidx.py index 12aedb9..8e4b84c 100644 --- a/diffpass/_modidx.py +++ b/diffpass/_modidx.py @@ -8,8 +8,18 @@ 'syms': { 'diffpass.base': { 'diffpass.base.DiffPASSMixin': ('base.html#diffpassmixin', 'diffpass/base.py'), 'diffpass.base.DiffPASSMixin.reduce_num_tokens': ( 'base.html#diffpassmixin.reduce_num_tokens', 'diffpass/base.py'), + 'diffpass.base.DiffPASSMixin.validate_information_measure': ( 'base.html#diffpassmixin.validate_information_measure', + 'diffpass/base.py'), 'diffpass.base.DiffPASSMixin.validate_inputs': ( 'base.html#diffpassmixin.validate_inputs', 'diffpass/base.py'), + 'diffpass.base.DiffPASSMixin.validate_permutation_cfg': ( 'base.html#diffpassmixin.validate_permutation_cfg', + 'diffpass/base.py'), + 'diffpass.base.DiffPASSMixin.validate_reciprocal_best_hits_cfg': ( 'base.html#diffpassmixin.validate_reciprocal_best_hits_cfg', + 'diffpass/base.py'), + 'diffpass.base.DiffPASSMixin.validate_similarities_cfg': ( 'base.html#diffpassmixin.validate_similarities_cfg', + 'diffpass/base.py'), + 'diffpass.base.DiffPASSMixin.validate_similarity_kind': ( 'base.html#diffpassmixin.validate_similarity_kind', + 'diffpass/base.py'), 'diffpass.base.EnsembleMixin': ('base.html#ensemblemixin', 'diffpass/base.py'), 'diffpass.base.EnsembleMixin._reshape_ensemble_param': ( 'base.html#ensemblemixin._reshape_ensemble_param', 'diffpass/base.py'), diff --git a/diffpass/base.py b/diffpass/base.py index 3ee6ccc..20ef1cc 100644 --- a/diffpass/base.py +++ b/diffpass/base.py @@ -21,7 +21,24 @@ # %% ../nbs/base.ipynb 3 class DiffPASSMixin: + allowed_permutation_cfg_keys = { + "tau", + "n_iter", + "noise", + "noise_factor", + "noise_std", + } + allowed_information_measures = {"MI", "TwoBodyEntropy"} + allowed_similarity_kinds = {"Hamming", "Blosum62"} + allowed_similarities_cfg_keys = { + "Hamming": {"use_dot", "p"}, + "Blosum62": {"aa_to_int", "gaps_as_stars"}, + } + allowed_reciprocal_best_hits_cfg_keys = {"tau"} + group_sizes: Iterable[int] + information_measure: str + similarity_kind: str @staticmethod def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor: @@ -32,6 +49,42 @@ def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor: return x[..., used_tokens] + def validate_permutation_cfg(self, permutation_cfg: dict) -> None: + if not set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys): + raise ValueError( + f"Invalid keys in `permutation_cfg`: {set(permutation_cfg) - self.allowed_permutation_cfg_keys}" + ) + + def validate_information_measure(self, information_measure: str) -> None: + if information_measure not in self.allowed_information_measures: + raise ValueError( + f"Invalid information measure: {self.information_measure}. " + f"Allowed values are: {self.allowed_information_measures}" + ) + + def validate_similarity_kind(self, similarity_kind: str) -> None: + if similarity_kind not in self.allowed_similarity_kinds: + raise ValueError( + f"Invalid similarity kind: {self.similarity_kind}. " + f"Allowed values are: {self.allowed_similarity_kinds}" + ) + + def validate_similarities_cfg(self, similarities_cfg: dict) -> None: + if not set(similarities_cfg).issubset( + self.allowed_similarities_cfg_keys[self.similarity_kind] + ): + raise ValueError( + f"Invalid keys in `similarities_cfg`: {set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}" + ) + + def validate_reciprocal_best_hits_cfg(self, reciprocal_best_hits_cfg: dict) -> None: + if not set(reciprocal_best_hits_cfg).issubset( + self.allowed_reciprocal_best_hits_cfg_keys + ): + raise ValueError( + f"Invalid keys in `reciprocal_best_hits_cfg`: {set(reciprocal_best_hits_cfg) - self.allowed_reciprocal_best_hits_cfg_keys}" + ) + def validate_inputs( self, x: torch.Tensor, y: torch.Tensor, check_same_alphabet_size: bool = False ) -> None: diff --git a/diffpass/model.py b/diffpass/model.py index 09ba272..1658176 100644 --- a/diffpass/model.py +++ b/diffpass/model.py @@ -303,6 +303,7 @@ def __init__( self, *, group_sizes: Optional[Iterable[int]] = None, + use_scoredist: bool = False, aa_to_int: Optional[dict[str, int]] = None, gaps_as_stars: bool = True, ) -> None: @@ -310,6 +311,7 @@ def __init__( self.group_sizes = ( tuple(s for s in group_sizes) if group_sizes is not None else None ) + self.use_scoredist = use_scoredist self.aa_to_int = aa_to_int self.gaps_as_stars = gaps_as_stars @@ -317,6 +319,7 @@ def __init__( aa_to_int=self.aa_to_int, gaps_as_stars=self.gaps_as_stars ) self.register_buffer("subs_mat", blosum62_data.mat) + self.expected_value = blosum62_data.expected_value self._group_slices = _consecutive_slices_from_sizes(self.group_sizes) @@ -328,7 +331,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for sl in self._group_slices: out[..., sl, sl].copy_( smooth_substitution_matrix_similarities( - x[..., sl, :, :], subs_mat=self.subs_mat + x[..., sl, :, :], + subs_mat=self.subs_mat, + expected_value=self.expected_value, + use_scoredist=self.use_scoredist, ) ) diff --git a/diffpass/train.py b/diffpass/train.py index 112111a..3595a5b 100644 --- a/diffpass/train.py +++ b/diffpass/train.py @@ -88,17 +88,6 @@ class DiffPASSResults: # %% ../nbs/train.ipynb 6 class InformationAndReciprocalBestHits(Module, EnsembleMixin, DiffPASSMixin): - allowed_permutation_cfg_keys = { - "tau", - "n_iter", - "noise", - "noise_factor", - "noise_std", - } - allowed_information_measures = {"MI", "TwoBodyEntropy"} - allowed_hamming_similarities_cfg_keys = {"use_dot", "p"} - allowed_reciprocal_best_hits_cfg_keys = {"tau"} - def __init__( self, group_sizes: Iterable[int], @@ -106,7 +95,8 @@ def __init__( loss_weights: Optional[dict[str, Union[float, torch.Tensor]]] = None, permutation_cfg: Optional[dict[str, Any]] = None, information_measure: Literal["MI", "TwoBodyEntropy"] = "TwoBodyEntropy", - hamming_similarities_cfg: Optional[dict[str, Any]] = None, + similarity_kind: Literal["Hamming", "Blosum62"] = "Hamming", + similarities_cfg: Optional[dict[str, Any]] = None, reciprocal_best_hits_cfg: Optional[dict[str, Any]] = None, inter_group_loss_score_fn: Optional[callable] = None, ): @@ -116,7 +106,8 @@ def __init__( self.loss_weights = loss_weights self.permutation_cfg = permutation_cfg self.information_measure = information_measure - self.hamming_similarities_cfg = hamming_similarities_cfg + self.similarity_kind = similarity_kind + self.similarities_cfg = similarities_cfg self.reciprocal_best_hits_cfg = reciprocal_best_hits_cfg self.inter_group_loss_score_fn = inter_group_loss_score_fn @@ -126,7 +117,7 @@ def __init__( if permutation_cfg is None: permutation_cfg = {} else: - assert set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys) + self.validate_permutation_cfg(permutation_cfg) permutation_cfg = deepcopy(permutation_cfg) _dim_in_ensemble = self._adjust_cfg_and_ensemble_shape( ensemble_shape=ensemble_shape, @@ -136,24 +127,21 @@ def __init__( ) self.effective_permutation_cfg_ = permutation_cfg - assert self.information_measure in self.allowed_information_measures + self.validate_information_measure(information_measure) self.loss_weights_keys = {self.information_measure, "ReciprocalBestHits"} - if hamming_similarities_cfg is None: - hamming_similarities_cfg = {} + self.validate_similarity_kind(similarity_kind) + if similarities_cfg is None: + similarities_cfg = {} else: - assert set(hamming_similarities_cfg).issubset( - self.allowed_hamming_similarities_cfg_keys - ) - hamming_similarities_cfg = deepcopy(hamming_similarities_cfg) - self.effective_hamming_similarities_cfg_ = hamming_similarities_cfg + self.validate_similarities_cfg(similarities_cfg) + similarities_cfg = deepcopy(similarities_cfg) + self.effective_similarities_cfg_ = similarities_cfg if reciprocal_best_hits_cfg is None: reciprocal_best_hits_cfg = {} else: - assert set(reciprocal_best_hits_cfg).issubset( - self.allowed_reciprocal_best_hits_cfg_keys - ) + self.validate_reciprocal_best_hits_cfg(reciprocal_best_hits_cfg) reciprocal_best_hits_cfg = deepcopy(reciprocal_best_hits_cfg) _dim_in_ensemble = self._adjust_cfg_and_ensemble_shape( ensemble_shape=ensemble_shape, @@ -190,16 +178,14 @@ def __init__( self.information_loss = TwoBodyEntropyLoss() elif self.information_measure == "MI": self.information_loss = MILoss() - else: - # FIXME Redundant check - raise ValueError( - f"``information_measure`` must be one of {self.allowed_information_measures}." - ) - self.information_loss.register_buffer( "weight", loss_weights[self.information_measure] ) - self.hamming_similarities = HammingSimilarities(**hamming_similarities_cfg) + + if similarity_kind == "Blosum62": + self.similarities = Blosum62Similarities(**similarities_cfg) + elif similarity_kind == "Hamming": + self.similarities = HammingSimilarities(**similarities_cfg) self.reciprocal_best_hits = ReciprocalBestHits( group_sizes=self.group_sizes, ensemble_shape=ensemble_shape, @@ -301,9 +287,9 @@ def _adjust_loss_weights_and_ensemble_shape( def _precompute_rbh(self, x: torch.Tensor, y: torch.Tensor) -> None: # Temporarily switch to hard RBH self.reciprocal_best_hits.hard_() - similarities_x = self.hamming_similarities(x) + similarities_x = self.similarities(x) self._rbh_hard_x = self.reciprocal_best_hits(similarities_x) - similarities_y = self.hamming_similarities(y) + similarities_y = self.similarities(y) self._rbh_hard_y = self.reciprocal_best_hits(similarities_y) # Revert to soft (default) RBH @@ -333,7 +319,7 @@ def forward( if mode == "soft": if x_perm_hard is not None: x_perm = (x_perm_hard - x_perm).detach() + x_perm - similarities_x = self.hamming_similarities(x_perm) + similarities_x = self.similarities(x_perm) rbh_x = self.reciprocal_best_hits(similarities_x) else: rbh_x = apply_hard_permutation_batch_to_similarity( @@ -366,7 +352,7 @@ def fit( optimizer_name: Optional[str] = "SGD", optimizer_kwargs: Optional[dict[str, Any]] = None, mean_centering: bool = True, - hamming_gradient_bypass: bool = False, + similarity_gradient_bypass: bool = False, ) -> DiffPASSResults: # Validate inputs self.validate_inputs(x, y) @@ -427,7 +413,7 @@ def fit( epoch_results = self(x, y) loss_info = epoch_results["loss_info"] loss_rbh = epoch_results["loss_rbh"] - if hamming_gradient_bypass: + if similarity_gradient_bypass: x_perm_hard = epoch_results["x_perm"] perms = epoch_results["perms"] results.log_alphas.append( diff --git a/nbs/_example_prokaryotic.ipynb b/nbs/_example_prokaryotic.ipynb index 7954a34..bc9ab59 100644 --- a/nbs/_example_prokaryotic.ipynb +++ b/nbs/_example_prokaryotic.ipynb @@ -168,14 +168,15 @@ "information_measure = \"TwoBodyEntropy\"\n", "\n", "# Settings affecting the reciprocal best hits part of the loss\n", - "hamming_similarities_cfg = {\n", + "similarity_kind = \"Hamming\"\n", + "similarities_cfg = {\n", " \"use_dot\": False,\n", " \"p\": 1\n", "}\n", "reciprocal_best_hits_cfg = {\n", " \"tau\": torch.tensor(1e-1)\n", "}\n", - "inter_group_loss_score_fn = torch.dot #torch.nn.CosineSimilarity(dim=-1)\n", + "inter_group_loss_score_fn = torch.nn.CosineSimilarity(dim=-1)\n", "\n", "# Loss weights\n", "loss_weights = {\n", @@ -202,7 +203,8 @@ " loss_weights=loss_weights,\n", " permutation_cfg=permutation_cfg,\n", " information_measure=information_measure,\n", - " hamming_similarities_cfg=hamming_similarities_cfg,\n", + " similarity_kind=similarity_kind,\n", + " similarities_cfg=similarities_cfg,\n", " reciprocal_best_hits_cfg=reciprocal_best_hits_cfg,\n", " inter_group_loss_score_fn=inter_group_loss_score_fn,\n", ")\n", @@ -231,7 +233,7 @@ " \"optimizer_name\": \"SGD\",\n", " \"optimizer_kwargs\": {\"lr\": 1e-1, \"weight_decay\": 0.},\n", " \"mean_centering\": True,\n", - " \"hamming_gradient_bypass\": False\n", + " \"similarity_gradient_bypass\": False\n", "}" ] }, @@ -305,6 +307,13 @@ "plt.title(\"Total, soft\")\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/nbs/base.ipynb b/nbs/base.ipynb index 648c0cd..d1bbc40 100644 --- a/nbs/base.ipynb +++ b/nbs/base.ipynb @@ -51,7 +51,24 @@ "#| export\n", "\n", "class DiffPASSMixin:\n", + " allowed_permutation_cfg_keys = {\n", + " \"tau\",\n", + " \"n_iter\",\n", + " \"noise\",\n", + " \"noise_factor\",\n", + " \"noise_std\",\n", + " }\n", + " allowed_information_measures = {\"MI\", \"TwoBodyEntropy\"}\n", + " allowed_similarity_kinds = {\"Hamming\", \"Blosum62\"}\n", + " allowed_similarities_cfg_keys = {\n", + " \"Hamming\": {\"use_dot\", \"p\"},\n", + " \"Blosum62\": {\"aa_to_int\", \"gaps_as_stars\"},\n", + " }\n", + " allowed_reciprocal_best_hits_cfg_keys = {\"tau\"}\n", + "\n", " group_sizes: Iterable[int]\n", + " information_measure: str\n", + " similarity_kind: str\n", "\n", " @staticmethod\n", " def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor:\n", @@ -62,6 +79,42 @@ "\n", " return x[..., used_tokens]\n", "\n", + " def validate_permutation_cfg(self, permutation_cfg: dict) -> None:\n", + " if not set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys):\n", + " raise ValueError(\n", + " f\"Invalid keys in `permutation_cfg`: {set(permutation_cfg) - self.allowed_permutation_cfg_keys}\"\n", + " )\n", + "\n", + " def validate_information_measure(self, information_measure: str) -> None:\n", + " if information_measure not in self.allowed_information_measures:\n", + " raise ValueError(\n", + " f\"Invalid information measure: {self.information_measure}. \"\n", + " f\"Allowed values are: {self.allowed_information_measures}\"\n", + " )\n", + "\n", + " def validate_similarity_kind(self, similarity_kind: str) -> None:\n", + " if similarity_kind not in self.allowed_similarity_kinds:\n", + " raise ValueError(\n", + " f\"Invalid similarity kind: {self.similarity_kind}. \"\n", + " f\"Allowed values are: {self.allowed_similarity_kinds}\"\n", + " )\n", + "\n", + " def validate_similarities_cfg(self, similarities_cfg: dict) -> None:\n", + " if not set(similarities_cfg).issubset(\n", + " self.allowed_similarities_cfg_keys[self.similarity_kind]\n", + " ):\n", + " raise ValueError(\n", + " f\"Invalid keys in `similarities_cfg`: {set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}\"\n", + " )\n", + "\n", + " def validate_reciprocal_best_hits_cfg(self, reciprocal_best_hits_cfg: dict) -> None:\n", + " if not set(reciprocal_best_hits_cfg).issubset(\n", + " self.allowed_reciprocal_best_hits_cfg_keys\n", + " ):\n", + " raise ValueError(\n", + " f\"Invalid keys in `reciprocal_best_hits_cfg`: {set(reciprocal_best_hits_cfg) - self.allowed_reciprocal_best_hits_cfg_keys}\"\n", + " )\n", + "\n", " def validate_inputs(\n", " self, x: torch.Tensor, y: torch.Tensor, check_same_alphabet_size: bool = False\n", " ) -> None:\n", diff --git a/nbs/model.ipynb b/nbs/model.ipynb index 0fced33..6731b26 100644 --- a/nbs/model.ipynb +++ b/nbs/model.ipynb @@ -481,6 +481,7 @@ " self,\n", " *,\n", " group_sizes: Optional[Iterable[int]] = None,\n", + " use_scoredist: bool = False,\n", " aa_to_int: Optional[dict[str, int]] = None,\n", " gaps_as_stars: bool = True,\n", " ) -> None:\n", @@ -488,6 +489,7 @@ " self.group_sizes = (\n", " tuple(s for s in group_sizes) if group_sizes is not None else None\n", " )\n", + " self.use_scoredist = use_scoredist\n", " self.aa_to_int = aa_to_int\n", " self.gaps_as_stars = gaps_as_stars\n", "\n", @@ -495,6 +497,7 @@ " aa_to_int=self.aa_to_int, gaps_as_stars=self.gaps_as_stars\n", " )\n", " self.register_buffer(\"subs_mat\", blosum62_data.mat)\n", + " self.expected_value = blosum62_data.expected_value\n", "\n", " self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)\n", "\n", @@ -506,7 +509,10 @@ " for sl in self._group_slices:\n", " out[..., sl, sl].copy_(\n", " smooth_substitution_matrix_similarities(\n", - " x[..., sl, :, :], subs_mat=self.subs_mat\n", + " x[..., sl, :, :],\n", + " subs_mat=self.subs_mat,\n", + " expected_value=self.expected_value,\n", + " use_scoredist=self.use_scoredist,\n", " )\n", " )\n", "\n", diff --git a/nbs/train.ipynb b/nbs/train.ipynb index 26396dc..0804d16 100644 --- a/nbs/train.ipynb +++ b/nbs/train.ipynb @@ -6,7 +6,7 @@ "source": [ "# train\n", "\n", - "> Perform optimization using DiffPASS" + "> Perform optimization using DiffPASS models" ] }, { @@ -160,17 +160,6 @@ "#| export\n", "\n", "class InformationAndReciprocalBestHits(Module, EnsembleMixin, DiffPASSMixin):\n", - " allowed_permutation_cfg_keys = {\n", - " \"tau\",\n", - " \"n_iter\",\n", - " \"noise\",\n", - " \"noise_factor\",\n", - " \"noise_std\",\n", - " }\n", - " allowed_information_measures = {\"MI\", \"TwoBodyEntropy\"}\n", - " allowed_hamming_similarities_cfg_keys = {\"use_dot\", \"p\"}\n", - " allowed_reciprocal_best_hits_cfg_keys = {\"tau\"}\n", - "\n", " def __init__(\n", " self,\n", " group_sizes: Iterable[int],\n", @@ -178,7 +167,8 @@ " loss_weights: Optional[dict[str, Union[float, torch.Tensor]]] = None,\n", " permutation_cfg: Optional[dict[str, Any]] = None,\n", " information_measure: Literal[\"MI\", \"TwoBodyEntropy\"] = \"TwoBodyEntropy\",\n", - " hamming_similarities_cfg: Optional[dict[str, Any]] = None,\n", + " similarity_kind: Literal[\"Hamming\", \"Blosum62\"] = \"Hamming\",\n", + " similarities_cfg: Optional[dict[str, Any]] = None,\n", " reciprocal_best_hits_cfg: Optional[dict[str, Any]] = None,\n", " inter_group_loss_score_fn: Optional[callable] = None,\n", " ):\n", @@ -188,7 +178,8 @@ " self.loss_weights = loss_weights\n", " self.permutation_cfg = permutation_cfg\n", " self.information_measure = information_measure\n", - " self.hamming_similarities_cfg = hamming_similarities_cfg\n", + " self.similarity_kind = similarity_kind\n", + " self.similarities_cfg = similarities_cfg\n", " self.reciprocal_best_hits_cfg = reciprocal_best_hits_cfg\n", " self.inter_group_loss_score_fn = inter_group_loss_score_fn\n", "\n", @@ -198,7 +189,7 @@ " if permutation_cfg is None:\n", " permutation_cfg = {}\n", " else:\n", - " assert set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys)\n", + " self.validate_permutation_cfg(permutation_cfg)\n", " permutation_cfg = deepcopy(permutation_cfg)\n", " _dim_in_ensemble = self._adjust_cfg_and_ensemble_shape(\n", " ensemble_shape=ensemble_shape,\n", @@ -208,24 +199,21 @@ " )\n", " self.effective_permutation_cfg_ = permutation_cfg\n", "\n", - " assert self.information_measure in self.allowed_information_measures\n", + " self.validate_information_measure(information_measure)\n", " self.loss_weights_keys = {self.information_measure, \"ReciprocalBestHits\"}\n", "\n", - " if hamming_similarities_cfg is None:\n", - " hamming_similarities_cfg = {}\n", + " self.validate_similarity_kind(similarity_kind)\n", + " if similarities_cfg is None:\n", + " similarities_cfg = {}\n", " else:\n", - " assert set(hamming_similarities_cfg).issubset(\n", - " self.allowed_hamming_similarities_cfg_keys\n", - " )\n", - " hamming_similarities_cfg = deepcopy(hamming_similarities_cfg)\n", - " self.effective_hamming_similarities_cfg_ = hamming_similarities_cfg\n", + " self.validate_similarities_cfg(similarities_cfg)\n", + " similarities_cfg = deepcopy(similarities_cfg)\n", + " self.effective_similarities_cfg_ = similarities_cfg\n", "\n", " if reciprocal_best_hits_cfg is None:\n", " reciprocal_best_hits_cfg = {}\n", " else:\n", - " assert set(reciprocal_best_hits_cfg).issubset(\n", - " self.allowed_reciprocal_best_hits_cfg_keys\n", - " )\n", + " self.validate_reciprocal_best_hits_cfg(reciprocal_best_hits_cfg)\n", " reciprocal_best_hits_cfg = deepcopy(reciprocal_best_hits_cfg)\n", " _dim_in_ensemble = self._adjust_cfg_and_ensemble_shape(\n", " ensemble_shape=ensemble_shape,\n", @@ -262,16 +250,14 @@ " self.information_loss = TwoBodyEntropyLoss()\n", " elif self.information_measure == \"MI\":\n", " self.information_loss = MILoss()\n", - " else:\n", - " # FIXME Redundant check\n", - " raise ValueError(\n", - " f\"``information_measure`` must be one of {self.allowed_information_measures}.\"\n", - " )\n", - "\n", " self.information_loss.register_buffer(\n", " \"weight\", loss_weights[self.information_measure]\n", " )\n", - " self.hamming_similarities = HammingSimilarities(**hamming_similarities_cfg)\n", + "\n", + " if similarity_kind == \"Blosum62\":\n", + " self.similarities = Blosum62Similarities(**similarities_cfg)\n", + " elif similarity_kind == \"Hamming\":\n", + " self.similarities = HammingSimilarities(**similarities_cfg)\n", " self.reciprocal_best_hits = ReciprocalBestHits(\n", " group_sizes=self.group_sizes,\n", " ensemble_shape=ensemble_shape,\n", @@ -373,9 +359,9 @@ " def _precompute_rbh(self, x: torch.Tensor, y: torch.Tensor) -> None:\n", " # Temporarily switch to hard RBH\n", " self.reciprocal_best_hits.hard_()\n", - " similarities_x = self.hamming_similarities(x)\n", + " similarities_x = self.similarities(x)\n", " self._rbh_hard_x = self.reciprocal_best_hits(similarities_x)\n", - " similarities_y = self.hamming_similarities(y)\n", + " similarities_y = self.similarities(y)\n", " self._rbh_hard_y = self.reciprocal_best_hits(similarities_y)\n", "\n", " # Revert to soft (default) RBH\n", @@ -405,7 +391,7 @@ " if mode == \"soft\":\n", " if x_perm_hard is not None:\n", " x_perm = (x_perm_hard - x_perm).detach() + x_perm\n", - " similarities_x = self.hamming_similarities(x_perm)\n", + " similarities_x = self.similarities(x_perm)\n", " rbh_x = self.reciprocal_best_hits(similarities_x)\n", " else:\n", " rbh_x = apply_hard_permutation_batch_to_similarity(\n", @@ -438,7 +424,7 @@ " optimizer_name: Optional[str] = \"SGD\",\n", " optimizer_kwargs: Optional[dict[str, Any]] = None,\n", " mean_centering: bool = True,\n", - " hamming_gradient_bypass: bool = False,\n", + " similarity_gradient_bypass: bool = False,\n", " ) -> DiffPASSResults:\n", " # Validate inputs\n", " self.validate_inputs(x, y)\n", @@ -499,7 +485,7 @@ " epoch_results = self(x, y)\n", " loss_info = epoch_results[\"loss_info\"]\n", " loss_rbh = epoch_results[\"loss_rbh\"]\n", - " if hamming_gradient_bypass:\n", + " if similarity_gradient_bypass:\n", " x_perm_hard = epoch_results[\"x_perm\"]\n", " perms = epoch_results[\"perms\"]\n", " results.log_alphas.append(\n",