Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AlkEthOH interaction-typing task #20

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,9 @@ dmypy.json

# Parm@Frosst download
parm_at_Frosst.tgz

# PyCharm
.idea


.DS_Store
3 changes: 3 additions & 0 deletions espaloma/data/alkethoh/.gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
AlkEthOH_rings.npz filter=lfs diff=lfs merge=lfs -text
AlkEthOH_rings.smi filter=lfs diff=lfs merge=lfs -text
AlkEthOH_rings_offmols.pkl filter=lfs diff=lfs merge=lfs -text
3 changes: 3 additions & 0 deletions espaloma/data/alkethoh/AlkEthOH_rings.npz
Git LFS file not shown
3 changes: 3 additions & 0 deletions espaloma/data/alkethoh/AlkEthOH_rings.smi
Git LFS file not shown
3 changes: 3 additions & 0 deletions espaloma/data/alkethoh/AlkEthOH_rings_offmols.pkl
Git LFS file not shown
Empty file.
119 changes: 119 additions & 0 deletions espaloma/data/alkethoh/label_molecules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""label every molecule in AlkEthOH rings set"""

import os
import urllib
from pickle import dump

import numpy as np
from openforcefield.topology import Molecule
from openforcefield.typing.engines.smirnoff import ForceField
from pkg_resources import resource_filename
from tqdm import tqdm

alkethoh_url = 'https://raw.githubusercontent.com/openforcefield/open-forcefield-data/e07bde16c34a3fa1d73ab72e2b8aeab7cd6524df/Model-Systems/AlkEthOH_distrib/AlkEthOH_rings.smi'

path_to_smiles = resource_filename('espaloma.data.alkethoh', 'AlkEthOH_rings.smi')
path_to_offmols = resource_filename('espaloma.data.alkethoh', 'AlkEthOH_rings_offmols.pkl')
path_to_npz = resource_filename('espaloma.data.alkethoh', 'AlkEthOH_rings.npz')


def download_alkethoh():
if not os.path.exists(path_to_smiles):
with urllib.request.urlopen(alkethoh_url) as response:
smi = response.read()
with open(path_to_smiles, 'wb') as f:
f.write(smi)


# Load the OpenFF "Parsley" force field. -- pick unconstrained so that Hbond stretches are sampled...
forcefield = ForceField('openff_unconstrained-1.0.0.offxml')


## loading molecules
# TODO: replace mol_from_smiles with something that reads the mol2 files directly...
def mol_from_smiles(smiles):
mol = Molecule.from_smiles(smiles, hydrogens_are_explicit=False, allow_undefined_stereo=True)
return mol


## labeling molecules
def label_mol(mol):
return forcefield.label_molecules(mol.to_topology())[0]


def get_inds_and_labels(labeled_mol, type_name='vdW'):
terms = labeled_mol[type_name]
inds = np.array(list(terms.keys()))
labels = np.array([int(term.id[1:]) for term in terms.values()])

assert (len(inds) == len(labels))

return inds, labels


from functools import partial

get_labeled_atoms = partial(get_inds_and_labels, type_name='vdW')
get_labeled_bonds = partial(get_inds_and_labels, type_name='Bonds')
get_labeled_angles = partial(get_inds_and_labels, type_name='Angles')
get_labeled_torsions = partial(get_inds_and_labels, type_name='ProperTorsions')

if __name__ == '__main__':
# download, if it isn't already present
download_alkethoh()

# load molecules
with open(path_to_smiles, 'r') as f:
smiles = [l.split()[0] for l in f.readlines()]
with open(path_to_smiles, 'r') as f:
names = [l.split()[1] for l in f.readlines()]

mols = dict()
for i in range(len(names)):
mols[names[i]] = Molecule.from_smiles(smiles[i], allow_undefined_stereo=True)

with open(path_to_offmols, 'wb') as f:
dump(mols, f)

# Label molecules using forcefield
# Takes about ~200ms per molecule -- can do ~1000 molecules in ~5-6 minutes, sequentially
labeled_mols = dict()
for name in tqdm(names):
labeled_mols[name] = label_mol(mols[name])

label_dict = dict()
n_atoms, n_bonds, n_angles, n_torsions = 0, 0, 0, 0

for name in names:
labeled_mol = labeled_mols[name]
label_dict[f'{name}_atom_inds'], label_dict[f'{name}_atom_labels'] = get_labeled_atoms(labeled_mol)
n_atoms += len(label_dict[f'{name}_atom_inds'])
label_dict[f'{name}_bond_inds'], label_dict[f'{name}_bond_labels'] = get_labeled_bonds(labeled_mol)
n_bonds += len(label_dict[f'{name}_bond_inds'])
label_dict[f'{name}_angle_inds'], label_dict[f'{name}_angle_labels'] = get_labeled_angles(labeled_mol)
n_angles += len(label_dict[f'{name}_angle_inds'])
label_dict[f'{name}_torsion_inds'], label_dict[f'{name}_torsion_labels'] = get_labeled_torsions(labeled_mol)
n_torsions += len(label_dict[f'{name}_torsion_inds'])
summary = f'# atoms: {n_atoms}, # bonds: {n_bonds}, # angles: {n_angles}, # torsions: {n_torsions}'
print(summary)

# save to compressed array
description = f"""
Each of the molecules in AlkEthOH_rings.smi
{alkethoh_url}

is labeled according to the forcefield `openff_unconstrained-1.0.0.offxml`:
https://github.com/openforcefield/openforcefields/blob/master/openforcefields/offxml/openff_unconstrained-1.0.0.offxml

Keys are of the form
<name>_<atom|bond|angle|torsion>_<inds|labels>

such as 'AlkEthOH_r0_atom_inds' or 'AlkEthOH_r0_torsion_labels'

and values are integer arrays.

{summary}
"""
np.savez_compressed(path_to_npz,
description=description,
**label_dict)
108 changes: 108 additions & 0 deletions espaloma/data/alkethoh/pytorch_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""provide pytorch dataset views of the interaction typing dataset"""

from pickle import load
from typing import Tuple

import numpy as np
import torch
from openforcefield.topology import Molecule
from torch import tensor, Tensor
from torch.nn import CrossEntropyLoss
from torch.utils.data.dataset import Dataset


class AlkEthOHDataset(Dataset):
def __init__(self):
with open('AlkEthOH_rings_offmols.pkl', 'rb') as f:
self._mols = load(f)

self._label_dict = np.load('AlkEthOH_rings.npz')
self._mol_names = sorted(list(self._mols.keys()))

def _get_inds(self, mol_name: str, type_name: str) -> Tensor:
return tensor(self._label_dict[f'{mol_name}_{type_name}_inds'])

def _get_labels(self, mol_name: str, type_name: str) -> Tensor:
return tensor(self._label_dict[f'{mol_name}_{type_name}_labels'])

def _get_all_unique_labels(self, type_name: str):
all_labels = set()
for mol_name in self._mol_names:
new_labels = set(self._label_dict[f'{mol_name}_{type_name}_labels'])
all_labels.update(new_labels)
return sorted(list(all_labels))

def _get_mol_inds_labels(self, index: int, type_name: str) -> Tuple[Molecule, Tensor, Tensor]:
mol_name = self._mol_names[index]
mol = self._mols[mol_name]
inds = self._get_inds(mol_name, type_name)
labels = self._get_labels(mol_name, type_name)
return mol, inds, labels

def loss(self, index: int, predictions: Tensor) -> Tensor:
raise (NotImplementedError)

def __len__(self):
return len(self._mol_names)


class AlkEthOHTypesDataset(AlkEthOHDataset):
def __init__(self, type_name='atom'):
super().__init__()
self.type_name = type_name
all_labels = self._get_all_unique_labels(self.type_name)

self._label_mapping = dict(zip(all_labels, range(len(all_labels))))
self.n_classes = len(self._label_mapping)

def __getitem__(self, index) -> Tuple[Molecule, Tensor, Tensor]:
mol, inds, _labels = self._get_mol_inds_labels(index, self.type_name)
labels = tensor([self._label_mapping[int(i)] for i in _labels])
return mol, inds, labels

def loss(self, index: int, predictions: Tensor) -> Tensor:
"""cross entropy loss"""
_, _, target = self[index]
assert (predictions.shape == (len(target), self.n_classes))

return CrossEntropyLoss()(predictions, target)


class AlkEthOHAtomTypesDataset(AlkEthOHTypesDataset):
def __init__(self):
super().__init__(type_name='atom')


class AlkEthOHBondTypesDataset(AlkEthOHTypesDataset):
def __init__(self):
super().__init__(type_name='bond')


class AlkEthOHAngleTypesDataset(AlkEthOHTypesDataset):
def __init__(self):
super().__init__(type_name='angle')


class AlkEthOHTorsionTypesDataset(AlkEthOHTypesDataset):
def __init__(self):
super().__init__(type_name='torsion')


if __name__ == '__main__':
# TODO: move this from __main__ into doctests

datasets = [AlkEthOHAtomTypesDataset(), AlkEthOHBondTypesDataset(),
AlkEthOHAngleTypesDataset(), AlkEthOHTorsionTypesDataset()]
for dataset in datasets:
# check that you can differentiate w.r.t. predictions
print(dataset.__class__.__name__)

n_classes = dataset.n_classes

index = np.random.randint(0, len(dataset))
mol, inds, labels = dataset[index]
n_labeled_entitites = len(labels)
predictions = torch.randn(n_labeled_entitites, n_classes, requires_grad=True)
loss = dataset.loss(index, predictions)
loss.backward()
print(predictions.grad)