-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from openforcefield/openmm
Updates for OpenMM 7.6
- Loading branch information
Showing
5 changed files
with
279 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import ast | ||
import operator as op | ||
import warnings | ||
from typing import TYPE_CHECKING, List | ||
|
||
from openff.utilities import has_package, requires_package | ||
|
||
from openff.units import unit | ||
|
||
if has_package("openmm.unit") or TYPE_CHECKING: | ||
from openmm import unit as openmm_unit | ||
elif has_package("simtk.unit"): | ||
warnings.warn( | ||
"Found units module in simtk namespace, not openmm. Use openff.units.simtk instead." | ||
) | ||
|
||
|
||
@requires_package("openmm.unit") | ||
def openmm_unit_to_string(input_unit: "openmm_unit.Unit") -> str: | ||
""" | ||
Convert a openmm.unit.Unit to a string representation. | ||
Parameters | ||
---------- | ||
input_unit : A openmm.unit | ||
The unit to serialize | ||
Returns | ||
------- | ||
unit_string : str | ||
The serialized unit. | ||
""" | ||
if input_unit == openmm_unit.dimensionless: | ||
return "dimensionless" | ||
|
||
# Decompose output_unit into a tuples of (base_dimension_unit, exponent) | ||
unit_string = "" | ||
|
||
for unit_component in input_unit.iter_base_or_scaled_units(): | ||
unit_component_name = unit_component[0].name | ||
# Convert, for example "elementary charge" --> "elementary_charge" | ||
unit_component_name = unit_component_name.replace(" ", "_") | ||
if unit_component[1] == 1: | ||
contribution = "{}".format(unit_component_name) | ||
else: | ||
contribution = "{}**{}".format(unit_component_name, int(unit_component[1])) | ||
if unit_string == "": | ||
unit_string = contribution | ||
else: | ||
unit_string += " * {}".format(contribution) | ||
|
||
return unit_string | ||
|
||
|
||
def _ast_eval(node): | ||
""" | ||
Performs an algebraic syntax tree evaluation of a unit. | ||
Parameters | ||
---------- | ||
node : An ast parsing tree node | ||
""" | ||
|
||
operators = { | ||
ast.Add: op.add, | ||
ast.Sub: op.sub, | ||
ast.Mult: op.mul, | ||
ast.Div: op.truediv, | ||
ast.Pow: op.pow, | ||
ast.BitXor: op.xor, | ||
ast.USub: op.neg, | ||
} | ||
|
||
if isinstance(node, ast.Num): # <number> | ||
return node.n | ||
elif isinstance(node, ast.BinOp): # <left> <operator> <right> | ||
return operators[type(node.op)](_ast_eval(node.left), _ast_eval(node.right)) | ||
elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1 | ||
return operators[type(node.op)](_ast_eval(node.operand)) | ||
elif isinstance(node, ast.Name): | ||
# see if this is a openmm unit | ||
b = getattr(openmm_unit, node.id) | ||
return b | ||
# TODO: This toolkit code that had a hack to cover some edge behavior; not clear which tests trigger it | ||
elif isinstance(node, ast.List): | ||
return ast.literal_eval(node) | ||
else: | ||
raise TypeError(node) | ||
|
||
|
||
def string_to_openmm_unit(unit_string: str) -> "openmm_unit.Quantity": | ||
""" | ||
Deserializes a openmm.unit.Quantity from a string representation, for | ||
example: "kilocalories_per_mole / angstrom ** 2" | ||
Parameters | ||
---------- | ||
unit_string : dict | ||
Serialized representation of a openmm.unit.Quantity. | ||
Returns | ||
------- | ||
output_unit: openmm.unit.Quantity | ||
The deserialized unit from the string | ||
""" | ||
|
||
output_unit = _ast_eval(ast.parse(unit_string, mode="eval").body) # type: ignore | ||
return output_unit | ||
|
||
|
||
@requires_package("openmm.unit") | ||
def from_openmm(openmm_quantity: "openmm_unit.Quantity"): | ||
if isinstance(openmm_quantity, List): | ||
openmm_quantity = openmm_unit.Quantity(openmm_quantity) | ||
openmm_unit_ = openmm_quantity.unit | ||
openmm_value = openmm_quantity.value_in_unit(openmm_unit_) | ||
|
||
target_unit = openmm_unit_to_string(openmm_unit_) | ||
target_unit = unit.Unit(target_unit) | ||
|
||
return openmm_value * target_unit | ||
|
||
|
||
@requires_package("openmm.unit") | ||
def to_openmm(quantity) -> "openmm_unit.Quantity": | ||
value = quantity.m | ||
|
||
unit_string = str(quantity.units._units) | ||
openmm_unit_ = string_to_openmm_unit(unit_string) | ||
|
||
return value * openmm_unit_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import pytest | ||
from openff.utilities.testing import skip_if_missing | ||
from openff.utilities.utilities import has_package | ||
|
||
from openff.units import unit | ||
from openff.units.openmm import from_openmm | ||
|
||
if has_package("openmm.unit"): | ||
from openmm import unit as openmm_unit | ||
|
||
from openff.units.openmm import ( | ||
openmm_unit_to_string, | ||
string_to_openmm_unit, | ||
to_openmm, | ||
) | ||
|
||
openmm_quantitites = [ | ||
4.0 * openmm_unit.nanometer, | ||
5.0 * openmm_unit.angstrom, | ||
1.0 * openmm_unit.elementary_charge, | ||
0.5 * openmm_unit.erg, | ||
1.0 * openmm_unit.dimensionless, | ||
] | ||
|
||
pint_quantities = [ | ||
4.0 * unit.nanometer, | ||
5.0 * unit.angstrom, | ||
1.0 * unit.elementary_charge, | ||
0.5 * unit.erg, | ||
1.0 * unit.dimensionless, | ||
] | ||
else: | ||
# Must be defined as something, despite not being used, because pytest | ||
# inspect the contents of pytest.mark.parametrize during collection; | ||
# otherwise NameErrors will be raised. Finding a way to skip _collection_ | ||
# would be more elegant than mocking a module | ||
class openmm_unit: # type: ignore[no-redef] | ||
kilojoule = 1 | ||
kilojoule_per_mole = 1 | ||
kilocalories_per_mole = 1 | ||
angstrom = 1 | ||
nanometer = 1 | ||
meter = 1 | ||
picosecond = 1 | ||
joule = 1 | ||
mole = 1 | ||
dimensionless = 1 | ||
second = 1 | ||
kelvin = 1 | ||
|
||
openmm_quantitites = [] | ||
pint_quantities = [] | ||
|
||
|
||
@pytest.mark.xfail | ||
@skip_if_missing("openmm.unit") | ||
class TestOpenMMUnits: | ||
@pytest.mark.parametrize( | ||
"openmm_quantity,pint_quantity", | ||
[(s, p) for s, p in zip(openmm_quantitites, pint_quantities)], | ||
) | ||
def test_openmm_to_pint(self, openmm_quantity, pint_quantity): | ||
"""Test conversion from OpenMM Quantity to pint Quantity.""" | ||
converted_pint_quantity = from_openmm(openmm_quantity) | ||
|
||
assert pint_quantity == converted_pint_quantity | ||
|
||
@skip_if_missing("openmm.unit") | ||
@pytest.mark.parametrize( | ||
"openmm_unit_,unit_str", | ||
[ | ||
(openmm_unit.kilojoule_per_mole, "mole**-1 * kilojoule"), | ||
( | ||
openmm_unit.kilocalories_per_mole / openmm_unit.angstrom ** 2, | ||
"angstrom**-2 * mole**-1 * kilocalorie", | ||
), | ||
( | ||
openmm_unit.joule / (openmm_unit.mole * openmm_unit.nanometer ** 2), | ||
"nanometer**-2 * mole**-1 * joule", | ||
), | ||
( | ||
openmm_unit.picosecond ** (-1), | ||
"picosecond**-1", | ||
), | ||
(openmm_unit.dimensionless, "dimensionless"), | ||
(openmm_unit.second, "second"), | ||
(openmm_unit.angstrom, "angstrom"), | ||
], | ||
) | ||
def test_openmm_unit_string_roundtrip(self, openmm_unit_, unit_str): | ||
assert openmm_unit_to_string(openmm_unit_) == unit_str | ||
|
||
assert unit_str == openmm_unit_to_string(string_to_openmm_unit(unit_str)) | ||
|
||
@pytest.mark.parametrize( | ||
"openff_quantity,openmm_quantity", | ||
[ | ||
(300.0 * unit.kelvin, 300.0 * openmm_unit.kelvin), | ||
( | ||
1.5 * unit.kilojoule, | ||
1.5 * openmm_unit.kilojoule, | ||
), | ||
(1.0 / unit.meter, 1.0 / openmm_unit.meter), | ||
], | ||
) | ||
def test_openmm_roundtrip(self, openff_quantity, openmm_quantity): | ||
assert openmm_quantity == to_openmm(openff_quantity) | ||
assert openff_quantity == from_openmm(openmm_quantity) | ||
|
||
assert openff_quantity == from_openmm(to_openmm(openff_quantity)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters