Skip to content

Commit

Permalink
Away from coast
Browse files Browse the repository at this point in the history
  • Loading branch information
Pperezhogin committed Oct 11, 2024
1 parent 29c7465 commit b501d30
Show file tree
Hide file tree
Showing 6 changed files with 3,585 additions and 19 deletions.
19 changes: 19 additions & 0 deletions experiments/ANN-Results/helpers/cm26.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def discard_land(x, percentile=1):
return (x==1).astype('float32')
else:
return (x>percentile).astype('float32')

def propagate_mask(wet0, grid, niter=1):
wet = wet0.copy()

for iter in range(niter):
wet = grid.interp(grid.interp(wet, ['X', 'Y']), ['X', 'Y'])

return discard_land(wet, percentile=1)

def create_grid(param):
'''
Expand Down Expand Up @@ -452,6 +460,7 @@ def SGS_skill(self):
in a few regions
'''
grid = self.grid
data = self.data
param = self.param
SGSx = self.data.SGSx
SGSy = self.data.SGSy
Expand Down Expand Up @@ -529,4 +538,14 @@ def M2v(x,y=None,centered=False,dims='time'):
skill['power_ZB_'+region] = power.rename({'freq_r': 'freq_r_'+region})
skill['power_time_ZB_'+region] = power_time

########### Global energy analysis ###############
areaT = param.dxT * param.dyT
areaU = param.dxCu * param.dyCu
areaV = param.dxCv * param.dyCv
skill['dEdt_map'] = ((grid.interp(data.SGSx * data.u * areaU,'X') + grid.interp(data.SGSy * data.v * areaV,'Y')) * param.wet / areaT).mean('time')
skill['dEdt_map_ZB'] = ((grid.interp(data.ZB20u * data.u * areaU,'X') + grid.interp(data.ZB20v * data.v * areaV,'Y')) * param.wet / areaT).mean('time')

skill['dEdt'] = (skill['dEdt_map'] * areaT).sum(['xh', 'yh']) / (areaT).sum(['xh', 'yh'])
skill['dEdt_ZB'] = (skill['dEdt_map_ZB'] * areaT).sum(['xh', 'yh']) / (areaT).sum(['xh', 'yh'])

return skill.compute()
18 changes: 10 additions & 8 deletions experiments/ANN-Results/helpers/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,17 @@ def compare(tested, control, mask=None, vmax=None, vmin = None, selector=select_

central_latitude = float(y_coord(control).mean())
central_longitude = float(x_coord(control).mean())
fig, axes = plt.subplots(2,2, figsize=(12, 10), subplot_kw={'projection': ccrs.Orthographic(central_latitude=central_latitude, central_longitude=central_longitude)})
fig, axes = plt.subplots(2,2, figsize=(12, 10))
cmap.set_bad('gray')

ax = axes[0][0]; ax.coastlines(); gl = ax.gridlines(); gl.bottom_labels=True; gl.left_labels=True;
im = tested.plot(ax=ax, vmax=vmax, vmin=vmin, transform=ccrs.PlateCarree(), cmap=cmap, add_colorbar=False)
ax = axes[0][0];# ax.coastlines(); gl = ax.gridlines(); gl.bottom_labels=True; gl.left_labels=True;
im = tested.plot(ax=ax, vmax=vmax, vmin=vmin, cmap=cmap, add_colorbar=False)
ax.set_title('Tested field')
ax = axes[0][1]; ax.coastlines(); gl = ax.gridlines(); gl.bottom_labels=True; gl.left_labels=True;
control.plot(ax=ax, vmax=vmax, vmin=vmin, transform=ccrs.PlateCarree(), cmap=cmap, add_colorbar=False)
ax = axes[0][1];# ax.coastlines(); gl = ax.gridlines(); gl.bottom_labels=True; gl.left_labels=True;
control.plot(ax=ax, vmax=vmax, vmin=vmin, cmap=cmap, add_colorbar=False)
ax.set_title('Control field')
ax = axes[1][0]; ax.coastlines(); gl = ax.gridlines(); gl.bottom_labels=True; gl.left_labels=True;
(tested-control).plot(ax=ax, vmax=vmax-control_mean, vmin=vmin-control_mean, transform=ccrs.PlateCarree(), cmap=cmap, add_colorbar=False)
ax = axes[1][0];# ax.coastlines(); gl = ax.gridlines(); gl.bottom_labels=True; gl.left_labels=True;
(tested-control).plot(ax=ax, vmax=vmax-control_mean, vmin=vmin-control_mean, cmap=cmap, add_colorbar=False)
ax.set_title('Tested-control')
plt.tight_layout()
plt.colorbar(im, ax=axes, shrink=0.9, aspect=30, extend='both')
Expand All @@ -150,4 +150,6 @@ def compare(tested, control, mask=None, vmax=None, vmin = None, selector=select_
print('R2 = ', float(R2))
print('R2 max = ', float(R2_max))
print('Optinal scaling:', float(optimal_scaling))
print(f'Nans [test/control]: [{int(np.sum(np.isnan(tested)))}, {int(np.sum(np.isnan(control)))}]')
print(f'Nans [test/control]: [{int(np.sum(np.isnan(tested)))}, {int(np.sum(np.isnan(control)))}]')

return axes
32 changes: 21 additions & 11 deletions experiments/ANN-Results/helpers/train_ann.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import numpy as np
import xarray as xr
from helpers.cm26 import read_datasets
from helpers.cm26 import read_datasets, propagate_mask
from helpers.ann_tools import ANN, export_ANN, tensor_from_xarray, torch_pad
import torch
import torch.optim as optim
Expand Down Expand Up @@ -36,16 +36,25 @@ def MSE(batch, SGSx, SGSy, SGS_norm, ann_Txy, ann_Txx_Tyy, ann_Tall,
batch_perturbed=None,
response_norm=None, smagx_response=None, smagy_response=None,
jacobian_trace=False, Cs_biharm=0.06,
perturbed_inputs=False, jacobian_reduction='component'):
perturbed_inputs=False, jacobian_reduction='component',
away_from_coast=0):
prediction = batch.state.Apply_ANN(ann_Txy, ann_Txx_Tyy, ann_Tall,
stencil_size=stencil_size, dimensional_scaling=dimensional_scaling,
feature_functions=feature_functions, gradient_features=gradient_features,
rotation=rotation, reflect_x=reflect_x, reflect_y=reflect_y,
jacobian_trace=jacobian_trace)

# If away_from_coast=0, all wet points are included into the loss
wet_u = tensor_from_xarray(propagate_mask(batch.param.wet_u, batch.grid, niter=away_from_coast))
wet_v = tensor_from_xarray(propagate_mask(batch.param.wet_v, batch.grid, niter=away_from_coast))

def reduction(x,y):
return (x * wet_u + y * wet_v).mean()

ANNx = prediction['ZB20u'] * SGS_norm
ANNy = prediction['ZB20v'] * SGS_norm
MSE_train = ((ANNx-SGSx)**2 + (ANNy-SGSy)**2).mean()

MSE_train = reduction((ANNx-SGSx)**2, (ANNy-SGSy)**2)

if short_waves_dissipation:
perturbed_prediction = batch_perturbed.state.Apply_ANN(ann_Txy, ann_Txx_Tyy, ann_Tall,
Expand All @@ -56,10 +65,7 @@ def MSE(batch, SGSx, SGSy, SGS_norm, ann_Txy, ann_Txx_Tyy, ann_Tall,
ANNx_response = (perturbed_prediction['ZB20u'] - prediction['ZB20u']) * response_norm
ANNy_response = (perturbed_prediction['ZB20v'] - prediction['ZB20v']) * response_norm

MSE_plane_waves = (
(ANNx_response - smagx_response)**2 +
(ANNy_response - smagy_response)**2
).mean()
MSE_plane_waves = reduction((ANNx_response - smagx_response)**2, (ANNy_response - smagy_response)**2)
else:
MSE_plane_waves = torch.tensor(0)

Expand All @@ -69,10 +75,11 @@ def fltr(x):
return (4 * x[1:-1,1:-1] + 2 * (x[2:,1:-1] + x[:-2,1:-1] + x[1:-1,2:] + x[1:-1,:-2]) + (x[2:,2:] + x[2:,:-2] + x[:-2,2:] + x[:-2,:-2])) / 16.
annx_sharpen = ANNx - fltr(ANNx)
anny_sharpen = ANNy - fltr(ANNy)
MSE_short_zero = (annx_sharpen**2 + anny_sharpen**2).mean()
MSE_short_zero = reduction(annx_sharpen**2, anny_sharpen**2)
else:
MSE_short_zero = torch.tensor(0)

# Note this should not be used with away_from_coast > 0
if jacobian_trace:
# First we define mean jacobian trace (per unit grid element)
# for Smagorinsky model which is analytical
Expand Down Expand Up @@ -114,7 +121,7 @@ def fltr(x):

ANNx = perturbed_prediction['ZB20u'] * SGS_norm
ANNy = perturbed_prediction['ZB20v'] * SGS_norm
MSE_perturbed = ((ANNx-SGSx)**2 + (ANNy-SGSy)**2).mean()
MSE_perturbed = reduction((ANNx-SGSx)**2, (ANNy-SGSy)**2)
else:
MSE_perturbed = torch.tensor(0)

Expand Down Expand Up @@ -143,6 +150,7 @@ def train_ANN(factors=[12,15],
Cs_biharm=0.06,
load=False,
subfilter='subfilter',
away_from_coast=0,
FGR=3):
'''
time_iters is the number of time snaphots
Expand Down Expand Up @@ -263,7 +271,8 @@ def iterator(x,y):
batch_perturbed=batch_perturbed,
response_norm=response_norm, smagx_response=smagx_response, smagy_response=smagy_response,
jacobian_trace=jacobian_trace, Cs_biharm=Cs_biharm,
perturbed_inputs=perturbed_inputs, jacobian_reduction=jacobian_reduction
perturbed_inputs=perturbed_inputs, jacobian_reduction=jacobian_reduction,
away_from_coast=away_from_coast
)
if short_waves_dissipation:
(MSE_train + MSE_plain_waves).backward()
Expand All @@ -286,7 +295,8 @@ def iterator(x,y):
with torch.no_grad():
MSE_validate, _, _, _, _ = MSE(batch, SGSx, SGSy, SGS_norm, ann_Txy, ann_Txx_Tyy, ann_Tall,
stencil_size=stencil_size, dimensional_scaling=dimensional_scaling,
feature_functions=feature_functions, gradient_features=gradient_features)
feature_functions=feature_functions, gradient_features=gradient_features,
away_from_coast=away_from_coast)

del batch

Expand Down
Loading

0 comments on commit b501d30

Please sign in to comment.