Skip to content

Commit

Permalink
dataset download in untils
Browse files Browse the repository at this point in the history
  • Loading branch information
Tetracarbonylnickel committed Mar 1, 2024
1 parent 2fbaf43 commit 4b0e11f
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 35 deletions.
4 changes: 2 additions & 2 deletions apax/utils/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import convert, data, jax_md_reduced, math, random
from . import convert, data, jax_md_reduced, math, random, datasets

__all__ = ["convert", "data", "math", "random", "jax_md_reduced"]
__all__ = ["convert", "data", "math", "random", "jax_md_reduced", datasets]
77 changes: 77 additions & 0 deletions apax/utils/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
import urllib
import zipfile

def download_md22_stachyose(data_path):
url = "http://www.quantum-machine.org/gdml/repo/static/md22_stachyose.zip"
file_path = data_path / "md22_stachyose.zip"

os.makedirs(data_path, exist_ok=True)
urllib.request.urlretrieve(url, file_path)

with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(data_path)

file_path = modify_xyz_file(
file_path.with_suffix(".xyz"), target_string="Energy", replacement_string="energy"
)

return file_path


def download_md17_benzene_DFT(data_path):
url = "http://www.quantum-machine.org/gdml/data/xyz/benzene2018_dft.zip"
file_path = data_path / "benzene2018_dft.zip"

os.makedirs(data_path, exist_ok=True)
urllib.request.urlretrieve(url, file_path)

with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(data_path)

new_file_path = data_path / "benzene.xyz"
os.remove(file_path)

return new_file_path


def download_md17_benzene_CCSDT(data_path):
url = "http://www.quantum-machine.org/gdml/data/xyz/benzene_ccsd_t.zip"
file_path = data_path / "benzene_ccsdt.zip"

os.makedirs(data_path, exist_ok=True)
urllib.request.urlretrieve(url, file_path)

with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(data_path)

train_file_path = data_path / "benzene_ccsd_t-train.xyz"
os.remove(file_path)

return train_file_path


def modify_xyz_file(file_path, target_string, replacement_string):
new_file_path = file_path.with_name(file_path.stem + "_mod" + file_path.suffix)

with open(file_path, "r") as input_file, open(new_file_path, "w") as output_file:
for line in input_file:
# Replace all occurrences of the target string with the replacement string
modified_line = line.replace(target_string, replacement_string)
output_file.write(modified_line)
return new_file_path


def mod_md17(file_path):
new_file_path = file_path.with_name(file_path.stem + "_mod" + file_path.suffix)
with open(file_path, "r") as input_file, open(new_file_path, "w") as output_file:
for line in input_file:
if line.startswith("-"):
modified_line = f"Properties=species:S:1:pos:R:3:forces:R:3 energy={line}"
else:
modified_line = line
output_file.write(modified_line)

os.remove(file_path)

return new_file_path
12 changes: 12 additions & 0 deletions apax/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import yaml


def setup_ase():
"""Add uncertainty keys to ASE all properties.
from https://github.com/zincware/IPSuite/blob/main/ipsuite/utils/helpers.py#L10
Expand All @@ -7,3 +10,12 @@ def setup_ase():
for val in ["forces_uncertainty", "energy_uncertainty", "stress_uncertainty"]:
if val not in all_properties:
all_properties.append(val)


def mod_config(config_path, updated_config):
with open(config_path.as_posix(), "r") as stream:
config_dict = yaml.safe_load(stream)

for key, new_value in updated_config.items():
config_dict[key].update(new_value)
return config_dict
38 changes: 5 additions & 33 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import urllib
import zipfile
from typing import List

import jax
Expand All @@ -15,7 +13,8 @@
from apax.model.builder import ModelBuilder
from apax.train.run import run
from apax.utils.random import seed_py_np_tf

from apax.utils.datasets import download_md22_stachyose
from apax.utils.helpers import mod_config

@pytest.fixture(autouse=True)
def set_radom_seeds():
Expand Down Expand Up @@ -114,33 +113,10 @@ def tmp_data_path(tmp_path_factory):

@pytest.fixture(scope="session")
def get_md22_stachyose(tmp_data_path):
url = "http://www.quantum-machine.org/gdml/repo/static/md22_stachyose.zip"
file_path = tmp_data_path / "md22_stachyose.zip"

os.makedirs(tmp_data_path, exist_ok=True)
urllib.request.urlretrieve(url, file_path)

with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(tmp_data_path)

file_path = modify_xyz_file(
file_path.with_suffix(".xyz"), target_string="Energy", replacement_string="energy"
)

file_path = download_md22_stachyose(tmp_data_path)
return file_path


def modify_xyz_file(file_path, target_string, replacement_string):
new_file_path = file_path.with_name(file_path.stem + "_mod" + file_path.suffix)

with open(file_path, "r") as input_file, open(new_file_path, "w") as output_file:
for line in input_file:
# Replace all occurrences of the target string with the replacement string
modified_line = line.replace(target_string, replacement_string)
output_file.write(modified_line)
return new_file_path


@pytest.fixture()
def get_sample_input():
positions = np.array([
Expand Down Expand Up @@ -179,9 +155,5 @@ def load_and_dump_config(config_path, dump_path):


def load_config_and_run_training(config_path, updated_config):
with open(config_path.as_posix(), "r") as stream:
config_dict = yaml.safe_load(stream)

for key, new_value in updated_config.items():
config_dict[key].update(new_value)
run(config_dict)
config_dict = mod_config(config_path, updated_config)
run(config_dict)

0 comments on commit 4b0e11f

Please sign in to comment.