Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve pipelines and test #39

Merged
merged 1 commit into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 21 additions & 36 deletions moltx/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from moltx import tokenizers


class Base:
class _Base:
def __init__(self, tokenizer: tokenizers.MoltxTokenizer, model: nn.Module, device: torch.device = torch.device('cpu')) -> None:
self.tokenizer = tokenizer
model = model.to(device)
Expand All @@ -26,6 +26,8 @@ def _tokens2tensor(self, tokens: typing.Sequence[int], size: int) -> torch.Tenso
out[i] = tk
return out.to(self.device)


class _GenBase(_Base):
@torch.no_grad()
def _greedy_search(self, tgt: torch.Tensor, **kwds: torch.Tensor) -> typing.Tuple[typing.Sequence[int], float]:
maxlen = self.model.conf.max_len
Expand Down Expand Up @@ -102,33 +104,22 @@ def _beam_search(self, tgt: torch.Tensor, beam_width: int = 3, **kwds: torch.Ten
return sorted(zip(smiles, probs), key=lambda x: x[1], reverse=True)


class AdaMR(Base):
class AdaMR(_GenBase):

def _model_args(self, smiles: str) -> typing.Tuple[torch.Tensor]:
src = self._tokenize(smiles)
tgt = self._tokenize(self.tokenizer.BOS)
return src, tgt

# gentype: greedy, beam
def __call__(self, smiles: str = "") -> typing.Mapping:
src, tgt = self._model_args(smiles)
if len(smiles) > 0:
meth = self._do_canonicalize
else:
meth = self._do_generate
smi, prob = meth(src, tgt)
def __call__(self) -> typing.Mapping:
src, tgt = self._model_args("")
smi, prob = self._random_sample(src=src, tgt=tgt)
return {
'smiles': smi,
'probability': prob
}

def _do_generate(self, src: torch.Tensor, tgt: torch.Tensor) -> typing.Mapping:
return self._random_sample(src=src, tgt=tgt)

def _do_canonicalize(self, src: torch.Tensor, tgt: torch.Tensor) -> typing.Mapping:
out = self._beam_search(src=src, tgt=tgt, beam_width=3)
return out[0]


class AdaMRClassifier(AdaMR):
def _model_args(self, smiles: str) -> typing.Tuple[torch.Tensor]:
Expand Down Expand Up @@ -197,31 +188,19 @@ def __call__(self, goal: float, k: int = 1) -> typing.Mapping:
}


class AdaMR2(Base):
class AdaMR2(_GenBase):

# gentype: greedy, beam
def __call__(self, smiles: str = "") -> typing.Mapping:
tgt = self._tokenize(f"{smiles}{self.tokenizer.BOS}")
if len(smiles) > 0:
meth = self._do_canonicalize
else:
meth = self._do_generate

smi, prob = meth(tgt)
def __call__(self) -> typing.Mapping:
tgt = self._tokenize(self.tokenizer.BOS)
smi, prob = self._random_sample(tgt=tgt)
return {
'smiles': smi,
'probability': prob
}

def _do_generate(self, tgt: torch.Tensor) -> typing.Mapping:
return self._random_sample(tgt=tgt)

def _do_canonicalize(self, tgt: torch.Tensor) -> typing.Mapping:
out = self._beam_search(tgt=tgt)
return out[0]


class AdaMR2Classifier(AdaMR):
class AdaMR2Classifier(AdaMR2):

def __call__(self, smiles: str) -> typing.Mapping:
tgt = self._tokenize(f"{self.tokenizer.BOS}{smiles}{self.tokenizer.EOS}")
Expand All @@ -233,7 +212,7 @@ def __call__(self, smiles: str) -> typing.Mapping:
}


class AdaMR2Regression(AdaMR):
class AdaMR2Regression(AdaMR2):

def __call__(self, smiles: str) -> typing.Mapping:
tgt = self._tokenize(f"{self.tokenizer.BOS}{smiles}{self.tokenizer.EOS}")
Expand All @@ -243,7 +222,7 @@ def __call__(self, smiles: str) -> typing.Mapping:
}


class AdaMR2DistGeneration(AdaMR):
class AdaMR2DistGeneration(AdaMR2):

def __call__(self, k: int = 1) -> typing.Mapping:
assert k <= 10
Expand All @@ -259,7 +238,7 @@ def __call__(self, k: int = 1) -> typing.Mapping:
}


class AdaMR2GoalGeneration(AdaMR):
class AdaMR2GoalGeneration(AdaMR2):

def __call__(self, goal: float, k: int = 1) -> typing.Mapping:
assert k <= 10
Expand All @@ -274,3 +253,9 @@ def __call__(self, goal: float, k: int = 1) -> typing.Mapping:
'smiles': smis,
'probabilities': probs
}


class AdaMR2SuperGeneration(AdaMR2):

# TODO
pass
38 changes: 30 additions & 8 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,33 @@ def adamr2_conf():
token_size=16, max_len=32, d_model=8, nhead=2,num_layers=2)


def test__GenBase(tokenizer, adamr2_conf):
pipeline = pipelines._GenBase(tokenizer, models.AdaMR2(adamr2_conf))
tgt = pipeline._tokenize(tokenizer.BOS)
smi1, prob1 = pipeline._greedy_search(tgt)
assert isinstance(smi1, str) and isinstance(prob1, float)
smi2, prob2 = pipeline._greedy_search(tgt)
assert smi2 == smi1 and prob1 == prob2
tgt2 = pipeline._tokenize(f"{tokenizer.BOS}C=C")
smi2, prob2 = pipeline._greedy_search(tgt2)
assert smi2 != smi1 and prob1 != prob2

smi1, prob1 = pipeline._random_sample(tgt)
assert isinstance(smi1, str) and isinstance(prob1, float)
smi2, prob2 = pipeline._random_sample(tgt)
assert smi2 != smi1 and prob1 != prob2

out1 = pipeline._beam_search(tgt, beam_width=2)
assert isinstance(out1, list) and len(out1) == 2
assert out1[0][1] >= out1[1][1]
out2 = pipeline._beam_search(tgt, beam_width=2)
assert out2 == out1


def test_AdaMR(tokenizer, adamr_conf):
pipeline = pipelines.AdaMR(tokenizer, models.AdaMR(adamr_conf))
out = pipeline("CC[N+](C)(C)Br")
assert 'smiles' in out and 'probability' in out
assert isinstance(out['smiles'], str)

out = pipeline("")
out = pipeline()
assert 'smiles' in out and 'probability' in out
assert isinstance(out['smiles'], str)

Expand Down Expand Up @@ -65,11 +85,8 @@ def test_AdaMRGoalGeneration(tokenizer, adamr_conf):

def test_AdaMR2(tokenizer, adamr2_conf):
pipeline = pipelines.AdaMR2(tokenizer, models.AdaMR2(adamr2_conf))
out = pipeline("CC[N+](C)(C)Br")
assert 'smiles' in out and 'probability' in out
assert isinstance(out['smiles'], str)

out = pipeline("")
out = pipeline()
assert 'smiles' in out and 'probability' in out
assert isinstance(out['smiles'], str)

Expand Down Expand Up @@ -109,3 +126,8 @@ def test_AdaMR2GoalGeneration(tokenizer, adamr2_conf):
assert isinstance(out['smiles'], list) and len(out['smiles']) == 2 and isinstance(out['smiles'][0], str)
assert isinstance(out['probabilities'], list) and len(
out['probabilities']) == 2


def test_AdaMR2SuperGeneration(tokenizer, adamr2_conf):
# TODO
pass
Loading