module obl_inmom
  implicit none
  integer, parameter :: border_shift = 1
  private :: border_shift

  contains

  subroutine init_vermix(richnum_mode, kh_km_mode, uu, vv, lu, dx, dy, dxh, dyh, hhu, hhv, hhq, zw, g, rh0, &
    rit, den, rlh, taux, tauy, aice0, tt, ss, anzt, anzu, anumaxt, anumaxu, anubgrt, anubgru)
    use obl_legacy
    use obl_pph, only: pph_kh => get_eddy_diffusivity_vec, &
                        pph_km => get_eddy_viscosity_vec, &
                        pphParamType
    use obl_pph_dyn, only: pph_dyn_kh => get_eddy_diffusivity_vec, &
                            pph_dyn_km => get_eddy_viscosity_vec, &
                            pphDynParamType
    integer, intent(in) :: richnum_mode, kh_km_mode
    real, intent(inout) :: anzt(:,:,:) ! kh
    real, intent(inout) :: anzu(:,:,:) ! km
    real, intent(inout) :: rit(:,:,:)
    real, intent(inout) :: den(:,:,:)
    real, intent(in) :: uu(:,:,:), vv(:,:,:)
    real, intent(in) :: tt(:,:,:)
    real, intent(in) :: ss(:,:,:)
    real, intent(in) :: taux(:,:), tauy(:,:)
    real, intent(in) :: lu(:,:)
    real, intent(in) :: dx(:,:), dy(:,:), dxh(:,:), dyh(:,:)
    real, intent(in) :: hhu(:,:), hhv(:,:), hhq(:,:)
    real, intent(in) :: aice0(:,:)
    real, intent(in) :: rlh(:,:)
    real, intent(in) :: zw(:)
    real, intent(in) :: g, rh0
    real, intent(in) :: anumaxt ! kh_0
    real, intent(in) :: anumaxu ! km_0
    real, intent(in) :: anubgrt ! kh_b0
    real, intent(in) :: anubgru ! km_b0
    ! Local variables
    integer :: nx, ny, nz
    integer :: i, j

    ! Modules types and vars
    !! obl_legacy vars
    real, allocatable :: u2(:,:,:), v2(:,:,:)
    real, allocatable :: s2(:,:,:)
    real, allocatable :: n2(:,:,:)
    real, allocatable :: kh_b(:)
    real, allocatable :: km_b(:)
    real, allocatable :: kh(:,:,:)
    real, allocatable :: km(:,:,:)
    real :: kh_str, km_str
    real :: kh_undimdepth, km_undimdepth
    real :: kh_unstable
    real :: kh_b0, km_b0
    real :: kh_0, km_0
    !! pph vars
    type(pphParamType) :: pphParams
    real :: neutral_mld(size(rit, 1), size(rit, 2))
    real :: u_dynH(size(rit, 1), size(rit, 2))
    !! pph_dyn vars
    type(pphDynParamType) :: pphDynParams

    !! allocate obl_legacy vars
    allocate(u2(size(uu, 1), size(uu, 2), size(uu, 3)))
    allocate(v2(size(vv, 1), size(vv, 2), size(vv, 3)))
    allocate(s2(size(u2, 1), size(u2, 2), size(u2, 3) - 1))
    allocate(n2(size(den, 1), size(den, 2), size(den, 3) - 1))
    allocate(kh(size(anzt, 1), size(anzt, 2), size(anzt, 3)))
    allocate(km(size(anzu, 1), size(anzu, 2), size(anzu, 3)))
    allocate(kh_b(size(kh, 3)))
    allocate(km_b(size(km, 3)))

    kh = anzt
    km = anzu
    kh_0 = anumaxt
    km_0 = anumaxu
    kh_b0 = anubgrt
    km_b0 = anubgru

    !! pph_vars
    nx = size(rit, 1)
    ny = size(rit, 2)
    nz = size(rit, 3)

    den = legacy_denp(tt, ss + 35.0, 0.0)

    if (richnum_mode == 1) then
      call legacy_u2(uu, dy, dyh, hhu, hhq, border_shift, lu, u2)
      call legacy_v2(vv, dx, dxh, hhv, hhq, border_shift, lu, v2)
      call legacy_s2(u2, v2, lu, s2)
      call legacy_n2(den, hhq, zw, g, lu, n2)
      call legacy_rit(n2, s2, border_shift, lu, rit(:,:,2:size(rit, 3)))
      call legacy_rit_top(rlh, taux, tauy, border_shift, lu, rit(:,:,1))
      call sync_xy_border_3d(rit)
      ! neutral mld & u_dynH for pph_dyn
      call legacy_neutral_mld(rlh, taux, tauy, border_shift, lu, neutral_mld, u_dynH)
      call sync_xy_border_2d(neutral_mld)
      call sync_xy_border_2d(u_dynH)
      !print *, "neutral_mld:", neutral_mld
    end if
    
    ! obl_legacy mixing mode
    if (kh_km_mode == 1) then
      do j = 1, ny
        do i = 1, nx
          if (lu(i, j) > lu_min) then
            call legacy_str("kh", taux(i,j), tauy(i,j), rh0, kh_str)
            call legacy_undimdepth("kh", kh_str, hhq(i,j), aice0(i,j), kh_undimdepth)
            call legacy_kh_b(zw(:), hhq(i,j), kh_unstable, kh_undimdepth, kh_b0, kh_b)
            call legacy_kh_unstable(tt(i,j,1), kh_unstable)
            call legacy_kh(rit(i,j,:), kh_0, kh_b(:), kh_unstable, kh(i,j,:))

            call legacy_str("km", taux(i,j), tauy(i,j), rh0, km_str)
            call legacy_undimdepth("km", km_str, hhq(i,j), aice0(i,j), km_undimdepth)
            call legacy_km_b(zw, km_unstable, km_undimdepth, km_b0, km_b)
            call legacy_km(rit(i,j,:), km_0, km_b(:), km(i,j,:))
          end if
        end do
      end do
    ! obl_pph mixing mode (lake constants)
    else if (kh_km_mode == 2) then
      pphParams%Kh_unstable = 0.05
      pphParams%Km_unstable = 0.05
      do j = 1, ny
        do i = 1, nx
          if (lu(i, j) > lu_min) then
            call pph_kh(kh(i,j,:), rit(i,j,:), pphParams, nz)
            call pph_km(km(i,j,:), rit(i,j,:), pphParams, nz)
          end if
        end do
      end do
      kh = kh * 10000.0
      km = km * 10000.0
    ! obl_pph mixing mode (inmom constants)
    else if (kh_km_mode == 3) then
      pphParams%Km_0 = 7.0 * 0.01
      pphParams%Kh_0 = 5.0 * 0.01
      pphParams%alpha = 5.0
      pphParams%n = 2.0
      pphParams%Km_b = 5e-6
      pphParams%Kh_b = 0.00001
      pphParams%Kh_unstable = 0.05
      pphParams%Km_unstable = 0.05
      do j = 1, ny
        do i = 1, nx
          if (lu(i, j) > lu_min) then
            call pph_kh(kh(i,j,:), rit(i,j,:), pphParams, nz)
            call pph_km(km(i,j,:), rit(i,j,:), pphParams, nz)
          end if
        end do
      end do
      kh = kh * 10000.0
      km = km * 10000.0
    ! obl_pph_dyn mixing mode
    else if (kh_km_mode == 4) then
      pphDynParams%Kh_unstable = 0.05
      pphDynParams%Km_unstable = 0.05
      do j = 1, ny
        do i = 1, nx
          if (lu(i, j) > lu_min) then
            call pph_dyn_kh(kh(i,j,:), rit(i,j,:), u_dynH(i,j), neutral_mld(i,j), pphdynParams, nz)
            call pph_dyn_km(km(i,j,:), rit(i,j,:), u_dynH(i,j), neutral_mld(i,j), pphdynParams, nz)
          end if
        end do
      end do
      kh = kh * 10000.0
      km = km * 10000.0
    ! obl_pph mixing mode (dasha constants)
    else if (kh_km_mode == 5) then
      pphParams%Km_0 = 7.0 * 0.01
      pphParams%Kh_0 = 5.0 * 0.01
      pphParams%alpha = 25.0 / 7.0 !Nuzhno tak!
      pphParams%Kh_unstable = 0.05
      pphParams%Km_unstable = 0.05
      do j = 1, ny
        do i = 1, nx
          if (lu(i, j) > lu_min) then
            call pph_kh(kh(i,j,:), rit(i,j,:), pphParams, nz)
            call pph_km(km(i,j,:), rit(i,j,:), pphParams, nz)
          end if
        end do
      end do
      kh = kh * 10000.0
      km = km * 10000.0
    end if
    
    ! print *, "Kh first:", kh(3,3,1:4)
    ! print *, "Km first:", km(3,3,1:4)

    ! print *, "Kh last:", kh(3,3,size(kh,3)-2:size(kh,3))
    ! print *, "Km last:", km(3,3,size(km,3)-2:size(km,3))
    anzt = kh
    anzu = km
  end subroutine init_vermix

  subroutine sync_xy_border_3d(array)
    implicit none
    real, intent(inout) :: array(:,:,:)
    integer :: nx, ny, nz, k
  
    ! Determine array dimensions
    nx = size(array, 1)
    ny = size(array, 2)
    nz = size(array, 3)

    ! Update boundary points along x (first_x and end_x boundaries)
    do k = 1, nz
      array(1,2:ny-1,k) = array(2,2:ny-1,k)      ! first_x boundary
      array(nx,2:ny-1,k) = array(nx-1,2:ny-1,k)  ! end_x boundary
    end do
  
    ! Update boundary points along y (first_y and end_y boundaries)
    do k = 1, nz
      array(2:nx-1,1,k) = array(2:nx-1,2,k)      ! first_y boundary
      array(2:nx-1,ny,k) = array(2:nx-1,ny-1,k)  ! end_y boundary
    end do
  
    ! Update corner points
    do k = 1, nz
      array(1,1,k) = (array(2,1,k) + array(1,2,k)) / 2.0
      array(1,ny,k) = (array(2,ny,k) + array(1,ny-1,k)) / 2.0
      array(nx,1,k) = (array(nx-1,1,k) + array(nx,2,k)) / 2.0
      array(nx,ny,k) = (array(nx-1,ny,k) + array(nx,ny-1,k)) / 2.0
    end do
  
  end subroutine sync_xy_border_3d

  subroutine sync_xy_border_2d(array)
    implicit none
    real, intent(inout) :: array(:,:)
    integer :: nx, ny

    ! Determine array dimensions
    nx = size(array, 1)
    ny = size(array, 2)

    ! Update boundary points along x (first_x and end_x boundaries)
    array(1,2:ny-1) = array(2,2:ny-1)         ! first_x boundary
    array(nx,2:ny-1) = array(nx-1,2:ny-1)     ! end_x boundary

    ! Update boundary points along y (first_y and end_y boundaries)
    array(2:nx-1,1) = array(2:nx-1,2)         ! first_y boundary
    array(2:nx-1,ny) = array(2:nx-1,ny-1)     ! end_y boundary

    ! Update corner points
    array(1,1) = (array(2,1) + array(1,2)) / 2.0
    array(1,ny) = (array(2,ny) + array(1,ny-1)) / 2.0
    array(nx,1) = (array(nx-1,1) + array(nx,2)) / 2.0
    array(nx,ny) = (array(nx-1,ny) + array(nx,ny-1)) / 2.0

  end subroutine sync_xy_border_2d
end module obl_inmom