Skip to content

Commit

Permalink
example 01 update
Browse files Browse the repository at this point in the history
  • Loading branch information
Tetracarbonylnickel committed Mar 1, 2024
1 parent 4b0e11f commit 95ee47b
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 19 deletions.
132 changes: 113 additions & 19 deletions examples/01_Model_Training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,23 @@
"\n",
"## Acquiring a dataset\n",
"\n",
"# TODO tmpdir\n",
"\n",
"```bash\n",
"mkdir project\n",
"cd project\n",
"```\n",
"You can obtain the benzene dataset either by running the following command or manually from this website.\n",
"\n",
"`curl ... ...`\n",
"\n",
"apax uses ASE to read in datasets, so make sure to convert your own data into an ASE readable format (extxyz, traj etc)."
"You can obtain the benzene dataset with DFT labels either by running the following command or manually from this [link](http://www.quantum-machine.org/gdml/data/xyz/benzene2018_dft.zip). Apax uses ASE to read in datasets, so make sure to convert your own data into an ASE readable format (extxyz, traj etc). Be carefull the downloaded dataset has to be modified like in the `apax.untils.dataset.mop_md17` function in order to be readable."
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"TODO Dataset splitting"
"from pathlib import Path\n",
"from apax.utils.datasets import download_md17_benzene_DFT, mod_md17\n",
"\n",
"data_path = Path(\"project\")\n",
"\n",
"file_path = download_md17_benzene_DFT(data_path)\n",
"file_path = mod_md17(file_path)\n",
"\n"
]
},
{
Expand All @@ -41,12 +40,99 @@
"In order to get users quickly up and running, our command line interface provides an easy way to generate input templates.\n",
"The provided templates come in in two levels of verbosity: minimal and full.\n",
"In the following we are going to use a minimal input file. To see a complete list and explanation of all parameters, consult the documentation page LINK.\n",
"For more information on the CLI, simply run `apax -h`.\n",
"For more information on the CLI, simply run `apax -h`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/linux3_i1/segreto/miniconda3/envs/apax/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
" pid, fd = os.forkpty()\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m \u001b[0m\n",
"\u001b[1m \u001b[0m\u001b[1;33mUsage: \u001b[0m\u001b[1mapax [OPTIONS] COMMAND [ARGS]...\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m\n",
"\u001b[1m \u001b[0m\n",
"\u001b[2m╭─\u001b[0m\u001b[2m Options \u001b[0m\u001b[2m───────────────────────────────────────────────────────────────────\u001b[0m\u001b[2m─╮\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36m-\u001b[0m\u001b[1;36m-version\u001b[0m \u001b[1;32m-V\u001b[0m \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36m-\u001b[0m\u001b[1;36m-install\u001b[0m\u001b[1;36m-completion\u001b[0m Install completion for the current shell. \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36m-\u001b[0m\u001b[1;36m-show\u001b[0m\u001b[1;36m-completion\u001b[0m Show completion for the current shell, to \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m copy it or customize the installation. \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36m-\u001b[0m\u001b[1;36m-help\u001b[0m \u001b[1;32m-h\u001b[0m Show this message and exit. \u001b[2m│\u001b[0m\n",
"\u001b[2m╰──────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
"\u001b[2m╭─\u001b[0m\u001b[2m Commands \u001b[0m\u001b[2m──────────────────────────────────────────────────────────────────\u001b[0m\u001b[2m─╮\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36mdocs \u001b[0m\u001b[1;36m \u001b[0m Opens the documentation website in your browser. \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36meval \u001b[0m\u001b[1;36m \u001b[0m Starts performing the evaluation of the test dataset with \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36m \u001b[0m parameters provided by a configuration file. \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36mmd \u001b[0m\u001b[1;36m \u001b[0m Starts performing a molecular dynamics simulation (currently only \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36m \u001b[0m NHC thermostat) with parameters provided by a configuration file. \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36mtemplate \u001b[0m\u001b[1;36m \u001b[0m Create configuration file templates. \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36mtrain \u001b[0m\u001b[1;36m \u001b[0m Starts the training of a model with parameters provided by a \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36m \u001b[0m configuration file. \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36mvalidate \u001b[0m\u001b[1;36m \u001b[0m Validate training or MD config files. \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36mvisualize\u001b[0m\u001b[1;36m \u001b[0m Visualize a model based on a configuration file. A CO molecule is \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36m \u001b[0m taken as sample input (influences number of atoms, number of \u001b[2m│\u001b[0m\n",
"\u001b[2m│\u001b[0m \u001b[1;36m \u001b[0m species is set to 10). \u001b[2m│\u001b[0m\n",
"\u001b[2m╰──────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
"\n"
]
}
],
"source": [
"!apax -h"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There is already a config file in the working directory.\n"
]
}
],
"source": [
"!apax template train"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from apax.utils.helpers import mod_config\n",
"\n",
"apax template train --minimal\n",
"config_path = Path(\"config.yaml\")\n",
"\n",
"Open the resulting `config_minimal.yaml` file in an editor of your choice and make sure to fill in the data path field with the name of the data set you just downloaded.\n",
"For the purposes of this tutorial we will train on 1000 data points and validate the model on 200 more during the training.\n",
"config_updates = {\n",
" \"data\": {\n",
" \"energy_unit\": \"kcal/mol\",\n",
" }\n",
"}\n",
"config_dict = mod_config(config_path, config_updates)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Open the resulting `config.yaml` file in an editor of your choice and make sure to fill in the data path field with the name of the data set you just downloaded.\n",
"For the purposes of this tutorial we will train on 1000 data points and validate the model on 200 more during the training. Random splitting is done by apax but it is also possible to input a pre-splitted training and validation dataset\n",
"\n",
"The filled in configuration file should look similar to this one.\n",
"\n",
Expand Down Expand Up @@ -105,7 +191,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -153,8 +239,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.11.5"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down
21 changes: 21 additions & 0 deletions examples/03_Transfer_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,27 @@
"We can now fine tune the model by running\n",
"`apax train config.yaml`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from apax.utils.datasets import download_md17_benzene_CCSDT, mod_md17\n",
"import os\n",
"\n",
"data_path = Path(\"project\")\n",
"file_path = download_md17_benzene_CCSDT(data_path)\n",
"os.remove(data_path / \"benzene_ccsd_t-test.xyz\")\n",
"\n",
"file_path = mod_md17(file_path)"
]
}
],
"metadata": {
Expand Down

0 comments on commit 95ee47b

Please sign in to comment.