Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding layer specific cutoff and correlation to multi-GPU branch #214

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions mace/data/atomic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Sequence

import torch.utils.data
import numpy as np

from mace.tools import (
AtomicNumberTable,
Expand Down Expand Up @@ -46,6 +47,7 @@ class AtomicData(torch_geometric.data.Data):
def __init__(
self,
edge_index: torch.Tensor, # [2, n_edges]
edge_index_mask: Optional[torch.Tensor], # [n_layers, n_edges]
node_attrs: torch.Tensor, # [n_nodes, n_node_feats]
positions: torch.Tensor, # [n_nodes, 3]
shifts: torch.Tensor, # [n_edges, 3],
Expand All @@ -67,6 +69,7 @@ def __init__(
num_nodes = node_attrs.shape[0]

assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2
assert edge_index_mask.shape[1] == edge_index.shape[1]
assert positions.shape == (num_nodes, 3)
assert shifts.shape[1] == 3
assert unit_shifts.shape[1] == 3
Expand All @@ -87,6 +90,7 @@ def __init__(
data = {
"num_nodes": num_nodes,
"edge_index": edge_index,
'edge_mask': edge_index_mask.T,
"positions": positions,
"shifts": shifts,
"unit_shifts": unit_shifts,
Expand All @@ -108,11 +112,23 @@ def __init__(

@classmethod
def from_config(
cls, config: Configuration, z_table: AtomicNumberTable, cutoff: float
cls, config: Configuration, z_table: AtomicNumberTable, cutoff: list
) -> "AtomicData":

# Get egdge index for larges cutoff
edge_index, shifts, unit_shifts = get_neighborhood(
positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell
positions=config.positions,
cutoff=torch.max(cutoff).item(),
pbc=config.pbc,
cell=config.cell,
)

# Create edge mask for each cutoff distance
edge_distance = np.linalg.norm(
config.positions[edge_index[0]] - config.positions[edge_index[1]] - shifts,
axis=1,
)
edge_index_mask = torch.tensor(edge_distance, device=cutoff.device) < cutoff[:, None]
indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table)
one_hot = to_one_hot(
torch.tensor(indices, dtype=torch.long).unsqueeze(-1),
Expand Down Expand Up @@ -192,6 +208,7 @@ def from_config(

return cls(
edge_index=torch.tensor(edge_index, dtype=torch.long),
edge_index_mask=torch.tensor(edge_index_mask, dtype=torch.bool),
positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()),
shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()),
unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()),
Expand Down
60 changes: 40 additions & 20 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
class MACE(torch.nn.Module):
def __init__(
self,
r_max: float,
r_max: List[float],
num_bessel: int,
num_polynomial_cutoff: int,
max_ell: int,
Expand All @@ -52,7 +52,7 @@ def __init__(
atomic_energies: np.ndarray,
avg_num_neighbors: float,
atomic_numbers: List[int],
correlation: int,
correlations: List[int],
gate: Optional[Callable],
):
super().__init__()
Expand All @@ -63,18 +63,23 @@ def __init__(
self.register_buffer(
"num_interactions", torch.tensor(num_interactions, dtype=torch.int64)
)

# Radial embedding, interactions and readouts
self.radial_embeddings = torch.nn.ModuleList()

# Embedding
node_attr_irreps = o3.Irreps([(num_elements, (0, 1))])
node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))])
self.node_embedding = LinearNodeEmbeddingBlock(
irreps_in=node_attr_irreps, irreps_out=node_feats_irreps
)
self.radial_embedding = RadialEmbeddingBlock(
r_max=r_max,
radial_embedding_first = RadialEmbeddingBlock(
r_max=r_max[0],
num_bessel=num_bessel,
num_polynomial_cutoff=num_polynomial_cutoff,
)
edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
edge_feats_irreps_first = o3.Irreps(f"{radial_embedding_first.out_dim}x0e")
self.radial_embeddings.append(radial_embedding_first)

sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
num_features = hidden_irreps.count(o3.Irrep(0, 1))
Expand All @@ -90,7 +95,7 @@ def __init__(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=node_feats_irreps,
edge_attrs_irreps=sh_irreps,
edge_feats_irreps=edge_feats_irreps,
edge_feats_irreps=edge_feats_irreps_first,
target_irreps=interaction_irreps,
hidden_irreps=hidden_irreps,
avg_num_neighbors=avg_num_neighbors,
Expand All @@ -106,7 +111,7 @@ def __init__(
prod = EquivariantProductBasisBlock(
node_feats_irreps=node_feats_irreps_out,
target_irreps=hidden_irreps,
correlation=correlation,
correlation=correlations[0],
num_elements=num_elements,
use_sc=use_sc_first,
)
Expand All @@ -122,6 +127,15 @@ def __init__(
) # Select only scalars for last layer
else:
hidden_irreps_out = hidden_irreps

radial_embedding = RadialEmbeddingBlock(
r_max=r_max[i + 1],
num_bessel=num_bessel,
num_polynomial_cutoff=num_polynomial_cutoff,
)
edge_feats_irreps = o3.Irreps(f"{radial_embedding.out_dim}x0e")
self.radial_embeddings.append(radial_embedding)

inter = interaction_cls(
node_attrs_irreps=node_attr_irreps,
node_feats_irreps=hidden_irreps,
Expand All @@ -135,7 +149,7 @@ def __init__(
prod = EquivariantProductBasisBlock(
node_feats_irreps=interaction_irreps,
target_irreps=hidden_irreps_out,
correlation=correlation,
correlation=correlations[i + 1],
num_elements=num_elements,
use_sc=True,
)
Expand Down Expand Up @@ -193,20 +207,23 @@ def forward(
shifts=data["shifts"],
)
edge_attrs = self.spherical_harmonics(vectors)
edge_feats = self.radial_embedding(lengths)

# Interactions
energies = [e0]
node_energies_list = [node_e0]
for interaction, product, readout in zip(
self.interactions, self.products, self.readouts
for i, (interaction, product, readout, radial_embedding) in enumerate(
zip(self.interactions, self.products, self.readouts, self.radial_embeddings)
):
edge_index_mask = data["edge_mask"][i, :]
edge_attrsi = edge_attrs[edge_index_mask]
edge_featsi = radial_embedding(lengths[edge_index_mask])

node_feats, sc = interaction(
node_attrs=data["node_attrs"],
node_feats=node_feats,
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=data["edge_index"],
edge_attrs=edge_attrsi,
edge_feats=edge_featsi,
edge_index=data["edge_index"][:, edge_index_mask],
)
node_feats = product(
node_feats=node_feats,
Expand Down Expand Up @@ -307,19 +324,22 @@ def forward(
shifts=data["shifts"],
)
edge_attrs = self.spherical_harmonics(vectors)
edge_feats = self.radial_embedding(lengths)

# Interactions
node_es_list = []
for interaction, product, readout in zip(
self.interactions, self.products, self.readouts
for i, (interaction, product, readout, radial_embedding) in enumerate(
zip(self.interactions, self.products, self.readouts, self.radial_embeddings)
):
edge_index_mask = data["edge_mask"][:,i]
edge_attrsi = edge_attrs[edge_index_mask]
edge_featsi = radial_embedding(lengths[edge_index_mask])

node_feats, sc = interaction(
node_attrs=data["node_attrs"],
node_feats=node_feats,
edge_attrs=edge_attrs,
edge_feats=edge_feats,
edge_index=data["edge_index"],
edge_attrs=edge_attrsi,
edge_feats=edge_featsi,
edge_index=data["edge_index"][:, edge_index_mask],
)
node_feats = product(
node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"]
Expand Down
28 changes: 15 additions & 13 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
],
)
parser.add_argument(
"--r_max", help="distance cutoff (in Ang)",
type=float,
default=5.0
"--r_max", help="distance cutoff (in Ang). Float or list of values for each layer", type=str, default=5.0
)
parser.add_argument(
"--num_radial_basis",
Expand Down Expand Up @@ -131,7 +129,11 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--max_ell", help=r"highest \ell of spherical harmonics", type=int, default=3
)
parser.add_argument(
"--correlation", help="correlation order at each layer", type=int, default=3
"--correlation",
help="correlation order at each layer. "
"Can be list of values for each layer or single int ",
type=str,
default=3,
)
parser.add_argument(
"--num_interactions", help="number of interactions", type=int, default=2
Expand Down Expand Up @@ -202,7 +204,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:

# Dataset
parser.add_argument(
"--train_file", help="Training set file, format is .xyz or .h5", type=str,
"--train_file", help="Training set file, format is .xyz or .h5", type=str,
required=True,
)
parser.add_argument(
Expand Down Expand Up @@ -247,7 +249,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--pin_memory",
help="Pin memory for data loading",
default=True,
default=True,
type=bool,
)
parser.add_argument(
Expand Down Expand Up @@ -553,8 +555,8 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--num_process",
help="The user defined number of processes to use, as well as the number of files created.",
type=int,
help="The user defined number of processes to use, as well as the number of files created.",
type=int,
default=int(os.cpu_count()/4)
)
parser.add_argument(
Expand All @@ -578,8 +580,8 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser:
default="",
)
parser.add_argument(
"--r_max", help="distance cutoff (in Ang)",
type=float,
"--r_max", help="distance cutoff (in Ang)",
type=float,
default=5.0
)
parser.add_argument(
Expand Down Expand Up @@ -638,9 +640,9 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser:
default=False,
)
parser.add_argument(
"--batch_size",
help="batch size to compute average number of neighbours",
type=int,
"--batch_size",
help="batch size to compute average number of neighbours",
type=int,
default=16,
)

Expand Down
Loading