From 8dc92c79a1f06e88a3cd20df09192a4beef023be Mon Sep 17 00:00:00 2001 From: Roy-Haolin-Du Date: Mon, 6 Jan 2025 12:09:01 +0000 Subject: [PATCH 1/8] Replace template_config.cfg with Pydantic config --- README.md | 1 - a3fe/__init__.py | 2 +- a3fe/configuration/__init__.py | 1 + a3fe/configuration/engine_config.py | 191 ++++++++++++++++++++++++ a3fe/run/_simulation_runner.py | 11 +- a3fe/run/calc_set.py | 10 ++ a3fe/run/calculation.py | 7 +- a3fe/run/lambda_window.py | 6 + a3fe/run/leg.py | 45 +++--- a3fe/run/simulation.py | 6 +- a3fe/run/stage.py | 6 + a3fe/tests/conftest.py | 10 +- a3fe/tests/test_engine_configuration.py | 81 ++++++++++ docs/getting_started.rst | 2 +- docs/guides.rst | 8 +- 15 files changed, 342 insertions(+), 45 deletions(-) create mode 100644 a3fe/configuration/engine_config.py create mode 100644 a3fe/tests/test_engine_configuration.py diff --git a/README.md b/README.md index 6d7b923..cd3327b 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,6 @@ python -m pip install --no-deps . - Activate your a3fe conda environment - Create a base directory for the calculation and create an directory called `input` within this - Move your input files into the the input directory. For example, if you have parameterised AMBER-format input files, name these bound_param.rst7, bound_param.prm7, free_param.rst7, and free_param.prm7. For more details see the documentation. Alternatively, copy the example input files from a3fe/a3fe/data/example_run_dir to your input directory. -- Copy run template_config.cfg from a3fe/a3fe/data/example_run_dir to your `input` directory. - In the calculation base directory, run the following python code, either through ipython or as a python script (you will likely want to run the script with `nohup`or use ipython through tmux to ensure that the calculation is not killed when you lose connection) ```python diff --git a/a3fe/__init__.py b/a3fe/__init__.py index 570eea0..744ab06 100644 --- a/a3fe/__init__.py +++ b/a3fe/__init__.py @@ -27,7 +27,7 @@ enums, ) -from .configuration import SystemPreparationConfig, SlurmConfig +from .configuration import SystemPreparationConfig, SlurmConfig, SomdConfig _sys.modules["EnsEquil"] = _sys.modules["a3fe"] diff --git a/a3fe/configuration/__init__.py b/a3fe/configuration/__init__.py index 9ea7e32..6f4e7ac 100644 --- a/a3fe/configuration/__init__.py +++ b/a3fe/configuration/__init__.py @@ -2,3 +2,4 @@ from .system_prep_config import SystemPreparationConfig from .slurm_config import SlurmConfig +from .engine_config import SomdConfig diff --git a/a3fe/configuration/engine_config.py b/a3fe/configuration/engine_config.py new file mode 100644 index 0000000..f883526 --- /dev/null +++ b/a3fe/configuration/engine_config.py @@ -0,0 +1,191 @@ +"""Configuration classes for SOMD engine configuration.""" + +__all__ = [ + "SomdConfig", +] + +import yaml as _yaml +import os as _os +from typing import Dict as _Dict + +from pydantic import BaseModel as _BaseModel +from pydantic import Field as _Field +from pydantic import ConfigDict as _ConfigDict + + +class SomdConfig(_BaseModel): + """ + Pydantic model for holding SOMD engine configuration. + """ + + ### Integrator - ncycles modified as required by a3fe ### + nmoves: int = _Field(25000, description="Number of moves per cycle") + ncycles: int = _Field(60, description="Number of cycles") + timestep: float = _Field(4.0, description="Timestep in femtoseconds") + constraint: str = _Field("hbonds", description="Constraint type") + hydrogen_mass_factor: float = _Field( + 3.0, + alias="hydrogen mass repartitioning factor", + description="Hydrogen mass repartitioning factor" + ) + integrator: str = _Field("langevinmiddle", description="Integration algorithm") + inverse_friction: float = _Field( + 1.0, + description="Inverse friction in picoseconds", + alias="inverse friction" + ) + temperature: float = _Field(25.0, description="Temperature in Celsius") + # Thermostatting already handled by langevin integrator + thermostat: bool = _Field(False, description="Enable thermostat") + + ### Barostat ### + barostat: bool = _Field(True, description="Enable barostat") + pressure: float = _Field(1.0, description="Pressure in atm") + + ### Non-Bonded Interactions ### + cutoff_type: str = _Field( + "PME", + alias="cutoff type", + description="Type of cutoff to use" + ) + cutoff_distance: float = _Field( + 10.0, + alias="cutoff distance", + description="Cutoff distance in angstroms" + ) + + ### Trajectory ### + buffered_coords_freq: int = _Field( + 5000, + alias="buffered coordinates frequency", + description="Frequency of buffered coordinates output" + ) + center_solute: bool = _Field( + True, + alias="center solute", + description="Center solute in box" + ) + + ### Minimisation ### + minimise: bool = _Field(True, description="Perform energy minimisation") + + ### Alchemistry - restraints added by a3fe ### + perturbed_residue_number: int = _Field( + 1, + alias="perturbed residue number", + description="Residue number to perturb" + ) + energy_frequency: int = _Field( + 200, + alias="energy frequency", + description="Frequency of energy output" + ) + extra_options: _Dict[str, str] = _Field( + default_factory=_Dict, + description="Extra options to pass to the SOMD engine" + ) + model_config = _ConfigDict(validate_assignment=True) + + def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str: + """ + Generates the SOMD configuration file and returns its path. + + Parameters + ---------- + run_dir : str + Directory to write the configuration file to. + + config_name : str, optional, default="somd_config" + Name of the configuration file to write. Note that when running many jobs from the + same directory, this should be unique to avoid overwriting the config file. + + Returns + ------- + str + Path to the generated configuration file. + """ + # First, generate the configuration string + config_lines = [ + "### Integrator ###", + f"nmoves = {self.nmoves}", + f"ncycles = {self.ncycles}", + f"timestep = {self.timestep} * femtosecond", + f"constraint = {self.constraint}", + f"hydrogen mass repartitioning factor = {self.hydrogen_mass_factor}", + f"integrator = {self.integrator}", + f"inverse friction = {self.inverse_friction} * picosecond", + f"temperature = {self.temperature} * celsius", + f"thermostat = {str(self.thermostat)}", + "", + "### Barostat ###", + f"barostat = {str(self.barostat)}", + f"pressure = {self.pressure} * atm", + "", + "### Non-Bonded Interactions ###", + f"cutoff type = {self.cutoff_type}", + f"cutoff distance = {self.cutoff_distance} * angstrom", + "", + "### Trajectory ###", + f"buffered coordinates frequency = {self.buffered_coords_freq}", + f"center solute = {str(self.center_solute)}", + "", + "### Minimisation ###", + f"minimise = {str(self.minimise)}", + "", + "### Alchemistry ###", + f"perturbed residue number = {self.perturbed_residue_number}", + f"energy frequency = {self.energy_frequency}", + ] + # Add any extra options + if self.extra_options: + config_lines.extend(["", "### Extra Options ###"]) + for key, value in self.extra_options.items(): + config_lines.append(f"{key} = {value}") + + # Write the configuration to a file + config_path = _os.path.join(run_dir, f"{config_name}.cfg") + with open(config_path, "w") as f: + f.write("\n".join(config_lines)) + + return config_path + + def dump(self, save_dir: str) -> None: + """ + Dumps the configuration to a YAML file. + + Parameters + ---------- + save_dir : str + Directory to save the YAML file to. + """ + model_dict = self.model_dump() + + save_path = save_dir + "/" + self.get_file_name() + with open(save_path, "w") as f: + _yaml.dump(model_dict, f, default_flow_style=False) + + @classmethod + def load(cls, load_dir: str) -> "SomdConfig": + """ + Loads the configuration from a YAML file. + + Parameters + ---------- + load_dir : str + Directory to load the YAML file from. + + Returns + ------- + SomdConfig + The loaded configuration. + """ + with open(load_dir + "/" + cls.get_file_name(), "r") as f: + model_dict = _yaml.safe_load(f) + return cls(**model_dict) + + @staticmethod + def get_file_name() -> str: + """ + Get the name of the SOMD configuration file. + """ + return "somd_config.yaml" diff --git a/a3fe/run/_simulation_runner.py b/a3fe/run/_simulation_runner.py index a89d8b2..27c457b 100644 --- a/a3fe/run/_simulation_runner.py +++ b/a3fe/run/_simulation_runner.py @@ -28,7 +28,7 @@ from ._logging_formatters import _A3feFileFormatter, _A3feStreamFormatter from ..configuration import SlurmConfig as _SlurmConfig - +from ..configuration import SomdConfig as _SomdConfig class SimulationRunner(ABC): """An abstract base class for simulation runners. Note that @@ -55,6 +55,7 @@ def __init__( output_dir: _Optional[str] = None, slurm_config: _Optional[_SlurmConfig] = None, analysis_slurm_config: _Optional[_SlurmConfig] = None, + engine_config: _Optional[_SomdConfig] = None, stream_log_level: int = _logging.INFO, dg_multiplier: int = 1, ensemble_size: int = 5, @@ -81,6 +82,8 @@ def __init__( This is helpful e.g. if you want to submit analysis to the CPU partition, but the main simulation to the GPU partition. If None, the standard slurm_config is used. + engine_config: SomdConfig, default: None + Configuration for the SOMD engine. If None, the default configuration is used. stream_log_level : int, Optional, default: logging.INFO Logging level to use for the steam file handlers for the calculation object and its child objects. @@ -171,6 +174,12 @@ def __init__( else self.slurm_config ) + # Create the SOMD config with default values if none is provided + if engine_config is None: + self.engine_config = _SomdConfig() + else: + self.engine_config = engine_config + # Save state if dump: self._dump() diff --git a/a3fe/run/calc_set.py b/a3fe/run/calc_set.py index 815f370..9809d25 100644 --- a/a3fe/run/calc_set.py +++ b/a3fe/run/calc_set.py @@ -13,6 +13,7 @@ from scipy import stats as _stats from ..configuration import SlurmConfig as _SlurmConfig +from ..configuration import SomdConfig as _SomdConfig from ..analyse.analyse_set import compute_stats as _compute_stats from ..analyse.plot import plot_against_exp as _plt_against_exp @@ -40,6 +41,7 @@ def __init__( stream_log_level: int = _logging.INFO, slurm_config: _Optional[_SlurmConfig] = None, analysis_slurm_config: _Optional[_SlurmConfig] = None, + engine_config: _Optional[_SomdConfig] = None, update_paths: bool = True, ) -> None: """ @@ -73,6 +75,8 @@ def __init__( Configuration for the SLURM job scheduler for the analysis. This is helpful e.g. if you want to submit analysis to the CPU partition, but the main simulation to the GPU partition. If None, + engine_config: SomdConfig, default: None + Configuration for the SOMD engine. If None, the default configuration is used. update_paths: bool, Optional, default: True If True, if the simulation runner is loaded by unpickling, then update_paths() is called. @@ -89,6 +93,7 @@ def __init__( update_paths=update_paths, slurm_config=slurm_config, analysis_slurm_config=analysis_slurm_config, + engine_config=engine_config, ) if not self.loaded_from_pickle: @@ -108,6 +113,11 @@ def __init__( calc_args["analysis_slurm_config"] = self.analysis_slurm_config self._calc_args = calc_args + # Ensure that all calculations share the same somd config by adding this if it is not present + if calc_args.get("engine_config") is None: + calc_args["engine_config"] = self.engine_config + self._calc_args = calc_args + # Check that we can load all of the calculations for calc in self.calcs: # Having set them up according to _calc_args, save them diff --git a/a3fe/run/calculation.py b/a3fe/run/calculation.py index 7d0ec1a..9abc4e9 100644 --- a/a3fe/run/calculation.py +++ b/a3fe/run/calculation.py @@ -15,6 +15,7 @@ from ..configuration import ( SystemPreparationConfig as _SystemPreparationConfig, SlurmConfig as _SlurmConfig, + SomdConfig as _SomdConfig, ) @@ -27,7 +28,6 @@ class Calculation(_SimulationRunner): required_input_files = [ "protein.pdb", "ligand.sdf", - "template_config.cfg", ] # Waters.pdb is optional required_legs = [_LegType.FREE, _LegType.BOUND] @@ -43,6 +43,7 @@ def __init__( stream_log_level: int = _logging.INFO, slurm_config: _Optional[_SlurmConfig] = None, analysis_slurm_config: _Optional[_SlurmConfig] = None, + engine_config: _Optional[_SomdConfig] = None, update_paths: bool = True, ) -> None: """ @@ -88,6 +89,8 @@ def __init__( Configuration for the SLURM job scheduler for the analysis. This is helpful e.g. if you want to submit analysis to the CPU partition, but the main simulation to the GPU partition. If None, + engine_config: SomdConfig, default: None + Configuration for the SOMD engine. If None, the default configuration is used. update_paths: bool, Optional, default: True If True, if the simulation runner is loaded by unpickling, then update_paths() is called. @@ -105,6 +108,7 @@ def __init__( update_paths=update_paths, slurm_config=slurm_config, analysis_slurm_config=analysis_slurm_config, + engine_config=engine_config, dump=False, ) @@ -210,6 +214,7 @@ def setup( stream_log_level=self.stream_log_level, slurm_config=self.slurm_config, analysis_slurm_config=self.analysis_slurm_config, + engine_config=self.engine_config, ) self.legs.append(leg) leg.setup(configs[leg_type]) diff --git a/a3fe/run/lambda_window.py b/a3fe/run/lambda_window.py index 59471e1..b8fae28 100644 --- a/a3fe/run/lambda_window.py +++ b/a3fe/run/lambda_window.py @@ -21,6 +21,7 @@ from ._virtual_queue import VirtualQueue as _VirtualQueue from .simulation import Simulation as _Simulation from ..configuration import SlurmConfig as _SlurmConfig +from ..configuration import SomdConfig as _SomdConfig class LamWindow(_SimulationRunner): @@ -53,6 +54,7 @@ def __init__( stream_log_level: int = _logging.INFO, slurm_config: _Optional[_SlurmConfig] = None, analysis_slurm_config: _Optional[_SlurmConfig] = None, + engine_config: _Optional[_SomdConfig] = None, update_paths: bool = True, ) -> None: """ @@ -113,6 +115,8 @@ def __init__( Configuration for the SLURM job scheduler for the analysis. This is helpful e.g. if you want to submit analysis to the CPU partition, but the main simulation to the GPU partition. If None, + engine_config: SomdConfig, default: None + Configuration for the SOMD engine. If None, the default configuration is used. update_paths: bool, Optional, default: True If true, if the simulation runner is loaded by unpickling, then update_paths() is called. @@ -134,6 +138,7 @@ def __init__( update_paths=update_paths, slurm_config=slurm_config, analysis_slurm_config=analysis_slurm_config, + engine_config=engine_config, dump=False, ) @@ -190,6 +195,7 @@ def __init__( stream_log_level=stream_log_level, slurm_config=self.slurm_config, analysis_slurm_config=self.analysis_slurm_config, + engine_config=self.engine_config, ) ) diff --git a/a3fe/run/leg.py b/a3fe/run/leg.py index b9f9ff7..f935202 100644 --- a/a3fe/run/leg.py +++ b/a3fe/run/leg.py @@ -37,6 +37,7 @@ from ..configuration import ( SystemPreparationConfig as _SystemPreparationConfig, SlurmConfig as _SlurmConfig, + SomdConfig as _SomdConfig, ) @@ -51,7 +52,6 @@ class Leg(_SimulationRunner): required_input_files[leg_type] = {} for prep_stage in _PreparationStage: required_input_files[leg_type][prep_stage] = [ - "template_config.cfg", ] + prep_stage.get_simulation_input_files(leg_type) required_stages = { @@ -71,6 +71,7 @@ def __init__( stream_log_level: int = _logging.INFO, slurm_config: _Optional[_SlurmConfig] = None, analysis_slurm_config: _Optional[_SlurmConfig] = None, + engine_config: _Optional[_SomdConfig] = None, update_paths: bool = True, ) -> None: """ @@ -113,6 +114,8 @@ def __init__( Configuration for the SLURM job scheduler for the analysis. This is helpful e.g. if you want to submit analysis to the CPU partition, but the main simulation to the GPU partition. If None, + engine_config: SomdConfig, default: None + Configuration for the SOMD engine. If None, the default configuration is used. update_paths: bool, optional, default: True if true, if the simulation runner is loaded by unpickling, then update_paths() is called. @@ -130,6 +133,7 @@ def __init__( stream_log_level=stream_log_level, slurm_config=slurm_config, analysis_slurm_config=analysis_slurm_config, + engine_config=engine_config, ensemble_size=ensemble_size, update_paths=update_paths, dump=False, @@ -279,6 +283,7 @@ def setup( stream_log_level=self.stream_log_level, slurm_config=self.slurm_config, analysis_slurm_config=self.analysis_slurm_config, + engine_config=self.engine_config, ) ) @@ -736,16 +741,11 @@ def write_input_files( # If we have a charged ligand, make sure that SOMD is using PME if lig_charge != 0: - try: - cuttoff_type = _read_simfile_option( - f"{self.input_dir}/template_config.cfg", "cutoff type" - ) - except ValueError: # Will get this if the option is not present (but the default is not PME) - cuttoff_type = None - if cuttoff_type != "PME": + cutoff_type = self.engine_config.cutoff_type + if cutoff_type != "PME": raise ValueError( f"The ligand has a non-zero charge ({lig_charge}), so SOMD must use PME for the electrostatics. " - "Please set the 'cutoff type' option in the somd.cfg file to 'PME'." + "Please set the 'cutoff type' option in the engine_config to 'PME'." ) self._logger.info( @@ -808,19 +808,24 @@ def write_input_files( restraint_file, f"{stage_input_dir}/restraint_{i + 1}.txt" ) - # Update the template-config.cfg file with the perturbed residue number generated + # Update the somd.cfg file with the perturbed residue number generated # by BSS, as well as the restraints options - _shutil.copy(f"{self.input_dir}/template_config.cfg", stage_input_dir) + + # generate the somd.cfg file + config_path = self.engine_config.get_somd_config( + run_dir=stage_input_dir, + config_name="somd" + ) try: use_boresch_restraints = _read_simfile_option( - f"{stage_input_dir}/somd.cfg", "use boresch restraints" + config_path, "use boresch restraints" ) except ValueError: use_boresch_restraints = False try: turn_on_receptor_ligand_restraints_mode = _read_simfile_option( - f"{stage_input_dir}/somd.cfg", + config_path, "turn on receptor-ligand restraints mode", ) except ValueError: @@ -837,24 +842,14 @@ def write_input_files( for option, value in options_to_write.items(): _write_simfile_option( - f"{stage_input_dir}/template_config.cfg", option, value + config_path, option, value ) - # Now overwrite the SOMD generated config file with the updated template - _subprocess.run( - [ - "mv", - f"{stage_input_dir}/template_config.cfg", - f"{stage_input_dir}/somd.cfg", - ], - check=True, - ) - # Set the default lambda windows based on the leg and stage types lam_vals = config.lambda_values[self.leg_type][stage_type] lam_vals_str = ", ".join([str(lam_val) for lam_val in lam_vals]) _write_simfile_option( - f"{stage_input_dir}/somd.cfg", "lambda array", lam_vals_str + config_path, "lambda array", lam_vals_str ) # We no longer need to store the large BSS restraint classes. diff --git a/a3fe/run/simulation.py b/a3fe/run/simulation.py index faabe54..b17a6d8 100644 --- a/a3fe/run/simulation.py +++ b/a3fe/run/simulation.py @@ -22,7 +22,7 @@ from ._virtual_queue import VirtualQueue as _VirtualQueue from .enums import JobStatus as _JobStatus from ..configuration import SlurmConfig as _SlurmConfig - +from ..configuration import SomdConfig as _SomdConfig class Simulation(_SimulationRunner): """Class to store information about a single SOMD simulation.""" @@ -60,6 +60,7 @@ def __init__( stream_log_level: int = _logging.INFO, slurm_config: _Optional[_SlurmConfig] = None, analysis_slurm_config: _Optional[_SlurmConfig] = None, + engine_config: _Optional[_SomdConfig] = None, update_paths: bool = True, ) -> None: """ @@ -92,6 +93,8 @@ def __init__( Configuration for the SLURM job scheduler for the analysis. This is helpful e.g. if you want to submit analysis to the CPU partition, but the main simulation to the GPU partition. If None, + engine_config: SomdConfig, default: None + Configuration for the SOMD engine. If None, the default configuration is used. update_paths: bool, Optional, default: True If True, if the simulation runner is loaded by unpickling, then update_paths() is called. @@ -112,6 +115,7 @@ def __init__( stream_log_level=stream_log_level, slurm_config=slurm_config, analysis_slurm_config=analysis_slurm_config, + engine_config=engine_config, update_paths=update_paths, dump=False, ) diff --git a/a3fe/run/stage.py b/a3fe/run/stage.py index 95f8074..91e36e3 100644 --- a/a3fe/run/stage.py +++ b/a3fe/run/stage.py @@ -54,6 +54,7 @@ from .enums import StageType as _StageType from .lambda_window import LamWindow as _LamWindow from ..configuration.slurm_config import SlurmConfig as _SlurmConfig +from ..configuration.engine_config import SomdConfig as _SomdConfig class Stage(_SimulationRunner): @@ -89,6 +90,7 @@ def __init__( stream_log_level: int = _logging.INFO, slurm_config: _Optional[_SlurmConfig] = None, analysis_slurm_config: _Optional[_SlurmConfig] = None, + engine_config: _Optional[_SomdConfig] = None, update_paths: bool = True, ) -> None: """ @@ -137,6 +139,8 @@ def __init__( Configuration for the SLURM job scheduler for the analysis. This is helpful e.g. if you want to submit analysis to the CPU partition, but the main simulation to the GPU partition. If None, + engine_config: SomdConfig, default: None + Configuration for the SOMD engine. If None, the default configuration is used. update_paths: bool, Optional, default: True If True, if the simulation runner is loaded by unpickling, then update_paths() is called. @@ -156,6 +160,7 @@ def __init__( stream_log_level=stream_log_level, slurm_config=slurm_config, analysis_slurm_config=analysis_slurm_config, + engine_config=engine_config, ensemble_size=ensemble_size, update_paths=update_paths, dump=False, @@ -196,6 +201,7 @@ def __init__( stream_log_level=self.stream_log_level, slurm_config=self.slurm_config, analysis_slurm_config=self.analysis_slurm_config, + engine_config=self.engine_config, ) ) diff --git a/a3fe/tests/conftest.py b/a3fe/tests/conftest.py index 1d82626..e172fba 100644 --- a/a3fe/tests/conftest.py +++ b/a3fe/tests/conftest.py @@ -96,15 +96,7 @@ def t4l_calc(): ) # Copy over remaining input files - for file in ["template_config.cfg"]: - subprocess.run( - [ - "cp", - os.path.join("a3fe/data/example_run_dir/input/", file), - os.path.join(dirname, "input"), - ], - check=True, - ) + # No files need to be copied calc = a3.Calculation( base_dir=dirname, diff --git a/a3fe/tests/test_engine_configuration.py b/a3fe/tests/test_engine_configuration.py new file mode 100644 index 0000000..4e92758 --- /dev/null +++ b/a3fe/tests/test_engine_configuration.py @@ -0,0 +1,81 @@ +"""Unit and regression tests for the SomdConfig class.""" + +from tempfile import TemporaryDirectory + +from a3fe import SomdConfig + +import os + + +def test_create_config(): + """Test that the config can be created.""" + config = SomdConfig() + assert isinstance(config, SomdConfig) + + +def test_config_pickle_and_load(): + """Test that the config can be pickled and loaded.""" + with TemporaryDirectory() as dirname: + config = SomdConfig() + config.dump(dirname) + config2 = SomdConfig.load(dirname) + assert config == config2 + + +def test_get_somd_config(): + """ + Test that the SOMD configuration file is generated correctly + and that the file is written correctly. + """ + # Tmpdir to store the config + with TemporaryDirectory() as dirname: + config = SomdConfig( + integrator="langevinmiddle", + nmoves=25000, + ncycles=60, + timestep=4.0, + cutoff_type="PME", + cutoff_distance=10.0, + ) + config_path = config.get_somd_config( + run_dir=dirname, + config_name="test" + ) + assert config_path == os.path.join(dirname, "test.cfg") + + expected_config = ( + "### Integrator ###\n" + "nmoves = 25000\n" + "ncycles = 60\n" + "timestep = 4.0 * femtosecond\n" + "constraint = hbonds\n" + "hydrogen mass repartitioning factor = 3.0\n" + "integrator = langevinmiddle\n" + "inverse friction = 1.0 * picosecond\n" + "temperature = 25.0 * celsius\n" + "thermostat = False\n" + "\n" + "### Barostat ###\n" + "barostat = True\n" + "pressure = 1.0 * atm\n" + "\n" + "### Non-Bonded Interactions ###\n" + "cutoff type = PME\n" + "cutoff distance = 10.0 * angstrom\n" + "\n" + "### Trajectory ###\n" + "buffered coordinates frequency = 500\n" + "center solute = True\n" + "\n" + "### Minimisation ###\n" + "minimise = True\n" + "\n" + "### Alchemistry ###\n" + "perturbed residue number = 1\n" + "energy frequency = 500\n" + ) + + with open(config_path, "r") as f: + config_content = f.read() + + assert config_content == expected_config diff --git a/docs/getting_started.rst b/docs/getting_started.rst index f4f74e3..9676cd1 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -13,7 +13,6 @@ Quick Start - Activate your a3fe conda environment - Create a base directory for the calculation and create an directory called ``input`` within this - Move your input files into the the input directory. For example, if you have parameterised AMBER-format input files, name these bound_param.rst7, bound_param.prm7, free_param.rst7, and free_param.prm7. **Ensure that the ligand is named LIG and is the first molecule in the system.** For more details see :ref:`Preparing Input for a3fe`. Alternatively, copy the pre-provided input from ``a3fe/a3fe/data/example_run_dir/input`` to your input directory. -- Copy template_config.cfg from ``a3fe/a3fe/data/example_run_dir`` to your ``input`` directory. - In the calculation base directory, run the following python code, either through ipython or as a python script (you will likely want to run this with ``nohup``/ through tmux to ensure that the calculation is not killed when you lose connection). Running though ipython will let you interact with the calculation while it's running. .. code-block:: python @@ -22,6 +21,7 @@ Quick Start calc = a3.Calculation( ensemble_size=5, # Use 5 (independently equilibrated) replicate runs slurm_config=a3.SlurmConfig(partition=""), # Set your desired partition! + engine_config=a3.SomdConfig(), ) calc.setup() calc.get_optimal_lam_vals() diff --git a/docs/guides.rst b/docs/guides.rst index ebbcbe8..0a4cafc 100644 --- a/docs/guides.rst +++ b/docs/guides.rst @@ -67,8 +67,6 @@ You can also find out which input files are required for a given preparation sta - free_preequil.prm7, free_preequil.rst7 - The solvated ligand after heating and short initial equilibration steps -In addition, for every preparation stage, **template_config.cfg must be present in the input -directory.** Please note that if you are suppling parameterised input files, **the ligand must be the first molecule in the system and the ligand must be named "LIG"**. The former can be achieved by reordering the system with BioSimSpace, and the latter @@ -140,7 +138,7 @@ Individual Simulation settings ------------------------------- To customise the specifics of how each lambda window is run (e.g. timestep), you can use the ``set_simfile_option`` method. For example, to set the timestep to 2 fs, run -``calc.set_simfile_option("timestep", "2 * femtosecond")``. This will change parameters from the defaults given in ``template_config.cfg`` in the ``input`` directory, and warn +``calc.set_simfile_option("timestep", "2 * femtosecond")``. This will change parameters from the defaults generated by ``engine_config``, and warn you if you are trying to set a parameter that is not present in the template config file. To see a list of available options, run ``somd-freenrg --help-config``. SLURM Options @@ -304,7 +302,7 @@ You can run sets of calculations using the :class:`a3fe.run.CalcSet` class. To d ABFE with Charged Ligands ************************* -Since A3FE 0.2.0, ABFE calculations with charged ligands are supported using a co-alchemical ion approach. The charge of the ligand will be automatically detected, assuming that this is correctly specified in the input sdf. The only change in the input required is that the use of PME, rather than reaction field electrostatics, should be specified in ``template_config.cfg`` as e.g.: +Since A3FE 0.2.0, ABFE calculations with charged ligands are supported using a co-alchemical ion approach. The charge of the ligand will be automatically detected, assuming that this is correctly specified in the input sdf. The only change in the input required is that the use of PME, rather than reaction field electrostatics, should be specified in ``somd_config.cfg`` as e.g.: .. code-block:: bash @@ -312,4 +310,4 @@ Since A3FE 0.2.0, ABFE calculations with charged ligands are supported using a c cutoff type = PME cutoff distance = 10 * angstrom -The default `template_config.cfg` uses reaction field instead of PME. This is faster (around twice as fast for some of our systems) and has been shown to give equivalent results for neutral ligands in RBFE calculations - see https://pubs.acs.org/doi/full/10.1021/acs.jcim.0c01424 . \ No newline at end of file +The default `somd_config.cfg` uses reaction field instead of PME. This is faster (around twice as fast for some of our systems) and has been shown to give equivalent results for neutral ligands in RBFE calculations - see https://pubs.acs.org/doi/full/10.1021/acs.jcim.0c01424 . From b654a04640d542bc90aca6d8b058557db6712393 Mon Sep 17 00:00:00 2001 From: Roy-Haolin-Du Date: Mon, 6 Jan 2025 15:17:14 +0000 Subject: [PATCH 2/8] test and format for removing run_somd.sh and template_config.cfg --- a3fe/configuration/engine_config.py | 5 ++- a3fe/configuration/slurm_config.py | 4 +-- .../mdm2_pip2_short/input/template_config.cfg | 36 ------------------- .../t4l/input/template_config.cfg | 36 ------------------- .../example_run_dir/input/template_config.cfg | 35 ------------------ a3fe/tests/test_engine_configuration.py | 4 +-- 6 files changed, 8 insertions(+), 112 deletions(-) delete mode 100644 a3fe/data/example_calc_set/mdm2_pip2_short/input/template_config.cfg delete mode 100644 a3fe/data/example_calc_set/t4l/input/template_config.cfg delete mode 100644 a3fe/data/example_run_dir/input/template_config.cfg diff --git a/a3fe/configuration/engine_config.py b/a3fe/configuration/engine_config.py index f883526..1ad5b7b 100644 --- a/a3fe/configuration/engine_config.py +++ b/a3fe/configuration/engine_config.py @@ -81,7 +81,7 @@ class SomdConfig(_BaseModel): description="Frequency of energy output" ) extra_options: _Dict[str, str] = _Field( - default_factory=_Dict, + default_factory=dict, description="Extra options to pass to the SOMD engine" ) model_config = _ConfigDict(validate_assignment=True) @@ -92,6 +92,8 @@ def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str Parameters ---------- + #content : str + Content to write to the configuration file. run_dir : str Directory to write the configuration file to. @@ -135,6 +137,7 @@ def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str "### Alchemistry ###", f"perturbed residue number = {self.perturbed_residue_number}", f"energy frequency = {self.energy_frequency}", + "", ] # Add any extra options if self.extra_options: diff --git a/a3fe/configuration/slurm_config.py b/a3fe/configuration/slurm_config.py index c2dfdd1..6ef9cbb 100644 --- a/a3fe/configuration/slurm_config.py +++ b/a3fe/configuration/slurm_config.py @@ -22,7 +22,7 @@ class SlurmConfig(_BaseModel): Pydantic model for holding a SLURM configuration. """ - partition: str = _Field("default", description="SLURM partition to submit to.") + partition: str = _Field("main", description="SLURM partition to submit to.") time: str = _Field("24:00:00", description="Time limit for the SLURM job.") gres: str = _Field("gpu:1", description="Resources to request - normally one GPU.") nodes: int = _Field(1, ge=1) @@ -80,7 +80,7 @@ def get_submission_cmds( with open(script_path, "w") as f: f.write(script) - return ["rbatch", f"--chdir={run_dir}", script_path] + return ["sbatch", f"--chdir={run_dir}", script_path] def get_slurm_output_file_base(self, run_dir: str) -> str: """ diff --git a/a3fe/data/example_calc_set/mdm2_pip2_short/input/template_config.cfg b/a3fe/data/example_calc_set/mdm2_pip2_short/input/template_config.cfg deleted file mode 100644 index 2913177..0000000 --- a/a3fe/data/example_calc_set/mdm2_pip2_short/input/template_config.cfg +++ /dev/null @@ -1,36 +0,0 @@ -### For information on options and defaults, run `somd-freenrg --help-config` - -### Integrator - ncycles modified as required by EnsEquil ### -nmoves = 25000 -ncycles = 60 -timestep = 4 * femtosecond -constraint = hbonds -hydrogen mass repartitioning factor = 3.0 -integrator = langevinmiddle -inverse friction = 1 * picosecond -temperature = 25 * celsius -# Thermostatting already handled by langevin integrator -thermostat = False - -### Barostat ### -barostat = True -pressure = 1 * atm - -### Non-Bonded Interactions ### -cutoff type = cutoffperiodic -cutoff distance = 12 * angstrom -reaction field dielectric = 78.3 - -### Trajectory ### -buffered coordinates frequency = 5000 -center solute = True - -### Minimisation ### -minimise = True - -### Alchemistry - restraints added by EnsEquil ### -perturbed residue number = 1 -energy frequency = 200 - -### Added by EnsEquil ### - diff --git a/a3fe/data/example_calc_set/t4l/input/template_config.cfg b/a3fe/data/example_calc_set/t4l/input/template_config.cfg deleted file mode 100644 index 2913177..0000000 --- a/a3fe/data/example_calc_set/t4l/input/template_config.cfg +++ /dev/null @@ -1,36 +0,0 @@ -### For information on options and defaults, run `somd-freenrg --help-config` - -### Integrator - ncycles modified as required by EnsEquil ### -nmoves = 25000 -ncycles = 60 -timestep = 4 * femtosecond -constraint = hbonds -hydrogen mass repartitioning factor = 3.0 -integrator = langevinmiddle -inverse friction = 1 * picosecond -temperature = 25 * celsius -# Thermostatting already handled by langevin integrator -thermostat = False - -### Barostat ### -barostat = True -pressure = 1 * atm - -### Non-Bonded Interactions ### -cutoff type = cutoffperiodic -cutoff distance = 12 * angstrom -reaction field dielectric = 78.3 - -### Trajectory ### -buffered coordinates frequency = 5000 -center solute = True - -### Minimisation ### -minimise = True - -### Alchemistry - restraints added by EnsEquil ### -perturbed residue number = 1 -energy frequency = 200 - -### Added by EnsEquil ### - diff --git a/a3fe/data/example_run_dir/input/template_config.cfg b/a3fe/data/example_run_dir/input/template_config.cfg deleted file mode 100644 index 02ee2fb..0000000 --- a/a3fe/data/example_run_dir/input/template_config.cfg +++ /dev/null @@ -1,35 +0,0 @@ -### For information on options and defaults, run `somd-freenrg --help-config` - -### Integrator - ncycles modified as required by a3fe ### -nmoves = 25000 -ncycles = 60 -timestep = 4 * femtosecond -constraint = hbonds -hydrogen mass repartitioning factor = 3.0 -integrator = langevinmiddle -inverse friction = 1 * picosecond -temperature = 25 * celsius -# Thermostatting already handled by langevin integrator -thermostat = False - -### Barostat ### -barostat = True -pressure = 1 * atm - -### Non-Bonded Interactions ### -cutoff type = cutoffperiodic -cutoff distance = 12 * angstrom -reaction field dielectric = 78.3 - -### Trajectory ### -buffered coordinates frequency = 5000 -center solute = True - -### Minimisation ### -minimise = True - -### Alchemistry - restraints added by a3fe ### -perturbed residue number = 1 -energy frequency = 200 - -### Added by a3fe ### diff --git a/a3fe/tests/test_engine_configuration.py b/a3fe/tests/test_engine_configuration.py index 4e92758..205fbec 100644 --- a/a3fe/tests/test_engine_configuration.py +++ b/a3fe/tests/test_engine_configuration.py @@ -64,7 +64,7 @@ def test_get_somd_config(): "cutoff distance = 10.0 * angstrom\n" "\n" "### Trajectory ###\n" - "buffered coordinates frequency = 500\n" + "buffered coordinates frequency = 5000\n" "center solute = True\n" "\n" "### Minimisation ###\n" @@ -72,7 +72,7 @@ def test_get_somd_config(): "\n" "### Alchemistry ###\n" "perturbed residue number = 1\n" - "energy frequency = 500\n" + "energy frequency = 200\n" ) with open(config_path, "r") as f: From fb7346c3facfc8b5d952e1c88a6371ad882e9480 Mon Sep 17 00:00:00 2001 From: Roy-Haolin-Du Date: Wed, 8 Jan 2025 17:22:52 +0000 Subject: [PATCH 3/8] Update engine_config validation, ncycle, runtime and test passed --- a3fe/configuration/engine_config.py | 210 +++++++++++++++--- a3fe/tests/test_engine_configuration.py | 274 ++++++++++++++++++------ 2 files changed, 399 insertions(+), 85 deletions(-) diff --git a/a3fe/configuration/engine_config.py b/a3fe/configuration/engine_config.py index 1ad5b7b..1114a5d 100644 --- a/a3fe/configuration/engine_config.py +++ b/a3fe/configuration/engine_config.py @@ -6,11 +6,16 @@ import yaml as _yaml import os as _os -from typing import Dict as _Dict - -from pydantic import BaseModel as _BaseModel -from pydantic import Field as _Field -from pydantic import ConfigDict as _ConfigDict +from typing import Dict as _Dict, Literal as _Literal, List as _List, Union as _Union, Optional as _Optional +from math import isclose as _isclose +from pydantic import ( + BaseModel as _BaseModel, + Field as _Field, + ConfigDict as _ConfigDict, + validator as _validator, + root_validator as _root_validator, + field_validator as _field_validator +) class SomdConfig(_BaseModel): @@ -20,40 +25,167 @@ class SomdConfig(_BaseModel): ### Integrator - ncycles modified as required by a3fe ### nmoves: int = _Field(25000, description="Number of moves per cycle") - ncycles: int = _Field(60, description="Number of cycles") - timestep: float = _Field(4.0, description="Timestep in femtoseconds") - constraint: str = _Field("hbonds", description="Constraint type") + timestep: float = _Field(4.0, description="Timestep in femtoseconds(fs)") + runtime: _Union[int, float] = _Field(..., description="Runtime in nanoseconds(ns)") + + input_dir: str = _Field(..., description="Input directory containing simulation config files") + @staticmethod + def _calculate_ncycles(runtime: float, time_per_cycle: float) -> int: + """ + Calculate the number of cycles given a runtime and time per cycle. + + Parameters + ---------- + runtime : float + Runtime in nanoseconds. + time_per_cycle : float + Time per cycle in nanoseconds. + + Returns + ------- + int + Number of cycles. + """ + return round(runtime / time_per_cycle) + + def _set_n_cycles(self, n_cycles: int) -> None: + """ + Set the number of cycles in the SOMD config file. + + Parameters + ---------- + n_cycles : int + Number of cycles to set in the somd config file. + """ + with open(_os.path.join(self.input_dir, "somd.cfg"), "r") as ifile: + lines = ifile.readlines() + for i, line in enumerate(lines): + if line.startswith("ncycles ="): + lines[i] = "ncycles = " + str(n_cycles) + "\n" + break + #Now write the new file + with open(_os.path.join(self.input_dir, "somd.cfg"), "w+") as ofile: + for line in lines: + ofile.write(line) + + def _validate_runtime_and_update_config(self) -> None: + """ + Validate the runtime and update the simulation configuration. + + Need to make sure that runtime is a multiple of the time per cycle + otherwise actual time could be quite different from requested runtime + + Raises + ------ + ValueError + If runtime is not a multiple of the time per cycle. + """ + time_per_cycle = self.timestep * self.nmoves / 1_000_000 # Convert fs to ns + # Convert both to float for division + remainder = float(self.runtime) / float(time_per_cycle) + if not _isclose(remainder - round(remainder), 0, abs_tol=1e-5): + raise ValueError( + f"Runtime must be a multiple of the time per cycle. " + f"Runtime: {self.runtime} ns, Time per cycle: {time_per_cycle} ns." + ) + # Need to modify the config file to set the correction n_cycles + n_cycles = self._calculate_ncycles(self.runtime, time_per_cycle) + self._set_n_cycles(n_cycles) + print(f"Updated ncycles to {n_cycles} in somd.cfg") + + constraint: str = _Field("hbonds", description="Constraint type, must be hbonds or all-bonds") + + @_validator('constraint') + def _check_constraint(cls, v): + if v not in ['hbonds', 'all-bonds']: + raise ValueError('constraint must be hbonds or all-bonds') + return v + hydrogen_mass_factor: float = _Field( 3.0, alias="hydrogen mass repartitioning factor", description="Hydrogen mass repartitioning factor" ) - integrator: str = _Field("langevinmiddle", description="Integration algorithm") + integrator: _Literal["langevinmiddle", "leapfrogverlet"] = _Field("langevinmiddle", description="Integration algorithm") + # Thermostatting already handled by langevin integrator + thermostat: bool = _Field(False, description="Enable thermostat") + + @_root_validator(pre=True) + def _validate_integrator_thermostat(cls, v): + ''' + Make sure that if integrator is 'langevinmiddle' then thermostat must be False + ''' + integrator = v.get('integrator') + thermostat = v.get('thermostat', False) # Default to False if not provided + + if integrator == "langevinmiddle" and thermostat is not False: + raise ValueError("thermostat must be False when integrator is langevinmiddle") + return v + inverse_friction: float = _Field( 1.0, description="Inverse friction in picoseconds", alias="inverse friction" ) temperature: float = _Field(25.0, description="Temperature in Celsius") - # Thermostatting already handled by langevin integrator - thermostat: bool = _Field(False, description="Enable thermostat") ### Barostat ### barostat: bool = _Field(True, description="Enable barostat") pressure: float = _Field(1.0, description="Pressure in atm") ### Non-Bonded Interactions ### - cutoff_type: str = _Field( + cutoff_type: _Literal["PME", "cutoffperiodic"] = _Field( "PME", - alias="cutoff type", - description="Type of cutoff to use" + description="Type of cutoff to use. Options: PME, cutoffperiodic" ) cutoff_distance: float = _Field( - 10.0, + 10.0, # Default to PME cutoff distance alias="cutoff distance", description="Cutoff distance in angstroms" ) + reaction_field_dielectric: float | None = _Field( + None, + alias="reaction field dielectric", + description="Reaction field dielectric constant(only for cutoffperiodic)" + ) + @_validator('cutoff_type') + def _check_cutoff_type(cls, v): + if v not in ['PME', 'cutoffperiodic']: + raise ValueError('cutoff type must be PME or cutoffperiodic') + return v + @_validator('cutoff_distance', always=True) + def _set_cutoff_distance(cls, v, values): + if values.get('cutoff_type') == 'PME': + return 10.0 if v is None else v + elif values.get('cutoff_type') == 'cutoffperiodic': + return 12.0 if v is None else v + return v + + @_validator('reaction_field_dielectric',always=True) + def _set_reaction_field_dielectric(cls, v, values): + cutoff_type = values.get('cutoff_type') + if cutoff_type == 'PME' and v is not None: + raise ValueError('reaction field dielectric should not be provided when cutoff type is PME') + elif cutoff_type == 'cutoffperiodic' and v is None: + return 78.3 # Default dielectric constant for cutoffperiodic + return v + + @_field_validator("cutoff_distance", mode="before") + def _validate_cutoff_distance(cls, v, values): + """ + Validate cutoff distance based on cutoff type. + """ + cutoff_type = values.data.get("cutoff_type") + if cutoff_type == "cutoffperiodic": + return 12.0 # Default for cutoffperiodic + return v # Default for PME (10.0) + + def __init__(self, **data): + super().__init__(**data) + self._validate_runtime_and_update_config() + + ### Trajectory ### buffered_coords_freq: int = _Field( 5000, @@ -69,6 +201,16 @@ class SomdConfig(_BaseModel): ### Minimisation ### minimise: bool = _Field(True, description="Perform energy minimisation") + ### Restraints ### + use_boresch_restraints: bool = _Field( + False, + description="UseBoresch restraints mode" + ) + receptor_ligand_restraints: bool = _Field( + False, + description="Turn on receptor-ligand restraints mode" + ) + ### Alchemistry - restraints added by a3fe ### perturbed_residue_number: int = _Field( 1, @@ -80,6 +222,28 @@ class SomdConfig(_BaseModel): alias="energy frequency", description="Frequency of energy output" ) + + ### Lambda ### + lambda_array: _List[float] = _Field( + default_factory=list, + description="Lambda array for alchemical perturbation, varies from 0.0 to 1.0 across stages" + ) + lambda_val: _Optional[float] = _Field( + None, + description="Lambda value for current stage" + ) + + ### Alchemical files ### + morphfile: _Optional[str] = _Field( + None, description="Path to morph file containing alchemical transformation" + ) + topfile: _Optional[str] = _Field( + None, description="Path to topology file for the system" + ) + crdfile: _Optional[str] = _Field( + None, description="Path to coordinate file for the system" + ) + extra_options: _Dict[str, str] = _Field( default_factory=dict, description="Extra options to pass to the SOMD engine" @@ -92,8 +256,6 @@ def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str Parameters ---------- - #content : str - Content to write to the configuration file. run_dir : str Directory to write the configuration file to. @@ -110,17 +272,16 @@ def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str config_lines = [ "### Integrator ###", f"nmoves = {self.nmoves}", - f"ncycles = {self.ncycles}", f"timestep = {self.timestep} * femtosecond", f"constraint = {self.constraint}", f"hydrogen mass repartitioning factor = {self.hydrogen_mass_factor}", f"integrator = {self.integrator}", f"inverse friction = {self.inverse_friction} * picosecond", f"temperature = {self.temperature} * celsius", - f"thermostat = {str(self.thermostat)}", + f"thermostat = {self.thermostat}", "", "### Barostat ###", - f"barostat = {str(self.barostat)}", + f"barostat = {self.barostat}", f"pressure = {self.pressure} * atm", "", "### Non-Bonded Interactions ###", @@ -129,15 +290,14 @@ def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str "", "### Trajectory ###", f"buffered coordinates frequency = {self.buffered_coords_freq}", - f"center solute = {str(self.center_solute)}", + f"center solute = {self.center_solute}", "", "### Minimisation ###", - f"minimise = {str(self.minimise)}", + f"minimise = {self.minimise}", "", "### Alchemistry ###", f"perturbed residue number = {self.perturbed_residue_number}", - f"energy frequency = {self.energy_frequency}", - "", + f"energy frequency = {self.energy_frequency}" ] # Add any extra options if self.extra_options: @@ -148,7 +308,7 @@ def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str # Write the configuration to a file config_path = _os.path.join(run_dir, f"{config_name}.cfg") with open(config_path, "w") as f: - f.write("\n".join(config_lines)) + f.write("\n".join(config_lines) + "\n") return config_path diff --git a/a3fe/tests/test_engine_configuration.py b/a3fe/tests/test_engine_configuration.py index 205fbec..0012cfd 100644 --- a/a3fe/tests/test_engine_configuration.py +++ b/a3fe/tests/test_engine_configuration.py @@ -1,81 +1,235 @@ """Unit and regression tests for the SomdConfig class.""" from tempfile import TemporaryDirectory +import os from a3fe import SomdConfig -import os - +def create_test_input_dir(): + """Create a temporary directory with a mock somd.cfg file.""" + temp_dir = TemporaryDirectory() + with open(os.path.join(temp_dir.name, "somd.cfg"), "w") as f: + f.write("ncycles = 1000\n") + return temp_dir def test_create_config(): """Test that the config can be created.""" - config = SomdConfig() - assert isinstance(config, SomdConfig) - + with create_test_input_dir() as input_dir: + # Test with integer runtime + config = SomdConfig( + runtime=1, # Integer runtime + input_dir=input_dir + ) + assert isinstance(config, SomdConfig) + + # Test with float runtime + config = SomdConfig( + runtime=0.3, # Float runtime + input_dir=input_dir + ) + assert isinstance(config, SomdConfig) def test_config_pickle_and_load(): """Test that the config can be pickled and loaded.""" - with TemporaryDirectory() as dirname: - config = SomdConfig() - config.dump(dirname) - config2 = SomdConfig.load(dirname) - assert config == config2 - + with create_test_input_dir() as input_dir: + with TemporaryDirectory() as dirname: + config = SomdConfig( + runtime=1, # Integer runtime + input_dir=input_dir + ) + config.dump(dirname) + config2 = SomdConfig.load(dirname) + assert config == config2 def test_get_somd_config(): - """ - Test that the SOMD configuration file is generated correctly - and that the file is written correctly. - """ - # Tmpdir to store the config - with TemporaryDirectory() as dirname: + """Test that the SOMD configuration file is generated correctly.""" + with create_test_input_dir() as input_dir: + with TemporaryDirectory() as dirname: + config = SomdConfig( + integrator="langevinmiddle", + nmoves=25000, + timestep=4.0, + runtime=1, # Integer runtime + input_dir=input_dir, + cutoff_type="PME", + thermostat=False + ) + config_path = config.get_somd_config( + run_dir=dirname, + config_name="test" + ) + assert config_path == os.path.join(dirname, "test.cfg") + + expected_config = ( + "### Integrator ###\n" + "nmoves = 25000\n" + "timestep = 4.0 * femtosecond\n" + "constraint = hbonds\n" + "hydrogen mass repartitioning factor = 3.0\n" + "integrator = langevinmiddle\n" + "inverse friction = 1.0 * picosecond\n" + "temperature = 25.0 * celsius\n" + "thermostat = False\n" + "\n" + "### Barostat ###\n" + "barostat = True\n" + "pressure = 1.0 * atm\n" + "\n" + "### Non-Bonded Interactions ###\n" + "cutoff type = PME\n" + "cutoff distance = 10.0 * angstrom\n" + "\n" + "### Trajectory ###\n" + "buffered coordinates frequency = 5000\n" + "center solute = True\n" + "\n" + "### Minimisation ###\n" + "minimise = True\n" + "\n" + "### Alchemistry ###\n" + "perturbed residue number = 1\n" + "energy frequency = 200\n" + ) + + with open(config_path, "r") as f: + config_content = f.read() + + assert config_content == expected_config + +def test_constraint_validation(): + """Test constraint type validation.""" + with create_test_input_dir() as input_dir: + config = SomdConfig( + constraint="hbonds", + runtime=1, # Integer runtime + input_dir=input_dir + ) + assert config.constraint == "hbonds" + + config = SomdConfig( + constraint="all-bonds", + runtime=0.3, # Float runtime + input_dir=input_dir + ) + assert config.constraint == "all-bonds" + + try: + SomdConfig( + constraint="invalid", + runtime=1, # Integer runtime + input_dir=input_dir + ) + assert False, "Should raise ValueError" + except ValueError: + pass + +def test_cutoff_type_validation(): + """Test cutoff type validation.""" + with create_test_input_dir() as input_dir: + # Test PME cutoff type config = SomdConfig( - integrator="langevinmiddle", - nmoves=25000, - ncycles=60, - timestep=4.0, cutoff_type="PME", - cutoff_distance=10.0, + runtime=1, # Integer runtime + input_dir=input_dir ) - config_path = config.get_somd_config( - run_dir=dirname, - config_name="test" + assert config.cutoff_type == "PME" + assert config.cutoff_distance == 10.0 + assert config.reaction_field_dielectric is None + + # Test cutoffperiodic type + config = SomdConfig( + cutoff_type="cutoffperiodic", + runtime=0.3, # Float runtime + input_dir=input_dir ) - assert config_path == os.path.join(dirname, "test.cfg") - - expected_config = ( - "### Integrator ###\n" - "nmoves = 25000\n" - "ncycles = 60\n" - "timestep = 4.0 * femtosecond\n" - "constraint = hbonds\n" - "hydrogen mass repartitioning factor = 3.0\n" - "integrator = langevinmiddle\n" - "inverse friction = 1.0 * picosecond\n" - "temperature = 25.0 * celsius\n" - "thermostat = False\n" - "\n" - "### Barostat ###\n" - "barostat = True\n" - "pressure = 1.0 * atm\n" - "\n" - "### Non-Bonded Interactions ###\n" - "cutoff type = PME\n" - "cutoff distance = 10.0 * angstrom\n" - "\n" - "### Trajectory ###\n" - "buffered coordinates frequency = 5000\n" - "center solute = True\n" - "\n" - "### Minimisation ###\n" - "minimise = True\n" - "\n" - "### Alchemistry ###\n" - "perturbed residue number = 1\n" - "energy frequency = 200\n" + assert config.cutoff_type == "cutoffperiodic" + assert config.cutoff_distance == 12.0 + assert config.reaction_field_dielectric == 78.3 + +def test_integrator_thermostat_validation(): + """Test integrator and thermostat validation.""" + with create_test_input_dir() as input_dir: + # Valid configuration + config = SomdConfig( + integrator="langevinmiddle", + thermostat=False, + runtime=1, # Integer runtime + input_dir=input_dir ) + assert config.integrator == "langevinmiddle" + assert config.thermostat is False + + # Invalid configuration + try: + SomdConfig( + integrator="langevinmiddle", + thermostat=True, + runtime=1, # Integer runtime + input_dir=input_dir + ) + assert False, "Should raise ValueError" + except ValueError: + pass - with open(config_path, "r") as f: - config_content = f.read() +def test_lambda_validation(): + """Test lambda array and value validation.""" + with create_test_input_dir() as input_dir: + config = SomdConfig( + lambda_array=[0.0, 0.5, 1.0], + lambda_val=0.5, + runtime=1, # Integer runtime + input_dir=input_dir + ) + assert config.lambda_array == [0.0, 0.5, 1.0] + assert config.lambda_val == 0.5 + +def test_restraints_configuration(): + """Test restraints configuration options.""" + with create_test_input_dir() as input_dir: + config = SomdConfig( + use_boresch_restraints=True, + receptor_ligand_restraints=True, + runtime=1, # Integer runtime + input_dir=input_dir + ) + assert config.use_boresch_restraints is True + assert config.receptor_ligand_restraints is True + +def test_alchemical_files(): + """Test alchemical transformation file paths.""" + with create_test_input_dir() as input_dir: + config = SomdConfig( + morphfile="/path/to/morph.pert", + topfile="/path/to/system.top", + crdfile="/path/to/system.crd", + runtime=1, # Integer runtime + input_dir=input_dir + ) + assert config.morphfile == "/path/to/morph.pert" + assert config.topfile == "/path/to/system.top" + assert config.crdfile == "/path/to/system.crd" - assert config_content == expected_config +def test_get_somd_config_with_extra_options(): + """Test SOMD configuration generation with extra options.""" + with create_test_input_dir() as input_dir: + with TemporaryDirectory() as dirname: + config = SomdConfig( + integrator="langevinmiddle", + nmoves=25000, + timestep=4.0, + runtime=1, # Integer runtime + input_dir=input_dir, + cutoff_type="PME", + thermostat=False, + extra_options={"custom_option": "value"} + ) + config_path = config.get_somd_config( + run_dir=dirname, + config_name="test_extra" + ) + + with open(config_path, "r") as f: + config_content = f.read() + + assert "### Extra Options ###" in config_content + assert "custom_option = value" in config_content \ No newline at end of file From ec4858f5d80f773c2844fadbf86ebcf11394e721 Mon Sep 17 00:00:00 2001 From: Roy-Haolin-Du Date: Wed, 8 Jan 2025 20:20:33 +0000 Subject: [PATCH 4/8] keep partition as default --- a3fe/configuration/slurm_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/a3fe/configuration/slurm_config.py b/a3fe/configuration/slurm_config.py index 6ef9cbb..f9cb5d6 100644 --- a/a3fe/configuration/slurm_config.py +++ b/a3fe/configuration/slurm_config.py @@ -22,7 +22,7 @@ class SlurmConfig(_BaseModel): Pydantic model for holding a SLURM configuration. """ - partition: str = _Field("main", description="SLURM partition to submit to.") + partition: str = _Field("default", description="SLURM partition to submit to.") time: str = _Field("24:00:00", description="Time limit for the SLURM job.") gres: str = _Field("gpu:1", description="Resources to request - normally one GPU.") nodes: int = _Field(1, ge=1) From abc0b9e892e504454004bcef1a6eaa3500d359ef Mon Sep 17 00:00:00 2001 From: Roy-Haolin-Du Date: Thu, 9 Jan 2025 11:04:12 +0000 Subject: [PATCH 5/8] Flexibly handle SOMD configuration files and create an abstract base class _engine_runner_config. Write out the A3FE version, and make the output logs more concise and suitable for debugging --- a3fe/configuration/_engine_runner_config.py | 139 ++++++++++++++++++++ a3fe/configuration/engine_config.py | 90 ++++++------- a3fe/run/_simulation_runner.py | 81 +++++++++++- a3fe/run/calculation.py | 32 ++++- a3fe/run/leg.py | 48 ++----- a3fe/run/simulation.py | 42 ------ a3fe/run/stage.py | 2 +- a3fe/tests/test_engine_configuration.py | 26 ++-- docs/guides.rst | 7 +- 9 files changed, 324 insertions(+), 143 deletions(-) create mode 100644 a3fe/configuration/_engine_runner_config.py diff --git a/a3fe/configuration/_engine_runner_config.py b/a3fe/configuration/_engine_runner_config.py new file mode 100644 index 0000000..94f1f63 --- /dev/null +++ b/a3fe/configuration/_engine_runner_config.py @@ -0,0 +1,139 @@ +"""Abstract base class for engine configurations.""" + +from __future__ import annotations + +import copy as _copy +import logging as _logging +import yaml as _yaml +from abc import ABC, abstractmethod +from typing import Any as _Any +from typing import Dict as _Dict + +from ..run._logging_formatters import _A3feStreamFormatter + + +class EngineRunnerConfig(ABC): + """An abstract base class for engine configurations.""" + + def __init__( + self, + stream_log_level: int = _logging.INFO, + ) -> None: + """ + Initialize the engine configuration. + + Parameters + ---------- + stream_log_level : int, Optional, default: logging.INFO + Logging level to use for the steam file handlers. + """ + # Set up logging + self._stream_log_level = stream_log_level + self._set_up_logging() + + def _set_up_logging(self) -> None: + """Set up logging for the configuration.""" + # If logger exists, remove it and start again + if hasattr(self, "_logger"): + handlers = self._logger.handlers[:] + for handler in handlers: + self._logger.removeHandler(handler) + handler.close() + del self._logger + + # Create a new logger + self._logger = _logging.getLogger(f"{self.__class__.__name__}") + self._logger.propagate = False + self._logger.setLevel(_logging.DEBUG) + + # For the stream handler, we want to log at the user-specified level + stream_handler = _logging.StreamHandler() + stream_handler.setFormatter(_A3feStreamFormatter()) + stream_handler.setLevel(self._stream_log_level) + self._logger.addHandler(stream_handler) + + @abstractmethod + def get_config(self) -> _Dict[str, _Any]: + """ + Get the configuration dictionary. + + Returns + ------- + config : Dict[str, Any] + The configuration dictionary. + """ + pass + + def dump(self, file_path: str) -> None: + """ + Dump the configuration to a YAML file. + + Parameters + ---------- + file_path : str + Path to dump the configuration to. + """ + config = self.get_config() + with open(file_path, "w") as f: + _yaml.safe_dump(config, f, default_flow_style=False) + self._logger.info(f"Configuration dumped to {file_path}") + + @classmethod + def load(cls, file_path: str) -> EngineRunnerConfig: + """ + Load a configuration from a YAML file. + + Parameters + ---------- + file_path : str + Path to load the configuration from. + + Returns + ------- + config : EngineRunnerConfig + The loaded configuration. + """ + with open(file_path, "r") as f: + config_dict = _yaml.safe_load(f) + return cls(**config_dict) + + @abstractmethod + def get_file_name(self) -> str: + """ + Get the name of the configuration file. + + Returns + ------- + file_name : str + The name of the configuration file. + """ + pass + + def __eq__(self, other: object) -> bool: + """ + Check if two configurations are equal. + + Parameters + ---------- + other : object + The other configuration to compare with. + + Returns + ------- + equal : bool + Whether the configurations are equal. + """ + if not isinstance(other, EngineRunnerConfig): + return NotImplemented + return self.get_config() == other.get_config() + + def copy(self) -> EngineRunnerConfig: + """ + Create a deep copy of the configuration. + + Returns + ------- + config : EngineRunnerConfig + A deep copy of the configuration. + """ + return _copy.deepcopy(self) diff --git a/a3fe/configuration/engine_config.py b/a3fe/configuration/engine_config.py index 1114a5d..c42198b 100644 --- a/a3fe/configuration/engine_config.py +++ b/a3fe/configuration/engine_config.py @@ -4,9 +4,9 @@ "SomdConfig", ] -import yaml as _yaml import os as _os -from typing import Dict as _Dict, Literal as _Literal, List as _List, Union as _Union, Optional as _Optional +import logging as _logging +from typing import Dict as _Dict, Literal as _Literal, List as _List, Union as _Union, Optional as _Optional, Any as _Any from math import isclose as _isclose from pydantic import ( BaseModel as _BaseModel, @@ -17,8 +17,10 @@ field_validator as _field_validator ) +from ._engine_runner_config import EngineRunnerConfig as _EngineRunnerConfig -class SomdConfig(_BaseModel): + +class SomdConfig(_EngineRunnerConfig, _BaseModel): """ Pydantic model for holding SOMD engine configuration. """ @@ -88,10 +90,14 @@ def _validate_runtime_and_update_config(self) -> None: f"Runtime must be a multiple of the time per cycle. " f"Runtime: {self.runtime} ns, Time per cycle: {time_per_cycle} ns." ) - # Need to modify the config file to set the correction n_cycles - n_cycles = self._calculate_ncycles(self.runtime, time_per_cycle) - self._set_n_cycles(n_cycles) - print(f"Updated ncycles to {n_cycles} in somd.cfg") + + # Only try to modify the config file if it exists + cfg_file = _os.path.join(self.input_dir, "somd.cfg") + if _os.path.exists(cfg_file): + # Need to modify the config file to set the correction n_cycles + n_cycles = self._calculate_ncycles(self.runtime, time_per_cycle) + self._set_n_cycles(n_cycles) + print(f"Updated ncycles to {n_cycles} in somd.cfg") constraint: str = _Field("hbonds", description="Constraint type, must be hbonds or all-bonds") @@ -181,10 +187,34 @@ def _validate_cutoff_distance(cls, v, values): return 12.0 # Default for cutoffperiodic return v # Default for PME (10.0) - def __init__(self, **data): - super().__init__(**data) + model_config = _ConfigDict(arbitrary_types_allowed=True) + + def __init__(self, stream_log_level: int = _logging.INFO, **data): + _BaseModel.__init__(self, **data) + _EngineRunnerConfig.__init__(self, stream_log_level=stream_log_level) self._validate_runtime_and_update_config() + def get_config(self) -> _Dict[str, _Any]: + """ + Get the SOMD configuration as a dictionary. + + Returns + ------- + config : Dict[str, Any] + The SOMD configuration dictionary. + """ + return self.model_dump() + + def get_file_name(self) -> str: + """ + Get the name of the SOMD configuration file. + + Returns + ------- + file_name : str + The name of the SOMD configuration file. + """ + return "somd.cfg" ### Trajectory ### buffered_coords_freq: int = _Field( @@ -248,7 +278,6 @@ def __init__(self, **data): default_factory=dict, description="Extra options to pass to the SOMD engine" ) - model_config = _ConfigDict(validate_assignment=True) def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str: """ @@ -311,44 +340,3 @@ def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str f.write("\n".join(config_lines) + "\n") return config_path - - def dump(self, save_dir: str) -> None: - """ - Dumps the configuration to a YAML file. - - Parameters - ---------- - save_dir : str - Directory to save the YAML file to. - """ - model_dict = self.model_dump() - - save_path = save_dir + "/" + self.get_file_name() - with open(save_path, "w") as f: - _yaml.dump(model_dict, f, default_flow_style=False) - - @classmethod - def load(cls, load_dir: str) -> "SomdConfig": - """ - Loads the configuration from a YAML file. - - Parameters - ---------- - load_dir : str - Directory to load the YAML file from. - - Returns - ------- - SomdConfig - The loaded configuration. - """ - with open(load_dir + "/" + cls.get_file_name(), "r") as f: - model_dict = _yaml.safe_load(f) - return cls(**model_dict) - - @staticmethod - def get_file_name() -> str: - """ - Get the name of the SOMD configuration file. - """ - return "somd_config.yaml" diff --git a/a3fe/run/_simulation_runner.py b/a3fe/run/_simulation_runner.py index 27c457b..84edee9 100644 --- a/a3fe/run/_simulation_runner.py +++ b/a3fe/run/_simulation_runner.py @@ -7,6 +7,7 @@ import os as _os import pathlib as _pathlib import pickle as _pkl +import shutil as _shutil import subprocess as _subprocess from abc import ABC from itertools import count as _count @@ -176,7 +177,10 @@ def __init__( # Create the SOMD config with default values if none is provided if engine_config is None: - self.engine_config = _SomdConfig() + self.engine_config = _SomdConfig( + runtime=2.5, # Default runtime of 2.5 ns + input_dir=self.input_dir # Use the simulation runner's input directory + ) else: self.engine_config = engine_config @@ -1082,3 +1086,78 @@ def _load(self, update_paths: bool = True) -> None: # Record that the object was loaded from a pickle file self.loaded_from_pickle = True + + def _update_logging_options( + self, + stream_log_level: _Optional[int] = None, + file_log_level: _Optional[int] = None, + track_time: bool = False, + ) -> None: + """ + Update logging options for this simulation runner and all sub-runners recursively. + + Parameters + ---------- + stream_log_level : int, Optional + The log level for the stream handler. If None, keeps current level. + file_log_level : int, Optional + The log level for the file handler. If None, keeps current level. + track_time : bool, default=False + Whether to track time for logging operations. + """ + if stream_log_level is not None: + self.recursively_set_attr("stream_log_level", stream_log_level, force=True) + + if hasattr(self, "_logger"): + # Update existing logger settings + for handler in self._logger.handlers: + if isinstance(handler, _logging.StreamHandler) and stream_log_level is not None: + handler.setLevel(stream_log_level) + elif isinstance(handler, _logging.FileHandler) and file_log_level is not None: + handler.setLevel(file_log_level) + + # Add time tracking if requested + if track_time: + self._logger.track_time = True + + # Recursively update sub-runners + if hasattr(self, "_sub_sim_runners"): + for sub_runner in self._sub_sim_runners: + sub_runner._update_logging_options( + stream_log_level=stream_log_level, + file_log_level=file_log_level, + track_time=track_time + ) + + def _copy_to_test(self, test_dir: str) -> None: + """ + Copy generated files to a test directory for validation. + + Parameters + ---------- + test_dir : str + Path to the test directory where files should be copied. + """ + if not _os.path.exists(test_dir): + _os.makedirs(test_dir) + + self._logger.info(f"Copying generated files to test directory: {test_dir}") + + # Copy all files from the run directory to test directory + for root, _, files in _os.walk(self.run_dir): + for file in files: + src_path = _os.path.join(root, file) + # Get relative path from run_dir + rel_path = _os.path.relpath(src_path, self.run_dir) + dst_path = _os.path.join(test_dir, rel_path) + + # Create destination directory if it doesn't exist + dst_dir = _os.path.dirname(dst_path) + if not _os.path.exists(dst_dir): + _os.makedirs(dst_dir) + + # Copy the file + _shutil.copy2(src_path, dst_path) + self._logger.debug(f"Copied {rel_path} to test directory") + + self._logger.info("Finished copying files to test directory") diff --git a/a3fe/run/calculation.py b/a3fe/run/calculation.py index 9abc4e9..e9ae785 100644 --- a/a3fe/run/calculation.py +++ b/a3fe/run/calculation.py @@ -5,6 +5,7 @@ import logging as _logging import os as _os +import time as _time from typing import List as _List from typing import Optional as _Optional @@ -12,6 +13,7 @@ from .enums import LegType as _LegType from .enums import PreparationStage as _PreparationStage from .leg import Leg as _Leg +from .._version import __version__ as _version from ..configuration import ( SystemPreparationConfig as _SystemPreparationConfig, SlurmConfig as _SlurmConfig, @@ -99,6 +101,9 @@ def __init__( ------- None """ + # Store the version + self.version = _version + super().__init__( base_dir=base_dir, input_dir=input_dir, @@ -108,7 +113,7 @@ def __init__( update_paths=update_paths, slurm_config=slurm_config, analysis_slurm_config=analysis_slurm_config, - engine_config=engine_config, + engine_config=engine_config.copy() if engine_config else None, dump=False, ) @@ -125,6 +130,8 @@ def __init__( self._update_log() self._dump() + self._logger.info(f"Initializing calculation with A3fe version: {self.version}") + @property def legs(self) -> _List[_Leg]: return self._sub_sim_runners @@ -191,7 +198,7 @@ def setup( """ if self.setup_complete: - self._logger.info("Setup already complete. Skipping...") + self._logger.debug("Setup already complete. Skipping...") return configs = { @@ -199,10 +206,24 @@ def setup( _LegType.FREE: free_leg_sysprep_config, } + # Set up logging options for detailed setup tracking + self._update_logging_options( + stream_log_level=self.stream_log_level, + file_log_level=_logging.DEBUG, # Always keep detailed logs in file + track_time=True # Track time during setup + ) + + self._logger.info(f"Setting up calculation with A3fe version: {self.version}") + + self._logger.info("Starting calculation setup...") + setup_start = _time.time() + # Set up the legs self.legs = [] for leg_type in reversed(Calculation.required_legs): self._logger.info(f"Setting up {leg_type.name.lower()} leg...") + leg_start = _time.time() + leg = _Leg( leg_type=leg_type, equil_detection=self.equil_detection, @@ -218,6 +239,13 @@ def setup( ) self.legs.append(leg) leg.setup(configs[leg_type]) + + self._logger.debug( + f"Completed {leg_type.name.lower()} leg setup in {_time.time() - leg_start:.2f}s" + ) + + total_time = _time.time() - setup_start + self._logger.info(f"Calculation setup completed in {total_time:.2f}s") # Save the state self.setup_complete = True diff --git a/a3fe/run/leg.py b/a3fe/run/leg.py index f935202..51b8a18 100644 --- a/a3fe/run/leg.py +++ b/a3fe/run/leg.py @@ -22,8 +22,6 @@ from ..analyse.plot import plot_convergence as _plot_convergence from ..analyse.plot import plot_rmsds as _plot_rmsds from ..analyse.plot import plot_sq_sem_convergence as _plot_sq_sem_convergence -from ..read._process_somd_files import read_simfile_option as _read_simfile_option -from ..read._process_somd_files import write_simfile_option as _write_simfile_option from . import system_prep as _system_prep from ._restraint import A3feRestraint as _A3feRestraint from ._simulation_runner import SimulationRunner as _SimulationRunner @@ -812,45 +810,23 @@ def write_input_files( # by BSS, as well as the restraints options # generate the somd.cfg file - config_path = self.engine_config.get_somd_config( + somd_config = self.engine_config.get_somd_config( run_dir=stage_input_dir, config_name="somd" ) - - try: - use_boresch_restraints = _read_simfile_option( - config_path, "use boresch restraints" - ) - except ValueError: - use_boresch_restraints = False - try: - turn_on_receptor_ligand_restraints_mode = _read_simfile_option( - config_path, - "turn on receptor-ligand restraints mode", - ) - except ValueError: - turn_on_receptor_ligand_restraints_mode = False - - # Now write simfile options - options_to_write = { - "perturbed_residue number": str(perturbed_resnum), - "use boresch restraints": use_boresch_restraints, - "turn on receptor-ligand restraints mode": turn_on_receptor_ligand_restraints_mode, - # This automatically uses the co-alchemical ion approach when there is a charge difference - "charge difference": str(-lig_charge), - } - - for option, value in options_to_write.items(): - _write_simfile_option( - config_path, option, value - ) - + + # Set configuration options + somd_config.perturbed_residue_number = perturbed_resnum + somd_config.use_boresch_restraints = self.leg_type == _LegType.BOUND + somd_config.turn_on_receptor_ligand_restraints_mode = self.leg_type == _LegType.BOUND + somd_config.charge_difference = -lig_charge # Use co-alchemical ion approach when there is a charge difference + # Set the default lambda windows based on the leg and stage types lam_vals = config.lambda_values[self.leg_type][stage_type] - lam_vals_str = ", ".join([str(lam_val) for lam_val in lam_vals]) - _write_simfile_option( - config_path, "lambda array", lam_vals_str - ) + somd_config.lambda_array = lam_vals + + # Write the updated configuration + somd_config.write(stage_input_dir) # We no longer need to store the large BSS restraint classes. self._lighten_restraints() diff --git a/a3fe/run/simulation.py b/a3fe/run/simulation.py index b17a6d8..6a96a12 100644 --- a/a3fe/run/simulation.py +++ b/a3fe/run/simulation.py @@ -7,7 +7,6 @@ import os as _os import pathlib as _pathlib import subprocess as _subprocess -from decimal import Decimal as _Decimal from typing import List as _List from typing import Optional as _Optional from typing import Tuple as _Tuple @@ -222,17 +221,14 @@ def _add_attributes_from_simfile(self) -> None: """ timestep = None # ns - nmoves = None # number of moves per cycle nrg_freq = None # number of timesteps between energy calculations timestep = float( _read_simfile_option(self.simfile_path, "timestep").split()[0] ) # Need to remove femtoseconds from the end - nmoves = float(_read_simfile_option(self.simfile_path, "nmoves")) nrg_freq = float(_read_simfile_option(self.simfile_path, "energy frequency")) self.timestep = timestep / 1_000_000 # fs to ns self.nrg_freq = nrg_freq - self.time_per_cycle = timestep * nmoves / 1_000_000 # fs to ns def _select_input_files(self) -> None: """Select the correct rst7 and, if supplied, restraints, @@ -342,19 +338,6 @@ def run(self, runtime: float = 2.5) -> None: ------- None """ - # Need to make sure that runtime is a multiple of the time per cycle - # otherwise actual time could be quite different from requested runtime - remainder = _Decimal(str(runtime)) % _Decimal(str(self.time_per_cycle)) - if round(float(remainder), 4) != 0: - raise ValueError( - ( - "Runtime must be a multiple of the time per cycle. " - f"Runtime is {runtime} ns, and time per cycle is {self.time_per_cycle} ns." - ) - ) - # Need to modify the config file to set the correction n_cycles - n_cycles = round(runtime / self.time_per_cycle) - self._set_n_cycles(n_cycles) # Run SOMD - note that command excludes sbatch as this is added by the virtual queue cmd = f"somd-freenrg -C somd.cfg -l {self.lam} -p CUDA" @@ -500,31 +483,6 @@ def lighten(self) -> None: self._logger.info(f"Deleting {file}") _subprocess.run(["rm", file]) - def _set_n_cycles(self, n_cycles: int) -> None: - """ - Set the number of cycles in the SOMD config file. - - Parameters - ---------- - n_cycles : int - Number of cycles to set in the config file. - - Returns - ------- - None - """ - # Find the line with n_cycles and replace - with open(_os.path.join(self.input_dir, "somd.cfg"), "r") as ifile: - lines = ifile.readlines() - for i, line in enumerate(lines): - if line.startswith("ncycles ="): - lines[i] = "ncycles = " + str(n_cycles) + "\n" - break - - # Now write the new file - with open(_os.path.join(self.input_dir, "somd.cfg"), "w+") as ofile: - for line in lines: - ofile.write(line) def read_gradients( self, equilibrated_only: bool = False, endstate: bool = False diff --git a/a3fe/run/stage.py b/a3fe/run/stage.py index 91e36e3..130d58d 100644 --- a/a3fe/run/stage.py +++ b/a3fe/run/stage.py @@ -201,7 +201,7 @@ def __init__( stream_log_level=self.stream_log_level, slurm_config=self.slurm_config, analysis_slurm_config=self.analysis_slurm_config, - engine_config=self.engine_config, + engine_config=self.engine_config.copy() if self.engine_config else None, ) ) diff --git a/a3fe/tests/test_engine_configuration.py b/a3fe/tests/test_engine_configuration.py index 0012cfd..2f6ea2e 100644 --- a/a3fe/tests/test_engine_configuration.py +++ b/a3fe/tests/test_engine_configuration.py @@ -2,6 +2,7 @@ from tempfile import TemporaryDirectory import os +import logging from a3fe import SomdConfig @@ -13,7 +14,7 @@ def create_test_input_dir(): return temp_dir def test_create_config(): - """Test that the config can be created.""" + """Test that the config can be created with different parameters.""" with create_test_input_dir() as input_dir: # Test with integer runtime config = SomdConfig( @@ -28,17 +29,26 @@ def test_create_config(): input_dir=input_dir ) assert isinstance(config, SomdConfig) + + # Test with custom stream_log_level + config = SomdConfig( + runtime=1, + input_dir=input_dir, + stream_log_level=logging.DEBUG + ) + assert config._stream_log_level == logging.DEBUG -def test_config_pickle_and_load(): - """Test that the config can be pickled and loaded.""" +def test_config_yaml_save_and_load(): + """Test that the config can be saved to and loaded from YAML.""" with create_test_input_dir() as input_dir: with TemporaryDirectory() as dirname: config = SomdConfig( runtime=1, # Integer runtime input_dir=input_dir ) - config.dump(dirname) - config2 = SomdConfig.load(dirname) + yaml_path = os.path.join(dirname, "config.yaml") + config.dump(yaml_path) + config2 = SomdConfig.load(yaml_path) assert config == config2 def test_get_somd_config(): @@ -124,7 +134,7 @@ def test_constraint_validation(): pass def test_cutoff_type_validation(): - """Test cutoff type validation.""" + """Test cutoff type and related parameters validation.""" with create_test_input_dir() as input_dir: # Test PME cutoff type config = SomdConfig( @@ -149,7 +159,7 @@ def test_cutoff_type_validation(): def test_integrator_thermostat_validation(): """Test integrator and thermostat validation.""" with create_test_input_dir() as input_dir: - # Valid configuration + # Test valid configuration config = SomdConfig( integrator="langevinmiddle", thermostat=False, @@ -159,7 +169,7 @@ def test_integrator_thermostat_validation(): assert config.integrator == "langevinmiddle" assert config.thermostat is False - # Invalid configuration + # Test invalid configuration try: SomdConfig( integrator="langevinmiddle", diff --git a/docs/guides.rst b/docs/guides.rst index 0a4cafc..083cb58 100644 --- a/docs/guides.rst +++ b/docs/guides.rst @@ -67,6 +67,9 @@ You can also find out which input files are required for a given preparation sta - free_preequil.prm7, free_preequil.rst7 - The solvated ligand after heating and short initial equilibration steps +The default simulation engine is SOMD ``engine_config = a3.SomdConfig()``,the details of its configuration could be customised freely. +For example, ``engine_config = a3.SomdConfig(constants="hbonds", cutoff_type="PME", integrator="langevinmiddle", thermostat=False...)`` +But you can also use Gromacs ``engine_config = a3.GromacsConfig()``. Please note that if you are suppling parameterised input files, **the ligand must be the first molecule in the system and the ligand must be named "LIG"**. The former can be achieved by reordering the system with BioSimSpace, and the latter @@ -189,8 +192,8 @@ three replicates. Note that this is expected to produce an erroneously favourabl cfg.runtime_npt = 50 # ps cfg.ensemble_equilibration_time = 100 # ps calc = a3.Calculation(ensemble_size = 3) - calc_set.setup(bound_leg_sysprep_config = cfg, free_leg_sysprep_config = cfg) - calc_set.run(adaptive = False, runtime=0.1) # ns + calc.setup(bound_leg_sysprep_config = cfg, free_leg_sysprep_config = cfg) + calc.run(adaptive = False, runtime=0.1) # ns calc.wait() # Wait for the simulations to finish calc.set_equilibration_time(1) # Discard the first ns of simulation time calc.analyse() # Fast analyses From 624576e7d7e24dd080f77d9fa7418ca7567f3701 Mon Sep 17 00:00:00 2001 From: Roy-Haolin-Du Date: Sun, 12 Jan 2025 19:45:24 +0000 Subject: [PATCH 6/8] engine_confi and test --- a3fe/configuration/_engine_runner_config.py | 123 +---- a3fe/configuration/engine_config.py | 476 ++++++++++---------- a3fe/tests/test_engine_configuration.py | 368 ++++++--------- 3 files changed, 404 insertions(+), 563 deletions(-) diff --git a/a3fe/configuration/_engine_runner_config.py b/a3fe/configuration/_engine_runner_config.py index 94f1f63..0bb1ef8 100644 --- a/a3fe/configuration/_engine_runner_config.py +++ b/a3fe/configuration/_engine_runner_config.py @@ -2,81 +2,43 @@ from __future__ import annotations -import copy as _copy -import logging as _logging import yaml as _yaml -from abc import ABC, abstractmethod +from pydantic import BaseModel as _BaseModel from typing import Any as _Any from typing import Dict as _Dict -from ..run._logging_formatters import _A3feStreamFormatter +class EngineRunnerConfig(_BaseModel): + """Base class for engine runner configurations.""" - -class EngineRunnerConfig(ABC): - """An abstract base class for engine configurations.""" - - def __init__( - self, - stream_log_level: int = _logging.INFO, - ) -> None: - """ - Initialize the engine configuration. - - Parameters - ---------- - stream_log_level : int, Optional, default: logging.INFO - Logging level to use for the steam file handlers. - """ - # Set up logging - self._stream_log_level = stream_log_level - self._set_up_logging() - - def _set_up_logging(self) -> None: - """Set up logging for the configuration.""" - # If logger exists, remove it and start again - if hasattr(self, "_logger"): - handlers = self._logger.handlers[:] - for handler in handlers: - self._logger.removeHandler(handler) - handler.close() - del self._logger - - # Create a new logger - self._logger = _logging.getLogger(f"{self.__class__.__name__}") - self._logger.propagate = False - self._logger.setLevel(_logging.DEBUG) - - # For the stream handler, we want to log at the user-specified level - stream_handler = _logging.StreamHandler() - stream_handler.setFormatter(_A3feStreamFormatter()) - stream_handler.setLevel(self._stream_log_level) - self._logger.addHandler(stream_handler) - - @abstractmethod def get_config(self) -> _Dict[str, _Any]: """ Get the configuration dictionary. - - Returns - ------- - config : Dict[str, Any] - The configuration dictionary. """ pass + def get_file_name(self) -> str: + """ + Get the name of the configuration file. + """ + pass + def dump(self, file_path: str) -> None: """ - Dump the configuration to a YAML file. + Dump the configuration to a YAML file using `self.model_dump()`. Parameters ---------- file_path : str Path to dump the configuration to. """ - config = self.get_config() - with open(file_path, "w") as f: - _yaml.safe_dump(config, f, default_flow_style=False) - self._logger.info(f"Configuration dumped to {file_path}") + try: + config = self.model_dump() + with open(file_path, "w") as f: + _yaml.safe_dump(config, f) + except Exception as e: + print(f"Error dumping configuration: {e}") + return + print(f"Configuration dumped to {file_path}") @classmethod def load(cls, file_path: str) -> EngineRunnerConfig: @@ -93,47 +55,10 @@ def load(cls, file_path: str) -> EngineRunnerConfig: config : EngineRunnerConfig The loaded configuration. """ - with open(file_path, "r") as f: - config_dict = _yaml.safe_load(f) + try: + with open(file_path, "r") as f: + config_dict = _yaml.safe_load(f) + except Exception as e: + print(f"Error loading configuration: {e}") + return return cls(**config_dict) - - @abstractmethod - def get_file_name(self) -> str: - """ - Get the name of the configuration file. - - Returns - ------- - file_name : str - The name of the configuration file. - """ - pass - - def __eq__(self, other: object) -> bool: - """ - Check if two configurations are equal. - - Parameters - ---------- - other : object - The other configuration to compare with. - - Returns - ------- - equal : bool - Whether the configurations are equal. - """ - if not isinstance(other, EngineRunnerConfig): - return NotImplemented - return self.get_config() == other.get_config() - - def copy(self) -> EngineRunnerConfig: - """ - Create a deep copy of the configuration. - - Returns - ------- - config : EngineRunnerConfig - A deep copy of the configuration. - """ - return _copy.deepcopy(self) diff --git a/a3fe/configuration/engine_config.py b/a3fe/configuration/engine_config.py index c42198b..0f20f42 100644 --- a/a3fe/configuration/engine_config.py +++ b/a3fe/configuration/engine_config.py @@ -5,6 +5,7 @@ ] import os as _os +from decimal import Decimal as _Decimal import logging as _logging from typing import Dict as _Dict, Literal as _Literal, List as _List, Union as _Union, Optional as _Optional, Any as _Any from math import isclose as _isclose @@ -12,274 +13,193 @@ BaseModel as _BaseModel, Field as _Field, ConfigDict as _ConfigDict, - validator as _validator, - root_validator as _root_validator, - field_validator as _field_validator + field_validator as _field_validator, + model_validator as _model_validator, + ValidationInfo as _ValidationInfo ) - from ._engine_runner_config import EngineRunnerConfig as _EngineRunnerConfig - +DEFAULT_LAM_VALS_SOMD = { + "BOUND": { + "RESTRAIN": [0.0, 1.0], + "DISCHARGE": [0.0, 0.291, 0.54, 0.776, 1.0], + "VANISH": [ + 0.0, 0.026, 0.054, 0.083, 0.111, 0.14, 0.173, 0.208, + 0.247, 0.286, 0.329, 0.373, 0.417, 0.467, 0.514, + 0.564, 0.623, 0.696, 0.833, 1.0, + ], + }, + "FREE": { + "DISCHARGE": [0.0, 0.222, 0.447, 0.713, 1.0], + "VANISH": [ + 0.0, 0.026, 0.055, 0.09, 0.126, 0.164, 0.202, 0.239, + 0.276, 0.314, 0.354, 0.396, 0.437, 0.478, 0.518, + 0.559, 0.606, 0.668, 0.762, 1.0, + ], + }, +} class SomdConfig(_EngineRunnerConfig, _BaseModel): """ Pydantic model for holding SOMD engine configuration. - """ - + """ ### Integrator - ncycles modified as required by a3fe ### nmoves: int = _Field(25000, description="Number of moves per cycle") timestep: float = _Field(4.0, description="Timestep in femtoseconds(fs)") - runtime: _Union[int, float] = _Field(..., description="Runtime in nanoseconds(ns)") + runtime: _Union[int, float] = _Field(5.0, description="Runtime in nanoseconds(ns)") - input_dir: str = _Field(..., description="Input directory containing simulation config files") - @staticmethod - def _calculate_ncycles(runtime: float, time_per_cycle: float) -> int: - """ - Calculate the number of cycles given a runtime and time per cycle. + ### Constraints ### + constraint: str = _Field("hbonds", description="Constraint type, must be hbonds or allbonds") + hydrogen_mass_factor: float = _Field(3.0, alias="hydrogen mass repartitioning factor", description="Hydrogen mass repartitioning factor") + integrator: _Literal["langevinmiddle", "leapfrogverlet"] = _Field("langevinmiddle", description="Integration algorithm") + + ### Thermostatting already handled by langevin integrator + thermostat: bool = _Field(False, description="Enable thermostat") + inverse_friction: float = _Field(1.0, description="Inverse friction in picoseconds", alias="inverse friction") + temperature: float = _Field(25.0, description="Temperature in Celsius") - Parameters - ---------- - runtime : float - Runtime in nanoseconds. - time_per_cycle : float - Time per cycle in nanoseconds. - - Returns - ------- - int - Number of cycles. - """ - return round(runtime / time_per_cycle) + ### Barostat ### + barostat: bool = _Field(True, description="Enable barostat") + pressure: float = _Field(1.0, description="Pressure in atm") - def _set_n_cycles(self, n_cycles: int) -> None: - """ - Set the number of cycles in the SOMD config file. + ### Non-Bonded Interactions ### + cutoff_type: _Literal["cutoffperiodic", "PME"] = _Field("cutoffperiodic", description="Type of cutoff to use. Options: PME, cutoffperiodic") + cutoff_distance: float = _Field(12.0, alias="cutoff distance", ge=6.0, le=18.0, description="Cutoff distance in angstroms (6-18). Default 12.0 for cutoffperiodic and 10.0 for PME") + reaction_field_dielectric: float = _Field(78.3, alias="reaction field dielectric", + description="Reaction field dielectric constant (only for cutoffperiodic). " + "If cutoff type is PME, this value is ignored" + ) + ### Trajectory ### + buffered_coords_freq: int = _Field(5000,alias="buffered coordinates frequency",description="Frequency of buffered coordinates output") + center_solute: bool = _Field(True, alias="center solute", description="Center solute in box") - Parameters - ---------- - n_cycles : int - Number of cycles to set in the somd config file. - """ - with open(_os.path.join(self.input_dir, "somd.cfg"), "r") as ifile: - lines = ifile.readlines() - for i, line in enumerate(lines): - if line.startswith("ncycles ="): - lines[i] = "ncycles = " + str(n_cycles) + "\n" - break - #Now write the new file - with open(_os.path.join(self.input_dir, "somd.cfg"), "w+") as ofile: - for line in lines: - ofile.write(line) - - def _validate_runtime_and_update_config(self) -> None: - """ - Validate the runtime and update the simulation configuration. + ### Minimisation ### + minimise: bool = _Field(True, description="Perform energy minimisation") - Need to make sure that runtime is a multiple of the time per cycle - otherwise actual time could be quite different from requested runtime + ### Restraints ### + use_boresch_restraints: bool = _Field(False, description="Use Boresch restraints mode") + turn_on_receptor_ligand_restraints: bool = _Field(False, description="Turn on receptor-ligand restraints mode") - Raises - ------ - ValueError - If runtime is not a multiple of the time per cycle. - """ - time_per_cycle = self.timestep * self.nmoves / 1_000_000 # Convert fs to ns - # Convert both to float for division - remainder = float(self.runtime) / float(time_per_cycle) - if not _isclose(remainder - round(remainder), 0, abs_tol=1e-5): + ### Alchemistry - restraints added by a3fe ### + perturbed_residue_number: int = _Field(1,alias="perturbed residue number",ge=1, description="Residue number to perturb. Must be >= 1") + energy_frequency: int = _Field(200,alias="energy frequency",description="Frequency of energy output") + ligand_charge: int = _Field(0, description="Net charge of the ligand. If non-zero, must use PME for electrostatics.") + + ### Lambda ### + lambda_array: _List[float] = _Field(default_factory=list,description="Lambda array for alchemical perturbation, varies from 0.0 to 1.0 across stages") + lambda_val: _Optional[float] = _Field(None, description="Lambda value for current stage") + + ### Alchemical files ### + morphfile: _Optional[str] = _Field(None, description="Path to morph file containing alchemical transformation") + topfile: _Optional[str] = _Field(None, description="Path to topology file for the system") + crdfile: _Optional[str] = _Field(None, description="Path to coordinate file for the system") + + boresch_restraints_dictionary: _Optional[str] = _Field( + None, + #alias="boresch restraints dictionary", + description="Optional string to hold boresch restraints dictionary content" + ) + ### Extra options ### + extra_options: _Dict[str, str] = _Field(default_factory=dict, description="Extra options to pass to the SOMD engine") + + @_field_validator('runtime') + def validate_runtime(cls, v: float, info: _ValidationInfo) -> float: + """Validate that runtime is a multiple of time per cycle using Decimal for precise division""" + data = info.data + if not ('timestep' in data and 'nmoves' in data): + return v + + time_per_cycle = _Decimal(str(data['timestep'])) * _Decimal(str(data['nmoves'])) / _Decimal('1000000') + runtime_decimal = _Decimal(str(v)) + + if runtime_decimal % time_per_cycle != 0: raise ValueError( f"Runtime must be a multiple of the time per cycle. " - f"Runtime: {self.runtime} ns, Time per cycle: {time_per_cycle} ns." + f"Runtime: {v} ns, Time per cycle: {float(time_per_cycle)} ns" ) - - # Only try to modify the config file if it exists - cfg_file = _os.path.join(self.input_dir, "somd.cfg") - if _os.path.exists(cfg_file): - # Need to modify the config file to set the correction n_cycles - n_cycles = self._calculate_ncycles(self.runtime, time_per_cycle) - self._set_n_cycles(n_cycles) - print(f"Updated ncycles to {n_cycles} in somd.cfg") + return float(v) - constraint: str = _Field("hbonds", description="Constraint type, must be hbonds or all-bonds") + @_model_validator(mode="after") + def _check_cutoff_values(self): + """ + Issue warnings if the user supplies certain contradictory or unusual combos. + """ + cutoff_type = self.cutoff_type + cutoff_distance = self.cutoff_distance + rfd = self.reaction_field_dielectric + + # 1) Only warn if user set reaction_field_dielectric != 78.3 + if cutoff_type == "cutoffperiodic" and rfd != 78.3: + warnings.warn( + "You have cutoff_type=cutoffperiodic but set a reaction_field_dielectric. " + f"This value ({rfd}) will be ignored by the engine." + ) - @_validator('constraint') + # 2) Only warn if user picks e.g. cutoff_distance < 6 or > 18 + if cutoff_type == "PME" and not (6.0 <= cutoff_distance <= 18.0): + warnings.warn( + f"For PME, we recommend cutoff_distance in [6.0, 18.0], but you have {cutoff_distance}." + "we'll still accept it." + ) + return self + + @_field_validator('constraint') def _check_constraint(cls, v): - if v not in ['hbonds', 'all-bonds']: - raise ValueError('constraint must be hbonds or all-bonds') + if v not in ['hbonds', 'allbonds']: + raise ValueError('constraint must be hbonds or allbonds') + return v + + @_field_validator('hydrogen_mass_factor') + def _check_hmf_range(cls, v): + if not (1.0 <= v <= 4.0): + raise ValueError('hydrogen_mass_factor must be between 1 and 4.') return v - hydrogen_mass_factor: float = _Field( - 3.0, - alias="hydrogen mass repartitioning factor", - description="Hydrogen mass repartitioning factor" - ) - integrator: _Literal["langevinmiddle", "leapfrogverlet"] = _Field("langevinmiddle", description="Integration algorithm") - # Thermostatting already handled by langevin integrator - thermostat: bool = _Field(False, description="Enable thermostat") - - @_root_validator(pre=True) - def _validate_integrator_thermostat(cls, v): - ''' - Make sure that if integrator is 'langevinmiddle' then thermostat must be False - ''' - integrator = v.get('integrator') - thermostat = v.get('thermostat', False) # Default to False if not provided - - if integrator == "langevinmiddle" and thermostat is not False: - raise ValueError("thermostat must be False when integrator is langevinmiddle") - return v + @_model_validator(mode="before") + def _validate_integrator_and_thermo(cls,v): + integrator = v.get("integrator") + thermostat = v.get("thermostat") + temperature = v.get("temperature", 25.0) # Use default value if None + pressure = v.get("pressure", 1.0) # Use default value if None - inverse_friction: float = _Field( - 1.0, - description="Inverse friction in picoseconds", - alias="inverse friction" - ) - temperature: float = _Field(25.0, description="Temperature in Celsius") + # 1) integrator='langevinmiddle' => thermostat must be False + # 2) integrator='leapfrogverlet' => thermostat must be True + if integrator == "langevinmiddle" and thermostat is True: + raise ValueError("If integrator is 'langevinmiddle', thermostat must be False.") + elif integrator == "leapfrogverlet" and thermostat is False: + raise ValueError("If integrator is 'leapfrogverlet', thermostat must be True.") - ### Barostat ### - barostat: bool = _Field(True, description="Enable barostat") - pressure: float = _Field(1.0, description="Pressure in atm") + # check temperature is in range [-200, 1000] + if not (-200 <= temperature <= 1000): + raise ValueError(f"Temperature must be between -200 and 1000 Celsius, got {temperature}") - ### Non-Bonded Interactions ### - cutoff_type: _Literal["PME", "cutoffperiodic"] = _Field( - "PME", - description="Type of cutoff to use. Options: PME, cutoffperiodic" - ) - cutoff_distance: float = _Field( - 10.0, # Default to PME cutoff distance - alias="cutoff distance", - description="Cutoff distance in angstroms" - ) - reaction_field_dielectric: float | None = _Field( - None, - alias="reaction field dielectric", - description="Reaction field dielectric constant(only for cutoffperiodic)" - ) - @_validator('cutoff_type') - def _check_cutoff_type(cls, v): - if v not in ['PME', 'cutoffperiodic']: - raise ValueError('cutoff type must be PME or cutoffperiodic') - return v - - @_validator('cutoff_distance', always=True) - def _set_cutoff_distance(cls, v, values): - if values.get('cutoff_type') == 'PME': - return 10.0 if v is None else v - elif values.get('cutoff_type') == 'cutoffperiodic': - return 12.0 if v is None else v - return v - - @_validator('reaction_field_dielectric',always=True) - def _set_reaction_field_dielectric(cls, v, values): - cutoff_type = values.get('cutoff_type') - if cutoff_type == 'PME' and v is not None: - raise ValueError('reaction field dielectric should not be provided when cutoff type is PME') - elif cutoff_type == 'cutoffperiodic' and v is None: - return 78.3 # Default dielectric constant for cutoffperiodic - return v - - @_field_validator("cutoff_distance", mode="before") - def _validate_cutoff_distance(cls, v, values): - """ - Validate cutoff distance based on cutoff type. - """ - cutoff_type = values.data.get("cutoff_type") - if cutoff_type == "cutoffperiodic": - return 12.0 # Default for cutoffperiodic - return v # Default for PME (10.0) + # check pressure is in range [0, 1000] atm + if not (0 <= pressure <= 1000): + raise ValueError("pressure must be in range [0, 1000] atm.") - model_config = _ConfigDict(arbitrary_types_allowed=True) - - def __init__(self, stream_log_level: int = _logging.INFO, **data): - _BaseModel.__init__(self, **data) - _EngineRunnerConfig.__init__(self, stream_log_level=stream_log_level) - self._validate_runtime_and_update_config() + return v - def get_config(self) -> _Dict[str, _Any]: + @_model_validator(mode="after") + def _check_charge_and_cutoff(self): """ - Get the SOMD configuration as a dictionary. - - Returns - ------- - config : Dict[str, Any] - The SOMD configuration dictionary. + Validate that if ligand_charge != 0, then cutoff_type must be PME. """ - return self.model_dump() + ligand_charge = self.ligand_charge + cutoff_type = self.cutoff_type + + if ligand_charge != 0 and cutoff_type != "PME": + raise ValueError( + "Ligand charge is non-zero. Must use PME for electrostatics." + ) + + return self def get_file_name(self) -> str: """ Get the name of the SOMD configuration file. - - Returns - ------- - file_name : str - The name of the SOMD configuration file. """ return "somd.cfg" - ### Trajectory ### - buffered_coords_freq: int = _Field( - 5000, - alias="buffered coordinates frequency", - description="Frequency of buffered coordinates output" - ) - center_solute: bool = _Field( - True, - alias="center solute", - description="Center solute in box" - ) - - ### Minimisation ### - minimise: bool = _Field(True, description="Perform energy minimisation") - - ### Restraints ### - use_boresch_restraints: bool = _Field( - False, - description="UseBoresch restraints mode" - ) - receptor_ligand_restraints: bool = _Field( - False, - description="Turn on receptor-ligand restraints mode" - ) - - ### Alchemistry - restraints added by a3fe ### - perturbed_residue_number: int = _Field( - 1, - alias="perturbed residue number", - description="Residue number to perturb" - ) - energy_frequency: int = _Field( - 200, - alias="energy frequency", - description="Frequency of energy output" - ) - - ### Lambda ### - lambda_array: _List[float] = _Field( - default_factory=list, - description="Lambda array for alchemical perturbation, varies from 0.0 to 1.0 across stages" - ) - lambda_val: _Optional[float] = _Field( - None, - description="Lambda value for current stage" - ) - - ### Alchemical files ### - morphfile: _Optional[str] = _Field( - None, description="Path to morph file containing alchemical transformation" - ) - topfile: _Optional[str] = _Field( - None, description="Path to topology file for the system" - ) - crdfile: _Optional[str] = _Field( - None, description="Path to coordinate file for the system" - ) - - extra_options: _Dict[str, str] = _Field( - default_factory=dict, - description="Extra options to pass to the SOMD engine" - ) - - def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str: + def get_somd_config(self, run_dir: str) -> str: """ Generates the SOMD configuration file and returns its path. @@ -287,16 +207,8 @@ def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str ---------- run_dir : str Directory to write the configuration file to. - - config_name : str, optional, default="somd_config" - Name of the configuration file to write. Note that when running many jobs from the - same directory, this should be unique to avoid overwriting the config file. - - Returns - ------- - str - Path to the generated configuration file. """ + config_filename = self.get_file_name() # First, generate the configuration string config_lines = [ "### Integrator ###", @@ -316,6 +228,11 @@ def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str "### Non-Bonded Interactions ###", f"cutoff type = {self.cutoff_type}", f"cutoff distance = {self.cutoff_distance} * angstrom", + ] + if self.cutoff_type == "cutoffperiodic" and self.reaction_field_dielectric is not None: + config_lines.append(f"reaction field dielectric = {self.reaction_field_dielectric}") + + config_lines.extend([ "", "### Trajectory ###", f"buffered coordinates frequency = {self.buffered_coords_freq}", @@ -326,8 +243,36 @@ def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str "", "### Alchemistry ###", f"perturbed residue number = {self.perturbed_residue_number}", - f"energy frequency = {self.energy_frequency}" - ] + f"energy frequency = {self.energy_frequency}", + f"ligand charge = {self.ligand_charge}", + "", + "### Restraints ###", + f"use boresch restraints = {self.use_boresch_restraints}", + f"turn on receptor-ligand restraints mode = {self.turn_on_receptor_ligand_restraints}" + ]) + # 2) Lambda parameters + config_lines.extend(["", "### Lambda / Alchemical Settings ###"]) + + if self.lambda_array: + lambda_str = ", ".join(str(x) for x in self.lambda_array) + config_lines.append(f"lambda array = {lambda_str}") + if self.lambda_val is not None: + config_lines.append(f"lambda_val = {self.lambda_val}") + + # 3) Alchemical files path + config_lines.extend(["", "### Alchemical Files ###"]) + if self.morphfile: + config_lines.append(f"morphfile = {self.morphfile}") + if self.topfile: + config_lines.append(f"topfile = {self.topfile}") + if self.crdfile: + config_lines.append(f"crdfile = {self.crdfile}") + + # 5) Boresch restraints + if self.boresch_restraints_dictionary is not None: + config_lines.extend(["", "### Boresch Restraints Dictionary ###"]) + config_lines.append(f"boresch restraints dictionary = {self.boresch_restraints_dictionary}") + # Add any extra options if self.extra_options: config_lines.extend(["", "### Extra Options ###"]) @@ -335,8 +280,51 @@ def get_somd_config(self, run_dir: str, config_name: str = "somd_config") -> str config_lines.append(f"{key} = {value}") # Write the configuration to a file - config_path = _os.path.join(run_dir, f"{config_name}.cfg") + config_path = _os.path.join(run_dir, config_filename) with open(config_path, "w") as f: f.write("\n".join(config_lines) + "\n") - + return config_path + + @classmethod + def _from_config_file(cls, config_path: str) -> "SomdConfig": + """Create a SomdConfig instance from an existing configuration file.""" + with open(config_path, "r") as f: + config_content = f.read() + + config_dict = {} + for line in config_content.split("\n"): + line = line.strip() + if line and not line.startswith("#") and "=" in line: + key, value = [x.strip() for x in line.split("=", 1)] + + # 处理 lambda array + if key == "lambda array": + value = [float(x.strip()) for x in value.split(",")] + # 处理带星号的值(如 "12*angstrom") + elif "*" in value: + value = value.split("*")[0].strip() + try: + value = float(value) + except ValueError: + pass + # 处理布尔值 + elif value.lower() == "true": + value = True + elif value.lower() == "false": + value = False + # 处理其他数值 + else: + try: + if "." in value: + value = float(value) + elif value.isdigit(): + value = int(value) + except ValueError: + pass # 保持为字符串 + + # 处理键名中的空格 + key = key.replace(" ", "_") + config_dict[key] = value + + return cls(**config_dict) \ No newline at end of file diff --git a/a3fe/tests/test_engine_configuration.py b/a3fe/tests/test_engine_configuration.py index 2f6ea2e..2273951 100644 --- a/a3fe/tests/test_engine_configuration.py +++ b/a3fe/tests/test_engine_configuration.py @@ -3,243 +3,171 @@ from tempfile import TemporaryDirectory import os import logging +import pytest from a3fe import SomdConfig -def create_test_input_dir(): - """Create a temporary directory with a mock somd.cfg file.""" - temp_dir = TemporaryDirectory() - with open(os.path.join(temp_dir.name, "somd.cfg"), "w") as f: - f.write("ncycles = 1000\n") - return temp_dir - -def test_create_config(): - """Test that the config can be created with different parameters.""" - with create_test_input_dir() as input_dir: - # Test with integer runtime - config = SomdConfig( - runtime=1, # Integer runtime - input_dir=input_dir - ) - assert isinstance(config, SomdConfig) - - # Test with float runtime - config = SomdConfig( - runtime=0.3, # Float runtime - input_dir=input_dir - ) - assert isinstance(config, SomdConfig) - - # Test with custom stream_log_level - config = SomdConfig( - runtime=1, - input_dir=input_dir, - stream_log_level=logging.DEBUG - ) - assert config._stream_log_level == logging.DEBUG - def test_config_yaml_save_and_load(): """Test that the config can be saved to and loaded from YAML.""" - with create_test_input_dir() as input_dir: - with TemporaryDirectory() as dirname: - config = SomdConfig( - runtime=1, # Integer runtime - input_dir=input_dir - ) - yaml_path = os.path.join(dirname, "config.yaml") - config.dump(yaml_path) - config2 = SomdConfig.load(yaml_path) - assert config == config2 + with TemporaryDirectory() as dirname: + config = SomdConfig(runtime=1) + yaml_path = os.path.join(dirname, "config.yaml") + config.dump(yaml_path) + config2 = SomdConfig.load(yaml_path) + assert config == config2 def test_get_somd_config(): """Test that the SOMD configuration file is generated correctly.""" - with create_test_input_dir() as input_dir: - with TemporaryDirectory() as dirname: - config = SomdConfig( - integrator="langevinmiddle", - nmoves=25000, - timestep=4.0, - runtime=1, # Integer runtime - input_dir=input_dir, - cutoff_type="PME", - thermostat=False - ) - config_path = config.get_somd_config( - run_dir=dirname, - config_name="test" - ) - assert config_path == os.path.join(dirname, "test.cfg") - - expected_config = ( - "### Integrator ###\n" - "nmoves = 25000\n" - "timestep = 4.0 * femtosecond\n" - "constraint = hbonds\n" - "hydrogen mass repartitioning factor = 3.0\n" - "integrator = langevinmiddle\n" - "inverse friction = 1.0 * picosecond\n" - "temperature = 25.0 * celsius\n" - "thermostat = False\n" - "\n" - "### Barostat ###\n" - "barostat = True\n" - "pressure = 1.0 * atm\n" - "\n" - "### Non-Bonded Interactions ###\n" - "cutoff type = PME\n" - "cutoff distance = 10.0 * angstrom\n" - "\n" - "### Trajectory ###\n" - "buffered coordinates frequency = 5000\n" - "center solute = True\n" - "\n" - "### Minimisation ###\n" - "minimise = True\n" - "\n" - "### Alchemistry ###\n" - "perturbed residue number = 1\n" - "energy frequency = 200\n" - ) - - with open(config_path, "r") as f: - config_content = f.read() - - assert config_content == expected_config - -def test_constraint_validation(): - """Test constraint type validation.""" - with create_test_input_dir() as input_dir: + with TemporaryDirectory() as dirname: config = SomdConfig( - constraint="hbonds", + integrator="langevinmiddle", + nmoves=25000, + timestep=4.0, runtime=1, # Integer runtime - input_dir=input_dir - ) - assert config.constraint == "hbonds" - - config = SomdConfig( - constraint="all-bonds", - runtime=0.3, # Float runtime - input_dir=input_dir - ) - assert config.constraint == "all-bonds" - - try: - SomdConfig( - constraint="invalid", - runtime=1, # Integer runtime - input_dir=input_dir + cutoff_type="PME", + thermostat=False ) - assert False, "Should raise ValueError" - except ValueError: - pass + config_path = config.get_somd_config(run_dir=dirname) + assert config_path == os.path.join(dirname, "somd.cfg") + + with open(config_path, "r") as f: + config_content = f.read() + + assert "integrator = langevinmiddle" in config_content + assert "nmoves = 25000" in config_content + assert "cutoff type = PME" in config_content + assert "thermostat = False" in config_content + +@pytest.mark.parametrize("integrator,thermostat,should_pass", [ + ("langevinmiddle", False, True), + ("langevinmiddle", True, False), + ("leapfrogverlet", True, True), + ("leapfrogverlet", False, False), +]) +def test_integrator_thermostat_validation(integrator, thermostat, should_pass): + """Test integrator and thermostat combination validation.""" + if should_pass: + config = SomdConfig(integrator=integrator, thermostat=thermostat, runtime=1) + assert config.integrator == integrator + assert config.thermostat == thermostat + else: + with pytest.raises(ValueError): + SomdConfig(integrator=integrator, thermostat=thermostat, runtime=1) + +@pytest.mark.parametrize("charge,cutoff,should_pass", [ + (0, "cutoffperiodic", True), + (0, "PME", True), + (1, "PME", True), + (-1, "PME", True), + (1, "cutoffperiodic", False), + (-1, "cutoffperiodic", False), +]) +def test_charge_cutoff_validation(charge, cutoff, should_pass): + """ + Test ligand charge & cutoff type combination validation: + if ligand_charge!=0 => must use PME. + """ + if should_pass: + config = SomdConfig(ligand_charge=charge, cutoff_type=cutoff, runtime=1) + assert config.ligand_charge == charge + assert config.cutoff_type == cutoff + else: + with pytest.raises(ValueError): + SomdConfig(ligand_charge=charge, cutoff_type=cutoff, runtime=1) -def test_cutoff_type_validation(): - """Test cutoff type and related parameters validation.""" - with create_test_input_dir() as input_dir: - # Test PME cutoff type +def test_get_somd_config_with_extra_options(): + """ + Test SOMD config generation with some extra_options. + """ + with TemporaryDirectory() as dirname: config = SomdConfig( + integrator="langevinmiddle", + nmoves=25000, + timestep=4.0, + runtime=1, cutoff_type="PME", - runtime=1, # Integer runtime - input_dir=input_dir - ) - assert config.cutoff_type == "PME" - assert config.cutoff_distance == 10.0 - assert config.reaction_field_dielectric is None - - # Test cutoffperiodic type - config = SomdConfig( - cutoff_type="cutoffperiodic", - runtime=0.3, # Float runtime - input_dir=input_dir + thermostat=False, + extra_options={"custom_option": "value"} ) - assert config.cutoff_type == "cutoffperiodic" - assert config.cutoff_distance == 12.0 - assert config.reaction_field_dielectric == 78.3 - -def test_integrator_thermostat_validation(): - """Test integrator and thermostat validation.""" - with create_test_input_dir() as input_dir: - # Test valid configuration + path = config.get_somd_config(run_dir=dirname) + with open(path, "r") as f: + content = f.read() + assert "### Extra Options ###" in content + assert "custom_option = value" in content + +def test_compare_with_reference_config(): + """Test that we can generate a config file that matches a reference config.""" + reference_lines = [ + "nmoves = 25000", + "timestep = 4.0 * femtosecond", + "constraint = hbonds", + "hydrogen mass repartitioning factor = 3.0", + "integrator = langevinmiddle", + "inverse friction = 1.0 * picosecond", + "temperature = 25.0 * celsius", + "thermostat = False", + "barostat = True", + "pressure = 1.0 * atm", + "cutoff type = cutoffperiodic", + "cutoff distance = 12.0 * angstrom", + "reaction field dielectric = 78.3", + "buffered coordinates frequency = 5000", + "center solute = True", + "minimise = True", + "use boresch restraints = True", + "turn on receptor-ligand restraints mode = True", + "perturbed residue number = 1", + "energy frequency = 200", + "lambda array = 0.0, 0.125, 0.25, 0.375, 0.5, 1.0" + ] + with TemporaryDirectory() as dirname: config = SomdConfig( + nmoves=25000, + timestep=4.0, + runtime=1, + constraint="hbonds", + hydrogen_mass_factor=3.0, integrator="langevinmiddle", + inverse_friction=1.0, + temperature=25.0, thermostat=False, - runtime=1, # Integer runtime - input_dir=input_dir - ) - assert config.integrator == "langevinmiddle" - assert config.thermostat is False - - # Test invalid configuration - try: - SomdConfig( - integrator="langevinmiddle", - thermostat=True, - runtime=1, # Integer runtime - input_dir=input_dir - ) - assert False, "Should raise ValueError" - except ValueError: - pass - -def test_lambda_validation(): - """Test lambda array and value validation.""" - with create_test_input_dir() as input_dir: - config = SomdConfig( - lambda_array=[0.0, 0.5, 1.0], - lambda_val=0.5, - runtime=1, # Integer runtime - input_dir=input_dir - ) - assert config.lambda_array == [0.0, 0.5, 1.0] - assert config.lambda_val == 0.5 - -def test_restraints_configuration(): - """Test restraints configuration options.""" - with create_test_input_dir() as input_dir: - config = SomdConfig( + barostat=True, + pressure=1.0, + cutoff_type="cutoffperiodic", + cutoff_distance=12.0, + reaction_field_dielectric=78.3, + buffered_coords_freq=5000, + center_solute=True, + minimise=True, use_boresch_restraints=True, - receptor_ligand_restraints=True, - runtime=1, # Integer runtime - input_dir=input_dir - ) - assert config.use_boresch_restraints is True - assert config.receptor_ligand_restraints is True - -def test_alchemical_files(): - """Test alchemical transformation file paths.""" - with create_test_input_dir() as input_dir: - config = SomdConfig( - morphfile="/path/to/morph.pert", - topfile="/path/to/system.top", - crdfile="/path/to/system.crd", - runtime=1, # Integer runtime - input_dir=input_dir + turn_on_receptor_ligand_restraints=True, + perturbed_residue_number=1, + energy_frequency=200, + lambda_array=[0.0, 0.125, 0.25, 0.375, 0.5, 1.0] ) - assert config.morphfile == "/path/to/morph.pert" - assert config.topfile == "/path/to/system.top" - assert config.crdfile == "/path/to/system.crd" - -def test_get_somd_config_with_extra_options(): - """Test SOMD configuration generation with extra options.""" - with create_test_input_dir() as input_dir: - with TemporaryDirectory() as dirname: - config = SomdConfig( - integrator="langevinmiddle", - nmoves=25000, - timestep=4.0, - runtime=1, # Integer runtime - input_dir=input_dir, - cutoff_type="PME", - thermostat=False, - extra_options={"custom_option": "value"} - ) - config_path = config.get_somd_config( - run_dir=dirname, - config_name="test_extra" - ) - - with open(config_path, "r") as f: - config_content = f.read() - - assert "### Extra Options ###" in config_content - assert "custom_option = value" in config_content \ No newline at end of file + cfg_path = config.get_somd_config(run_dir=dirname) + with open(cfg_path, "r") as f: + cfg_content = f.read() + for line in reference_lines: + assert line in cfg_content, f"Expected '{line}' in generated config." + +def test_copy_from_existing_config(): + """Test that we can copy from an existing somd.cfg file.""" + reference_config = "/home/roy/software/deve/a3fe/a3fe/data/example_run_dir/bound/discharge/output/lambda_0.000/run_01/somd.cfg" + if not os.path.isfile(reference_config): + pytest.skip("Reference config not found, skipping test.") + c = SomdConfig._from_config_file(reference_config) + + assert c.use_boresch_restraints is True + assert c.turn_on_receptor_ligand_restraints is False + assert c.topfile.endswith("somd.prm7") + expected_lambda = [0.0, 0.068, 0.137, 0.199, 0.261, 0.317, 0.368, 0.419, 0.472, + 0.524, 0.577, 0.627, 0.677, 0.727, 0.775, 0.824, 0.877, 0.938, 1.0] + assert c.lambda_array == expected_lambda # Boresch restraints dictionary + expected_boresch_dict = ( + '{"anchor_points":{"r1":4900, "r2":4888, "r3":4902, "l1":3, "l2":5, "l3":11}, ' + '"equilibrium_values":{"r0":7.67, "thetaA0":2.55, "thetaB0":1.48,"phiA0":-0.74, ' + '"phiB0":-1.53, "phiC0":3.09}, "force_constants":{"kr":3.74, "kthetaA":28.06, ' + '"kthetaB":9.98, "kphiA":16.70, "kphiB":24.63, "kphiC":5.52}}' + ) + assert c.boresch_restraints_dictionary == expected_boresch_dict From f0730aa4ec0c762f65b2b3bc75427a89eb1a1ec6 Mon Sep 17 00:00:00 2001 From: Roy-Haolin-Du Date: Sun, 12 Jan 2025 19:51:55 +0000 Subject: [PATCH 7/8] engine_confi and test --- a3fe/configuration/engine_config.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/a3fe/configuration/engine_config.py b/a3fe/configuration/engine_config.py index 0f20f42..535f572 100644 --- a/a3fe/configuration/engine_config.py +++ b/a3fe/configuration/engine_config.py @@ -298,22 +298,18 @@ def _from_config_file(cls, config_path: str) -> "SomdConfig": if line and not line.startswith("#") and "=" in line: key, value = [x.strip() for x in line.split("=", 1)] - # 处理 lambda array if key == "lambda array": value = [float(x.strip()) for x in value.split(",")] - # 处理带星号的值(如 "12*angstrom") elif "*" in value: value = value.split("*")[0].strip() try: value = float(value) except ValueError: pass - # 处理布尔值 elif value.lower() == "true": value = True elif value.lower() == "false": value = False - # 处理其他数值 else: try: if "." in value: @@ -321,9 +317,7 @@ def _from_config_file(cls, config_path: str) -> "SomdConfig": elif value.isdigit(): value = int(value) except ValueError: - pass # 保持为字符串 - - # 处理键名中的空格 + pass key = key.replace(" ", "_") config_dict[key] = value From 8d7b957bec466f63c9e947f17f18cdfc003a2783 Mon Sep 17 00:00:00 2001 From: Roy-Haolin-Du Date: Thu, 16 Jan 2025 15:27:00 +0000 Subject: [PATCH 8/8] only consult somd_config object, not read from files --- a3fe/configuration/engine_config.py | 109 ++++++++++++++------- a3fe/run/_simulation_runner.py | 92 ++---------------- a3fe/run/calculation.py | 19 +--- a3fe/run/enums.py | 8 ++ a3fe/run/lambda_window.py | 2 +- a3fe/run/leg.py | 61 +++++------- a3fe/run/simulation.py | 122 +++++++++++------------- a3fe/run/stage.py | 21 ++-- a3fe/tests/test_engine_configuration.py | 30 +++++- 9 files changed, 218 insertions(+), 246 deletions(-) diff --git a/a3fe/configuration/engine_config.py b/a3fe/configuration/engine_config.py index 535f572..6b3d8a5 100644 --- a/a3fe/configuration/engine_config.py +++ b/a3fe/configuration/engine_config.py @@ -6,42 +6,65 @@ import os as _os from decimal import Decimal as _Decimal -import logging as _logging -from typing import Dict as _Dict, Literal as _Literal, List as _List, Union as _Union, Optional as _Optional, Any as _Any -from math import isclose as _isclose +from typing import Dict as _Dict, Literal as _Literal, List as _List, Union as _Union, Optional as _Optional from pydantic import ( BaseModel as _BaseModel, Field as _Field, - ConfigDict as _ConfigDict, field_validator as _field_validator, model_validator as _model_validator, ValidationInfo as _ValidationInfo ) from ._engine_runner_config import EngineRunnerConfig as _EngineRunnerConfig -DEFAULT_LAM_VALS_SOMD = { - "BOUND": { - "RESTRAIN": [0.0, 1.0], - "DISCHARGE": [0.0, 0.291, 0.54, 0.776, 1.0], - "VANISH": [ - 0.0, 0.026, 0.054, 0.083, 0.111, 0.14, 0.173, 0.208, - 0.247, 0.286, 0.329, 0.373, 0.417, 0.467, 0.514, - 0.564, 0.623, 0.696, 0.833, 1.0, - ], - }, - "FREE": { - "DISCHARGE": [0.0, 0.222, 0.447, 0.713, 1.0], - "VANISH": [ - 0.0, 0.026, 0.055, 0.09, 0.126, 0.164, 0.202, 0.239, - 0.276, 0.314, 0.354, 0.396, 0.437, 0.478, 0.518, - 0.559, 0.606, 0.668, 0.762, 1.0, - ], - }, -} +def _get_default_lambda_array(leg: str = "BOUND", stage: str = "VANISH") -> _List[float]: + + default_values = { + "BOUND": { + "RESTRAIN": [0.0, 1.0], + "DISCHARGE": [0.0, 0.291, 0.54, 0.776, 1.0], + "VANISH": [ + 0.0, 0.026, 0.054, 0.083, 0.111, 0.14, 0.173, 0.208, + 0.247, 0.286, 0.329, 0.373, 0.417, 0.467, 0.514, + 0.564, 0.623, 0.696, 0.833, 1.0, + ], + }, + "FREE": { + "DISCHARGE": [0.0, 0.222, 0.447, 0.713, 1.0], + "VANISH": [ + 0.0, 0.026, 0.055, 0.09, 0.126, 0.164, 0.202, 0.239, + 0.276, 0.314, 0.354, 0.396, 0.437, 0.478, 0.518, + 0.559, 0.606, 0.668, 0.762, 1.0, + ], + }, + } + return default_values[leg][stage] + class SomdConfig(_EngineRunnerConfig, _BaseModel): """ Pydantic model for holding SOMD engine configuration. """ + default_lambda_values: _Dict[str, _Dict[str, _List[float]]] = _Field( + default={ + "BOUND": { + "RESTRAIN": [0.0, 1.0], + "DISCHARGE": [0.0, 0.291, 0.54, 0.776, 1.0], + "VANISH": [ + 0.0, 0.026, 0.054, 0.083, 0.111, 0.14, 0.173, 0.208, + 0.247, 0.286, 0.329, 0.373, 0.417, 0.467, 0.514, + 0.564, 0.623, 0.696, 0.833, 1.0, + ], + }, + "FREE": { + "DISCHARGE": [0.0, 0.222, 0.447, 0.713, 1.0], + "VANISH": [ + 0.0, 0.026, 0.055, 0.09, 0.126, 0.164, 0.202, 0.239, + 0.276, 0.314, 0.354, 0.396, 0.437, 0.478, 0.518, + 0.559, 0.606, 0.668, 0.762, 1.0, + ], + }, + }, + description="Default lambda values for each leg and stage type" + ) ### Integrator - ncycles modified as required by a3fe ### nmoves: int = _Field(25000, description="Number of moves per cycle") timestep: float = _Field(4.0, description="Timestep in femtoseconds(fs)") @@ -83,9 +106,13 @@ class SomdConfig(_EngineRunnerConfig, _BaseModel): perturbed_residue_number: int = _Field(1,alias="perturbed residue number",ge=1, description="Residue number to perturb. Must be >= 1") energy_frequency: int = _Field(200,alias="energy frequency",description="Frequency of energy output") ligand_charge: int = _Field(0, description="Net charge of the ligand. If non-zero, must use PME for electrostatics.") + charge_difference: int = _Field(0, description="Charge difference between the ligand and the system. If non-zero, must use co-alchemical ion approach.") ### Lambda ### - lambda_array: _List[float] = _Field(default_factory=list,description="Lambda array for alchemical perturbation, varies from 0.0 to 1.0 across stages") + lambda_array: _List[float] = _Field( + default_factory=_get_default_lambda_array, + description="Lambda array for alchemical perturbation, varies from 0.0 to 1.0 across stages" + ) lambda_val: _Optional[float] = _Field(None, description="Lambda value for current stage") ### Alchemical files ### @@ -129,19 +156,25 @@ def _check_cutoff_values(self): # 1) Only warn if user set reaction_field_dielectric != 78.3 if cutoff_type == "cutoffperiodic" and rfd != 78.3: - warnings.warn( + self._logger.warning( "You have cutoff_type=cutoffperiodic but set a reaction_field_dielectric. " f"This value ({rfd}) will be ignored by the engine." ) # 2) Only warn if user picks e.g. cutoff_distance < 6 or > 18 if cutoff_type == "PME" and not (6.0 <= cutoff_distance <= 18.0): - warnings.warn( - f"For PME, we recommend cutoff_distance in [6.0, 18.0], but you have {cutoff_distance}." + self._logger.warning( + f"For PME, we recommend cutoff_distance in [6.0, 18.0], but you have {cutoff_distance}. " "we'll still accept it." ) return self - + + @_model_validator(mode="after") + def _check_charge_difference(self): + if self.charge_difference != 0 and self.cutoff_type != "PME": + raise ValueError("Charge difference is non-zero but cutoff type is not PME.") + return self + @_field_validator('constraint') def _check_constraint(cls, v): if v not in ['hbonds', 'allbonds']: @@ -244,7 +277,8 @@ def get_somd_config(self, run_dir: str) -> str: "### Alchemistry ###", f"perturbed residue number = {self.perturbed_residue_number}", f"energy frequency = {self.energy_frequency}", - f"ligand charge = {self.ligand_charge}", + f"ligand charge = {self.ligand_charge}", + f"charge difference = {self.charge_difference}", "", "### Restraints ###", f"use boresch restraints = {self.use_boresch_restraints}", @@ -254,8 +288,8 @@ def get_somd_config(self, run_dir: str) -> str: config_lines.extend(["", "### Lambda / Alchemical Settings ###"]) if self.lambda_array: - lambda_str = ", ".join(str(x) for x in self.lambda_array) - config_lines.append(f"lambda array = {lambda_str}") + lambda_str = f"[{', '.join(str(x) for x in self.lambda_array)}]" + config_lines.append(f"lambda_array = {lambda_str}") if self.lambda_val is not None: config_lines.append(f"lambda_val = {self.lambda_val}") @@ -321,4 +355,15 @@ def _from_config_file(cls, config_path: str) -> "SomdConfig": key = key.replace(" ", "_") config_dict[key] = value - return cls(**config_dict) \ No newline at end of file + return cls(**config_dict) + + def write_somd_config(self, output_dir: str, lam: float) -> str: + """ + Write out the SOMD configuration file with the current settings. + """ + self.lambda_val = lam + + # Generate somd.cfg file using the current configuration + config_path = self.get_somd_config(run_dir=output_dir) + + return config_path diff --git a/a3fe/run/_simulation_runner.py b/a3fe/run/_simulation_runner.py index 84edee9..0bfbbb5 100644 --- a/a3fe/run/_simulation_runner.py +++ b/a3fe/run/_simulation_runner.py @@ -7,7 +7,6 @@ import os as _os import pathlib as _pathlib import pickle as _pkl -import shutil as _shutil import subprocess as _subprocess from abc import ABC from itertools import count as _count @@ -30,6 +29,7 @@ from ..configuration import SlurmConfig as _SlurmConfig from ..configuration import SomdConfig as _SomdConfig +from .._version import __version__ as _version class SimulationRunner(ABC): """An abstract base class for simulation runners. Note that @@ -100,6 +100,11 @@ def __init__( dump: bool, Optional, default: True If True, the state of the simulation runner is saved to a pickle file. """ + # Set the version of the simulation runner + self._logger = _logging.getLogger(self.__class__.__name__) + self._version = _version + self._logger.info(f"Initializing simulation runner with A3fe version: {self._version}") + # Set up the directories (which may be overwritten if the # simulation runner is subsequently loaded from a pickle file) # Make sure that we always use absolute paths @@ -176,13 +181,9 @@ def __init__( ) # Create the SOMD config with default values if none is provided - if engine_config is None: - self.engine_config = _SomdConfig( - runtime=2.5, # Default runtime of 2.5 ns - input_dir=self.input_dir # Use the simulation runner's input directory - ) - else: - self.engine_config = engine_config + self.engine_config = engine_config if engine_config is not None else _SomdConfig( + input_dir=self.input_dir # Use the simulation runner's input directory + ) # Save state if dump: @@ -1086,78 +1087,3 @@ def _load(self, update_paths: bool = True) -> None: # Record that the object was loaded from a pickle file self.loaded_from_pickle = True - - def _update_logging_options( - self, - stream_log_level: _Optional[int] = None, - file_log_level: _Optional[int] = None, - track_time: bool = False, - ) -> None: - """ - Update logging options for this simulation runner and all sub-runners recursively. - - Parameters - ---------- - stream_log_level : int, Optional - The log level for the stream handler. If None, keeps current level. - file_log_level : int, Optional - The log level for the file handler. If None, keeps current level. - track_time : bool, default=False - Whether to track time for logging operations. - """ - if stream_log_level is not None: - self.recursively_set_attr("stream_log_level", stream_log_level, force=True) - - if hasattr(self, "_logger"): - # Update existing logger settings - for handler in self._logger.handlers: - if isinstance(handler, _logging.StreamHandler) and stream_log_level is not None: - handler.setLevel(stream_log_level) - elif isinstance(handler, _logging.FileHandler) and file_log_level is not None: - handler.setLevel(file_log_level) - - # Add time tracking if requested - if track_time: - self._logger.track_time = True - - # Recursively update sub-runners - if hasattr(self, "_sub_sim_runners"): - for sub_runner in self._sub_sim_runners: - sub_runner._update_logging_options( - stream_log_level=stream_log_level, - file_log_level=file_log_level, - track_time=track_time - ) - - def _copy_to_test(self, test_dir: str) -> None: - """ - Copy generated files to a test directory for validation. - - Parameters - ---------- - test_dir : str - Path to the test directory where files should be copied. - """ - if not _os.path.exists(test_dir): - _os.makedirs(test_dir) - - self._logger.info(f"Copying generated files to test directory: {test_dir}") - - # Copy all files from the run directory to test directory - for root, _, files in _os.walk(self.run_dir): - for file in files: - src_path = _os.path.join(root, file) - # Get relative path from run_dir - rel_path = _os.path.relpath(src_path, self.run_dir) - dst_path = _os.path.join(test_dir, rel_path) - - # Create destination directory if it doesn't exist - dst_dir = _os.path.dirname(dst_path) - if not _os.path.exists(dst_dir): - _os.makedirs(dst_dir) - - # Copy the file - _shutil.copy2(src_path, dst_path) - self._logger.debug(f"Copied {rel_path} to test directory") - - self._logger.info("Finished copying files to test directory") diff --git a/a3fe/run/calculation.py b/a3fe/run/calculation.py index e9ae785..090e088 100644 --- a/a3fe/run/calculation.py +++ b/a3fe/run/calculation.py @@ -13,7 +13,6 @@ from .enums import LegType as _LegType from .enums import PreparationStage as _PreparationStage from .leg import Leg as _Leg -from .._version import __version__ as _version from ..configuration import ( SystemPreparationConfig as _SystemPreparationConfig, SlurmConfig as _SlurmConfig, @@ -100,10 +99,7 @@ def __init__( Returns ------- None - """ - # Store the version - self.version = _version - + """ super().__init__( base_dir=base_dir, input_dir=input_dir, @@ -130,8 +126,6 @@ def __init__( self._update_log() self._dump() - self._logger.info(f"Initializing calculation with A3fe version: {self.version}") - @property def legs(self) -> _List[_Leg]: return self._sub_sim_runners @@ -198,22 +192,13 @@ def setup( """ if self.setup_complete: - self._logger.debug("Setup already complete. Skipping...") + self._logger.info("Setup already complete. Skipping...") return configs = { _LegType.BOUND: bound_leg_sysprep_config, _LegType.FREE: free_leg_sysprep_config, } - - # Set up logging options for detailed setup tracking - self._update_logging_options( - stream_log_level=self.stream_log_level, - file_log_level=_logging.DEBUG, # Always keep detailed logs in file - track_time=True # Track time during setup - ) - - self._logger.info(f"Setting up calculation with A3fe version: {self.version}") self._logger.info("Starting calculation setup...") setup_start = _time.time() diff --git a/a3fe/run/enums.py b/a3fe/run/enums.py index 3ca7d4d..b40094b 100644 --- a/a3fe/run/enums.py +++ b/a3fe/run/enums.py @@ -82,6 +82,10 @@ def bss_perturbation_type(self) -> str: else: raise ValueError("Unknown stage type.") + @property + def config_key(self) -> str: + return self.name + class LegType(_YamlSerialisableEnum): """The type of leg in the calculation.""" @@ -89,6 +93,10 @@ class LegType(_YamlSerialisableEnum): BOUND = 1 FREE = 2 + @property + def config_key(self) -> str: + return self.name + class PreparationStage(_YamlSerialisableEnum): """The stage of preparation of the input files.""" diff --git a/a3fe/run/lambda_window.py b/a3fe/run/lambda_window.py index b8fae28..ed56781 100644 --- a/a3fe/run/lambda_window.py +++ b/a3fe/run/lambda_window.py @@ -195,7 +195,7 @@ def __init__( stream_log_level=stream_log_level, slurm_config=self.slurm_config, analysis_slurm_config=self.analysis_slurm_config, - engine_config=self.engine_config, + engine_config=self.engine_config.copy() if engine_config else None ) ) diff --git a/a3fe/run/leg.py b/a3fe/run/leg.py index 51b8a18..817bf05 100644 --- a/a3fe/run/leg.py +++ b/a3fe/run/leg.py @@ -253,7 +253,7 @@ def setup( system = self.run_ensemble_equilibration(sysprep_config=cfg) # Write input files - self.write_input_files(system, config=cfg) + self.setup_stages(system, config=cfg) # Make sure the stored restraints reflect the restraints used. TODO: # make this more robust my using the SOMD functionality to extract @@ -281,7 +281,7 @@ def setup( stream_log_level=self.stream_log_level, slurm_config=self.slurm_config, analysis_slurm_config=self.analysis_slurm_config, - engine_config=self.engine_config, + engine_config=self.engine_config.copy() if self.engine_config else None ) ) @@ -717,13 +717,13 @@ def run_ensemble_equilibration( else: # Free leg return pre_equilibrated_system - def write_input_files( + def setup_stages( self, pre_equilibrated_system: _BSS._SireWrappers._system.System, # type: ignore config: _SystemPreparationConfig, - ) -> None: + ) -> _Dict[_StageType, _SomdConfig]: """ - Write the required input files to all of the stage input directories. + Set up the SOMD configurations for each stage of the leg. Parameters ---------- @@ -732,24 +732,21 @@ def write_input_files( are then used as input for each of the individual runs. config: SystemPreparationConfig Configuration object for the setup of the leg. + + Returns + ------- + Dict[StageType, SomdConfig] + Dictionary mapping stage types to their SOMD configurations """ # Get the charge of the ligand lig = _get_single_mol(pre_equilibrated_system, "LIG") lig_charge = round(lig.charge().value()) - # If we have a charged ligand, make sure that SOMD is using PME - if lig_charge != 0: - cutoff_type = self.engine_config.cutoff_type - if cutoff_type != "PME": - raise ValueError( - f"The ligand has a non-zero charge ({lig_charge}), so SOMD must use PME for the electrostatics. " - "Please set the 'cutoff type' option in the engine_config to 'PME'." - ) - + if lig_charge != 0: self._logger.info( - f"Ligand has charge {lig_charge}. Using co-alchemical ion approach to maintain neutrality." + f"The ligand has a charge of {lig_charge}. Using co-alchemical ion approach to maintain neutrality. " + "Please note: The cutoff type should be PME and the cutoff length can be adjusted to other values, e.g., 10 Å." ) - # Figure out where the ligand is in the system perturbed_resnum = pre_equilibrated_system.getIndex(lig) + 1 @@ -759,9 +756,10 @@ def write_input_files( if not hasattr(self, "stage_input_dirs"): raise AttributeError("No stage input directories have been set.") + stage_configs = {} for stage_type, stage_input_dir in self.stage_input_dirs.items(): self._logger.info( - f"Writing input files for {self.leg_type.name} leg {stage_type.name} stage" + f"Setting up {self.leg_type.name} leg {stage_type.name} stage" ) restraint = self.restraints[0] if self.leg_type == _LegType.BOUND else None protocol = _BSS.Protocol.FreeEnergy( @@ -805,31 +803,24 @@ def write_input_files( _shutil.copy( restraint_file, f"{stage_input_dir}/restraint_{i + 1}.txt" ) - - # Update the somd.cfg file with the perturbed residue number generated - # by BSS, as well as the restraints options - - # generate the somd.cfg file - somd_config = self.engine_config.get_somd_config( - run_dir=stage_input_dir, - config_name="somd" - ) + # Create a new config for this stage + stage_config = self.engine_config.copy() # Set configuration options - somd_config.perturbed_residue_number = perturbed_resnum - somd_config.use_boresch_restraints = self.leg_type == _LegType.BOUND - somd_config.turn_on_receptor_ligand_restraints_mode = self.leg_type == _LegType.BOUND - somd_config.charge_difference = -lig_charge # Use co-alchemical ion approach when there is a charge difference + stage_config.perturbed_residue_number = perturbed_resnum + stage_config.use_boresch_restraints = self.leg_type == _LegType.BOUND + stage_config.turn_on_receptor_ligand_restraints = self.leg_type == _LegType.BOUND + stage_config.charge_difference = -lig_charge # Use co-alchemical ion approach when there is a charge difference - # Set the default lambda windows based on the leg and stage types - lam_vals = config.lambda_values[self.leg_type][stage_type] - somd_config.lambda_array = lam_vals + # Set lambda values from default dictionary + stage_config.lambda_array = stage_config.default_lambda_values[self.leg_type.config_key][stage_type.config_key] - # Write the updated configuration - somd_config.write(stage_input_dir) + stage_configs[stage_type] = stage_config # We no longer need to store the large BSS restraint classes. self._lighten_restraints() + + return stage_configs def _run_slurm( self, sys_prep_fn: _Callable, wait: bool, run_dir: str, job_name: str diff --git a/a3fe/run/simulation.py b/a3fe/run/simulation.py index 6a96a12..68e3c9e 100644 --- a/a3fe/run/simulation.py +++ b/a3fe/run/simulation.py @@ -14,8 +14,6 @@ import numpy as _np from sire.units import k_boltz as _k_boltz -from ..read._process_somd_files import read_simfile_option as _read_simfile_option -from ..read._process_somd_files import write_simfile_option as _write_simfile_option from ._simulation_runner import SimulationRunner as _SimulationRunner from ._virtual_queue import Job as _Job from ._virtual_queue import VirtualQueue as _VirtualQueue @@ -132,7 +130,7 @@ def __init__( # as well as the restraints if necessary self._update_simfile() # Now read useful parameters from the simulation file options - self._add_attributes_from_simfile() + #self._add_attributes_from_simfile() # Get slurm file base self._get_slurm_file_base() @@ -213,26 +211,13 @@ def _validate_input(self) -> None: if not _os.path.isfile(_os.path.join(self.input_dir, file)): raise FileNotFoundError("Required input file " + file + " not found.") - def _add_attributes_from_simfile(self) -> None: - """ - Read the SOMD simulation option file and - add useful attributes to the Simulation object. - All times in ns. - """ - - timestep = None # ns - nrg_freq = None # number of timesteps between energy calculations - timestep = float( - _read_simfile_option(self.simfile_path, "timestep").split()[0] - ) # Need to remove femtoseconds from the end - nrg_freq = float(_read_simfile_option(self.simfile_path, "energy frequency")) - - self.timestep = timestep / 1_000_000 # fs to ns - self.nrg_freq = nrg_freq - def _select_input_files(self) -> None: """Select the correct rst7 and, if supplied, restraints, according to the run number.""" + + # First ensure the most up-to-date SOMD configuration is loaded + self.engine_config.get_somd_config(self.input_dir) + # Check if we have multiple rst7 files, or only one rst7_files = _glob.glob(_os.path.join(self.input_dir, "*.rst7")) if len(rst7_files) == 0: @@ -279,44 +264,7 @@ def _update_simfile(self) -> None: """Set the lambda value in the simulation file, as well as some paths to input files.""" - # Check that the lambda value has been set - if not hasattr(self, "lam"): - raise AttributeError("Lambda value not set for simulation") - - # Check that the set lambda value is in the list of lamvals in the simfile - lamvals = [ - float(lam_val) - for lam_val in _read_simfile_option( - self.simfile_path, "lambda array" - ).split(",") - ] - if self.lam not in lamvals: - raise ValueError( - f"Lambda value {self.lam} not in list of lambda values in simfile" - ) - - # Set the lambda value in the simfile - _write_simfile_option(self.simfile_path, "lambda_val", str(self.lam)) - - # Set the paths to the input files - input_paths = { - "morphfile": "somd.pert", - "topfile": "somd.prm7", - "crdfile": "somd.rst7", - } - - for option, name in input_paths.items(): - _write_simfile_option( - self.simfile_path, option, _os.path.join(self.input_dir, name) - ) - - # Add the restraints file if it exists - if _os.path.isfile(_os.path.join(self.input_dir, "restraint.txt")): - with open(_os.path.join(self.input_dir, "restraint.txt"), "r") as f: - restraint = f.readlines()[0].split("=")[1] - _write_simfile_option( - self.simfile_path, "boresch restraints dictionary", restraint - ) + self.engine_config.write_somd_config(self.output_dir, self.lam) def _get_slurm_file_base(self) -> None: """Find out what the slurm output file will be called and save it.""" @@ -338,6 +286,23 @@ def run(self, runtime: float = 2.5) -> None: ------- None """ + # update engine_config + self.engine_config.lambda_val = self.lam + self.engine_config.morphfile = _os.path.join(self.input_dir, "somd.pert") + self.engine_config.topfile = _os.path.join(self.input_dir, "somd.prm7") + self.engine_config.crdfile = _os.path.join(self.input_dir, "somd.rst7") + + # if restraint file exists, read and set + restraint_file = _os.path.join(self.input_dir, "restraint.txt") + if _os.path.isfile(restraint_file): + with open(restraint_file, "r") as f: + content = f.read().strip() + if "=" in content: + restraint = content.split("=", 1)[1].strip() + self.engine_config.boresch_restraints_dictionary = restraint + + # write config to output directory + self.engine_config.write_somd_config(self.output_dir, self.lam) # Run SOMD - note that command excludes sbatch as this is added by the virtual queue cmd = f"somd-freenrg -C somd.cfg -l {self.lam} -p CUDA" @@ -560,7 +525,7 @@ def read_gradients( def update_paths(self, old_sub_path: str, new_sub_path: str) -> None: """ Replace the old sub-path with the new sub-path in the base, input, and output directory - paths. Also update the slurm file base and the paths in the simfile. + paths. Also update the slurm file base. Parameters ---------- @@ -571,8 +536,7 @@ def update_paths(self, old_sub_path: str, new_sub_path: str) -> None: """ super().update_paths(old_sub_path, new_sub_path) - # Also need to update the slurm file base and the paths in the simfile - self.simfile_path = _os.path.join(self.base_dir, "somd.cfg") + # Also need to update the slurm file base if self.slurm_file_base: self.slurm_file_base = self.slurm_file_base.replace( old_sub_path, new_sub_path @@ -585,13 +549,41 @@ def update_paths(self, old_sub_path: str, new_sub_path: str) -> None: "crdfile": "somd.rst7", } for option, name in input_paths.items(): - _write_simfile_option( - self.simfile_path, option, _os.path.join(self.input_dir, name) - ) + setattr(self.engine_config, option, _os.path.join(self.input_dir, name)) + + # Write updated config to file + self.engine_config.write_somd_config(self.output_dir, self.lam) def set_simfile_option(self, option: str, value: str) -> None: """Set the value of an option in the simulation configuration file.""" - _write_simfile_option(self.simfile_path, option, value, logger=self._logger) + # Read the simfile and check if the option is already present + with open(self.simfile_path, "r") as f: + lines = f.readlines() + option_line_idx = None + for i, line in enumerate(lines): + if line.split("=")[0].strip() == option: + option_line_idx = i + break + + # If the option is not present, append it to the end of the file + if option_line_idx is None: + self._logger.warning( + f"Option {option} not found in simfile {self.simfile_path}. Appending new option to the end of the file." + ) + + # Ensure the previous line ends with a newline + if lines[-1][-1] != "\n": + lines[-1] += "\n" + + lines.append(f"{option} = {value}\n") + + # Otherwise, replace the line with the new value + else: + lines[option_line_idx] = f"{option} = {value}\n" + + # Write the updated simfile + with open(self.simfile_path, "w") as f: + f.writelines(lines) def analyse(self) -> None: raise NotImplementedError( diff --git a/a3fe/run/stage.py b/a3fe/run/stage.py index 130d58d..7cca854 100644 --- a/a3fe/run/stage.py +++ b/a3fe/run/stage.py @@ -48,7 +48,6 @@ from ..analyse.plot import plot_rmsds as _plot_rmsds from ..analyse.plot import plot_sq_sem_convergence as _plot_sq_sem_convergence from ..analyse.process_grads import GradientData as _GradientData -from ..read._process_somd_files import write_simfile_option as _write_simfile_option from ._simulation_runner import SimulationRunner as _SimulationRunner from ._virtual_queue import VirtualQueue as _VirtualQueue from .enums import StageType as _StageType @@ -201,7 +200,7 @@ def __init__( stream_log_level=self.stream_log_level, slurm_config=self.slurm_config, analysis_slurm_config=self.analysis_slurm_config, - engine_config=self.engine_config.copy() if self.engine_config else None, + engine_config=self.engine_config.copy(), ) ) @@ -1164,9 +1163,9 @@ def _mv_output(self, save_name: str) -> None: _os.rename(self.output_dir, _os.path.join(base_dir, save_name)) def set_simfile_option(self, option: str, value: str) -> None: - """Set the value of an option in the simulation configuration file.""" - simfile = _os.path.join(self.input_dir, "somd.cfg") - _write_simfile_option(simfile, option, value, logger=self._logger) + + setattr(self.engine_config, option, value) + self.engine_config.get_somd_config(self.input_dir) super().set_simfile_option(option, value) def wait(self) -> None: @@ -1214,13 +1213,10 @@ def update(self, save_name: str = "output_saved") -> None: raise RuntimeError("Can't update while ensemble is running") if _os.path.isdir(self.output_dir): self._mv_output(save_name) - # Update the list of lambda windows in the simfile - _write_simfile_option( - simfile=f"{self.input_dir}/somd.cfg", - option="lambda array", - value=", ".join([str(lam) for lam in self.lam_vals]), - ) - # Store the previous lambda window attributes that we want to preserve + + self.engine_config.lambda_array = ", ".join([str(lam) for lam in self.lam_vals]) + self.engine_config.get_somd_config(self.input_dir) + old_lam_vals_attrs = self.lam_windows[0].__dict__ self._logger.info("Deleting old lambda windows and creating new ones...") self._sub_sim_runners = [] @@ -1239,6 +1235,7 @@ def update(self, save_name: str = "output_saved") -> None: stream_log_level=self.stream_log_level, slurm_config=self.slurm_config, analysis_slurm_config=self.analysis_slurm_config, + engine_config=self.engine_config.copy(), ) # Overwrite the default equilibration detection algorithm new_lam_win.check_equil = old_lam_vals_attrs["check_equil"] diff --git a/a3fe/tests/test_engine_configuration.py b/a3fe/tests/test_engine_configuration.py index 2273951..a7d7dcb 100644 --- a/a3fe/tests/test_engine_configuration.py +++ b/a3fe/tests/test_engine_configuration.py @@ -2,7 +2,6 @@ from tempfile import TemporaryDirectory import os -import logging import pytest from a3fe import SomdConfig @@ -75,6 +74,35 @@ def test_charge_cutoff_validation(charge, cutoff, should_pass): with pytest.raises(ValueError): SomdConfig(ligand_charge=charge, cutoff_type=cutoff, runtime=1) +def test_charge_difference_validation(): + """Test that charge difference validation works correctly.""" + + #test charge_difference=0, any cutoff_type + valid_config_cutoff = SomdConfig( + charge_difference=0, + cutoff_type="cutoffperiodic", + runtime=1 + ) + assert valid_config_cutoff.charge_difference == 0 + assert valid_config_cutoff.cutoff_type == "cutoffperiodic" + + + valid_config_charge = SomdConfig( + charge_difference=1, + cutoff_type="PME", + runtime=1 + ) + assert valid_config_charge.charge_difference == 1 + assert valid_config_charge.cutoff_type == "PME" + + with pytest.raises(ValueError): + SomdConfig( + charge_difference=1, + cutoff_type="cutoffperiodic", + runtime=1 + ) + + def test_get_somd_config_with_extra_options(): """ Test SOMD config generation with some extra_options.