From 76eaa7535ad248e1b8cf09793a4302a68d8665a0 Mon Sep 17 00:00:00 2001 From: Michael Ding Date: Thu, 13 Jun 2024 18:53:48 +0800 Subject: [PATCH] support constraint generation --- moltx/pipelines.py | 57 +++++++++++++++-------------------------- tests/test_pipelines.py | 38 +++++++++++++++++++++------ 2 files changed, 51 insertions(+), 44 deletions(-) diff --git a/moltx/pipelines.py b/moltx/pipelines.py index 7ac82c3..7c067ee 100644 --- a/moltx/pipelines.py +++ b/moltx/pipelines.py @@ -4,7 +4,7 @@ from moltx import tokenizers -class Base: +class _Base: def __init__(self, tokenizer: tokenizers.MoltxTokenizer, model: nn.Module, device: torch.device = torch.device('cpu')) -> None: self.tokenizer = tokenizer model = model.to(device) @@ -26,6 +26,8 @@ def _tokens2tensor(self, tokens: typing.Sequence[int], size: int) -> torch.Tenso out[i] = tk return out.to(self.device) + +class _GenBase(_Base): @torch.no_grad() def _greedy_search(self, tgt: torch.Tensor, **kwds: torch.Tensor) -> typing.Tuple[typing.Sequence[int], float]: maxlen = self.model.conf.max_len @@ -102,7 +104,7 @@ def _beam_search(self, tgt: torch.Tensor, beam_width: int = 3, **kwds: torch.Ten return sorted(zip(smiles, probs), key=lambda x: x[1], reverse=True) -class AdaMR(Base): +class AdaMR(_GenBase): def _model_args(self, smiles: str) -> typing.Tuple[torch.Tensor]: src = self._tokenize(smiles) @@ -110,25 +112,14 @@ def _model_args(self, smiles: str) -> typing.Tuple[torch.Tensor]: return src, tgt # gentype: greedy, beam - def __call__(self, smiles: str = "") -> typing.Mapping: - src, tgt = self._model_args(smiles) - if len(smiles) > 0: - meth = self._do_canonicalize - else: - meth = self._do_generate - smi, prob = meth(src, tgt) + def __call__(self) -> typing.Mapping: + src, tgt = self._model_args("") + smi, prob = self._random_sample(src=src, tgt=tgt) return { 'smiles': smi, 'probability': prob } - def _do_generate(self, src: torch.Tensor, tgt: torch.Tensor) -> typing.Mapping: - return self._random_sample(src=src, tgt=tgt) - - def _do_canonicalize(self, src: torch.Tensor, tgt: torch.Tensor) -> typing.Mapping: - out = self._beam_search(src=src, tgt=tgt, beam_width=3) - return out[0] - class AdaMRClassifier(AdaMR): def _model_args(self, smiles: str) -> typing.Tuple[torch.Tensor]: @@ -197,31 +188,19 @@ def __call__(self, goal: float, k: int = 1) -> typing.Mapping: } -class AdaMR2(Base): +class AdaMR2(_GenBase): # gentype: greedy, beam - def __call__(self, smiles: str = "") -> typing.Mapping: - tgt = self._tokenize(f"{smiles}{self.tokenizer.BOS}") - if len(smiles) > 0: - meth = self._do_canonicalize - else: - meth = self._do_generate - - smi, prob = meth(tgt) + def __call__(self) -> typing.Mapping: + tgt = self._tokenize(self.tokenizer.BOS) + smi, prob = self._random_sample(tgt=tgt) return { 'smiles': smi, 'probability': prob } - def _do_generate(self, tgt: torch.Tensor) -> typing.Mapping: - return self._random_sample(tgt=tgt) - - def _do_canonicalize(self, tgt: torch.Tensor) -> typing.Mapping: - out = self._beam_search(tgt=tgt) - return out[0] - -class AdaMR2Classifier(AdaMR): +class AdaMR2Classifier(AdaMR2): def __call__(self, smiles: str) -> typing.Mapping: tgt = self._tokenize(f"{self.tokenizer.BOS}{smiles}{self.tokenizer.EOS}") @@ -233,7 +212,7 @@ def __call__(self, smiles: str) -> typing.Mapping: } -class AdaMR2Regression(AdaMR): +class AdaMR2Regression(AdaMR2): def __call__(self, smiles: str) -> typing.Mapping: tgt = self._tokenize(f"{self.tokenizer.BOS}{smiles}{self.tokenizer.EOS}") @@ -243,7 +222,7 @@ def __call__(self, smiles: str) -> typing.Mapping: } -class AdaMR2DistGeneration(AdaMR): +class AdaMR2DistGeneration(AdaMR2): def __call__(self, k: int = 1) -> typing.Mapping: assert k <= 10 @@ -259,7 +238,7 @@ def __call__(self, k: int = 1) -> typing.Mapping: } -class AdaMR2GoalGeneration(AdaMR): +class AdaMR2GoalGeneration(AdaMR2): def __call__(self, goal: float, k: int = 1) -> typing.Mapping: assert k <= 10 @@ -274,3 +253,9 @@ def __call__(self, goal: float, k: int = 1) -> typing.Mapping: 'smiles': smis, 'probabilities': probs } + + +class AdaMR2SuperGeneration(AdaMR2): + + # TODO + pass diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 32720eb..a6eda12 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -15,13 +15,33 @@ def adamr2_conf(): token_size=16, max_len=32, d_model=8, nhead=2,num_layers=2) +def test__GenBase(tokenizer, adamr2_conf): + pipeline = pipelines._GenBase(tokenizer, models.AdaMR2(adamr2_conf)) + tgt = pipeline._tokenize(tokenizer.BOS) + smi1, prob1 = pipeline._greedy_search(tgt) + assert isinstance(smi1, str) and isinstance(prob1, float) + smi2, prob2 = pipeline._greedy_search(tgt) + assert smi2 == smi1 and prob1 == prob2 + tgt2 = pipeline._tokenize(f"{tokenizer.BOS}C=C") + smi2, prob2 = pipeline._greedy_search(tgt2) + assert smi2 != smi1 and prob1 != prob2 + + smi1, prob1 = pipeline._random_sample(tgt) + assert isinstance(smi1, str) and isinstance(prob1, float) + smi2, prob2 = pipeline._random_sample(tgt) + assert smi2 != smi1 and prob1 != prob2 + + out1 = pipeline._beam_search(tgt, beam_width=2) + assert isinstance(out1, list) and len(out1) == 2 + assert out1[0][1] >= out1[1][1] + out2 = pipeline._beam_search(tgt, beam_width=2) + assert out2 == out1 + + def test_AdaMR(tokenizer, adamr_conf): pipeline = pipelines.AdaMR(tokenizer, models.AdaMR(adamr_conf)) - out = pipeline("CC[N+](C)(C)Br") - assert 'smiles' in out and 'probability' in out - assert isinstance(out['smiles'], str) - out = pipeline("") + out = pipeline() assert 'smiles' in out and 'probability' in out assert isinstance(out['smiles'], str) @@ -65,11 +85,8 @@ def test_AdaMRGoalGeneration(tokenizer, adamr_conf): def test_AdaMR2(tokenizer, adamr2_conf): pipeline = pipelines.AdaMR2(tokenizer, models.AdaMR2(adamr2_conf)) - out = pipeline("CC[N+](C)(C)Br") - assert 'smiles' in out and 'probability' in out - assert isinstance(out['smiles'], str) - out = pipeline("") + out = pipeline() assert 'smiles' in out and 'probability' in out assert isinstance(out['smiles'], str) @@ -109,3 +126,8 @@ def test_AdaMR2GoalGeneration(tokenizer, adamr2_conf): assert isinstance(out['smiles'], list) and len(out['smiles']) == 2 and isinstance(out['smiles'][0], str) assert isinstance(out['probabilities'], list) and len( out['probabilities']) == 2 + + +def test_AdaMR2SuperGeneration(tokenizer, adamr2_conf): + # TODO + pass