Skip to content

Commit

Permalink
Merge pull request #6 from openforcefield/openmm
Browse files Browse the repository at this point in the history
Updates for OpenMM 7.6
  • Loading branch information
mattwthompson authored Sep 7, 2021
2 parents 8ad7934 + 22ef0ce commit cae8aea
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 2 deletions.
7 changes: 6 additions & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ dependencies:
- pytest
- pytest-cov
- codecov
- mypy
- uncertainties
- openmm

# Typing
- mypy
- typing-extensions
- types-setuptools
- types-pkg_resources
132 changes: 132 additions & 0 deletions openff/units/openmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import ast
import operator as op
import warnings
from typing import TYPE_CHECKING, List

from openff.utilities import has_package, requires_package

from openff.units import unit

if has_package("openmm.unit") or TYPE_CHECKING:
from openmm import unit as openmm_unit
elif has_package("simtk.unit"):
warnings.warn(
"Found units module in simtk namespace, not openmm. Use openff.units.simtk instead."
)


@requires_package("openmm.unit")
def openmm_unit_to_string(input_unit: "openmm_unit.Unit") -> str:
"""
Convert a openmm.unit.Unit to a string representation.
Parameters
----------
input_unit : A openmm.unit
The unit to serialize
Returns
-------
unit_string : str
The serialized unit.
"""
if input_unit == openmm_unit.dimensionless:
return "dimensionless"

# Decompose output_unit into a tuples of (base_dimension_unit, exponent)
unit_string = ""

for unit_component in input_unit.iter_base_or_scaled_units():
unit_component_name = unit_component[0].name
# Convert, for example "elementary charge" --> "elementary_charge"
unit_component_name = unit_component_name.replace(" ", "_")
if unit_component[1] == 1:
contribution = "{}".format(unit_component_name)
else:
contribution = "{}**{}".format(unit_component_name, int(unit_component[1]))
if unit_string == "":
unit_string = contribution
else:
unit_string += " * {}".format(contribution)

return unit_string


def _ast_eval(node):
"""
Performs an algebraic syntax tree evaluation of a unit.
Parameters
----------
node : An ast parsing tree node
"""

operators = {
ast.Add: op.add,
ast.Sub: op.sub,
ast.Mult: op.mul,
ast.Div: op.truediv,
ast.Pow: op.pow,
ast.BitXor: op.xor,
ast.USub: op.neg,
}

if isinstance(node, ast.Num): # <number>
return node.n
elif isinstance(node, ast.BinOp): # <left> <operator> <right>
return operators[type(node.op)](_ast_eval(node.left), _ast_eval(node.right))
elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
return operators[type(node.op)](_ast_eval(node.operand))
elif isinstance(node, ast.Name):
# see if this is a openmm unit
b = getattr(openmm_unit, node.id)
return b
# TODO: This toolkit code that had a hack to cover some edge behavior; not clear which tests trigger it
elif isinstance(node, ast.List):
return ast.literal_eval(node)
else:
raise TypeError(node)


def string_to_openmm_unit(unit_string: str) -> "openmm_unit.Quantity":
"""
Deserializes a openmm.unit.Quantity from a string representation, for
example: "kilocalories_per_mole / angstrom ** 2"
Parameters
----------
unit_string : dict
Serialized representation of a openmm.unit.Quantity.
Returns
-------
output_unit: openmm.unit.Quantity
The deserialized unit from the string
"""

output_unit = _ast_eval(ast.parse(unit_string, mode="eval").body) # type: ignore
return output_unit


@requires_package("openmm.unit")
def from_openmm(openmm_quantity: "openmm_unit.Quantity"):
if isinstance(openmm_quantity, List):
openmm_quantity = openmm_unit.Quantity(openmm_quantity)
openmm_unit_ = openmm_quantity.unit
openmm_value = openmm_quantity.value_in_unit(openmm_unit_)

target_unit = openmm_unit_to_string(openmm_unit_)
target_unit = unit.Unit(target_unit)

return openmm_value * target_unit


@requires_package("openmm.unit")
def to_openmm(quantity) -> "openmm_unit.Quantity":
value = quantity.m

unit_string = str(quantity.units._units)
openmm_unit_ = string_to_openmm_unit(unit_string)

return value * openmm_unit_
30 changes: 29 additions & 1 deletion openff/units/simtk.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,42 @@
import ast
import operator as op
import warnings
from typing import TYPE_CHECKING, List

from openff.utilities import has_package, requires_package

from openff.units import unit

if has_package("simtk.unit") or TYPE_CHECKING:
simtk_to_openmm = {
"from_simtk": "from_openmm",
"to_simtk": "to_openmm",
"simtk_unit_to_string": "openmm_unit_to_string",
"string_to_simtk_unit": "simtk_unit_to_string",
}


def __getattr__(name):
if has_package("openmm.unit"):
if name in simtk_to_openmm.keys():
warnings.warn(
"Found units module in openmm namespace, not simtk. Returning a "
"corresponding method in openff.units.openmm, but you should import "
"from that module directly."
)

return simtk_to_openmm[name]
raise AttributeError(f"module {__name__} has no attribute {name}")


if TYPE_CHECKING:
from simtk import unit as simtk_unit
elif has_package("simtk.unit"):
from simtk import unit as simtk_unit

warnings.warn(
"The openff.units.simtk module is deprecated. Use openff.units.openmm instead."
)


@requires_package("simtk.unit")
def simtk_unit_to_string(input_unit: "simtk_unit.Unit") -> str:
Expand Down
110 changes: 110 additions & 0 deletions openff/units/tests/test_openmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
from openff.utilities.testing import skip_if_missing
from openff.utilities.utilities import has_package

from openff.units import unit
from openff.units.openmm import from_openmm

if has_package("openmm.unit"):
from openmm import unit as openmm_unit

from openff.units.openmm import (
openmm_unit_to_string,
string_to_openmm_unit,
to_openmm,
)

openmm_quantitites = [
4.0 * openmm_unit.nanometer,
5.0 * openmm_unit.angstrom,
1.0 * openmm_unit.elementary_charge,
0.5 * openmm_unit.erg,
1.0 * openmm_unit.dimensionless,
]

pint_quantities = [
4.0 * unit.nanometer,
5.0 * unit.angstrom,
1.0 * unit.elementary_charge,
0.5 * unit.erg,
1.0 * unit.dimensionless,
]
else:
# Must be defined as something, despite not being used, because pytest
# inspect the contents of pytest.mark.parametrize during collection;
# otherwise NameErrors will be raised. Finding a way to skip _collection_
# would be more elegant than mocking a module
class openmm_unit: # type: ignore[no-redef]
kilojoule = 1
kilojoule_per_mole = 1
kilocalories_per_mole = 1
angstrom = 1
nanometer = 1
meter = 1
picosecond = 1
joule = 1
mole = 1
dimensionless = 1
second = 1
kelvin = 1

openmm_quantitites = []
pint_quantities = []


@pytest.mark.xfail
@skip_if_missing("openmm.unit")
class TestOpenMMUnits:
@pytest.mark.parametrize(
"openmm_quantity,pint_quantity",
[(s, p) for s, p in zip(openmm_quantitites, pint_quantities)],
)
def test_openmm_to_pint(self, openmm_quantity, pint_quantity):
"""Test conversion from OpenMM Quantity to pint Quantity."""
converted_pint_quantity = from_openmm(openmm_quantity)

assert pint_quantity == converted_pint_quantity

@skip_if_missing("openmm.unit")
@pytest.mark.parametrize(
"openmm_unit_,unit_str",
[
(openmm_unit.kilojoule_per_mole, "mole**-1 * kilojoule"),
(
openmm_unit.kilocalories_per_mole / openmm_unit.angstrom ** 2,
"angstrom**-2 * mole**-1 * kilocalorie",
),
(
openmm_unit.joule / (openmm_unit.mole * openmm_unit.nanometer ** 2),
"nanometer**-2 * mole**-1 * joule",
),
(
openmm_unit.picosecond ** (-1),
"picosecond**-1",
),
(openmm_unit.dimensionless, "dimensionless"),
(openmm_unit.second, "second"),
(openmm_unit.angstrom, "angstrom"),
],
)
def test_openmm_unit_string_roundtrip(self, openmm_unit_, unit_str):
assert openmm_unit_to_string(openmm_unit_) == unit_str

assert unit_str == openmm_unit_to_string(string_to_openmm_unit(unit_str))

@pytest.mark.parametrize(
"openff_quantity,openmm_quantity",
[
(300.0 * unit.kelvin, 300.0 * openmm_unit.kelvin),
(
1.5 * unit.kilojoule,
1.5 * openmm_unit.kilojoule,
),
(1.0 / unit.meter, 1.0 / openmm_unit.meter),
],
)
def test_openmm_roundtrip(self, openff_quantity, openmm_quantity):
assert openmm_quantity == to_openmm(openff_quantity)
assert openff_quantity == from_openmm(openmm_quantity)

assert openff_quantity == from_openmm(to_openmm(openff_quantity))
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@ ignore_missing_imports = True
[mypy-simtk]
ignore_missing_imports = True

[mypy-openmm]
ignore_missing_imports = True

0 comments on commit cae8aea

Please sign in to comment.