Skip to content

Commit

Permalink
Training script with jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
Pperezhogin committed Jun 11, 2024
1 parent bd51f3f commit a87d849
Show file tree
Hide file tree
Showing 5 changed files with 1,467 additions and 26 deletions.
3 changes: 3 additions & 0 deletions experiments/ANN-Results/helpers/state_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,9 @@ def sample_grid_harmonic(self, grid_harmonic='plane_wave'):

u = np.sin(freq_x * i + freq_y * j + phase_u)
v = np.sin(freq_x * i + freq_y * j + phase_v)
elif grid_harmonic == 'white_noise':
u = np.random.randn(ny,nx)
v = np.random.randn(ny,nx)
else:
print('Error: wrong grid harmonic')

Expand Down
72 changes: 59 additions & 13 deletions experiments/ANN-Results/helpers/train_ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@ def MSE(batch, SGSx, SGSy, SGS_norm, ann_Txy, ann_Txx_Tyy, ann_Tall,
rotation=0, reflect_x=False, reflect_y=False,
short_waves_dissipation=False, short_waves_zero=False,
batch_perturbed=None,
response_norm=None, smagx_response=None, smagy_response=None):
response_norm=None, smagx_response=None, smagy_response=None,
jacobian_trace=False, Cs_biharm=0.06,
perturbed_inputs=False):
prediction = batch.state.Apply_ANN(ann_Txy, ann_Txx_Tyy, ann_Tall,
stencil_size=stencil_size, dimensional_scaling=dimensional_scaling,
feature_functions=feature_functions, gradient_features=gradient_features,
rotation=rotation, reflect_x=reflect_x, reflect_y=reflect_y)
rotation=rotation, reflect_x=reflect_x, reflect_y=reflect_y,
jacobian_trace=jacobian_trace)

ANNx = prediction['ZB20u'] * SGS_norm
ANNy = prediction['ZB20v'] * SGS_norm
Expand Down Expand Up @@ -62,8 +65,33 @@ def fltr(x):
MSE_short_zero = (annx_sharpen**2 + anny_sharpen**2).mean()
else:
MSE_short_zero = torch.tensor(0)

if jacobian_trace:
# First we define mean jacobian trace (per unit grid element)
# for Smagorinsky model which is analytical
target_Jtr = - Cs_biharm * 8 * np.sqrt(prediction['sh_xx']**2 + prediction['sh_xy']**2).mean()

MSE_jacobian_trace = \
(1 - (prediction['dTxx_du'] / target_Jtr).mean())**2 + \
(1 - (prediction['dTyy_dv'] / target_Jtr).mean())**2 + \
(1 - (prediction['dTxy_du'] / target_Jtr).mean())**2 + \
(1 - (prediction['dTxy_dv'] / target_Jtr).mean())**2
else:
MSE_jacobian_trace = torch.tensor(0)

if perturbed_inputs:
perturbed_prediction = batch_perturbed.state.Apply_ANN(ann_Txy, ann_Txx_Tyy, ann_Tall,
stencil_size=stencil_size, dimensional_scaling=dimensional_scaling,
feature_functions=feature_functions, gradient_features=gradient_features,
rotation=rotation, reflect_x=reflect_x, reflect_y=reflect_y)

ANNx = perturbed_prediction['ZB20u'] * SGS_norm
ANNy = perturbed_prediction['ZB20v'] * SGS_norm
MSE_perturbed = ((ANNx-SGSx)**2 + (ANNy-SGSy)**2).mean()
else:
MSE_perturbed = torch.tensor(0)

return MSE_train, MSE_plane_waves, MSE_short_zero
return MSE_train, MSE_plane_waves, MSE_short_zero, MSE_jacobian_trace, MSE_perturbed

def train_ANN(factors=[12,15],
stencil_size = 3,
Expand All @@ -80,6 +108,9 @@ def train_ANN(factors=[12,15],
permute_factors_and_depth=False,
short_waves_dissipation=False,
short_waves_zero=False,
jacobian_trace=False,
perturbed_inputs=False,
Cs_biharm=0.06,
load=False,
subfilter='subfilter',
FGR=3):
Expand All @@ -95,7 +126,7 @@ def train_ANN(factors=[12,15],

########## Init logger ###########
logger = xr.Dataset()
for key in ['MSE_train', 'MSE_plain_waves', 'MSE_short_zero', 'MSE_validate']:
for key in ['MSE_train', 'MSE_plain_waves', 'MSE_short_zero', 'MSE_jacobian_trace', 'MSE_validate', 'MSE_perturbed']:
logger[key] = xr.DataArray(np.zeros([time_iters, len(factors), len(depth_idx)]),
dims=['iter', 'factor', 'depth'],
coords={'factor': factors, 'depth': depth_idx})
Expand Down Expand Up @@ -173,36 +204,45 @@ def iterator(x,y):

if short_waves_dissipation:
batch_perturbed = batch.perturb_velocities()
smag = batch.state.Smagorinsky()
smag_perturbed = batch_perturbed.state.Smagorinsky()
smag = batch.state.Smagorinsky(Cs_biharm=Cs_biharm)
smag_perturbed = batch_perturbed.state.Smagorinsky(Cs_biharm=Cs_biharm)

smagx_response = tensor_from_xarray(smag_perturbed['smagx']) - tensor_from_xarray(smag['smagx'])
smagy_response = tensor_from_xarray(smag_perturbed['smagy']) - tensor_from_xarray(smag['smagy'])
response_norm = 1. / torch.sqrt((smagx_response**2 + smagy_response**2).mean())
smagx_response = smagx_response * response_norm
smagy_response = smagy_response * response_norm
elif perturbed_inputs:
batch_perturbed = batch.perturb_velocities('white_noise', amp=0.01)
response_norm=None; smagx_response=None; smagy_response=None
else:
batch_perturbed = None; response_norm=None; smagx_response=None; smagy_response=None

############## Training step ###############
SGSx, SGSy, SGS_norm = get_SGS(batch)

######## Optionally, apply symmetries by data augmentation #########
for rotation, reflect_x, reflect_y in augment():
optimizer.zero_grad()
MSE_train, MSE_plain_waves, MSE_short_zero = \
MSE_train, MSE_plain_waves, MSE_short_zero, MSE_jacobian_trace, MSE_perturbed = \
MSE(batch, SGSx, SGSy, SGS_norm, ann_Txy, ann_Txx_Tyy, ann_Tall,
stencil_size=stencil_size, dimensional_scaling=dimensional_scaling,
feature_functions=feature_functions, gradient_features=gradient_features,
rotation=rotation, reflect_x=reflect_x, reflect_y=reflect_y,
short_waves_dissipation=short_waves_dissipation, short_waves_zero=short_waves_zero,
batch_perturbed=batch_perturbed,
response_norm=response_norm, smagx_response=smagx_response, smagy_response=smagy_response
response_norm=response_norm, smagx_response=smagx_response, smagy_response=smagy_response,
jacobian_trace=jacobian_trace, Cs_biharm=Cs_biharm,
perturbed_inputs=perturbed_inputs
)
if short_waves_dissipation:
(MSE_train + MSE_plain_waves).backward()
elif short_waves_zero:
(MSE_train + MSE_short_zero).backward()
elif jacobian_trace:
(MSE_train + MSE_jacobian_trace).backward()
elif perturbed_inputs:
(MSE_perturbed).backward()
else:
MSE_train.backward()
optimizer.step()
Expand All @@ -214,18 +254,24 @@ def iterator(x,y):
batch = dataset[f'validate-{factor}'].select2d(zl=depth)
SGSx, SGSy, SGS_norm = get_SGS(batch)
with torch.no_grad():
MSE_validate, _, _ = MSE(batch, SGSx, SGSy, SGS_norm, ann_Txy, ann_Txx_Tyy, ann_Tall,
MSE_validate, _, _, _, _ = MSE(batch, SGSx, SGSy, SGS_norm, ann_Txy, ann_Txx_Tyy, ann_Tall,
stencil_size=stencil_size, dimensional_scaling=dimensional_scaling,
feature_functions=feature_functions, gradient_features=gradient_features)

del batch

########### Logging ############
MSE_train = float(MSE_train.data); MSE_validate = float(MSE_validate.data); MSE_plain_waves = float(MSE_plain_waves.data); MSE_short_zero = float(MSE_short_zero)
for key in ['MSE_train', 'MSE_plain_waves', 'MSE_short_zero', 'MSE_validate']:
MSE_train = float(MSE_train.data)
MSE_validate = float(MSE_validate.data)
MSE_perturbed = float(MSE_perturbed.data)
MSE_plain_waves = float(MSE_plain_waves.data)
MSE_short_zero = float(MSE_short_zero.data)
MSE_jacobian_trace = float(MSE_jacobian_trace.data)

for key in ['MSE_train', 'MSE_plain_waves', 'MSE_short_zero', 'MSE_jacobian_trace', 'MSE_validate', 'MSE_perturbed']:
logger[key].loc[{'iter': time_iter, 'factor': factor, 'depth': depth}] = eval(key)
if (time_iter+1) % print_iters == 0:
print(f'Factor: {factor}, depth: {depth}, '+'MSE train/validate/waves/short: [%.6f, %.6f, %.6f, %.6f]' % (MSE_train, MSE_validate, MSE_plain_waves, MSE_short_zero))
print(f'Factor: {factor}, depth: {depth}, '+'MSE train/validate/perturbed/waves/short/trace: [%.6f, %.6f, %.6f, %.6f, %.6f, %.6f]' % (MSE_train, MSE_validate, MSE_perturbed, MSE_plain_waves, MSE_short_zero, MSE_jacobian_trace))
t = time()
if (time_iter+1) % print_iters == 0:
print(f'Iter/num_iters [{time_iter+1}/{time_iters}]. Iter time/Remaining time in seconds: [%.2f/%.1f]' % (t-t_e, (t-t_s)*(time_iters/(time_iter+1)-1)))
Expand Down
Loading

0 comments on commit a87d849

Please sign in to comment.