diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 31170d3d..64c19752 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -7,6 +7,7 @@ from typing import Optional, Sequence import torch.utils.data +import numpy as np from mace.tools import ( AtomicNumberTable, @@ -46,6 +47,7 @@ class AtomicData(torch_geometric.data.Data): def __init__( self, edge_index: torch.Tensor, # [2, n_edges] + edge_index_mask: Optional[torch.Tensor], # [n_layers, n_edges] node_attrs: torch.Tensor, # [n_nodes, n_node_feats] positions: torch.Tensor, # [n_nodes, 3] shifts: torch.Tensor, # [n_edges, 3], @@ -67,6 +69,7 @@ def __init__( num_nodes = node_attrs.shape[0] assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2 + assert edge_index_mask.shape[1] == edge_index.shape[1] assert positions.shape == (num_nodes, 3) assert shifts.shape[1] == 3 assert unit_shifts.shape[1] == 3 @@ -87,6 +90,7 @@ def __init__( data = { "num_nodes": num_nodes, "edge_index": edge_index, + 'edge_mask': edge_index_mask.T, "positions": positions, "shifts": shifts, "unit_shifts": unit_shifts, @@ -108,11 +112,23 @@ def __init__( @classmethod def from_config( - cls, config: Configuration, z_table: AtomicNumberTable, cutoff: float + cls, config: Configuration, z_table: AtomicNumberTable, cutoff: list ) -> "AtomicData": + + # Get egdge index for larges cutoff edge_index, shifts, unit_shifts = get_neighborhood( - positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell + positions=config.positions, + cutoff=torch.max(cutoff).item(), + pbc=config.pbc, + cell=config.cell, + ) + + # Create edge mask for each cutoff distance + edge_distance = np.linalg.norm( + config.positions[edge_index[0]] - config.positions[edge_index[1]] - shifts, + axis=1, ) + edge_index_mask = torch.tensor(edge_distance, device=cutoff.device) < cutoff[:, None] indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) one_hot = to_one_hot( torch.tensor(indices, dtype=torch.long).unsqueeze(-1), @@ -192,6 +208,7 @@ def from_config( return cls( edge_index=torch.tensor(edge_index, dtype=torch.long), + edge_index_mask=torch.tensor(edge_index_mask, dtype=torch.bool), positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()), shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()), unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), diff --git a/mace/modules/models.py b/mace/modules/models.py index 4a05fa8b..0c3abdf2 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -39,7 +39,7 @@ class MACE(torch.nn.Module): def __init__( self, - r_max: float, + r_max: List[float], num_bessel: int, num_polynomial_cutoff: int, max_ell: int, @@ -52,7 +52,7 @@ def __init__( atomic_energies: np.ndarray, avg_num_neighbors: float, atomic_numbers: List[int], - correlation: int, + correlations: List[int], gate: Optional[Callable], ): super().__init__() @@ -63,18 +63,23 @@ def __init__( self.register_buffer( "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) ) + + # Radial embedding, interactions and readouts + self.radial_embeddings = torch.nn.ModuleList() + # Embedding node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) self.node_embedding = LinearNodeEmbeddingBlock( irreps_in=node_attr_irreps, irreps_out=node_feats_irreps ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, + radial_embedding_first = RadialEmbeddingBlock( + r_max=r_max[0], num_bessel=num_bessel, num_polynomial_cutoff=num_polynomial_cutoff, ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + edge_feats_irreps_first = o3.Irreps(f"{radial_embedding_first.out_dim}x0e") + self.radial_embeddings.append(radial_embedding_first) sh_irreps = o3.Irreps.spherical_harmonics(max_ell) num_features = hidden_irreps.count(o3.Irrep(0, 1)) @@ -90,7 +95,7 @@ def __init__( node_attrs_irreps=node_attr_irreps, node_feats_irreps=node_feats_irreps, edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, + edge_feats_irreps=edge_feats_irreps_first, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps, avg_num_neighbors=avg_num_neighbors, @@ -106,7 +111,7 @@ def __init__( prod = EquivariantProductBasisBlock( node_feats_irreps=node_feats_irreps_out, target_irreps=hidden_irreps, - correlation=correlation, + correlation=correlations[0], num_elements=num_elements, use_sc=use_sc_first, ) @@ -122,6 +127,15 @@ def __init__( ) # Select only scalars for last layer else: hidden_irreps_out = hidden_irreps + + radial_embedding = RadialEmbeddingBlock( + r_max=r_max[i + 1], + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + edge_feats_irreps = o3.Irreps(f"{radial_embedding.out_dim}x0e") + self.radial_embeddings.append(radial_embedding) + inter = interaction_cls( node_attrs_irreps=node_attr_irreps, node_feats_irreps=hidden_irreps, @@ -135,7 +149,7 @@ def __init__( prod = EquivariantProductBasisBlock( node_feats_irreps=interaction_irreps, target_irreps=hidden_irreps_out, - correlation=correlation, + correlation=correlations[i + 1], num_elements=num_elements, use_sc=True, ) @@ -193,20 +207,23 @@ def forward( shifts=data["shifts"], ) edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding(lengths) # Interactions energies = [e0] node_energies_list = [node_e0] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts + for i, (interaction, product, readout, radial_embedding) in enumerate( + zip(self.interactions, self.products, self.readouts, self.radial_embeddings) ): + edge_index_mask = data["edge_mask"][i, :] + edge_attrsi = edge_attrs[edge_index_mask] + edge_featsi = radial_embedding(lengths[edge_index_mask]) + node_feats, sc = interaction( node_attrs=data["node_attrs"], node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], + edge_attrs=edge_attrsi, + edge_feats=edge_featsi, + edge_index=data["edge_index"][:, edge_index_mask], ) node_feats = product( node_feats=node_feats, @@ -307,19 +324,22 @@ def forward( shifts=data["shifts"], ) edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding(lengths) # Interactions node_es_list = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts + for i, (interaction, product, readout, radial_embedding) in enumerate( + zip(self.interactions, self.products, self.readouts, self.radial_embeddings) ): + edge_index_mask = data["edge_mask"][:,i] + edge_attrsi = edge_attrs[edge_index_mask] + edge_featsi = radial_embedding(lengths[edge_index_mask]) + node_feats, sc = interaction( node_attrs=data["node_attrs"], node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], + edge_attrs=edge_attrsi, + edge_feats=edge_featsi, + edge_index=data["edge_index"][:, edge_index_mask], ) node_feats = product( node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 4f8e747e..3ae0c599 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -91,9 +91,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: ], ) parser.add_argument( - "--r_max", help="distance cutoff (in Ang)", - type=float, - default=5.0 + "--r_max", help="distance cutoff (in Ang). Float or list of values for each layer", type=str, default=5.0 ) parser.add_argument( "--num_radial_basis", @@ -131,7 +129,11 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--max_ell", help=r"highest \ell of spherical harmonics", type=int, default=3 ) parser.add_argument( - "--correlation", help="correlation order at each layer", type=int, default=3 + "--correlation", + help="correlation order at each layer. " + "Can be list of values for each layer or single int ", + type=str, + default=3, ) parser.add_argument( "--num_interactions", help="number of interactions", type=int, default=2 @@ -202,7 +204,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: # Dataset parser.add_argument( - "--train_file", help="Training set file, format is .xyz or .h5", type=str, + "--train_file", help="Training set file, format is .xyz or .h5", type=str, required=True, ) parser.add_argument( @@ -247,7 +249,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--pin_memory", help="Pin memory for data loading", - default=True, + default=True, type=bool, ) parser.add_argument( @@ -553,8 +555,8 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--num_process", - help="The user defined number of processes to use, as well as the number of files created.", - type=int, + help="The user defined number of processes to use, as well as the number of files created.", + type=int, default=int(os.cpu_count()/4) ) parser.add_argument( @@ -578,8 +580,8 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser: default="", ) parser.add_argument( - "--r_max", help="distance cutoff (in Ang)", - type=float, + "--r_max", help="distance cutoff (in Ang)", + type=float, default=5.0 ) parser.add_argument( @@ -638,9 +640,9 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser: default=False, ) parser.add_argument( - "--batch_size", - help="batch size to compute average number of neighbours", - type=int, + "--batch_size", + help="batch size to compute average number of neighbours", + type=int, default=16, ) diff --git a/scripts/run_train.py b/scripts/run_train.py index cfe3ee36..c7b5d751 100644 --- a/scripts/run_train.py +++ b/scripts/run_train.py @@ -52,22 +52,22 @@ def main() -> None: torch.distributed.init_process_group(backend='nccl') else: rank = int(0) - + # Setup tools.set_seeds(args.seed) tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) - + if args.distributed: torch.cuda.set_device(local_rank) logging.info(f"Process group initialized: {torch.distributed.is_initialized()}") logging.info(f"Processes: {world_size}") - + try: logging.info(f"MACE version: {mace.__version__}") except AttributeError: logging.info("Cannot find MACE version, please install MACE via pip") logging.info(f"Configuration: {args}") - + tools.set_default_dtype(args.default_dtype) device = tools.init_device(args.device) @@ -111,7 +111,7 @@ def main() -> None: ) else: atomic_energies_dict = None - + # Atomic number table # yapf: disable if args.atomic_numbers is None: @@ -165,53 +165,72 @@ def main() -> None: [atomic_energies_dict[z] for z in z_table.zs] ) logging.info(f"Atomic energies: {atomic_energies.tolist()}") + # Support different settings for each layer + r_max = ast.literal_eval(args.r_max) + correlation = ast.literal_eval(args.correlation) + print(type(r_max), type(correlation)) + + if isinstance(r_max, (list, tuple, np.ndarray)): + r_max = torch.tensor(r_max, dtype=torch.get_default_dtype()) + else: + r_max = torch.tensor( + [r_max] * args.num_interactions, dtype=torch.get_default_dtype() + ) + if isinstance(correlation, (list, tuple, np.ndarray)): + correlation = torch.tensor(correlation, dtype=torch.int) + else: + correlation = torch.tensor( + [correlation] * args.num_interactions, dtype=torch.int + ) + assert r_max.shape == correlation.shape, f"Rmax and Correlation must have same shape: {r_max.shape} != {correlation.shape}" + assert r_max.shape[0] == args.num_interactions, f"Rmax and Correlation must have length num_interactions: {r_max.shape[0]} != {args.num_interactions}" if args.train_file.endswith(".xyz"): train_set = [ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + data.AtomicData.from_config(config, z_table=z_table, cutoff=r_max) for config in collections.train ] valid_set = [ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + data.AtomicData.from_config(config, z_table=z_table, cutoff=r_max) for config in collections.valid ] elif args.train_file.endswith(".h5"): train_set = HDF5Dataset( - args.train_file, r_max=args.r_max, z_table=z_table + args.train_file, r_max=r_max, z_table=z_table ) valid_set = HDF5Dataset( - args.valid_file, r_max=args.r_max, z_table=z_table + args.valid_file, r_max=r_max, z_table=z_table ) else: # This case would be for when the file path is to a directory of multiple .h5 files train_set = dataset_from_sharded_hdf5( - args.train_file, r_max=args.r_max, z_table=z_table + args.train_file, r_max=r_max, z_table=z_table ) valid_set = dataset_from_sharded_hdf5( - args.valid_file, r_max=args.r_max, z_table=z_table + args.valid_file, r_max=r_max, z_table=z_table ) - + train_sampler, valid_sampler = None, None if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( - train_set, - num_replicas=world_size, + train_set, + num_replicas=world_size, rank=rank, shuffle=True, drop_last=True, seed=args.seed, ) valid_sampler = torch.utils.data.distributed.DistributedSampler( - valid_set, - num_replicas=world_size, + valid_set, + num_replicas=world_size, rank=rank, shuffle=True, drop_last=True, seed=args.seed, ) - + train_loader = torch_geometric.dataloader.DataLoader( dataset=train_set, - batch_size=args.batch_size, + batch_size=args.batch_size, sampler=train_sampler, shuffle=(train_sampler is None), drop_last=False, @@ -227,7 +246,7 @@ def main() -> None: pin_memory=args.pin_memory, num_workers=args.num_workers, ) - + loss_fn: torch.nn.Module = get_loss_fn( args.loss, args.energy_weight, @@ -285,7 +304,7 @@ def main() -> None: logging.info(f"Hidden irreps: {args.hidden_irreps}") model_config = dict( - r_max=args.r_max, + r_max=r_max, num_bessel=args.num_radial_basis, num_polynomial_cutoff=args.num_cutoff_basis, max_ell=args.max_ell, @@ -311,7 +330,7 @@ def main() -> None: if args.model == "MACE": model = modules.ScaleShiftMACE( **model_config, - correlation=args.correlation, + correlations=correlation, gate=modules.gate_dict[args.gate], interaction_cls_first=modules.interaction_classes[ "RealAgnosticInteractionBlock" @@ -323,7 +342,7 @@ def main() -> None: elif args.model == "ScaleShiftMACE": model = modules.ScaleShiftMACE( **model_config, - correlation=args.correlation, + correlations=correlation, gate=modules.gate_dict[args.gate], interaction_cls_first=modules.interaction_classes[args.interaction_first], MLP_irreps=o3.Irreps(args.MLP_irreps), @@ -580,35 +599,35 @@ def main() -> None: ) logging.info("Computing metrics for training, validation, and test sets") - + all_data_loaders = { "train": train_loader, "valid": valid_loader, } - + test_sets = {} if args.train_file.endswith(".xyz"): for name, subset in collections.tests: test_sets[name] = [ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + data.AtomicData.from_config(config, z_table=z_table, cutoff=r_max) for config in subset ] elif not args.multi_processed_test: test_files = get_files_with_suffix(args.test_dir, "_test.h5") for test_file in test_files: name = os.path.splitext(os.path.basename(test_file))[0] - test_sets[name] = HDF5Dataset(test_file, r_max=args.r_max, z_table=z_table) + test_sets[name] = HDF5Dataset(test_file, r_max=r_max, z_table=z_table) else: test_folders = glob(args.test_dir + "/*") for folder in test_folders: - test_sets[name] = dataset_from_sharded_hdf5(folder, r_max=args.r_max, z_table=z_table) - + test_sets[name] = dataset_from_sharded_hdf5(folder, r_max=r_max, z_table=z_table) + for test_name, test_set in test_sets.items(): test_sampler = None if args.distributed: test_sampler = torch.utils.data.distributed.DistributedSampler( - test_set, - num_replicas=world_size, + test_set, + num_replicas=world_size, rank=rank, shuffle=True, drop_last=True, @@ -623,7 +642,7 @@ def main() -> None: pin_memory=args.pin_memory, ) all_data_loaders[test_name] = test_loader - + for swa_eval in swas: epoch = checkpoint_handler.load_latest( state=tools.CheckpointState(model, optimizer, lr_scheduler), @@ -647,7 +666,7 @@ def main() -> None: distributed=args.distributed, ) logging.info("\n" + str(table)) - + if rank == 0: # Save entire model if swa_eval: @@ -663,11 +682,11 @@ def main() -> None: torch.save(model, Path(args.model_dir) / (args.name + "_swa.model")) else: torch.save(model, Path(args.model_dir) / (args.name + ".model")) - + if args.distributed: torch.distributed.barrier() - logging.info("Done") + logging.info("Done") if args.distributed: torch.distributed.destroy_process_group()