Skip to content

Commit

Permalink
Implemented Collocated ANN with arbitrary stencil size
Browse files Browse the repository at this point in the history
  • Loading branch information
Pperezhogin committed May 9, 2024
1 parent 3f0c9c0 commit 310b6b5
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 4 deletions.
113 changes: 113 additions & 0 deletions src/parameterizations/lateral/MOM_Zanna_Bolton.F90
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ module MOM_Zanna_Bolton
! KE backscatter
logical, public :: GM_conserv !< If true, adds GM dissipation to equilibrate bakcscatter of KE

logical, public :: true_vorticity !< Use correct curvilinear coordinate expression for vorticity

!> Control structure for Zanna-Bolton-2020 parameterization.
type, public :: ZB2020_CS ; private
! Parameters
Expand Down Expand Up @@ -97,6 +99,7 @@ module MOM_Zanna_Bolton
real :: backscatter_ratio !< The ratio of backscattered energy to the dissipated energy

integer :: use_ann !< 0: ANN is turned off, 1: default ANN with ZB20 model, 2: two separate ANNs on stencil 3x3 for corner and center
integer :: stencil_size !< Default is 3x3
logical :: rotation_invariant !< If true, the ANN is rotation invariant
logical :: ann_smag_conserv !< Energy-conservative ANN by imposing Smagorinsky model
logical :: smag_conserv_lagrangian !< Energy conservation is imposed by introducing Smagorinsky model and performing averaging in Lagrangian frame
Expand All @@ -106,9 +109,11 @@ module MOM_Zanna_Bolton
type(ANN_CS) :: ann_instance !< ANN instance
type(ANN_CS) :: ann_Txy !< ANN instance for Txy
type(ANN_CS) :: ann_Txx_Tyy !< ANN instance for diagonal stress
type(ANN_CS) :: ann_Tall !< ANN instance for off-diagonal and diagonal stress
character(len=200) :: ann_file = "/home/pp2681/MOM6-examples/src/MOM6/experiments/ANN-Results/trained_models/ANN_64_neurons_ZB-ver-1.2.nc" !< Default ANN with ZB20 model
character(len=200) :: ann_file_Txy
character(len=200) :: ann_file_Txx_Tyy
character(len=200) :: ann_file_Tall
real :: subroundoff_shear

type(diag_ctrl), pointer :: diag => NULL() !< A type that regulates diagnostics output
Expand Down Expand Up @@ -194,6 +199,13 @@ subroutine ZB2020_init(Time, G, GV, US, param_file, diag, CS, use_ZB2020)
call get_param(param_file, mdl, "USE_ANN", CS%use_ann, &
"ANN inference of momentum fluxes: 0 off, 1: single ANN 2x2, 2: two ANNs 3x3", default=0)

call get_param(param_file, mdl, "ANN_STENCIL_SIZE", CS%stencil_size, &
"ANN stencil size", default=3)

call get_param(param_file, mdl, "ANN_TRUE_VORTICITY", true_vorticity, &
"Use correct curvilinear approximation of vorticity", &
default=.False.)

call get_param(param_file, mdl, "ANN_SMAG_CONSERV", CS%ann_smag_conserv, &
"Smagorinsky model makes SGS parameterization energy-conservative", default=.False.)

Expand Down Expand Up @@ -226,6 +238,10 @@ subroutine ZB2020_init(Time, G, GV, US, param_file, diag, CS, use_ZB2020)
"ANN parameters for prediction of Txx and Tyy netcdf input", &
default="/scratch/pp2681/mom6/CM26_ML_models/Gauss-FGR2/hdn-64-64/model/Txx_Tyy_epoch_2000.nc")

call get_param(param_file, mdl, "ANN_FILE_TALL", CS%ann_file_Tall, &
"ANN parameters for prediction of Txy, Txx and Tyy netcdf input", &
default="/scratch/pp2681/mom6/CM26_ML_models/ocean3d/Gauss-FGR3/EXP-32-32/repeat/model/Tall.nc")

call get_param(param_file, mdl, "ROT_INV", CS%rotation_invariant, &
"If true, rotation invariance is imposed as hard constraint", default=.false.)

Expand Down Expand Up @@ -366,6 +382,8 @@ subroutine ZB2020_init(Time, G, GV, US, param_file, diag, CS, use_ZB2020)
elseif (CS%use_ann == 2) then
call ANN_init(CS%ann_Txy, CS%ann_file_Txy)
call ANN_init(CS%ann_Txx_Tyy, CS%ann_file_Txx_Tyy)
elseif (CS%use_ann == 3) then
call ANN_init(CS%ann_Tall, CS%ann_file_Tall)
endif

! Allocate memory
Expand Down Expand Up @@ -629,6 +647,8 @@ subroutine ZB2020_lateral_stress(u, v, h, diffu, diffv, G, GV, CS, &
call compute_stress_ANN(G, GV, CS)
elseif (CS%use_ann==2) then
call compute_stress_ANN_3x3(G, GV, CS)
elseif (CS%use_ann==3) then
call compute_stress_ANN_collocated(G, GV, CS)
endif

! Smooth the stress tensor if specified
Expand Down Expand Up @@ -1591,6 +1611,99 @@ subroutine compute_stress_ANN_3x3(G, GV, CS)

end subroutine compute_stress_ANN_3x3

!> Compute stress tensor T =
!! (Txx, Txy;
!! Txy, Tyy)
!! with ANN
subroutine compute_stress_ANN_collocated(G, GV, CS)
type(ocean_grid_type), intent(in) :: G !< The ocean's grid structure.
type(verticalGrid_type), intent(in) :: GV !< The ocean's vertical grid structure
type(ZB2020_CS), intent(inout) :: CS !< ZB2020 control structure.

integer :: is, ie, js, je, Isq, Ieq, Jsq, Jeq, nz
integer :: i, j, k, n

real :: x(3*CS%stencil_size**2), y(3)
real :: input_norm
integer :: shift, stencil_points

real, dimension(SZI_(G),SZJ_(G),SZK_(GV)) :: &
sh_xy_h, & ! sh_xy interpolated to the center [T-1 ~ s-1]
vort_xy_h, & ! vort_xy interpolated to the center [T-1 ~ s-1]
norm_h ! Norm in h points [T-1 ~ s-1]

real, dimension(SZI_(G),SZJ_(G)) :: &
sqr_h, & ! Sum of squares in h points
Txy ! Predicted Txy in center points to be interpolated to coreners

call cpu_clock_begin(CS%id_clock_stress_ANN)

is = G%isc ; ie = G%iec ; js = G%jsc ; je = G%jec ; nz = GV%ke
Isq = G%IscB ; Ieq = G%IecB ; Jsq = G%JscB ; Jeq = G%JecB

sh_xy_h = 0.
vort_xy_h = 0.
norm_h = 0.

call pass_var(CS%sh_xy, G%Domain, clock=CS%id_clock_mpi, position=CORNER)
call pass_var(CS%sh_xx, G%Domain, clock=CS%id_clock_mpi)
call pass_var(CS%vort_xy, G%Domain, clock=CS%id_clock_mpi, position=CORNER)

shift = (CS%stencil_size-1)/2
stencil_points = CS%stencil_size**2

! Interpolate input features
do k=1,nz
do j=js-2,je+2 ; do i=is-2,ie+2
! It is assumed that B.C. is applied to sh_xy and vort_xy
sh_xy_h(i,j,k) = 0.25 * ( (CS%sh_xy(I-1,J-1,k) + CS%sh_xy(I,J,k)) &
+ (CS%sh_xy(I-1,J,k) + CS%sh_xy(I,J-1,k)) ) * G%mask2dT(i,j)

vort_xy_h(i,j,k) = 0.25 * ( (CS%vort_xy(I-1,J-1,k) + CS%vort_xy(I,J,k)) &
+ (CS%vort_xy(I-1,J,k) + CS%vort_xy(I,J-1,k)) ) * G%mask2dT(i,j)

sqr_h(i,j) = CS%sh_xx(i,j,k)**2 + sh_xy_h(i,j,k)**2 + vort_xy_h(i,j,k)**2
enddo; enddo

do j=js,je ; do i=is,ie
norm_h(i,j,k) = sqrt(SUM(sqr_h(i-shift:i+shift,j-shift:j+shift)))
enddo; enddo
enddo

call pass_var(sh_xy_h, G%Domain, clock=CS%id_clock_mpi)
call pass_var(vort_xy_h, G%Domain, clock=CS%id_clock_mpi)
call pass_var(norm_h, G%Domain, clock=CS%id_clock_mpi)

do k=1,nz
do j=js-2,je+2 ; do i=is-2,ie+2
x(1:stencil_points) = RESHAPE(sh_xy_h(i-shift:i+shift,j-shift:j+shift,k), (/stencil_points/))
x(stencil_points+1:2*stencil_points) = RESHAPE(CS%sh_xx(i-shift:i+shift,j-shift:j+shift,k), (/stencil_points/))
x(2*stencil_points+1:3*stencil_points) = RESHAPE(vort_xy_h(i-shift:i+shift,j-shift:j+shift,k), (/stencil_points/))

input_norm = norm_h(i,j,k)

x(1:3*stencil_points) = x(1:3*stencil_points) / (input_norm + CS%subroundoff_shear)

call ANN_apply(x, y, CS%ann_Tall)

y = y * input_norm * input_norm * CS%kappa_h(i,j)

Txy(i,j) = y(1)
CS%Txx(i,j,k) = y(2)
CS%Tyy(i,j,k) = y(3)
enddo ; enddo

do J=Jsq-1,Jeq+1 ; do I=Isq-1,Ieq+1
CS%Txy(I,J,k) = 0.25 * ( (Txy(i+1,j+1) + Txy(i,j)) &
+ (Txy(i+1,j) + Txy(i,j+1))) * G%mask2dBu(I,J)
enddo; enddo

enddo ! end of k loop

call cpu_clock_end(CS%id_clock_stress_ANN)

end subroutine compute_stress_ANN_collocated

!> Compute the divergence of subgrid stress
!! weighted with thickness, i.e.
!! (fx,fy) = 1/h Div(h * [Txx, Txy; Txy, Tyy])
Expand Down
18 changes: 14 additions & 4 deletions src/parameterizations/lateral/MOM_hor_visc.F90
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ module MOM_hor_visc
use MOM_verticalGrid, only : verticalGrid_type
use MOM_variables, only : accel_diag_ptrs
use MOM_Zanna_Bolton, only : ZB2020_lateral_stress, ZB2020_init, ZB2020_end, &
ZB2020_CS, ZB2020_copy_gradient_and_thickness
ZB2020_CS, ZB2020_copy_gradient_and_thickness, &
true_vorticity

implicit none ; private

Expand Down Expand Up @@ -830,9 +831,18 @@ subroutine horizontal_viscosity(u, v, h, diffu, diffv, MEKE, VarMix, G, GV, US,
vort_xy(I,J) = (2.0-G%mask2dBu(I,J)) * ( dvdx(I,J) - dudy(I,J) )
enddo ; enddo
else
do J=Jsq-2,Jeq+2 ; do I=Isq-2,Ieq+2
vort_xy(I,J) = G%mask2dBu(I,J) * ( dvdx(I,J) - dudy(I,J) )
enddo ; enddo
if (true_vorticity) then
do J=Jsq-2,Jeq+2 ; do I=Isq-2,Ieq+2
vort_xy(I,J) = G%mask2dBu(I,J) * G%IareaBu(I,J) * ( &
(v(i+1,J,k)*G%dyCv(i+1,J) - v(i,J,k)*G%dyCv(i,J)) &
- (u(I,j+1,k)*G%dxCu(I,j+1) - u(I,j,k)*G%dxCu(I,j)) &
)
enddo ; enddo
else
do J=Jsq-2,Jeq+2 ; do I=Isq-2,Ieq+2
vort_xy(I,J) = G%mask2dBu(I,J) * ( dvdx(I,J) - dudy(I,J) )
enddo ; enddo
endif
endif

if (CS%use_Leithy) then
Expand Down

0 comments on commit 310b6b5

Please sign in to comment.