Skip to content

Commit

Permalink
Update gencast mini demo notebook defaults.
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewlkd committed Dec 13, 2024
1 parent f71882b commit 10fa386
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions gencast_mini_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@
"## Load the model params\n",
"\n",
"Choose one of the two ways of getting model params:\n",
"- **random**: You'll get random predictions, but you can change the model architecture, which may run faster or fit on your device.\n",
"- **checkpoint**: You'll get sensible predictions, but are limited to the model architecture that it was trained with, which may not fit on your device.\n",
"- **random**: You'll get random predictions, but you can change the model architecture and data resolution which may run faster or fit on your device.\n",
"\n"
]
},
Expand Down Expand Up @@ -272,6 +272,8 @@
" options=[int(2**i) for i in range(0, 3)], value=4,description=\"Num heads:\")\n",
"random_attention_k_hop = widgets.Dropdown(\n",
" options=[int(2**i) for i in range(2, 5)], value=16,description=\"Attn k hop:\")\n",
"random_resolution = widgets.Dropdown(\n",
" options=[\"1p0\", \"0p25\"], value=\"1p0\", description=\"Resolution:\")\n",
"\n",
"def update_latent_options(*args):\n",
" def _latent_valid_for_attn(attn, latent, heads):\n",
Expand Down Expand Up @@ -299,17 +301,18 @@
" layout={\"width\": \"max-content\"})\n",
"\n",
"source_tab = widgets.Tab([\n",
" params_file,\n",
" widgets.VBox([\n",
" random_attention_type,\n",
" random_mesh_size,\n",
" random_num_heads,\n",
" random_latent_size,\n",
" random_attention_k_hop\n",
" random_attention_k_hop,\n",
" random_resolution\n",
" ]),\n",
" params_file,\n",
"])\n",
"source_tab.set_title(0, \"Random\")\n",
"source_tab.set_title(1, \"Checkpoint\")\n",
"source_tab.set_title(0, \"Checkpoint\")\n",
"source_tab.set_title(1, \"Random\")\n",
"widgets.VBox([\n",
" source_tab,\n",
" widgets.Label(value=\"Run the next cell to load the model. Rerunning this cell clears your selection.\")\n",
Expand Down Expand Up @@ -388,7 +391,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "5XGzOww0y_BC"
},
"outputs": [],
Expand All @@ -405,10 +407,11 @@
"\n",
"def data_valid_for_model(file_name: str, params_file_name: str):\n",
" \"\"\"Check data type and resolution matches.\"\"\"\n",
" if source == \"Random\":\n",
" return True\n",
" data_file_parts = parse_file_parts(file_name.removesuffix(\".nc\"))\n",
" res_matches = data_file_parts[\"res\"].replace(\".\", \"p\") in params_file_name.lower()\n",
" data_res = data_file_parts[\"res\"].replace(\".\", \"p\")\n",
" if source == \"Random\":\n",
" return random_resolution.value == data_res\n",
" res_matches = data_res in params_file_name.lower()\n",
" source_matches = \"Operational\" in params_file_name\n",
" if data_file_parts[\"source\"] == \"era5\":\n",
" source_matches = not source_matches\n",
Expand Down

0 comments on commit 10fa386

Please sign in to comment.