#include "obl_def.fi"

module obl_pph
    !< @brief standard pacanowski-philander scheme
    ! --------------------------------------------------------------------------------

    ! modules used
    ! --------------------------------------------------------------------------------
#ifdef USE_CONFIG_PARSER
    use iso_c_binding, only: C_NULL_CHAR
    use config_parser
#endif

    use obl_grid
    use obl_bc
    use obl_turb_common

    ! directives list
    ! --------------------------------------------------------------------------------
    implicit none
    private

    ! public interface
    ! --------------------------------------------------------------------------------
    public :: init_turbulence_closure
    public :: advance_turbulence_closure
    public :: define_stability_functions

    public :: get_eddy_viscosity, get_eddy_diffusivity
    public :: get_eddy_viscosity_vec, get_eddy_diffusivity_vec

    public :: set_config_param
    ! --------------------------------------------------------------------------------

    !> @brief Pacanowski-Philander parameters
    type, public :: pphParamType
        real :: Km_0 = 2.0 * 0.01   !< neutral eddy viscosity, [m**2 / s] *: INMCM: 7.0 * 0.01
        real :: Kh_0 = 2.0 * 0.01   !< neutral eddy diffusivity, [m**2 / s], *: INMCM: 5.0 * 0.01
        real :: alpha = 5.0         !< constant
        real :: n = 2.0             !< constant
        real :: Km_b = 0.0001       !< background eddy viscosity, [m**2 / s], *: INMCM: 5 * 10**(-6)  
        real :: Kh_b = 0.00001      !< background eddy diffusivity, [m**2 / s]

        real :: Km_unstable = 100.0 !< unstable eddy viscosity, [m**2 / s] *:
        real :: Kh_unstable = 100.0 !< unstable eddy diffusivity, [m**2 / s], *: INMCM: 5.0 * 0.01

        real :: kappa = 0.4         !< von Karman constant, [n/d]
        real :: PrT0 = 1.0          !< neutral Prandtl number, [n/d]

        real :: c_S2_min = 1e-5     !< min shear frequency, [(1/s)**2]
    end type

    contains


    ! --------------------------------------------------------------------------------
    subroutine define_stability_functions(param, bc, grid)
        !< @brief advance stability functions: N**2, S**2, Ri-gradient
        ! ----------------------------------------------------------------------------
        use obl_state

        type(pphParamType), intent(in) :: param
        type(oblBcType), intent(in) :: bc
        type (gridDataType), intent(in) :: grid
        ! --------------------------------------------------------------------------------

        call get_N2(N2, Rho, bc%rho_dyn0, bc%rho_dynH, & 
            param%kappa, param%PrT0, grid)
        call get_S2(S2, U, V, bc%U_dyn0, bc%U_dynH, &
            param%kappa, grid)
        call get_Ri_gradient(Ri_grad, N2, S2, param%c_S2_min, grid)
    end subroutine

    ! --------------------------------------------------------------------------------
    subroutine init_turbulence_closure(param, bc, grid)
        !< @brief advance turbulence closure
        ! ----------------------------------------------------------------------------
        use obl_state

        type(pphParamType), intent(in) :: param
        type(oblBcType), intent(in) :: bc
        type (gridDataType), intent(in) :: grid
        ! --------------------------------------------------------------------------------

        call get_eddy_viscosity(Km, Ri_grad, param, grid)
        call get_eddy_diffusivity(Kh, Ri_grad, param, grid)
    end subroutine

    ! --------------------------------------------------------------------------------
    subroutine advance_turbulence_closure(param, bc, grid, dt)
        !< @brief advance turbulence closure
        ! ----------------------------------------------------------------------------
        use obl_state

        type(pphParamType), intent(in) :: param
        type(oblBcType), intent(in) :: bc
        type (gridDataType), intent(in) :: grid
        real, intent(in) :: dt
        ! --------------------------------------------------------------------------------

        call get_TKE_production(TKE_production, Km, S2, grid)
        call get_TKE_buoyancy(TKE_buoyancy, Kh, N2, grid)

        call get_eddy_viscosity(Km, Ri_grad, param, grid)
        call get_eddy_diffusivity(Kh, Ri_grad, param, grid)
    end subroutine
    
    ! --------------------------------------------------------------------------------
    subroutine get_eddy_viscosity(Km, Ri_grad, param, grid)
        !< @brief calculate eddy viscosity on grid
        ! ----------------------------------------------------------------------------
        type(pphParamType), intent(in) :: param
        type (gridDataType), intent(in) :: grid

        real, dimension(grid%cz), intent(in) :: Ri_grad      !< Richardson gradient number 
        real, dimension(grid%cz), intent(out) :: Km          !< eddy viscosity, [m**2 / s]
        ! --------------------------------------------------------------------------------

        call get_eddy_viscosity_vec(Km, Ri_grad, param, grid%cz)
    end subroutine

    subroutine get_eddy_viscosity_vec(Km, Ri_grad, param, n)
        !< @brief calculate eddy viscosity
        ! ----------------------------------------------------------------------------
        type(pphParamType), intent(in) :: param

        integer, intent(in) :: n                       !< vector size
        real, dimension(n), intent(in) :: Ri_grad      !< Richardson gradient number 
        real, dimension(n), intent(out) :: Km          !< eddy viscosity, [m**2 / s]

        integer :: k
        ! --------------------------------------------------------------------------------
    
        do k = 1, n
            if (Ri_grad(k) >= 0.0) then
                Km(k) = param%Km_0 / (1.0 + param%alpha * Ri_grad(k))**(param%n) + param%Km_b
            else 
                Km(k) = param%Km_unstable
            endif
        end do
    end subroutine

    ! --------------------------------------------------------------------------------
    subroutine get_eddy_diffusivity(Kh, Ri_grad, param, grid)
        !< @brief calculate eddy diffusivity on grid
        ! ----------------------------------------------------------------------------
        type(pphParamType), intent(in) :: param
        type (gridDataType), intent(in) :: grid

        real, dimension(grid%cz), intent(in) :: Ri_grad      !< Richardson gradient number 
        real, dimension(grid%cz), intent(out) :: Kh          !< eddy diffusivity, [m**2 / s]
        ! --------------------------------------------------------------------------------
        
        call get_eddy_diffusivity_vec(Kh, Ri_grad, param, grid%cz)
    end subroutine

    subroutine get_eddy_diffusivity_vec(Kh, Ri_grad, param, n)
        !< @brief calculate eddy diffusivity
        ! ----------------------------------------------------------------------------
        type(pphParamType), intent(in) :: param

        integer, intent(in) :: n                       !< vector size
        real, dimension(n), intent(in) :: Ri_grad      !< Richardson gradient number 
        real, dimension(n), intent(out) :: Kh          !< eddy diffusivity, [m**2 / s]
        
        integer :: k
        ! --------------------------------------------------------------------------------

        do k = 1, n
            if (Ri_grad(k) >= 0.0) then
                Kh(k) = param%Kh_0 / (1.0 + param%alpha * Ri_grad(k))**(param%n + 1.0) + param%Kh_b
            else
                Kh(k) = param%Kh_unstable
            endif
        end do
    end subroutine

    ! --------------------------------------------------------------------------------
    subroutine set_config_param(param, tag, ierr)
        !< @brief set turbulence closure parameters
        ! ----------------------------------------------------------------------------
        type(pphParamType), intent(out) :: param
        integer, intent(out) :: ierr

        character(len = *), intent(in) :: tag

        integer :: status
        ! --------------------------------------------------------------------------------

        ierr = 0        ! = OK

#ifdef USE_CONFIG_PARSER
        call c_config_is_varname(trim(tag)//".Km_0"//C_NULL_CHAR, status)
        if (status /= 0) then
            call c_config_get_float(trim(tag)//".Km_0"//C_NULL_CHAR, param%Km_0, status)
            if (status == 0) then
                ierr = 1        ! signal ERROR
                return
            end if
        endif
        call c_config_is_varname(trim(tag)//".Kh_0"//C_NULL_CHAR, status)
        if (status /= 0) then
            call c_config_get_float(trim(tag)//".Kh_0"//C_NULL_CHAR, param%Kh_0, status)
            if (status == 0) then
                ierr = 1        ! signal ERROR
                return
            end if
        endif
        call c_config_is_varname(trim(tag)//".alpha"//C_NULL_CHAR, status)
        if (status /= 0) then
            call c_config_get_float(trim(tag)//".alpha"//C_NULL_CHAR, param%alpha, status)
            if (status == 0) then
                ierr = 1        ! signal ERROR
                return
            end if
        endif
        call c_config_is_varname(trim(tag)//".n"//C_NULL_CHAR, status)
        if (status /= 0) then
            call c_config_get_float(trim(tag)//".n"//C_NULL_CHAR, param%n, status)
            if (status == 0) then
                ierr = 1        ! signal ERROR
                return
            end if
        endif
        call c_config_is_varname(trim(tag)//".Km_b"//C_NULL_CHAR, status)
        if (status /= 0) then
            call c_config_get_float(trim(tag)//".Km_b"//C_NULL_CHAR, param%Km_b, status)
            if (status == 0) then
                ierr = 1        ! signal ERROR
                return
            end if
        endif
        call c_config_is_varname(trim(tag)//".Kh_b"//C_NULL_CHAR, status)
        if (status /= 0) then
            call c_config_get_float(trim(tag)//".Kh_b"//C_NULL_CHAR, param%Kh_b, status)
            if (status == 0) then
                ierr = 1        ! signal ERROR
                return
            end if
        endif
        call c_config_is_varname(trim(tag)//".Km_unstable"//C_NULL_CHAR, status)
        if (status /= 0) then
            call c_config_get_float(trim(tag)//".Km_unstable"//C_NULL_CHAR, param%Km_unstable, status)
            if (status == 0) then
                ierr = 1        ! signal ERROR
                return
            end if
        endif
        call c_config_is_varname(trim(tag)//".Kh_unstable"//C_NULL_CHAR, status)
        if (status /= 0) then
            call c_config_get_float(trim(tag)//".Kh_unstable"//C_NULL_CHAR, param%Kh_unstable, status)
            if (status == 0) then
                ierr = 1        ! signal ERROR
                return
            end if
        endif
        call c_config_is_varname(trim(tag)//".kappa"//C_NULL_CHAR, status)
        if (status /= 0) then
            call c_config_get_float(trim(tag)//".kappa"//C_NULL_CHAR, param%kappa, status)
            if (status == 0) then
                ierr = 1        ! signal ERROR
                return
            end if
        endif
        call c_config_is_varname(trim(tag)//".PrT0"//C_NULL_CHAR, status)
        if (status /= 0) then
            call c_config_get_float(trim(tag)//".PrT0"//C_NULL_CHAR, param%PrT0, status)
            if (status == 0) then
                ierr = 1        ! signal ERROR
                return
            end if
        endif
        call c_config_is_varname(trim(tag)//".c_S2_min"//C_NULL_CHAR, status)
        if (status /= 0) then
            call c_config_get_float(trim(tag)//".c_S2_min"//C_NULL_CHAR, param%c_S2_min, status)
            if (status == 0) then
                ierr = 1        ! signal ERROR
                return
            end if
        endif
#else
        !> *: just skipping setup without configuration file
#endif

    end subroutine
    
end module