Skip to content

Commit

Permalink
Merge pull request #261 from apax-hub/moredocs-nico
Browse files Browse the repository at this point in the history
Moredocs nico
  • Loading branch information
M-R-Schaefer authored Apr 10, 2024
2 parents 095cfe4 + 8cf08b1 commit 80ea38a
Show file tree
Hide file tree
Showing 29 changed files with 881 additions and 880 deletions.
16 changes: 14 additions & 2 deletions apax/cli/templates/md_config_minimal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,23 @@ ensemble:
name: nvt
dt: 0.5 # fs time step
temperature: <T> # K
thermostat_chain:
chain_length: 3
chain_steps: 2
sy_steps: 3
tau: 100

duration: <DURATION> # fs
n_inner: 1 # compiled innner steps
sampling_rate: 1 # dump interval
n_inner: 100 # compiled innner steps
sampling_rate: 10 # dump interval
buffer_size: 100
dr_threshold: 0.5 # Neighborlist skin
extra_capacity: 0

sim_dir: md
initial_structure: <INITIAL_STRUCTURE>
load_momenta: false
traj_name: md.h5
restart: true
checkpoint_interval: 50_000
disable_pbar: false
27 changes: 22 additions & 5 deletions apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
n_epochs: <NUMBER OF EPOCHS>
seed: 1
patience: null
n_models: 1
n_jitted_steps: 1
data_parallel: True

data:
directory: models/
Expand All @@ -11,6 +15,8 @@ data:
#train_data_path: <PATH>
#val_data_path: <PATH>
#test_data_path: <PATH>
additional_properties_info: {}
ds_type: cached

n_train: 1000
n_valid: 100
Expand All @@ -20,6 +26,10 @@ data:

shift_method: "per_element_regression_shift"
shift_options: {"energy_regularisation": 1.0}

scale_method: "per_element_force_rms_scale"
scale_options: {}

shuffle_buffer_size: 1000

pos_unit: Ang
Expand All @@ -28,25 +38,30 @@ data:
model:
n_basis: 7
n_radial: 5
n_contr: -1
nn: [512, 512]

r_max: 6.0
r_min: 0.5

calc_stress: true
use_zbl: false

b_init: normal
descriptor_dtype: fp32
descriptor_dtype: fp64
readout_dtype: fp32
scale_shift_dtype: fp32
emb_init: uniform

loss:
- loss_type: structures
name: energy
- name: energy
loss_type: mse
weight: 1.0
- loss_type: structures
name: forces
atoms_exponent: 1
- name: forces
loss_type: mse
weight: 4.0
atoms_exponent: 1

metrics:
- name: energy
Expand All @@ -66,6 +81,7 @@ optimizer:
shift_lr: 0.05
zbl_lr: 0.001
transition_begin: 0
sam_rho: 0.0

callbacks:
- name: csv
Expand All @@ -78,3 +94,4 @@ checkpoints:

progress_bar:
disable_epoch_pbar: false
disable_batch_pbar: true
10 changes: 7 additions & 3 deletions apax/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
def parse_config(config: Union[str, os.PathLike, dict], mode: str = "train") -> Config:
"""Load the training configuration from file or a dictionary.
Attributes
Parameters
----------
config: Path to the config file or a dictionary
containing the config.
config : str | os.PathLike | dict
Path to the config file or a dictionary
containing the config.
mode: str, default = train
Defines if the config is validated for training ("train")
or MD simulation("md").
"""
if isinstance(config, (str, os.PathLike)):
with open(config, "r") as stream:
Expand Down
140 changes: 100 additions & 40 deletions apax/config/md_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,77 +4,136 @@
from typing import Literal, Union

import yaml
from pydantic import BaseModel, Field, PositiveFloat, PositiveInt
from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt


class NHCOptions(BaseModel, extra="forbid"):
"""
Options for Nose-Hoover chain thermostat.
Parameters
----------
chain_length : PositiveInt, default = 3
Number of thermostats in the chain.
chain_steps : PositiveInt, default = 2
Number of steps per chain.
sy_steps : PositiveInt, default = 3
Number of steps for Suzuki-Yoshida integration.
tau : PositiveFloat, default = 100
Relaxation time parameter.
"""

chain_length: PositiveInt = 3
chain_steps: PositiveInt = 2
sy_steps: PositiveInt = 3
tau: PositiveFloat = 100


class Integrator(BaseModel, extra="forbid"):
"""
Molecular dynamics integrator options.
Parameters
----------
dt : PositiveFloat, default = 0.5
Time step size in femtoseconds (fs).
"""

dt: PositiveFloat = 0.5 # fs


class NVEOptions(Integrator, extra="forbid"):
"""
Options for NVE ensemble simulations.
Attributes
----------
name : Literal["nve"]
Name of the ensemble.
"""

name: Literal["nve"]


class NVTOptions(Integrator, extra="forbid"):
"""
Options for NVT ensemble simulations.
Parameters
----------
name : Literal["nvt"]
Name of the ensemble.
temperature : PositiveFloat, default = 298.15
Temperature in Kelvin (K).
thermostat_chain : NHCOptions, default = NHCOptions()
Thermostat chain options.
"""

name: Literal["nvt"]
temperature: PositiveFloat = 298.15 # K
thermostat_chain: NHCOptions = NHCOptions()


class NPTOptions(NVTOptions, extra="forbid"):
"""
Options for NPT ensemble simulations.
Parameters
----------
name : Literal["npt"]
Name of the ensemble.
pressure : PositiveFloat, default = 1.01325
Pressure in bar.
barostat_chain : NHCOptions, default = NHCOptions(tau=1000)
Barostat chain options.
"""

name: Literal["npt"]
pressure: PositiveFloat = 1.01325 # bar
barostat_chain: NHCOptions = NHCOptions(tau=1000)


class MDConfig(BaseModel, frozen=True, extra="forbid"):
"""Configuration for a NHC molecular dynamics simulation.
"""
Configuration for a NHC molecular dynamics simulation.
Full config :ref:`here <md_config>`:
Attributes
Parameters
----------
seed:
Random seed for momentum initialization.
temperature:
Temperature of the simulation in Kelvin.
dt:
Time step in fs.
duration:
Total simulation time in fs.
n_inner:
Number of compiled simulation steps (i.e. number of iterations of the
`jax.lax.fori_loop` loop). Also determines atoms buffer size.
sampling_rate:
Interval between saving frames.
buffer_size:
Number of collected frames to be dumped at once.
dr_threshold:
Skin of the neighborlist.
extra_capacity:
JaxMD allocates a maximal number of neighbors.
This argument lets you add additional capacity to avoid recompilation.
The default is usually fine.
initial_structure:
Path to the starting structure of the simulation.
sim_dir:
Directory where simulation file will be stored.
traj_name:
Name of the trajectory file.
restart:
Whether the simulation should restart from the latest configuration
in `traj_name`.
checkpoint_interval:
Number of time steps between saving
full simulation state checkpoints. These will be loaded
with the `restart` option.
disable_pbar:
Disables the MD progressbar.
seed : int, default = 1
| Random seed for momentum initialization.
temperature : float, default = 298.15
| Temperature of the simulation in Kelvin.
dt : float, default = 0.5
| Time step in fs.
duration : float, required
| Total simulation time in fs.
n_inner : int, default = 100
| Number of compiled simulation steps (i.e. number of iterations of the
| `jax.lax.fori_loop` loop). Also determines atoms buffer size.
sampling_rate : int, default = 10
| Interval between saving frames.
buffer_size : int, default = 100
| Number of collected frames to be dumped at once.
dr_threshold : float, default = 0.5
| Skin of the neighborlist.
extra_capacity : int, default = 0
| JaxMD allocates a maximal number of neighbors. This argument lets you add
| additional capacity to avoid recompilation. The default is usually fine.
initial_structure : str, required
| Path to the starting structure of the simulation.
sim_dir : str, default = "."
| Directory where simulation file will be stored.
traj_name : str, default = "md.h5"
| Name of the trajectory file.
restart : bool, default = True
| Whether the simulation should restart from the latest configuration in
| `traj_name`.
checkpoint_interval : int, default = 50_000
| Number of time steps between saving full simulation state checkpoints.
| These will be loaded with the `restart` option.
disable_pbar : bool, False
| Disables the MD progressbar.
"""

seed: int = 1
Expand All @@ -89,7 +148,7 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
sampling_rate: PositiveInt = 10
buffer_size: PositiveInt = 100
dr_threshold: PositiveFloat = 0.5
extra_capacity: PositiveInt = 0
extra_capacity: NonNegativeInt = 0

initial_structure: str
load_momenta: bool = False
Expand All @@ -102,6 +161,7 @@ class MDConfig(BaseModel, frozen=True, extra="forbid"):
def dump_config(self):
"""
Writes the current config file to the MD directory.
"""
with open(os.path.join(self.sim_dir, "md_config.yaml"), "w") as conf:
yaml.dump(self.model_dump(), conf, default_flow_style=False)
Loading

0 comments on commit 80ea38a

Please sign in to comment.