Skip to content

Commit

Permalink
Update cloud_vm_setup with fuller explanations of TPU vs GPU comparis…
Browse files Browse the repository at this point in the history
…ons.
  • Loading branch information
andrewlkd committed Dec 13, 2024
1 parent 10fa386 commit 273af28
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 14 additions & 10 deletions docs/cloud_vm_setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,27 @@ This document describes how to run `gencast_demo_cloud_vm.ipynb` through [Colabo
denoiser_architecture_config.sparse_transformer_config.attention_type = "triblockdiag_mha"
denoiser_architecture_config.sparse_transformer_config.mask_type = "full
```
- We tried running the model on a H100 with this attention mechanism and found that
while the performance is comparable, there is a small degradation (on average ~0.3% on unbiased Ensemble Mean RMSE and ~0.4% on unbiased CRPS). We suspect that
this originates from the attention mechanisms being algebraically equivalent, but
not numerically equivalent. For reference, a scorecard comparing GenCast
forecasts produced on a TPUv4 with `splash attention` vs. on a H100 with `triblockdiag_mha`
can be found in [docs/](https://github.com/google-deepmind/graphcast/blob/main/docs/GenCast_0p25deg_accelerator_scorecard.png).
Note that this scorecard differs from those found in the GenCast paper
in a number of ways:
**Skill comparison vs. TPU**
- We tried running the model on a H100 using the `triblockdiag_mha` attention implementation and found that, while the performance is comparable, there is a small degradation (on average ~0.3% on unbiased Ensemble Mean RMSE and ~0.4% on unbiased CRPS). For reference, a scorecard comparing GenCast forecasts produced on a TPUv4 with `splash attention` vs. on a H100 with `triblockdiag_mha` can be found in [docs/](https://github.com/google-deepmind/graphcast/blob/main/docs/GenCast_0p25deg_accelerator_scorecard.png). Note that this scorecard differs from those found in the GenCast paper in a number of ways:
- 8 member ensembles (vs. 50 in the paper)
- 30 hour initialisation strides starting from 01-01-2019T00, i.e. a comparison
of 292 initialisations (vs. 730 in the paper)
of 292 initialisations (vs. 730 in the paper)
- Colorbar limits of +-3% (vs. +-20% in the paper)
- There are two possible sources of this discrepancy. The first is the fact that the `splash` and `triblockdiag_mha` attention implementations are not exactly numerically equivalent (despite being algebraically equivalent). We have tested the isolated impact of these numerical differences by comparing performance with each attention implementation, both running on TPU. This comparison (scorecard [here](https://github.com/google-deepmind/graphcast/blob/main/docs/GenCast_0p25deg_attention_implementation_scorecard.png)) shows that there is very little difference caused by numerical differences between attention implementations. This implies that the minor degradation is caused primarily by running on GPU instead of TPU, and our initial investigations suggest that the root cause is the difference in the default precision of matmul operations on GPU compared to TPU.
** Memory requirement comparison vs. TPU **
- `triblockdiag_mha` also requires more memory, as such running inference on GPU
requires:
- 0.25deg GenCast: ~300GB of System Memory and ~60GB of vRAM
- 1deg GenCast: ~24GB of System Memory and ~16GB vRAM
** Inference time comparison vs. TPU **
- We have observed that running inference on H100 is slower than expected. Specifically we saw that a 30-step rollout of 0.25deg GenCast takes ~8min on TPUv5 with `splash_attention` (once compiled) whereas it takes ~25min on GPU with `triblockdiag_mha` attention.
- Part of this runtime discrepancy is caused by the fact that using `triblockdiag_mha` attention makes inference ~2x slower, such that running on TPU with `triblockdiag_mha` takes about ~15min, compared to the ~8min using `splash_attention`. However, there remains a discrepancy between the ~15min on a TPU and ~25min on GPU when using `triblockdiag_mha`.
## Prerequisites
### Create a Google Cloud account
Expand Down Expand Up @@ -210,7 +214,7 @@ Summary:
- `<location>` is the zone used for the TPU, using the names listed as per https://cloud.google.com/storage/docs/locations#available-locations,
- E,g, `us-south1-a` → `US-SOUTH-1`
```
gcloud storage buckets create gs://<bucket_name> --location <location>
gcloud storage buckets create gs://<bucket_name> --location <location>
```
- Create a Cloud TPU service account
```
Expand Down

0 comments on commit 273af28

Please sign in to comment.