-
Notifications
You must be signed in to change notification settings - Fork 13
Trainer
The training can be launched with:
srun python atmorep/core/train.py
due to memory constraints, the full configuration with 6 fields requires different settings in Juelich. To train the coupled system we advice you to use:
srun python atmorep/core/train_multi.py
See the glossary page for details on each of the options.
The supported straining strategies are declared in atmorep/training/bert.py
:
if not BERT_strategy :
BERT_strategy = cf.BERT_strategy
if BERT_strategy == 'BERT' :
bert_f = prepare_batch_BERT_field
elif BERT_strategy == 'forecast' :
bert_f = prepare_batch_BERT_forecast_field
elif BERT_strategy == 'temporal_interpolation' :
bert_f = prepare_batch_BERT_temporal_field
else :
assert False
an intuitive sketch of the two most used strategies is reported in the Figure above. Each supported strategy is briefly described below.
This is an adaptation of the BERT style masking training protocol. The tokens are masked randomly within the loaded source
cube according to masking rates defined in the cf.fields[field][-1]
parameter. the parameters of interest are: [ total masking rate, rate masking, rate noising, rate for multi-res distortion]
, which control:
- the
total masking rate
controls ????? - the
rate masking
parameter controls the fraction of masked tokens over the total. Very high masking ratios, eg. 90% have been shown to produce more robust results. - the
rate noising
parameter controls the fraction of masked tokens that are masked using gaussian noise - the
rate for multi-res distortion
parameter controls the number of ????????
example:
cf.BERT_strategy = 'BERT'
cf.BERT_fields_synced = False # apply synchronized / identical masking to all fields
# (fields need to have same BERT params for this to have effect)
cf.BERT_mr_max = 2 # maximum reduction rate for resolution
This option is the default option used for training the AtmoRep core model.
The forecast
option aims at optimising the training for the forecasting task. It has been used to fine-tune AtmoRep for forecasting.
It is considered as a special case of the BERT strategy, in which the source cube gets all the last time step(s) completely masked for prediction.
Parameters:
- number of loaded tokens in time:
cf.fields[i][3][0]
(see glossary for details) - number of forecasted tokens:
cf.forecast_num_tokens = 2
(default = 2) Important: the first parameter controls the total number of tokens loaded intosource
. The second parameter controls the number of masked tokens in time, within the loaded cube. No roll-out is implemented at the moment soforecast_num_tokens
cannot be larger thancf.fields[i][3][0]
!
The temporal_interpolation
option aims at optimising the training for the temporal interpolation task. It has been used to fine tune AtmoRep for temporal interpolation. It is again considered as a special case for the BERT strategy, in which the intermediate tokens in time are masked. Masked tokens:
idx_time_mask = int( np.floor(num_tokens[0] / 2.)) # TODO: masking of multiple time steps
where num_tokens[0]
represents the number of loaded tokens along the time dimension (12 by default).
Important: be careful if you use this option. it might be out-dated!
**Parameters: **
- number of loaded tokens in time:
cf.fields[i][3][0]
The loss computation is modular and defined by the following parameter in the config:
cf.losses = ['mse', 'stats']
The final loss in this case will be the sum of the two terms. The supported losses are defined in trainer.py
( def loss( self, preds, batch_idx = 0)
:
- option:
mse
- Description: MSE loss is based on the usual mean square error.
- option:
mse_ensemble
- Description: it is the MSE loss computed for each ensemble member separately and then averaged.
- option:
stats
- Description: statistical loss. Generalized cross entroy loss for continuous distributions. Refer to paper for a detailed explanation
- option:
stats_area
- Description: based on
torch.special.erf
(error function), computed as link.
- option:
crps
- Description: Loss based on the continuous ranked probability score. see Eq. A2 in S. Rasp and S. Lerch. Neural networks for postprocessing ensemble weather forecasts. Monthly Weather Review, 146(11):3885 – 3900, 2018.
The output of train.py
is:
- A new directory inside
atmorep/results
named with thewandb_id
(weights and biases id) of the run. The folder contains ajson
file with the configuration. - A new directory inside
atmorep/models
named with thewandb_id
(weights and biases id) of the run and containing the model weights.
Depending on the script you use, when you launch the runs on slurm you might have three additional files:
3. An output file output/output_XXXXXX.txt
with the log. XXXXXX
here is the number associated with the slurm
ID of the job (remember to note it down when you launch your job!). The wandb_id
is saved within this file, just grep
it.
4. Two log files in logs/*XXXXX*.err
and logs/*XXXXX*.out
again named with the slurm
ID of the job containing the errors and the output from the shell scripts.
The AtmoRep Collaboration - last update: April 2024