Skip to content

Commit

Permalink
feat: fix tokenize s2 atomwise in pretrain
Browse files Browse the repository at this point in the history
  • Loading branch information
yandy committed Sep 21, 2024
1 parent ff1e25e commit b7ffc68
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 162 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
strategy:
fail-fast: false
matrix:
py-version: ['3.8', '3.9', '3.10', '3.11']
py-version: ['3.9', '3.10', '3.11']
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion moltx/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.5"
__version__ = "2.0.0"
117 changes: 81 additions & 36 deletions moltx/datasets.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import typing
import torch
from moltx import tokenizers
from moltx import tokenizers, models


class Base:
def __init__(self, tokenizer: tokenizers.MoltxTokenizer, device: torch.device = torch.device('cpu')) -> None:
self.tokenizer = tokenizer
self.device = device

def _tokenize(self, smiles: typing.Sequence[str], seq_len: int = None) -> torch.Tensor:
tks_list = [self.tokenizer(smi) for smi in smiles]
def _tokenize(self, smiles: typing.Sequence[str], seq_len: int = None, spe_dropout: float = 0) -> torch.Tensor:
tks_list = [self.tokenizer(smi, spe_dropout) for smi in smiles]
size = seq_len or max(map(len, tks_list))
out = [self._tokens2tensor(tks, size).unsqueeze(0) for tks in tks_list]
return torch.concat(out)
Expand All @@ -24,27 +24,30 @@ def _tokens2tensor(self, tokens: typing.Sequence[int], size: int) -> torch.Tenso


class AdaMR(Base):
def __call__(self, s1: typing.Sequence[str], s2: typing.Sequence[str], seq_len: int = None) -> typing.Tuple[torch.Tensor]:
def __init__(self, device: torch.device = torch.device('cpu')) -> None:
tokenizer = tokenizers.MoltxTokenizer.from_pretrain(models.AdaMRTokenizerConfig.Spe)
super().__init__(tokenizer=tokenizer, device=device)

def __call__(self, s1: typing.Sequence[str], s2: typing.Sequence[str]) -> typing.Tuple[torch.Tensor]:
if len(s1) != len(s2):
raise RuntimeError("the length of s1 and s2 must be the same!")
bos = self.tokenizer[self.tokenizer.BOS]
eos = self.tokenizer[self.tokenizer.EOS]
src = self._tokenize(s1)
s2tokens = [self.tokenizer(smi) for smi in s2]
size = seq_len or max(map(len, s2tokens)) + 1
tgt = [self._tokens2tensor([bos] + tks, size).unsqueeze(0) for tks in s2tokens]
out = [self._tokens2tensor(tks + [eos], size).unsqueeze(0) for tks in s2tokens]
return src, torch.concat(tgt), torch.concat(out)
src = self._tokenize(s1, spe_dropout=0.2)
ts2 = self._tokenize(s2, spe_dropout=1.0)
bos = self._tokenize([self.tokenizer.BOS for _ in range(len(s2))])
eos = self._tokenize([self.tokenizer.EOS for _ in range(len(s2))])
tgt = torch.concat([bos, ts2], dim=1)
out = torch.concat([ts2, eos], dim=1)
return src, tgt, out


class AdaMRClassifier(AdaMR):
def __call__(self, smiles: typing.Sequence[str], labels: typing.Sequence[int], seq_len: int = None) -> typing.Tuple[torch.Tensor]:
if len(smiles) != len(labels):
raise RuntimeError(
"the length of smiles and labels must be the same!")
head = [self.tokenizer.CLS for _ in range(len(smiles))]
src = self._tokenize(smiles, seq_len)
tgt = self._tokenize(
[f"{self.tokenizer.BOS}{smi}{self.tokenizer.EOS}" for smi in smiles], seq_len)
tgt = self._tokenize(head)
out = torch.tensor(labels, device=self.device)
return src, tgt, out

Expand All @@ -54,43 +57,64 @@ def __call__(self, smiles: typing.Sequence[str], values: typing.Sequence[float],
if len(smiles) != len(values):
raise RuntimeError(
"the length of smiles and values must be the same!")
head = [self.tokenizer.CLS for _ in range(len(smiles))]
src = self._tokenize(smiles, seq_len)
tgt = self._tokenize(
[f"{self.tokenizer.BOS}{smi}{self.tokenizer.EOS}" for smi in smiles], seq_len)
tgt = self._tokenize(head)
out = torch.tensor(values, device=self.device).unsqueeze(-1)
return src, tgt, out


class AdaMRDistGeneration(AdaMR):
def __init__(self, device: torch.device = torch.device('cpu')) -> None:
tokenizer = tokenizers.MoltxTokenizer.from_pretrain(models.AdaMRTokenizerConfig.Atom)
super(AdaMR, self).__init__(tokenizer=tokenizer, device=device)

def __call__(self, smiles: typing.Sequence[str], seq_len: int = None) -> typing.Tuple[torch.Tensor]:
head = [self.tokenizer.CLS for _ in range(len(smiles))]
return super().__call__(head, smiles, seq_len)
seq_len = seq_len and seq_len - 1
src = self._tokenize([self.tokenizer.CLS for _ in range(len(smiles))])
smi = self._tokenize(smiles, seq_len=seq_len)
bos = self._tokenize([self.tokenizer.BOS for _ in range(len(smiles))])
eos = self._tokenize([self.tokenizer.EOS for _ in range(len(smiles))])
tgt = torch.concat([bos, smi], dim=1)
out = torch.concat([smi, eos], dim=1)
return src, tgt, out


class AdaMRGoalGeneration(AdaMR):
def __init__(self, device: torch.device = torch.device('cpu')) -> None:
tokenizer = tokenizers.MoltxTokenizer.from_pretrain(models.AdaMRTokenizerConfig.Atom)
super(AdaMR, self).__init__(tokenizer=tokenizer, device=device)

def __call__(self, smiles: typing.Sequence[str], goals: typing.Sequence[float], seq_len: int = None) -> typing.Tuple[torch.Tensor]:
if len(smiles) != len(goals):
raise RuntimeError(
"the length of smiles and goals must be the same!")
head = [self.tokenizer.CLS for _ in range(len(smiles))]
src, tgt, out = super().__call__(head, smiles, seq_len)
seq_len = seq_len and seq_len - 1
src = self._tokenize([self.tokenizer.CLS for _ in range(len(smiles))])
smi = self._tokenize(smiles, seq_len=seq_len)
bos = self._tokenize([self.tokenizer.BOS for _ in range(len(smiles))])
eos = self._tokenize([self.tokenizer.EOS for _ in range(len(smiles))])
tgt = torch.concat([bos, smi], dim=1)
out = torch.concat([smi, eos], dim=1)
goal = torch.tensor(goals, device=self.device).unsqueeze(-1)
return goal, src, tgt, out


class AdaMR2(Base):
def __call__(self, s1: typing.Sequence[str], s2: typing.Sequence[str], seq_len: int = None) -> typing.Tuple[torch.Tensor]:
def __init__(self, device: torch.device = torch.device('cpu')) -> None:
tokenizer = tokenizers.MoltxTokenizer.from_pretrain(models.AdaMRTokenizerConfig.Spe)
super().__init__(tokenizer=tokenizer, device=device)

def __call__(self, s1: typing.Sequence[str], s2: typing.Sequence[str]) -> typing.Tuple[torch.Tensor]:
if len(s1) != len(s2):
raise RuntimeError("the length of s1 and s2 must be the same!")
bos = self.tokenizer[self.tokenizer.BOS]
eos = self.tokenizer[self.tokenizer.EOS]
s1tokens = [self.tokenizer(smi) for smi in s1]
s2tokens = [self.tokenizer(smi) for smi in s2]
tgt_tokens = [tks1 + [bos] + tks2 for tks1, tks2 in zip(s1tokens, s2tokens)]
out_tokens = [[0] * len(tks1) + tks2 + [eos] for tks1, tks2 in zip(s1tokens, s2tokens)]
size = seq_len or max(map(len, out_tokens))
tgt = torch.concat([self._tokens2tensor(tks, size).unsqueeze(0) for tks in tgt_tokens])
out = torch.concat([self._tokens2tensor(tks, size).unsqueeze(0) for tks in out_tokens])
ts1 = self._tokenize(s1, spe_dropout=0.2)
ts2 = self._tokenize(s2, spe_dropout=1.0)
zero = torch.zeros_like(ts1)
bos = self._tokenize([self.tokenizer.BOS for _ in range(len(s2))])
eos = self._tokenize([self.tokenizer.EOS for _ in range(len(s2))])
tgt = torch.concat([ts1, bos, ts2], dim=1)
out = torch.concat([zero, ts2, eos], dim=1)
return tgt, out


Expand All @@ -100,7 +124,7 @@ def __call__(self, smiles: typing.Sequence[str], labels: typing.Sequence[int], s
raise RuntimeError(
"the length of smiles and labels must be the same!")
tgt = self._tokenize(
[f"{self.tokenizer.BOS}{smi}{self.tokenizer.EOS}" for smi in smiles], seq_len)
[f"{smi}{self.tokenizer.CLS}" for smi in smiles], seq_len)
out = torch.tensor(labels, device=self.device)
return tgt, out

Expand All @@ -111,23 +135,44 @@ def __call__(self, smiles: typing.Sequence[str], values: typing.Sequence[float],
raise RuntimeError(
"the length of smiles and values must be the same!")
tgt = self._tokenize(
[f"{self.tokenizer.BOS}{smi}{self.tokenizer.EOS}" for smi in smiles], seq_len)
[f"{smi}{self.tokenizer.CLS}" for smi in smiles], seq_len)
out = torch.tensor(values, device=self.device).unsqueeze(-1)
return tgt, out


class AdaMR2DistGeneration(AdaMR2):
def __init__(self, device: torch.device = torch.device('cpu')) -> None:
tokenizer = tokenizers.MoltxTokenizer.from_pretrain(models.AdaMRTokenizerConfig.Atom)
super(AdaMR2, self).__init__(tokenizer=tokenizer, device=device)

def __call__(self, smiles: typing.Sequence[str], seq_len: int = None) -> typing.Tuple[torch.Tensor]:
head = [self.tokenizer.CLS for _ in range(len(smiles))]
return super().__call__(head, smiles, seq_len)
seq_len = seq_len and seq_len - 2
head = self._tokenize([self.tokenizer.CLS for _ in range(len(smiles))])
zero = torch.zeros_like(head)
smi = self._tokenize(smiles, seq_len=seq_len)
bos = self._tokenize([self.tokenizer.BOS for _ in range(len(smiles))])
eos = self._tokenize([self.tokenizer.EOS for _ in range(len(smiles))])
tgt = torch.concat([head, bos, smi], dim=1)
out = torch.concat([zero, smi, eos], dim=1)
return tgt, out


class AdaMR2GoalGeneration(AdaMR2):
def __init__(self, device: torch.device = torch.device('cpu')) -> None:
tokenizer = tokenizers.MoltxTokenizer.from_pretrain(models.AdaMRTokenizerConfig.Atom)
super(AdaMR2, self).__init__(tokenizer=tokenizer, device=device)

def __call__(self, smiles: typing.Sequence[str], goals: typing.Sequence[float], seq_len: int = None) -> typing.Tuple[torch.Tensor]:
if len(smiles) != len(goals):
raise RuntimeError(
"the length of smiles and goals must be the same!")
head = [self.tokenizer.CLS for _ in range(len(smiles))]
tgt, out = super().__call__(head, smiles, seq_len)
seq_len = seq_len and seq_len - 2
head = self._tokenize([self.tokenizer.CLS for _ in range(len(smiles))])
zero = torch.zeros_like(head)
smi = self._tokenize(smiles, seq_len=seq_len)
bos = self._tokenize([self.tokenizer.BOS for _ in range(len(smiles))])
eos = self._tokenize([self.tokenizer.EOS for _ in range(len(smiles))])
tgt = torch.concat([head, bos, smi], dim=1)
out = torch.concat([zero, smi, eos], dim=1)
goal = torch.tensor(goals, device=self.device).unsqueeze(-1)
return goal, tgt, out
15 changes: 2 additions & 13 deletions moltx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,16 @@


class AdaMRTokenizerConfig:
Pretrain = tokenizers.MoltxPretrainConfig(
token_size=512,
fmt='smiles',
spe=True,
spe_dropout=0.2,
spe_merges=240
)

Generation = tokenizers.MoltxPretrainConfig(
Atom = tokenizers.MoltxPretrainConfig(
token_size=512,
fmt='smiles',
spe=False,
spe_dropout=1.0,
spe_merges=240
)

Prediction = tokenizers.MoltxPretrainConfig(
Spe = tokenizers.MoltxPretrainConfig(
token_size=512,
fmt='smiles',
spe=True,
spe_dropout=0.0,
spe_merges=240
)

Expand Down
36 changes: 17 additions & 19 deletions moltx/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing
import torch
import torch.nn as nn
from moltx import tokenizers
from moltx import tokenizers, models


class _Base:
Expand Down Expand Up @@ -105,16 +105,18 @@ def _beam_search(self, tgt: torch.Tensor, beam_width: int = 3, **kwds: torch.Ten


class AdaMR(_GenBase):
def __init__(self, model: models.AdaMR, device: torch.device = torch.device('cpu')) -> None:
tokenizer = tokenizers.MoltxTokenizer.from_pretrain(models.AdaMRTokenizerConfig.Spe)
super().__init__(tokenizer, model, device)

def _model_args(self, smiles: str) -> typing.Tuple[torch.Tensor]:
src = self._tokenize(smiles)
tgt = self._tokenize(self.tokenizer.BOS)
return src, tgt

# gentype: greedy, beam
def __call__(self) -> typing.Mapping:
src, tgt = self._model_args("")
smi, prob = self._random_sample(src=src, tgt=tgt)
def __call__(self, smiles: str = "") -> typing.Mapping:
src, tgt = self._model_args(smiles)
smi, prob = self._beam_search(src=src, tgt=tgt, beam_width=3)[0]
return {
'smiles': smi,
'probability': prob
Expand All @@ -124,7 +126,7 @@ def __call__(self) -> typing.Mapping:
class AdaMRClassifier(AdaMR):
def _model_args(self, smiles: str) -> typing.Tuple[torch.Tensor]:
src = self._tokenize(smiles)
tgt = self._tokenize(f"{self.tokenizer.BOS}{smiles}{self.tokenizer.EOS}")
tgt = self._tokenize(self.tokenizer.CLS)
return src, tgt

def __call__(self, smiles: str) -> typing.Mapping:
Expand All @@ -140,7 +142,7 @@ def __call__(self, smiles: str) -> typing.Mapping:
class AdaMRRegression(AdaMR):
def _model_args(self, smiles: str) -> typing.Tuple[torch.Tensor]:
src = self._tokenize(smiles)
tgt = self._tokenize(f"{self.tokenizer.BOS}{smiles}{self.tokenizer.EOS}")
tgt = self._tokenize(self.tokenizer.CLS)
return src, tgt

def __call__(self, smiles: str) -> typing.Mapping:
Expand Down Expand Up @@ -189,11 +191,13 @@ def __call__(self, goal: float, k: int = 1) -> typing.Mapping:


class AdaMR2(_GenBase):
def __init__(self, model: models.AdaMR2, device: torch.device = torch.device('cpu')) -> None:
tokenizer = tokenizers.MoltxTokenizer.from_pretrain(models.AdaMRTokenizerConfig.Spe)
super().__init__(tokenizer, model, device)

# gentype: greedy, beam
def __call__(self) -> typing.Mapping:
tgt = self._tokenize(self.tokenizer.BOS)
smi, prob = self._random_sample(tgt=tgt)
def __call__(self, smiles: str = "") -> typing.Mapping:
tgt = self._tokenize(f"{smiles}{self.tokenizer.BOS}")
smi, prob = self._beam_search(tgt=tgt, beam_width=3)[0]
return {
'smiles': smi,
'probability': prob
Expand All @@ -203,7 +207,7 @@ def __call__(self) -> typing.Mapping:
class AdaMR2Classifier(AdaMR2):

def __call__(self, smiles: str) -> typing.Mapping:
tgt = self._tokenize(f"{self.tokenizer.BOS}{smiles}{self.tokenizer.EOS}")
tgt = self._tokenize(f"{smiles}{self.tokenizer.CLS}")
out = self.model(tgt)
prob, label = out.softmax(-1).max(-1)
return {
Expand All @@ -215,7 +219,7 @@ def __call__(self, smiles: str) -> typing.Mapping:
class AdaMR2Regression(AdaMR2):

def __call__(self, smiles: str) -> typing.Mapping:
tgt = self._tokenize(f"{self.tokenizer.BOS}{smiles}{self.tokenizer.EOS}")
tgt = self._tokenize(f"{smiles}{self.tokenizer.CLS}")
out = self.model(tgt)
return {
'value': out.item()
Expand Down Expand Up @@ -253,9 +257,3 @@ def __call__(self, goal: float, k: int = 1) -> typing.Mapping:
'smiles': smis,
'probabilities': probs
}


class AdaMR2SuperGeneration(AdaMR2):

# TODO
pass
Loading

0 comments on commit b7ffc68

Please sign in to comment.