Skip to content

Commit

Permalink
Generalize InformationAndReciprocalBestHits to include Blosum62, gene…
Browse files Browse the repository at this point in the history
…ralize Blosum62
  • Loading branch information
ulupo committed Oct 13, 2023
1 parent 582f524 commit 2056dda
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 81 deletions.
10 changes: 10 additions & 0 deletions diffpass/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,18 @@
'syms': { 'diffpass.base': { 'diffpass.base.DiffPASSMixin': ('base.html#diffpassmixin', 'diffpass/base.py'),
'diffpass.base.DiffPASSMixin.reduce_num_tokens': ( 'base.html#diffpassmixin.reduce_num_tokens',
'diffpass/base.py'),
'diffpass.base.DiffPASSMixin.validate_information_measure': ( 'base.html#diffpassmixin.validate_information_measure',
'diffpass/base.py'),
'diffpass.base.DiffPASSMixin.validate_inputs': ( 'base.html#diffpassmixin.validate_inputs',
'diffpass/base.py'),
'diffpass.base.DiffPASSMixin.validate_permutation_cfg': ( 'base.html#diffpassmixin.validate_permutation_cfg',
'diffpass/base.py'),
'diffpass.base.DiffPASSMixin.validate_reciprocal_best_hits_cfg': ( 'base.html#diffpassmixin.validate_reciprocal_best_hits_cfg',
'diffpass/base.py'),
'diffpass.base.DiffPASSMixin.validate_similarities_cfg': ( 'base.html#diffpassmixin.validate_similarities_cfg',
'diffpass/base.py'),
'diffpass.base.DiffPASSMixin.validate_similarity_kind': ( 'base.html#diffpassmixin.validate_similarity_kind',
'diffpass/base.py'),
'diffpass.base.EnsembleMixin': ('base.html#ensemblemixin', 'diffpass/base.py'),
'diffpass.base.EnsembleMixin._reshape_ensemble_param': ( 'base.html#ensemblemixin._reshape_ensemble_param',
'diffpass/base.py'),
Expand Down
53 changes: 53 additions & 0 deletions diffpass/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,24 @@

# %% ../nbs/base.ipynb 3
class DiffPASSMixin:
allowed_permutation_cfg_keys = {
"tau",
"n_iter",
"noise",
"noise_factor",
"noise_std",
}
allowed_information_measures = {"MI", "TwoBodyEntropy"}
allowed_similarity_kinds = {"Hamming", "Blosum62"}
allowed_similarities_cfg_keys = {
"Hamming": {"use_dot", "p"},
"Blosum62": {"aa_to_int", "gaps_as_stars"},
}
allowed_reciprocal_best_hits_cfg_keys = {"tau"}

group_sizes: Iterable[int]
information_measure: str
similarity_kind: str

@staticmethod
def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor:
Expand All @@ -32,6 +49,42 @@ def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor:

return x[..., used_tokens]

def validate_permutation_cfg(self, permutation_cfg: dict) -> None:
if not set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys):
raise ValueError(
f"Invalid keys in `permutation_cfg`: {set(permutation_cfg) - self.allowed_permutation_cfg_keys}"
)

def validate_information_measure(self, information_measure: str) -> None:
if information_measure not in self.allowed_information_measures:
raise ValueError(
f"Invalid information measure: {self.information_measure}. "
f"Allowed values are: {self.allowed_information_measures}"
)

def validate_similarity_kind(self, similarity_kind: str) -> None:
if similarity_kind not in self.allowed_similarity_kinds:
raise ValueError(
f"Invalid similarity kind: {self.similarity_kind}. "
f"Allowed values are: {self.allowed_similarity_kinds}"
)

def validate_similarities_cfg(self, similarities_cfg: dict) -> None:
if not set(similarities_cfg).issubset(
self.allowed_similarities_cfg_keys[self.similarity_kind]
):
raise ValueError(
f"Invalid keys in `similarities_cfg`: {set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}"
)

def validate_reciprocal_best_hits_cfg(self, reciprocal_best_hits_cfg: dict) -> None:
if not set(reciprocal_best_hits_cfg).issubset(
self.allowed_reciprocal_best_hits_cfg_keys
):
raise ValueError(
f"Invalid keys in `reciprocal_best_hits_cfg`: {set(reciprocal_best_hits_cfg) - self.allowed_reciprocal_best_hits_cfg_keys}"
)

def validate_inputs(
self, x: torch.Tensor, y: torch.Tensor, check_same_alphabet_size: bool = False
) -> None:
Expand Down
8 changes: 7 additions & 1 deletion diffpass/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,20 +303,23 @@ def __init__(
self,
*,
group_sizes: Optional[Iterable[int]] = None,
use_scoredist: bool = False,
aa_to_int: Optional[dict[str, int]] = None,
gaps_as_stars: bool = True,
) -> None:
super().__init__()
self.group_sizes = (
tuple(s for s in group_sizes) if group_sizes is not None else None
)
self.use_scoredist = use_scoredist
self.aa_to_int = aa_to_int
self.gaps_as_stars = gaps_as_stars

blosum62_data = get_blosum62_data(
aa_to_int=self.aa_to_int, gaps_as_stars=self.gaps_as_stars
)
self.register_buffer("subs_mat", blosum62_data.mat)
self.expected_value = blosum62_data.expected_value

self._group_slices = _consecutive_slices_from_sizes(self.group_sizes)

Expand All @@ -328,7 +331,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
for sl in self._group_slices:
out[..., sl, sl].copy_(
smooth_substitution_matrix_similarities(
x[..., sl, :, :], subs_mat=self.subs_mat
x[..., sl, :, :],
subs_mat=self.subs_mat,
expected_value=self.expected_value,
use_scoredist=self.use_scoredist,
)
)

Expand Down
60 changes: 23 additions & 37 deletions diffpass/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,15 @@ class DiffPASSResults:

# %% ../nbs/train.ipynb 6
class InformationAndReciprocalBestHits(Module, EnsembleMixin, DiffPASSMixin):
allowed_permutation_cfg_keys = {
"tau",
"n_iter",
"noise",
"noise_factor",
"noise_std",
}
allowed_information_measures = {"MI", "TwoBodyEntropy"}
allowed_hamming_similarities_cfg_keys = {"use_dot", "p"}
allowed_reciprocal_best_hits_cfg_keys = {"tau"}

def __init__(
self,
group_sizes: Iterable[int],
fixed_matchings: Optional[Sequence[Sequence[Sequence[int]]]] = None,
loss_weights: Optional[dict[str, Union[float, torch.Tensor]]] = None,
permutation_cfg: Optional[dict[str, Any]] = None,
information_measure: Literal["MI", "TwoBodyEntropy"] = "TwoBodyEntropy",
hamming_similarities_cfg: Optional[dict[str, Any]] = None,
similarity_kind: Literal["Hamming", "Blosum62"] = "Hamming",
similarities_cfg: Optional[dict[str, Any]] = None,
reciprocal_best_hits_cfg: Optional[dict[str, Any]] = None,
inter_group_loss_score_fn: Optional[callable] = None,
):
Expand All @@ -116,7 +106,8 @@ def __init__(
self.loss_weights = loss_weights
self.permutation_cfg = permutation_cfg
self.information_measure = information_measure
self.hamming_similarities_cfg = hamming_similarities_cfg
self.similarity_kind = similarity_kind
self.similarities_cfg = similarities_cfg
self.reciprocal_best_hits_cfg = reciprocal_best_hits_cfg
self.inter_group_loss_score_fn = inter_group_loss_score_fn

Expand All @@ -126,7 +117,7 @@ def __init__(
if permutation_cfg is None:
permutation_cfg = {}
else:
assert set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys)
self.validate_permutation_cfg(permutation_cfg)
permutation_cfg = deepcopy(permutation_cfg)
_dim_in_ensemble = self._adjust_cfg_and_ensemble_shape(
ensemble_shape=ensemble_shape,
Expand All @@ -136,24 +127,21 @@ def __init__(
)
self.effective_permutation_cfg_ = permutation_cfg

assert self.information_measure in self.allowed_information_measures
self.validate_information_measure(information_measure)
self.loss_weights_keys = {self.information_measure, "ReciprocalBestHits"}

if hamming_similarities_cfg is None:
hamming_similarities_cfg = {}
self.validate_similarity_kind(similarity_kind)
if similarities_cfg is None:
similarities_cfg = {}
else:
assert set(hamming_similarities_cfg).issubset(
self.allowed_hamming_similarities_cfg_keys
)
hamming_similarities_cfg = deepcopy(hamming_similarities_cfg)
self.effective_hamming_similarities_cfg_ = hamming_similarities_cfg
self.validate_similarities_cfg(similarities_cfg)
similarities_cfg = deepcopy(similarities_cfg)
self.effective_similarities_cfg_ = similarities_cfg

if reciprocal_best_hits_cfg is None:
reciprocal_best_hits_cfg = {}
else:
assert set(reciprocal_best_hits_cfg).issubset(
self.allowed_reciprocal_best_hits_cfg_keys
)
self.validate_reciprocal_best_hits_cfg(reciprocal_best_hits_cfg)
reciprocal_best_hits_cfg = deepcopy(reciprocal_best_hits_cfg)
_dim_in_ensemble = self._adjust_cfg_and_ensemble_shape(
ensemble_shape=ensemble_shape,
Expand Down Expand Up @@ -190,16 +178,14 @@ def __init__(
self.information_loss = TwoBodyEntropyLoss()
elif self.information_measure == "MI":
self.information_loss = MILoss()
else:
# FIXME Redundant check
raise ValueError(
f"``information_measure`` must be one of {self.allowed_information_measures}."
)

self.information_loss.register_buffer(
"weight", loss_weights[self.information_measure]
)
self.hamming_similarities = HammingSimilarities(**hamming_similarities_cfg)

if similarity_kind == "Blosum62":
self.similarities = Blosum62Similarities(**similarities_cfg)
elif similarity_kind == "Hamming":
self.similarities = HammingSimilarities(**similarities_cfg)
self.reciprocal_best_hits = ReciprocalBestHits(
group_sizes=self.group_sizes,
ensemble_shape=ensemble_shape,
Expand Down Expand Up @@ -301,9 +287,9 @@ def _adjust_loss_weights_and_ensemble_shape(
def _precompute_rbh(self, x: torch.Tensor, y: torch.Tensor) -> None:
# Temporarily switch to hard RBH
self.reciprocal_best_hits.hard_()
similarities_x = self.hamming_similarities(x)
similarities_x = self.similarities(x)
self._rbh_hard_x = self.reciprocal_best_hits(similarities_x)
similarities_y = self.hamming_similarities(y)
similarities_y = self.similarities(y)
self._rbh_hard_y = self.reciprocal_best_hits(similarities_y)

# Revert to soft (default) RBH
Expand Down Expand Up @@ -333,7 +319,7 @@ def forward(
if mode == "soft":
if x_perm_hard is not None:
x_perm = (x_perm_hard - x_perm).detach() + x_perm
similarities_x = self.hamming_similarities(x_perm)
similarities_x = self.similarities(x_perm)
rbh_x = self.reciprocal_best_hits(similarities_x)
else:
rbh_x = apply_hard_permutation_batch_to_similarity(
Expand Down Expand Up @@ -366,7 +352,7 @@ def fit(
optimizer_name: Optional[str] = "SGD",
optimizer_kwargs: Optional[dict[str, Any]] = None,
mean_centering: bool = True,
hamming_gradient_bypass: bool = False,
similarity_gradient_bypass: bool = False,
) -> DiffPASSResults:
# Validate inputs
self.validate_inputs(x, y)
Expand Down Expand Up @@ -427,7 +413,7 @@ def fit(
epoch_results = self(x, y)
loss_info = epoch_results["loss_info"]
loss_rbh = epoch_results["loss_rbh"]
if hamming_gradient_bypass:
if similarity_gradient_bypass:
x_perm_hard = epoch_results["x_perm"]
perms = epoch_results["perms"]
results.log_alphas.append(
Expand Down
17 changes: 13 additions & 4 deletions nbs/_example_prokaryotic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,15 @@
"information_measure = \"TwoBodyEntropy\"\n",
"\n",
"# Settings affecting the reciprocal best hits part of the loss\n",
"hamming_similarities_cfg = {\n",
"similarity_kind = \"Hamming\"\n",
"similarities_cfg = {\n",
" \"use_dot\": False,\n",
" \"p\": 1\n",
"}\n",
"reciprocal_best_hits_cfg = {\n",
" \"tau\": torch.tensor(1e-1)\n",
"}\n",
"inter_group_loss_score_fn = torch.dot #torch.nn.CosineSimilarity(dim=-1)\n",
"inter_group_loss_score_fn = torch.nn.CosineSimilarity(dim=-1)\n",
"\n",
"# Loss weights\n",
"loss_weights = {\n",
Expand All @@ -202,7 +203,8 @@
" loss_weights=loss_weights,\n",
" permutation_cfg=permutation_cfg,\n",
" information_measure=information_measure,\n",
" hamming_similarities_cfg=hamming_similarities_cfg,\n",
" similarity_kind=similarity_kind,\n",
" similarities_cfg=similarities_cfg,\n",
" reciprocal_best_hits_cfg=reciprocal_best_hits_cfg,\n",
" inter_group_loss_score_fn=inter_group_loss_score_fn,\n",
")\n",
Expand Down Expand Up @@ -231,7 +233,7 @@
" \"optimizer_name\": \"SGD\",\n",
" \"optimizer_kwargs\": {\"lr\": 1e-1, \"weight_decay\": 0.},\n",
" \"mean_centering\": True,\n",
" \"hamming_gradient_bypass\": False\n",
" \"similarity_gradient_bypass\": False\n",
"}"
]
},
Expand Down Expand Up @@ -305,6 +307,13 @@
"plt.title(\"Total, soft\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
53 changes: 53 additions & 0 deletions nbs/base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,24 @@
"#| export\n",
"\n",
"class DiffPASSMixin:\n",
" allowed_permutation_cfg_keys = {\n",
" \"tau\",\n",
" \"n_iter\",\n",
" \"noise\",\n",
" \"noise_factor\",\n",
" \"noise_std\",\n",
" }\n",
" allowed_information_measures = {\"MI\", \"TwoBodyEntropy\"}\n",
" allowed_similarity_kinds = {\"Hamming\", \"Blosum62\"}\n",
" allowed_similarities_cfg_keys = {\n",
" \"Hamming\": {\"use_dot\", \"p\"},\n",
" \"Blosum62\": {\"aa_to_int\", \"gaps_as_stars\"},\n",
" }\n",
" allowed_reciprocal_best_hits_cfg_keys = {\"tau\"}\n",
"\n",
" group_sizes: Iterable[int]\n",
" information_measure: str\n",
" similarity_kind: str\n",
"\n",
" @staticmethod\n",
" def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor:\n",
Expand All @@ -62,6 +79,42 @@
"\n",
" return x[..., used_tokens]\n",
"\n",
" def validate_permutation_cfg(self, permutation_cfg: dict) -> None:\n",
" if not set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys):\n",
" raise ValueError(\n",
" f\"Invalid keys in `permutation_cfg`: {set(permutation_cfg) - self.allowed_permutation_cfg_keys}\"\n",
" )\n",
"\n",
" def validate_information_measure(self, information_measure: str) -> None:\n",
" if information_measure not in self.allowed_information_measures:\n",
" raise ValueError(\n",
" f\"Invalid information measure: {self.information_measure}. \"\n",
" f\"Allowed values are: {self.allowed_information_measures}\"\n",
" )\n",
"\n",
" def validate_similarity_kind(self, similarity_kind: str) -> None:\n",
" if similarity_kind not in self.allowed_similarity_kinds:\n",
" raise ValueError(\n",
" f\"Invalid similarity kind: {self.similarity_kind}. \"\n",
" f\"Allowed values are: {self.allowed_similarity_kinds}\"\n",
" )\n",
"\n",
" def validate_similarities_cfg(self, similarities_cfg: dict) -> None:\n",
" if not set(similarities_cfg).issubset(\n",
" self.allowed_similarities_cfg_keys[self.similarity_kind]\n",
" ):\n",
" raise ValueError(\n",
" f\"Invalid keys in `similarities_cfg`: {set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}\"\n",
" )\n",
"\n",
" def validate_reciprocal_best_hits_cfg(self, reciprocal_best_hits_cfg: dict) -> None:\n",
" if not set(reciprocal_best_hits_cfg).issubset(\n",
" self.allowed_reciprocal_best_hits_cfg_keys\n",
" ):\n",
" raise ValueError(\n",
" f\"Invalid keys in `reciprocal_best_hits_cfg`: {set(reciprocal_best_hits_cfg) - self.allowed_reciprocal_best_hits_cfg_keys}\"\n",
" )\n",
"\n",
" def validate_inputs(\n",
" self, x: torch.Tensor, y: torch.Tensor, check_same_alphabet_size: bool = False\n",
" ) -> None:\n",
Expand Down
Loading

0 comments on commit 2056dda

Please sign in to comment.