Skip to content

Commit

Permalink
offline analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
Pperezhogin committed Apr 30, 2024
1 parent 58140db commit 2b46807
Show file tree
Hide file tree
Showing 4 changed files with 1,345 additions and 83 deletions.
132 changes: 51 additions & 81 deletions experiments/ANN-Results/helpers/cm26.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,23 +409,23 @@ def compute_subfilter_forcing(self, factor=4, FGR_multiplier=2,

return ds_coarse

def predict_ANN(self, ann_Txy, ann_Txx_Tyy):
def predict_ANN(self, ann_Txy, ann_Txx_Tyy, **kw):
'''
This function makes ANN inference on the whole dataset
'''
# Just a (lazy) copy of dataset
ds = DatasetCM26(self.data, self.param)
ds.data = ds.data.load() # This data will anyway be needed for later evaluation
ds.data['ZB20u'] = xr.zeros_like(ds.data.u)
ds.data['ZB20v'] = xr.zeros_like(ds.data.v)

for time in range(len(self)):
batch = ds.split(time=time)
prediction = batch.state.ANN(ann_Txy, ann_Txx_Tyy)
for key in ['ZB20u', 'ZB20v']:
ds.data[key][{'time':time}] = prediction[key]

data = self.data[['SGSx', 'SGSy', 'u', 'v']]
data['ZB20u'] = xr.zeros_like(data.SGSx)
data['ZB20v'] = xr.zeros_like(data.SGSy)

for time in range(len(self.data.time)):
for zl in range(len(self.data.zl)):
batch = self.select2d(time=time,zl=zl)
prediction = batch.state.ANN(ann_Txy, ann_Txx_Tyy, **kw)
data['ZB20u'][{'time':time, 'zl':zl}] = prediction['ZB20u']
data['ZB20v'][{'time':time, 'zl':zl}] = prediction['ZB20v']

return ds
return DatasetCM26(data, self.param)

def SGS_skill(self):
'''
Expand All @@ -444,7 +444,13 @@ def SGS_skill(self):

############# R-squared and correlation ##############
# Here we define second moments
def M2(x,y=None,centered=False,dims=None):
def M2(x,y=None,centered=False,dims=None,exclude_dims='zl'):
if dims is None and exclude_dims is not None:
dims = []
for dim in x.dims:
if dim not in exclude_dims:
dims.append(dim)

if y is None:
y = x
if centered:
Expand All @@ -460,87 +466,51 @@ def M2v(x,y=None,centered=False,dims='time'):
errx = SGSx - ZB20u
erry = SGSy - ZB20v

ds = param.copy()
skill = xr.Dataset()
######## Simplest statistics ##########
ds['SGSx_mean'] = SGSx.mean('time')
ds['SGSy_mean'] = SGSy.mean('time')
ds['ZB20u_mean'] = ZB20u.mean('time')
ds['ZB20v_mean'] = ZB20v.mean('time')
ds['SGSx_std'] = SGSx.std('time')
ds['SGSy_std'] = SGSy.std('time')
ds['ZB20u_std'] = ZB20u.std('time')
ds['ZB20v_std'] = ZB20v.std('time')
skill['SGSx_mean'] = SGSx.mean('time')
skill['SGSy_mean'] = SGSy.mean('time')
skill['ZB20u_mean'] = ZB20u.mean('time')
skill['ZB20v_mean'] = ZB20v.mean('time')
skill['SGSx_std'] = SGSx.std('time')
skill['SGSy_std'] = SGSy.std('time')
skill['ZB20u_std'] = ZB20u.std('time')
skill['ZB20v_std'] = ZB20v.std('time')

# These metrics are same as in GZ21 work
# Note: eveything is uncentered
ds['R2u_map'] = 1 - M2u(errx) / M2u(SGSx)
ds['R2v_map'] = 1 - M2v(erry) / M2v(SGSy)
ds['R2_map'] = 1 - (M2u(errx) + M2v(erry)) / (M2u(SGSx) + M2v(SGSy))
skill['R2u_map'] = 1 - M2u(errx) / M2u(SGSx)
skill['R2v_map'] = 1 - M2v(erry) / M2v(SGSy)
skill['R2_map'] = 1 - (M2u(errx) + M2v(erry)) / (M2u(SGSx) + M2v(SGSy))

# Here everything is centered according to definition of correlation
ds['corru_map'] = M2u(SGSx,ZB20u,centered=True) / np.sqrt(M2u(SGSx,centered=True) * M2u(ZB20u,centered=True))
ds['corrv_map'] = M2v(SGSy,ZB20v,centered=True) / np.sqrt(M2v(SGSy,centered=True) * M2v(ZB20v,centered=True))
skill['corru_map'] = M2u(SGSx,ZB20u,centered=True) / np.sqrt(M2u(SGSx,centered=True) * M2u(ZB20u,centered=True))
skill['corrv_map'] = M2v(SGSy,ZB20v,centered=True) / np.sqrt(M2v(SGSy,centered=True) * M2v(ZB20v,centered=True))
# It is complicated to derive a single true formula, so use simplest one
ds['corr_map'] = (ds['corru_map'] + ds['corrv_map']) * 0.5
skill['corr_map'] = (skill['corru_map'] + skill['corrv_map']) * 0.5

########### Global metrics ############
ds['R2u'] = 1 - M2(errx) / M2(SGSx)
ds['R2v'] = 1 - M2(erry) / M2(SGSy)
ds['R2'] = 1 - (M2(errx) + M2(erry)) / (M2(SGSx) + M2(SGSy))
ds['corru'] = M2(SGSx,ZB20u,centered=True) \
skill['R2u'] = 1 - M2(errx) / M2(SGSx)
skill['R2v'] = 1 - M2(erry) / M2(SGSy)
skill['R2'] = 1 - (M2(errx) + M2(erry)) / (M2(SGSx) + M2(SGSy))
skill['corru'] = M2(SGSx,ZB20u,centered=True) \
/ np.sqrt(M2(SGSx,centered=True) * M2(ZB20u,centered=True))
ds['corrv'] = M2(SGSy,ZB20v,centered=True) \
skill['corrv'] = M2(SGSy,ZB20v,centered=True) \
/ np.sqrt(M2(SGSy,centered=True) * M2(ZB20v,centered=True))
ds['corr'] = (ds['corru'] + ds['corrv']) * 0.5

########## Optimal scaling analysis ###########
ds['opt_scaling_map'] = (M2u(SGSx,ZB20u) + M2v(SGSy,ZB20v)) / (M2u(ZB20u) + M2v(ZB20v))
# Maximum achievable R2 if scaling was optimal
scaling_u = grid.interp(ds['opt_scaling_map'], 'X')
scaling_v = grid.interp(ds['opt_scaling_map'], 'Y')
errx = SGSx - ZB20u * scaling_u
erry = SGSy - ZB20v * scaling_v
ds['R2_max_map'] = 1 - (M2u(errx) + M2v(erry)) / (M2u(SGSx) + M2v(SGSy))

ds['opt_scaling'] = (M2(SGSx,ZB20u) + M2(SGSy,ZB20v)) / (M2(ZB20u) + M2(ZB20v))
errx = SGSx - ZB20u * ds['opt_scaling']
erry = SGSy - ZB20v * ds['opt_scaling']
ds['R2_max'] = 1 - (M2(errx) + M2(erry)) / (M2(SGSx) + M2(SGSy))

############### Dissipation analysis ###############
d = self.state.compute_EZ_source(SGSx, SGSy)
ds['Esource_map'] = d['dEdt_local'].mean('time')
ds['Zsource_map'] = d['dZdt_local'].mean('time')
ds['Psource_map'] = d['dPdt_local'].mean('time')
d = self.state.compute_EZ_source(ZB20u, ZB20v)
ds['Esource_ZB_map'] = d['dEdt_local'].mean('time')
ds['Zsource_ZB_map'] = d['dZdt_local'].mean('time')
ds['Psource_ZB_map'] = d['dPdt_local'].mean('time')

######## Domain-averaged energy/enstrophy sources #########
# We integrate sources away from the land
wet = param.wet.copy()
for i in range(3):
wet = discard_land(grid.interp(grid.interp(wet, ['X', 'Y']), ['X', 'Y']))
ds['wet_extended'] = wet

for key in ['Esource', 'Zsource', 'Psource',
'Esource_ZB', 'Zsource_ZB', 'Psource_ZB']:
areaT = param.dxT * param.dyT
ds[key+'_extend'] = (ds[key+'_map'] * areaT * ds['wet_extended']).mean()
ds[key] = (ds[key+'_map'] * areaT * ds['wet']).mean()
skill['corr'] = (skill['corru'] + skill['corrv']) * 0.5
skill['opt_scaling'] = (M2(SGSx,ZB20u) + M2(SGSy,ZB20v)) / (M2(ZB20u) + M2(ZB20v))

############### Spectral analysis ##################
for region in ['NA', 'Pacific', 'Equator', 'ACC']:
transfer, power, KE_spec, power_time, KE_time = self.state.transfer(SGSx, SGSy, region=region, additional_spectra=True)
ds['transfer_'+region] = transfer.rename({'freq_r': 'freq_r_'+region})
ds['power_'+region] = power.rename({'freq_r': 'freq_r_'+region})
ds['KE_spec_'+region] = KE_spec.rename({'freq_r': 'freq_r_t'+region})
ds['power_time_'+region] = power_time
ds['KE_time_'+region] = KE_time
skill['transfer_'+region] = transfer.rename({'freq_r': 'freq_r_'+region})
skill['power_'+region] = power.rename({'freq_r': 'freq_r_'+region})
skill['KE_spec_'+region] = KE_spec.rename({'freq_r': 'freq_r_t'+region})
skill['power_time_'+region] = power_time
skill['KE_time_'+region] = KE_time
transfer, power, KE_spec, power_time, KE_time = self.state.transfer(ZB20u, ZB20v, region=region, additional_spectra=True)
ds['transfer_ZB_'+region] = transfer.rename({'freq_r': 'freq_r_'+region})
ds['power_ZB_'+region] = power.rename({'freq_r': 'freq_r_'+region})
ds['power_time_ZB_'+region] = power_time
skill['transfer_ZB_'+region] = transfer.rename({'freq_r': 'freq_r_'+region})
skill['power_ZB_'+region] = power.rename({'freq_r': 'freq_r_'+region})
skill['power_time_ZB_'+region] = power_time

return ds.compute()
return skill.compute()
3 changes: 1 addition & 2 deletions experiments/ANN-Results/helpers/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def select_Equator(array, time=None):
def select_ACC(array, time=None):
return select_LatLon(array, Lat=(-70,-30), Lon=(-40,0), time=time)

def plot(control, mask=None, vmax=None, vmin=None, selector=select_NA, cartopy=True):
def plot(control, mask=None, vmax=None, vmin=None, selector=select_NA, cartopy=True, cmap=cmocean.cm.balance):
if mask is not None:
mask_nan = selector(mask).data.copy()
mask_nan[mask_nan==0.] = np.nan
Expand Down Expand Up @@ -90,7 +90,6 @@ def plot(control, mask=None, vmax=None, vmin=None, selector=select_NA, cartopy=T
else:
ax = plt.gca()
kw = {}
cmap = cmocean.cm.balance
cmap.set_bad('gray')
im = selector(control).plot(ax=ax, vmax=vmax, vmin=vmin, cmap=cmap, add_colorbar=True, **kw)
plt.title('')
Expand Down
571 changes: 571 additions & 0 deletions experiments/ANN-Results/offline_analysis/5-training.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 2b46807

Please sign in to comment.