From 2056ddadba5d4cbd63f1946f5f8d8c0ed67b652a Mon Sep 17 00:00:00 2001
From: Umberto Lupo <umberto.lupo@epfl.ch>
Date: Sat, 14 Oct 2023 00:47:43 +0200
Subject: [PATCH] Generalize InformationAndReciprocalBestHits to include
 Blosum62, generalize Blosum62

---
 diffpass/_modidx.py            | 10 ++++++
 diffpass/base.py               | 53 +++++++++++++++++++++++++++++
 diffpass/model.py              |  8 ++++-
 diffpass/train.py              | 60 +++++++++++++-------------------
 nbs/_example_prokaryotic.ipynb | 17 +++++++---
 nbs/base.ipynb                 | 53 +++++++++++++++++++++++++++++
 nbs/model.ipynb                |  8 ++++-
 nbs/train.ipynb                | 62 +++++++++++++---------------------
 8 files changed, 190 insertions(+), 81 deletions(-)

diff --git a/diffpass/_modidx.py b/diffpass/_modidx.py
index 12aedb9..8e4b84c 100644
--- a/diffpass/_modidx.py
+++ b/diffpass/_modidx.py
@@ -8,8 +8,18 @@
   'syms': { 'diffpass.base': { 'diffpass.base.DiffPASSMixin': ('base.html#diffpassmixin', 'diffpass/base.py'),
                                'diffpass.base.DiffPASSMixin.reduce_num_tokens': ( 'base.html#diffpassmixin.reduce_num_tokens',
                                                                                   'diffpass/base.py'),
+                               'diffpass.base.DiffPASSMixin.validate_information_measure': ( 'base.html#diffpassmixin.validate_information_measure',
+                                                                                             'diffpass/base.py'),
                                'diffpass.base.DiffPASSMixin.validate_inputs': ( 'base.html#diffpassmixin.validate_inputs',
                                                                                 'diffpass/base.py'),
+                               'diffpass.base.DiffPASSMixin.validate_permutation_cfg': ( 'base.html#diffpassmixin.validate_permutation_cfg',
+                                                                                         'diffpass/base.py'),
+                               'diffpass.base.DiffPASSMixin.validate_reciprocal_best_hits_cfg': ( 'base.html#diffpassmixin.validate_reciprocal_best_hits_cfg',
+                                                                                                  'diffpass/base.py'),
+                               'diffpass.base.DiffPASSMixin.validate_similarities_cfg': ( 'base.html#diffpassmixin.validate_similarities_cfg',
+                                                                                          'diffpass/base.py'),
+                               'diffpass.base.DiffPASSMixin.validate_similarity_kind': ( 'base.html#diffpassmixin.validate_similarity_kind',
+                                                                                         'diffpass/base.py'),
                                'diffpass.base.EnsembleMixin': ('base.html#ensemblemixin', 'diffpass/base.py'),
                                'diffpass.base.EnsembleMixin._reshape_ensemble_param': ( 'base.html#ensemblemixin._reshape_ensemble_param',
                                                                                         'diffpass/base.py'),
diff --git a/diffpass/base.py b/diffpass/base.py
index 3ee6ccc..20ef1cc 100644
--- a/diffpass/base.py
+++ b/diffpass/base.py
@@ -21,7 +21,24 @@
 
 # %% ../nbs/base.ipynb 3
 class DiffPASSMixin:
+    allowed_permutation_cfg_keys = {
+        "tau",
+        "n_iter",
+        "noise",
+        "noise_factor",
+        "noise_std",
+    }
+    allowed_information_measures = {"MI", "TwoBodyEntropy"}
+    allowed_similarity_kinds = {"Hamming", "Blosum62"}
+    allowed_similarities_cfg_keys = {
+        "Hamming": {"use_dot", "p"},
+        "Blosum62": {"aa_to_int", "gaps_as_stars"},
+    }
+    allowed_reciprocal_best_hits_cfg_keys = {"tau"}
+
     group_sizes: Iterable[int]
+    information_measure: str
+    similarity_kind: str
 
     @staticmethod
     def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor:
@@ -32,6 +49,42 @@ def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor:
 
         return x[..., used_tokens]
 
+    def validate_permutation_cfg(self, permutation_cfg: dict) -> None:
+        if not set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys):
+            raise ValueError(
+                f"Invalid keys in `permutation_cfg`: {set(permutation_cfg) - self.allowed_permutation_cfg_keys}"
+            )
+
+    def validate_information_measure(self, information_measure: str) -> None:
+        if information_measure not in self.allowed_information_measures:
+            raise ValueError(
+                f"Invalid information measure: {self.information_measure}. "
+                f"Allowed values are: {self.allowed_information_measures}"
+            )
+
+    def validate_similarity_kind(self, similarity_kind: str) -> None:
+        if similarity_kind not in self.allowed_similarity_kinds:
+            raise ValueError(
+                f"Invalid similarity kind: {self.similarity_kind}. "
+                f"Allowed values are: {self.allowed_similarity_kinds}"
+            )
+
+    def validate_similarities_cfg(self, similarities_cfg: dict) -> None:
+        if not set(similarities_cfg).issubset(
+            self.allowed_similarities_cfg_keys[self.similarity_kind]
+        ):
+            raise ValueError(
+                f"Invalid keys in `similarities_cfg`: {set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}"
+            )
+
+    def validate_reciprocal_best_hits_cfg(self, reciprocal_best_hits_cfg: dict) -> None:
+        if not set(reciprocal_best_hits_cfg).issubset(
+            self.allowed_reciprocal_best_hits_cfg_keys
+        ):
+            raise ValueError(
+                f"Invalid keys in `reciprocal_best_hits_cfg`: {set(reciprocal_best_hits_cfg) - self.allowed_reciprocal_best_hits_cfg_keys}"
+            )
+
     def validate_inputs(
         self, x: torch.Tensor, y: torch.Tensor, check_same_alphabet_size: bool = False
     ) -> None:
diff --git a/diffpass/model.py b/diffpass/model.py
index 09ba272..1658176 100644
--- a/diffpass/model.py
+++ b/diffpass/model.py
@@ -303,6 +303,7 @@ def __init__(
         self,
         *,
         group_sizes: Optional[Iterable[int]] = None,
+        use_scoredist: bool = False,
         aa_to_int: Optional[dict[str, int]] = None,
         gaps_as_stars: bool = True,
     ) -> None:
@@ -310,6 +311,7 @@ def __init__(
         self.group_sizes = (
             tuple(s for s in group_sizes) if group_sizes is not None else None
         )
+        self.use_scoredist = use_scoredist
         self.aa_to_int = aa_to_int
         self.gaps_as_stars = gaps_as_stars
 
@@ -317,6 +319,7 @@ def __init__(
             aa_to_int=self.aa_to_int, gaps_as_stars=self.gaps_as_stars
         )
         self.register_buffer("subs_mat", blosum62_data.mat)
+        self.expected_value = blosum62_data.expected_value
 
         self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)
 
@@ -328,7 +331,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
         for sl in self._group_slices:
             out[..., sl, sl].copy_(
                 smooth_substitution_matrix_similarities(
-                    x[..., sl, :, :], subs_mat=self.subs_mat
+                    x[..., sl, :, :],
+                    subs_mat=self.subs_mat,
+                    expected_value=self.expected_value,
+                    use_scoredist=self.use_scoredist,
                 )
             )
 
diff --git a/diffpass/train.py b/diffpass/train.py
index 112111a..3595a5b 100644
--- a/diffpass/train.py
+++ b/diffpass/train.py
@@ -88,17 +88,6 @@ class DiffPASSResults:
 
 # %% ../nbs/train.ipynb 6
 class InformationAndReciprocalBestHits(Module, EnsembleMixin, DiffPASSMixin):
-    allowed_permutation_cfg_keys = {
-        "tau",
-        "n_iter",
-        "noise",
-        "noise_factor",
-        "noise_std",
-    }
-    allowed_information_measures = {"MI", "TwoBodyEntropy"}
-    allowed_hamming_similarities_cfg_keys = {"use_dot", "p"}
-    allowed_reciprocal_best_hits_cfg_keys = {"tau"}
-
     def __init__(
         self,
         group_sizes: Iterable[int],
@@ -106,7 +95,8 @@ def __init__(
         loss_weights: Optional[dict[str, Union[float, torch.Tensor]]] = None,
         permutation_cfg: Optional[dict[str, Any]] = None,
         information_measure: Literal["MI", "TwoBodyEntropy"] = "TwoBodyEntropy",
-        hamming_similarities_cfg: Optional[dict[str, Any]] = None,
+        similarity_kind: Literal["Hamming", "Blosum62"] = "Hamming",
+        similarities_cfg: Optional[dict[str, Any]] = None,
         reciprocal_best_hits_cfg: Optional[dict[str, Any]] = None,
         inter_group_loss_score_fn: Optional[callable] = None,
     ):
@@ -116,7 +106,8 @@ def __init__(
         self.loss_weights = loss_weights
         self.permutation_cfg = permutation_cfg
         self.information_measure = information_measure
-        self.hamming_similarities_cfg = hamming_similarities_cfg
+        self.similarity_kind = similarity_kind
+        self.similarities_cfg = similarities_cfg
         self.reciprocal_best_hits_cfg = reciprocal_best_hits_cfg
         self.inter_group_loss_score_fn = inter_group_loss_score_fn
 
@@ -126,7 +117,7 @@ def __init__(
         if permutation_cfg is None:
             permutation_cfg = {}
         else:
-            assert set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys)
+            self.validate_permutation_cfg(permutation_cfg)
             permutation_cfg = deepcopy(permutation_cfg)
             _dim_in_ensemble = self._adjust_cfg_and_ensemble_shape(
                 ensemble_shape=ensemble_shape,
@@ -136,24 +127,21 @@ def __init__(
             )
         self.effective_permutation_cfg_ = permutation_cfg
 
-        assert self.information_measure in self.allowed_information_measures
+        self.validate_information_measure(information_measure)
         self.loss_weights_keys = {self.information_measure, "ReciprocalBestHits"}
 
-        if hamming_similarities_cfg is None:
-            hamming_similarities_cfg = {}
+        self.validate_similarity_kind(similarity_kind)
+        if similarities_cfg is None:
+            similarities_cfg = {}
         else:
-            assert set(hamming_similarities_cfg).issubset(
-                self.allowed_hamming_similarities_cfg_keys
-            )
-            hamming_similarities_cfg = deepcopy(hamming_similarities_cfg)
-        self.effective_hamming_similarities_cfg_ = hamming_similarities_cfg
+            self.validate_similarities_cfg(similarities_cfg)
+            similarities_cfg = deepcopy(similarities_cfg)
+        self.effective_similarities_cfg_ = similarities_cfg
 
         if reciprocal_best_hits_cfg is None:
             reciprocal_best_hits_cfg = {}
         else:
-            assert set(reciprocal_best_hits_cfg).issubset(
-                self.allowed_reciprocal_best_hits_cfg_keys
-            )
+            self.validate_reciprocal_best_hits_cfg(reciprocal_best_hits_cfg)
             reciprocal_best_hits_cfg = deepcopy(reciprocal_best_hits_cfg)
             _dim_in_ensemble = self._adjust_cfg_and_ensemble_shape(
                 ensemble_shape=ensemble_shape,
@@ -190,16 +178,14 @@ def __init__(
             self.information_loss = TwoBodyEntropyLoss()
         elif self.information_measure == "MI":
             self.information_loss = MILoss()
-        else:
-            # FIXME Redundant check
-            raise ValueError(
-                f"``information_measure`` must be one of {self.allowed_information_measures}."
-            )
-
         self.information_loss.register_buffer(
             "weight", loss_weights[self.information_measure]
         )
-        self.hamming_similarities = HammingSimilarities(**hamming_similarities_cfg)
+
+        if similarity_kind == "Blosum62":
+            self.similarities = Blosum62Similarities(**similarities_cfg)
+        elif similarity_kind == "Hamming":
+            self.similarities = HammingSimilarities(**similarities_cfg)
         self.reciprocal_best_hits = ReciprocalBestHits(
             group_sizes=self.group_sizes,
             ensemble_shape=ensemble_shape,
@@ -301,9 +287,9 @@ def _adjust_loss_weights_and_ensemble_shape(
     def _precompute_rbh(self, x: torch.Tensor, y: torch.Tensor) -> None:
         # Temporarily switch to hard RBH
         self.reciprocal_best_hits.hard_()
-        similarities_x = self.hamming_similarities(x)
+        similarities_x = self.similarities(x)
         self._rbh_hard_x = self.reciprocal_best_hits(similarities_x)
-        similarities_y = self.hamming_similarities(y)
+        similarities_y = self.similarities(y)
         self._rbh_hard_y = self.reciprocal_best_hits(similarities_y)
 
         # Revert to soft (default) RBH
@@ -333,7 +319,7 @@ def forward(
         if mode == "soft":
             if x_perm_hard is not None:
                 x_perm = (x_perm_hard - x_perm).detach() + x_perm
-            similarities_x = self.hamming_similarities(x_perm)
+            similarities_x = self.similarities(x_perm)
             rbh_x = self.reciprocal_best_hits(similarities_x)
         else:
             rbh_x = apply_hard_permutation_batch_to_similarity(
@@ -366,7 +352,7 @@ def fit(
         optimizer_name: Optional[str] = "SGD",
         optimizer_kwargs: Optional[dict[str, Any]] = None,
         mean_centering: bool = True,
-        hamming_gradient_bypass: bool = False,
+        similarity_gradient_bypass: bool = False,
     ) -> DiffPASSResults:
         # Validate inputs
         self.validate_inputs(x, y)
@@ -427,7 +413,7 @@ def fit(
                     epoch_results = self(x, y)
                     loss_info = epoch_results["loss_info"]
                     loss_rbh = epoch_results["loss_rbh"]
-                    if hamming_gradient_bypass:
+                    if similarity_gradient_bypass:
                         x_perm_hard = epoch_results["x_perm"]
                     perms = epoch_results["perms"]
                     results.log_alphas.append(
diff --git a/nbs/_example_prokaryotic.ipynb b/nbs/_example_prokaryotic.ipynb
index 7954a34..bc9ab59 100644
--- a/nbs/_example_prokaryotic.ipynb
+++ b/nbs/_example_prokaryotic.ipynb
@@ -168,14 +168,15 @@
     "information_measure = \"TwoBodyEntropy\"\n",
     "\n",
     "# Settings affecting the reciprocal best hits part of the loss\n",
-    "hamming_similarities_cfg = {\n",
+    "similarity_kind = \"Hamming\"\n",
+    "similarities_cfg = {\n",
     "    \"use_dot\": False,\n",
     "    \"p\": 1\n",
     "}\n",
     "reciprocal_best_hits_cfg = {\n",
     "    \"tau\": torch.tensor(1e-1)\n",
     "}\n",
-    "inter_group_loss_score_fn = torch.dot  #torch.nn.CosineSimilarity(dim=-1)\n",
+    "inter_group_loss_score_fn = torch.nn.CosineSimilarity(dim=-1)\n",
     "\n",
     "# Loss weights\n",
     "loss_weights = {\n",
@@ -202,7 +203,8 @@
     "    loss_weights=loss_weights,\n",
     "    permutation_cfg=permutation_cfg,\n",
     "    information_measure=information_measure,\n",
-    "    hamming_similarities_cfg=hamming_similarities_cfg,\n",
+    "    similarity_kind=similarity_kind,\n",
+    "    similarities_cfg=similarities_cfg,\n",
     "    reciprocal_best_hits_cfg=reciprocal_best_hits_cfg,\n",
     "    inter_group_loss_score_fn=inter_group_loss_score_fn,\n",
     ")\n",
@@ -231,7 +233,7 @@
     "    \"optimizer_name\": \"SGD\",\n",
     "    \"optimizer_kwargs\": {\"lr\": 1e-1, \"weight_decay\": 0.},\n",
     "    \"mean_centering\": True,\n",
-    "    \"hamming_gradient_bypass\": False\n",
+    "    \"similarity_gradient_bypass\": False\n",
     "}"
    ]
   },
@@ -305,6 +307,13 @@
     "plt.title(\"Total, soft\")\n",
     "plt.show()"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
   }
  ],
  "metadata": {
diff --git a/nbs/base.ipynb b/nbs/base.ipynb
index 648c0cd..d1bbc40 100644
--- a/nbs/base.ipynb
+++ b/nbs/base.ipynb
@@ -51,7 +51,24 @@
     "#| export\n",
     "\n",
     "class DiffPASSMixin:\n",
+    "    allowed_permutation_cfg_keys = {\n",
+    "        \"tau\",\n",
+    "        \"n_iter\",\n",
+    "        \"noise\",\n",
+    "        \"noise_factor\",\n",
+    "        \"noise_std\",\n",
+    "    }\n",
+    "    allowed_information_measures = {\"MI\", \"TwoBodyEntropy\"}\n",
+    "    allowed_similarity_kinds = {\"Hamming\", \"Blosum62\"}\n",
+    "    allowed_similarities_cfg_keys = {\n",
+    "        \"Hamming\": {\"use_dot\", \"p\"},\n",
+    "        \"Blosum62\": {\"aa_to_int\", \"gaps_as_stars\"},\n",
+    "    }\n",
+    "    allowed_reciprocal_best_hits_cfg_keys = {\"tau\"}\n",
+    "\n",
     "    group_sizes: Iterable[int]\n",
+    "    information_measure: str\n",
+    "    similarity_kind: str\n",
     "\n",
     "    @staticmethod\n",
     "    def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor:\n",
@@ -62,6 +79,42 @@
     "\n",
     "        return x[..., used_tokens]\n",
     "\n",
+    "    def validate_permutation_cfg(self, permutation_cfg: dict) -> None:\n",
+    "        if not set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys):\n",
+    "            raise ValueError(\n",
+    "                f\"Invalid keys in `permutation_cfg`: {set(permutation_cfg) - self.allowed_permutation_cfg_keys}\"\n",
+    "            )\n",
+    "\n",
+    "    def validate_information_measure(self, information_measure: str) -> None:\n",
+    "        if information_measure not in self.allowed_information_measures:\n",
+    "            raise ValueError(\n",
+    "                f\"Invalid information measure: {self.information_measure}. \"\n",
+    "                f\"Allowed values are: {self.allowed_information_measures}\"\n",
+    "            )\n",
+    "\n",
+    "    def validate_similarity_kind(self, similarity_kind: str) -> None:\n",
+    "        if similarity_kind not in self.allowed_similarity_kinds:\n",
+    "            raise ValueError(\n",
+    "                f\"Invalid similarity kind: {self.similarity_kind}. \"\n",
+    "                f\"Allowed values are: {self.allowed_similarity_kinds}\"\n",
+    "            )\n",
+    "\n",
+    "    def validate_similarities_cfg(self, similarities_cfg: dict) -> None:\n",
+    "        if not set(similarities_cfg).issubset(\n",
+    "            self.allowed_similarities_cfg_keys[self.similarity_kind]\n",
+    "        ):\n",
+    "            raise ValueError(\n",
+    "                f\"Invalid keys in `similarities_cfg`: {set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}\"\n",
+    "            )\n",
+    "\n",
+    "    def validate_reciprocal_best_hits_cfg(self, reciprocal_best_hits_cfg: dict) -> None:\n",
+    "        if not set(reciprocal_best_hits_cfg).issubset(\n",
+    "            self.allowed_reciprocal_best_hits_cfg_keys\n",
+    "        ):\n",
+    "            raise ValueError(\n",
+    "                f\"Invalid keys in `reciprocal_best_hits_cfg`: {set(reciprocal_best_hits_cfg) - self.allowed_reciprocal_best_hits_cfg_keys}\"\n",
+    "            )\n",
+    "\n",
     "    def validate_inputs(\n",
     "        self, x: torch.Tensor, y: torch.Tensor, check_same_alphabet_size: bool = False\n",
     "    ) -> None:\n",
diff --git a/nbs/model.ipynb b/nbs/model.ipynb
index 0fced33..6731b26 100644
--- a/nbs/model.ipynb
+++ b/nbs/model.ipynb
@@ -481,6 +481,7 @@
     "        self,\n",
     "        *,\n",
     "        group_sizes: Optional[Iterable[int]] = None,\n",
+    "        use_scoredist: bool = False,\n",
     "        aa_to_int: Optional[dict[str, int]] = None,\n",
     "        gaps_as_stars: bool = True,\n",
     "    ) -> None:\n",
@@ -488,6 +489,7 @@
     "        self.group_sizes = (\n",
     "            tuple(s for s in group_sizes) if group_sizes is not None else None\n",
     "        )\n",
+    "        self.use_scoredist = use_scoredist\n",
     "        self.aa_to_int = aa_to_int\n",
     "        self.gaps_as_stars = gaps_as_stars\n",
     "\n",
@@ -495,6 +497,7 @@
     "            aa_to_int=self.aa_to_int, gaps_as_stars=self.gaps_as_stars\n",
     "        )\n",
     "        self.register_buffer(\"subs_mat\", blosum62_data.mat)\n",
+    "        self.expected_value = blosum62_data.expected_value\n",
     "\n",
     "        self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)\n",
     "\n",
@@ -506,7 +509,10 @@
     "        for sl in self._group_slices:\n",
     "            out[..., sl, sl].copy_(\n",
     "                smooth_substitution_matrix_similarities(\n",
-    "                    x[..., sl, :, :], subs_mat=self.subs_mat\n",
+    "                    x[..., sl, :, :],\n",
+    "                    subs_mat=self.subs_mat,\n",
+    "                    expected_value=self.expected_value,\n",
+    "                    use_scoredist=self.use_scoredist,\n",
     "                )\n",
     "            )\n",
     "\n",
diff --git a/nbs/train.ipynb b/nbs/train.ipynb
index 26396dc..0804d16 100644
--- a/nbs/train.ipynb
+++ b/nbs/train.ipynb
@@ -6,7 +6,7 @@
    "source": [
     "# train\n",
     "\n",
-    "> Perform optimization using DiffPASS"
+    "> Perform optimization using DiffPASS models"
    ]
   },
   {
@@ -160,17 +160,6 @@
     "#| export\n",
     "\n",
     "class InformationAndReciprocalBestHits(Module, EnsembleMixin, DiffPASSMixin):\n",
-    "    allowed_permutation_cfg_keys = {\n",
-    "        \"tau\",\n",
-    "        \"n_iter\",\n",
-    "        \"noise\",\n",
-    "        \"noise_factor\",\n",
-    "        \"noise_std\",\n",
-    "    }\n",
-    "    allowed_information_measures = {\"MI\", \"TwoBodyEntropy\"}\n",
-    "    allowed_hamming_similarities_cfg_keys = {\"use_dot\", \"p\"}\n",
-    "    allowed_reciprocal_best_hits_cfg_keys = {\"tau\"}\n",
-    "\n",
     "    def __init__(\n",
     "        self,\n",
     "        group_sizes: Iterable[int],\n",
@@ -178,7 +167,8 @@
     "        loss_weights: Optional[dict[str, Union[float, torch.Tensor]]] = None,\n",
     "        permutation_cfg: Optional[dict[str, Any]] = None,\n",
     "        information_measure: Literal[\"MI\", \"TwoBodyEntropy\"] = \"TwoBodyEntropy\",\n",
-    "        hamming_similarities_cfg: Optional[dict[str, Any]] = None,\n",
+    "        similarity_kind: Literal[\"Hamming\", \"Blosum62\"] = \"Hamming\",\n",
+    "        similarities_cfg: Optional[dict[str, Any]] = None,\n",
     "        reciprocal_best_hits_cfg: Optional[dict[str, Any]] = None,\n",
     "        inter_group_loss_score_fn: Optional[callable] = None,\n",
     "    ):\n",
@@ -188,7 +178,8 @@
     "        self.loss_weights = loss_weights\n",
     "        self.permutation_cfg = permutation_cfg\n",
     "        self.information_measure = information_measure\n",
-    "        self.hamming_similarities_cfg = hamming_similarities_cfg\n",
+    "        self.similarity_kind = similarity_kind\n",
+    "        self.similarities_cfg = similarities_cfg\n",
     "        self.reciprocal_best_hits_cfg = reciprocal_best_hits_cfg\n",
     "        self.inter_group_loss_score_fn = inter_group_loss_score_fn\n",
     "\n",
@@ -198,7 +189,7 @@
     "        if permutation_cfg is None:\n",
     "            permutation_cfg = {}\n",
     "        else:\n",
-    "            assert set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys)\n",
+    "            self.validate_permutation_cfg(permutation_cfg)\n",
     "            permutation_cfg = deepcopy(permutation_cfg)\n",
     "            _dim_in_ensemble = self._adjust_cfg_and_ensemble_shape(\n",
     "                ensemble_shape=ensemble_shape,\n",
@@ -208,24 +199,21 @@
     "            )\n",
     "        self.effective_permutation_cfg_ = permutation_cfg\n",
     "\n",
-    "        assert self.information_measure in self.allowed_information_measures\n",
+    "        self.validate_information_measure(information_measure)\n",
     "        self.loss_weights_keys = {self.information_measure, \"ReciprocalBestHits\"}\n",
     "\n",
-    "        if hamming_similarities_cfg is None:\n",
-    "            hamming_similarities_cfg = {}\n",
+    "        self.validate_similarity_kind(similarity_kind)\n",
+    "        if similarities_cfg is None:\n",
+    "            similarities_cfg = {}\n",
     "        else:\n",
-    "            assert set(hamming_similarities_cfg).issubset(\n",
-    "                self.allowed_hamming_similarities_cfg_keys\n",
-    "            )\n",
-    "            hamming_similarities_cfg = deepcopy(hamming_similarities_cfg)\n",
-    "        self.effective_hamming_similarities_cfg_ = hamming_similarities_cfg\n",
+    "            self.validate_similarities_cfg(similarities_cfg)\n",
+    "            similarities_cfg = deepcopy(similarities_cfg)\n",
+    "        self.effective_similarities_cfg_ = similarities_cfg\n",
     "\n",
     "        if reciprocal_best_hits_cfg is None:\n",
     "            reciprocal_best_hits_cfg = {}\n",
     "        else:\n",
-    "            assert set(reciprocal_best_hits_cfg).issubset(\n",
-    "                self.allowed_reciprocal_best_hits_cfg_keys\n",
-    "            )\n",
+    "            self.validate_reciprocal_best_hits_cfg(reciprocal_best_hits_cfg)\n",
     "            reciprocal_best_hits_cfg = deepcopy(reciprocal_best_hits_cfg)\n",
     "            _dim_in_ensemble = self._adjust_cfg_and_ensemble_shape(\n",
     "                ensemble_shape=ensemble_shape,\n",
@@ -262,16 +250,14 @@
     "            self.information_loss = TwoBodyEntropyLoss()\n",
     "        elif self.information_measure == \"MI\":\n",
     "            self.information_loss = MILoss()\n",
-    "        else:\n",
-    "            # FIXME Redundant check\n",
-    "            raise ValueError(\n",
-    "                f\"``information_measure`` must be one of {self.allowed_information_measures}.\"\n",
-    "            )\n",
-    "\n",
     "        self.information_loss.register_buffer(\n",
     "            \"weight\", loss_weights[self.information_measure]\n",
     "        )\n",
-    "        self.hamming_similarities = HammingSimilarities(**hamming_similarities_cfg)\n",
+    "\n",
+    "        if similarity_kind == \"Blosum62\":\n",
+    "            self.similarities = Blosum62Similarities(**similarities_cfg)\n",
+    "        elif similarity_kind == \"Hamming\":\n",
+    "            self.similarities = HammingSimilarities(**similarities_cfg)\n",
     "        self.reciprocal_best_hits = ReciprocalBestHits(\n",
     "            group_sizes=self.group_sizes,\n",
     "            ensemble_shape=ensemble_shape,\n",
@@ -373,9 +359,9 @@
     "    def _precompute_rbh(self, x: torch.Tensor, y: torch.Tensor) -> None:\n",
     "        # Temporarily switch to hard RBH\n",
     "        self.reciprocal_best_hits.hard_()\n",
-    "        similarities_x = self.hamming_similarities(x)\n",
+    "        similarities_x = self.similarities(x)\n",
     "        self._rbh_hard_x = self.reciprocal_best_hits(similarities_x)\n",
-    "        similarities_y = self.hamming_similarities(y)\n",
+    "        similarities_y = self.similarities(y)\n",
     "        self._rbh_hard_y = self.reciprocal_best_hits(similarities_y)\n",
     "\n",
     "        # Revert to soft (default) RBH\n",
@@ -405,7 +391,7 @@
     "        if mode == \"soft\":\n",
     "            if x_perm_hard is not None:\n",
     "                x_perm = (x_perm_hard - x_perm).detach() + x_perm\n",
-    "            similarities_x = self.hamming_similarities(x_perm)\n",
+    "            similarities_x = self.similarities(x_perm)\n",
     "            rbh_x = self.reciprocal_best_hits(similarities_x)\n",
     "        else:\n",
     "            rbh_x = apply_hard_permutation_batch_to_similarity(\n",
@@ -438,7 +424,7 @@
     "        optimizer_name: Optional[str] = \"SGD\",\n",
     "        optimizer_kwargs: Optional[dict[str, Any]] = None,\n",
     "        mean_centering: bool = True,\n",
-    "        hamming_gradient_bypass: bool = False,\n",
+    "        similarity_gradient_bypass: bool = False,\n",
     "    ) -> DiffPASSResults:\n",
     "        # Validate inputs\n",
     "        self.validate_inputs(x, y)\n",
@@ -499,7 +485,7 @@
     "                    epoch_results = self(x, y)\n",
     "                    loss_info = epoch_results[\"loss_info\"]\n",
     "                    loss_rbh = epoch_results[\"loss_rbh\"]\n",
-    "                    if hamming_gradient_bypass:\n",
+    "                    if similarity_gradient_bypass:\n",
     "                        x_perm_hard = epoch_results[\"x_perm\"]\n",
     "                    perms = epoch_results[\"perms\"]\n",
     "                    results.log_alphas.append(\n",